Add Gemma 4 FLOPs & fix sliding window flops computations#3592
Add Gemma 4 FLOPs & fix sliding window flops computations#3592
Conversation
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
e46459c to
038d9e9
Compare
038d9e9 to
ccbcf03
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces Gemma 4 FLOPs calculations and significantly improves the accuracy of existing FLOPs math, particularly for sliding window attention and mixed attention architectures. The addition of a comprehensive test suite covering multiple model families is a major highlight and ensures the reliability of these critical metrics.
🔍 General Feedback
- Great Test Coverage: The new
maxtext_utils_flops_test.pyis excellent. It uses a robust6 * params * tokensverification strategy that provides high confidence in the computed TFLOPs across various architectures. - Improved Accuracy: The fixes for sliding window area and vision encoder scaling (backward pass) are well-timed and correct.
- Inconsistency in Shared KV Projections: There is a potential logic error in how
share_kv_projectionsis applied to mixed attention models in the main caller. One unit test specifically assumes local layers do not share KV projections even when the flag is True, but the code currently applies it to both. - MoE Fallback Logic: The fallback for MoE layer detection is now more generalized, which is good, but might be too broad for future hybrid architectures.
| config.decoder_block = maxtext_utils.DecoderBlockType.DEEPSEEK | ||
| config.per_device_batch_size = 1 | ||
| config.max_target_length = 2048 | ||
| config.emb_dim = 1024 |
There was a problem hiding this comment.
| config.emb_dim = 1024 | |
| # dense ffn matmul (silu: 2 * mlp_dim) |
| total_ffn_flops if is_ffn_flops_already_total else total_ffn_flops * config.num_decoder_layers | ||
| ) | ||
|
|
||
| # Attention flops |
There was a problem hiding this comment.
According to the field description in types.py, share_kv_projections only applies to global attention. Applying it to qkv_flops here will incorrectly reduce the FLOPs estimate for local attention layers if share_kv_projections is True.
Consider calculating qkv_flops for local layers without this multiplier, or pass a separate multiplier for global layers.
| # Attention flops | |
| qkv_flops = ( | |
| 2 | |
| * config.per_device_batch_size | |
| * config.max_target_length | |
| * config.emb_dim | |
| * (config.num_query_heads + 2 * config.num_kv_heads) | |
| * config.head_dim | |
| ) |
| elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: | ||
| num_moe_layers = config.num_decoder_layers | ||
| num_dense_layers = 0 | ||
| else: |
There was a problem hiding this comment.
It might be safer to check for specific model families or rely on a more explicit config flag for layer interleaving if available.
| config.num_experts = 4 | ||
| config.mlp_dim = 2048 | ||
| config.moe_mlp_dim = 1024 | ||
| config.shared_experts = 1 |
There was a problem hiding this comment.
| config.shared_experts = 1 | |
| # moe ffn matmul |
ccbcf03 to
ffd741b
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This pull request significantly improves the accuracy of FLOPs and MFU (Model Flops Utilization) calculations across multiple architectures, with a focus on Gemma 4 and corrected sliding window logic. The implementation is thorough, including a new comprehensive test suite that validates calculations for 12 different model configurations.
🔍 General Feedback
- Accuracy Improvements: The switch to a precise triangular overlap formula for sliding window attention and the inclusion of backward pass FLOPs for vision encoders are excellent updates that prevent MFU over-estimation.
- Architectural Coverage: The addition of Gemma 4 specific logic and the generalization of MoE layer detection make the utilities much more robust for future model support.
- Testing: The new
maxtext_utils_flops_test.pyis a great addition, providing clear manual-calculation-based verification for various architectures. - Suggestions: I've provided a few suggestions to further generalize the MoE layer detection and ensure consistent dimension usage in MoE FFN calculations.
| learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer | ||
| else: | ||
| learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer | ||
| total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass |
There was a problem hiding this comment.
🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning.
| * config.per_device_batch_size | ||
| * config.max_target_length | ||
| * min(config.sliding_window_size, config.max_target_length) | ||
| * (config.max_target_length * window - 0.5 * window**2) |
There was a problem hiding this comment.
🟢 Good catch on the causal sliding window FLOPs calculation. The new formula (config.max_target_length * window - 0.5 * window**2) is more accurate than the previous approximation.
| num_dense_layers = 0 | ||
| else: | ||
| raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.") | ||
| if config.num_experts > 1: |
There was a problem hiding this comment.
🟡 Medium - This fallback logic for MoE layer detection can be made more general by respecting config.first_num_dense_layers for any MoE architecture that doesn't match the specific interleaved logic of Llama 4.
| if config.num_experts > 1: | |
| if config.num_experts > 1: | |
| num_dense_layers = config.first_num_dense_layers | |
| num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers | |
| else: |
| gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts | ||
| total_ffn_flops = ( | ||
| gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.num_experts_per_tok | ||
| gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok |
There was a problem hiding this comment.
🟠 High - Correcting this to use config.moe_mlp_dim is essential for accurate FLOPs in MoE models, as MoE layers typically use a different intermediate dimension than dense layers.
| gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok | |
| gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok |
| learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer | ||
| else: | ||
| learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer | ||
| total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass |
There was a problem hiding this comment.
🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning.
| * config.max_target_length | ||
| * config.emb_dim | ||
| * (config.num_query_heads + 2 * config.num_kv_heads) | ||
| * (config.num_query_heads + kv_multiplier * config.num_kv_heads) |
There was a problem hiding this comment.
🟡 Factoring in share_kv_projections for QKV FLOPs ensures accuracy for models that share key and value projections.
| * (config.num_query_heads + kv_multiplier * config.num_kv_heads) | |
| * (config.num_query_heads + kv_multiplier * config.num_kv_heads) |
| base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).") | ||
| first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.") | ||
| shared_experts: PositiveInt = Field(1, description="Number of shared experts.") | ||
| shared_experts: NonNegativeInt = Field(0, description="Number of shared experts.") |
There was a problem hiding this comment.
🟢 Allowing shared_experts to be 0 is necessary for models that do not use shared experts.
|
I love the test! I am not sure how we have gotten this far without testing our tflops calculation... |
| @@ -0,0 +1,516 @@ | |||
| # Copyright 2023–2026 Google LLC | |||
There was a problem hiding this comment.
In MaxText we have flop_calculation_test.py in https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/unit/flop_calculation_test.py. Seems replicated?
Description
Adds TFLOPs calculations for the Gemma 4 architecture (including MoE) and fixes several inaccuracies in existing FLOPs math (sliding window overlap, vision encoder scaling, and shared KV projections).
moe_mlp_dimand generalized MoE layer detection (num_experts > 1).max_target_length * window - 0.5 * window**2).share_kv_projectionsfor accurate QKV FLOPs.Tests
Added
maxtext_utils_flops_test.pyto validate FLOPs calculations across 12 model architectures.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.