From 82232c0795347e21f51cbae4de06cf254d70fb26 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 11 Mar 2026 19:18:43 +0100 Subject: [PATCH 1/6] add REMEDIAL implementation --- .../preprocessing/datasets/ml_overbagging.py | 238 ++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 chebai/preprocessing/datasets/ml_overbagging.py diff --git a/chebai/preprocessing/datasets/ml_overbagging.py b/chebai/preprocessing/datasets/ml_overbagging.py new file mode 100644 index 00000000..d4350f93 --- /dev/null +++ b/chebai/preprocessing/datasets/ml_overbagging.py @@ -0,0 +1,238 @@ +import os +from typing import Any + +import pandas as pd +import tqdm + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.datasets.chebi import ChEBIOver50 + + +class _ResampledDynamicDataset(_DynamicDataset): + """ + A dataset class that extends _DynamicDataset with an additional resampled data file. + + This class produces two pickle files during data preparation: + - ``data_standard.pkl``: The standard dataset created by the regular pipeline. + - ``data_resampled.pkl``: A resampled version of the standard dataset, produced by + :meth:`_resample_data`. + + Subclasses must implement :meth:`_resample_data` to define the resampling strategy. + + Args: + **kwargs: Additional keyword arguments passed to :class:`_DynamicDataset`. + """ + + _RESAMPLED_PKL_FILENAME: str = "data_resampled.pkl" + + def __init__(self, **kwargs): + # splits_file_path has to be provided + if "splits_file_path" not in kwargs: + raise ValueError( + "`splits_file_path` must be provided for resampled datasets. To generate a new dataset, use the regular dataset classes" + ) + super().__init__(**kwargs) + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares both the standard and resampled data files. + + First runs the regular data preparation pipeline (producing ``data_standard.pkl``), + then generates ``data_resampled.pkl`` by applying :meth:`_resample_data` to the + standard data. + """ + + resampled_path = os.path.join( + self.processed_dir_main, self._RESAMPLED_PKL_FILENAME + ) + if not os.path.isfile(resampled_path): + print( + f"Missing resampled data file (`{self._RESAMPLED_PKL_FILENAME}`). Generating..." + ) + standard_pkl_path = os.path.join( + self.processed_dir_main, self.processed_main_file_names_dict["data"] + ) + if standard_pkl_path is None: + raise FileNotFoundError( + f"Standard data file `{self._STANDARD_PKL_FILENAME}` not found " + f"in {self.processed_dir_main}" + ) + standard_df = pd.read_pickle(standard_pkl_path) + splits_df = pd.read_csv(self.splits_file_path) + splits_df["id"] = splits_df["id"].astype(str) + train_ids = splits_df[splits_df["split"] == "train"]["id"].values + + resampled_df = self._resample_data(standard_df, train_ids) + self.save_processed(resampled_df, self._RESAMPLED_PKL_FILENAME) + + def scumble(self, label_imbalance_ratios): + if len(label_imbalance_ratios) == 0: + return None + geometric_mean_ir = label_imbalance_ratios.prod() ** ( + 1 / len(label_imbalance_ratios) + ) + arithmetic_mean_ir = label_imbalance_ratios.mean() + scumble_score = 1 - geometric_mean_ir / arithmetic_mean_ir + return scumble_score + + def _resample_data( + self, data: pd.DataFrame, train_instances: list[str] + ) -> pd.DataFrame: + """ + Resample the standard ChEBI dataset. + + Subclasses must implement this method to define a resampling strategy + (e.g., oversampling minority classes, undersampling majority classes). + + Args: + data (pd.DataFrame): The standard dataset as produced by the regular + data preparation pipeline. + + Returns: + pd.DataFrame: The resampled dataset. + """ + print("Resampling with REMEDIAL...") + print(data.head()) + labels = data.columns[2:] + print(f"Number of labels: {len(labels)}, first 10 labels: {labels[:10]}") + label_frequencies = data[labels].sum() + print("Label frequencies before resampling:") + print(len(label_frequencies), label_frequencies[:10]) + max_freq = label_frequencies.max() + print(f"Maximum label frequency: {max_freq}") + irlbl = max_freq / label_frequencies + print("Imbalance ratio per label:") + print(len(irlbl), irlbl[:10]) + meanir = irlbl.mean() + print(f"Mean imbalance ratio: {meanir}") + with open( + os.path.join(self.processed_dir_main, "label_imbalance_ratios.csv"), "w" + ) as f: + for label, ir in irlbl.items(): + f.write(f"{label},{ir}\n") + + train_data = data[data["chebi_id"].isin(train_instances)] + if os.path.isfile(os.path.join(self.processed_dir_main, "data_scumble.csv")): + print("Scumble scores already calculated, loading from file...") + scumble_df = pd.read_csv( + os.path.join(self.processed_dir_main, "data_scumble.csv") + ) + scumble_df["chebi_id"] = scumble_df["chebi_id"].astype(str) + scumble_dict = dict(zip(scumble_df["chebi_id"], scumble_df["scumble"])) + train_data["scumble"] = train_data["chebi_id"].map(scumble_dict) + else: + for row in tqdm.tqdm( + train_data.itertuples(), + total=len(train_data), + desc="Calculating scumble scores", + ): + label_values = row[3:] + label_imbalance_ratios = irlbl[[v == 1 for v in label_values]] + scumble_score = self.scumble(label_imbalance_ratios) + train_data.at[row[0], "scumble"] = scumble_score + with open( + os.path.join(self.processed_dir_main, "data_scumble.csv"), "w" + ) as f: + f.write("chebi_id,scumble\n") + for row in train_data.itertuples(): + f.write(f"{row.chebi_id},{row.scumble}\n") + scumble_mean = train_data["scumble"].mean() + print(f"Mean scumble score: {scumble_mean}") + + # split labels into majority labels (irlbl > meanir) and minority labels (irlbl <= meanir) + minority_labels = irlbl[irlbl > meanir].index + majority_labels = irlbl[irlbl <= meanir].index + print( + f"Majority labels: {len(majority_labels)}, first 10: {majority_labels[:10]}" + ) + print( + f"Minority labels: {len(minority_labels)}, first 10: {minority_labels[:10]}" + ) + + # split instances where scumble > mean into two copies, one with only majority labels and one with only minority labels + # Drop train instances with NaN scumble (no labels) + nan_scumble_idx = train_data.index[train_data["scumble"].isna()] + # Identify train instances to split + high_scumble = train_data[train_data["scumble"] > scumble_mean] + + # Build majority and minority copies of high-scumble rows with zeroed-out labels + majority_rows = high_scumble[data.columns].copy() + majority_rows[minority_labels] = 0 + + minority_rows = high_scumble[data.columns].copy() + minority_rows[majority_labels] = 0 + + # Indices to remove from the original data: NaN-scumble rows + rows that were split + indices_to_drop = nan_scumble_idx.union(high_scumble.index) + + resampled_data = pd.concat( + [ + data.drop(index=indices_to_drop.intersection(data.index)), + majority_rows, + minority_rows, + ], + ignore_index=True, + ) + print( + "Data resampling completed, dataset size after resampling:", + len(resampled_data), + ) + print(resampled_data.head()) + return resampled_data + + # ------------------------------ Properties ----------------------------------- + @property + def processed_main_file_names_dict(self) -> dict: + """ + Returns a dictionary of all main processed file names, including both the + standard and resampled pickle files. + """ + d = super().processed_main_file_names_dict + d["data_resampled"] = self._RESAMPLED_PKL_FILENAME + return d + + @property + def processed_file_names_dict(self) -> dict: + return { + "data": "data_resampled.pt", + } + + def setup_processed(self) -> None: + """ + Instead of data.pkl, use resampled data as basis for processing + + Returns: + None + """ + os.makedirs(self.processed_dir, exist_ok=True) + transformed_file_name = self.processed_file_names_dict["data"] + print( + f"Missing transformed data (`{transformed_file_name}` file). Transforming data.... " + ) + import torch + + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.processed_main_file_names_dict["data_resampled"], + ) + ), + os.path.join(self.processed_dir, transformed_file_name), + ) + + +class ChEBI50ResampledDataset(_ResampledDynamicDataset, ChEBIOver50): + pass + + +if __name__ == "__main__": + dataset = ChEBI50ResampledDataset( + chebi_version="248", + splits_file_path=os.path.join( + "data", "chebi_v248", "ChEBI50", "processed", "splits_chebi50_v248.csv" + ), + ) + dataset.prepare_data() + dataset.setup() From 8e3ab3ac26ea574c5b6ec31f4f42ea4123d9ba20 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 11 Mar 2026 19:43:37 +0100 Subject: [PATCH 2/6] fix chebi50resampled --- chebai/preprocessing/datasets/ml_overbagging.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/chebai/preprocessing/datasets/ml_overbagging.py b/chebai/preprocessing/datasets/ml_overbagging.py index d4350f93..da439346 100644 --- a/chebai/preprocessing/datasets/ml_overbagging.py +++ b/chebai/preprocessing/datasets/ml_overbagging.py @@ -12,13 +12,6 @@ class _ResampledDynamicDataset(_DynamicDataset): """ A dataset class that extends _DynamicDataset with an additional resampled data file. - This class produces two pickle files during data preparation: - - ``data_standard.pkl``: The standard dataset created by the regular pipeline. - - ``data_resampled.pkl``: A resampled version of the standard dataset, produced by - :meth:`_resample_data`. - - Subclasses must implement :meth:`_resample_data` to define the resampling strategy. - Args: **kwargs: Additional keyword arguments passed to :class:`_DynamicDataset`. """ @@ -80,10 +73,7 @@ def _resample_data( self, data: pd.DataFrame, train_instances: list[str] ) -> pd.DataFrame: """ - Resample the standard ChEBI dataset. - - Subclasses must implement this method to define a resampling strategy - (e.g., oversampling minority classes, undersampling majority classes). + Resample the standard ChEBI dataset with REMEDIAL. Args: data (pd.DataFrame): The standard dataset as produced by the regular @@ -223,12 +213,12 @@ def setup_processed(self) -> None: ) -class ChEBI50ResampledDataset(_ResampledDynamicDataset, ChEBIOver50): +class ChEBI50Resampled(ChEBIOver50, _ResampledDynamicDataset): pass if __name__ == "__main__": - dataset = ChEBI50ResampledDataset( + dataset = ChEBI50Resampled( chebi_version="248", splits_file_path=os.path.join( "data", "chebi_v248", "ChEBI50", "processed", "splits_chebi50_v248.csv" From 7f91a254834e1fad82fcfa0590ab4bc2085331c4 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 11 Mar 2026 19:59:47 +0100 Subject: [PATCH 3/6] move chebi50resampled to chebi.py --- chebai/preprocessing/datasets/chebi.py | 25 +++++++++++++++---- .../preprocessing/datasets/ml_overbagging.py | 24 ------------------ 2 files changed, 20 insertions(+), 29 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 64036d4e..03571e4f 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -15,6 +15,7 @@ from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.datasets.ml_overbagging import _ResampledDynamicDataset if TYPE_CHECKING: import networkx as nx @@ -597,6 +598,19 @@ class ChEBIOver50(ChEBIOverX): THRESHOLD: int = 50 +class ChEBI50Resampled(ChEBIOver50, _ResampledDynamicDataset): + """ + A class for extracting data from the ChEBI dataset with a threshold of 50 for selecting classes. + + Inherits from ChEBIOverX. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (50). + """ + + THRESHOLD: int = 50 + + class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100): """ A class for extracting data from the ChEBI dataset with DeepChem SMILES reader and a threshold of 100. @@ -751,11 +765,12 @@ class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100): if __name__ == "__main__": - dataset = ChEBIOver50Partial( - chebi_version=247, - subset="3_STAR", - top_class_id="36700", - external_data_ratio=0.5, + dataset = ChEBI50Resampled( + chebi_version="248", + splits_file_path=os.path.join( + "data", "chebi_v248", "ChEBI50", "processed", "splits_chebi50_v248.csv" + ), + batch_size=32, ) dataset.prepare_data() dataset.setup() diff --git a/chebai/preprocessing/datasets/ml_overbagging.py b/chebai/preprocessing/datasets/ml_overbagging.py index da439346..aedfcef3 100644 --- a/chebai/preprocessing/datasets/ml_overbagging.py +++ b/chebai/preprocessing/datasets/ml_overbagging.py @@ -5,7 +5,6 @@ import tqdm from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.chebi import ChEBIOver50 class _ResampledDynamicDataset(_DynamicDataset): @@ -18,14 +17,6 @@ class _ResampledDynamicDataset(_DynamicDataset): _RESAMPLED_PKL_FILENAME: str = "data_resampled.pkl" - def __init__(self, **kwargs): - # splits_file_path has to be provided - if "splits_file_path" not in kwargs: - raise ValueError( - "`splits_file_path` must be provided for resampled datasets. To generate a new dataset, use the regular dataset classes" - ) - super().__init__(**kwargs) - # ------------------------------ Phase: Prepare data ----------------------------------- def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ @@ -211,18 +202,3 @@ def setup_processed(self) -> None: ), os.path.join(self.processed_dir, transformed_file_name), ) - - -class ChEBI50Resampled(ChEBIOver50, _ResampledDynamicDataset): - pass - - -if __name__ == "__main__": - dataset = ChEBI50Resampled( - chebi_version="248", - splits_file_path=os.path.join( - "data", "chebi_v248", "ChEBI50", "processed", "splits_chebi50_v248.csv" - ), - ) - dataset.prepare_data() - dataset.setup() From 99060dbb9e06f630ee6b1c0324921efcab6b52de Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 11 Mar 2026 20:04:08 +0100 Subject: [PATCH 4/6] fix inheritance --- chebai/preprocessing/datasets/chebi.py | 2 +- chebai/preprocessing/datasets/ml_overbagging.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 03571e4f..c773a00e 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -598,7 +598,7 @@ class ChEBIOver50(ChEBIOverX): THRESHOLD: int = 50 -class ChEBI50Resampled(ChEBIOver50, _ResampledDynamicDataset): +class ChEBI50Resampled(_ResampledDynamicDataset, ChEBIOver50): """ A class for extracting data from the ChEBI dataset with a threshold of 50 for selecting classes. diff --git a/chebai/preprocessing/datasets/ml_overbagging.py b/chebai/preprocessing/datasets/ml_overbagging.py index aedfcef3..195cf960 100644 --- a/chebai/preprocessing/datasets/ml_overbagging.py +++ b/chebai/preprocessing/datasets/ml_overbagging.py @@ -17,6 +17,14 @@ class _ResampledDynamicDataset(_DynamicDataset): _RESAMPLED_PKL_FILENAME: str = "data_resampled.pkl" + def __init__(self, **kwargs): + # splits_file_path has to be provided + if "splits_file_path" not in kwargs: + raise ValueError( + "`splits_file_path` must be provided for resampled datasets. To generate a new dataset, use the regular dataset classes" + ) + super().__init__(**kwargs) + # ------------------------------ Phase: Prepare data ----------------------------------- def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ From 4cdd3b9b443d850d4a754e262219e797553d5770 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 12 Mar 2026 15:09:42 +0100 Subject: [PATCH 5/6] change label datatype to bool (smaller file size) and set removed labels to None --- chebai/preprocessing/datasets/ml_overbagging.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/ml_overbagging.py b/chebai/preprocessing/datasets/ml_overbagging.py index 195cf960..92d49569 100644 --- a/chebai/preprocessing/datasets/ml_overbagging.py +++ b/chebai/preprocessing/datasets/ml_overbagging.py @@ -147,10 +147,10 @@ def _resample_data( # Build majority and minority copies of high-scumble rows with zeroed-out labels majority_rows = high_scumble[data.columns].copy() - majority_rows[minority_labels] = 0 + majority_rows[minority_labels] = None minority_rows = high_scumble[data.columns].copy() - minority_rows[majority_labels] = 0 + minority_rows[majority_labels] = None # Indices to remove from the original data: NaN-scumble rows + rows that were split indices_to_drop = nan_scumble_idx.union(high_scumble.index) @@ -163,6 +163,9 @@ def _resample_data( ], ignore_index=True, ) + for col in resampled_data.columns[2:]: + resampled_data[col] = resampled_data[col].astype(bool) + print( "Data resampling completed, dataset size after resampling:", len(resampled_data), From 2a4a596411f7b0e82e6aaa915d8e76dfcf7fdc7b Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 12 Mar 2026 16:58:55 +0100 Subject: [PATCH 6/6] add bagging and ml-ros --- chebai/preprocessing/datasets/chebi.py | 35 +-- .../preprocessing/datasets/ml_overbagging.py | 221 +++++++++++++++++- 2 files changed, 240 insertions(+), 16 deletions(-) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index c773a00e..4797c67a 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -15,7 +15,11 @@ from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.ml_overbagging import _ResampledDynamicDataset +from chebai.preprocessing.datasets.ml_overbagging import ( + _BootstrapDynamicDataset, + _MLROSDynamicDataset, + _ResampledDynamicDataset, +) if TYPE_CHECKING: import networkx as nx @@ -598,19 +602,6 @@ class ChEBIOver50(ChEBIOverX): THRESHOLD: int = 50 -class ChEBI50Resampled(_ResampledDynamicDataset, ChEBIOver50): - """ - A class for extracting data from the ChEBI dataset with a threshold of 50 for selecting classes. - - Inherits from ChEBIOverX. - - Attributes: - THRESHOLD (int): The threshold for selecting classes (50). - """ - - THRESHOLD: int = 50 - - class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100): """ A class for extracting data from the ChEBI dataset with DeepChem SMILES reader and a threshold of 100. @@ -764,12 +755,26 @@ class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100): pass +class ChEBI50Resampled(_ResampledDynamicDataset, ChEBIOver50): + pass + + +class ChEBI50Boostrapped(_BootstrapDynamicDataset, ChEBIOver50): + pass + + +class ChEBI50MLROS(_MLROSDynamicDataset, ChEBIOver50): + pass + + if __name__ == "__main__": - dataset = ChEBI50Resampled( + dataset = ChEBI50MLROS( chebi_version="248", splits_file_path=os.path.join( "data", "chebi_v248", "ChEBI50", "processed", "splits_chebi50_v248.csv" ), + take_from_file="data_resampled.pkl", + add_to_file="data_bag1_standard.pkl", batch_size=32, ) dataset.prepare_data() diff --git a/chebai/preprocessing/datasets/ml_overbagging.py b/chebai/preprocessing/datasets/ml_overbagging.py index 92d49569..dac9ca62 100644 --- a/chebai/preprocessing/datasets/ml_overbagging.py +++ b/chebai/preprocessing/datasets/ml_overbagging.py @@ -1,4 +1,5 @@ import os +import random from typing import Any import pandas as pd @@ -30,7 +31,7 @@ def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ Prepares both the standard and resampled data files. - First runs the regular data preparation pipeline (producing ``data_standard.pkl``), + First runs the regular data preparation pipeline (producing ``data.pkl``), then generates ``data_resampled.pkl`` by applying :meth:`_resample_data` to the standard data. """ @@ -213,3 +214,221 @@ def setup_processed(self) -> None: ), os.path.join(self.processed_dir, transformed_file_name), ) + + +def bootstrap_data(data: pd.DataFrame, train_instances: list[str]) -> pd.DataFrame: + """ + Bootstrap the training instances in the dataset. + + Args: + data (pd.DataFrame): The standard dataset as produced by the regular + data preparation pipeline. + + Returns: + pd.DataFrame: The bootstrapped dataset. + """ + print("Bootstrapping data...") + train_data = data[data["chebi_id"].isin(train_instances)] + bootstrapped_data = train_data.sample( + n=len(train_data), replace=True, random_state=42 + ) + # Add non-train instances back to the bootstrapped data + non_train_data = data[~data["chebi_id"].isin(train_instances)] + bootstrapped_data = pd.concat( + [bootstrapped_data, non_train_data], ignore_index=True + ) + return bootstrapped_data + + +class _BootstrapDynamicDataset(_DynamicDataset): + """ + A dataset class that extends _DynamicDataset by bootstrapping the base dataset. + + Args: + **kwargs: Additional keyword arguments passed to :class:`_DynamicDataset`. + """ + + def __init__(self, bag_name: str, input_data_file: str, **kwargs): + # splits_file_path has to be provided + if "splits_file_path" not in kwargs: + raise ValueError( + "`splits_file_path` must be provided for bootstrapping datasets. To generate a new dataset, use the regular dataset classes" + ) + super().__init__(**kwargs) + self.bag_name = bag_name + self.input_data_file = input_data_file # filename in processed_dir_main to use as input for bootstrapping + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares both the base data file and a bag. + + First runs the regular data preparation pipeline, + then generates bags` by applying :meth:`_bootstrap_data` to the + standard data. + """ + + bag_path = os.path.join( + self.processed_dir_main, self.processed_main_file_names_dict["data"] + ) + if not os.path.isfile(bag_path): + print( + f"Missing bag file (`{self.processed_main_file_names_dict['data']}`). Generating..." + ) + standard_pkl_path = os.path.join( + self.processed_dir_main, self.input_data_file + ) + if standard_pkl_path is None: + raise FileNotFoundError( + f"Standard data file `{standard_pkl_path}` not found " + ) + standard_df = pd.read_pickle(standard_pkl_path) + splits_df = pd.read_csv(self.splits_file_path) + splits_df["id"] = splits_df["id"].astype(str) + train_ids = splits_df[splits_df["split"] == "train"]["id"].values + + bag_df = bootstrap_data(standard_df, train_ids) + self.save_processed(bag_df, self.processed_main_file_names_dict["data"]) + + @property + def processed_main_file_names_dict(self) -> dict: + """ + Returns a dictionary of all main processed file names. + """ + d = {"data": f"data_{self.bag_name}.pkl"} + return d + + @property + def processed_file_names_dict(self) -> dict: + return { + "data": f"data_{self.bag_name}.pt", + } + + +def oversample( + data: pd.DataFrame, train_instances: list[str], sampling_rate: float = 0.1 +) -> pd.DataFrame: + """ + Oversample the training instances in the dataset using ML-ROS. + + Args: + data (pd.DataFrame): The standard dataset as produced by the regular dataset classes. + train_instances (list[str]): A list of instance IDs to oversample. + sampling_rate (float): The rate at which to oversample the training instances. + + Returns: + pd.DataFrame: The oversampled dataset. + """ + data = data.reset_index(drop=True) + # Implementation for oversampling logic + samples_to_add = sampling_rate * len(train_instances) + print(f"Need to add {samples_to_add} samples to data") + # calculate label imbalance ratios + labels = data.columns[2:] + label_frequencies = data[labels].sum() + max_freq = label_frequencies.max() + irlbl = max_freq / label_frequencies + meanir = irlbl.mean() + print(f"Mean imbalance ratio: {meanir:.2f}") + # get bags for all labels where irlbl > meanir + minority_labels = irlbl[irlbl > meanir].index + print(f"Oversampling {len(minority_labels)} minority labels") + minority_bags = dict() + for label in minority_labels: + minority_bags[label] = list(data[data[label] == 1].index) + new_samples = [] + round_idx = 1 + while samples_to_add > 0: + minority_bags_next_round = dict() + for label, bag in minority_bags.items(): + new_sample = bag[random.randint(0, len(bag) - 1)] + bag.append(new_sample) + new_samples.append(new_sample) + samples_to_add -= 1 + irlbl_bag = max_freq / len(bag) + if irlbl_bag > meanir: + minority_bags_next_round[label] = bag + minority_bags = minority_bags_next_round + if round_idx % 5 == 0: + print( + f"Round {round_idx} finished, {samples_to_add} samples to go, {len(minority_bags)} minority bags left" + ) + round_idx += 1 + + new_samples_df = data.iloc[new_samples] + print(f"Adding {len(new_samples_df)} samples to data") + return new_samples_df + + +class _MLROSDynamicDataset(_DynamicDataset): + """ + A dataset class that extends _DynamicDataset by applying ML-ROS to the base dataset. + Takes a dataset from which to oversample and a dataset to which to add the oversampled data as inputs + (might be the same or different, e.g. sample from REMEDIAL dataset, add data to bags). + + Args: + **kwargs: Additional keyword arguments passed to :class:`_DynamicDataset`. + """ + + def __init__( + self, + take_from_file: str, + add_to_file: str, + sampling_rate: float = 0.1, + **kwargs, + ): + # splits_file_path has to be provided + if "splits_file_path" not in kwargs: + raise ValueError( + "`splits_file_path` must be provided for ML-ROS datasets. To generate a new dataset, use the regular dataset classes" + ) + super().__init__(**kwargs) + self.take_from_file = take_from_file + self.add_to_file = add_to_file + self.sampling_rate = sampling_rate + + def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares the oversampled dataset. + """ + + oversampled_path = os.path.join( + self.processed_dir_main, self.processed_main_file_names_dict["data"] + ) + if not os.path.isfile(oversampled_path): + print( + f"Missing oversampled file (`{self.processed_main_file_names_dict['data']}`). Generating..." + ) + take_from_pkl_path = os.path.join( + self.processed_dir_main, self.take_from_file + ) + add_to_pkl_path = os.path.join(self.processed_dir_main, self.add_to_file) + if take_from_pkl_path is None: + raise FileNotFoundError(f"File `{take_from_pkl_path}` not found ") + if add_to_pkl_path is None: + raise FileNotFoundError(f"File `{add_to_pkl_path}` not found ") + take_from_df = pd.read_pickle(take_from_pkl_path) + add_to_df = pd.read_pickle(add_to_pkl_path) + splits_df = pd.read_csv(self.splits_file_path) + splits_df["id"] = splits_df["id"].astype(str) + train_ids = splits_df[splits_df["split"] == "train"]["id"].values + extra_samples = oversample(take_from_df, train_ids, self.sampling_rate) + add_to_df = pd.concat([add_to_df, extra_samples], ignore_index=True) + + self.save_processed(add_to_df, self.processed_main_file_names_dict["data"]) + + @property + def processed_main_file_names_dict(self) -> dict: + """ + Returns a dictionary of all main processed file names. + """ + d = { + "data": f"{self.add_to_file[:-4]}_oversampled_with_{self.sampling_rate:.1f}_from_{self.take_from_file[:-4]}.pkl" + } + return d + + @property + def processed_file_names_dict(self) -> dict: + return { + "data": f"{self.add_to_file[:-4]}_oversampled_with_{self.sampling_rate:.1f}_from_{self.take_from_file[:-4]}.pt", + }