-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathenvironment.py
More file actions
67 lines (60 loc) · 2.4 KB
/
environment.py
File metadata and controls
67 lines (60 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gym
import numpy as np
class TSCEnv(gym.Env):
"""
Environment for Traffic Signal Control task.
Parameters
----------
world: World object
agents: list of agents, corresponding to each intersection in world.intersections
metric: Metric object, used to calculate evaluation metric
"""
def __init__(self, world, agents, metric):
"""
:param world: one world object to interact with agents. Support multi world
objects in different TSCEnvs.
:param agents: single agents, each control all intersections. Or multi agents,
each control one intersection.
actions is a list of actions, agents is a list of agents.
:param metric: metrics to evaluate policy.
"""
self.world = world
self.eng = self.world.eng
self.n_agents = len(agents) * agents[0].sub_agents
# test agents number == intersection number
assert len(world.intersection_ids) == self.n_agents
self.agents = agents
action_dims = [agent.action_space.n * agent.sub_agents for agent in agents]
# total action space of all agents.
self.action_space = gym.spaces.MultiDiscrete(action_dims)
self.metric = metric
def step(self, actions):
"""
:param actions: keep action as N_agents * 1
"""
if not actions.shape:
assert(self.n_agents == 1)
actions = actions[np.newaxis]
else:
assert len(actions) == self.n_agents
self.world.step(actions)
if not len(self.agents) == 1:
obs = [agent.get_ob() for agent in self.agents]
# obs = np.expand_dims(np.array(obs),axis=1)
rewards = [agent.get_reward() for agent in self.agents]
# rewards = np.expand_dims(np.array(rewards),axis=1)
else:
obs = [self.agents[0].get_ob()]
rewards = [self.agents[0].get_reward()]
dones = [False] * self.n_agents
# infos = {"metric": self.metric.update()}
infos = {}
return obs, rewards, dones, infos
def reset(self):
self.world.reset()
if not len(self.agents) == 1:
obs = [agent.get_ob() for agent in self.agents] # [agent, sub_agent==1, feature]
# obs = np.expand_dims(np.array(obs),axis=1)
else:
obs = [self.agents[0].get_ob()] # [agent==1, sub_agent, feature]
return obs