From cde6b3aadb90521fcd4cda39cd856ad92bd5e9af Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 7 Apr 2026 09:59:48 -0400 Subject: [PATCH 1/3] Fix SequenceDescriptor.from_segment_ids_and_pos() for TE >= 2.12 TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523) made `is_thd` and `is_segment_ids_reordered` required keyword arguments on `SequenceDescriptor.from_segment_ids_and_pos()` to fix incorrect segment position calculation for THD layouts. Since the packing branch in `cudnn_flash_attention` uses `qkv_layout="THD_THD_THD"` with standard (non-reordered) segment IDs, the correct values are `is_thd=True, is_segment_ids_reordered=False`. Without this fix, any configuration using `attention="cudnn_flash_te"` with `packing=True` and real data (`dataset_type != "synthetic"`) fails with: TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2 required keyword-only arguments: 'is_thd' and 'is_segment_ids_reordered' --- src/maxtext/layers/attention_op.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index d265459e69..ee91098623 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1570,10 +1570,14 @@ def cudnn_flash_attention( qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None) + attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=decoder_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) # Create dummy SequenceDescriptor for lazy_init dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=dummy_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) max_segments_per_seq = self.config.max_segments_per_seq elif using_context_parallelism: if self.attention_type == AttentionType.LOCAL_SLIDING: From 4778b3e2b9e24a716322f8758698e9dba8e462ce Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 7 Apr 2026 10:18:18 -0400 Subject: [PATCH 2/3] Add integration test for packed cuDNN TE attention with real data The existing test_gpu_packed_attention uses dataset_type=synthetic, which bypasses SequenceDescriptor.from_segment_ids_and_pos() entirely (takes the elif branch in cudnn_flash_attention). This meant the TE v2.12 API breakage went undetected. Add test_gpu_packed_attention_hf that uses HF parquet data with packing=True + attention=cudnn_flash_te, exercising the actual SequenceDescriptor codepath. This serves as a regression test for NVIDIA/TransformerEngine#2523. --- tests/integration/train_tests.py | 40 ++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index cc55f07844..b7cff1ddf1 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -582,6 +582,46 @@ def test_gpu_packed_attention(self): ] train_main(packed_attention) + @pytest.mark.integration_test + @pytest.mark.gpu_only + @pytest.mark.external_serving + def test_gpu_packed_attention_hf(self): + """Packed (THD) cuDNN TE attention with real HF data. + + Unlike test_gpu_packed_attention (which uses synthetic data and bypasses + SequenceDescriptor), this test exercises the SequenceDescriptor.from_segment_ids_and_pos() + codepath that is only reached when packing=True AND dataset_type != 'synthetic'. + Regression test for NVIDIA/TransformerEngine#2523. + """ + gpu_device = jax.devices("gpu")[0] + compute_capability = getattr(gpu_device, "compute_capability", None) + try: + if float(compute_capability) < 9.0: + pytest.skip("Packed (THD) attention is only supported on sm90+!") + except Exception: # pylint: disable=broad-exception-caught + print("checking if Packed THD attention is supported on this host...") + pytest.skip("Packed (THD) attention is only supported on sm90+!") + os.environ["NVTE_FUSED_ATTN"] = "1" + packed_attention_hf = [ + None, + get_test_config_path(), + f"base_output_directory={self._base_output_directory}", + "run_name=runner_test", + f"dataset_path={self.dataset_path}", + "dataset_type=hf", + "hf_path=parquet", + f"hf_train_files={self.dataset_path}/hf/c4/c4-train-00000-of-01637.parquet", + "tokenizer_path=google-t5/t5-large", + "steps=2", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "attention=cudnn_flash_te", + "ici_fsdp_parallelism=-1", + "packing=True", + "max_segments_per_seq=32", + ] + train_main(packed_attention_hf) + @pytest.mark.integration_test @pytest.mark.gpu_only @pytest.mark.skip(reason="b/489133823. Previously transient in b/462548581.") From 948834e556e75c2833f64c1fa785981913fbd7d5 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Wed, 8 Apr 2026 09:45:07 -0400 Subject: [PATCH 3/3] Remove @pytest.mark.external_serving --- tests/integration/train_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index b7cff1ddf1..5c81f059cb 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -584,7 +584,6 @@ def test_gpu_packed_attention(self): @pytest.mark.integration_test @pytest.mark.gpu_only - @pytest.mark.external_serving def test_gpu_packed_attention_hf(self): """Packed (THD) cuDNN TE attention with real HF data.