From 06c0c76f8f4e1fafa31ea16c610c3a36b20e294f Mon Sep 17 00:00:00 2001 From: Johan Larsson Date: Wed, 2 Nov 2022 16:44:08 +0100 Subject: [PATCH] feat: add standardization as parameter for dataset --- datasets/breheny.py | 7 +++++-- datasets/libsvm.py | 7 +++++-- datasets/simulated.py | 13 +++++++++++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/datasets/breheny.py b/datasets/breheny.py index 8bacbd6..21a8110 100644 --- a/datasets/breheny.py +++ b/datasets/breheny.py @@ -44,17 +44,20 @@ class Dataset(BaseDataset): parameters = { "dataset": ["Scheetz2006", "Rhee2006", "bcTCGA"], + "standardize" : [True, False] } install_cmd = "conda" requirements = ["rpy2", "numpy", "scipy", "appdirs", "r"] - def __init__(self, dataset="bcTCGA"): + def __init__(self, dataset="bcTCGA", standardize=True): self.dataset = dataset self.X, self.y = None, None + self.standardize = standardize def get_data(self): X, y = fetch_breheny(self.dataset) - X, y = preprocess_data(X, y) + if self.standardize: + X, y = preprocess_data(X, y) return dict(X=X, y=y) diff --git a/datasets/libsvm.py b/datasets/libsvm.py index 578dc52..f71b372 100644 --- a/datasets/libsvm.py +++ b/datasets/libsvm.py @@ -12,17 +12,20 @@ class Dataset(BaseDataset): parameters = { "dataset": ["finance", "finance-tf-idf", "YearPredictionMSD"], + "standardize" : [True, False] } install_cmd = "conda" requirements = ["pip:libsvmdata"] - def __init__(self, dataset="bodyfat"): + def __init__(self, dataset="bodyfat", standardize=True): self.dataset = dataset self.X, self.y = None, None + self.standardize = standardize def get_data(self): X, y = fetch_libsvm(self.dataset) - X, y = preprocess_data(X, y) + if self.standardize: + X, y = preprocess_data(X, y) return dict(X=X, y=y) diff --git a/datasets/simulated.py b/datasets/simulated.py index f2eba28..da9921d 100644 --- a/datasets/simulated.py +++ b/datasets/simulated.py @@ -15,16 +15,24 @@ class Dataset(BaseDataset): (200, 10_000, 20), ], "rho": [0, 0.5], + "standardize": [True, False], } def __init__( - self, n_samples=10, n_features=50, n_signals=5, rho=0, random_state=27 + self, + n_samples=10, + n_features=50, + n_signals=5, + rho=0, + random_state=27, + standardize=True, ): self.n_samples = n_samples self.n_features = n_features self.n_signals = n_signals self.random_state = random_state self.rho = rho + self.standardize = standardize def get_data(self): X, y, _ = make_correlated_data( @@ -35,6 +43,7 @@ def get_data(self): random_state=self.random_state, ) - X, y = preprocess_data(X, y) + if self.standardize: + X, y = preprocess_data(X, y) return dict(X=X, y=y)