[DO NOT SUBMIT] EP-TP proof of concept (only supports EP + DP)#3600
[DO NOT SUBMIT] EP-TP proof of concept (only supports EP + DP)#3600gobbleturk wants to merge 4 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| else: | ||
| input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) | ||
| # This is terrible =( | ||
| input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed")) |
There was a problem hiding this comment.
Can you explain why this is terrible?
There was a problem hiding this comment.
there are a lot of things wrong with our shardings any _moe rule should only be used deep inside the moe layer after tokens have been routed. At this point in the model/code things should still be sharded like attention (not routed yet), so we should not use any _moe rule. Additionally the weights below use "activation" logical axis rules when "activation" should only be use for activations
There was a problem hiding this comment.
the resultant physical specs are probably what we want but we got there in a very poor way that is very hard to read / maintain
Description
Support EP-TP (EP acts like TP for attn/shared expert) for AG-RS (ring_of_experts) path
This is helpful for small token counts e.g. autoregressive inference and small prefills. This has been hacked together.
There are huge problems this PR does not solve - the token sorting and some other ops happen on the fully all gathered tokens (worst case size), which are by far the longest ops for EP>=4 (b/496676734). We need significant code changes and kernels to support ragged performance (e.g. ops grow as O(routed tokens) as opposed to our current O(worst case))
example command on my v6e-8 devbox
profile
Tests
ran above command, generated xprof which looks expected (not great due to sorting + elementwise ops on worst case size tensors)
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.