Skip to content

[DO NOT SUBMIT] EP-TP proof of concept (only supports EP + DP)#3600

Draft
gobbleturk wants to merge 4 commits intomainfrom
mattdavidow-ep-tp
Draft

[DO NOT SUBMIT] EP-TP proof of concept (only supports EP + DP)#3600
gobbleturk wants to merge 4 commits intomainfrom
mattdavidow-ep-tp

Conversation

@gobbleturk
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk commented Apr 8, 2026

Description

Support EP-TP (EP acts like TP for attn/shared expert) for AG-RS (ring_of_experts) path

This is helpful for small token counts e.g. autoregressive inference and small prefills. This has been hacked together.

There are huge problems this PR does not solve - the token sorting and some other ops happen on the fully all gathered tokens (worst case size), which are by far the longest ops for EP>=4 (b/496676734). We need significant code changes and kernels to support ragged performance (e.g. ops grow as O(routed tokens) as opposed to our current O(worst case))

example command on my v6e-8 devbox

alias smoke_train='python3 -m MaxText.train maxtext/configs/base.yml run_name=mattdavidow-train-base base_output_directory=gs://maxtext-experiments-multipod dataset_path=gs://max-datasets-rogue dataset_type=synthetic steps=5 enable_checkpointing=False enable_goodput_recording=False'

alias smoke_moe='smoke_train decoder_block=mixtral num_experts=4 num_experts_per_tok=2 sparse_matmul=True megablox=True per_device_batch_size=4 base_num_decoder_layers=4'

smoke_moe ici_data_parallelism=2 ici_expert_parallelism=4 use_ring_of_experts=True custom_mesh_and_rule=ep-tp num_experts=8 per_device_batch_size=2 expert_shard_attention_option=tp profiler=xplane

profile

Tests

ran above command, generated xprof which looks expected (not great due to sorting + elementwise ops on worst case size tensors)

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

Codecov Report

❌ Patch coverage is 61.53846% with 5 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 58.33% 3 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

else:
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
# This is terrible =(
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this is terrible?

Copy link
Copy Markdown
Collaborator Author

@gobbleturk gobbleturk Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are a lot of things wrong with our shardings any _moe rule should only be used deep inside the moe layer after tokens have been routed. At this point in the model/code things should still be sharded like attention (not routed yet), so we should not use any _moe rule. Additionally the weights below use "activation" logical axis rules when "activation" should only be use for activations

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the resultant physical specs are probably what we want but we got there in a very poor way that is very hard to read / maintain

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants