diff --git a/.gitignore b/.gitignore index 254a8aa..4b76236 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ logs/ # data -- include the README but ignore everything else data/* !data/README.md + +evaluation_checkpoint/* diff --git a/README.md b/README.md index fbfadf5..e149fc3 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,17 @@ To generate new data, a minimal example is python main.py --mode generate --num_samples [num_samples] --load [path/to/parameters.pt] ``` +### Evaluate +To evaluate the generated samples, a minimal example is (here assuming WBM is the reference set) +``` +python main.py --mode evaluate --train_datafile data/wbm/raw/wbm_train.csv --generated_datafile [path/to/generated_samples.pt] +``` +This will default to using 10k samples (from both the generated and reference/train set) when computing all metrics. If you prefer another number, this can be set with `--num_samples_in_evaluation`. + +**Note regarding pre-trained weights:** The pre-trained weights of the Wrenformer used for computing FWD will be automatically downloaded. If you prefer to download these weights yourself, create a new directory called `evaluation_checkpoint` and place the weights called ``checkpoint.pth`` there. The code will run a checksum to verify that they are the same weights. + +**Note regarding compatibility with other models:** The evaluation code also supports reading data generated by other models that we compared to in the paper (CDVAE, DiffCSP++, SymmCD). The data will be converted to protostructures. + ### Parse generated data To convert generated data to protostructures and prototypes, run ``` diff --git a/wyckoff_generation/common/args_and_config.py b/wyckoff_generation/common/args_and_config.py index 6210ecc..c5932c7 100644 --- a/wyckoff_generation/common/args_and_config.py +++ b/wyckoff_generation/common/args_and_config.py @@ -34,6 +34,8 @@ "mlp_activation": "SiLU", # Generation "num_samples": 10000, + # Evaluation + "num_samples_in_evaluation": 10000, } @@ -242,6 +244,26 @@ def get_parser(): help="Number of samples to generate", ) + # Evaluation + parser.add_argument( + "--generated_datafile", + type=str, + help="Path to generated data", + ) + + parser.add_argument( + "--train_datafile", + type=str, + help="Path to csv file with training data (for evaluation)", + ) + + parser.add_argument( + "--num_samples_in_evaluation", + type=int, + default=default_args_dict["num_samples_in_evaluation"], + help="Number of samples to compute evaluation statistics on", + ) + # Post processing of generated samples parser.add_argument( "--post_process_all", diff --git a/wyckoff_generation/common/utils.py b/wyckoff_generation/common/utils.py index de31a8a..c7d51fa 100644 --- a/wyckoff_generation/common/utils.py +++ b/wyckoff_generation/common/utils.py @@ -19,6 +19,7 @@ """ +import hashlib import importlib import os import re @@ -34,6 +35,14 @@ def get_pretrained_checkpoint(load_path, best=True): return checkpoint +def compare_hash(data_path, correct_hash): + sha256_hash = hashlib.sha256() + with open(data_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() == correct_hash + + def increment_filename(file_path): # Split the file path into directory, filename, and extension directory, filename = os.path.split(file_path) diff --git a/wyckoff_generation/datasets/data_utils.py b/wyckoff_generation/datasets/data_utils.py index 018e14c..937e823 100644 --- a/wyckoff_generation/datasets/data_utils.py +++ b/wyckoff_generation/datasets/data_utils.py @@ -3,6 +3,7 @@ import aviary.wren.data as aviary_wren_data import aviary.wren.utils as aviary_wren_utils +import pandas as pd import torch from aviary.wren.utils import ( canonicalize_element_wyckoffs, @@ -304,3 +305,82 @@ def enrich_dataset( return dataset, all_aflow_labels, all_prototype_labels return dataset + + +def compare_generated_with_training_dataset_fast_label_list( + generated_dataset: list[list[int, str, str]], + training_dataset: list[str], + set_name: str, + return_duplicate_names: bool = False, +) -> pd.DataFrame: + + """ + + generated_dataset: list of lists, where second level contains [original_index, protostructure_label, prototype_label] + training_dataset: list of protostructure labels from reference set (e.g., train/val/test) + set_name: name of dataset, include split for explicitness + return_duplicate_names: return the duplicate names if needed for other function + + returns: pd.DataFrame containing protostructures, prototypes, novel, novel_prototype, duplicates_{set_name}_set", duplicates_{set_name}_set_prototype + """ + + print( + f"Identifying novelty of generated data compared with {set_name} dataset, saving to attribute 'novel' and 'duplicates_{set_name}_set'...", + file=sys.stdout, + ) + + # Create a dataframe from the aflow label lists and canonical prototypes in the generated dataset + gen_df = pd.DataFrame( + generated_dataset, columns=["original_index", "protostructures", "prototypes"] + ) + + # --- Protostructures and prototypes matching + + train_df = pd.DataFrame(training_dataset, columns=["protostructures"]) + train_df["prototypes"] = train_df.protostructures.apply( + aviary_wren_utils.get_prototype_from_protostructure + ) + + # Check if novel (note that novel is opposite to is-in) + gen_df["novel"] = ~gen_df.protostructures.isin(train_df.protostructures) + gen_df["novel_prototype"] = ~gen_df.prototypes.isin(train_df.prototypes) + + # Collect duplicates + # Protostructures + duplicate_attribute_name = f"duplicates_{set_name}_set" + train_grouped_protostructures = ( + train_df.groupby("protostructures")["protostructures"] + .apply(list) + .reset_index(name=duplicate_attribute_name) + ) + gen_df = gen_df.merge( + train_grouped_protostructures, + how="left", + left_on="protostructures", + right_on="protostructures", + ) + gen_df[duplicate_attribute_name] = gen_df[duplicate_attribute_name].apply( + lambda x: x if isinstance(x, list) else [] + ) + + # Prototypes + duplicate_prototype_attribute_name = f"duplicates_{set_name}_set_prototype" + train_grouped_prototypes = ( + train_df.groupby("prototypes")["prototypes"] + .apply(list) + .reset_index(name=duplicate_prototype_attribute_name) + ) + gen_df = gen_df.merge( + train_grouped_prototypes, + how="left", + left_on="prototypes", + right_on="prototypes", + ) + gen_df[duplicate_prototype_attribute_name] = gen_df[ + duplicate_prototype_attribute_name + ].apply(lambda x: x if isinstance(x, list) else []) + + if return_duplicate_names: + return gen_df, duplicate_attribute_name, duplicate_prototype_attribute_name + + return gen_df diff --git a/wyckoff_generation/datasets/dataset.py b/wyckoff_generation/datasets/dataset.py index 05063dc..0767180 100644 --- a/wyckoff_generation/datasets/dataset.py +++ b/wyckoff_generation/datasets/dataset.py @@ -11,6 +11,7 @@ from torch_geometric.loader import DataLoader from wyckoff_generation.common.registry import registry +from wyckoff_generation.common.utils import compare_hash from wyckoff_generation.datasets.lookup_tables import ( element_number, spg_wyckoff, @@ -199,14 +200,6 @@ def preprocess(raw_file_path) -> pd.DataFrame: return parsed_aflow_labels -def compare_hash(data_path, correct_hash): - sha256_hash = hashlib.sha256() - with open(data_path, "rb") as f: - for byte_block in iter(lambda: f.read(4096), b""): - sha256_hash.update(byte_block) - return sha256_hash.hexdigest() == correct_hash - - def decompress_bz2_file( compressed_file_path, remove_original=False, decompressed_file_path=None ): diff --git a/wyckoff_generation/evaluation/compute_fwd.py b/wyckoff_generation/evaluation/compute_fwd.py new file mode 100644 index 0000000..fba8535 --- /dev/null +++ b/wyckoff_generation/evaluation/compute_fwd.py @@ -0,0 +1,146 @@ +import torch +from aviary.wrenformer.data import df_to_in_mem_dataloader +from aviary.wrenformer.model import Wrenformer +from tqdm import tqdm + +from wyckoff_generation.evaluation import read_file_utils +from wyckoff_generation.evaluation.frechet_distance import ( + frechet_distance_from_embeddings, +) +from wyckoff_generation.evaluation.novelty_helper import ( + get_enriched_df, + get_statistics_from_df, +) + + +def get_embeddings(model, dataset): + store = [] + + def hook_fn(module, input, output): + store.append(output) + return output + + target_layer = list(model.children())[-2] + hook_handle = target_layer.register_forward_hook(hook_fn) + + # ids_list = [] + with torch.no_grad(): + for d in dataset: + (padded_features, mask, equivalence_counts), targets, ids = d + # ids_list.extend(ids.tolist()) + output = model(padded_features, mask, equivalence_counts) + hook_handle.remove() + + return torch.cat(store, dim=0) + + +def main(args): + train_data_df_full = read_file_utils.get_dataset_df( + args["train_datafile"], + ) + print("Parsing generated materials") + gen_data_df_full = read_file_utils.get_dataset_df(args["generated_datafile"]) + + assert len(train_data_df_full.index) >= args["num_samples_in_evaluation"], len( + train_data_df_full.index + ) + assert len(gen_data_df_full.index) >= args["num_samples_in_evaluation"], len( + gen_data_df_full.index + ) + + gen_data_df_enriched = get_enriched_df(gen_data_df_full, train_data_df_full) + gen_data_df_subsampled = gen_data_df_enriched.sample( + n=args["num_samples_in_evaluation"], + replace=False, + ignore_index=True, + random_state=42, + ) + train_data_df_subsampled = train_data_df_full.sample( + n=args["num_samples_in_evaluation"], + replace=False, + ignore_index=True, + random_state=42, + ) + + gen_data_fwd = df_to_in_mem_dataloader( + gen_data_df_subsampled, + input_col="protostructures", + batch_size=args["batch_size"], + shuffle=False, + ) + train_data_fwd = df_to_in_mem_dataloader( + train_data_df_subsampled, + input_col="wyckoff", + batch_size=args["batch_size"], + shuffle=False, + ) + + state_dict = torch.load("evaluation_checkpoint/checkpoint.pth", map_location="cpu") + model = Wrenformer(**state_dict["model_params"]).to(args["device"]) + model.load_state_dict(state_dict["model_state"]) + model.train(False) + assert not model.training + print("Computing training embeddings") + + # to improve stability, use double precision + print("Computing Wrenformer embeddings of generated and training materials") + train_embeddings = get_embeddings(model, train_data_fwd).double() + gen_embeddings = get_embeddings(model, gen_data_fwd).double() + fwd = float(frechet_distance_from_embeddings(train_embeddings, gen_embeddings)) + + stats_subsampled = get_statistics_from_df(gen_data_df_subsampled) + stats_subsampled["fwd"] = fwd + print("\n\n----Stats for generated materials----") + for key, value in stats_subsampled.items(): + if isinstance(value, float): + print(f"{key}: {value}") + + gen_data_novel_only = gen_data_df_enriched.loc[gen_data_df_enriched["novel"]] + assert len(gen_data_novel_only.index) >= args["num_samples_in_evaluation"] + gen_data_novel_subsampled = gen_data_novel_only.sample( + n=args["num_samples_in_evaluation"], + replace=False, + ignore_index=True, + random_state=42, + ) + gen_data_novel_fwd = df_to_in_mem_dataloader( + gen_data_novel_subsampled, + input_col="protostructures", + batch_size=args["batch_size"], + shuffle=False, + ) + + gen_novel_embeddings = get_embeddings(model, gen_data_novel_fwd).double() + fwd_novel = float( + frechet_distance_from_embeddings(train_embeddings, gen_novel_embeddings) + ) + stats_novel = get_statistics_from_df(gen_data_novel_subsampled) + stats_novel["fwd"] = fwd_novel + print("\n\n----Stats for generated novel materials----") + for key, value in stats_novel.items(): + if isinstance(value, float): + print(f"{key}: {value}") + + result_string = " & ".join( + [ + f"{stats_subsampled['fwd']:.2f}", + f"{stats_subsampled['novelty']*100:.1f}", + f"{stats_subsampled['uniqueness']*100:.1f}", + f"{stats_novel['fwd']:.2f}", + f"{stats_novel['uniqueness']*100:.1f}", + ] + ) + print("\n\n----Results string for LaTeX table----\n", result_string) + + full_results_dict = { + "stats_subsampled": { + key: value + for key, value in stats_subsampled.items() + if isinstance(value, float) + }, + "stats_novel": { + key: value for key, value in stats_novel.items() if isinstance(value, float) + }, + "result_string": result_string, + } + return full_results_dict, gen_data_df_subsampled, gen_data_novel_subsampled diff --git a/wyckoff_generation/evaluation/compute_prototype_stats.py b/wyckoff_generation/evaluation/compute_prototype_stats.py new file mode 100644 index 0000000..7f87b9b --- /dev/null +++ b/wyckoff_generation/evaluation/compute_prototype_stats.py @@ -0,0 +1,19 @@ +import os +import sys + +import pandas as pd + + +def main(folder, num_samples): + file = os.path.join( + folder, f"gen_data_novel_subsampled_num_samples={num_samples}.csv" + ) + df = pd.read_csv(file) + assert len(df.index) == 10000 + assert df["novel"].all() + + novel_prototype_df = df[df["novel_prototype"]] + # print(novel_prototype_df.head()) + + unique_prototypes = novel_prototype_df["prototypes"].unique().tolist() + print(len(unique_prototypes)) diff --git a/wyckoff_generation/evaluation/eval_utils.py b/wyckoff_generation/evaluation/eval_utils.py new file mode 100644 index 0000000..ccab7dd --- /dev/null +++ b/wyckoff_generation/evaluation/eval_utils.py @@ -0,0 +1,140 @@ +from collections import Counter + +import numpy as np + +chemical_symbols = [ + # 0 + "X", + # 1 + "H", + "He", + # 2 + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + # 3 + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + # 4 + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + # 5 + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + # 6 + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + # 7 + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og", +] + + +def get_composition(atom_types): + composition = sorted(Counter(atom_types.tolist()).items()) + elem_idxs, counts = zip(*composition) + counts = np.array(counts) / np.gcd.reduce(counts) + return elem_idxs, tuple(counts.astype("int").tolist()) diff --git a/wyckoff_generation/evaluation/frechet_distance.py b/wyckoff_generation/evaluation/frechet_distance.py new file mode 100644 index 0000000..316763b --- /dev/null +++ b/wyckoff_generation/evaluation/frechet_distance.py @@ -0,0 +1,27 @@ +import torch + + +def frechet_distance(mu_1, sigma_1, mu_2, sigma_2): + assert len(mu_1.shape) == len(mu_2.shape) == 1 + assert len(sigma_1.shape) == len(sigma_2.shape) == 2 + assert mu_1.shape == mu_2.shape + assert sigma_1.shape == sigma_2.shape + assert sigma_1.shape[0] == sigma_1.shape[1] == mu_1.shape[0] + mean_diff_norm = (mu_1 - mu_2).square().sum() + cov_traces = sigma_1.trace() + sigma_2.trace() + cov_product_trace = torch.linalg.eigvals(sigma_1 @ sigma_2).sqrt().real.sum() + return mean_diff_norm + cov_traces - 2 * cov_product_trace + + +def frechet_distance_from_embeddings(feats_1, feats_2): + assert len(feats_1.shape) == len(feats_2.shape) == 2 + assert feats_1.shape[-1] == feats_2.shape[-1] + # for improved stability, use double precision + feats_1 = feats_1.double() + feats_2 = feats_2.double() + mu_1 = torch.mean(feats_1, dim=0) + mu_2 = torch.mean(feats_2, dim=0) + cov_1 = torch.cov(feats_1.T) + cov_2 = torch.cov(feats_2.T) + + return frechet_distance(mu_1, cov_1, mu_2, cov_2) diff --git a/wyckoff_generation/evaluation/novelty_helper.py b/wyckoff_generation/evaluation/novelty_helper.py new file mode 100644 index 0000000..6fd8816 --- /dev/null +++ b/wyckoff_generation/evaluation/novelty_helper.py @@ -0,0 +1,49 @@ +import numpy as np +import pandas as pd +from aviary.wren.utils import get_prototype_from_protostructure + +from wyckoff_generation.datasets.data_utils import ( + compare_generated_with_training_dataset_fast_label_list, +) + + +def get_statistics_from_df(dataset_comparison): + result_dict = {} + + # protostructures + protostructures = list(dataset_comparison["protostructures"]) + unique_protostructures = set(protostructures) + result_dict["uniqueness"] = len(unique_protostructures) / len(protostructures) + novel_array = np.array(dataset_comparison["novel"]) + result_dict["novelty"] = novel_array.mean() + novel_protostructures = list(dataset_comparison["protostructures"][novel_array]) + novel_unique_protostructures = list(set(novel_protostructures)) + result_dict["novel_uniqueness"] = len(novel_unique_protostructures) / len( + novel_protostructures + ) + result_dict["novel_and_unique_protostructures"] = novel_unique_protostructures + + # prototypes + prototypes = list(dataset_comparison["prototypes"]) + unique_prototypes = set(prototypes) + result_dict["uniqueness_prototypes"] = len(unique_prototypes) / len(prototypes) + novel_prototype_array = np.array(dataset_comparison["novel_prototype"]) + result_dict["novel_prototypes"] = novel_prototype_array.mean() + novel_prototypes = list(dataset_comparison["prototypes"][novel_prototype_array]) + novel_unique_prototypes = set(novel_prototypes) + result_dict["novel_prototypes_uniqueness"] = len(novel_unique_prototypes) / len( + novel_prototypes + ) + result_dict["novel_and_unique_prototypes"] = novel_unique_protostructures + return result_dict + + +def get_enriched_df(gen_data_df, train_df): + gen_data_list = [ + [i, row["wyckoff"], get_prototype_from_protostructure(row["wyckoff"])] + for i, row in gen_data_df.iterrows() + ] + + return compare_generated_with_training_dataset_fast_label_list( + gen_data_list, list(train_df["wyckoff"]), "training" + ) diff --git a/wyckoff_generation/evaluation/read_file_utils.py b/wyckoff_generation/evaluation/read_file_utils.py new file mode 100644 index 0000000..9195450 --- /dev/null +++ b/wyckoff_generation/evaluation/read_file_utils.py @@ -0,0 +1,131 @@ +import warnings + +import pandas as pd +import torch +from aviary.wren.utils import get_protostructure_label_from_spglib +from pymatgen.core import Lattice, Structure +from tqdm import tqdm + +from wyckoff_generation.datasets.data_utils import build_protostructure + +warnings.filterwarnings("ignore") + + +def cifstrings_to_protostructs(strings_list): + aflow_labels = [] + for cif in strings_list: + aflow_labels.append( + get_protostructure_label_from_spglib(Structure.from_str(cif, "cif")) + ) + df = pd.DataFrame(aflow_labels, columns=["wyckoff"]) + return df + + +def data_from_csv(file_path): + df = pd.read_csv(file_path) + if "wyckoff" in df.columns: + pass + elif "wyckoff_spglib" in df.columns: + df["wyckoff"] = df["wyckoff_spglib"] + elif "cif" in df.columns: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + df["wyckoff"] = cifstrings_to_protostructs(df["cif"]) + else: + raise ValueError( + "No column that is readable. Check input file, or implement function that creates protostructure" + ) + + return df[["wyckoff"]] + + +def _safe_build(d): + try: + return build_protostructure(d) + except ValueError: + print("Error with material, skipping") + return None + + +import time +from concurrent.futures import ProcessPoolExecutor + + +def data_from_wyckoffdiff_list(data): + st = time.time() + with ProcessPoolExecutor() as executor: + results = list(executor.map(_safe_build, data)) + + results = [r for r in results if r is not None] # drop failed entries + et = time.time() + print(f"Total time for reading WyckoffDiff data: {et - st:.2f} s") + return pd.DataFrame(results, columns=["wyckoff"]) + + +def data_from_cdvae_dict(data): + lengths = data["lengths"] + if lengths.dim() == 3: # CDVAE + assert lengths.shape[0] == 1 + lengths = lengths.squeeze(0) + angles = data["angles"].squeeze(0) + num_atoms_per_material = data["num_atoms"].squeeze(0) + frac_coords = data["frac_coords"].squeeze(0) + frac_coords = torch.split(frac_coords, num_atoms_per_material.tolist()) + atom_types = data["atom_types"].squeeze(0) + if atom_types.dim() == 2: # SymmCD + atom_types = torch.where(atom_types == 1)[1] + 1 + atom_types = torch.split(atom_types, num_atoms_per_material.tolist()) + aflow_labels = [] + for l, ang, frac, atms in zip(lengths, angles, frac_coords, atom_types): + # Check for NaN, inf, or out-of-range values + if ( + torch.isnan(l).any() + or torch.isnan(ang).any() + or torch.isnan(frac).any() + or torch.isinf(l).any() + or torch.isinf(ang).any() + or torch.isinf(frac).any() + or (l > 1e4).any() + ): + print("Skipping due to NaN or inf values") + continue + structure = Structure( + Lattice.from_parameters(*l, *ang), atms, frac, coords_are_cartesian=False + ) + protostructure = get_protostructure_label_from_spglib(structure) + # Filter out NULL entries from spglib outputs + if protostructure: + aflow_labels.append(protostructure) + df = pd.DataFrame(aflow_labels, columns=["wyckoff"]) + print(f"Effective data size: {df.shape[0]}") + return df + + +def data_from_pt(file_path): + data = torch.load(file_path, map_location="cpu") + if isinstance(data, dict): + return data_from_cdvae_dict(data) + elif isinstance(data, list): + return data_from_wyckoffdiff_list(data) + else: + ValueError(f"Unknown datatype of {file_path}: {type(data)}") + + +def get_dataset_df(file_path, num_samples=None): + if file_path.endswith(".csv"): + df = data_from_csv(file_path) + elif file_path.endswith(".pt"): + df = data_from_pt(file_path) + else: + raise ValueError("file extension not known") + + if num_samples is not None: + if len(df.index) < num_samples: + raise RuntimeError( + f"The file {file_path} contains less than the requested {num_samples} materials" + ) + if len(df.index) > num_samples: + df = df.sample( + n=num_samples, replace=False, ignore_index=True, random_state=42 + ) + return df diff --git a/wyckoff_generation/runners/evaluation_runner.py b/wyckoff_generation/runners/evaluation_runner.py new file mode 100644 index 0000000..d69c9a0 --- /dev/null +++ b/wyckoff_generation/runners/evaluation_runner.py @@ -0,0 +1,98 @@ +import hashlib +import json +import os + +import pandas as pd +import wandb + +from wyckoff_generation.common.registry import registry +from wyckoff_generation.common.utils import compare_hash +from wyckoff_generation.evaluation import compute_fwd +from wyckoff_generation.runners.base_runner import BaseRunner + + +@registry.register_runner("evaluate") +class EvaluationRunner(BaseRunner): + def __init__(self, config): + super().__init__(config) + self.config = config + self.get_eval_checkpoint() + + def init_model(self, config): + pass + + def init_dataloaders(self, config): + pass + + def init_optimizer(self, config): + pass + + def load_checkpoint(self, config): + pass + + def run(self): + folder = self.compute_fwd(self.config) + self.compute_prototype_stats(folder, self.config["num_samples_in_evaluation"]) + + def get_eval_checkpoint(self): + folder = "evaluation_checkpoint" + checkpoint_file_path = os.path.join(folder, "checkpoint.pth") + if not os.path.isfile(checkpoint_file_path): + print("Wrenformer checkpoint not found, will download") + os.makedirs(folder, exist_ok=True) + run = wandb.Api().run("janosh/matbench-discovery/2kozbp4q") + run.file("checkpoint.pth").download(root=folder) + print("Downloaded") + else: + print("Wrenformer checkpoint already downloaded") + assert compare_hash( + checkpoint_file_path, + "20cecd1560e5fc71851cf8675216ed7e30ccb53f18e86de6cfd8ffd090d35f36", + ), f"Hash mismatch for Wrenformer checkpoint at {checkpoint_file_path}" + + def compute_fwd(self, args): + folder, file = os.path.split(args["generated_datafile"]) + ( + results_dict, + gen_data_subsampled_df, + gen_data_novel_subsampled_df, + ) = compute_fwd.main(args) + with open( + os.path.join( + folder, + f"statistics_num_samples={args['num_samples_in_evaluation']}.json", + ), + "w", + ) as f: + json.dump(results_dict, f, indent=4) + columns_to_save = ["protostructures", "prototypes", "novel", "novel_prototype"] + gen_data_subsampled_df.to_csv( + os.path.join( + folder, + f"gen_data_subsampled_num_samples={args['num_samples_in_evaluation']}.csv", + ), + columns=columns_to_save, + index=False, + ) + gen_data_novel_subsampled_df.to_csv( + os.path.join( + folder, + f"gen_data_novel_subsampled_num_samples={args['num_samples_in_evaluation']}.csv", + ), + columns=columns_to_save, + index=False, + ) + return folder + + def compute_prototype_stats(self, folder, num_samples): + file = os.path.join( + folder, f"gen_data_novel_subsampled_num_samples={num_samples}.csv" + ) + df = pd.read_csv(file) + assert len(df.index) == num_samples + assert df["novel"].all() + + novel_prototype_df = df[df["novel_prototype"]] + + unique_prototypes = novel_prototype_df["prototypes"].unique().tolist() + print("\n\nNumber of novel and unique prototypes: ", len(unique_prototypes))