Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ def _execute(
for metric_name, metric in metrics.items():
metric.update(pr, tar)
self._log_metrics(prefix, metrics, len(batch))
if isinstance(d, dict) and "loss" not in d:
print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}")
return d

def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
Expand Down
1 change: 1 addition & 0 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ def _retrieve_splits_from_csv(self) -> None:
new_labels = old_labels[:, label_mapping]
df_data["labels"] = list(new_labels)

splits_df["id"] = splits_df["id"].astype(str)
train_ids = splits_df[splits_df["split"] == "train"]["id"]
validation_ids = splits_df[splits_df["split"] == "validation"]["id"]
test_ids = splits_df[splits_df["split"] == "test"]["id"]
Expand Down
235 changes: 68 additions & 167 deletions tutorials/eval_model_basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,28 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-02T13:47:31.150545Z",
"start_time": "2024-04-02T13:47:27.181585Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\HP\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"cpu\n"
]
}
],
"outputs": [],
"source": [
"from chebai.result.utils import (\n",
" evaluate_model,\n",
" load_results_from_buffer,\n",
")\n",
"from chebai.result.classification import print_metrics\n",
"from chebai.models.electra import Electra\n",
"from chebai.preprocessing.datasets.chebi import ChEBIOver50\n",
"import os\n",
"import torch\n",
"from chebai.result.prediction import Predictor\n",
"\n",
"\n",
"DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"print(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "bdb5fc6919cf72be",
"metadata": {
"ExecuteTime": {
Expand All @@ -56,165 +35,87 @@
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Check for processed data in data\\chebi_v231\\ChEBI50\\processed\\smiles_token\n",
"Cross-validation enabled: False\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Check for processed data in data\\chebi_v231\\ChEBI50\\processed\n",
"saving 771 tokens to C:\\Users\\HP\\Desktop\\github-aditya0by0\\python-chebai\\chebai\\preprocessing\\bin\\smiles_token\\tokens.txt...\n",
"first 10 tokens: ['[*-]', '[Al-]', '[F-]', '.', '[H]', '[N]', '(', ')', '[Ag+]', 'C']\n",
"Get test data split\n",
"Split dataset into train / val with given test set\n"
]
}
],
"outputs": [],
"source": [
"# specify the checkpoint name\n",
"checkpoint_name = \"my_trained_model\"\n",
"checkpoint_path = os.path.join(\"logs\", f\"{checkpoint_name}.ckpt\")\n",
"kind = \"test\" # replace with \"train\" / \"validation\" to run on train / validation sets\n",
"buffer_dir = os.path.join(\"results_buffer\", checkpoint_name, kind)\n",
"# make sure to use the same data module and model class that were used during training\n",
"data_module = ChEBIOver50(\n",
" chebi_version=231,\n",
")\n",
"# load chebi data if missing and perform dynamic splits\n",
"data_module.prepare_data()\n",
"data_module.setup()\n",
"# specify the checkpoint path\n",
"checkpoint_path = os.path.join(\"path\", \"to\", \"checkpoint.ckpt\")\n",
"\n",
"model_class = Electra"
"# initialize the predictor - this class loads the data module parameters and the checkpoint\n",
"predictor = Predictor(checkpoint_path=checkpoint_path)\n",
"# you can override parameters if needed (e.g., if the splits file has a different location)\n",
"predictor._dm.splits_file_path = os.path.join(\"path\", \"to\", \"splits_file.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fa1276b47def696c",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-02T13:47:38.418564Z",
"start_time": "2024-04-02T13:47:37.861168Z"
},
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████| 10/10 [00:06<00:00, 1.54it/s]\n"
]
}
],
"execution_count": null,
"id": "3d057f2f",
"metadata": {},
"outputs": [],
"source": [
"# evaluates model, stores results in buffer_dir\n",
"model = model_class.load_from_checkpoint(checkpoint_path, pretrained_checkpoint=None)\n",
"if buffer_dir is None:\n",
" preds, labels = evaluate_model(\n",
" model,\n",
" data_module,\n",
" buffer_dir=buffer_dir,\n",
" # No need to provide this parameter for Chebi dataset, \"kind\" parameter should be provided\n",
" # filename=data_module.processed_file_names_dict[kind],\n",
" batch_size=10,\n",
" kind=kind,\n",
" )\n",
"else:\n",
" evaluate_model(\n",
" model,\n",
" data_module,\n",
" buffer_dir=buffer_dir,\n",
" # No need to provide this parameter for Chebi dataset, \"kind\" parameter should be provided\n",
" # filename=data_module.processed_file_names_dict[kind],\n",
" batch_size=10,\n",
" kind=kind,\n",
" )\n",
" # load data from buffer_dir\n",
" preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE)"
"# Predict SMILES strings from file or directly\n",
"# predictor.predict_from_file(os.path.join(\"path\", \"to\", \"smiles_file.txt\"))\n",
"preds_smiles = predictor.predict_smiles(\n",
" [\n",
" \"OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)O\",\n",
" \"CCCCC[C@H](O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C)(C)COP(=O)(O)OP(=O)(O)OC[C@H]1O[C@@H](n2cnc3c(N)ncnc32)[C@H](O)[C@@H]1OP(=O)(O)O\",\n",
" ]\n",
")\n",
"for i, pred in enumerate(preds_smiles):\n",
" print(f\"SMILES {i}\")\n",
" for p, label in zip(pred, predictor._classification_labels):\n",
" if p > 0.5:\n",
" print(f\" Predicted CHEBI:{label} with {p:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "201f750c475b4677",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"execution_count": null,
"id": "372c03f2",
"metadata": {},
"outputs": [],
"source": [
"# Load classes from the classes.txt\n",
"with open(os.path.join(data_module.processed_dir_main, \"classes.txt\"), \"r\") as f:\n",
" classes = [line.strip() for line in f.readlines()]"
"# Predict the whole test set (and also get the labels for evaluation)\n",
"# to get the validation set predictions, just replace test_dataloader with val_dataloader\n",
"test_dl = predictor._dm.test_dataloader()\n",
"preds = []\n",
"labels = []\n",
"for batch_idx, batch in enumerate(test_dl):\n",
" batch_preds = predictor._model.test_step(batch, batch_idx)\n",
" preds.append(batch_preds[\"preds\"])\n",
" labels.append(batch_preds[\"labels\"])\n",
"preds = torch.cat(preds)\n",
"labels = torch.cat(labels)\n",
"print(preds.shape)\n",
"print(labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e567cd2fb1718baf",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Macro-F1: 0.290936\n",
"Micro-F1: 0.890380\n",
"Balanced Accuracy: 0.507610\n",
"Macro-Precision: 0.021964\n",
"Micro-Precision: 0.908676\n",
"Macro-Recall: 0.020987\n",
"Micro-Recall: 0.872807\n",
"Top 10 classes (F1-score):\n",
"1. 23367 - F1: 1.000000\n",
"2. 33259 - F1: 1.000000\n",
"3. 36914 - F1: 1.000000\n",
"4. 24431 - F1: 1.000000\n",
"5. 33238 - F1: 1.000000\n",
"6. 36357 - F1: 1.000000\n",
"7. 37577 - F1: 1.000000\n",
"8. 24867 - F1: 1.000000\n",
"9. 33579 - F1: 0.974026\n",
"10. 24866 - F1: 0.973684\n",
"Found 63 classes with F1-score == 0 (and non-zero labels): 17792, 22563, 22632, 22712, 24062, 24834, 25108, 25693, 25697, 25698, 25699, 25806, 26151, 26217, 26218, 26421, 26469, 29347, 32988, 33240, 33256, 33296, 33299, 33304, 33597, 33598, 33635, 33655, 33659, 33661, 33670, 33671, 33836, 33976, 35217, 35273, 35479, 35618, 36364, 36562, 36916, 36962, 36963, 37141, 37143, 37622, 37929, 37960, 38101, 38104, 38166, 38835, 39203, 46850, 47704, 47916, 48592, 50047, 50995, 72544, 79389, 83565, 139358\n"
]
}
],
"execution_count": null,
"id": "ff9cfd6a",
"metadata": {},
"outputs": [],
"source": [
"# output relevant metrics\n",
"print_metrics(\n",
" preds,\n",
" labels.to(torch.int),\n",
" DEVICE,\n",
" classes=classes,\n",
" markdown_output=False,\n",
" top_k=10,\n",
")"
"# print the metrics\n",
"from torchmetrics import F1Score, Recall, Precision\n",
"\n",
"micro_f1 = F1Score(num_labels=preds.shape[1], task=\"multilabel\", average=\"micro\")\n",
"micro_precision = Precision(\n",
" num_labels=preds.shape[1], task=\"multilabel\", average=\"micro\"\n",
")\n",
"micro_recall = Recall(num_labels=preds.shape[1], task=\"multilabel\", average=\"micro\")\n",
"print(f\"Micro F1: {micro_f1(preds, labels)}\")\n",
"print(f\"Micro Precision: {micro_precision(preds, labels)}\")\n",
"print(f\"Micro Recall: {micro_recall(preds, labels)}\")\n",
"macro_f1 = F1Score(num_labels=preds.shape[1], task=\"multilabel\", average=\"macro\")\n",
"macro_precision = Precision(\n",
" num_labels=preds.shape[1], task=\"multilabel\", average=\"macro\"\n",
")\n",
"macro_recall = Recall(num_labels=preds.shape[1], task=\"multilabel\", average=\"macro\")\n",
"print(f\"Macro F1: {macro_f1(preds, labels)}\")\n",
"print(f\"Macro Precision: {macro_precision(preds, labels)}\")\n",
"print(f\"Macro Recall: {macro_recall(preds, labels)}\")"
]
}
],
Expand Down
Loading