Skip to content

Commit 1bc3bbb

Browse files
author
SamoraHunter
committed
Refactor code to comply with Black formatting standards
This commit applies automated formatting changes to improve code readability and consistency. Specifically, it addresses line length violations and argument wrapping across several modules, including H2O classifier wrappers, the grid search pipeline, and utility classes. Changes include: Wrapping long function calls and dictionary definitions in H2OBaseClassifier.py. Breaking long lines in grid_search_cross_validate.py and hyperparameter_search.py. Formatting dictionary access in project_score_save.py. Updating test assertions in test_h2o_base_classifier.py to match the new style.
1 parent 63e4154 commit 1bc3bbb

5 files changed

Lines changed: 30 additions & 14 deletions

File tree

ml_grid/model_classes/H2OBaseClassifier.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,7 @@ def _prepare_fit(
354354
train_df = pd.concat([X, y_series], axis=1)
355355
# Optimization: Provide destination_frame to avoid expensive gc.get_referrers() name search
356356
train_h2o = h2o.H2OFrame(
357-
train_df,
358-
destination_frame=f"train_{uuid.uuid4().hex}"
357+
train_df, destination_frame=f"train_{uuid.uuid4().hex}"
359358
)
360359

361360
# Explicitly convert the outcome column to factor
@@ -412,7 +411,9 @@ def _get_model_params(self) -> Dict[str, Any]:
412411
self._estimator_signature_cache[self.estimator_class] = inspect.signature(
413412
self.estimator_class
414413
).parameters
415-
valid_param_keys = set(self._estimator_signature_cache[self.estimator_class].keys())
414+
valid_param_keys = set(
415+
self._estimator_signature_cache[self.estimator_class].keys()
416+
)
416417

417418
model_params = {
418419
key: value for key, value in all_params.items() if key in valid_param_keys
@@ -633,7 +634,7 @@ def predict(self, X: pd.DataFrame) -> np.ndarray:
633634
X,
634635
column_names=self.feature_names_,
635636
column_types=col_types,
636-
destination_frame=f"pred_{uuid.uuid4().hex}"
637+
destination_frame=f"pred_{uuid.uuid4().hex}",
637638
)
638639

639640
# Optimization: Use the temporary frame directly.
@@ -761,7 +762,7 @@ def predict_proba(self, X: pd.DataFrame) -> np.ndarray:
761762
X,
762763
column_names=self.feature_names_,
763764
column_types=col_types,
764-
destination_frame=f"prob_{uuid.uuid4().hex}"
765+
destination_frame=f"prob_{uuid.uuid4().hex}",
765766
)
766767
except Exception as e:
767768
raise RuntimeError(f"Failed to create H2O frame for prediction: {e}")
@@ -926,7 +927,7 @@ def _get_param_names(self):
926927
for p in init_signature.parameters.values()
927928
if p.name not in ("self", "args", "kwargs")
928929
]
929-
930+
930931
init_params = self._init_param_names_cache[cls]
931932

932933
# Optimization: Use sets for O(1) lookup

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,13 @@ def __init__(
231231
self.logger.debug("Dropping 'client_idcode' from training data.")
232232
self.X_train = self.X_train.drop(columns=["client_idcode"], errors="ignore")
233233
if isinstance(self.X_test, pd.DataFrame):
234-
self.X_test = self.X_test.drop(columns=["client_idcode"], errors="ignore")
234+
self.X_test = self.X_test.drop(
235+
columns=["client_idcode"], errors="ignore"
236+
)
235237
if isinstance(self.X_test_orig, pd.DataFrame):
236-
self.X_test_orig = self.X_test_orig.drop(columns=["client_idcode"], errors="ignore")
238+
self.X_test_orig = self.X_test_orig.drop(
239+
columns=["client_idcode"], errors="ignore"
240+
)
237241

238242
max_param_space_iter_value = (
239243
self.global_params.max_param_space_iter_value
@@ -292,7 +296,9 @@ def __init__(
292296
if "catboost" in method_name.lower() and hasattr(
293297
current_algorithm, "set_params"
294298
):
295-
ml_grid_object.logger.info("Silencing CatBoost verbose output and file writing.")
299+
ml_grid_object.logger.info(
300+
"Silencing CatBoost verbose output and file writing."
301+
)
296302
current_algorithm.set_params(verbose=0, allow_writing_files=False)
297303

298304
# Check for GPU availability and set device for torch-based models
@@ -563,7 +569,9 @@ def __init__(
563569
# --- OPTIMIZATION: Force threading backend for search ---
564570
# Prevents 'loky' overhead (abort_everything ~273s) which occurs even with n_jobs=1
565571
with joblib.parallel_backend("threading"):
566-
current_algorithm = search.run_search(X_train_reset, y_train_search)
572+
current_algorithm = search.run_search(
573+
X_train_reset, y_train_search
574+
)
567575

568576
except TimeoutError:
569577
self.logger.warning("Timeout occurred during hyperparameter search.")

ml_grid/pipeline/hyperparameter_search.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def run_search(self, X_train: pd.DataFrame, y_train: pd.Series) -> BaseEstimator
250250
y_train_reset = y_train_reset.values
251251

252252
# Force integer encoding if possible to speed up unique() calls
253-
if hasattr(y_train_reset, "dtype") and not pd.api.types.is_integer_dtype(y_train_reset):
253+
if hasattr(y_train_reset, "dtype") and not pd.api.types.is_integer_dtype(
254+
y_train_reset
255+
):
254256
try:
255257
y_train_reset = y_train_reset.astype(int)
256258
except (ValueError, TypeError):

ml_grid/util/project_score_save.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def update_score_log(
185185
logger = logging.getLogger("ml_grid")
186186
logger.info("Writing grid permutation to log")
187187
# write line to best grid scores---------------------
188-
188+
189189
# --- OPTIMIZATION: Construct dictionary first to avoid slow DataFrame element-wise setting ---
190190
row_data = {}
191191
column_list = _get_score_log_columns(list(global_params.metric_list.keys()))
@@ -253,7 +253,9 @@ def update_score_log(
253253
for key_1 in ml_grid_object.local_param_dict.get("data"):
254254
# print(key_1)
255255
if key_1 in column_list:
256-
row_data[key_1] = ml_grid_object.local_param_dict.get("data").get(key_1)
256+
row_data[key_1] = ml_grid_object.local_param_dict.get(
257+
"data"
258+
).get(key_1)
257259

258260
current_f = ml_grid_object.final_column_list
259261
# current_f = list(self.X_test.columns)

tests/test_h2o_base_classifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,10 @@ def test_predict_successful(
213213

214214
# 2. Check that the new frame creation logic was called
215215
mock_h2o_frame.assert_called_once_with(
216-
X, column_names=list(X.columns), column_types=classifier_instance.feature_types_, destination_frame=ANY
216+
X,
217+
column_names=list(X.columns),
218+
column_types=classifier_instance.feature_types_,
219+
destination_frame=ANY,
217220
)
218221

219222
# Optimization: h2o.assign and h2o.get_frame should NO LONGER be called

0 commit comments

Comments
 (0)