Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ use_2d_fsdp_sharding: False
# deepseek moe
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim.
first_num_dense_layers: 0 # number of initial dense layers in the model
shared_experts: 1
shared_experts: 0
routed_scaling_factor: 1.0 # scaling factor for routing scores
routed_score_func: "" # scoring function for routing
routed_bias: False # a flag if a learnable bias is added for routing
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ class DeepSeekMoE(BaseModel):

base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
shared_experts: NonNegativeInt = Field(0, description="Number of shared experts.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Allowing shared_experts to be 0 is necessary for models that do not use shared experts.

routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').")
routed_bias: bool = Field(False, description="Whether to add a bias term for routing.")
Expand Down
193 changes: 153 additions & 40 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,18 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo
Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder
layer and we use sliding window attention in local_attention
"""
noncausal_attention_flops = (
# global attention
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
+
# local attention
window = min(config.sliding_window_size, config.max_target_length)
global_causal_flops = (
2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
)
local_causal_flops = (
4
* config.per_device_batch_size
* config.max_target_length
* min(config.sliding_window_size, config.max_target_length)
* (config.max_target_length * window - 0.5 * window**2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Good catch on the causal sliding window FLOPs calculation. The new formula (config.max_target_length * window - 0.5 * window**2) is more accurate than the previous approximation.

* config.num_query_heads
* config.head_dim
)
causal_attention_flops = noncausal_attention_flops / 2
causal_attention_flops = global_causal_flops + local_causal_flops
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12

# multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer
Expand All @@ -238,7 +237,7 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo


def calculate_mixed_attention_model_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length
config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length
):
"""
Calculate training TFLOPs for models with a mixed attention pattern of local
Expand All @@ -249,34 +248,125 @@ def calculate_mixed_attention_model_tflops_training_per_device(
num_global_layers = num_layers // attention_pattern_length
num_local_layers = num_layers - num_global_layers

# FLOPs for a single global attention layer (full attention)
# Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim
global_attention_flops_per_layer = (
4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
# Global causal attention uses a multiplier of 2 (instead of 4 for non-causal)
# since we only compute the lower triangular half of the attention matrix.
global_causal_flops_per_layer = (
2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim
)

# FLOPs for a single local attention layer (sliding window)
# Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim
local_attention_flops_per_layer = (
# Local sliding window attention directly computes the exact causal interactions
# via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4.
window = min(config.sliding_window_size, config.max_target_length)
local_causal_flops_per_layer = (
4
* config.per_device_batch_size
* config.max_target_length
* min(config.sliding_window_size, config.max_target_length)
* (config.max_target_length * window - 0.5 * window**2)
* config.num_query_heads
* config.head_dim
)

# Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local)
noncausal_attention_flops = (
num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer
causal_attention_flops = (
num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer
)

# Convert to TFLOPs and multiply by 3 for fwd/bwd pass
attention_tflops = causal_attention_flops * 3 / 10**12

total_learnable_flops = total_ffn_flops_all_layers

total_learnable_flops += (qkv_flops + projection_flops) * num_layers + embedding_flops

learnable_weight_tflops = total_learnable_flops * 3 / 10**12

return attention_tflops, learnable_weight_tflops


def calculate_gemma4_tflops_training_per_device(
config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length
):
"""
Calculate training TFLOPs for Gemma 4.
Gemma 4 has specific quirks:
- Different QKV projection sizes for local vs. global layers.
- Global-only KV sharing and varying global head dimensions.
"""
num_layers = config.num_decoder_layers

num_global_layers = num_layers // attention_pattern_length
num_local_layers = num_layers - num_global_layers

kv_multiplier = 1 if config.share_kv_projections else 2
global_head_dim = config.global_head_dim or config.head_dim
global_num_kv_heads = config.global_num_kv_heads or config.num_kv_heads

# Global causal attention uses a multiplier of 2 (instead of 4 for non-causal)
# since we only compute the lower triangular half of the attention matrix.
global_causal_flops_per_layer = (
2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * global_head_dim
)

# Local sliding window attention directly computes the exact causal interactions
# via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4.
window = min(config.sliding_window_size, config.max_target_length)
local_causal_flops_per_layer = (
4
* config.per_device_batch_size
* (config.max_target_length * window - 0.5 * window**2)
* config.num_query_heads
* config.head_dim
)

causal_attention_flops = (
num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer
)
causal_attention_flops = noncausal_attention_flops / 2

# Convert to TFLOPs and multiply by 3 for fwd/bwd pass
attention_tflops = causal_attention_flops * 3 / 10**12

# Learnable weights (FFN, QKV, Projections) are present in every layer.
learnable_weight_tflops = ((total_ffn_flops + qkv_flops + projection_flops) * num_layers + embedding_flops) * 3 / 10**12
global_qkv_flops_per_layer = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* (config.num_query_heads + kv_multiplier * global_num_kv_heads)
* global_head_dim
)
global_projection_flops_per_layer = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* config.num_query_heads
* global_head_dim
)

# Local layers never share KV projections (kv_multiplier is always 2).
local_qkv_flops_per_layer = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* (config.num_query_heads + 2 * config.num_kv_heads)
* config.head_dim
)
local_projection_flops_per_layer = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* config.num_query_heads
* config.head_dim
)

total_learnable_flops = total_ffn_flops_all_layers

total_learnable_flops += (
(local_qkv_flops_per_layer + local_projection_flops_per_layer) * num_local_layers
+ (global_qkv_flops_per_layer + global_projection_flops_per_layer) * num_global_layers
+ embedding_flops
)

learnable_weight_tflops = total_learnable_flops * 3 / 10**12

return attention_tflops, learnable_weight_tflops

Expand Down Expand Up @@ -493,11 +583,13 @@ def get_dense_moe_layers(config):
elif config.decoder_block == DecoderBlockType.LLAMA4:
num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step
num_dense_layers = config.num_decoder_layers - num_moe_layers
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
num_moe_layers = config.num_decoder_layers
num_dense_layers = 0
else:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The fallback here assumes that if `num_experts > 1`, all layers are MoE layers. While this might be true for current Gemma 4 configs, it's a broad assumption that could lead to inaccuracies if future models interleave dense and MoE layers but aren't explicitly handled in the `if/elif` blocks above.

It might be safer to check for specific model families or rely on a more explicit config flag for layer interleaving if available.

raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.")
if config.num_experts > 1:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium - This fallback logic for MoE layer detection can be made more general by respecting config.first_num_dense_layers for any MoE architecture that doesn't match the specific interleaved logic of Llama 4.

Suggested change
if config.num_experts > 1:
if config.num_experts > 1:
num_dense_layers = config.first_num_dense_layers
num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers
else:

num_moe_layers = config.num_decoder_layers
num_dense_layers = 0
else:
num_moe_layers = 0
num_dense_layers = config.num_decoder_layers

return num_dense_layers, num_moe_layers

Expand Down Expand Up @@ -598,6 +690,7 @@ def calculate_gemma3_vision_layers_tflops_per_device(config):
learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer
else:
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning.


# Convert to TFLOPs
learnable_weight_tflops = learnable_weight_flops / 1e12
Expand Down Expand Up @@ -660,6 +753,7 @@ def calculate_llama4_vision_layers_tflops_per_device(config):
learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer
else:
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning.


# Convert to TFLOPs
learnable_weight_tflops = learnable_weight_flops / 1e12
Expand Down Expand Up @@ -723,28 +817,40 @@ def calculate_vision_encoder_tflops(config):
def calculate_tflops_training_per_device(config, log=True):
"""Calculate training TFLOP"""
# MLP flops
is_ffn_flops_already_total = False
if config.num_experts > 1:
# calculation based on dropless implementation
if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT):
if config.decoder_block in (
DecoderBlockType.DEEPSEEK,
DecoderBlockType.LLAMA4,
DecoderBlockType.QWEN3_NEXT,
DecoderBlockType.GEMMA4,
):
total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config)
is_ffn_flops_already_total = True
else:
gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
total_ffn_flops = (
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.num_experts_per_tok
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - Correcting this to use config.moe_mlp_dim is essential for accurate FLOPs in MoE models, as MoE layers typically use a different intermediate dimension than dense layers.

Suggested change
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok

)
else:
total_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim)

total_ffn_flops_all_layers = (
total_ffn_flops if is_ffn_flops_already_total else total_ffn_flops * config.num_decoder_layers
)

# Attention flops
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The `kv_multiplier` is calculated here based on `config.share_kv_projections` and applied to `qkv_flops`. However, `qkv_flops` is used as the base for local attention layers in models like Gemma 3/4 (via `calculate_mixed_attention_model_tflops_training_per_device`).

According to the field description in types.py, share_kv_projections only applies to global attention. Applying it to qkv_flops here will incorrectly reduce the FLOPs estimate for local attention layers if share_kv_projections is True.

Consider calculating qkv_flops for local layers without this multiplier, or pass a separate multiplier for global layers.

Suggested change
# Attention flops
qkv_flops = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* (config.num_query_heads + 2 * config.num_kv_heads)
* config.head_dim
)

if config.attention_type == "mla":
qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config)
else:
kv_multiplier = 1 if config.share_kv_projections else 2
qkv_flops = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* (config.num_query_heads + 2 * config.num_kv_heads)
* (config.num_query_heads + kv_multiplier * config.num_kv_heads)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Factoring in share_kv_projections for QKV FLOPs ensures accuracy for models that share key and value projections.

Suggested change
* (config.num_query_heads + kv_multiplier * config.num_kv_heads)
* (config.num_query_heads + kv_multiplier * config.num_kv_heads)

* config.head_dim
)
noncausal_attention_flops = (
Expand All @@ -765,7 +871,8 @@ def calculate_tflops_training_per_device(config, log=True):
# NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272
causal_attention_flops = noncausal_attention_flops / 2

# Embedding flops
# Embedding flops (counts only the unembedding projection; the embedding lookup is a gather operation
# that performs no dense math, matching standard MFU hardware calculations)
embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size

# Combine flops with number of decoder layers
Expand All @@ -775,26 +882,30 @@ def calculate_tflops_training_per_device(config, log=True):
)
elif config.decoder_block == DecoderBlockType.GEMMA3:
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6
config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6
)
elif config.decoder_block == DecoderBlockType.GPT_OSS:
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2
config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2
)
elif config.decoder_block == DecoderBlockType.LLAMA4:
# Use the new helper to calculate attention TFLOPs correctly.
attention_tflops = calculate_llama4_attention_tflops(config)
# The learnable weight calculation remains the same as it correctly handles Llama4's MoE structure.
learnable_weight_tflops = (
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
(total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops)
* 3
/ 10**12
)
elif config.decoder_block == DecoderBlockType.GEMMA4:
attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device(
config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6
attention_tflops, learnable_weight_tflops = calculate_gemma4_tflops_training_per_device(
config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length=6
)
elif config.decoder_block == DecoderBlockType.DEEPSEEK:
learnable_weight_tflops = (
(total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
(total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops)
* 3
/ 10**12
)
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12
elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
Expand All @@ -805,7 +916,7 @@ def calculate_tflops_training_per_device(config, log=True):

# Weights TFLOPs:
total_weights = (
total_ffn_flops
total_ffn_flops_all_layers
+ embedding_flops
+ (qkv_flops + projection_flops) * num_full_attn_layers
+ gdn_weight_flops_per_layer * num_linear_attn_layers
Expand All @@ -818,7 +929,9 @@ def calculate_tflops_training_per_device(config, log=True):
else:
# multiply by 3 for both feed forward and back propagation flops
learnable_weight_tflops = (
((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12
(total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops)
* 3
/ 10**12
)
attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12

Expand Down
Loading
Loading