diff --git a/docs/reference/core_concepts/moe_configuration.md b/docs/reference/core_concepts/moe_configuration.md index 7ce7d63110..c150022b15 100644 --- a/docs/reference/core_concepts/moe_configuration.md +++ b/docs/reference/core_concepts/moe_configuration.md @@ -96,11 +96,6 @@ Dropping: ## 2. Sharding -`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include: - -- `fsdp`: Treats the expert axis as a FSDP axis. -- `context`: Treats the expert axis as a context parallelism axis, useful for long context. - `use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication. `moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable. diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 8ab7182779..ec2e96333f 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -66,10 +66,6 @@ MODEL_MODE_PREFILL = "prefill" MODEL_MODE_TRAIN = "train" -# expert_shard_attention_option -EP_AS_CONTEXT = "context" -EP_AS_FSDP = "fsdp" - DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 # A large negative mask value is used for masking to ensure that the diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 99d1afe436..738628ea9f 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -237,11 +237,6 @@ merge_gating_gmm: False norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights. -# how the expert axis is used to shard attention weights and activations -# "fsdp" (ep acts as fsdp parallelism) -# "context" (ep acts as context parallelism, training only) -expert_shard_attention_option: "fsdp" - # when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls moe_fsdp_use_two_stage_all_gather: false # Shard the expert dimension of the MLP weights on the FSDP axis. @@ -453,92 +448,119 @@ compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_s shard_mode: "auto" # can be either auto or explicit custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/. mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] -logical_axis_rules: [ - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], +logical_axis_rules: [ + # ========================================== + # Vocabulary Embedding + # ========================================== + # Vocab Activations ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], - ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['sequence', 'context']], - ['activation_length', ['context']], + ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose']], + ['activation_vocab', 'tensor_sequence'], + ['activation_vocab', ['sequence', 'context']], + # Vocab Weights + ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + # ========================================== + # Attention + # ========================================== + # Attention Activations + ['activation_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive']], + ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], ['activation_attn_length', ['sequence', 'context']], - ['activation_attn_length', ['context']], - ['activation_length_moe', ['sequence', 'context']], - ['activation_length_moe', ['context']], - ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], - ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']], + # ['activation_attn_length', ['context']], ['activation_q_length', ['context']], - ['prefill_activation_length', ['sequence', 'context']], - ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']], ['activation_kv_length', []], ['activation_attn_embed', ['tensor', 'tensor_transpose']], - ['activation_embed', ['tensor', 'tensor_transpose']], - ['activation_embed_moe', ['tensor', 'tensor_transpose']], - ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']], ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose']], - ['activation_vocab', 'tensor_sequence'], - ['activation_vocab', ['sequence','context']], - ['activation_stage', 'stage'], - ['activation_exp', ['expert']], - ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['decode_length', ['sequence']], - ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], - ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + # Attention Weights ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_moe', ['fsdp', 'sequence', 'context']], - ['embed_tensor_transpose', ['tensor_transpose']], + ['qkv', []], + ['kv', []], + ['kv_head_dim', []], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'expert']], - ["q_lora_up_proj",[]], + ["q_lora_up_proj", []], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], ['kv_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['kv_lora', ['fsdp', 'sequence', 'context', 'expert']], - ["kv_lora_up_proj",[]], + ["kv_lora_up_proj", []], + # ========================================== + # Mixture of Experts (MoE) + # ========================================== + # MoE Activations + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_length_moe', ['sequence', 'context']], + # ['activation_length_moe', ['context']], + ['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']], + ['activation_embed_moe', ['tensor', 'tensor_transpose']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_exp', ['expert']], + # MoE Weights + ['exp', 'expert'], + ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'context']], + # ========================================== + # Standard MLP / Dense Layers / Model Structure + # ========================================== + # Dense Activations + ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_length', ['sequence', 'context']], + # ['activation_length', ['context']], + ['activation_norm_length', ['tensor_sequence', 'context', 'sequence']], + ['activation_embed', ['tensor', 'tensor_transpose']], + ['activation_stage', 'stage'], + # General Weights + ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], + ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context', 'expert']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], + ['embed', ['fsdp', 'sequence', 'context', 'expert']], ['norm', ['tensor', 'tensor_transpose']], ['layers', 'stage'], - ['qkv', []], - ['kv', []], - ['kv_head_dim', []], + ['diloco', 'diloco'], + ['engram_dim', ['tensor']], + ['dense_layers', []], + ['moe_layers', []], + ['mhc', []], + # ========================================== + # Inference(Prefill, Decode, Cache) + # ========================================== + ['prefill_activation_length', ['sequence', 'context']], + ['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']], + ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['decode_length', ['sequence']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], + ['paged_kv_heads', ['tensor']], ['cache_batch_prefill', []], ['cache_batch', []], ['cache_heads_none', []], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], ['cache_kv', []], ['cache_sequence', []], - ['exp', 'expert'], - ['exp_with_fsdp', 'fsdp'], - ['paged_kv_heads', ['tensor']], ['num_pages', []], ['tokens_per_page', []], ['paged_kv_head_dim_size', []], - ['dense_layers', []], - ['moe_layers', []], - ['engram_dim', ['tensor']], - ['mhc', []], - ['diloco', 'diloco'], - ] + # ========================================== + # Deprecated / Scheduled for Removal + # ========================================== + ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], + ['embed_tensor_transpose', ['tensor_transpose']], + ['exp_with_fsdp', 'fsdp'], + ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length'] diff --git a/src/maxtext/configs/inference/inference.yml b/src/maxtext/configs/inference/inference.yml index 55407b3edc..fa972a0343 100644 --- a/src/maxtext/configs/inference/inference.yml +++ b/src/maxtext/configs/inference/inference.yml @@ -28,6 +28,7 @@ logical_axis_rules: [ ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], + ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/post_train/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml index 4383b1c4ac..715f02962d 100644 --- a/src/maxtext/configs/post_train/rl_mt_jt.yml +++ b/src/maxtext/configs/post_train/rl_mt_jt.yml @@ -42,6 +42,7 @@ logical_axis_rules: [ ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], + ['embed_vocab', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 39e5030543..ed0ad0709b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -661,10 +661,6 @@ class MoEGeneral(BaseModel): ) use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.") interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.") - expert_shard_attention_option: Literal["fsdp", "context"] = Field( - "fsdp", - description="How the expert axis is used to shard attention weights and activations.", - ) moe_fsdp_use_two_stage_all_gather: bool = Field( False, description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.", @@ -2393,8 +2389,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.tensors_to_offload = [t for t in tensors if getattr(self, t) == "offload"] cp_size = self.ici_context_parallelism * self.dcn_context_parallelism - if self.expert_shard_attention_option == "context": - cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism self.context_parallel_size = cp_size if self.pipeline_parallel_layers == -1: if self.decoder_block == DecoderBlockType.DEEPSEEK: diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index 13bec20f6c..71a013b2ec 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -48,7 +48,6 @@ D_KV, DType, EMBED, - EP_AS_CONTEXT, HEAD, Q_LORA_UP_PROJ, KV_BATCH, @@ -901,9 +900,6 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode): if model_mode == MODEL_MODE_PREFILL: key_logical_name = self.prefill_key_axis_names value_logical_name = self.prefill_value_axis_names - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - key_logical_name = self.ep_key_axis_names - value_logical_name = self.ep_value_axis_names else: key_logical_name = self.key_axis_names value_logical_name = self.value_axis_names @@ -1224,10 +1220,7 @@ def __call__( ) out = jax.ad_checkpoint.checkpoint_name(out, "attention_out") - if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: - out = self._maybe_shard_with_logical(out, self.ep_out_axis_names) - else: - out = self._maybe_shard_with_logical(out, self.out_axis_names) + out = self._maybe_shard_with_logical(out, self.out_axis_names) out_sharding = create_sharding(self.mesh, out_logical_name) out = self.out_projection(out, out_sharding=out_sharding) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 3bcd07e3e1..3df9b80d95 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -55,7 +55,6 @@ DEFAULT_MASK_VALUE, DType, D_KV, - EP_AS_FSDP, HEAD, KV_LENGTH, LENGTH, @@ -1270,7 +1269,7 @@ def wrap_splash_kernel(single_head_mask): splash_kernel = wrap_splash_kernel(single_head_mask) segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,)) - elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP: + elif self.config.use_jax_splash: if self.config.use_max_logit_estimate > 0: sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate) segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,)) diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index f5608e6465..26c2f8a470 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -736,7 +736,7 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi out_features_shape=cfg.vocab_size, weight_dtype=cfg.weight_dtype, dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=("embed", "vocab"), + kernel_axes=("embed_vocab", "vocab"), shard_mode=cfg.shard_mode, name="logits_dense", matmul_precision=self.config.matmul_precision, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index da7634960e..36b843e1b5 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -132,7 +132,7 @@ def __init__( (self.num_embeddings, self.num_features), self.config.weight_dtype, ), - sharding=("vocab", "embed"), + sharding=("vocab", "embed_vocab"), ) def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 3b0de8e0da..3c8a601201 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -287,7 +287,7 @@ def __init__( out_features_shape=config.vocab_size, weight_dtype=config.weight_dtype, dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype, - kernel_axes=("embed", "vocab"), + kernel_axes=("embed_vocab", "vocab"), shard_mode=config.shard_mode, matmul_precision=self.config.matmul_precision, parameter_memory_host_offload=config.parameter_memory_host_offload, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json index 5cfea0ee37..8d30b919f8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json @@ -133,7 +133,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -318,7 +318,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -459,7 +459,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -644,7 +644,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -781,7 +781,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -966,7 +966,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json index 5cfea0ee37..8d30b919f8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json @@ -133,7 +133,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -318,7 +318,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -459,7 +459,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -644,7 +644,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -781,7 +781,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -966,7 +966,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json index 5cfea0ee37..8d30b919f8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json @@ -133,7 +133,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -318,7 +318,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -459,7 +459,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -644,7 +644,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -781,7 +781,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -966,7 +966,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json index 5cfea0ee37..8d30b919f8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json @@ -133,7 +133,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -318,7 +318,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -459,7 +459,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -644,7 +644,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -781,7 +781,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -966,7 +966,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json index 5cfea0ee37..8d30b919f8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json @@ -133,7 +133,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -318,7 +318,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -459,7 +459,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -644,7 +644,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -781,7 +781,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -966,7 +966,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json index 5cfea0ee37..8d30b919f8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json @@ -133,7 +133,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -318,7 +318,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -459,7 +459,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -644,7 +644,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, @@ -781,7 +781,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -966,7 +966,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 102400, diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json index c944c8e273..119ddf8c82 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json @@ -477,7 +477,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -488,7 +488,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -973,7 +973,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -984,7 +984,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -1465,7 +1465,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -1476,7 +1476,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json index c944c8e273..119ddf8c82 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json @@ -477,7 +477,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -488,7 +488,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -973,7 +973,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -984,7 +984,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -1465,7 +1465,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -1476,7 +1476,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json index c944c8e273..119ddf8c82 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json @@ -477,7 +477,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -488,7 +488,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -973,7 +973,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -984,7 +984,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -1465,7 +1465,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -1476,7 +1476,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json index c944c8e273..119ddf8c82 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json @@ -477,7 +477,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -488,7 +488,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -973,7 +973,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -984,7 +984,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -1465,7 +1465,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -1476,7 +1476,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json index c944c8e273..119ddf8c82 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json @@ -477,7 +477,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -488,7 +488,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -973,7 +973,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -984,7 +984,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -1465,7 +1465,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -1476,7 +1476,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json index c944c8e273..119ddf8c82 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json @@ -477,7 +477,7 @@ }, ".params/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -488,7 +488,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -973,7 +973,7 @@ }, ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -984,7 +984,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, @@ -1465,7 +1465,7 @@ }, ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { "partition_spec": [ - "embed", + "embed_vocab", "vocab" ], "shape": [ @@ -1476,7 +1476,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 201088, diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json index 487e9bb959..0530ce7dce 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json @@ -146,7 +146,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -300,7 +300,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -450,7 +450,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json index 487e9bb959..0530ce7dce 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json @@ -146,7 +146,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -300,7 +300,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -450,7 +450,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json index 487e9bb959..0530ce7dce 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json @@ -146,7 +146,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -300,7 +300,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -450,7 +450,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json index 487e9bb959..0530ce7dce 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json @@ -146,7 +146,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -300,7 +300,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -450,7 +450,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json index 487e9bb959..0530ce7dce 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json @@ -146,7 +146,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -300,7 +300,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -450,7 +450,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json index 487e9bb959..0530ce7dce 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json @@ -146,7 +146,7 @@ ".params/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -300,7 +300,7 @@ ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936, @@ -450,7 +450,7 @@ ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { "partition_spec": [ "vocab", - "embed" + "embed_vocab" ], "shape": [ 151936,