diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index aeda4729..3fe3ff9e 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -4520,3 +4520,4 @@ b [224RaH2] [226RaH2] [228RaH2] +[*-:0] diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index e295a3ed..03054278 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -907,7 +907,9 @@ def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: print(f"Missing processed data file (`{processed_name}` file)") os.makedirs(self.processed_dir_main, exist_ok=True) data_path = self._download_required_data() - g = self._extract_class_hierarchy(data_path) + from chebi_utils import build_chebi_graph + + g = build_chebi_graph(data_path) data_df = self._graph_to_raw_dataset(g) self.save_processed(data_df, processed_name) @@ -921,26 +923,11 @@ def _download_required_data(self) -> str: """ pass - @abstractmethod - def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": - """ - Extracts the class hierarchy from the data. - Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from - the term documents. - - Args: - data_path (str): Path to the data. - - Returns: - nx.DiGraph: The class hierarchy graph. - """ - pass - @abstractmethod def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. - Uses the graph created by `_extract_class_hierarchy` method to extract the + Uses the graph created by chebi_utils to extract the raw data in Dataframe format with additional columns corresponding to each multi-label class. Args: @@ -951,21 +938,6 @@ def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame: """ pass - @abstractmethod - def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: - """ - Selects classes from the dataset based on a specified criteria. - - Args: - g (nx.Graph): The graph representing the dataset. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - List: A sorted list of node IDs that meet the specified criteria. - """ - pass - def save_processed(self, data: pd.DataFrame, filename: str) -> None: """ Save the processed dataset to a pickle file. @@ -1123,120 +1095,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ pass - def get_test_split( - self, df: pd.DataFrame, seed: Optional[int] = None - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """ - Split the input DataFrame into training and testing sets based on multilabel stratified sampling. - - This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels - in the training and testing sets is approximately the same. The split is based on the "labels" column - in the DataFrame. - - Args: - df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column - named "labels" with the multilabel data. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. - - Raises: - ValueError: If the DataFrame does not contain a column named "labels". - """ - from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit - from sklearn.model_selection import StratifiedShuffleSplit - - print("Get test data split") - - labels_list = df["labels"].tolist() - - if len(labels_list[0]) > 1: - splitter = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=self.test_split, random_state=seed - ) - else: - splitter = StratifiedShuffleSplit( - n_splits=1, test_size=self.test_split, random_state=seed - ) - - train_indices, test_indices = next(splitter.split(labels_list, labels_list)) - - df_train = df.iloc[train_indices] - df_test = df.iloc[test_indices] - return df_train, df_test - - def get_train_val_splits_given_test( - self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None - ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: - """ - Split the dataset into train and validation sets, given a test set. - Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap - - Args: - df (pd.DataFrame): The original dataset. - test_df (pd.DataFrame): The test dataset. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and - validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train - and validation DataFrames. The keys are the names of the train and validation sets, and the values - are the corresponding DataFrames. - """ - from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, - ) - from sklearn.model_selection import StratifiedShuffleSplit - - print("Split dataset into train / val with given test set") - - test_ids = test_df["ident"].tolist() - df_trainval = df[~df["ident"].isin(test_ids)] - labels_list_trainval = df_trainval["labels"].tolist() - - if self.use_inner_cross_validation: - folds = {} - kfold = MultilabelStratifiedKFold( - n_splits=self.inner_k_folds, random_state=seed - ) - for fold, (train_ids, val_ids) in enumerate( - kfold.split( - labels_list_trainval, - labels_list_trainval, - ) - ): - df_validation = df_trainval.iloc[val_ids] - df_train = df_trainval.iloc[train_ids] - folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train - folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( - df_validation - ) - - return folds - - if len(labels_list_trainval[0]) > 1: - splitter = MultilabelStratifiedShuffleSplit( - n_splits=1, - test_size=self.validation_split / (1 - self.test_split), - random_state=seed, - ) - else: - splitter = StratifiedShuffleSplit( - n_splits=1, - test_size=self.validation_split / (1 - self.test_split), - random_state=seed, - ) - - train_indices, validation_indices = next( - splitter.split(labels_list_trainval, labels_list_trainval) - ) - - df_validation = df_trainval.iloc[validation_indices] - df_train = df_trainval.iloc[train_indices] - return df_train, df_validation - def _retrieve_splits_from_csv(self) -> None: """ Retrieve previously saved data splits from splits.csv file or from provided file path. diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index d8296530..64036d4e 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -6,38 +6,19 @@ import os import random from abc import ABC -from collections import OrderedDict from itertools import cycle, permutations, product -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Generator, Literal, Optional import numpy as np import pandas as pd -import torch from rdkit import Chem -from tqdm import tqdm from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset if TYPE_CHECKING: - import fastobo import networkx as nx -# exclude some entities from the dataset because the violate disjointness axioms -CHEBI_BLACKLIST = [ - 194026, - 144321, - 156504, - 167175, - 167174, - 167178, - 183506, - 74635, - 3311, - 190439, - 92386, -] - class _ChEBIDataExtractor(_DynamicDataset, ABC): """ @@ -60,13 +41,12 @@ class _ChEBIDataExtractor(_DynamicDataset, ABC): # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ # "id" at row index 0 - # "name" at row index 1 - # "SMILES" at row index 2 - # "mol" at row index 3 - # labels starting from row index 4 + # "mol" at row index 1 + # labels starting from row index 2 _ID_IDX: int = 0 - _DATA_REPRESENTATION_IDX: int = 3 - _LABELS_START_IDX: int = 4 + _DATA_REPRESENTATION_IDX: int = 1 + _LABELS_START_IDX: int = 2 + THRESHOLD: int = None def __init__( self, @@ -115,6 +95,10 @@ def __init__( **_init_kwargs, ) + from rdkit import RDLogger + + RDLogger.DisableLog("rdApp.*") + # ------------------------------ Phase: Prepare data ----------------------------------- def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ @@ -175,7 +159,8 @@ def _load_chebi(self, version: Optional[int] = None) -> str: Returns: str: The file path of the loaded ChEBI ontology. """ - import requests + if version is None: + version = self.chebi_version if version is None: version = self.chebi_version @@ -186,12 +171,9 @@ def _load_chebi(self, version: Optional[int] = None) -> str: print( f"Missing raw ChEBI data related for version v{version}. Downloading..." ) - if version < 245: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" - else: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/ontology/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) + from chebi_utils import download_chebi_obo + + download_chebi_obo(version, dest_dir=self.raw_dir, filename=chebi_name) return chebi_path def _load_sdf(self, version: Optional[int] = None) -> str: @@ -204,10 +186,6 @@ def _load_sdf(self, version: Optional[int] = None) -> str: Returns: str: The file path of the loaded ChEBI SDF file. """ - import requests - import gzip - import shutil - if version is None: version = self.chebi_version @@ -215,132 +193,36 @@ def _load_sdf(self, version: Optional[int] = None) -> str: sdf_path = os.path.join(self.raw_dir, sdf_name) if not os.path.isfile(sdf_path): print(f"Missing raw SDF data related to version v{version}. Downloading...") - if version < 245: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/chebi_legacy/archive/rel{version}/ontology/chebi.obo" - else: - url = f"https://ftp.ebi.ac.uk/pub/databases/chebi/archive/rel{version}/SDF/chebi.sdf.gz" - r = requests.get(url, allow_redirects=True, stream=True) - open(sdf_path + ".gz", "wb").write(r.content) - with gzip.open(sdf_path + ".gz", "rb") as f_in: - with open(sdf_path, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - return sdf_path + from chebi_utils import download_chebi_sdf - def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph": - """ - Extracts the class hierarchy from the ChEBI ontology. - Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from - the chebi term documents from `.obo` file. - - Args: - data_path (str): The path to the ChEBI ontology. - - Returns: - nx.DiGraph: The class hierarchy. - """ - import fastobo - import networkx as nx - - with open(data_path, encoding="utf-8") as chebi: - chebi = "\n".join(line for line in chebi if not line.startswith("xref:")) - - elements = [] - for term_doc in fastobo.loads(chebi): - if ( - term_doc - and isinstance(term_doc.id, fastobo.id.PrefixedIdent) - and term_doc.id.prefix == "CHEBI" - ): - term_dict = term_callback(term_doc) - if term_dict and ( - not self.subset - or ( - "subset" in term_dict - and term_dict["subset"] is not None - and term_dict["subset"][0] == self.subset[0] - ) # match 3:STAR to 3_STAR, 3star, 3_star, etc. - ): - elements.append(term_dict) - - g = nx.DiGraph() - for n in elements: - g.add_node(n["id"], **n) - - # Only take the edges which connects the existing nodes, to avoid internal creation of obsolete nodes - # https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142 - g.add_edges_from( - [(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)] - ) - - print("Compute transitive closure") - return nx.transitive_closure_dag(g) + download_chebi_sdf(version, dest_dir=self.raw_dir, filename=sdf_name) + return sdf_path def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: """ Converts the graph to a raw dataset. - Uses the graph created by `_extract_class_hierarchy` method to extract the + Uses the graph to extract the raw data in Dataframe format with additional columns corresponding to each multi-label class. + Uses :func:`chebi_utils.sdf_extractor.extract_molecules` for SDF parsing. + Args: g (nx.DiGraph): The class hierarchy graph. Returns: pd.DataFrame: The raw dataset created from the graph. """ - import networkx as nx - smiles = nx.get_node_attributes(g, "smiles") - names = nx.get_node_attributes(g, "name") + # Extract mol objects from SDF using chebi-utils + from chebi_utils import build_labeled_dataset, extract_molecules - print(f"Processing {g}") + sdf_path = os.path.join(self.raw_dir, self.raw_file_names_dict["sdf"]) + mol_df = extract_molecules(sdf_path) + mol_df = mol_df[mol_df["STAR"] == self.subset[0]] if self.subset else mol_df + data, labels = build_labeled_dataset(g, mol_df, self.THRESHOLD) - molecules, smiles_list = zip( - *( - (n, smiles) - for n, smiles in ((n, smiles.get(n)) for n in smiles.keys()) - if smiles - ) - ) - data = OrderedDict(id=molecules) # `id` column at index 0 - data["name"] = [ - names.get(node) for node in molecules - ] # `name` column at index 1 - data["SMILES"] = smiles_list # `SMILES` (data representation) column at index 2 - - # # `mol` (RDKit Mol object) column at index 3 - from chembl_structure_pipeline.standardizer import ( - parse_molblock, - ) - - with open( - os.path.join(self.raw_dir, self.raw_file_names_dict["sdf"]), "rb" - ) as f: - # split input into blocks separated by "$$$$" - blocks = f.read().decode("utf-8").split("$$$$\n") - id_to_mol = dict() - for molfile in tqdm(blocks, desc="Processing SDF molecules"): - if "" not in molfile: - print(f"Skipping molfile without ChEBI ID: {molfile[:30]}...") - continue - ident = int(molfile.split("")[1].split(">")[0].split("CHEBI:")[1]) - # use same parsing strategy as CHEBI: github.com/chembl/libRDChEBI/blob/main/libRDChEBI/formats.py - mol = parse_molblock(molfile) - if mol is None: - print(f"Failed to parse molfile for CHEBI:{ident}") - continue - mol = sanitize_molecule(mol) - id_to_mol[ident] = mol - data["mol"] = [id_to_mol.get(node) for node in molecules] - - # Labels columns from index 4 onwards - for n in self.select_classes(g): - data[n] = [ - ((n in g.predecessors(node)) or (n == node)) for node in molecules - ] - - data = pd.DataFrame(data) - data = data[~data["mol"].isnull()] - data = data[~data["name"].isin(CHEBI_BLACKLIST)] + with open(os.path.join(self.classes_txt_file_path), "wt") as fout: + fout.writelines(str(label) + "\n" for label in labels) return data @@ -491,85 +373,26 @@ def _load_dict(self, input_file_path: str) -> Generator[dict[str, Any], None, No for feat, labels, ident in zip(features, all_labels, idents): yield dict(features=feat, labels=labels, ident=ident) - # ------------------------------ Phase: Dynamic Splits ----------------------------------- def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Loads encoded/transformed data and generates training, validation, and test splits. - - This method first loads encoded data from a file named `data.pt`, which is derived from either - `chebi_version` or `chebi_version_train`. It then splits the data into training, validation, and test sets. - - If `chebi_version_train` is provided: - - Loads additional encoded data from `chebi_version_train`. - - Splits this data into training and validation sets, while using the test set from `chebi_version`. - - Prunes the test set from `chebi_version` to include only labels that exist in `chebi_version_train`. - - If `chebi_version_train` is not provided: - - Splits the data from `chebi_version` into training, validation, and test sets without modification. - - Raises: - FileNotFoundError: If the required `data.pt` file(s) do not exist. Ensure that `prepare_data` - and/or `setup` methods have been called to generate the dataset files. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: - - Training set - - Validation set - - Test set """ - try: - filename = self.processed_file_names_dict["data"] - data_chebi_version = self.load_processed_data_from_file(filename) - except FileNotFoundError: - raise FileNotFoundError( - "File data.pt doesn't exists. " - "Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" - ) - df_chebi_version = pd.DataFrame(data_chebi_version) - train_df_chebi_ver, df_test_chebi_ver = self.get_test_split( - df_chebi_version, seed=self.dynamic_data_split_seed - ) - - if self.chebi_version_train is not None: - # Load encoded data derived from "chebi_version_train" - try: - filename_train = ( - self._chebi_version_train_obj.processed_file_names_dict["data"] - ) - data_chebi_train_version = torch.load( - os.path.join( - self._chebi_version_train_obj.processed_dir, filename_train - ), - weights_only=False, - ) - except FileNotFoundError: - raise FileNotFoundError( - f"File data.pt doesn't exists related to chebi_version_train {self.chebi_version_train}." - f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" - ) + filename = self.processed_file_names_dict["data"] + data = self.load_processed_data_from_file(filename) + df_data = pd.DataFrame(data) - df_chebi_train_version = pd.DataFrame(data_chebi_train_version) - # Get train/val split of data based on "chebi_version_train", but - # using test set from "chebi_version" - df_train, df_val = self.get_train_val_splits_given_test( - df_chebi_train_version, - df_test_chebi_ver, - seed=self.dynamic_data_split_seed, - ) - # Modify test set from "chebi_version" to only include the labels that - # exists in "chebi_version_train", all other entries remains same. - df_test = self._setup_pruned_test_set(df_test_chebi_ver) - else: - # Get all splits based on "chebi_version" - df_train, df_val = self.get_train_val_splits_given_test( - train_df_chebi_ver, - df_test_chebi_ver, - seed=self.dynamic_data_split_seed, - ) - df_test = df_test_chebi_ver + from chebi_utils import create_multilabel_splits - return df_train, df_val, df_test + splits = create_multilabel_splits( + df_data, + self._LABELS_START_IDX, + 1 - self.validation_split - self.test_split, + self.validation_split, + self.test_split, + self.dynamic_data_split_seed, + ) + return splits["train"], splits["val"], splits["test"] def _setup_pruned_test_set( self, df_test_chebi_version: pd.DataFrame @@ -583,15 +406,11 @@ def _setup_pruned_test_set( Returns: pd.DataFrame: The pruned test dataset. """ - classes_file_name = "classes.txt" - # Load original and new classes - with open(os.path.join(self.processed_dir_main, classes_file_name), "r") as f: + with open(os.path.join(self.classes_txt_file_path), "r") as f: orig_classes = f.readlines() with open( - os.path.join( - self._chebi_version_train_obj.processed_dir_main, classes_file_name - ), + os.path.join(self._chebi_version_train_obj.classes_txt_file_path), "r", ) as f: new_classes = f.readlines() @@ -666,7 +485,7 @@ def processed_dir(self) -> str: @property def raw_file_names_dict(self) -> dict: - return {"chebi": "chebi.obo", "sdf": "chebi.sdf"} + return {"chebi": "chebi.obo", "sdf": "chebi.sdf.gz"} @property def processed_main_file_names_dict(self) -> dict: @@ -714,7 +533,6 @@ class ChEBIOverX(_ChEBIDataExtractor): """ READER: dr.ChemDataReader = dr.ChemDataReader - THRESHOLD: int = None @property def _name(self) -> str: @@ -726,51 +544,6 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: - """ - Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. - - This method iterates over the nodes in the graph, counting the number of successors for each node. - Nodes with a number of successors greater than or equal to the defined threshold are selected. - - Note: - The input graph must be transitive closure of a directed acyclic graph. - - Args: - g (nx.Graph): The graph representing the dataset. - *args: Additional positional arguments (not used). - **kwargs: Additional keyword arguments (not used). - - Returns: - List: A sorted list of node IDs that meet the successor threshold criteria. - - Side Effects: - Writes the list of selected nodes to a file named "classes.txt" in the specified processed directory. - - Notes: - - The `THRESHOLD` attribute should be defined in the subclass of this class. - - Nodes without a 'smiles' attribute are ignored in the successor count. - """ - import networkx as nx - - smiles = nx.get_node_attributes(g, "smiles") - nodes = list( - sorted( - { - node - for node in g.nodes - if sum( - 1 if smiles[s] is not None else 0 for s in g.successors(node) - ) - >= self.THRESHOLD - } - ) - ) - filename = "classes.txt" - with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: - fout.writelines(str(node) + "\n" for node in nodes) - return nodes - class ChEBIOverXDeepSMILES(ChEBIOverX): """ @@ -856,12 +629,12 @@ class ChEBIOverXPartial(ChEBIOverX): top_class_id (int): The ID of the top class from which to extract subclasses. """ - def __init__(self, top_class_id: int, external_data_ratio: float, **kwargs): + def __init__(self, top_class_id: str, external_data_ratio: float, **kwargs): """ Initializes the ChEBIOverXPartial dataset. Args: - top_class_id (int): The ID of the top class from which to extract subclasses. + top_class_id (str): The ID of the top class from which to extract subclasses. **kwargs: Additional keyword arguments passed to the superclass initializer. external_data_ratio (float): How much external data (i.e., samples where top_class_id is no positive label) to include in the dataset. 0 means no external data, 1 means @@ -872,7 +645,7 @@ def __init__(self, top_class_id: int, external_data_ratio: float, **kwargs): if "external_data_ratio" not in kwargs: kwargs["external_data_ratio"] = external_data_ratio - self.top_class_id: int = top_class_id + self.top_class_id: str = top_class_id self.external_data_ratio: float = external_data_ratio super().__init__(**kwargs) @@ -891,69 +664,63 @@ def processed_dir_main(self) -> str: "processed", ) - def _extract_class_hierarchy(self, chebi_path: str) -> "nx.DiGraph": + def _graph_to_raw_dataset(self, g: "nx.DiGraph") -> pd.DataFrame: """ - Extracts a subset of ChEBI based on subclasses of the top class ID. + Converts the graph to a raw dataset. + Uses the graph to extract the + raw data in Dataframe format with additional columns corresponding to each multi-label class. - This method calls the superclass method to extract the full class hierarchy, - then extracts the subgraph containing only the descendants of the top class ID, including itself. + Uses :func:`chebi_utils.sdf_extractor.extract_molecules` for SDF parsing. Args: - chebi_path (str): The file path to the ChEBI ontology file. + g (nx.DiGraph): The class hierarchy graph. Returns: - nx.DiGraph: The extracted class hierarchy as a directed graph, limited to the - descendants of the top class ID. + pd.DataFrame: The raw dataset created from the graph. """ - g = super()._extract_class_hierarchy(chebi_path) - top_class_successors = list(g.successors(self.top_class_id)) + [ - self.top_class_id - ] - external_nodes = list(set(n for n in g.nodes if n not in top_class_successors)) + + # Extract mol objects from SDF using chebi-utils + from chebi_utils import ( + build_labeled_dataset, + extract_molecules, + get_hierarchy_subgraph, + ) + import networkx as nx + + sdf_path = os.path.join(self.raw_dir, self.raw_file_names_dict["sdf"]) + mol_df = extract_molecules(sdf_path) + mol_df = mol_df[mol_df["STAR"] == self.subset[0]] if self.subset else mol_df + + # take only molecules that are subclasses of the top class ID, and a certain ratio of external nodes (nodes that are not subclasses of the top class ID) + transitive_closure = nx.transitive_closure_dag(get_hierarchy_subgraph(g)) + top_class_successors = list( + transitive_closure.predecessors(self.top_class_id) + ) + [self.top_class_id] + top_class_molecules = mol_df[mol_df["chebi_id"].isin(top_class_successors)] + external_molecules = mol_df[~mol_df["chebi_id"].isin(top_class_successors)] if 0 < self.external_data_ratio < 1: n_external_nodes = int( - len(top_class_successors) + len(top_class_molecules) * self.external_data_ratio / (1 - self.external_data_ratio) ) - print( - f"Extracting {n_external_nodes} external nodes from the ChEBI dataset (ratio: {self.external_data_ratio:.2f})" + external_molecules = external_molecules.sample( + n=min(n_external_nodes, len(external_molecules)), + random_state=self.dynamic_data_split_seed, ) - external_nodes = external_nodes[: int(n_external_nodes)] elif self.external_data_ratio == 0: - external_nodes = [] + external_molecules = mol_df.iloc[0:0] + mol_df = pd.concat([top_class_molecules, external_molecules], ignore_index=True) - g = g.subgraph(top_class_successors + external_nodes) - print( - f"Subgraph contains {len(g.nodes)} nodes, of which {len(top_class_successors)} are subclasses of the top class ID {self.top_class_id}." - ) - return g + data, labels = build_labeled_dataset(g, mol_df, self.THRESHOLD) + # the dataset might contain classes that are not subclasses of the top class ID + labels_top_class = [label for label in labels if label in top_class_successors] + data = data[["chebi_id", "mol"] + labels_top_class] - def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List: - """Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself).""" - import networkx as nx + with open(os.path.join(self.classes_txt_file_path), "wt") as fout: + fout.writelines(str(label) + "\n" for label in labels_top_class) - smiles = nx.get_node_attributes(g, "smiles") - nodes = list( - sorted( - { - node - for node in g.nodes - if sum( - 1 if smiles[s] is not None else 0 for s in g.successors(node) - ) - >= self.THRESHOLD - and ( - self.top_class_id in g.predecessors(node) - or node == self.top_class_id - ) - } - ) - ) - filename = "classes.txt" - with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: - fout.writelines(str(node) + "\n" for node in nodes) - return nodes + return data class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50): @@ -983,103 +750,12 @@ class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100): pass -def chebi_to_int(s: str) -> int: - """ - Converts a ChEBI term string representation to an integer ID. - - Args: - - s (str): A ChEBI term string, e.g., "CHEBI:12345". - - Returns: - - int: The integer ID extracted from the ChEBI term string. - """ - return int(s[s.index(":") + 1 :]) - - -def term_callback(doc: "fastobo.term.TermFrame") -> Union[Dict, bool]: - """ - Extracts information from a ChEBI term document. - This function takes a ChEBI term document as input and extracts relevant information such as the term ID, parents, - parts, name, and SMILES string. It returns a dictionary containing the extracted information. - - Args: - - doc: A ChEBI term document. - - Returns: - A dictionary containing the following keys: - - "id": The ID of the ChEBI term. - - "parents": A list of parent term IDs. - - "has_part": A set of term IDs representing the parts of the ChEBI term. - - "name": The name of the ChEBI term. - - "smiles": The SMILES string associated with the ChEBI term, if available. - """ - import fastobo - - parts = set() - parents = [] - name = None - smiles = None - subset = None - for clause in doc: - if isinstance(clause, fastobo.term.PropertyValueClause): - t = clause.property_value - # chemrof:smiles_string is the new annotation property, chebi/smiles is the old one (see https://chembl.blogspot.com/2025/07/chebi-20-data-products.html) - if ( - str(t.relation) == "chemrof:smiles_string" - or str(t.relation) == "http://purl.obolibrary.org/obo/chebi/smiles" - ): - assert smiles is None - smiles = t.value - # in older chebi versions, smiles strings are synonyms - # e.g. synonym: "[F-].[Na+]" RELATED SMILES [ChEBI] - elif isinstance(clause, fastobo.term.SynonymClause): - if "SMILES" in clause.raw_value(): - assert smiles is None - smiles = clause.raw_value().split('"')[1] - elif isinstance(clause, fastobo.term.RelationshipClause): - if str(clause.typedef) == "has_part": - parts.add(chebi_to_int(str(clause.term))) - elif isinstance(clause, fastobo.term.IsAClause): - parents.append(chebi_to_int(str(clause.term))) - elif isinstance(clause, fastobo.term.NameClause): - name = str(clause.name) - elif isinstance(clause, fastobo.term.SubsetClause): - subset = str(clause.subset) - - if isinstance(clause, fastobo.term.IsObsoleteClause): - if clause.obsolete: - # if the term document contains clause as obsolete as true, skips this document. - return False - - return { - "id": chebi_to_int(str(doc.id)), - "parents": parents, - "has_part": parts, - "name": name, - "smiles": smiles, - "subset": subset, - } - - -def sanitize_molecule(mol: Chem.Mol) -> Chem.Mol: - # mirror ChEBI molecule processing - from chembl_structure_pipeline.standardizer import update_mol_valences - - mol = update_mol_valences(mol) - Chem.SanitizeMol( - mol, - sanitizeOps=Chem.SanitizeFlags.SANITIZE_FINDRADICALS - | Chem.SanitizeFlags.SANITIZE_KEKULIZE - | Chem.SanitizeFlags.SANITIZE_SETAROMATICITY - | Chem.SanitizeFlags.SANITIZE_SETCONJUGATION - | Chem.SanitizeFlags.SANITIZE_SETHYBRIDIZATION - | Chem.SanitizeFlags.SANITIZE_SYMMRINGS, - catchErrors=True, - ) - return mol - - if __name__ == "__main__": - dataset = ChEBIOver50(chebi_version=248, subset="3_STAR") + dataset = ChEBIOver50Partial( + chebi_version=247, + subset="3_STAR", + top_class_id="36700", + external_data_ratio=0.5, + ) dataset.prepare_data() dataset.setup() diff --git a/pyproject.toml b/pyproject.toml index 715f4555..45c72e01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,13 +35,12 @@ dev = [ "requests", "scikit-learn", "scipy", - "fastobo", "selfies", "jsonargparse[signatures]>=4.17", "omegaconf", "deepsmiles", - "iterative-stratification", "torchmetrics", + "chebi-utils>=0.1.1", ] linters = [ diff --git a/tests/unit/dataset_classes/testChEBIOverX.py b/tests/unit/dataset_classes/testChEBIOverX.py deleted file mode 100644 index 270b868c..00000000 --- a/tests/unit/dataset_classes/testChEBIOverX.py +++ /dev/null @@ -1,125 +0,0 @@ -import unittest -from unittest.mock import PropertyMock, mock_open, patch - -from chebai.preprocessing.datasets.chebi import ChEBIOverX -from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology - - -class TestChEBIOverX(unittest.TestCase): - @classmethod - @patch.multiple(ChEBIOverX, __abstractmethods__=frozenset()) - @patch.object(ChEBIOverX, "processed_dir_main", new_callable=PropertyMock) - @patch("os.makedirs", return_value=None) - def setUpClass(cls, mock_makedirs, mock_processed_dir_main: PropertyMock) -> None: - """ - Set up the ChEBIOverX instance with a mock processed directory path and a test graph. - - Args: - mock_makedirs: This patches os.makedirs to do nothing - mock_processed_dir_main (PropertyMock): Mocked property for the processed directory path. - """ - mock_processed_dir_main.return_value = "/mock/processed_dir" - cls.chebi_extractor = ChEBIOverX(chebi_version=231) - cls.test_graph = ChebiMockOntology.get_transitively_closed_graph() - - @patch("builtins.open", new_callable=mock_open) - def test_select_classes(self, mock_open_file: mock_open) -> None: - """ - Test the select_classes method to ensure it correctly selects nodes based on the threshold. - - Args: - mock_open_file (mock_open): Mocked open function to intercept file operations. - """ - self.chebi_extractor.THRESHOLD = 3 - selected_classes = self.chebi_extractor.select_classes(self.test_graph) - - # Check if the returned selected classes match the expected list - expected_classes = sorted([11111, 22222, 67890]) - self.assertListEqual( - selected_classes, - expected_classes, - "The selected classes do not match the expected output for the given threshold of 3.", - ) - - # Expected data as string - expected_lines = "\n".join(map(str, expected_classes)) + "\n" - - # Extract the generator passed to writelines - written_generator = mock_open_file().writelines.call_args[0][0] - written_lines = "".join(written_generator) - - # Ensure the data matches - self.assertEqual( - written_lines, - expected_lines, - "The written lines do not match the expected lines for the given threshold of 3.", - ) - - @patch("builtins.open", new_callable=mock_open) - def test_no_classes_meet_threshold(self, mock_open_file: mock_open) -> None: - """ - Test the select_classes method when no nodes meet the successor threshold. - - Args: - mock_open_file (mock_open): Mocked open function to intercept file operations. - """ - self.chebi_extractor.THRESHOLD = 5 - selected_classes = self.chebi_extractor.select_classes(self.test_graph) - - # Expected empty result - self.assertEqual( - selected_classes, - [], - "The selected classes list should be empty when no nodes meet the threshold of 5.", - ) - - # Expected data as string - expected_lines = "" - - # Extract the generator passed to writelines - written_generator = mock_open_file().writelines.call_args[0][0] - written_lines = "".join(written_generator) - - # Ensure the data matches - self.assertEqual( - written_lines, - expected_lines, - "The written lines do not match the expected lines when no nodes meet the threshold of 5.", - ) - - @patch("builtins.open", new_callable=mock_open) - def test_all_nodes_meet_threshold(self, mock_open_file: mock_open) -> None: - """ - Test the select_classes method when all nodes meet the successor threshold. - - Args: - mock_open_file (mock_open): Mocked open function to intercept file operations. - """ - self.chebi_extractor.THRESHOLD = 0 - selected_classes = self.chebi_extractor.select_classes(self.test_graph) - - expected_classes = sorted(ChebiMockOntology.get_nodes()) - # Check if the returned selected classes match the expected list - self.assertListEqual( - selected_classes, - expected_classes, - "The selected classes do not match the expected output when all nodes meet the threshold of 0.", - ) - - # Expected data as string - expected_lines = "\n".join(map(str, expected_classes)) + "\n" - - # Extract the generator passed to writelines - written_generator = mock_open_file().writelines.call_args[0][0] - written_lines = "".join(written_generator) - - # Ensure the data matches - self.assertEqual( - written_lines, - expected_lines, - "The written lines do not match the expected lines when all nodes meet the threshold of 0.", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/dataset_classes/testChebiDataExtractor.py b/tests/unit/dataset_classes/testChebiDataExtractor.py index e0b9c4bf..a9825568 100644 --- a/tests/unit/dataset_classes/testChebiDataExtractor.py +++ b/tests/unit/dataset_classes/testChebiDataExtractor.py @@ -1,12 +1,9 @@ import unittest from unittest.mock import MagicMock, PropertyMock, mock_open, patch -import networkx as nx import pandas as pd -from rdkit import Chem from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor -from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology class TestChEBIDataExtractor(unittest.TestCase): @@ -43,85 +40,6 @@ def setUpClass( mock_train_obj.processed_dir_main = "/mock/path/to/train" cls.extractor._chebi_version_train_obj = mock_train_obj - @patch( - "builtins.open", - new_callable=mock_open, - read_data=ChebiMockOntology.get_raw_data(), - ) - def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: - """ - Test the extraction of class hierarchy and validate the structure of the resulting graph. - """ - # Mock the output of fastobo.loads - graph = self.extractor._extract_class_hierarchy("fake_path") - - # Validate the graph structure - self.assertIsInstance( - graph, nx.DiGraph, "The result should be a directed graph." - ) - - # Check nodes - actual_nodes = set(graph.nodes) - self.assertEqual( - set(ChebiMockOntology.get_nodes()), - actual_nodes, - "The graph nodes do not match the expected nodes.", - ) - - # Check edges - actual_edges = set(graph.edges) - self.assertEqual( - ChebiMockOntology.get_edges_of_transitive_closure_graph(), - actual_edges, - "The graph edges do not match the expected edges.", - ) - - # Check number of nodes and edges - self.assertEqual( - ChebiMockOntology.get_number_of_nodes(), - len(actual_nodes), - "The number of nodes should match the actual number of nodes in the graph.", - ) - - self.assertEqual( - ChebiMockOntology.get_number_of_transitive_edges(), - len(actual_edges), - "The number of transitive edges should match the actual number of transitive edges in the graph.", - ) - - @patch("builtins.open", new_callable=mock_open) - @patch.object( - _ChEBIDataExtractor, - "select_classes", - return_value=ChebiMockOntology.get_nodes(), - ) - def test_graph_to_raw_dataset( - self, mock_select_classes: PropertyMock, mock_open_file: mock_open - ) -> None: - """ - Test conversion of a graph to a raw dataset and compare it with the expected DataFrame. - """ - # Mock the OBO file (for _extract_class_hierarchy) and SDF file (for _graph_to_raw_dataset) - mock_obo_data = ChebiMockOntology.get_raw_data() - mock_sdf_data = ChebiMockOntology.get_sdf_data() - - mock_open_file.side_effect = [ - mock_open(read_data=mock_obo_data).return_value, - mock_open(read_data=mock_sdf_data).return_value, - ] - - graph = self.extractor._extract_class_hierarchy("fake_path") - data_df = self.extractor._graph_to_raw_dataset(graph) - for row in data_df.itertuples(): - self.assertIsInstance(row.mol, Chem.Mol, f"No Mol object in row: {row}") - data_df["mol"] = data_df["mol"].apply(lambda x: "mol_placeholder") - - pd.testing.assert_frame_equal( - data_df, - ChebiMockOntology.get_data_in_dataframe(), - obj="The DataFrame should match the expected structure.", - ) - @patch( "builtins.open", new_callable=mock_open, read_data=b"Mocktestdata" ) # Mocking open as a binary file @@ -136,8 +54,8 @@ def test_load_dict( mock_df = pd.DataFrame( { "id": [12345, 67890, 11111, 54321], # Corrected ID - "name": ["A", "B", "C", "D"], - "SMILES": ["C1CCCCC1", "O=C=O", "C1CC=CC1", "C[Mg+]"], + # "name": ["A", "B", "C", "D"], + # "SMILES": ["C1CCCCC1", "O=C=O", "C1CC=CC1", "C[Mg+]"], "mol": ["mol1", "mol2", "mol3", "mol4"], 12345: [True, False, False, True], 67890: [False, True, True, False], diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py index 0f4bad41..568e626b 100644 --- a/tests/unit/dataset_classes/testChebiOverXPartial.py +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -2,13 +2,64 @@ from unittest.mock import mock_open, patch import networkx as nx +import pandas as pd from rdkit import Chem from chebai.preprocessing.datasets.chebi import ChEBIOverXPartial from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology -class TestChEBIOverX(unittest.TestCase): +def _build_mock_chebi_graph() -> nx.DiGraph: + """ + Build a mock ChEBI graph with is_a relation attributes on edges, + matching the format produced by chebi_utils.build_chebi_graph. + Node IDs are strings. Edge direction is child -> parent (is_a). + + Edges derived from the OBO mock data: + 12345 -> 54321 (12345 is_a 54321) + 12345 -> 67890 (12345 is_a 67890) + 54321 -> 11111 (54321 is_a 11111) + 67890 -> 22222 (67890 is_a 22222) + 99999 -> 12345 (99999 is_a 12345) + 88888 -> 67890 (88888 is_a 67890) + """ + g = nx.DiGraph() + for node in ChebiMockOntology.get_nodes(): + g.add_node(str(node), smiles="test_smiles_placeholder") + # child -> parent (matching build_chebi_graph convention) + is_a_edges = [ + ("12345", "54321"), + ("12345", "67890"), + ("54321", "11111"), + ("67890", "22222"), + ("99999", "12345"), + ("88888", "67890"), + ] + for src, dst in is_a_edges: + g.add_edge(src, dst, relation="is_a") + return g + + +def _build_mock_mol_df() -> pd.DataFrame: + """ + Build a mock molecule DataFrame matching the format returned by + chebi_utils.extract_molecules (columns: chebi_id, mol, ...). + """ + rows = [] + for smiles, chebi_id in [ + ("C1=CC=CC=C1", "12345"), + ("C1=CC=CC=C1O", "54321"), + ("C1=CC=CC=C1N", "67890"), + ("C1=CC=CC=C1F", "11111"), + ("C1=CC=CC=C1Cl", "22222"), + ("C1=CC=CC=C1Br", "99999"), + ("C1=CC=CC=C1[Mg+]", "88888"), + ]: + rows.append({"chebi_id": chebi_id, "mol": Chem.MolFromSmiles(smiles)}) + return pd.DataFrame(rows) + + +class TestChEBIOverXPartial(unittest.TestCase): @classmethod @patch.multiple(ChEBIOverXPartial, __abstractmethods__=frozenset()) @patch("os.makedirs", return_value=None) @@ -17,125 +68,150 @@ def setUpClass(cls, mock_makedirs) -> None: Set up the ChEBIOverXPartial instance with a mock processed directory path and a test graph. """ cls.chebi_extractor = ChEBIOverXPartial( - top_class_id=11111, external_data_ratio=0, chebi_version=231 + top_class_id="11111", external_data_ratio=0, chebi_version=231 ) - cls.test_graph = ChebiMockOntology.get_transitively_closed_graph() - - @patch( - "builtins.open", - new_callable=mock_open, - read_data=ChebiMockOntology.get_raw_data(), - ) - def test_extract_class_hierarchy(self, mock_open: mock_open) -> None: + cls.test_graph = _build_mock_chebi_graph() + + @patch("builtins.open", new_callable=mock_open) + @patch("chebi_utils.extract_molecules") + def test_graph_to_raw_dataset_no_external( + self, mock_extract_molecules, mock_open_file + ) -> None: """ - Test the extraction of class hierarchy and validate the structure of the resulting graph. + Test _graph_to_raw_dataset with external_data_ratio=0. + With child->parent edges, predecessors in the transitive closure are descendants. + For top_class_id="11111", descendants are: 54321, 12345, 99999. + So included IDs: {11111, 54321, 12345, 99999}. """ - # Mock the output of fastobo.loads - self.chebi_extractor.top_class_id = 11111 - graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") + mock_extract_molecules.return_value = _build_mock_mol_df() + self.chebi_extractor.top_class_id = "11111" + self.chebi_extractor.external_data_ratio = 0 + self.chebi_extractor.THRESHOLD = 1 - # Validate the graph structure - self.assertIsInstance( - graph, nx.DiGraph, "The result should be a directed graph." - ) + data_df = self.chebi_extractor._graph_to_raw_dataset(self.test_graph) - # Check nodes - expected_nodes = {11111, 54321, 12345, 99999} - expected_edges = { - (54321, 12345), - (54321, 99999), - (11111, 54321), - (11111, 12345), - (11111, 99999), - (12345, 99999), - } + result_ids = set(data_df["chebi_id"].tolist()) + expected_ids = {"11111", "54321", "12345", "99999"} self.assertEqual( - set(graph.nodes), - expected_nodes, - f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + result_ids, + expected_ids, + f"Expected molecule IDs {expected_ids}, got {result_ids}", ) - # Check edges - self.assertEqual( - expected_edges, - set(graph.edges), - "The graph edges do not match the expected edges.", - ) + # Verify each row has a valid Mol object + for _, row in data_df.iterrows(): + self.assertIsInstance( + row["mol"], + Chem.Mol, + f"Expected Mol object for chebi_id={row['chebi_id']}", + ) - # Check number of nodes and edges - self.assertEqual( - len(graph.nodes), - len(expected_nodes), - "The number of nodes should match the actual number of nodes in the graph.", - ) + @patch("builtins.open", new_callable=mock_open) + @patch("chebi_utils.extract_molecules") + def test_graph_to_raw_dataset_with_external( + self, mock_extract_molecules, mock_open_file + ) -> None: + """ + Test _graph_to_raw_dataset with external_data_ratio=1 (all external nodes included). + """ + mock_extract_molecules.return_value = _build_mock_mol_df() + self.chebi_extractor.top_class_id = "11111" + self.chebi_extractor.external_data_ratio = 1 + self.chebi_extractor.THRESHOLD = 1 + data_df = self.chebi_extractor._graph_to_raw_dataset(self.test_graph) + + # With external_data_ratio=1, all nodes should be included + result_ids = set(data_df["chebi_id"].tolist()) + expected_ids = {"11111", "54321", "12345", "99999", "22222", "67890", "88888"} self.assertEqual( - len(expected_edges), - len(graph.edges), - "The number of transitive edges should match the actual number of transitive edges in the graph.", + result_ids, + expected_ids, + f"Expected all molecule IDs {expected_ids}, got {result_ids}", ) - self.chebi_extractor.top_class_id = 22222 - graph = self.chebi_extractor._extract_class_hierarchy("fake_path") + @patch("builtins.open", new_callable=mock_open) + @patch("chebi_utils.extract_molecules") + def test_graph_to_raw_dataset_leaf_class( + self, mock_extract_molecules, mock_open_file + ) -> None: + """ + Test _graph_to_raw_dataset with a leaf node (no descendants). + For top_class_id="99999", which has no children in the hierarchy, + only the top class itself should be included. + """ + mock_extract_molecules.return_value = _build_mock_mol_df() + self.chebi_extractor.top_class_id = "99999" + self.chebi_extractor.external_data_ratio = 0 + self.chebi_extractor.THRESHOLD = 1 + + data_df = self.chebi_extractor._graph_to_raw_dataset(self.test_graph) - # Check nodes with top class as 22222 + result_ids = set(data_df["chebi_id"].tolist()) self.assertEqual( - set(graph.nodes), - {67890, 88888, 12345, 99999, 22222}, - f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + result_ids, + {"99999"}, + f"Expected only leaf node {{'99999'}}, got {result_ids}", ) - @patch( - "builtins.open", - new_callable=mock_open, - read_data=ChebiMockOntology.get_raw_data(), - ) - def test_extract_class_hierarchy_with_bottom_cls( - self, mock_open: mock_open + @patch("builtins.open", new_callable=mock_open) + @patch("chebi_utils.extract_molecules") + def test_graph_to_raw_dataset_has_label_columns( + self, mock_extract_molecules, mock_open_file ) -> None: """ - Test the extraction of class hierarchy and validate the structure of the resulting graph. + Test that _graph_to_raw_dataset produces label columns from build_labeled_dataset. """ - self.chebi_extractor.top_class_id = 88888 - graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") + mock_extract_molecules.return_value = _build_mock_mol_df() + self.chebi_extractor.top_class_id = "11111" + self.chebi_extractor.external_data_ratio = 0 + self.chebi_extractor.THRESHOLD = 1 - # Check nodes with top class as 88888 - self.assertEqual( - set(graph.nodes), - {self.chebi_extractor.top_class_id}, - f"The graph nodes do not match the expected nodes for top class {self.chebi_extractor.top_class_id} hierarchy.", + data_df = self.chebi_extractor._graph_to_raw_dataset(self.test_graph) + + # The returned DataFrame should have chebi_id, mol, and label columns + self.assertIn("chebi_id", data_df.columns) + self.assertIn("mol", data_df.columns) + # Label columns are string IDs of classes that meet the threshold + label_cols = [c for c in data_df.columns if c not in ("chebi_id", "mol")] + self.assertGreater( + len(label_cols), 0, "Expected at least one label column in the result" ) + # All label values should be boolean + for col in label_cols: + self.assertTrue( + data_df[col].dtype == bool, + f"Label column {col} should be boolean, got {data_df[col].dtype}", + ) @patch("pandas.DataFrame.to_csv") - @patch.object(ChEBIOverXPartial, "_get_data_size", return_value=4.0) + @patch.object(ChEBIOverXPartial, "_get_data_size", return_value=7.0) @patch("torch.load") - @patch( - "builtins.open", - new_callable=mock_open, - ) + @patch("builtins.open", new_callable=mock_open) + @patch("chebi_utils.extract_molecules") def test_single_label_data_split( - self, mock_open_file: mock_open, mock_load, mock_get_data_size, mock_to_csv + self, + mock_extract_molecules, + mock_open_file, + mock_load, + mock_get_data_size, + mock_to_csv, ) -> None: """ - Test the single-label data splitting functionality of the ChebiExtractor class. - - This test mocks several key methods (file operations, torch loading, and pandas functions) - to ensure that the class hierarchy is properly extracted, data is processed into a raw dataset, - and the data splitting logic works as intended without actual file I/O. + Test the single-label data splitting functionality. - It also verifies that there is no overlap between training, validation, and test sets. + Mocks file operations and chebi_utils functions to ensure that data is processed + into a raw dataset and the dynamic data splitting logic produces non-overlapping + train, validation, and test sets. """ - mock_open_file.side_effect = [ - mock_open(read_data=ChebiMockOntology.get_raw_data()).return_value, - mock_open(read_data=ChebiMockOntology.get_sdf_data()).return_value, - mock_open(read_data="").return_value, - ] - self.chebi_extractor.top_class_id = 11111 - self.chebi_extractor.THRESHOLD = 3 + mock_extract_molecules.return_value = _build_mock_mol_df() + self.chebi_extractor.top_class_id = "99999" + self.chebi_extractor.THRESHOLD = 1 + self.chebi_extractor.external_data_ratio = 1 self.chebi_extractor.chebi_version_train = None - graph: nx.DiGraph = self.chebi_extractor._extract_class_hierarchy("fake_path") - data_df = self.chebi_extractor._graph_to_raw_dataset(graph) + data_df = self.chebi_extractor._graph_to_raw_dataset(self.test_graph) + self.assertGreater(len(data_df), 0, "DataFrame should not be empty") self.assertEqual( type([row for _, row in data_df.iterrows()][0]["mol"]), Chem.Mol, @@ -143,28 +219,25 @@ def test_single_label_data_split( ) # Mock _load_dict to return the expected data structure + label_start = 2 # chebi_id=0, mol=1, labels from 2 onwards + def mock_load_dict_generator(path): for _, row in data_df.iterrows(): yield { "features": row["mol"], - "labels": row.iloc[4:].to_numpy( - dtype=bool - ), # Labels start at index 4 - "ident": row["id"], + "labels": row.iloc[label_start:].to_numpy(dtype=bool), + "ident": row["chebi_id"], } with patch.object( self.chebi_extractor, "_load_dict", side_effect=mock_load_dict_generator ): data_pt = self.chebi_extractor._load_data_from_file("fake/path") - # Verify that the data contains only 1 label - self.assertEqual( - type(data_pt), list, f"Data_pt should be a list, got {type(data_pt)}" - ) - self.assertEqual( - len(data_pt), 4, f"Data_pt should contain 4 items, got {len(data_pt)}" + + self.assertIsInstance( + data_pt, list, f"Data_pt should be a list, got {type(data_pt)}" ) - self.assertEqual(len(data_pt[0]["labels"]), 1, f"Data_pt: {data_pt}") + self.assertGreater(len(data_pt), 0, "Data_pt should not be empty") mock_load.return_value = data_pt @@ -177,21 +250,17 @@ def mock_load_dict_generator(path): val_idents = set(validation_split["ident"]) test_idents = set(test_split["ident"]) - # Ensure there is no overlap between train and test sets + # Ensure there is no overlap between any pair of splits self.assertEqual( len(train_idents.intersection(test_idents)), 0, "Train and test sets should not overlap.", ) - - # Ensure there is no overlap between validation and test sets self.assertEqual( len(val_idents.intersection(test_idents)), 0, "Validation and test sets should not overlap.", ) - - # Ensure there is no overlap between train and validation sets self.assertEqual( len(train_idents.intersection(val_idents)), 0, diff --git a/tests/unit/dataset_classes/testChebiTermCallback.py b/tests/unit/dataset_classes/testChebiTermCallback.py deleted file mode 100644 index 9ea77177..00000000 --- a/tests/unit/dataset_classes/testChebiTermCallback.py +++ /dev/null @@ -1,70 +0,0 @@ -import unittest -from typing import Any, Dict - -import fastobo -from fastobo.term import TermFrame - -from chebai.preprocessing.datasets.chebi import term_callback -from tests.unit.mock_data.ontology_mock_data import ChebiMockOntology - - -class TestChebiTermCallback(unittest.TestCase): - """ - Unit tests for the `term_callback` function used in processing ChEBI ontology terms. - """ - - @classmethod - def setUpClass(cls) -> None: - """ - Set up the test class by loading ChEBI term data and storing it in a dictionary - where keys are the term IDs and values are TermFrame instances. - """ - cls.callback_input_data: Dict[int, TermFrame] = { - int(term_doc.id.local): term_doc - for term_doc in fastobo.loads(ChebiMockOntology.get_raw_data()) - if term_doc and ":" in str(term_doc.id) - } - - def test_process_valid_terms(self) -> None: - """ - Test that `term_callback` correctly processes valid ChEBI terms. - """ - - expected_result: Dict[str, Any] = { - "id": 12345, - "parents": [54321, 67890], - "has_part": set(), - "name": "Compound A", - "smiles": "C1=CC=CC=C1", - "subset": "2_STAR", - } - - actual_dict: Dict[str, Any] = term_callback( - self.callback_input_data.get(expected_result["id"]) - ) - self.assertEqual( - expected_result, - actual_dict, - msg="term_callback should correctly extract information from valid ChEBI terms.", - ) - - def test_skip_obsolete_terms(self) -> None: - """ - Test that `term_callback` correctly skips obsolete ChEBI terms. - """ - term_callback_output = [] - for ident in ChebiMockOntology.get_obsolete_nodes_ids(): - raw_term = self.callback_input_data.get(ident) - term_dict = term_callback(raw_term) - if term_dict: - term_callback_output.append(term_dict) - - self.assertEqual( - term_callback_output, - [], - msg="The term_callback function should skip obsolete terms and return an empty list.", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py deleted file mode 100644 index b61ca80c..00000000 --- a/tests/unit/dataset_classes/testDynamicDataset.py +++ /dev/null @@ -1,381 +0,0 @@ -import unittest -from typing import Tuple -from unittest.mock import MagicMock, PropertyMock, patch - -import pandas as pd - -from chebai.preprocessing.datasets.base import _DynamicDataset - - -class TestDynamicDataset(unittest.TestCase): - """ - Test case for _DynamicDataset functionality, ensuring correct data splits and integrity - of train, validation, and test datasets. - """ - - @classmethod - @patch.multiple(_DynamicDataset, __abstractmethods__=frozenset()) - @patch.object(_DynamicDataset, "base_dir", new_callable=PropertyMock) - @patch.object(_DynamicDataset, "_name", new_callable=PropertyMock) - @patch("os.makedirs", return_value=None) - def setUpClass( - cls, - mock_makedirs, - mock_base_dir_property: PropertyMock, - mock_name_property: PropertyMock, - ) -> None: - """ - Set up a base instance of _DynamicDataset for testing with mocked properties. - """ - - # Mocking properties - mock_base_dir_property.return_value = "MockedBaseDirPropertyDynamicDataset" - mock_name_property.return_value = "MockedNamePropertyDynamicDataset" - - # Mock Data Reader - ReaderMock = MagicMock() - ReaderMock.name.return_value = "MockedReader" - _DynamicDataset.READER = ReaderMock - - # Creating an instance of the dataset - cls.dataset: _DynamicDataset = _DynamicDataset() - - # Dataset with a balanced distribution of labels - X = [ - [1, 2], - [3, 4], - [5, 6], - [7, 8], - [9, 10], - [11, 12], - [13, 14], - [15, 16], - [17, 18], - [19, 20], - [21, 22], - [23, 24], - [25, 26], - [27, 28], - [29, 30], - [31, 32], - ] - y = [ - [False, False], - [False, True], - [True, False], - [True, True], - [False, False], - [False, True], - [True, False], - [True, True], - [False, False], - [False, True], - [True, False], - [True, True], - [False, False], - [False, True], - [True, False], - [True, True], - ] - cls.data_df = pd.DataFrame( - {"ident": [f"id{i + 1}" for i in range(len(X))], "features": X, "labels": y} - ) - - def test_get_test_split_valid(self) -> None: - """ - Test splitting the dataset into train and test sets and verify balance and non-overlap. - """ - # self.dataset.train_split = 0.5 - # Test size will be 0.25 * 16 = 4 - self.dataset.test_split = 0.25 - self.dataset.validation_split = 0.25 - train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) - - # Assert the correct number of rows in train and test sets - self.assertEqual(len(train_df), 12, "Train set should contain 12 samples.") - self.assertEqual(len(test_df), 4, "Test set should contain 4 samples.") - - # Check positive and negative label counts in train and test sets - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( - test_df - ) - - # Ensure that the train and test sets have balanced positives and negatives - self.assertEqual( - train_pos_count, train_neg_count, "Train set labels should be balanced." - ) - self.assertEqual( - test_pos_count, test_neg_count, "Test set labels should be balanced." - ) - - # Assert there is no overlap between train and test sets - train_idents = set(train_df["ident"]) - test_idents = set(test_df["ident"]) - self.assertEqual( - len(train_idents.intersection(test_idents)), - 0, - "Train and test sets should not overlap.", - ) - - def test_get_test_split_missing_labels(self) -> None: - """ - Test the behavior when the 'labels' column is missing in the dataset. - """ - df_missing_labels = pd.DataFrame({"ident": ["id1", "id2"]}) - with self.assertRaises( - KeyError, msg="Expected KeyError when 'labels' column is missing." - ): - self.dataset.get_test_split(df_missing_labels) - - def test_get_test_split_seed_consistency(self) -> None: - """ - Test that splitting the dataset with the same seed produces consistent results. - """ - train_df1, test_df1 = self.dataset.get_test_split(self.data_df, seed=42) - train_df2, test_df2 = self.dataset.get_test_split(self.data_df, seed=42) - - pd.testing.assert_frame_equal( - train_df1, - train_df2, - obj="Train sets should be identical for the same seed.", - ) - pd.testing.assert_frame_equal( - test_df1, test_df2, obj="Test sets should be identical for the same seed." - ) - - def test_get_train_val_splits_given_test(self) -> None: - """ - Test splitting the dataset into train and validation sets and verify balance and non-overlap. - """ - self.dataset.use_inner_cross_validation = False - # self.dataset.train_split = 0.5 - self.dataset.test_split = 0.25 - self.dataset.validation_split = 0.25 - df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) - train_df, val_df = self.dataset.get_train_val_splits_given_test( - df_train_main, test_df, seed=42 - ) - - # Ensure there is no overlap between train and test sets - train_idents = set(train_df["ident"]) - test_idents = set(test_df["ident"]) - self.assertEqual( - len(train_idents.intersection(test_idents)), - 0, - "Train and test sets should not overlap.", - ) - - # Ensure there is no overlap between validation and test sets - val_idents = set(val_df["ident"]) - self.assertEqual( - len(val_idents.intersection(test_idents)), - 0, - "Validation and test sets should not overlap.", - ) - - # Ensure there is no overlap between train and validation sets - self.assertEqual( - len(train_idents.intersection(val_idents)), - 0, - "Train and validation sets should not overlap.", - ) - - # Check positive and negative label counts in train and validation sets - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) - - # Ensure that the train and validation sets have balanced positives and negatives - self.assertEqual( - train_pos_count, train_neg_count, "Train set labels should be balanced." - ) - self.assertEqual( - val_pos_count, val_neg_count, "Validation set labels should be balanced." - ) - - def test_get_train_val_splits_given_test_consistency(self) -> None: - """ - Test that splitting the dataset into train and validation sets with the same seed produces consistent results. - """ - test_df = self.data_df.iloc[12:] # Assume rows 12 onward are for testing - train_df1, val_df1 = self.dataset.get_train_val_splits_given_test( - self.data_df, test_df, seed=42 - ) - train_df2, val_df2 = self.dataset.get_train_val_splits_given_test( - self.data_df, test_df, seed=42 - ) - - pd.testing.assert_frame_equal( - train_df1, - train_df2, - obj="Train sets should be identical for the same seed.", - ) - pd.testing.assert_frame_equal( - val_df1, - val_df2, - obj="Validation sets should be identical for the same seed.", - ) - - def test_get_test_split_stratification(self) -> None: - """ - Test that the split into train and test sets maintains the stratification of labels. - """ - # self.dataset.train_split = 0.5 - self.dataset.test_split = 0.25 - self.dataset.validation_split = 0.25 - train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) - - number_of_labels = len(self.data_df["labels"][0]) - - # Check the label distribution in the original dataset - original_pos_count, original_neg_count = ( - self.get_positive_negative_labels_counts(self.data_df) - ) - total_count = len(self.data_df) * number_of_labels - - # Calculate the expected proportions - original_pos_proportion = original_pos_count / total_count - original_neg_proportion = original_neg_count / total_count - - # Check the label distribution in the train set - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - train_total_count = len(train_df) * number_of_labels - - # Calculate the train set proportions - train_pos_proportion = train_pos_count / train_total_count - train_neg_proportion = train_neg_count / train_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - train_pos_proportion, - original_pos_proportion, - places=1, - msg="Train set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - train_neg_proportion, - original_neg_proportion, - places=1, - msg="Train set labels should maintain original negative label proportion.", - ) - - # Check the label distribution in the test set - test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( - test_df - ) - test_total_count = len(test_df) * number_of_labels - - # Calculate the test set proportions - test_pos_proportion = test_pos_count / test_total_count - test_neg_proportion = test_neg_count / test_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - test_pos_proportion, - original_pos_proportion, - places=1, - msg="Test set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - test_neg_proportion, - original_neg_proportion, - places=1, - msg="Test set labels should maintain original negative label proportion.", - ) - - def test_get_train_val_splits_given_test_stratification(self) -> None: - """ - Test that the split into train and validation sets maintains the stratification of labels. - """ - self.dataset.use_inner_cross_validation = False - # self.dataset.train_split = 0.5 - self.dataset.test_split = 0.25 - self.dataset.validation_split = 0.25 - - df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) - train_df, val_df = self.dataset.get_train_val_splits_given_test( - df_train_main, test_df, seed=42 - ) - - number_of_labels = len(self.data_df["labels"][0]) - - # Check the label distribution in the original dataset - original_pos_count, original_neg_count = ( - self.get_positive_negative_labels_counts(self.data_df) - ) - total_count = len(self.data_df) * number_of_labels - - # Calculate the expected proportions - original_pos_proportion = original_pos_count / total_count - original_neg_proportion = original_neg_count / total_count - - # Check the label distribution in the train set - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - train_total_count = len(train_df) * number_of_labels - - # Calculate the train set proportions - train_pos_proportion = train_pos_count / train_total_count - train_neg_proportion = train_neg_count / train_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - train_pos_proportion, - original_pos_proportion, - places=1, - msg="Train set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - train_neg_proportion, - original_neg_proportion, - places=1, - msg="Train set labels should maintain original negative label proportion.", - ) - - # Check the label distribution in the validation set - val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) - val_total_count = len(val_df) * number_of_labels - - # Calculate the validation set proportions - val_pos_proportion = val_pos_count / val_total_count - val_neg_proportion = val_neg_count / val_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - val_pos_proportion, - original_pos_proportion, - places=1, - msg="Validation set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - val_neg_proportion, - original_neg_proportion, - places=1, - msg="Validation set labels should maintain original negative label proportion.", - ) - - @staticmethod - def get_positive_negative_labels_counts(df: pd.DataFrame) -> Tuple[int, int]: - """ - Count the number of True and False values within the labels column. - - Args: - df (pd.DataFrame): The DataFrame containing the 'labels' column. - - Returns: - Tuple[int, int]: A tuple containing the counts of True and False values, respectively. - """ - true_count = sum(sum(label) for label in df["labels"]) - false_count = sum(len(label) - sum(label) for label in df["labels"]) - return true_count, false_count - - -if __name__ == "__main__": - unittest.main()