From 9e28a26bd1b3f1041f8d660c4970bf73d1fb0115 Mon Sep 17 00:00:00 2001 From: Jimmy Tsai Date: Tue, 7 Apr 2026 10:07:02 +0000 Subject: [PATCH] Fix qwen2 vllm_mapping for tunix --- .../integration/tunix/weight_mapping/qwen2.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/maxtext/integration/tunix/weight_mapping/qwen2.py b/src/maxtext/integration/tunix/weight_mapping/qwen2.py index 6a93ffd402..16ffae58e2 100644 --- a/src/maxtext/integration/tunix/weight_mapping/qwen2.py +++ b/src/maxtext/integration/tunix/weight_mapping/qwen2.py @@ -67,12 +67,12 @@ def to_hf_mapping(): return { # Token embeddings - shard vocab dimension "base.token_embedder.embedding": ( - "model.embed.embedding", + "model.embed_tokens.weight", ("model", None), ), # Final layer norm - no sharding needed "base.decoder.decoder_norm.scale": ( - "model.norm.scale", + "model.norm.weight", (None,), ), # LM head (logits projection) - shard vocab dimension @@ -83,41 +83,41 @@ def to_hf_mapping(): # Layer-specific mappings (scanned -> unscanned) # MLP components - shard hidden dimensions "base.decoder.layers.mlp.wi_0.kernel": ( - "model.layers.*.mlp.gate_proj.kernel", + "model.layers.*.mlp.gate_proj.weight", (None, "layer", "model"), ), "base.decoder.layers.mlp.wi_1.kernel": ( - "model.layers.*.mlp.up_proj.kernel", + "model.layers.*.mlp.up_proj.weight", (None, "layer", "model"), ), "base.decoder.layers.mlp.wo.kernel": ( - "model.layers.*.mlp.down_proj.kernel", + "model.layers.*.mlp.down_proj.weight", ("model", "layer", None), ), # Layer norms - no sharding needed "base.decoder.layers.pre_self_attention_layer_norm.scale": ( - "model.layers.*.input_layernorm.scale", + "model.layers.*.input_layernorm.weight", (None, "layer"), ), "base.decoder.layers.post_self_attention_layer_norm.scale": ( - "model.layers.*.post_attention_layernorm.scale", + "model.layers.*.post_attention_layernorm.weight", (None, "layer"), ), # Attention components - shard head dimensions "base.decoder.layers.self_attention.query.kernel": ( - "model.layers.*.self_attn.q_proj.kernel", + "model.layers.*.self_attn.q_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.key.kernel": ( - "model.layers.*.self_attn.k_proj.kernel", + "model.layers.*.self_attn.k_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.value.kernel": ( - "model.layers.*.self_attn.v_proj.kernel", + "model.layers.*.self_attn.v_proj.weight", (None, "layer", "model", None), ), "base.decoder.layers.self_attention.out.kernel": ( - "model.layers.*.self_attn.o_proj.kernel", + "model.layers.*.self_attn.o_proj.weight", ("model", "layer", None, None), ), # Attention biases