Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding
from jax.sharding import Mesh
from maxtext.common.common_types import (
Array,
AttentionType,
Expand Down Expand Up @@ -78,7 +78,7 @@
from maxtext.layers.initializers import variable_to_logically_partitioned
from maxtext.layers.quantizations import AqtQuantization as Quant
from maxtext.utils import max_utils
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_name
from maxtext.utils.sharding import logical_to_mesh_axes, maybe_shard_with_pspec
import numpy as np
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_kernel as tokamax_splash_kernel
from tokamax._src.ops.experimental.tpu.splash_attention import splash_attention_mask as tokamax_splash_mask
Expand Down Expand Up @@ -1484,26 +1484,19 @@ def kernel_fn(q, k, v, d, s):

return attention_output, None

def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
# decoder_segment_ids can be None
if pspec is None:
return None
sharding = NamedSharding(self.mesh, pspec)
return maybe_shard_with_name(
inputs,
sharding,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)

query = _maybe_shard_with_pspec(query, axis_names_q)
key = _maybe_shard_with_pspec(key, axis_names_kv)
value = _maybe_shard_with_pspec(value, axis_names_kv)
decoder_segment_ids_q = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_q)
decoder_segment_ids_kv = _maybe_shard_with_pspec(decoder_segment_ids, segment_axis_names_kv)
sinks = _maybe_shard_with_pspec(sinks, sink_axis_names)
indexer_mask = _maybe_shard_with_pspec(indexer_mask, indexer_mask_axis_names)
query = maybe_shard_with_pspec(query, self.mesh, self.config.shard_mode, axis_names_q, self.config.debug_sharding)
key = maybe_shard_with_pspec(key, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding)
value = maybe_shard_with_pspec(value, self.mesh, self.config.shard_mode, axis_names_kv, self.config.debug_sharding)
decoder_segment_ids_q = maybe_shard_with_pspec(
decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_q, self.config.debug_sharding
)
decoder_segment_ids_kv = maybe_shard_with_pspec(
decoder_segment_ids, self.mesh, self.config.shard_mode, segment_axis_names_kv, self.config.debug_sharding
)
sinks = maybe_shard_with_pspec(sinks, self.mesh, self.config.shard_mode, sink_axis_names, self.config.debug_sharding)
indexer_mask = maybe_shard_with_pspec(
indexer_mask, self.mesh, self.config.shard_mode, indexer_mask_axis_names, self.config.debug_sharding
)

ret = wrap_flash_attention(
query,
Expand Down
12 changes: 11 additions & 1 deletion src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from maxtext.kernels import megablox as mblx
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding, maybe_shard_with_pspec
from maxtext.utils.sharding import logical_to_mesh_axes
import numpy as np
import qwix.pallas as qpl
Expand Down Expand Up @@ -1439,6 +1439,16 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
gate_logits = self._maybe_shard_with_logical(gate_logits, gate_logits_axes)
pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, pre_bias_logits_axes)

w0_kernel = maybe_shard_with_pspec(w0_kernel, self.mesh, self.config.shard_mode, w0_pspec)
w1_kernel = maybe_shard_with_pspec(w1_kernel, self.mesh, self.config.shard_mode, w1_pspec)
wo_kernel = maybe_shard_with_pspec(wo_kernel, self.mesh, self.config.shard_mode, wo_pspec)
if w0_bias is not None:
w0_bias = maybe_shard_with_pspec(w0_bias, self.mesh, self.config.shard_mode, w0_bias_pspec)
if w1_bias is not None:
w1_bias = maybe_shard_with_pspec(w1_bias, self.mesh, self.config.shard_mode, w1_bias_pspec)
if wo_bias is not None:
wo_bias = maybe_shard_with_pspec(wo_bias, self.mesh, self.config.shard_mode, wo_bias_pspec)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI we did similar things but in a different style in attention_op.py, see

def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
# decoder_segment_ids can be None
if pspec is None:
return None
sharding = NamedSharding(self.mesh, pspec)
return maybe_shard_with_name(
inputs,
sharding,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)
. It is optional to make the change

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return wrapper(
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias, self.rngs
)
Expand Down
13 changes: 13 additions & 0 deletions src/maxtext/utils/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ def maybe_shard_with_name(
return jax.lax.with_sharding_constraint(inputs, named_sharding)


def maybe_shard_with_pspec(inputs, mesh, shard_mode, pspec: jax.sharding.PartitionSpec | None, debug_sharding=False):
if pspec is None:
return None
sharding = NamedSharding(mesh, pspec)
return maybe_shard_with_name(
inputs,
sharding,
shard_mode=shard_mode,
debug_sharding=debug_sharding,
extra_stack_level=1,
)


def maybe_shard_with_logical(
inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc=""
):
Expand Down
98 changes: 92 additions & 6 deletions tests/integration/smoke/train_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@

""" Smoke test """
import os
import unittest

from absl.testing import absltest
from absl.testing import absltest, parameterized
from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory

from maxtext.common.gcloud_stub import is_decoupled
from maxtext.trainers.pre_train.train import main as train_main
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT


class Train(unittest.TestCase):
class Train(parameterized.TestCase):
"""Smoke test G3 only"""

def setUp(self):
Expand All @@ -46,7 +45,7 @@ def test_tiny_config(self):
# pylint: disable=f-string-without-interpolation
f"base_output_directory={self.base_output_directory}",
"run_name=runner_test",
r"dataset_path={self.dataset_path}",
f"dataset_path={self.dataset_path}",
"base_emb_dim=8",
"base_num_query_heads=4",
"base_num_kv_heads=4",
Expand Down Expand Up @@ -74,7 +73,7 @@ def test_tiny_config_no_scan(self):
# pylint: disable=f-string-without-interpolation
f"base_output_directory={self.base_output_directory}",
"run_name=runner_test",
r"dataset_path={self.dataset_path}",
f"dataset_path={self.dataset_path}",
"base_emb_dim=8",
"base_num_query_heads=4",
"base_num_kv_heads=4",
Expand Down Expand Up @@ -103,7 +102,7 @@ def test_tiny_config_explicit_shardmode(self):
# pylint: disable=f-string-without-interpolation
f"base_output_directory={self.base_output_directory}",
"run_name=runner_test",
r"dataset_path={self.dataset_path}",
f"dataset_path={self.dataset_path}",
"base_emb_dim=8",
"base_num_query_heads=4",
"base_num_kv_heads=4",
Expand All @@ -123,6 +122,93 @@ def test_tiny_config_explicit_shardmode(self):
]
)

def test_tiny_config_explicit_shardmode_deepseek(self):
test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
train_main(
[
None,
get_test_config_path(),
# pylint: disable=f-string-without-interpolation
f"base_output_directory={self.base_output_directory}",
"run_name=runner_test_deepseek",
f"dataset_path={self.dataset_path}",
"model_name=deepseek3-test",
"base_emb_dim=32",
"base_num_query_heads=4",
"base_num_kv_heads=4",
"base_mlp_dim=64",
"base_moe_mlp_dim=64",
"base_num_decoder_layers=2",
"first_num_dense_layers=1",
"head_dim=32",
"v_head_dim=32",
"qk_nope_head_dim=32",
"qk_rope_head_dim=16",
"q_lora_rank=16",
"kv_lora_rank=16",
"per_device_batch_size=1",
"max_target_length=64",
"dataset_type=synthetic",
"steps=2",
"enable_checkpointing=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"sparse_matmul=True",
"capacity_factor=-1",
"shard_mode=explicit",
"enable_goodput_recording=False",
"enable_checkpoint_cloud_logger=False",
"monitor_goodput=False",
"abort_on_nan_loss=False",
"abort_on_inf_loss=False",
]
)

@parameterized.named_parameters(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NuojCheng these tests depend on the hardware type but should work for all topologies with >= 2 devices. wdyt?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think we have v6e-4 for CI tests

("fsdp_expert", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=True"]),
("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]),
("fsdp", ["ici_fsdp_parallelism=-1"]),
)
def test_parallelism_configs(self, parallelism_args):
test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable

base_args = [
None,
get_test_config_path(),
f"base_output_directory={self.base_output_directory}",
"run_name=runner_test_parallelism",
f"dataset_path={self.dataset_path}",
"model_name=deepseek3-test",
"base_emb_dim=32",
"base_num_query_heads=4",
"base_num_kv_heads=4",
"base_mlp_dim=64",
"base_moe_mlp_dim=64",
"base_num_decoder_layers=2",
"first_num_dense_layers=1",
"head_dim=32",
"v_head_dim=32",
"qk_nope_head_dim=32",
"qk_rope_head_dim=16",
"q_lora_rank=16",
"kv_lora_rank=16",
"per_device_batch_size=1",
"max_target_length=64",
"dataset_type=synthetic",
"steps=2",
"enable_checkpointing=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"capacity_factor=-1",
"shard_mode=explicit",
"enable_goodput_recording=False",
"monitor_goodput=False",
"abort_on_nan_loss=False",
"abort_on_inf_loss=False",
]

full_args = base_args + parallelism_args

train_main(full_args)


if __name__ == "__main__":
absltest.main()
Loading