Change callback for AdversarialTrainer#626
Change callback for AdversarialTrainer#626gunnxx wants to merge 2 commits intoHumanCompatibleAI:masterfrom
AdversarialTrainer#626Conversation
| callback: Optional[List[BaseCallback]] = None | ||
| ) -> None: | ||
| """Alternates between training the generator and discriminator. | ||
|
|
There was a problem hiding this comment.
The last part of description and finally a call to callback(round) is probably misleading now.
| if self.gen_callback is None: | ||
| self.gen_callback = callback | ||
| else: | ||
| self.gen_callback = callback + [self.gen_callback] |
There was a problem hiding this comment.
Can someone abuse the API by calling train() multiple times? If so, the value of self.gen_callback would contain nested list, which is not correct. Generally, the value of gen_callback is currently Optional[BaseCallback] and we shouldn't change the type to a list at runtime.
Perhaps it would be better to add an optional callback argument to train_gen(), merge callbacks there, and avoid the stateful change here?
There was a problem hiding this comment.
Also, can the learn_kwargs argument from train_gen() be removed, as discussed in the original issue #607 ?
| self, | ||
| total_timesteps: int, | ||
| callback: Optional[Callable[[int], None]] = None, | ||
| callback: Optional[List[BaseCallback]] = None |
There was a problem hiding this comment.
Do we want to change the semantics of the argument here, or should we rather deprecate the feature (and introduce a different parameter for additional gen_callback)?
I think the suggestion in the original issue was to add a new gen_callback argument. (Btw, stable-baselines supports both CallbackList and list of callbacks if we wanted to be fancy)
| @@ -421,7 +422,7 @@ def train_gen( | |||
| def train( | |||
There was a problem hiding this comment.
One more thing - if you change the arguments, update of training_adversarial.py will also be needed
Changing the callback mechanism of
AdversarialTrainersuch that we can insertsb3.EvalCallback. See #607.