Skip to content

[train] Multi-modal inputs support in FSDP2#1331

Merged
SumanthRH merged 21 commits intoNovaSky-AI:mainfrom
nithinvc:nithinc/mm-training-backend
Mar 27, 2026
Merged

[train] Multi-modal inputs support in FSDP2#1331
SumanthRH merged 21 commits intoNovaSky-AI:mainfrom
nithinvc:nithinc/mm-training-backend

Conversation

@nithinvc
Copy link
Copy Markdown
Contributor

@nithinvc nithinvc commented Mar 17, 2026

This PR adds support for variable-shape vision inputs (e.g. pixel_values, image_grid_thw) to the FSDP2 training backend, item 2 in #1200 .

Summary

  • TensorList: New container in training_batch.py for lists of tensors with variable shapes per batch element (e.g., each image may produce a different number of patches). Supports the same batch operations as regular tensors: slicing, chunking, concatenation, repeat, device transfer, and pickle serialization.
  • TensorBatch extended: TensorBatch now accepts TensorList fields alongside regular tensors. All operations (chunk, slice, cat, repeat, pickle, equality, __setitem__) handle both types transparently.
  • TrainingInput schema and Experience dataclass: Added optional pixel_values and image_grid_thw fields.
  • BatchIterator.batch_to_experience: Propagates vision fields from TrainingInputBatch to Experience.
  • HFModelWrapper.forward: Accepts pixel_values and image_grid_thw kwargs, converts TensorList to concatenated tensors for the HF model, and routes through a VLM-specific forward path (no position IDs, passes pixel_values/image_grid_thw directly).
  • HFModelWrapper.__init__: Auto-detects VLMs via AutoModelForCausalLM config mapping; falls back to AutoModelForVision2Seq and sets is_vlm=True.
  • PolicyWorkerBase: Both the training step and the initial policy log-prob computation now pass vision fields through to the model.

Test plan

CPU unit tests:

  • test_train_batch.pyTensorList unit tests + TensorBatch integration tests + bfloat16 pickle roundtrip.
  • test_vlm_data_plumbing.py — Verifies pixel_values/image_grid_thw flow through batch_to_experience and Experience.to_device.

GPU integration tests:
New test_vlm_model_wrapper.py with the following tests:

  • Tests VLM model loading (is_vlm flag).
  • Forward pass with vision data and text-only fallback.
  • Test vision affects output log probabilities.
  • Verifies returned log-probs vs. manually computed log-probs correctness.
  • Tests that different produce different log probs.
  • Tests semantic color recognition. Given a solid color image, asserts the correct color has the highest log probability.
  • Test batched forward pass and verifies batch permutation leads to the same log probs.

Open with Devin

@nithinvc nithinvc marked this pull request as ready for review March 17, 2026 02:53
gemini-code-assist[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

@SumanthRH SumanthRH self-assigned this Mar 17, 2026

model_config = AutoConfig.from_pretrained(pretrain_or_model, trust_remote_code=True, **model_config_kwargs)

# Fall back to AutoModelForVision2Seq for VLMs (e.g. Qwen3-VL)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very ad hoc (e.g. why are we always falling back to AutoModelForVision2Seq if the model config has no AutoModelForCausalLM mapping -- there might for example be other classes we need to fall back to), is there a better approach to make this work?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the fact that use_liger_kernel and type(model_config) not in AutoModelForCausalLM._model_mapping determine if self.is_vlm is true or false seems off

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the use_liger_kernel and _model_mapping check and instead use "vision_config" in model_config which should be a better indicator of a VLM. AutoModelForImageTextToText works for VLMs but in transformers 5.0.0+ all multi-modal models can use AutoModelForMultimodalLM.

Alternatively, we can fall back to the base AutoModel but that feels too broad.

return False
for k, v in self.items():
if k not in other or not torch.equal(v, other[k]):
if k not in other:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already tested this above with if set(self.keys()) != set(other.keys()): return False

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

if k not in other or not torch.equal(v, other[k]):
if k not in other:
return False
other_v = other[k]
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check this, but it seems to me this can be simplified:

if isinstance(v, torch.Tensor) and isinstance(other[k], torch.Tensor) and not torch.equal(v, other[k]):
    return False
if v != other[k]:
    return False

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think the one other edge case is if torch.equal(v, other[k]) is True, then the v != other[k] is run and can error. To fix this, I made the torch.equal(v, other[k]) a nested if statement.

assert (action_log_probs <= 0).all()


def test_vlm_vision_affects_output(vlm_model, processor):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like test_vlm_forward_with_vision_data, test_vlm_forward_text_only and test_vlm_vision_affects_output can be consolidated into a single test, there is really no reason to have three separate ones that basically do the same thing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. All 3 (+ test_vlm_different_images_diverge) are actually redundant since test_vlm_semantic_color_recognition will test the forward pass logic and checks the vision inputs affects log probs. I updated the tests to only:

  1. test_vlm_log_probs_match_manual - manual log prob calculation with the HF model
  2. test_vlm_semantic_color_recognition - log prob calculation of different colors
  3. test_vlm_forward_batched_vision - batched forward pass logic

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

Comment on lines +58 to +60
if len(tensors) == 0:
raise ValueError("Cannot create a TensorList with no tensors.")
self.tensors = tensors
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to also assert that the tensors are all on the same device.

It looks like you assume so based on device property implementation

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread skyrl/backends/skyrl_train/training_batch.py
@SumanthRH SumanthRH merged commit 92764a7 into NovaSky-AI:main Mar 27, 2026
4 of 7 checks passed
@nithinvc nithinvc deleted the nithinc/mm-training-backend branch March 30, 2026 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants