add video saving and uploading support to train_* scripts#524
add video saving and uploading support to train_* scripts#524
train_* scripts#524Conversation
train_* scripts
AdamGleave
left a comment
There was a problem hiding this comment.
Took a quick look, only skimmed as still in draft mode. Seems like a useful feature, couple of suggestions.
Codecov Report
@@ Coverage Diff @@
## master #524 +/- ##
==========================================
- Coverage 96.95% 96.93% -0.03%
==========================================
Files 84 84
Lines 7460 7369 -91
==========================================
- Hits 7233 7143 -90
+ Misses 227 226 -1
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
| ) | ||
| callback_objs.append(save_policy_callback) | ||
|
|
||
| if _config["train"]["videos"]: |
There was a problem hiding this comment.
Here we need to init a video_wrapper.SaveVideoCallback instead of using train.save_video like other scripts do. A bit unsatisfying.
An alternative could be passing a save_video partial function into the callback.
There was a problem hiding this comment.
Yes, this is strange, why is that the case? I would advocate for using the callback class everywhere or using a partial / closure+wrapper defined in this file for this specific instance. Currently the existence of the class is confusing and not documented.
| rl_algo.set_logger(custom_logger) | ||
| rl_algo.learn(total_timesteps, callback=callback) | ||
|
|
||
| with common.make_venv(num_vec=1, log_dir=None) as eval_venv: |
There was a problem hiding this comment.
Create an eval_venv
- with
num_vec=1. - without having creating monitors from here by setting
log_dir=None.
…eo-saving-during-training
|
I'm still a bit backlogged, @Rocamonde could you review this please? |
| total_timesteps = int(1e6) # total number of environment timesteps | ||
| total_comparisons = 5000 # total number of comparisons to elicit | ||
| num_iterations = 5 # Arbitrary, should be tuned for the task | ||
| num_iterations = 50 # Arbitrary, should be tuned for the task |
There was a problem hiding this comment.
Apologies if this has been discussed, but why are you doing this?
| cross_entropy_loss_kwargs = {} | ||
| reward_trainer_kwargs = { | ||
| "epochs": 3, | ||
| "weight_decay": 0.0, |
There was a problem hiding this comment.
I'll have to remember changing this as I have a PR that replaces weight decay with a general regularization API (#481). @AdamGleave what do you think, should we merge my PR or this one first?
There was a problem hiding this comment.
Probably best to merge your PR first, though really depends which one is ready earlier.
| ) | ||
| callback_objs.append(save_policy_callback) | ||
|
|
||
| if _config["train"]["videos"]: |
There was a problem hiding this comment.
Yes, this is strange, why is that the case? I would advocate for using the callback class everywhere or using a partial / closure+wrapper defined in this file for this specific instance. Currently the existence of the class is confusing and not documented.
| total_timesteps: int, | ||
| total_comparisons: int, | ||
| callback: Optional[Callable[[int], None]] = None, | ||
| callback: Optional[Callable[[int, int], None]] = None, |
There was a problem hiding this comment.
Probably should add in the docstring what the callback type signature represents.
|
|
||
|
|
||
| @train_ingredient.capture | ||
| def save_video( |
There was a problem hiding this comment.
When you call this function it self-documents as if the video were always saved. (but a flag indicating whether this should happen is magically injected through a decorator). I don't have an immediately better alternative, but perhaps a more explanatory function name could help.
| round_str: str, | ||
| ) -> None: | ||
| """Save discriminator and generator.""" | ||
| save_path = os.path.join(log_dir, "checkpoints", round_str) |
There was a problem hiding this comment.
I have a PR for replacing os.path with pathlib in most places, but might as well keep it consistent for now until that's merged.
| """ | ||
| super().__init__(env) | ||
| self.episode_id = 0 | ||
| self._episode_id = 0 |
| directory=video_dir, | ||
| **(video_kwargs or dict()), | ||
| ) | ||
| sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=1) |
There was a problem hiding this comment.
I understand where the name of this function is coming from ("make the function called sample_until"), but how it actually reads IMO is "make the sample (until...?)". I think that refactoring this to something like "get_stopping_conditions_callback" or "get_sampling_termination_fn" would be much more readable.
| sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=1) | ||
| # video.{:06}.mp4".format(VideoWrapper.episode_id) will be saved within | ||
| # rollout.generate_trajectories() | ||
| rollout.generate_trajectories(policy, video_venv, sample_until) |
There was a problem hiding this comment.
For some reason I was expecting that the video that would be saved would be one of the real training trajectories instead of a newly sampled one.
|
Closing in favor of #597 |
Description
Closes #523.
Problem
scripts.train_rl,scripts.train_preference_comparisons,scripts.train_adversarialandscripts.train_bc.Solution
record_and_save_video()function inimitation.util.video_wrapperthat takes in a policy, eval_venv, and a logger to save the video of a policy evaluated on an environment to a designated path.WandbOutputFormat.write()by adding the following:Testing
tests/scripts/test_scripts.pytests/util/test_wb_logger.py