Skip to content

Remove no exp usage from logical rule Part I#3578

Open
NuojCheng wants to merge 3 commits intomainfrom
chengnuojin-no-exp
Open

Remove no exp usage from logical rule Part I#3578
NuojCheng wants to merge 3 commits intomainfrom
chengnuojin-no-exp

Conversation

@NuojCheng
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng commented Apr 6, 2026

Description

This PR deprecates

  • activation_batch_no_exp
  • activation_length_no_exp

from logical names.

After this change

  • activation_batch always includes "expert" physical axis
  • activation_length does not include "expert"

activation_batch_moe and activation_batch_no_exp_moe are kept unchanged in moe.py.

Other logical names containing "_no_exp" will be deprecated in following PR.

Tests

CI tests.

Inference run v5p-8

NEW_MODEL_DESIGN=1 python src/maxtext/inference/vllm_decode.py src/maxtext/configs/post_train/rl.yml model_name=qwen3-30b-a3b tokenizer_path=Qwen/Qwen3-30B-A3B ici_tensor_parallelism=4 ici_expert_parallelism=1 enable_dp_attention=false hbm_utilization_vllm=0.3 load_parameters_path=gs://parambole-qwen3-moe-verification/unscanned/qwen3-30b-a3b-thinking-2507/14_08_2025/0/items vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' prompt="Suggest some famous landmarks in London." 2>&1 | tee  qwen3_moe_vllm_0.log

Output: https://paste.googleplex.com/5103945832857600

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 6, 2026

Codecov Report

❌ Patch coverage is 91.30435% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/decoders.py 66.66% 1 Missing ⚠️
src/maxtext/layers/nnx_decoders.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng changed the title Remove no exp usage from logical rule Remove no exp usage from logical rule 1/N Apr 6, 2026
@RissyRan
Copy link
Copy Markdown
Collaborator

RissyRan commented Apr 6, 2026

Could you help check if this breaks inference workflow? Potentially this has been deprecated due to vLLM migration.

@NuojCheng
Copy link
Copy Markdown
Collaborator Author

Could you help check if this breaks inference workflow? Potentially this has been deprecated due to vLLM migration.

AFAIK inference.yml is deprecated. I will check inference path in https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/inference/vllm.yml.

@NuojCheng NuojCheng force-pushed the chengnuojin-no-exp branch from d72cdeb to ef90d9b Compare April 6, 2026 22:26
@NuojCheng NuojCheng changed the title Remove no exp usage from logical rule 1/N Remove no exp usage from logical rule Part I Apr 7, 2026
@NuojCheng NuojCheng marked this pull request as ready for review April 7, 2026 21:33
@NuojCheng NuojCheng requested a review from igorts-git as a code owner April 7, 2026 21:33
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits = self._maybe_shard_with_logical(
pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None)
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this should be activation_batch and activation_length - not activation_batch_moe and activation_length_moe

its a bit subtle and I'm not sure how to educate everyone, but the _moe rules should only be used deep inside moe after the tokens have moved to their expert shard. At this point in the model/code they have not been routed to their expert shard yet, so we should use the non _moe

this actual tensor, the logits, never need to get routed so never need to use _moe shardings....

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be changed for this PR I guess

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye I set _moe for all logical names inside RouteMoE, which is wrong. Now I understand why.

Copy link
Copy Markdown
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants