Skip to content

Add Gemma 4 FLOPs & fix sliding window flops computations#3592

Open
gagika wants to merge 1 commit intomainfrom
agagik-gemma4-flops
Open

Add Gemma 4 FLOPs & fix sliding window flops computations#3592
gagika wants to merge 1 commit intomainfrom
agagik-gemma4-flops

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Apr 7, 2026

Description

Adds TFLOPs calculations for the Gemma 4 architecture (including MoE) and fixes several inaccuracies in existing FLOPs math (sliding window overlap, vision encoder scaling, and shared KV projections).

  • Gemma 4 & MoE: Added Gemma 4 support. Fixed fallback MoE calculations to correctly use moe_mlp_dim and generalized MoE layer detection (num_experts > 1).
  • Sliding Window: Corrected local causal FLOPs to account for triangular overlap (max_target_length * window - 0.5 * window**2).
  • Vision Encoders: Fixed backward pass scaling (x3) for Gemma 3 and Llama 4 when parameters are unfreezed.
  • KV Projections: Factored in share_kv_projections for accurate QKV FLOPs.

Tests

Added maxtext_utils_flops_test.py to validate FLOPs calculations across 12 model architectures.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 7, 2026

Codecov Report

❌ Patch coverage is 95.45455% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/maxtext_utils.py 95.45% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@gagika gagika force-pushed the agagik-gemma4-flops branch from e46459c to 038d9e9 Compare April 7, 2026 21:12
@gagika gagika changed the title Gemma4 TFLOPs calculations and fix causal attention flops for sliding window attention Add Gemma 4 FLOPs & fix sliding window flops computations Apr 7, 2026
@gagika gagika force-pushed the agagik-gemma4-flops branch from 038d9e9 to ccbcf03 Compare April 7, 2026 21:29
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces Gemma 4 FLOPs calculations and significantly improves the accuracy of existing FLOPs math, particularly for sliding window attention and mixed attention architectures. The addition of a comprehensive test suite covering multiple model families is a major highlight and ensures the reliability of these critical metrics.

🔍 General Feedback

  • Great Test Coverage: The new maxtext_utils_flops_test.py is excellent. It uses a robust 6 * params * tokens verification strategy that provides high confidence in the computed TFLOPs across various architectures.
  • Improved Accuracy: The fixes for sliding window area and vision encoder scaling (backward pass) are well-timed and correct.
  • Inconsistency in Shared KV Projections: There is a potential logic error in how share_kv_projections is applied to mixed attention models in the main caller. One unit test specifically assumes local layers do not share KV projections even when the flag is True, but the code currently applies it to both.
  • MoE Fallback Logic: The fallback for MoE layer detection is now more generalized, which is good, but might be too broad for future hybrid architectures.

config.decoder_block = maxtext_utils.DecoderBlockType.DEEPSEEK
config.per_device_batch_size = 1
config.max_target_length = 2048
config.emb_dim = 1024
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Minor typo: "mamtul" -> "matmul".
Suggested change
config.emb_dim = 1024
# dense ffn matmul (silu: 2 * mlp_dim)

total_ffn_flops if is_ffn_flops_already_total else total_ffn_flops * config.num_decoder_layers
)

# Attention flops
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The `kv_multiplier` is calculated here based on `config.share_kv_projections` and applied to `qkv_flops`. However, `qkv_flops` is used as the base for local attention layers in models like Gemma 3/4 (via `calculate_mixed_attention_model_tflops_training_per_device`).

According to the field description in types.py, share_kv_projections only applies to global attention. Applying it to qkv_flops here will incorrectly reduce the FLOPs estimate for local attention layers if share_kv_projections is True.

Consider calculating qkv_flops for local layers without this multiplier, or pass a separate multiplier for global layers.

Suggested change
# Attention flops
qkv_flops = (
2
* config.per_device_batch_size
* config.max_target_length
* config.emb_dim
* (config.num_query_heads + 2 * config.num_kv_heads)
* config.head_dim
)

elif config.decoder_block == DecoderBlockType.QWEN3_NEXT:
num_moe_layers = config.num_decoder_layers
num_dense_layers = 0
else:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The fallback here assumes that if `num_experts > 1`, all layers are MoE layers. While this might be true for current Gemma 4 configs, it's a broad assumption that could lead to inaccuracies if future models interleave dense and MoE layers but aren't explicitly handled in the `if/elif` blocks above.

It might be safer to check for specific model families or rely on a more explicit config flag for layer interleaving if available.

config.num_experts = 4
config.mlp_dim = 2048
config.moe_mlp_dim = 1024
config.shared_experts = 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Minor typo: "mamtul" -> "matmul".
Suggested change
config.shared_experts = 1
# moe ffn matmul

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This pull request significantly improves the accuracy of FLOPs and MFU (Model Flops Utilization) calculations across multiple architectures, with a focus on Gemma 4 and corrected sliding window logic. The implementation is thorough, including a new comprehensive test suite that validates calculations for 12 different model configurations.

🔍 General Feedback

  • Accuracy Improvements: The switch to a precise triangular overlap formula for sliding window attention and the inclusion of backward pass FLOPs for vision encoders are excellent updates that prevent MFU over-estimation.
  • Architectural Coverage: The addition of Gemma 4 specific logic and the generalization of MoE layer detection make the utilities much more robust for future model support.
  • Testing: The new maxtext_utils_flops_test.py is a great addition, providing clear manual-calculation-based verification for various architectures.
  • Suggestions: I've provided a few suggestions to further generalize the MoE layer detection and ensure consistent dimension usage in MoE FFN calculations.

learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer
else:
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning.

* config.per_device_batch_size
* config.max_target_length
* min(config.sliding_window_size, config.max_target_length)
* (config.max_target_length * window - 0.5 * window**2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Good catch on the causal sliding window FLOPs calculation. The new formula (config.max_target_length * window - 0.5 * window**2) is more accurate than the previous approximation.

num_dense_layers = 0
else:
raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.")
if config.num_experts > 1:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium - This fallback logic for MoE layer detection can be made more general by respecting config.first_num_dense_layers for any MoE architecture that doesn't match the specific interleaved logic of Llama 4.

Suggested change
if config.num_experts > 1:
if config.num_experts > 1:
num_dense_layers = config.first_num_dense_layers
num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers
else:

gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts
total_ffn_flops = (
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.num_experts_per_tok
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - Correcting this to use config.moe_mlp_dim is essential for accurate FLOPs in MoE models, as MoE layers typically use a different intermediate dimension than dense layers.

Suggested change
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok
gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok

learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer
else:
learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer
total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning.

* config.max_target_length
* config.emb_dim
* (config.num_query_heads + 2 * config.num_kv_heads)
* (config.num_query_heads + kv_multiplier * config.num_kv_heads)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Factoring in share_kv_projections for QKV FLOPs ensures accuracy for models that share key and value projections.

Suggested change
* (config.num_query_heads + kv_multiplier * config.num_kv_heads)
* (config.num_query_heads + kv_multiplier * config.num_kv_heads)

base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).")
first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.")
shared_experts: PositiveInt = Field(1, description="Number of shared experts.")
shared_experts: NonNegativeInt = Field(0, description="Number of shared experts.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Allowing shared_experts to be 0 is necessary for models that do not use shared experts.

@gobbleturk
Copy link
Copy Markdown
Collaborator

I love the test! I am not sure how we have gotten this far without testing our tflops calculation...

@@ -0,0 +1,516 @@
# Copyright 2023–2026 Google LLC
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MaxText we have flop_calculation_test.py in https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/unit/flop_calculation_test.py. Seems replicated?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants