Fix SequenceDescriptor API for TransformerEngine >= 2.12#3589
Draft
yeandy wants to merge 3 commits intoAI-Hypercomputer:mainfrom
Draft
Fix SequenceDescriptor API for TransformerEngine >= 2.12#3589yeandy wants to merge 3 commits intoAI-Hypercomputer:mainfrom
yeandy wants to merge 3 commits intoAI-Hypercomputer:mainfrom
Conversation
TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523) made `is_thd` and `is_segment_ids_reordered` required keyword arguments on `SequenceDescriptor.from_segment_ids_and_pos()` to fix incorrect segment position calculation for THD layouts. Since the packing branch in `cudnn_flash_attention` uses `qkv_layout="THD_THD_THD"` with standard (non-reordered) segment IDs, the correct values are `is_thd=True, is_segment_ids_reordered=False`. Without this fix, any configuration using `attention="cudnn_flash_te"` with `packing=True` and real data (`dataset_type != "synthetic"`) fails with: TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2 required keyword-only arguments: 'is_thd' and 'is_segment_ids_reordered'
The existing test_gpu_packed_attention uses dataset_type=synthetic, which bypasses SequenceDescriptor.from_segment_ids_and_pos() entirely (takes the elif branch in cudnn_flash_attention). This meant the TE v2.12 API breakage went undetected. Add test_gpu_packed_attention_hf that uses HF parquet data with packing=True + attention=cudnn_flash_te, exercising the actual SequenceDescriptor codepath. This serves as a regression test for NVIDIA/TransformerEngine#2523.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
SequenceDescriptor.from_segment_ids_and_pos()calls incudnn_flash_attentionto pass the two new required keyword arguments (is_thd,is_segment_ids_reordered) added in TransformerEngine v2.12test_gpu_packed_attention_hfthat exercises the actualSequenceDescriptorcodepath with real HF dataTransformerEngine v2.12 (NVIDIA/TransformerEngine#2523, merged Dec 31, 2025) made
is_thdandis_segment_ids_reorderedrequired keyword-only arguments onSequenceDescriptor.from_segment_ids_and_pos()to fix incorrect segment position calculation for THD layouts. This is a breaking change documented in the TE v2.12 release notes.Since MaxText's cuda12-requirements.txt pins
transformer-engine>=2.9.0, any fresh install picks upTE >= 2.12, which breaks any configuration usingattention="cudnn_flash_te"+packing=True+ real data (dataset_type != "synthetic") with:The ROCm fork of TransformerEngine picked up the same change via its IFU 2.12 merge (ccda1a5, Apr 3, 2026).
Fix
The packing branch in
cudnn_flash_attentionusesqkv_layout="THD_THD_THD"with standard (non-reordered) segment IDs, so the correct values are:is_thd=True— THD packed layoutis_segment_ids_reordered=False— no load-balancing reordering appliedWhy this wasn't caught by existing tests
The existing
test_gpu_packed_attentionusesdataset_type=synthetic, which takes theelifbranch incudnn_flash_attentionand setsattn_mask = None— completely bypassingSequenceDescriptor.from_segment_ids_and_pos(). The newtest_gpu_packed_attention_hfuses HF parquet data to exercise the actual codepath.Tests
[] Verify the fix with a training run using attention=cudnn_flash_te + packing=True + non-synthetic data on a GPU with TE >= 2.12
[ ] test_gpu_packed_attention (existing, synthetic) still passes
[ ] test_gpu_packed_attention_hf (new, HF data) passes on sm90+ hardware
Convergence Validation
Verified numerical correctness with a 3-way convergence comparison on 8× H200 (sm90), 5000 training steps, streaming
allenai/c4from HuggingFace:nvcr.io/nvidia/jax:26.03-maxtext-py3dot_productnvcr.io/nvidia/jax:26.03-maxtext-py3cudnn_flash_teis_thd=True, is_segment_ids_reordered=Falsenvcr.io/nvidia/jax:26.01-maxtext-py3cudnn_flash_teis_thdkwarg)Model config:
base_emb_dim=1024,base_mlp_dim=4096,base_num_decoder_layers=8,head_dim=128,per_device_batch_size=1,packing=True,max_target_length=512,tokenizer=google-t5/t5-largeLoss comparison at milestones (click to expand)
Results:
Conclusion: The PR fix (
is_thd=True, is_segment_ids_reordered=False) on TE 2.13 is numerically equivalent to the old TE 2.10 behavior. Correctness is preserved.Training logs — Runs 1 & 2 (TE 2.13,
nvcr.io/nvidia/jax:26.03-maxtext-py3)Training log — Run 3 (TE 2.10,
nvcr.io/nvidia/jax:26.01-maxtext-py3)Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.