diff --git a/src/simulation/scripts/launch_train_multiprocessing.py b/src/simulation/scripts/launch_train_multiprocessing.py index 527baac..fc01f97 100644 --- a/src/simulation/scripts/launch_train_multiprocessing.py +++ b/src/simulation/scripts/launch_train_multiprocessing.py @@ -30,29 +30,7 @@ ] ) - policy_kwargs: Dict[str, Any] = dict( - features_extractor_class=c.ExtractorClass, - # features_extractor_kwargs=dict( - # device=c.device, - # ), - activation_fn=nn.ReLU, - net_arch=[512, 512, 512], - ) - - ppo_args: Dict[str, Any] = dict( - n_steps=4096, - n_epochs=10, - batch_size=256, - learning_rate=3e-4, - gamma=0.99, - verbose=1, - normalize_advantage=True, - device=c.device, - ) - - save_path = ( - Path("~/.cache/autotech/checkpoints").expanduser() / c.ExtractorClass.__name__ - ) + save_path = c.save_dir / "checkpoints" / c.ExtractorClass.__name__ save_path.mkdir(parents=True, exist_ok=True) @@ -61,12 +39,12 @@ if valid_files: model_path = max(valid_files, key=lambda x: int(x.name.rstrip(".zip"))) print(f"Loading model {model_path.name}") - model = PPO.load(model_path, envs, **ppo_args, policy_kwargs=policy_kwargs) + model = PPO.load(model_path, envs, **c.ppo_args, policy_kwargs=c.policy_kwargs) i = int(model_path.name.rstrip(".zip")) + 1 print(f"Model found, loading {model_path}") else: - model = PPO("MlpPolicy", envs, **ppo_args, policy_kwargs=policy_kwargs) + model = PPO("MlpPolicy", envs, **c.ppo_args, policy_kwargs=c.policy_kwargs) i = 0 print("Model not found, creating a new one") @@ -83,9 +61,7 @@ while True: onnx_utils.export_onnx( model, - os.path.expanduser( - f"~/.cache/autotech/model_{c.ExtractorClass.__name__}.onnx" - ), + str(c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx"), ) onnx_utils.test_onnx(model) @@ -93,12 +69,12 @@ from utils import PlotModelIO model.learn( - total_timesteps=500_000, + total_timesteps=c.total_timesteps, progress_bar=False, callback=PlotModelIO(), ) else: - model.learn(total_timesteps=500_000, progress_bar=True) + model.learn(total_timesteps=c.total_timesteps, progress_bar=True) print("iteration over") # TODO: we could just use a callback to save checkpoints or export the model to onnx diff --git a/src/simulation/src/simulation/config.py b/src/simulation/src/simulation/config.py index 1e8864d..5b1dd8c 100644 --- a/src/simulation/src/simulation/config.py +++ b/src/simulation/src/simulation/config.py @@ -1,6 +1,9 @@ # just a file that lets us define some constants that are used in multiple files the simulation import logging +from pathlib import Path +from typing import Any, Dict +import torch.nn as nn from torch.cuda import is_available from extractors import ( # noqa: F401 @@ -9,6 +12,7 @@ TemporalResNetExtractor, ) +# Webots environments config n_map = 2 n_simulations = 1 n_vehicles = 2 @@ -16,13 +20,43 @@ n_actions_steering = 16 n_actions_speed = 16 lidar_max_range = 12.0 +respawn_on_crash = True # whether to go backwards or to respawn when crashing + + +# Training config device = "cuda" if is_available() else "cpu" +save_dir = Path("~/.cache/autotech").expanduser() +total_timesteps = 500_000 +ppo_args: Dict[str, Any] = dict( + n_steps=4096, + n_epochs=10, + batch_size=256, + learning_rate=3e-4, + gamma=0.99, + verbose=1, + normalize_advantage=True, + device=device, +) + +# Common extractor shared between the policy and value networks +# (cf: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html) ExtractorClass = TemporalResNetExtractor context_size = ExtractorClass.context_size lidar_horizontal_resolution = ExtractorClass.lidar_horizontal_resolution camera_horizontal_resolution = ExtractorClass.camera_horizontal_resolution n_sensors = ExtractorClass.n_sensors + +# Architecture of the model +policy_kwargs: Dict[str, Any] = dict( + features_extractor_class=ExtractorClass, + activation_fn=nn.ReLU, + # Architecture of the MLP heads for the Value and Policy networks + net_arch=[512, 512, 512], +) + + +# Logging config LOG_LEVEL = logging.INFO FORMATTER = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/src/simulation/src/utils/onnx_utils.py b/src/simulation/src/utils/onnx_utils.py index 9ff3784..1b90adb 100644 --- a/src/simulation/src/utils/onnx_utils.py +++ b/src/simulation/src/utils/onnx_utils.py @@ -54,8 +54,7 @@ def test_onnx(model: OnPolicyAlgorithm): try: class_name = model.policy.features_extractor.__class__.__name__ - model_path = os.path.expanduser(f"~/.cache/autotech/model_{class_name}.onnx") - + model_path = c.save_dir / f"model_{class_name}.onnx" session = ort.InferenceSession(model_path) except Exception as e: print(f"Error loading ONNX model: {e}") diff --git a/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py b/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py index 8647300..7b656f4 100644 --- a/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py +++ b/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py @@ -170,7 +170,7 @@ def step(self): done = np.True_ elif b_collided: reward = np.float32(-0.5) - done = np.False_ + done = np.bool(c.respawn_on_crash) elif b_past_checkpoint: reward = np.float32(1.0) done = np.False_