From 9a82ee7185474cae0714081709b01ae863846692 Mon Sep 17 00:00:00 2001 From: matthiasbienvenu Date: Fri, 20 Mar 2026 22:37:03 +0100 Subject: [PATCH 1/2] fix: moved ai hyperparameters to config.py --- .../scripts/launch_train_multiprocessing.py | 19 ++++--------------- src/simulation/src/simulation/config.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/simulation/scripts/launch_train_multiprocessing.py b/src/simulation/scripts/launch_train_multiprocessing.py index 527baac..422cf3b 100644 --- a/src/simulation/scripts/launch_train_multiprocessing.py +++ b/src/simulation/scripts/launch_train_multiprocessing.py @@ -39,17 +39,6 @@ net_arch=[512, 512, 512], ) - ppo_args: Dict[str, Any] = dict( - n_steps=4096, - n_epochs=10, - batch_size=256, - learning_rate=3e-4, - gamma=0.99, - verbose=1, - normalize_advantage=True, - device=c.device, - ) - save_path = ( Path("~/.cache/autotech/checkpoints").expanduser() / c.ExtractorClass.__name__ ) @@ -61,12 +50,12 @@ if valid_files: model_path = max(valid_files, key=lambda x: int(x.name.rstrip(".zip"))) print(f"Loading model {model_path.name}") - model = PPO.load(model_path, envs, **ppo_args, policy_kwargs=policy_kwargs) + model = PPO.load(model_path, envs, **c.ppo_args, policy_kwargs=policy_kwargs) i = int(model_path.name.rstrip(".zip")) + 1 print(f"Model found, loading {model_path}") else: - model = PPO("MlpPolicy", envs, **ppo_args, policy_kwargs=policy_kwargs) + model = PPO("MlpPolicy", envs, **c.ppo_args, policy_kwargs=policy_kwargs) i = 0 print("Model not found, creating a new one") @@ -93,12 +82,12 @@ from utils import PlotModelIO model.learn( - total_timesteps=500_000, + total_timesteps=c.total_timesteps, progress_bar=False, callback=PlotModelIO(), ) else: - model.learn(total_timesteps=500_000, progress_bar=True) + model.learn(total_timesteps=c.total_timesteps, progress_bar=True) print("iteration over") # TODO: we could just use a callback to save checkpoints or export the model to onnx diff --git a/src/simulation/src/simulation/config.py b/src/simulation/src/simulation/config.py index 1e8864d..05a38d7 100644 --- a/src/simulation/src/simulation/config.py +++ b/src/simulation/src/simulation/config.py @@ -1,5 +1,6 @@ # just a file that lets us define some constants that are used in multiple files the simulation import logging +from typing import Any, Dict from torch.cuda import is_available @@ -9,6 +10,7 @@ TemporalResNetExtractor, ) +# Webots environments config n_map = 2 n_simulations = 1 n_vehicles = 2 @@ -18,11 +20,26 @@ lidar_max_range = 12.0 device = "cuda" if is_available() else "cpu" + +# Training config +total_timesteps = 500_000 +ppo_args: Dict[str, Any] = dict( + n_steps=4096, + n_epochs=10, + batch_size=256, + learning_rate=3e-4, + gamma=0.99, + verbose=1, + normalize_advantage=True, + device=device, +) ExtractorClass = TemporalResNetExtractor context_size = ExtractorClass.context_size lidar_horizontal_resolution = ExtractorClass.lidar_horizontal_resolution camera_horizontal_resolution = ExtractorClass.camera_horizontal_resolution n_sensors = ExtractorClass.n_sensors + +# Logging config LOG_LEVEL = logging.INFO FORMATTER = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") From ac9279799da37e2d4483f464389f1790d0a0c316 Mon Sep 17 00:00:00 2001 From: matthiasbienvenu Date: Tue, 24 Mar 2026 12:23:38 +0100 Subject: [PATCH 2/2] feat: we can choose to go backwards or respawn when crashing I also move some stuff to config.py --- .../scripts/launch_train_multiprocessing.py | 21 ++++--------------- src/simulation/src/simulation/config.py | 19 ++++++++++++++++- src/simulation/src/utils/onnx_utils.py | 3 +-- .../controller_world_supervisor.py | 2 +- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/simulation/scripts/launch_train_multiprocessing.py b/src/simulation/scripts/launch_train_multiprocessing.py index 422cf3b..fc01f97 100644 --- a/src/simulation/scripts/launch_train_multiprocessing.py +++ b/src/simulation/scripts/launch_train_multiprocessing.py @@ -30,18 +30,7 @@ ] ) - policy_kwargs: Dict[str, Any] = dict( - features_extractor_class=c.ExtractorClass, - # features_extractor_kwargs=dict( - # device=c.device, - # ), - activation_fn=nn.ReLU, - net_arch=[512, 512, 512], - ) - - save_path = ( - Path("~/.cache/autotech/checkpoints").expanduser() / c.ExtractorClass.__name__ - ) + save_path = c.save_dir / "checkpoints" / c.ExtractorClass.__name__ save_path.mkdir(parents=True, exist_ok=True) @@ -50,12 +39,12 @@ if valid_files: model_path = max(valid_files, key=lambda x: int(x.name.rstrip(".zip"))) print(f"Loading model {model_path.name}") - model = PPO.load(model_path, envs, **c.ppo_args, policy_kwargs=policy_kwargs) + model = PPO.load(model_path, envs, **c.ppo_args, policy_kwargs=c.policy_kwargs) i = int(model_path.name.rstrip(".zip")) + 1 print(f"Model found, loading {model_path}") else: - model = PPO("MlpPolicy", envs, **c.ppo_args, policy_kwargs=policy_kwargs) + model = PPO("MlpPolicy", envs, **c.ppo_args, policy_kwargs=c.policy_kwargs) i = 0 print("Model not found, creating a new one") @@ -72,9 +61,7 @@ while True: onnx_utils.export_onnx( model, - os.path.expanduser( - f"~/.cache/autotech/model_{c.ExtractorClass.__name__}.onnx" - ), + str(c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx"), ) onnx_utils.test_onnx(model) diff --git a/src/simulation/src/simulation/config.py b/src/simulation/src/simulation/config.py index 05a38d7..5b1dd8c 100644 --- a/src/simulation/src/simulation/config.py +++ b/src/simulation/src/simulation/config.py @@ -1,7 +1,9 @@ # just a file that lets us define some constants that are used in multiple files the simulation import logging +from pathlib import Path from typing import Any, Dict +import torch.nn as nn from torch.cuda import is_available from extractors import ( # noqa: F401 @@ -18,10 +20,12 @@ n_actions_steering = 16 n_actions_speed = 16 lidar_max_range = 12.0 -device = "cuda" if is_available() else "cpu" +respawn_on_crash = True # whether to go backwards or to respawn when crashing # Training config +device = "cuda" if is_available() else "cpu" +save_dir = Path("~/.cache/autotech").expanduser() total_timesteps = 500_000 ppo_args: Dict[str, Any] = dict( n_steps=4096, @@ -33,6 +37,10 @@ normalize_advantage=True, device=device, ) + + +# Common extractor shared between the policy and value networks +# (cf: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html) ExtractorClass = TemporalResNetExtractor context_size = ExtractorClass.context_size lidar_horizontal_resolution = ExtractorClass.lidar_horizontal_resolution @@ -40,6 +48,15 @@ n_sensors = ExtractorClass.n_sensors +# Architecture of the model +policy_kwargs: Dict[str, Any] = dict( + features_extractor_class=ExtractorClass, + activation_fn=nn.ReLU, + # Architecture of the MLP heads for the Value and Policy networks + net_arch=[512, 512, 512], +) + + # Logging config LOG_LEVEL = logging.INFO FORMATTER = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/src/simulation/src/utils/onnx_utils.py b/src/simulation/src/utils/onnx_utils.py index 9ff3784..1b90adb 100644 --- a/src/simulation/src/utils/onnx_utils.py +++ b/src/simulation/src/utils/onnx_utils.py @@ -54,8 +54,7 @@ def test_onnx(model: OnPolicyAlgorithm): try: class_name = model.policy.features_extractor.__class__.__name__ - model_path = os.path.expanduser(f"~/.cache/autotech/model_{class_name}.onnx") - + model_path = c.save_dir / f"model_{class_name}.onnx" session = ort.InferenceSession(model_path) except Exception as e: print(f"Error loading ONNX model: {e}") diff --git a/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py b/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py index 8647300..7b656f4 100644 --- a/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py +++ b/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py @@ -170,7 +170,7 @@ def step(self): done = np.True_ elif b_collided: reward = np.float32(-0.5) - done = np.False_ + done = np.bool(c.respawn_on_crash) elif b_past_checkpoint: reward = np.float32(1.0) done = np.False_