Remove no exp usage from logical rule Part I#3578
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
Could you help check if this breaks inference workflow? Potentially this has been deprecated due to vLLM migration. |
AFAIK inference.yml is deprecated. I will check inference path in https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/inference/vllm.yml. |
d72cdeb to
ef90d9b
Compare
| # pre_bias_logits is None for non-DeepSeek v3 models | ||
| pre_bias_logits = self._maybe_shard_with_logical( | ||
| pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None) | ||
| pre_bias_logits, ("activation_batch_moe", "activation_length_moe", None) |
There was a problem hiding this comment.
note this should be activation_batch and activation_length - not activation_batch_moe and activation_length_moe
its a bit subtle and I'm not sure how to educate everyone, but the _moe rules should only be used deep inside moe after the tokens have moved to their expert shard. At this point in the model/code they have not been routed to their expert shard yet, so we should use the non _moe
this actual tensor, the logits, never need to get routed so never need to use _moe shardings....
There was a problem hiding this comment.
Doesn't have to be changed for this PR I guess
There was a problem hiding this comment.
ye I set _moe for all logical names inside RouteMoE, which is wrong. Now I understand why.
Description
This PR deprecates
activation_batch_no_expactivation_length_no_expfrom logical names.
After this change
activation_batchalways includes "expert" physical axisactivation_lengthdoes not include "expert"activation_batch_moeandactivation_batch_no_exp_moeare kept unchanged inmoe.py.Other logical names containing "_no_exp" will be deprecated in following PR.
Tests
CI tests.
Inference run v5p-8
Output: https://paste.googleplex.com/5103945832857600
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.