diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index d265459e69..ee91098623 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1570,10 +1570,14 @@ def cudnn_flash_attention( qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None) + attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=decoder_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) # Create dummy SequenceDescriptor for lazy_init dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=dummy_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) max_segments_per_seq = self.config.max_segments_per_seq elif using_context_parallelism: if self.attention_type == AttentionType.LOCAL_SLIDING: diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index cc55f07844..5c81f059cb 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -582,6 +582,45 @@ def test_gpu_packed_attention(self): ] train_main(packed_attention) + @pytest.mark.integration_test + @pytest.mark.gpu_only + def test_gpu_packed_attention_hf(self): + """Packed (THD) cuDNN TE attention with real HF data. + + Unlike test_gpu_packed_attention (which uses synthetic data and bypasses + SequenceDescriptor), this test exercises the SequenceDescriptor.from_segment_ids_and_pos() + codepath that is only reached when packing=True AND dataset_type != 'synthetic'. + Regression test for NVIDIA/TransformerEngine#2523. + """ + gpu_device = jax.devices("gpu")[0] + compute_capability = getattr(gpu_device, "compute_capability", None) + try: + if float(compute_capability) < 9.0: + pytest.skip("Packed (THD) attention is only supported on sm90+!") + except Exception: # pylint: disable=broad-exception-caught + print("checking if Packed THD attention is supported on this host...") + pytest.skip("Packed (THD) attention is only supported on sm90+!") + os.environ["NVTE_FUSED_ATTN"] = "1" + packed_attention_hf = [ + None, + get_test_config_path(), + f"base_output_directory={self._base_output_directory}", + "run_name=runner_test", + f"dataset_path={self.dataset_path}", + "dataset_type=hf", + "hf_path=parquet", + f"hf_train_files={self.dataset_path}/hf/c4/c4-train-00000-of-01637.parquet", + "tokenizer_path=google-t5/t5-large", + "steps=2", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "attention=cudnn_flash_te", + "ici_fsdp_parallelism=-1", + "packing=True", + "max_segments_per_seq=32", + ] + train_main(packed_attention_hf) + @pytest.mark.integration_test @pytest.mark.gpu_only @pytest.mark.skip(reason="b/489133823. Previously transient in b/462548581.")