Skip to content

Fix SequenceDescriptor API for TransformerEngine >= 2.12#3589

Draft
yeandy wants to merge 3 commits intoAI-Hypercomputer:mainfrom
ROCm:yeandy/te-sequence-descriptor-api
Draft

Fix SequenceDescriptor API for TransformerEngine >= 2.12#3589
yeandy wants to merge 3 commits intoAI-Hypercomputer:mainfrom
ROCm:yeandy/te-sequence-descriptor-api

Conversation

@yeandy
Copy link
Copy Markdown
Contributor

@yeandy yeandy commented Apr 7, 2026

Description

  • Fix SequenceDescriptor.from_segment_ids_and_pos() calls in cudnn_flash_attention to pass the two new required keyword arguments (is_thd, is_segment_ids_reordered) added in TransformerEngine v2.12
  • Add regression test test_gpu_packed_attention_hf that exercises the actual SequenceDescriptor codepath with real HF data

TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523, merged Dec 31, 2025) made is_thd and is_segment_ids_reordered required keyword-only arguments on SequenceDescriptor.from_segment_ids_and_pos() to fix incorrect segment position calculation for THD layouts. This is a breaking change documented in the TE v2.12 release notes.
Since MaxText's cuda12-requirements.txt pins transformer-engine>=2.9.0, any fresh install picks up TE >= 2.12, which breaks any configuration using attention="cudnn_flash_te" + packing=True + real data (dataset_type != "synthetic") with:

TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2 required
keyword-only arguments: 'is_thd' and 'is_segment_ids_reordered'

The ROCm fork of TransformerEngine picked up the same change via its IFU 2.12 merge (ccda1a5, Apr 3, 2026).

Fix

The packing branch in cudnn_flash_attention uses qkv_layout="THD_THD_THD" with standard (non-reordered) segment IDs, so the correct values are:

  • is_thd=True — THD packed layout
  • is_segment_ids_reordered=False — no load-balancing reordering applied

Why this wasn't caught by existing tests

The existing test_gpu_packed_attention uses dataset_type=synthetic, which takes the elif branch in cudnn_flash_attention and sets attn_mask = None — completely bypassing SequenceDescriptor.from_segment_ids_and_pos(). The new test_gpu_packed_attention_hf uses HF parquet data to exercise the actual codepath.

Tests

[] Verify the fix with a training run using attention=cudnn_flash_te + packing=True + non-synthetic data on a GPU with TE >= 2.12
[ ] test_gpu_packed_attention (existing, synthetic) still passes
[ ] test_gpu_packed_attention_hf (new, HF data) passes on sm90+ hardware

Convergence Validation

Verified numerical correctness with a 3-way convergence comparison on 8× H200 (sm90), 5000 training steps, streaming allenai/c4 from HuggingFace:

Run Container TE Version Attention API
1 (reference) nvcr.io/nvidia/jax:26.03-maxtext-py3 2.13.0 dot_product N/A
2 (PR fix) nvcr.io/nvidia/jax:26.03-maxtext-py3 2.13.0 cudnn_flash_te is_thd=True, is_segment_ids_reordered=False
3 (old TE) nvcr.io/nvidia/jax:26.01-maxtext-py3 2.10.0 cudnn_flash_te Old API (no is_thd kwarg)

Model config: base_emb_dim=1024, base_mlp_dim=4096, base_num_decoder_layers=8, head_dim=128, per_device_batch_size=1, packing=True, max_target_length=512, tokenizer=google-t5/t5-large

Loss comparison at milestones (click to expand)
Step dot_product (ref) te+fix (TE 2.13) te+old (TE 2.10) fix vs ref old vs ref
0 10.9452 10.9453 10.9453 0.00% 0.00%
100 6.7611 6.7614 6.7615 0.01% 0.01%
200 6.4777 6.4832 6.4967 0.08% 0.29%
500 5.8637 5.8760 5.8821 0.21% 0.32%
1000 5.6764 5.6769 5.6863 0.01% 0.17%
1500 5.9349 5.9302 5.9131 0.08% 0.37%
2000 5.5790 5.5958 5.5812 0.30% 0.04%
2500 5.5399 5.5718 5.5403 0.58% 0.01%
3000 5.5533 5.5516 5.5358 0.03% 0.31%
3500 5.5888 5.6207 5.5903 0.57% 0.03%
4000 5.2008 5.2163 5.2319 0.30% 0.60%
4500 5.4246 5.4533 5.4582 0.53% 0.62%
4999 5.3046 5.3288 5.3191 0.46% 0.27%

Results:

  • Final loss: 5.305 (ref) / 5.329 (fix, TE 2.13) / 5.319 (old, TE 2.10)
  • Avg per-step divergence vs reference: 0.36% (fix) / 0.26% (old TE)
  • All three runs converge to the same loss range (~5.3). Per-step differences stay < 0.7%, consistent with expected numerical noise from different attention kernel implementations.

Conclusion: The PR fix (is_thd=True, is_segment_ids_reordered=False) on TE 2.13 is numerically equivalent to the old TE 2.10 behavior. Correctness is preserved.

Training logs — Runs 1 & 2 (TE 2.13, nvcr.io/nvidia/jax:26.03-maxtext-py3)
======================================================================
GPUs:               8 x NVIDIA H200
Compute cap:        9.0
JAX:                0.9.0.dev20260205+6e29effa6
TransformerEngine:  2.13.0+287770466
======================================================================

Starting: conv_dot_product  attention=dot_product  steps=5000
Done: conv_dot_product  218.5s

Starting: conv_cudnn_te_fix  attention=cudnn_flash_te  steps=5000
Done: conv_cudnn_te_fix  417.1s
Training log — Run 3 (TE 2.10, nvcr.io/nvidia/jax:26.01-maxtext-py3)
======================================================================
GPUs:               8 x NVIDIA H200
Compute cap:        9.0
JAX:                0.8.1.dev20251212+6ab1fef24
TransformerEngine:  2.10.0+769ed7783
======================================================================

Starting: conv_cudnn_te_old  attention=cudnn_flash_te  steps=5000
Done: conv_cudnn_te_old  227.1s

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

yeandy added 3 commits April 7, 2026 09:59
TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523) made `is_thd`
and `is_segment_ids_reordered` required keyword arguments on
`SequenceDescriptor.from_segment_ids_and_pos()` to fix incorrect
segment position calculation for THD layouts.

Since the packing branch in `cudnn_flash_attention` uses
`qkv_layout="THD_THD_THD"` with standard (non-reordered) segment IDs,
the correct values are `is_thd=True, is_segment_ids_reordered=False`.

Without this fix, any configuration using `attention="cudnn_flash_te"`
with `packing=True` and real data (`dataset_type != "synthetic"`) fails
with:
  TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2
  required keyword-only arguments: 'is_thd' and
  'is_segment_ids_reordered'
The existing test_gpu_packed_attention uses dataset_type=synthetic,
which bypasses SequenceDescriptor.from_segment_ids_and_pos() entirely
(takes the elif branch in cudnn_flash_attention). This meant the
TE v2.12 API breakage went undetected.

Add test_gpu_packed_attention_hf that uses HF parquet data with
packing=True + attention=cudnn_flash_te, exercising the actual
SequenceDescriptor codepath. This serves as a regression test for
NVIDIA/TransformerEngine#2523.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant