Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 6 additions & 30 deletions src/simulation/scripts/launch_train_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,7 @@
]
)

policy_kwargs: Dict[str, Any] = dict(
features_extractor_class=c.ExtractorClass,
# features_extractor_kwargs=dict(
# device=c.device,
# ),
activation_fn=nn.ReLU,
net_arch=[512, 512, 512],
)

ppo_args: Dict[str, Any] = dict(
n_steps=4096,
n_epochs=10,
batch_size=256,
learning_rate=3e-4,
gamma=0.99,
verbose=1,
normalize_advantage=True,
device=c.device,
)

save_path = (
Path("~/.cache/autotech/checkpoints").expanduser() / c.ExtractorClass.__name__
)
save_path = c.save_dir / "checkpoints" / c.ExtractorClass.__name__

save_path.mkdir(parents=True, exist_ok=True)

Expand All @@ -61,12 +39,12 @@
if valid_files:
model_path = max(valid_files, key=lambda x: int(x.name.rstrip(".zip")))
print(f"Loading model {model_path.name}")
model = PPO.load(model_path, envs, **ppo_args, policy_kwargs=policy_kwargs)
model = PPO.load(model_path, envs, **c.ppo_args, policy_kwargs=c.policy_kwargs)
i = int(model_path.name.rstrip(".zip")) + 1
print(f"Model found, loading {model_path}")

else:
model = PPO("MlpPolicy", envs, **ppo_args, policy_kwargs=policy_kwargs)
model = PPO("MlpPolicy", envs, **c.ppo_args, policy_kwargs=c.policy_kwargs)

i = 0
print("Model not found, creating a new one")
Expand All @@ -83,22 +61,20 @@
while True:
onnx_utils.export_onnx(
model,
os.path.expanduser(
f"~/.cache/autotech/model_{c.ExtractorClass.__name__}.onnx"
),
str(c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx"),
)
onnx_utils.test_onnx(model)

if c.LOG_LEVEL <= DEBUG:
from utils import PlotModelIO

model.learn(
total_timesteps=500_000,
total_timesteps=c.total_timesteps,
progress_bar=False,
callback=PlotModelIO(),
)
else:
model.learn(total_timesteps=500_000, progress_bar=True)
model.learn(total_timesteps=c.total_timesteps, progress_bar=True)

print("iteration over")
# TODO: we could just use a callback to save checkpoints or export the model to onnx
Expand Down
34 changes: 34 additions & 0 deletions src/simulation/src/simulation/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# just a file that lets us define some constants that are used in multiple files the simulation
import logging
from pathlib import Path
from typing import Any, Dict

import torch.nn as nn
from torch.cuda import is_available

from extractors import ( # noqa: F401
Expand All @@ -9,20 +12,51 @@
TemporalResNetExtractor,
)

# Webots environments config
n_map = 2
n_simulations = 1
n_vehicles = 2
n_stupid_vehicles = 0
n_actions_steering = 16
n_actions_speed = 16
lidar_max_range = 12.0
respawn_on_crash = True # whether to go backwards or to respawn when crashing


# Training config
device = "cuda" if is_available() else "cpu"
save_dir = Path("~/.cache/autotech").expanduser()
total_timesteps = 500_000
ppo_args: Dict[str, Any] = dict(
n_steps=4096,
n_epochs=10,
batch_size=256,
learning_rate=3e-4,
gamma=0.99,
verbose=1,
normalize_advantage=True,
device=device,
)


# Common extractor shared between the policy and value networks
# (cf: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html)
ExtractorClass = TemporalResNetExtractor
context_size = ExtractorClass.context_size
lidar_horizontal_resolution = ExtractorClass.lidar_horizontal_resolution
camera_horizontal_resolution = ExtractorClass.camera_horizontal_resolution
n_sensors = ExtractorClass.n_sensors


# Architecture of the model
policy_kwargs: Dict[str, Any] = dict(
features_extractor_class=ExtractorClass,
activation_fn=nn.ReLU,
# Architecture of the MLP heads for the Value and Policy networks
net_arch=[512, 512, 512],
)


# Logging config
LOG_LEVEL = logging.INFO
FORMATTER = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
3 changes: 1 addition & 2 deletions src/simulation/src/utils/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def test_onnx(model: OnPolicyAlgorithm):

try:
class_name = model.policy.features_extractor.__class__.__name__
model_path = os.path.expanduser(f"~/.cache/autotech/model_{class_name}.onnx")

model_path = c.save_dir / f"model_{class_name}.onnx"
session = ort.InferenceSession(model_path)
except Exception as e:
print(f"Error loading ONNX model: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def step(self):
done = np.True_
elif b_collided:
reward = np.float32(-0.5)
done = np.False_
done = np.bool(c.respawn_on_crash)
elif b_past_checkpoint:
reward = np.float32(1.0)
done = np.False_
Expand Down