diff --git a/src/simulation/scripts/lanch_one_simu.py b/src/simulation/scripts/lanch_one_simu.py deleted file mode 100644 index 8779f0b..0000000 --- a/src/simulation/scripts/lanch_one_simu.py +++ /dev/null @@ -1,79 +0,0 @@ -raise NotImplementedError("This file is currently begin worked on") - -import os -import sys - -import onnxruntime as ort - -from simulation import ( - VehicleEnv, -) -from simulation import config as c -from utils import onnx_utils - -# ------------------------------------------------------------------------- - - -# --- Chemin vers le fichier ONNX --- - -ONNX_MODEL_PATH = "model.onnx" - - -# --- Initialisation du moteur d'inférence ONNX Runtime (ORT) --- -def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession: - if not os.path.exists(onnx_path): - raise FileNotFoundError( - f"Le fichier ONNX est introuvable à : {onnx_path}. Veuillez l'exporter d'abord." - ) - - # Crée la session d'inférence - return ort.InferenceSession( - onnx_path - ) # On peut modifier le providers afin de mettre une CUDA - - -if __name__ == "__main__": - if not os.path.exists("/tmp/autotech/"): - os.mkdir("/tmp/autotech/") - - os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi') - - # 2. Initialisation de la session ONNX Runtime - try: - ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH) - input_name = ort_session.get_inputs()[0].name - output_name = ort_session.get_outputs()[0].name - print(f"Modèle ONNX chargé depuis {ONNX_MODEL_PATH}") - print(f"Input Name: {input_name}, Output Name: {output_name}") - except FileNotFoundError as e: - print(f"ERREUR : {e}") - print( - "Veuillez vous assurer que vous avez exécuté une fois le script d'entraînement pour exporter 'model.onnx'." - ) - sys.exit(1) - - # 3. Boucle d'inférence (Test) - env = VehicleEnv(0, 0) - obs = env.reset() - print("Début de la simulation en mode inférence...") - - max_steps = 5000 - step_count = 0 - - while True: - action = onnx_utils.run_onnx_model(ort_session, obs) - - # 4. Exécuter l'action dans l'environnement - obs, reward, done, info = env.step(action) - - # Note: L'environnement Webots gère généralement son propre affichage - # env.render() # Décommenter si votre env supporte le rendu externe - - # Gestion des fins d'épisodes - if done: - print(f"Épisode(s) terminé(s) après {step_count} étapes.") - obs = env.reset() - - # Fermeture propre (très important pour les processus parallèles SubprocVecEnv) - envs.close() - print("Simulation terminée. Environnements fermés.") diff --git a/src/simulation/scripts/launch_one_simu.py b/src/simulation/scripts/launch_one_simu.py new file mode 100644 index 0000000..aad8dde --- /dev/null +++ b/src/simulation/scripts/launch_one_simu.py @@ -0,0 +1,76 @@ +import os +import sys +from pathlib import Path + +import numpy as np +import onnxruntime as ort + +import simulation.config as c +from extractors import ( # noqa: F401 + CNN1DResNetExtractor, + TemporalResNetExtractor, +) +from simulation import VehicleEnv +from utils import run_onnx_model + +ONNX_MODEL_PATH = c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx" + + +def init_onnx_runtime_session(onnx_path: Path) -> ort.InferenceSession: + if not os.path.exists(onnx_path): + raise FileNotFoundError( + f"The ONNX file could not be found at: {onnx_path}. Please export it first." + ) + return ort.InferenceSession(onnx_path) + + +if __name__ == "__main__": + if not os.path.exists("/tmp/autotech/"): + os.mkdir("/tmp/autotech/") + + os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi') + + # Starting the ONNX session + try: + ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH) + input_name = ort_session.get_inputs()[0].name + output_name = ort_session.get_outputs()[0].name + print(f"ONNX model loaded from {ONNX_MODEL_PATH}") + print(f"Input Name: {input_name}, Output Name: {output_name}") + except FileNotFoundError as e: + print(f"ERROR: {e}") + sys.exit(1) + + env = VehicleEnv(0, 0) + obs, _ = env.reset() + + print("Starting simulation in inference mode...") + + step_count = 0 + + while True: + raw_action = run_onnx_model(ort_session, obs[None]) + logits = np.array(raw_action).flatten() + + steer_logits = logits[: c.n_actions_steering] + speed_logits = logits[c.n_actions_steering :] + + action_steer = np.argmax(steer_logits) + action_speed = np.argmax(speed_logits) + + action = np.array([action_steer, action_speed], dtype=np.int64) + + next_obs, reward, done, truncated, info = env.step(action) + + step_count += 1 + + if done: + print(f"Episode(s) finished after {step_count} steps.") + step_count = 0 + + fresh_frame = next_obs[:, -1:] + obs, _ = env.reset() + env.context[:, -1:] = fresh_frame + obs = env.context + else: + obs = next_obs diff --git a/src/simulation/scripts/launch_train_multiprocessing.py b/src/simulation/scripts/launch_train_multiprocessing.py index 527baac..fc01f97 100644 --- a/src/simulation/scripts/launch_train_multiprocessing.py +++ b/src/simulation/scripts/launch_train_multiprocessing.py @@ -30,29 +30,7 @@ ] ) - policy_kwargs: Dict[str, Any] = dict( - features_extractor_class=c.ExtractorClass, - # features_extractor_kwargs=dict( - # device=c.device, - # ), - activation_fn=nn.ReLU, - net_arch=[512, 512, 512], - ) - - ppo_args: Dict[str, Any] = dict( - n_steps=4096, - n_epochs=10, - batch_size=256, - learning_rate=3e-4, - gamma=0.99, - verbose=1, - normalize_advantage=True, - device=c.device, - ) - - save_path = ( - Path("~/.cache/autotech/checkpoints").expanduser() / c.ExtractorClass.__name__ - ) + save_path = c.save_dir / "checkpoints" / c.ExtractorClass.__name__ save_path.mkdir(parents=True, exist_ok=True) @@ -61,12 +39,12 @@ if valid_files: model_path = max(valid_files, key=lambda x: int(x.name.rstrip(".zip"))) print(f"Loading model {model_path.name}") - model = PPO.load(model_path, envs, **ppo_args, policy_kwargs=policy_kwargs) + model = PPO.load(model_path, envs, **c.ppo_args, policy_kwargs=c.policy_kwargs) i = int(model_path.name.rstrip(".zip")) + 1 print(f"Model found, loading {model_path}") else: - model = PPO("MlpPolicy", envs, **ppo_args, policy_kwargs=policy_kwargs) + model = PPO("MlpPolicy", envs, **c.ppo_args, policy_kwargs=c.policy_kwargs) i = 0 print("Model not found, creating a new one") @@ -83,9 +61,7 @@ while True: onnx_utils.export_onnx( model, - os.path.expanduser( - f"~/.cache/autotech/model_{c.ExtractorClass.__name__}.onnx" - ), + str(c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx"), ) onnx_utils.test_onnx(model) @@ -93,12 +69,12 @@ from utils import PlotModelIO model.learn( - total_timesteps=500_000, + total_timesteps=c.total_timesteps, progress_bar=False, callback=PlotModelIO(), ) else: - model.learn(total_timesteps=500_000, progress_bar=True) + model.learn(total_timesteps=c.total_timesteps, progress_bar=True) print("iteration over") # TODO: we could just use a callback to save checkpoints or export the model to onnx diff --git a/src/simulation/src/simulation/config.py b/src/simulation/src/simulation/config.py index 1e8864d..5fb064f 100644 --- a/src/simulation/src/simulation/config.py +++ b/src/simulation/src/simulation/config.py @@ -1,6 +1,9 @@ # just a file that lets us define some constants that are used in multiple files the simulation import logging +from pathlib import Path +from typing import Any, Dict +import torch.nn as nn from torch.cuda import is_available from extractors import ( # noqa: F401 @@ -9,20 +12,51 @@ TemporalResNetExtractor, ) +# Webots environments config n_map = 2 n_simulations = 1 -n_vehicles = 2 +n_vehicles = 1 n_stupid_vehicles = 0 n_actions_steering = 16 n_actions_speed = 16 lidar_max_range = 12.0 +respawn_on_crash = True # whether to go backwards or to respawn when crashing + + +# Training config device = "cuda" if is_available() else "cpu" +save_dir = Path("~/.cache/autotech").expanduser() +total_timesteps = 500_000 +ppo_args: Dict[str, Any] = dict( + n_steps=4096, + n_epochs=10, + batch_size=256, + learning_rate=3e-4, + gamma=0.99, + verbose=1, + normalize_advantage=True, + device=device, +) + +# Common extractor shared between the policy and value networks +# (cf: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html) ExtractorClass = TemporalResNetExtractor context_size = ExtractorClass.context_size lidar_horizontal_resolution = ExtractorClass.lidar_horizontal_resolution camera_horizontal_resolution = ExtractorClass.camera_horizontal_resolution n_sensors = ExtractorClass.n_sensors + +# Architecture of the model +policy_kwargs: Dict[str, Any] = dict( + features_extractor_class=ExtractorClass, + activation_fn=nn.ReLU, + # Architecture of the MLP heads for the Value and Policy networks + net_arch=[512, 512, 512], +) + + +# Logging config LOG_LEVEL = logging.INFO FORMATTER = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/src/simulation/src/utils/__init__.py b/src/simulation/src/utils/__init__.py index 4b7ee72..7e56893 100644 --- a/src/simulation/src/utils/__init__.py +++ b/src/simulation/src/utils/__init__.py @@ -1,3 +1,8 @@ from .plot_model_io import PlotModelIO +import onnxruntime as ort +import numpy as np __all__ = ["PlotModelIO"] + +def run_onnx_model(session: ort.InferenceSession, x: np.ndarray): + return session.run(None, {"input": x})[0] \ No newline at end of file diff --git a/src/simulation/src/utils/onnx_utils.py b/src/simulation/src/utils/onnx_utils.py index 9ff3784..1b90adb 100644 --- a/src/simulation/src/utils/onnx_utils.py +++ b/src/simulation/src/utils/onnx_utils.py @@ -54,8 +54,7 @@ def test_onnx(model: OnPolicyAlgorithm): try: class_name = model.policy.features_extractor.__class__.__name__ - model_path = os.path.expanduser(f"~/.cache/autotech/model_{class_name}.onnx") - + model_path = c.save_dir / f"model_{class_name}.onnx" session = ort.InferenceSession(model_path) except Exception as e: print(f"Error loading ONNX model: {e}") diff --git a/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py b/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py index 8647300..7b656f4 100644 --- a/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py +++ b/src/simulation/src/webots/controllers/controller_world_supervisor/controller_world_supervisor.py @@ -170,7 +170,7 @@ def step(self): done = np.True_ elif b_collided: reward = np.float32(-0.5) - done = np.False_ + done = np.bool(c.respawn_on_crash) elif b_past_checkpoint: reward = np.float32(1.0) done = np.False_ diff --git a/uv.lock b/uv.lock index 7effdcb..7ac586b 100644 --- a/uv.lock +++ b/uv.lock @@ -852,6 +852,11 @@ dependencies = [ { name = "zmq" }, ] +[package.optional-dependencies] +controller = [ + { name = "pygame" }, +] + [package.metadata] requires-dist = [ { name = "adafruit-blinka", specifier = ">=8.0.0" }, @@ -869,6 +874,7 @@ requires-dist = [ { name = "onnxruntime", specifier = ">=1.8.0" }, { name = "opencv-python", specifier = ">=4.12.0.88" }, { name = "picamera2", specifier = ">=0.3.0" }, + { name = "pygame", marker = "extra == 'controller'", specifier = ">=2.6.1" }, { name = "pyps4controller", specifier = ">=1.2.0" }, { name = "rpi-gpio", specifier = ">=0.7.1" }, { name = "rpi-hardware-pwm", specifier = ">=0.1.0" }, @@ -879,6 +885,7 @@ requires-dist = [ { name = "websockets", specifier = ">=16.0" }, { name = "zmq", specifier = ">=0.0.0" }, ] +provides-extras = ["controller"] [[package]] name = "humanfriendly" @@ -2392,6 +2399,9 @@ dependencies = [ wheels = [ { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" },