Conversation
…iex/dev/warm-and-frozen-teachers
…iex/dev/warm-and-frozen-teachers
…to be (at least here) identical
clessig
left a comment
There was a problem hiding this comment.
Overall looks fine. I pushed some minor changes. config_jepa.yml has 2D rope param but it's not in here. This should be removed (it was also one of the things that caused problems for me).
| cf, target_and_aux_calc_params.get("model_param_overrides", {}) | ||
| ) | ||
| prepare_encoder_teacher( | ||
| meta_ema_model, cf.training_config, cf_overridden.ae_global_dim_embed |
There was a problem hiding this comment.
It would be more generic to pass cf_overridden to prepare_encoder_teacher(); there might be more params in the future from the config that are useful beyond cf_overridden.ae_global_dim_embed
There was a problem hiding this comment.
will fix that
| self.batch_size = batch_size | ||
| self.reset() | ||
|
|
||
| def _forward_teacher(self, model_params, batch): |
There was a problem hiding this comment.
I don't think it's a "private" function since it's called from in the base class. We also usually don't use the '_' convention so I would remove.
There was a problem hiding this comment.
I can rename it
| class FrozenTeacher(EncoderTeacher): | ||
| """SSL teacher using a frozen pre-trained encoder. | ||
|
|
||
| The encoder is loaded from a checkpoint and never updated. Non-encoder |
There was a problem hiding this comment.
The teacher_model is assumed to have non-encoder parts discarded, not?
There was a problem hiding this comment.
The code should do the discarding the original model as specified in its config associated to its run id may have an encoder
| self.teacher_model.eval() | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher: |
There was a problem hiding this comment.
This function is inconsistent with what is done for EMATeacher in model_interface. Either we have from_pretrained() for both classes or we have the functionality in model_inferface.py
There was a problem hiding this comment.
But they conceptually and functionally do different things, so I don't follow
There was a problem hiding this comment.
Ok, can you then maybe briefly explain what the difference is for you between this here and load_encoder_from_checkpoint()
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _create_head(name: str, head_type: str, dim_embed: int, loss_conf, cf=None) -> nn.Module: |
There was a problem hiding this comment.
If this is for teacher_heads then the function same should say so.
| model.pred_heads = nn.ModuleDict() | ||
|
|
||
| # Ensure latent_pre_norm exists (teacher may not have had SSL training) | ||
| if model.latent_pre_norm is None: |
There was a problem hiding this comment.
When/why wouldn't this exist?
There was a problem hiding this comment.
I don't understand, what makes you think we can assume this layer norm exists?
There was a problem hiding this comment.
Ok, I assumed it always exists because it's used in the output.
| 3. Creates fresh latent_heads based on the student's SSL loss config | ||
| """ | ||
| # Strip non-encoder components | ||
| model.forecast_engine = None |
There was a problem hiding this comment.
Can we formulate it as is not encoder so that we are robust to changes in the model design, e.g. we discussed to have a decoder-type model for the stream-specific prediction heads and we will most likely forget this hidden dependency here. Otherwise, we might have a function in model that reduces it to the encoder which is called here.
There was a problem hiding this comment.
Something similar to
encoder_params = {
k: v for k, v in params.items() if k.startswith(("encoder.", "latent_pre_norm"))
}
There was a problem hiding this comment.
okay, will change this
| logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.") | ||
|
|
||
|
|
||
| def load_encoder_from_checkpoint( |
There was a problem hiding this comment.
Why do we need this as well as the first part of prepare_encoder_teacher(); it seems to be the same functionality
| @@ -0,0 +1,16 @@ | |||
| training_config: | |||
There was a problem hiding this comment.
How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing
There was a problem hiding this comment.
it is for testing purposes will remove at the end
| @@ -0,0 +1,7 @@ | |||
| training_config: | |||
There was a problem hiding this comment.
How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing
Description
Allow for the warm start with EMA and Frozen Teachers
Issue Number
Closes #1881
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60