-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcallbacks.py
More file actions
24 lines (20 loc) · 765 Bytes
/
callbacks.py
File metadata and controls
24 lines (20 loc) · 765 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import wandb
from stable_baselines3.common.callbacks import BaseCallback
class WandbCustomCallback(BaseCallback):
def __init__(self, verbose=0):
super(WandbCustomCallback, self).__init__(verbose=verbose)
# get the reward every step
def _on_step(self) -> bool:
# Log scalar values (here a random variable)
wandb.log({"reward": self.locals["rewards"]})
return True
class RewardCallback(BaseCallback):
def __init__(self, verbose=0):
super(RewardCallback, self).__init__(verbose)
self.rewards = []
def _on_step(self):
# Append the reward to the list after each step
self.rewards.append(self.locals["rewards"])
return True
def reset(self):
self.rewards = []