[train] Multi-modal inputs support in FSDP2#1331
Conversation
|
|
||
| model_config = AutoConfig.from_pretrained(pretrain_or_model, trust_remote_code=True, **model_config_kwargs) | ||
|
|
||
| # Fall back to AutoModelForVision2Seq for VLMs (e.g. Qwen3-VL) |
There was a problem hiding this comment.
This seems very ad hoc (e.g. why are we always falling back to AutoModelForVision2Seq if the model config has no AutoModelForCausalLM mapping -- there might for example be other classes we need to fall back to), is there a better approach to make this work?
There was a problem hiding this comment.
Also the fact that use_liger_kernel and type(model_config) not in AutoModelForCausalLM._model_mapping determine if self.is_vlm is true or false seems off
There was a problem hiding this comment.
I removed the use_liger_kernel and _model_mapping check and instead use "vision_config" in model_config which should be a better indicator of a VLM. AutoModelForImageTextToText works for VLMs but in transformers 5.0.0+ all multi-modal models can use AutoModelForMultimodalLM.
Alternatively, we can fall back to the base AutoModel but that feels too broad.
| return False | ||
| for k, v in self.items(): | ||
| if k not in other or not torch.equal(v, other[k]): | ||
| if k not in other: |
There was a problem hiding this comment.
We already tested this above with if set(self.keys()) != set(other.keys()): return False
| if k not in other or not torch.equal(v, other[k]): | ||
| if k not in other: | ||
| return False | ||
| other_v = other[k] |
There was a problem hiding this comment.
You should check this, but it seems to me this can be simplified:
if isinstance(v, torch.Tensor) and isinstance(other[k], torch.Tensor) and not torch.equal(v, other[k]):
return False
if v != other[k]:
return FalseThere was a problem hiding this comment.
I agree. I think the one other edge case is if torch.equal(v, other[k]) is True, then the v != other[k] is run and can error. To fix this, I made the torch.equal(v, other[k]) a nested if statement.
| assert (action_log_probs <= 0).all() | ||
|
|
||
|
|
||
| def test_vlm_vision_affects_output(vlm_model, processor): |
There was a problem hiding this comment.
I feel like test_vlm_forward_with_vision_data, test_vlm_forward_text_only and test_vlm_vision_affects_output can be consolidated into a single test, there is really no reason to have three separate ones that basically do the same thing.
There was a problem hiding this comment.
That makes sense. All 3 (+ test_vlm_different_images_diverge) are actually redundant since test_vlm_semantic_color_recognition will test the forward pass logic and checks the vision inputs affects log probs. I updated the tests to only:
test_vlm_log_probs_match_manual- manual log prob calculation with the HF modeltest_vlm_semantic_color_recognition- log prob calculation of different colorstest_vlm_forward_batched_vision- batched forward pass logic
| if len(tensors) == 0: | ||
| raise ValueError("Cannot create a TensorList with no tensors.") | ||
| self.tensors = tensors |
There was a problem hiding this comment.
It would be good to also assert that the tensors are all on the same device.
It looks like you assume so based on device property implementation
This PR adds support for variable-shape vision inputs (e.g.
pixel_values,image_grid_thw) to the FSDP2 training backend, item 2 in #1200 .Summary
TensorList: New container intraining_batch.pyfor lists of tensors with variable shapes per batch element (e.g., each image may produce a different number of patches). Supports the same batch operations as regular tensors: slicing, chunking, concatenation, repeat, device transfer, and pickle serialization.TensorBatchextended:TensorBatchnow acceptsTensorListfields alongside regular tensors. All operations (chunk, slice, cat, repeat, pickle, equality,__setitem__) handle both types transparently.TrainingInputschema andExperiencedataclass: Added optionalpixel_valuesandimage_grid_thwfields.BatchIterator.batch_to_experience: Propagates vision fields fromTrainingInputBatchtoExperience.HFModelWrapper.forward: Acceptspixel_valuesandimage_grid_thwkwargs, convertsTensorListto concatenated tensors for the HF model, and routes through a VLM-specific forward path (no position IDs, passespixel_values/image_grid_thwdirectly).HFModelWrapper.__init__: Auto-detects VLMs viaAutoModelForCausalLMconfig mapping; falls back toAutoModelForVision2Seqand setsis_vlm=True.PolicyWorkerBase: Both the training step and the initial policy log-prob computation now pass vision fields through to the model.Test plan
CPU unit tests:
test_train_batch.py—TensorListunit tests +TensorBatchintegration tests + bfloat16 pickle roundtrip.test_vlm_data_plumbing.py— Verifiespixel_values/image_grid_thwflow throughbatch_to_experienceandExperience.to_device.GPU integration tests:
New
test_vlm_model_wrapper.pywith the following tests:is_vlmflag).