-
Notifications
You must be signed in to change notification settings - Fork 500
Add Gemma 4 FLOPs & fix sliding window flops computations #3592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -214,19 +214,18 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo | |||||||||||||||||||
| Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder | ||||||||||||||||||||
| layer and we use sliding window attention in local_attention | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| noncausal_attention_flops = ( | ||||||||||||||||||||
| # global attention | ||||||||||||||||||||
| 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim | ||||||||||||||||||||
| + | ||||||||||||||||||||
| # local attention | ||||||||||||||||||||
| window = min(config.sliding_window_size, config.max_target_length) | ||||||||||||||||||||
| global_causal_flops = ( | ||||||||||||||||||||
| 2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| local_causal_flops = ( | ||||||||||||||||||||
| 4 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * config.max_target_length | ||||||||||||||||||||
| * min(config.sliding_window_size, config.max_target_length) | ||||||||||||||||||||
| * (config.max_target_length * window - 0.5 * window**2) | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟢 Good catch on the causal sliding window FLOPs calculation. The new formula |
||||||||||||||||||||
| * config.num_query_heads | ||||||||||||||||||||
| * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| causal_attention_flops = noncausal_attention_flops / 2 | ||||||||||||||||||||
| causal_attention_flops = global_causal_flops + local_causal_flops | ||||||||||||||||||||
| attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer | ||||||||||||||||||||
|
|
@@ -238,7 +237,7 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo | |||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def calculate_mixed_attention_model_tflops_training_per_device( | ||||||||||||||||||||
| config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length | ||||||||||||||||||||
| config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Calculate training TFLOPs for models with a mixed attention pattern of local | ||||||||||||||||||||
|
|
@@ -249,34 +248,125 @@ def calculate_mixed_attention_model_tflops_training_per_device( | |||||||||||||||||||
| num_global_layers = num_layers // attention_pattern_length | ||||||||||||||||||||
| num_local_layers = num_layers - num_global_layers | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # FLOPs for a single global attention layer (full attention) | ||||||||||||||||||||
| # Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim | ||||||||||||||||||||
| global_attention_flops_per_layer = ( | ||||||||||||||||||||
| 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim | ||||||||||||||||||||
| # Global causal attention uses a multiplier of 2 (instead of 4 for non-causal) | ||||||||||||||||||||
| # since we only compute the lower triangular half of the attention matrix. | ||||||||||||||||||||
| global_causal_flops_per_layer = ( | ||||||||||||||||||||
| 2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # FLOPs for a single local attention layer (sliding window) | ||||||||||||||||||||
| # Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim | ||||||||||||||||||||
| local_attention_flops_per_layer = ( | ||||||||||||||||||||
| # Local sliding window attention directly computes the exact causal interactions | ||||||||||||||||||||
| # via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4. | ||||||||||||||||||||
| window = min(config.sliding_window_size, config.max_target_length) | ||||||||||||||||||||
| local_causal_flops_per_layer = ( | ||||||||||||||||||||
| 4 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * config.max_target_length | ||||||||||||||||||||
| * min(config.sliding_window_size, config.max_target_length) | ||||||||||||||||||||
| * (config.max_target_length * window - 0.5 * window**2) | ||||||||||||||||||||
| * config.num_query_heads | ||||||||||||||||||||
| * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local) | ||||||||||||||||||||
| noncausal_attention_flops = ( | ||||||||||||||||||||
| num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer | ||||||||||||||||||||
| causal_attention_flops = ( | ||||||||||||||||||||
| num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Convert to TFLOPs and multiply by 3 for fwd/bwd pass | ||||||||||||||||||||
| attention_tflops = causal_attention_flops * 3 / 10**12 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| total_learnable_flops = total_ffn_flops_all_layers | ||||||||||||||||||||
|
|
||||||||||||||||||||
| total_learnable_flops += (qkv_flops + projection_flops) * num_layers + embedding_flops | ||||||||||||||||||||
|
|
||||||||||||||||||||
| learnable_weight_tflops = total_learnable_flops * 3 / 10**12 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return attention_tflops, learnable_weight_tflops | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def calculate_gemma4_tflops_training_per_device( | ||||||||||||||||||||
| config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Calculate training TFLOPs for Gemma 4. | ||||||||||||||||||||
| Gemma 4 has specific quirks: | ||||||||||||||||||||
| - Different QKV projection sizes for local vs. global layers. | ||||||||||||||||||||
| - Global-only KV sharing and varying global head dimensions. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| num_layers = config.num_decoder_layers | ||||||||||||||||||||
|
|
||||||||||||||||||||
| num_global_layers = num_layers // attention_pattern_length | ||||||||||||||||||||
| num_local_layers = num_layers - num_global_layers | ||||||||||||||||||||
|
|
||||||||||||||||||||
| kv_multiplier = 1 if config.share_kv_projections else 2 | ||||||||||||||||||||
| global_head_dim = config.global_head_dim or config.head_dim | ||||||||||||||||||||
| global_num_kv_heads = config.global_num_kv_heads or config.num_kv_heads | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Global causal attention uses a multiplier of 2 (instead of 4 for non-causal) | ||||||||||||||||||||
| # since we only compute the lower triangular half of the attention matrix. | ||||||||||||||||||||
| global_causal_flops_per_layer = ( | ||||||||||||||||||||
| 2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * global_head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Local sliding window attention directly computes the exact causal interactions | ||||||||||||||||||||
| # via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4. | ||||||||||||||||||||
| window = min(config.sliding_window_size, config.max_target_length) | ||||||||||||||||||||
| local_causal_flops_per_layer = ( | ||||||||||||||||||||
| 4 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * (config.max_target_length * window - 0.5 * window**2) | ||||||||||||||||||||
| * config.num_query_heads | ||||||||||||||||||||
| * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| causal_attention_flops = ( | ||||||||||||||||||||
| num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| causal_attention_flops = noncausal_attention_flops / 2 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Convert to TFLOPs and multiply by 3 for fwd/bwd pass | ||||||||||||||||||||
| attention_tflops = causal_attention_flops * 3 / 10**12 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Learnable weights (FFN, QKV, Projections) are present in every layer. | ||||||||||||||||||||
| learnable_weight_tflops = ((total_ffn_flops + qkv_flops + projection_flops) * num_layers + embedding_flops) * 3 / 10**12 | ||||||||||||||||||||
| global_qkv_flops_per_layer = ( | ||||||||||||||||||||
| 2 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * config.max_target_length | ||||||||||||||||||||
| * config.emb_dim | ||||||||||||||||||||
| * (config.num_query_heads + kv_multiplier * global_num_kv_heads) | ||||||||||||||||||||
| * global_head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| global_projection_flops_per_layer = ( | ||||||||||||||||||||
| 2 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * config.max_target_length | ||||||||||||||||||||
| * config.emb_dim | ||||||||||||||||||||
| * config.num_query_heads | ||||||||||||||||||||
| * global_head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Local layers never share KV projections (kv_multiplier is always 2). | ||||||||||||||||||||
| local_qkv_flops_per_layer = ( | ||||||||||||||||||||
| 2 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * config.max_target_length | ||||||||||||||||||||
| * config.emb_dim | ||||||||||||||||||||
| * (config.num_query_heads + 2 * config.num_kv_heads) | ||||||||||||||||||||
| * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| local_projection_flops_per_layer = ( | ||||||||||||||||||||
| 2 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * config.max_target_length | ||||||||||||||||||||
| * config.emb_dim | ||||||||||||||||||||
| * config.num_query_heads | ||||||||||||||||||||
| * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| total_learnable_flops = total_ffn_flops_all_layers | ||||||||||||||||||||
|
|
||||||||||||||||||||
| total_learnable_flops += ( | ||||||||||||||||||||
| (local_qkv_flops_per_layer + local_projection_flops_per_layer) * num_local_layers | ||||||||||||||||||||
| + (global_qkv_flops_per_layer + global_projection_flops_per_layer) * num_global_layers | ||||||||||||||||||||
| + embedding_flops | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| learnable_weight_tflops = total_learnable_flops * 3 / 10**12 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return attention_tflops, learnable_weight_tflops | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -493,11 +583,13 @@ def get_dense_moe_layers(config): | |||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.LLAMA4: | ||||||||||||||||||||
| num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step | ||||||||||||||||||||
| num_dense_layers = config.num_decoder_layers - num_moe_layers | ||||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: | ||||||||||||||||||||
| num_moe_layers = config.num_decoder_layers | ||||||||||||||||||||
| num_dense_layers = 0 | ||||||||||||||||||||
| else: | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 The fallback here assumes that if `num_experts > 1`, all layers are MoE layers. While this might be true for current Gemma 4 configs, it's a broad assumption that could lead to inaccuracies if future models interleave dense and MoE layers but aren't explicitly handled in the `if/elif` blocks above.
It might be safer to check for specific model families or rely on a more explicit config flag for layer interleaving if available. |
||||||||||||||||||||
| raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.") | ||||||||||||||||||||
| if config.num_experts > 1: | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Medium - This fallback logic for MoE layer detection can be made more general by respecting
Suggested change
|
||||||||||||||||||||
| num_moe_layers = config.num_decoder_layers | ||||||||||||||||||||
| num_dense_layers = 0 | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| num_moe_layers = 0 | ||||||||||||||||||||
| num_dense_layers = config.num_decoder_layers | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return num_dense_layers, num_moe_layers | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -598,6 +690,7 @@ def calculate_gemma3_vision_layers_tflops_per_device(config): | |||||||||||||||||||
| learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer | ||||||||||||||||||||
| total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning. |
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Convert to TFLOPs | ||||||||||||||||||||
| learnable_weight_tflops = learnable_weight_flops / 1e12 | ||||||||||||||||||||
|
|
@@ -660,6 +753,7 @@ def calculate_llama4_vision_layers_tflops_per_device(config): | |||||||||||||||||||
| learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer | ||||||||||||||||||||
| total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟠 High - Including the attention FLOPs in the vision encoder backward pass is correct when the encoder is unfreezed. This ensures that the total TFLOPs and MFU are not over-estimated during full fine-tuning. |
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Convert to TFLOPs | ||||||||||||||||||||
| learnable_weight_tflops = learnable_weight_flops / 1e12 | ||||||||||||||||||||
|
|
@@ -723,28 +817,40 @@ def calculate_vision_encoder_tflops(config): | |||||||||||||||||||
| def calculate_tflops_training_per_device(config, log=True): | ||||||||||||||||||||
| """Calculate training TFLOP""" | ||||||||||||||||||||
| # MLP flops | ||||||||||||||||||||
| is_ffn_flops_already_total = False | ||||||||||||||||||||
| if config.num_experts > 1: | ||||||||||||||||||||
| # calculation based on dropless implementation | ||||||||||||||||||||
| if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT): | ||||||||||||||||||||
| if config.decoder_block in ( | ||||||||||||||||||||
| DecoderBlockType.DEEPSEEK, | ||||||||||||||||||||
| DecoderBlockType.LLAMA4, | ||||||||||||||||||||
| DecoderBlockType.QWEN3_NEXT, | ||||||||||||||||||||
| DecoderBlockType.GEMMA4, | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config) | ||||||||||||||||||||
| is_ffn_flops_already_total = True | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts | ||||||||||||||||||||
| total_ffn_flops = ( | ||||||||||||||||||||
| gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.num_experts_per_tok | ||||||||||||||||||||
| gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟠 High - Correcting this to use
Suggested change
|
||||||||||||||||||||
| ) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| total_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| total_ffn_flops_all_layers = ( | ||||||||||||||||||||
| total_ffn_flops if is_ffn_flops_already_total else total_ffn_flops * config.num_decoder_layers | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Attention flops | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟠 The `kv_multiplier` is calculated here based on `config.share_kv_projections` and applied to `qkv_flops`. However, `qkv_flops` is used as the base for local attention layers in models like Gemma 3/4 (via `calculate_mixed_attention_model_tflops_training_per_device`).
According to the field description in Consider calculating
Suggested change
|
||||||||||||||||||||
| if config.attention_type == "mla": | ||||||||||||||||||||
| qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| kv_multiplier = 1 if config.share_kv_projections else 2 | ||||||||||||||||||||
| qkv_flops = ( | ||||||||||||||||||||
| 2 | ||||||||||||||||||||
| * config.per_device_batch_size | ||||||||||||||||||||
| * config.max_target_length | ||||||||||||||||||||
| * config.emb_dim | ||||||||||||||||||||
| * (config.num_query_heads + 2 * config.num_kv_heads) | ||||||||||||||||||||
| * (config.num_query_heads + kv_multiplier * config.num_kv_heads) | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Factoring in
Suggested change
|
||||||||||||||||||||
| * config.head_dim | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| noncausal_attention_flops = ( | ||||||||||||||||||||
|
|
@@ -765,7 +871,8 @@ def calculate_tflops_training_per_device(config, log=True): | |||||||||||||||||||
| # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272 | ||||||||||||||||||||
| causal_attention_flops = noncausal_attention_flops / 2 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Embedding flops | ||||||||||||||||||||
| # Embedding flops (counts only the unembedding projection; the embedding lookup is a gather operation | ||||||||||||||||||||
| # that performs no dense math, matching standard MFU hardware calculations) | ||||||||||||||||||||
| embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Combine flops with number of decoder layers | ||||||||||||||||||||
|
|
@@ -775,26 +882,30 @@ def calculate_tflops_training_per_device(config, log=True): | |||||||||||||||||||
| ) | ||||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.GEMMA3: | ||||||||||||||||||||
| attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device( | ||||||||||||||||||||
| config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6 | ||||||||||||||||||||
| config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6 | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.GPT_OSS: | ||||||||||||||||||||
| attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device( | ||||||||||||||||||||
| config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2 | ||||||||||||||||||||
| config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2 | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.LLAMA4: | ||||||||||||||||||||
| # Use the new helper to calculate attention TFLOPs correctly. | ||||||||||||||||||||
| attention_tflops = calculate_llama4_attention_tflops(config) | ||||||||||||||||||||
| # The learnable weight calculation remains the same as it correctly handles Llama4's MoE structure. | ||||||||||||||||||||
| learnable_weight_tflops = ( | ||||||||||||||||||||
| (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 | ||||||||||||||||||||
| (total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) | ||||||||||||||||||||
| * 3 | ||||||||||||||||||||
| / 10**12 | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.GEMMA4: | ||||||||||||||||||||
| attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device( | ||||||||||||||||||||
| config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6 | ||||||||||||||||||||
| attention_tflops, learnable_weight_tflops = calculate_gemma4_tflops_training_per_device( | ||||||||||||||||||||
| config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length=6 | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.DEEPSEEK: | ||||||||||||||||||||
| learnable_weight_tflops = ( | ||||||||||||||||||||
| (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 | ||||||||||||||||||||
| (total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) | ||||||||||||||||||||
| * 3 | ||||||||||||||||||||
| / 10**12 | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 | ||||||||||||||||||||
| elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: | ||||||||||||||||||||
|
|
@@ -805,7 +916,7 @@ def calculate_tflops_training_per_device(config, log=True): | |||||||||||||||||||
|
|
||||||||||||||||||||
| # Weights TFLOPs: | ||||||||||||||||||||
| total_weights = ( | ||||||||||||||||||||
| total_ffn_flops | ||||||||||||||||||||
| total_ffn_flops_all_layers | ||||||||||||||||||||
| + embedding_flops | ||||||||||||||||||||
| + (qkv_flops + projection_flops) * num_full_attn_layers | ||||||||||||||||||||
| + gdn_weight_flops_per_layer * num_linear_attn_layers | ||||||||||||||||||||
|
|
@@ -818,7 +929,9 @@ def calculate_tflops_training_per_device(config, log=True): | |||||||||||||||||||
| else: | ||||||||||||||||||||
| # multiply by 3 for both feed forward and back propagation flops | ||||||||||||||||||||
| learnable_weight_tflops = ( | ||||||||||||||||||||
| ((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 | ||||||||||||||||||||
| (total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) | ||||||||||||||||||||
| * 3 | ||||||||||||||||||||
| / 10**12 | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 Allowing
shared_expertsto be 0 is necessary for models that do not use shared experts.