-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathnaf2.py
More file actions
94 lines (76 loc) · 2.57 KB
/
naf2.py
File metadata and controls
94 lines (76 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import sys
import tensorflow as tf
import numpy as np
import random
import logging
from pernaf.pernaf.agent import NAF
from simulated_environment import AwakeElectronEnv
def main():
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Parameters
random_seed = 123
try:
if len(sys.argv) > 1:
random_seed = int(sys.argv[1])
except ValueError:
pass
logger.info(f"Starting NAF training with seed {random_seed}")
# Set seeds
tf.random.set_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
# Initialize Environment
# Note: Environment expects 'electron_tt43.out' to be available
try:
env = AwakeElectronEnv()
except FileNotFoundError as e:
logger.error(e)
logger.error("Please ensure 'electron_tt43.out' is present.")
sys.exit(1)
# Training Configuration
directory = f'checkpoints/naf_seed_{random_seed}'
if not os.path.exists(directory):
os.makedirs(directory)
discount = 0.999
batch_size = 10
learning_rate = 1e-3
max_steps = 200 # Shortened for testing, original was higher
update_repeat = 3
max_episodes = 50 # Shortened for testing
# Prioritized Experience Replay (PER) Configuration
# prio_info = dict(alpha=.5, beta_start=.9, beta_decay=lambda nr: max(1e-16, 0.25*(1 - nr / 100)))
prio_info = dict() # Default to no PER for base run, uncomment above to enable
noise_info = dict(noise_function=lambda nr: max(0, 2*(1 - nr / 500)))
nafnet_kwargs = dict(
hidden_sizes=[100, 100],
activation=tf.nn.tanh,
learning_rate=learning_rate
)
# Initialize Agent
# Refactored to SB3 Style
agent = NAF(
env=env,
discount=discount,
batch_size=batch_size,
learning_rate=learning_rate,
max_steps=max_steps,
update_repeat=update_repeat,
# max_episodes was removed in favor of total_timesteps in learn()
prio_info=prio_info,
noise_info=noise_info,
directory=directory,
**nafnet_kwargs
)
# Run Training
logger.info("Starting run...")
# Convert max_episodes to approximate total_timesteps for the learn method
total_timesteps = max_episodes * max_steps
agent.learn(total_timesteps=total_timesteps)
logger.info("Run finished.")
# Save final model
agent.q_target_model_1.save_model(os.path.join(directory, 'final_model'))
if __name__ == '__main__':
main()