From 930e4dba42621b1d7708cac6bbad33de4ec96625 Mon Sep 17 00:00:00 2001 From: Timm Ruland Date: Sat, 4 Apr 2026 12:55:25 +0000 Subject: [PATCH] feat: added enitity parameter for w&b logging --- src/modalities/config/config.py | 1 + .../logging_broker/subscriber_impl/results_subscriber.py | 6 ++++-- .../logging_broker/subscriber_impl/subscriber_factory.py | 8 +++++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 46696aa3b..13a37103b 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -492,6 +492,7 @@ class EvaluationResultToDiscSubscriberConfig(BaseModel): class WandBEvaluationResultSubscriberConfig(BaseModel): global_rank: int + entity: Optional[str] = None project: str experiment_id: str mode: WandbMode diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index 08e2cdc56..d924a2e78 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -65,14 +65,16 @@ def __init__( project: str, experiment_id: str, mode: WandbMode, - logging_directory: Path, + logging_directory: Path | None, config_file_path: Path, + entity: str | None = None, ) -> None: super().__init__() with open(config_file_path, "r", encoding="utf-8") as file: config = yaml.safe_load(file) self.run = wandb.init( + entity=entity, project=project, name=experiment_id, mode=mode.value.lower(), @@ -81,7 +83,7 @@ def __init__( settings=wandb.Settings(init_timeout=120), ) - self.run.log_artifact(config_file_path, name=f"config_{wandb.run.id}", type="config") + self.run.log_artifact(config_file_path, name=f"config_{self.run.id}", type="config") def consume_dict(self, message_dict: dict[str, Any]): for k, v in message_dict.items(): diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py index 51ee8f984..f463c0448 100644 --- a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -68,6 +68,7 @@ def get_wandb_result_subscriber( mode: WandbMode, config_file_path: Path, directory: Optional[Path] = None, + entity: Optional[str] = None, ) -> WandBEvaluationResultSubscriber: if global_rank == 0 and (mode != WandbMode.DISABLED): if directory is not None: @@ -88,7 +89,12 @@ def get_wandb_result_subscriber( absolute_dir = None result_subscriber = WandBEvaluationResultSubscriber( - project, experiment_id, mode, absolute_dir, config_file_path + project=project, + experiment_id=experiment_id, + mode=mode, + logging_directory=absolute_dir, + config_file_path=config_file_path, + entity=entity, ) else: result_subscriber = ResultsSubscriberFactory.get_dummy_result_subscriber()