diff --git a/chebai/models/base.py b/chebai/models/base.py index df060e9a..cea135e0 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -343,8 +343,6 @@ def _execute( for metric_name, metric in metrics.items(): metric.update(pr, tar) self._log_metrics(prefix, metrics, len(batch)) - if isinstance(d, dict) and "loss" not in d: - print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}") return d def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int): diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 03054278..f67a5518 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1132,6 +1132,7 @@ def _retrieve_splits_from_csv(self) -> None: new_labels = old_labels[:, label_mapping] df_data["labels"] = list(new_labels) + splits_df["id"] = splits_df["id"].astype(str) train_ids = splits_df[splits_df["split"] == "train"]["id"] validation_ids = splits_df[splits_df["split"] == "validation"]["id"] test_ids = splits_df[splits_df["split"] == "test"]["id"] diff --git a/tutorials/eval_model_basic.ipynb b/tutorials/eval_model_basic.ipynb index a2c570e1..d5c90223 100644 --- a/tutorials/eval_model_basic.ipynb +++ b/tutorials/eval_model_basic.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "initial_id", "metadata": { "ExecuteTime": { @@ -10,33 +10,12 @@ "start_time": "2024-04-02T13:47:27.181585Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\HP\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cpu\n" - ] - } - ], + "outputs": [], "source": [ - "from chebai.result.utils import (\n", - " evaluate_model,\n", - " load_results_from_buffer,\n", - ")\n", - "from chebai.result.classification import print_metrics\n", - "from chebai.models.electra import Electra\n", - "from chebai.preprocessing.datasets.chebi import ChEBIOver50\n", "import os\n", "import torch\n", + "from chebai.result.prediction import Predictor\n", + "\n", "\n", "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(DEVICE)" @@ -44,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "bdb5fc6919cf72be", "metadata": { "ExecuteTime": { @@ -56,165 +35,87 @@ "outputs_hidden": false } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Check for processed data in data\\chebi_v231\\ChEBI50\\processed\\smiles_token\n", - "Cross-validation enabled: False\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Check for processed data in data\\chebi_v231\\ChEBI50\\processed\n", - "saving 771 tokens to C:\\Users\\HP\\Desktop\\github-aditya0by0\\python-chebai\\chebai\\preprocessing\\bin\\smiles_token\\tokens.txt...\n", - "first 10 tokens: ['[*-]', '[Al-]', '[F-]', '.', '[H]', '[N]', '(', ')', '[Ag+]', 'C']\n", - "Get test data split\n", - "Split dataset into train / val with given test set\n" - ] - } - ], + "outputs": [], "source": [ - "# specify the checkpoint name\n", - "checkpoint_name = \"my_trained_model\"\n", - "checkpoint_path = os.path.join(\"logs\", f\"{checkpoint_name}.ckpt\")\n", - "kind = \"test\" # replace with \"train\" / \"validation\" to run on train / validation sets\n", - "buffer_dir = os.path.join(\"results_buffer\", checkpoint_name, kind)\n", - "# make sure to use the same data module and model class that were used during training\n", - "data_module = ChEBIOver50(\n", - " chebi_version=231,\n", - ")\n", - "# load chebi data if missing and perform dynamic splits\n", - "data_module.prepare_data()\n", - "data_module.setup()\n", + "# specify the checkpoint path\n", + "checkpoint_path = os.path.join(\"path\", \"to\", \"checkpoint.ckpt\")\n", "\n", - "model_class = Electra" + "# initialize the predictor - this class loads the data module parameters and the checkpoint\n", + "predictor = Predictor(checkpoint_path=checkpoint_path)\n", + "# you can override parameters if needed (e.g., if the splits file has a different location)\n", + "predictor._dm.splits_file_path = os.path.join(\"path\", \"to\", \"splits_file.csv\")" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "fa1276b47def696c", - "metadata": { - "ExecuteTime": { - "end_time": "2024-04-02T13:47:38.418564Z", - "start_time": "2024-04-02T13:47:37.861168Z" - }, - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|████████████████████████████████| 10/10 [00:06<00:00, 1.54it/s]\n" - ] - } - ], + "execution_count": null, + "id": "3d057f2f", + "metadata": {}, + "outputs": [], "source": [ - "# evaluates model, stores results in buffer_dir\n", - "model = model_class.load_from_checkpoint(checkpoint_path, pretrained_checkpoint=None)\n", - "if buffer_dir is None:\n", - " preds, labels = evaluate_model(\n", - " model,\n", - " data_module,\n", - " buffer_dir=buffer_dir,\n", - " # No need to provide this parameter for Chebi dataset, \"kind\" parameter should be provided\n", - " # filename=data_module.processed_file_names_dict[kind],\n", - " batch_size=10,\n", - " kind=kind,\n", - " )\n", - "else:\n", - " evaluate_model(\n", - " model,\n", - " data_module,\n", - " buffer_dir=buffer_dir,\n", - " # No need to provide this parameter for Chebi dataset, \"kind\" parameter should be provided\n", - " # filename=data_module.processed_file_names_dict[kind],\n", - " batch_size=10,\n", - " kind=kind,\n", - " )\n", - " # load data from buffer_dir\n", - " preds, labels = load_results_from_buffer(buffer_dir, device=DEVICE)" + "# Predict SMILES strings from file or directly\n", + "# predictor.predict_from_file(os.path.join(\"path\", \"to\", \"smiles_file.txt\"))\n", + "preds_smiles = predictor.predict_smiles(\n", + " [\n", + " \"OC[C@H](COP(=O)([O-])OCC[N+](C)(C)C)O\",\n", + " \"CCCCC[C@H](O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C)(C)COP(=O)(O)OP(=O)(O)OC[C@H]1O[C@@H](n2cnc3c(N)ncnc32)[C@H](O)[C@@H]1OP(=O)(O)O\",\n", + " ]\n", + ")\n", + "for i, pred in enumerate(preds_smiles):\n", + " print(f\"SMILES {i}\")\n", + " for p, label in zip(pred, predictor._classification_labels):\n", + " if p > 0.5:\n", + " print(f\" Predicted CHEBI:{label} with {p:.4f}\")" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "201f750c475b4677", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, + "execution_count": null, + "id": "372c03f2", + "metadata": {}, "outputs": [], "source": [ - "# Load classes from the classes.txt\n", - "with open(os.path.join(data_module.processed_dir_main, \"classes.txt\"), \"r\") as f:\n", - " classes = [line.strip() for line in f.readlines()]" + "# Predict the whole test set (and also get the labels for evaluation)\n", + "# to get the validation set predictions, just replace test_dataloader with val_dataloader\n", + "test_dl = predictor._dm.test_dataloader()\n", + "preds = []\n", + "labels = []\n", + "for batch_idx, batch in enumerate(test_dl):\n", + " batch_preds = predictor._model.test_step(batch, batch_idx)\n", + " preds.append(batch_preds[\"preds\"])\n", + " labels.append(batch_preds[\"labels\"])\n", + "preds = torch.cat(preds)\n", + "labels = torch.cat(labels)\n", + "print(preds.shape)\n", + "print(labels.shape)" ] }, { "cell_type": "code", - "execution_count": 8, - "id": "e567cd2fb1718baf", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Macro-F1: 0.290936\n", - "Micro-F1: 0.890380\n", - "Balanced Accuracy: 0.507610\n", - "Macro-Precision: 0.021964\n", - "Micro-Precision: 0.908676\n", - "Macro-Recall: 0.020987\n", - "Micro-Recall: 0.872807\n", - "Top 10 classes (F1-score):\n", - "1. 23367 - F1: 1.000000\n", - "2. 33259 - F1: 1.000000\n", - "3. 36914 - F1: 1.000000\n", - "4. 24431 - F1: 1.000000\n", - "5. 33238 - F1: 1.000000\n", - "6. 36357 - F1: 1.000000\n", - "7. 37577 - F1: 1.000000\n", - "8. 24867 - F1: 1.000000\n", - "9. 33579 - F1: 0.974026\n", - "10. 24866 - F1: 0.973684\n", - "Found 63 classes with F1-score == 0 (and non-zero labels): 17792, 22563, 22632, 22712, 24062, 24834, 25108, 25693, 25697, 25698, 25699, 25806, 26151, 26217, 26218, 26421, 26469, 29347, 32988, 33240, 33256, 33296, 33299, 33304, 33597, 33598, 33635, 33655, 33659, 33661, 33670, 33671, 33836, 33976, 35217, 35273, 35479, 35618, 36364, 36562, 36916, 36962, 36963, 37141, 37143, 37622, 37929, 37960, 38101, 38104, 38166, 38835, 39203, 46850, 47704, 47916, 48592, 50047, 50995, 72544, 79389, 83565, 139358\n" - ] - } - ], + "execution_count": null, + "id": "ff9cfd6a", + "metadata": {}, + "outputs": [], "source": [ - "# output relevant metrics\n", - "print_metrics(\n", - " preds,\n", - " labels.to(torch.int),\n", - " DEVICE,\n", - " classes=classes,\n", - " markdown_output=False,\n", - " top_k=10,\n", - ")" + "# print the metrics\n", + "from torchmetrics import F1Score, Recall, Precision\n", + "\n", + "micro_f1 = F1Score(num_labels=preds.shape[1], task=\"multilabel\", average=\"micro\")\n", + "micro_precision = Precision(\n", + " num_labels=preds.shape[1], task=\"multilabel\", average=\"micro\"\n", + ")\n", + "micro_recall = Recall(num_labels=preds.shape[1], task=\"multilabel\", average=\"micro\")\n", + "print(f\"Micro F1: {micro_f1(preds, labels)}\")\n", + "print(f\"Micro Precision: {micro_precision(preds, labels)}\")\n", + "print(f\"Micro Recall: {micro_recall(preds, labels)}\")\n", + "macro_f1 = F1Score(num_labels=preds.shape[1], task=\"multilabel\", average=\"macro\")\n", + "macro_precision = Precision(\n", + " num_labels=preds.shape[1], task=\"multilabel\", average=\"macro\"\n", + ")\n", + "macro_recall = Recall(num_labels=preds.shape[1], task=\"multilabel\", average=\"macro\")\n", + "print(f\"Macro F1: {macro_f1(preds, labels)}\")\n", + "print(f\"Macro Precision: {macro_precision(preds, labels)}\")\n", + "print(f\"Macro Recall: {macro_recall(preds, labels)}\")" ] } ],