diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 0b335608c4..c239c19178 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -252,7 +252,7 @@ use_2d_fsdp_sharding: False # deepseek moe base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer. For a fully MoE model, base_mlp_dim must be equal to base_moe_mlp_dim. first_num_dense_layers: 0 # number of initial dense layers in the model -shared_experts: 1 +shared_experts: 0 routed_scaling_factor: 1.0 # scaling factor for routing scores routed_score_func: "" # scoring function for routing routed_bias: False # a flag if a learnable bias is added for routing diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 315758015a..6349a5bc42 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -748,7 +748,7 @@ class DeepSeekMoE(BaseModel): base_moe_mlp_dim: int = Field(7168, description="Intermediate dimension at MoE layer (DeepSeek style).") first_num_dense_layers: NonNegativeInt = Field(0, description="Number of initial dense layers in the model.") - shared_experts: PositiveInt = Field(1, description="Number of shared experts.") + shared_experts: NonNegativeInt = Field(0, description="Number of shared experts.") routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.") routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').") routed_bias: bool = Field(False, description="Whether to add a bias term for routing.") diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index bdec9c1f10..daad142f51 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -217,19 +217,18 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo Calculate training TFLOP for Gemma2 as in Gemma2 we combine [local_attention, global_attention] into one decoder layer and we use sliding window attention in local_attention """ - noncausal_attention_flops = ( - # global attention - 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim - + - # local attention + window = min(config.sliding_window_size, config.max_target_length) + global_causal_flops = ( + 2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim + ) + local_causal_flops = ( 4 * config.per_device_batch_size - * config.max_target_length - * min(config.sliding_window_size, config.max_target_length) + * (config.max_target_length * window - 0.5 * window**2) * config.num_query_heads * config.head_dim ) - causal_attention_flops = noncausal_attention_flops / 2 + causal_attention_flops = global_causal_flops + local_causal_flops attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 # multiply num_decoder_layers by 2 because we combine [local_attention, global_attention] into one decoder layer @@ -241,7 +240,7 @@ def calculate_gemma2_tflops_training_per_device(config, total_ffn_flops, qkv_flo def calculate_mixed_attention_model_tflops_training_per_device( - config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length + config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length ): """ Calculate training TFLOPs for models with a mixed attention pattern of local @@ -252,34 +251,125 @@ def calculate_mixed_attention_model_tflops_training_per_device( num_global_layers = num_layers // attention_pattern_length num_local_layers = num_layers - num_global_layers - # FLOPs for a single global attention layer (full attention) - # Formula: 4 * batch_size * seq_len^2 * num_heads * head_dim - global_attention_flops_per_layer = ( - 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim + # Global causal attention uses a multiplier of 2 (instead of 4 for non-causal) + # since we only compute the lower triangular half of the attention matrix. + global_causal_flops_per_layer = ( + 2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim ) - # FLOPs for a single local attention layer (sliding window) - # Formula: 4 * batch_size * seq_len * window_size * num_heads * head_dim - local_attention_flops_per_layer = ( + # Local sliding window attention directly computes the exact causal interactions + # via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4. + window = min(config.sliding_window_size, config.max_target_length) + local_causal_flops_per_layer = ( 4 * config.per_device_batch_size - * config.max_target_length - * min(config.sliding_window_size, config.max_target_length) + * (config.max_target_length * window - 0.5 * window**2) * config.num_query_heads * config.head_dim ) - # Total attention FLOPs = (num_global_layers * FLOPs_per_global) + (num_local_layers * FLOPs_per_local) - noncausal_attention_flops = ( - num_global_layers * global_attention_flops_per_layer + num_local_layers * local_attention_flops_per_layer + causal_attention_flops = ( + num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer + ) + + # Convert to TFLOPs and multiply by 3 for fwd/bwd pass + attention_tflops = causal_attention_flops * 3 / 10**12 + + total_learnable_flops = total_ffn_flops_all_layers + + total_learnable_flops += (qkv_flops + projection_flops) * num_layers + embedding_flops + + learnable_weight_tflops = total_learnable_flops * 3 / 10**12 + + return attention_tflops, learnable_weight_tflops + + +def calculate_gemma4_tflops_training_per_device( + config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length +): + """ + Calculate training TFLOPs for Gemma 4. + Gemma 4 has specific quirks: + - Different QKV projection sizes for local vs. global layers. + - Global-only KV sharing and varying global head dimensions. + """ + num_layers = config.num_decoder_layers + + num_global_layers = num_layers // attention_pattern_length + num_local_layers = num_layers - num_global_layers + + kv_multiplier = 1 if config.share_kv_projections else 2 + global_head_dim = config.global_head_dim or config.head_dim + global_num_kv_heads = config.global_num_kv_heads or config.num_kv_heads + + # Global causal attention uses a multiplier of 2 (instead of 4 for non-causal) + # since we only compute the lower triangular half of the attention matrix. + global_causal_flops_per_layer = ( + 2 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * global_head_dim + ) + + # Local sliding window attention directly computes the exact causal interactions + # via the formula `(T * W - 0.5 * W^2)`. Therefore, we use the base multiplier of 4. + window = min(config.sliding_window_size, config.max_target_length) + local_causal_flops_per_layer = ( + 4 + * config.per_device_batch_size + * (config.max_target_length * window - 0.5 * window**2) + * config.num_query_heads + * config.head_dim + ) + + causal_attention_flops = ( + num_global_layers * global_causal_flops_per_layer + num_local_layers * local_causal_flops_per_layer ) - causal_attention_flops = noncausal_attention_flops / 2 # Convert to TFLOPs and multiply by 3 for fwd/bwd pass attention_tflops = causal_attention_flops * 3 / 10**12 - # Learnable weights (FFN, QKV, Projections) are present in every layer. - learnable_weight_tflops = ((total_ffn_flops + qkv_flops + projection_flops) * num_layers + embedding_flops) * 3 / 10**12 + global_qkv_flops_per_layer = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.emb_dim + * (config.num_query_heads + kv_multiplier * global_num_kv_heads) + * global_head_dim + ) + global_projection_flops_per_layer = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.emb_dim + * config.num_query_heads + * global_head_dim + ) + + # Local layers never share KV projections (kv_multiplier is always 2). + local_qkv_flops_per_layer = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.emb_dim + * (config.num_query_heads + 2 * config.num_kv_heads) + * config.head_dim + ) + local_projection_flops_per_layer = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.emb_dim + * config.num_query_heads + * config.head_dim + ) + + total_learnable_flops = total_ffn_flops_all_layers + + total_learnable_flops += ( + (local_qkv_flops_per_layer + local_projection_flops_per_layer) * num_local_layers + + (global_qkv_flops_per_layer + global_projection_flops_per_layer) * num_global_layers + + embedding_flops + ) + + learnable_weight_tflops = total_learnable_flops * 3 / 10**12 return attention_tflops, learnable_weight_tflops @@ -496,11 +586,13 @@ def get_dense_moe_layers(config): elif config.decoder_block == DecoderBlockType.LLAMA4: num_moe_layers = config.num_decoder_layers // config.interleave_moe_layer_step num_dense_layers = config.num_decoder_layers - num_moe_layers - elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: - num_moe_layers = config.num_decoder_layers - num_dense_layers = 0 else: - raise ValueError("Currently we only support DeepSeek, Llama4, and Qwen3-Next calculation.") + if config.num_experts > 1: + num_moe_layers = config.num_decoder_layers + num_dense_layers = 0 + else: + num_moe_layers = 0 + num_dense_layers = config.num_decoder_layers return num_dense_layers, num_moe_layers @@ -601,6 +693,7 @@ def calculate_gemma3_vision_layers_tflops_per_device(config): learnable_weight_flops += 2 * vision_embedder_flops # only projector is learnable, add fwd+optimizer else: learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer + total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass # Convert to TFLOPs learnable_weight_tflops = learnable_weight_flops / 1e12 @@ -663,6 +756,7 @@ def calculate_llama4_vision_layers_tflops_per_device(config): learnable_weight_flops += 2 * projector_flops # only projector is learnable, add fwd+optimizer else: learnable_weight_flops *= 3 # multiply by 3 for fwd + bwd + optimizer + total_attn_flops *= 3 # multiply by 3 for fwd + bwd pass # Convert to TFLOPs learnable_weight_tflops = learnable_weight_flops / 1e12 @@ -726,28 +820,40 @@ def calculate_vision_encoder_tflops(config): def calculate_tflops_training_per_device(config, log=True): """Calculate training TFLOP""" # MLP flops + is_ffn_flops_already_total = False if config.num_experts > 1: # calculation based on dropless implementation - if config.decoder_block in (DecoderBlockType.DEEPSEEK, DecoderBlockType.LLAMA4, DecoderBlockType.QWEN3_NEXT): + if config.decoder_block in ( + DecoderBlockType.DEEPSEEK, + DecoderBlockType.LLAMA4, + DecoderBlockType.QWEN3_NEXT, + DecoderBlockType.GEMMA4, + ): total_ffn_flops = calculate_routed_and_shared_ffn_tflops_per_device(config) + is_ffn_flops_already_total = True else: gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts total_ffn_flops = ( - gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) * config.num_experts_per_tok + gate_flops + calculate_ffn_mamtul_tflops_per_device(config, config.moe_mlp_dim) * config.num_experts_per_tok ) else: total_ffn_flops = calculate_ffn_mamtul_tflops_per_device(config, config.mlp_dim) + total_ffn_flops_all_layers = ( + total_ffn_flops if is_ffn_flops_already_total else total_ffn_flops * config.num_decoder_layers + ) + # Attention flops if config.attention_type == "mla": qkv_flops, causal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config) else: + kv_multiplier = 1 if config.share_kv_projections else 2 qkv_flops = ( 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim - * (config.num_query_heads + 2 * config.num_kv_heads) + * (config.num_query_heads + kv_multiplier * config.num_kv_heads) * config.head_dim ) noncausal_attention_flops = ( @@ -768,7 +874,8 @@ def calculate_tflops_training_per_device(config, log=True): # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272 causal_attention_flops = noncausal_attention_flops / 2 - # Embedding flops + # Embedding flops (counts only the unembedding projection; the embedding lookup is a gather operation + # that performs no dense math, matching standard MFU hardware calculations) embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size # Combine flops with number of decoder layers @@ -778,26 +885,30 @@ def calculate_tflops_training_per_device(config, log=True): ) elif config.decoder_block == DecoderBlockType.GEMMA3: attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device( - config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6 + config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6 ) elif config.decoder_block == DecoderBlockType.GPT_OSS: attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device( - config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2 + config, total_ffn_flops_all_layers, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=2 ) elif config.decoder_block == DecoderBlockType.LLAMA4: # Use the new helper to calculate attention TFLOPs correctly. attention_tflops = calculate_llama4_attention_tflops(config) # The learnable weight calculation remains the same as it correctly handles Llama4's MoE structure. learnable_weight_tflops = ( - (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 + (total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) + * 3 + / 10**12 ) elif config.decoder_block == DecoderBlockType.GEMMA4: - attention_tflops, learnable_weight_tflops = calculate_mixed_attention_model_tflops_training_per_device( - config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length=6 + attention_tflops, learnable_weight_tflops = calculate_gemma4_tflops_training_per_device( + config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length=6 ) elif config.decoder_block == DecoderBlockType.DEEPSEEK: learnable_weight_tflops = ( - (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 + (total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) + * 3 + / 10**12 ) attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: @@ -808,7 +919,7 @@ def calculate_tflops_training_per_device(config, log=True): # Weights TFLOPs: total_weights = ( - total_ffn_flops + total_ffn_flops_all_layers + embedding_flops + (qkv_flops + projection_flops) * num_full_attn_layers + gdn_weight_flops_per_layer * num_linear_attn_layers @@ -821,7 +932,9 @@ def calculate_tflops_training_per_device(config, log=True): else: # multiply by 3 for both feed forward and back propagation flops learnable_weight_tflops = ( - ((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 + (total_ffn_flops_all_layers + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) + * 3 + / 10**12 ) attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 diff --git a/tests/unit/flop_calculation_test.py b/tests/unit/flop_calculation_test.py index 04a87460c0..391dd179dc 100644 --- a/tests/unit/flop_calculation_test.py +++ b/tests/unit/flop_calculation_test.py @@ -14,17 +14,47 @@ """ Tests for verifying FLOPs calculation in maxtext_utils.py""" +from typing import Any import unittest import pytest +from unittest.mock import MagicMock +from absl.testing import parameterized from maxtext.configs import pyconfig +from maxtext.utils import maxtext_utils from maxtext.utils.maxtext_utils import calculate_tflops_training_per_device from tests.utils.test_helpers import get_test_config_path -class FlopCalculation(unittest.TestCase): +@pytest.mark.cpu_only +class FlopCalculation(parameterized.TestCase): """Tests for verifying FLOP calculation in MaxText""" + def _get_model_config_args( + self, model_name: str, max_target_length: int | None = None, per_device_batch_size: int | None = None + ): + """Returns the config args for a given model name, target length and batch size.""" + config_args = [None, get_test_config_path(f"models/{model_name}.yml"), "run_name=test"] + if max_target_length is not None: + config_args.append(f"max_target_length={max_target_length}") + if per_device_batch_size is not None: + config_args.append(f"per_device_batch_size={per_device_batch_size}") + config_args.append("skip_jax_distributed_system=True") + return config_args + + def _initialize_model_config( + self, + model_name: str, + max_target_length: int | None = None, + per_device_batch_size: int | None = None, + **overrides: Any, + ): + """Initializes the model config.""" + config_args = self._get_model_config_args( + model_name, max_target_length=max_target_length, per_device_batch_size=per_device_batch_size + ) + return pyconfig.initialize(config_args, enable_checkpointing=False, **overrides) + def assertFlopsAlmostEqual(self, flops1, flops2, rel_tol=5e-2): """Assert that two FLOPs values are almost equal, within 5% relative tolerance.""" self.assertTrue( @@ -53,6 +83,7 @@ def compute_regular_attention_flops_per_device(self, kwargs: dict) -> float: return attention_flops / 1e12 # return tflops + # ========== Unit Tests for Direct FLOP Calculation Functions ========== def compute_deepseek_attention_flops_per_device(self, kwargs: dict) -> float: """ Computes the total training TFLOPs per device for a DeepSeek-style model. @@ -134,39 +165,14 @@ def compute_qwen3_next_attention_flops_per_device(self, kwargs: dict) -> float: return (full_attn_flops + linear_attn_flops) / 1e12 - @pytest.mark.cpu_only def test_qwen3_next_flops(self): """Test Qwen3-Next Flops calculation""" - kwargs = { - "model_name": "qwen3-next-80b-a3b", - "override_model_config": True, - "per_device_batch_size": 1, - "max_target_length": 4096, - "decoder_block": "qwen3_next", - "gradient_accumulation_steps": 1, - "skip_jax_distributed_system": True, - # Core Architectural Parameters - "base_emb_dim": 2048, - "base_num_decoder_layers": 48, - "base_num_query_heads": 16, - "base_num_kv_heads": 2, - "head_dim": 256, - "vocab_size": 151936, - # MoE Parameters - "base_mlp_dim": 512, # Note: maxtext_utils uses moe_mlp_dim for calculations - "base_moe_mlp_dim": 512, - "num_experts": 512, - "num_experts_per_tok": 10, - "mlp_activations": ["silu", "linear"], - # Qwen3-Next Specific Parameters - "inhomogeneous_layer_cycle_interval": 4, - "gdn_conv_kernel_dim": 4, - "gdn_key_head_dim": 128, - "gdn_value_head_dim": 128, - "gdn_num_key_heads": 16, - "gdn_num_value_heads": 32, - "gdn_chunk_size": 64, - } + cfg = self._initialize_model_config( + "qwen3-next-80b-a3b", + max_target_length=4096, + per_device_batch_size=1, + ) + kwargs = cfg.get_keys() # 1. Calculate Attention TFLOPs attention_tflops = self.compute_qwen3_next_attention_flops_per_device(kwargs) @@ -233,73 +239,38 @@ def test_qwen3_next_flops(self): golden_tflops = weight_tflops + attention_tflops # Run Calculation - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) - @pytest.mark.cpu_only def test_llama2_7b_flops(self): """Test Llama2 7b Flops calculation with default parameters""" - kwargs = { - # Model bases - "model_name": "llama2-7b", - "override_model_config": True, - # Core workload parameters - "per_device_batch_size": 12, - "max_target_length": 2048, - # Model dimensions - "base_emb_dim": 4096, - "base_mlp_dim": 11008, - "base_num_query_heads": 32, - "base_num_kv_heads": 32, - "base_num_decoder_layers": 32, - "head_dim": 128, - "vocab_size": 32_000, - "mlp_activations": ["silu", "linear"], - "skip_jax_distributed_system": True, - } - B = kwargs["per_device_batch_size"] - S = kwargs["max_target_length"] + cfg = self._initialize_model_config( + "llama2-7b", + max_target_length=2048, + per_device_batch_size=12, + ) + kwargs = cfg.get_keys() + B = cfg.per_device_batch_size + S = cfg.max_target_length attention_flops = self.compute_regular_attention_flops_per_device(kwargs) # Llama2-7b has ~6.74B parameters # https://adithyask.medium.com/from-7b-to-8b-parameters-understanding-weight-matrix-changes-in-llama-transformer-models-31ea7ed5fd88 golden_param_size = 6.74e9 golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) - @pytest.mark.cpu_only def test_llama3_8b_flops(self): """Test Llama3 8b Flops calculation with default parameters""" - kwargs = { - # Model bases - "model_name": "llama3-8b", - "override_model_config": True, - # Core workload parameters - "per_device_batch_size": 4, - "max_target_length": 2048, - "gradient_accumulation_steps": 1, - # Model dimensions - "base_emb_dim": 4096, - "base_mlp_dim": 14336, - "base_num_query_heads": 32, - "base_num_kv_heads": 8, - "base_num_decoder_layers": 32, - "head_dim": 128, - "vocab_size": 128256, - "mlp_activations": ["silu", "linear"], - "skip_jax_distributed_system": True, - } - B = kwargs["per_device_batch_size"] - S = kwargs["max_target_length"] + cfg = self._initialize_model_config( + "llama3-8b", + max_target_length=2048, + per_device_batch_size=4, + ) + kwargs = cfg.get_keys() + B = cfg.per_device_batch_size + S = cfg.max_target_length attention_flops = self.compute_regular_attention_flops_per_device(kwargs) # LLaMA3-8b has ~8.03B parameters # https://adithyask.medium.com/from-7b-to-8b-parameters-understanding-weight-matrix-changes-in-llama-transformer-models-31ea7ed5fd88 @@ -307,239 +278,99 @@ def test_llama3_8b_flops(self): # Here we consider TIED embedding table, which reduces param count to 7.50B golden_param_size = 7.50e9 golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) - @pytest.mark.cpu_only def test_mixtral_8x7b_flops(self): """Test Mixtral 8x7b Flops calculation""" - kwargs = { - # Model bases - "model_name": "mixtral-8x7b", - "override_model_config": True, - # Core workload parameters - "per_device_batch_size": 4, - "max_target_length": 8192, - "num_experts": 8, - "num_experts_per_tok": 2, - "gradient_accumulation_steps": 1, - # model dimensions - "base_emb_dim": 4096, - "base_mlp_dim": 14336, - "base_num_query_heads": 32, - "base_num_kv_heads": 8, - "head_dim": 128, - "base_num_decoder_layers": 32, - "vocab_size": 32000, - "mlp_activations": ["silu", "linear"], - "skip_jax_distributed_system": True, - } - B = kwargs["per_device_batch_size"] - S = kwargs["max_target_length"] + cfg = self._initialize_model_config( + "mixtral-8x7b", + max_target_length=8192, + per_device_batch_size=4, + ) + kwargs = cfg.get_keys() + B = cfg.per_device_batch_size + S = cfg.max_target_length attention_flops = self.compute_regular_attention_flops_per_device(kwargs) # mixtral-8x7b has ~12.9B active parameters # https://mistral.ai/news/mixtral-of-experts golden_param_size = 12.9e9 golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) - @pytest.mark.cpu_only def test_deepseek2_16b_flops(self): """Test DeepSeek2-16b FLops calculation""" - kwargs = { - # Model bases - "model_name": "deepseek2-16b", - "override_model_config": True, - # Core workload parameters - "per_device_batch_size": 4, - "max_target_length": 8192, - "num_experts": 64, - "num_experts_per_tok": 6, - "shared_experts": 2, - # Model dimensions - "base_emb_dim": 2048, - "base_num_query_heads": 16, - "base_num_kv_heads": 16, - "base_mlp_dim": 10944, - "base_moe_mlp_dim": 1408, - "base_num_decoder_layers": 27, - "first_num_dense_layers": 1, - "mlp_activations": ["silu", "linear"], - "vocab_size": 102400, - # MLA - "q_lora_rank": 0, - "kv_lora_rank": 512, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 128, - "skip_jax_distributed_system": True, - } - B = kwargs["per_device_batch_size"] - S = kwargs["max_target_length"] + cfg = self._initialize_model_config( + "deepseek2-16b", + max_target_length=8192, + per_device_batch_size=4, + ) + kwargs = cfg.get_keys() + B = cfg.per_device_batch_size + S = cfg.max_target_length attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs) # deepseek2-16b has ~2.4B active parameters # https://arxiv.org/pdf/2405.04434 golden_param_size = 2.4e9 golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) - @pytest.mark.cpu_only def test_gpt_oss_20b_flops(self): """Test GPT OSS 20B Flops calculation""" - kwargs = { - # Model bases - "model_name": "gpt-oss-20b", - "override_model_config": True, - # Core workload parameters - "per_device_batch_size": 4, - "max_target_length": 8192, - "sliding_window_size": 128, - "num_experts": 32, - "num_experts_per_tok": 4, - "gradient_accumulation_steps": 1, - # model dimensions - "base_emb_dim": 2880, - "base_mlp_dim": 2880, - "base_num_query_heads": 64, - "base_num_kv_heads": 8, - "head_dim": 64, - "base_num_decoder_layers": 24, - "vocab_size": 201088, - "mlp_activations": ["silu", "linear"], - "skip_jax_distributed_system": True, - } - B = kwargs["per_device_batch_size"] - S = kwargs["max_target_length"] + cfg = self._initialize_model_config( + "gpt-oss-20b", + max_target_length=8192, + per_device_batch_size=4, + ) + kwargs = cfg.get_keys() + B = cfg.per_device_batch_size + S = cfg.max_target_length attention_flops = self.compute_gpt_attention_flops_per_device(kwargs) # gpt-oss-20b has ~3.6B active parameters # https://openai.com/index/introducing-gpt-oss/ golden_param_size = 3.6e9 golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) - @pytest.mark.cpu_only def test_deepseek32_671b_flops(self): """Test DeepSeek3.2-671b FLops calculation""" - kwargs = { - # Model bases - "model_name": "deepseek3.2-671b", - "override_model_config": True, - # Core workload parameters - "per_device_batch_size": 4, - "max_target_length": 4096, - "num_experts": 256, - "num_experts_per_tok": 8, - "shared_experts": 1, - # Model dimensions - "base_emb_dim": 7168, - "base_num_query_heads": 128, - "base_num_kv_heads": 128, - "base_mlp_dim": 18432, - "base_moe_mlp_dim": 2048, - "base_num_decoder_layers": 61, - "first_num_dense_layers": 3, - "mlp_activations": ["silu", "linear"], - "vocab_size": 129280, - # MLA - "q_lora_rank": 1536, - "kv_lora_rank": 512, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 128, - "skip_jax_distributed_system": True, - # Indexer for DeepSeek Sparse Attention - "use_indexer": True, - "indexer_n_heads": 64, - "indexer_head_dim": 128, - "indexer_topk": 2048, - "attention": "flash", - "use_tokamax_splash": True, - } - B = kwargs["per_device_batch_size"] - S = kwargs["max_target_length"] + cfg = self._initialize_model_config( + "deepseek3.2-671b", + max_target_length=4096, + per_device_batch_size=4, + attention="dot_product", + ) + kwargs = cfg.get_keys() + B = cfg.per_device_batch_size + S = cfg.max_target_length attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs) # deepseek3-671b has ~37B active parameters # https://arxiv.org/pdf/2412.19437 golden_param_size = 37e9 golden_tflops = 6 * B * S * golden_param_size / 1e12 + attention_flops - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) - @pytest.mark.cpu_only def test_custom_engram_flops(self): - """Test model with Engram FLops calculation""" - kwargs = { - # Model bases - "model_name": "deepseek2-16b", - "override_model_config": True, - # Core workload parameters - "per_device_batch_size": 4, - "max_target_length": 8192, - "num_experts": 64, - "num_experts_per_tok": 6, - "shared_experts": 2, - # Model dimensions - "base_emb_dim": 2048, - "base_num_query_heads": 16, - "base_num_kv_heads": 16, - "base_mlp_dim": 10944, - "base_moe_mlp_dim": 1408, - "base_num_decoder_layers": 27, - "first_num_dense_layers": 1, - "mlp_activations": ["silu", "linear"], - "vocab_size": 102400, - # MLA - "q_lora_rank": 0, - "kv_lora_rank": 512, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 128, - "skip_jax_distributed_system": True, - # Engram - "mhc_expansion_rate": 1, - "engram_layers": [2, 15], - "engram_num_heads": 8, - "engram_head_dim": 1280, - "engram_kernel_size": 4, - "engram_max_ngram_size": 3, - "engram_vocab_bases": [226240, 226240], - "tokenizer_type": "huggingface", - "tokenizer_path": "deepseek-ai/DeepSeek-V3.2", - "hf_access_token": "fake", - "scan_layers": False, - } - B = kwargs["per_device_batch_size"] - S = kwargs["max_target_length"] - G = kwargs["mhc_expansion_rate"] - D = kwargs["base_emb_dim"] - K = kwargs["engram_kernel_size"] - H = kwargs["engram_num_heads"] - H_D = kwargs["engram_head_dim"] - L = len(kwargs["engram_layers"]) - N = kwargs["engram_max_ngram_size"] + """Test model with Engram Flops calculation""" + cfg = self._initialize_model_config( + "deepseek2-16b", + max_target_length=8192, + per_device_batch_size=4, + ) + kwargs = cfg.get_keys() + B = cfg.per_device_batch_size + S = cfg.max_target_length + G = cfg.mhc_expansion_rate + D = cfg.base_emb_dim + K = cfg.engram_kernel_size + H = cfg.engram_num_heads + H_D = cfg.engram_head_dim + L = len(cfg.engram_layers) + N = cfg.engram_max_ngram_size attention_flops = self.compute_deepseek_attention_flops_per_device(kwargs) # deepseek2-16b has ~2.4B active parameters @@ -554,9 +385,498 @@ def test_custom_engram_flops(self): engram_active_params = L * (key_params + value_params + conv_params) golden_tflops = 6 * B * S * (golden_param_size + engram_active_params) / 1e12 + attention_flops - cfg = pyconfig.initialize( - [None, get_test_config_path()], - **kwargs, - ) calculated_tflops, _, _ = calculate_tflops_training_per_device(cfg) self.assertFlopsAlmostEqual(calculated_tflops, golden_tflops) + + def test_calculate_gemma2_tflops_training_per_device(self): + """Test calculate_gemma2_tflops_training_per_device.""" + config = MagicMock() + config.per_device_batch_size = 2 + config.max_target_length = 8192 + config.sliding_window_size = 4096 + config.num_query_heads = 8 + config.head_dim = 128 + config.num_decoder_layers = 10 + config.share_kv_projections = False + config.global_head_dim = None + config.global_num_kv_heads = None + + total_ffn_flops = 100 + qkv_flops = 200 + projection_flops = 150 + embedding_flops = 50 + + attention_tflops, learnable_weight_tflops = maxtext_utils.calculate_gemma2_tflops_training_per_device( + config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops + ) + + B = config.per_device_batch_size + T = config.max_target_length + W = config.sliding_window_size + H = config.num_query_heads + D = config.head_dim + + expected_global = 2 * B * T * T * H * D + expected_local = 4 * B * (T * W - 0.5 * W * W) * H * D + expected_causal = expected_global + expected_local + + expected_attention_tflops = expected_causal * config.num_decoder_layers * 3 / 10**12 + + self.assertAlmostEqual(attention_tflops, expected_attention_tflops, places=5) + + expected_learnable = ( + total_ffn_flops + qkv_flops + projection_flops + ) * config.num_decoder_layers * 2 + embedding_flops + expected_learnable_tflops = expected_learnable * 3 / 10**12 + self.assertAlmostEqual(learnable_weight_tflops, expected_learnable_tflops, places=5) + + def test_calculate_mixed_attention_model_tflops_training_per_device(self): + """Test calculate_mixed_attention_model_tflops_training_per_device.""" + config = MagicMock() + config.per_device_batch_size = 2 + config.max_target_length = 8192 + config.sliding_window_size = 4096 + config.num_query_heads = 8 + config.head_dim = 128 + config.num_decoder_layers = 10 + config.share_kv_projections = False + config.global_head_dim = None + config.global_num_kv_heads = None + + config.num_kv_heads = 4 + config.emb_dim = 512 + + total_ffn_flops = 100 + embedding_flops = 50 + attention_pattern_length = 5 + + qkv_flops = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.emb_dim + * (config.num_query_heads + 2 * config.num_kv_heads) + * config.head_dim + ) + projection_flops = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.emb_dim + * config.num_query_heads + * config.head_dim + ) + + attention_tflops, learnable_weight_tflops = maxtext_utils.calculate_mixed_attention_model_tflops_training_per_device( + config, total_ffn_flops, qkv_flops, projection_flops, embedding_flops, attention_pattern_length + ) + + B = config.per_device_batch_size + T = config.max_target_length + W = config.sliding_window_size + H = config.num_query_heads + D = config.head_dim + + num_global = 10 // 5 + num_local = 10 - num_global + + expected_global_per_layer = 2 * B * T * T * H * D + expected_local_per_layer = 4 * B * (T * W - 0.5 * W * W) * H * D + expected_causal = num_global * expected_global_per_layer + num_local * expected_local_per_layer + + expected_attention_tflops = expected_causal * 3 / 10**12 + + self.assertAlmostEqual(attention_tflops, expected_attention_tflops, places=5) + + expected_learnable = ( + total_ffn_flops * config.num_decoder_layers + + (qkv_flops + projection_flops) * config.num_decoder_layers + + embedding_flops + ) + expected_learnable_tflops = expected_learnable * 3 / 10**12 + self.assertAlmostEqual(learnable_weight_tflops, expected_learnable_tflops, places=5) + + def test_calculate_gemma4_tflops_training_per_device(self): + """Test calculate_gemma4_tflops_training_per_device.""" + config = MagicMock() + config.per_device_batch_size = 2 + config.max_target_length = 8192 + config.sliding_window_size = 4096 + config.num_query_heads = 16 + config.num_kv_heads = 8 + config.head_dim = 256 + config.num_decoder_layers = 12 + config.share_kv_projections = True + config.global_head_dim = 512 + config.global_num_kv_heads = 2 + config.emb_dim = 2816 + + total_ffn_flops = 0 + embedding_flops = 0 + + attention_tflops, learnable_weight_tflops = maxtext_utils.calculate_gemma4_tflops_training_per_device( + config, total_ffn_flops, embedding_flops, attention_pattern_length=6 + ) + + B = config.per_device_batch_size + T = config.max_target_length + W = config.sliding_window_size + H = config.num_query_heads + D = config.head_dim + GD = config.global_head_dim + + num_global = 12 // 6 + num_local = 12 - num_global + + expected_global_per_layer = 2 * B * T * T * H * GD + expected_local_per_layer = 4 * B * (T * W - 0.5 * W * W) * H * D + expected_causal = num_global * expected_global_per_layer + num_local * expected_local_per_layer + + expected_attention_tflops = expected_causal * 3 / 10**12 + + self.assertAlmostEqual(attention_tflops, expected_attention_tflops, places=5) + + # Weights checking + # local: share_kv_projections = False -> kv_multiplier = 2 + # qkv + proj + expected_local_weights = (2 * B * T * config.emb_dim * (H + 2 * config.num_kv_heads) * D) + ( + 2 * B * T * config.emb_dim * H * D + ) + expected_global_weights = (2 * B * T * config.emb_dim * (H + 1 * config.global_num_kv_heads) * GD) + ( + 2 * B * T * config.emb_dim * H * GD + ) + + expected_learnable = num_local * expected_local_weights + num_global * expected_global_weights + expected_learnable_tflops = expected_learnable * 3 / 10**12 + + self.assertAlmostEqual(learnable_weight_tflops, expected_learnable_tflops, places=5) + + def test_calculate_llama4_attention_tflops(self): + """Test calculate_llama4_attention_tflops.""" + config = MagicMock() + config.num_decoder_layers = 16 + config.max_target_length = 4096 + config.chunk_attn_window_size = 1024 + config.nope_layer_interval = 4 + config.per_device_batch_size = 2 + config.num_query_heads = 16 + config.head_dim = 128 + + attention_tflops = maxtext_utils.calculate_llama4_attention_tflops(config) + + # Manual calculation + num_global_layers = 16 // 4 # 4 + num_chunked_layers = 16 - 4 # 12 + global_flops = 4 * 2 * config.max_target_length**2 * config.num_query_heads * config.head_dim + + num_chunks = 4096 // 1024 # 4 + chunked_complexity = num_chunks * config.chunk_attn_window_size**2 + chunked_flops = 4 * 2 * chunked_complexity * config.num_query_heads * config.head_dim + + noncausal = (num_global_layers * global_flops) + (num_chunked_layers * chunked_flops) + expected_attention_tflops = (noncausal / 2) * 3 / 10**12 + + self.assertAlmostEqual(attention_tflops, expected_attention_tflops, places=5) + + def test_calculate_gemma4_tflops_training_per_device_shared_kv(self): + """Test calculate_gemma4_tflops_training_per_device_shared_kv.""" + config = MagicMock() + config.per_device_batch_size = 2 + config.max_target_length = 8192 + config.sliding_window_size = 1024 + config.num_query_heads = 32 + config.num_kv_heads = 8 + config.head_dim = 128 + config.num_decoder_layers = 12 + config.share_kv_projections = True + config.global_head_dim = 128 + config.global_num_kv_heads = 8 + config.emb_dim = 4096 + + total_ffn_flops_all_layers = 123456789 + embedding_flops = 333333333 + + attention_tflops, learnable_weight_tflops = maxtext_utils.calculate_gemma4_tflops_training_per_device( + config, total_ffn_flops_all_layers, embedding_flops, attention_pattern_length=6 + ) + + B = config.per_device_batch_size + T = config.max_target_length + W = min(config.sliding_window_size, config.max_target_length) + H = config.num_query_heads + D = config.head_dim + GD = config.global_head_dim + GKH = config.global_num_kv_heads + + num_global = 12 // 6 + num_local = 12 - num_global + + expected_global_per_layer = 2 * B * T * T * H * GD + expected_local_per_layer = 4 * B * (T * W - 0.5 * W * W) * H * D + expected_causal = num_global * expected_global_per_layer + num_local * expected_local_per_layer + + expected_attention_tflops = expected_causal * 3 / 10**12 + + self.assertAlmostEqual(attention_tflops, expected_attention_tflops, places=5) + + kv_multiplier = 1 if config.share_kv_projections else 2 + expected_global_qkv_flops_per_layer = 2 * B * T * config.emb_dim * (H + kv_multiplier * GKH) * GD + expected_global_projection_flops_per_layer = 2 * B * T * config.emb_dim * H * GD + + expected_local_qkv_flops_per_layer = 2 * B * T * config.emb_dim * (H + 2 * config.num_kv_heads) * D + expected_local_projection_flops_per_layer = 2 * B * T * config.emb_dim * H * D + + expected_learnable = ( + total_ffn_flops_all_layers + + (expected_local_qkv_flops_per_layer + expected_local_projection_flops_per_layer) * num_local + + (expected_global_qkv_flops_per_layer + expected_global_projection_flops_per_layer) * num_global + + embedding_flops + ) + expected_learnable_tflops = expected_learnable * 3 / 10**12 + + self.assertAlmostEqual(learnable_weight_tflops, expected_learnable_tflops, places=5) + + def test_calculate_routed_and_shared_ffn_tflops_per_device(self): + """Test calculate_routed_and_shared_ffn_tflops_per_device.""" + config = MagicMock() + config.decoder_block = maxtext_utils.DecoderBlockType.DEEPSEEK + config.per_device_batch_size = 1 + config.max_target_length = 2048 + config.emb_dim = 1024 + config.first_num_dense_layers = 2 + config.num_decoder_layers = 8 + config.num_experts = 4 + config.mlp_dim = 2048 + config.moe_mlp_dim = 1024 + config.shared_experts = 1 + config.num_experts_per_tok = 2 + config.mlp_activations = ["silu", "linear"] + + ffn_tflops = maxtext_utils.calculate_routed_and_shared_ffn_tflops_per_device(config) + + B = config.per_device_batch_size + T = config.max_target_length + E = config.emb_dim + N = config.num_experts + + gate_flops = 2 * B * T * E * N + + # dense ffn mamtul (silu: 2 * mlp_dim) + dense_ffn1 = 2 * B * T * E * (2 * config.mlp_dim) + dense_ffn2 = 2 * B * T * config.mlp_dim * E + dense_flops_per_layer = dense_ffn1 + dense_ffn2 + + # moe ffn mamtul + moe_ffn1 = 2 * B * T * E * (2 * config.moe_mlp_dim) + moe_ffn2 = 2 * B * T * config.moe_mlp_dim * E + moe_flops_per_expert = moe_ffn1 + moe_ffn2 + + shared_flops = moe_flops_per_expert * config.shared_experts + routed_flops = moe_flops_per_expert * config.num_experts_per_tok + + # layers + dense_layers = config.first_num_dense_layers + moe_layers = config.num_decoder_layers - config.first_num_dense_layers + + expected_total = (dense_flops_per_layer * dense_layers) + ((gate_flops + shared_flops + routed_flops) * moe_layers) + + self.assertAlmostEqual(ffn_tflops, expected_total, places=5) + + # ========== Parameterized Tests for Multiple Standard Models ========== + + def _verify_flops(self, model_name, max_target_length=1): + """ + Verifies that for a given sequence length, the total compute matches exactly what we + expect from manual parameter extraction using the `6 * active_params * tokens` estimation rule, + plus the expected attention flops. + """ + config_args = [ + None, + get_test_config_path(f"models/{model_name}.yml"), + "run_name=test", + f"max_target_length={max_target_length}", + "per_device_batch_size=1", + "skip_jax_distributed_system=True", + ] + config = pyconfig.initialize(config_args, enable_checkpointing=False) + tflops, _, attention_tflops = calculate_tflops_training_per_device(config) + + # 1. Determine layer counts (dense vs MoE) + num_dense, num_moe = maxtext_utils.get_dense_moe_layers(config) + + # 2. Calculate FFN (Feed-Forward Network) parameters + dense_ffn_params = (config.emb_dim * config.mlp_dim * 2 + config.mlp_dim * config.emb_dim) * num_dense + moe_ffn_params = ( + (config.emb_dim * config.num_experts) # gate (router module) + + ( + (config.emb_dim * config.moe_mlp_dim * 2 + config.moe_mlp_dim * config.emb_dim) * config.shared_experts + ) # shared experts + + ( + (config.emb_dim * config.moe_mlp_dim * 2 + config.moe_mlp_dim * config.emb_dim) * config.num_experts_per_tok + ) # routed experts + ) * num_moe + total_ffn_params = dense_ffn_params + moe_ffn_params + + # 3. Calculate embedding parameters + embedding_params = config.vocab_size * config.emb_dim + # If not sharing weights, there is a separate unembedding layer + if getattr(config, "logits_via_embedding", False) is False: + embedding_params += config.vocab_size * config.emb_dim + + # 4. Resolve attention pattern lengths based on architecture (local sliding vs global causal) + attention_pattern_length = getattr(config, "attention_pattern_length", config.num_decoder_layers) + if not attention_pattern_length: + attention_pattern_length = config.num_decoder_layers + + if getattr(config, "decoder_block", None) == maxtext_utils.DecoderBlockType.GPT_OSS: + attention_pattern_length = 2 + elif getattr(config, "decoder_block", None) in ( + maxtext_utils.DecoderBlockType.GEMMA4, + maxtext_utils.DecoderBlockType.GEMMA3, + ): + attention_pattern_length = 6 + + num_global_layers = config.num_decoder_layers // attention_pattern_length + num_local_layers = config.num_decoder_layers - num_global_layers + + # 5. Calculate QKV and Projection parameters based on attention type + if getattr(config, "attention_type", "") == "mla": + # Multi-Head Latent Attention (MLA) used in DeepSeek models + qk_head_dim_sum = config.qk_nope_head_dim + config.qk_rope_head_dim + if config.q_lora_rank == 0: + q_params = config.emb_dim * config.num_query_heads * qk_head_dim_sum + else: + q_params = config.emb_dim * config.q_lora_rank + config.q_lora_rank * config.num_query_heads * qk_head_dim_sum + + kv_params = config.emb_dim * ( + config.kv_lora_rank + config.qk_rope_head_dim + ) + config.kv_lora_rank * config.num_query_heads * (config.qk_nope_head_dim + config.v_head_dim) + proj_params = config.emb_dim * config.num_query_heads * config.v_head_dim + total_qkv_proj_params = (q_params + kv_params + proj_params) * config.num_decoder_layers + elif getattr(config, "decoder_block", None) == maxtext_utils.DecoderBlockType.QWEN3_NEXT: + # Interleaved Full Attention and Gated Delta Net (Linear Attention) + cycle_interval = config.inhomogeneous_layer_cycle_interval + num_full_attn_layers = config.num_decoder_layers // cycle_interval + num_linear_attn_layers = config.num_decoder_layers - num_full_attn_layers + + local_kv_multiplier = 1 if getattr(config, "share_kv_projections", False) else 2 + qkv_params = config.emb_dim * (config.num_query_heads + local_kv_multiplier * config.num_kv_heads) * config.head_dim + proj_params = config.emb_dim * config.num_query_heads * config.head_dim + + H_k = config.gdn_num_key_heads + H_v = config.gdn_num_value_heads + D_k = config.gdn_key_head_dim + D_v = config.gdn_value_head_dim + K_conv = config.gdn_conv_kernel_dim + K_dim = H_k * D_k + V_dim = H_v * D_v + + # in_proj_qkvz + in_proj_ba + out_proj + gdn_proj_params = config.emb_dim * (2 * K_dim + 2 * V_dim + 2 * H_v) + config.emb_dim * V_dim + gdn_conv_params = K_conv * (2 * K_dim + V_dim) + + total_qkv_proj_params = (qkv_params + proj_params) * num_full_attn_layers + ( + gdn_proj_params + gdn_conv_params + ) * num_linear_attn_layers + else: + # Standard Attention (MHA / GQA / MQA) with local window variations + global_head_dim = getattr(config, "global_head_dim", config.head_dim) or config.head_dim + global_num_kv_heads = getattr(config, "global_num_kv_heads", config.num_kv_heads) or config.num_kv_heads + + # Local window layer parameters + # Local layers NEVER share KV projections in Gemma 4 + local_kv_multiplier = 2 + local_qkv_params = ( + config.emb_dim + * (config.num_query_heads + local_kv_multiplier * config.num_kv_heads) + * config.head_dim + * num_local_layers + ) + local_proj_params = config.emb_dim * config.num_query_heads * config.head_dim * num_local_layers + + # Global full attention layer parameters + global_kv_multiplier = 1 if getattr(config, "share_kv_projections", False) else 2 + global_qkv_params = ( + config.emb_dim + * (config.num_query_heads + global_kv_multiplier * global_num_kv_heads) + * global_head_dim + * num_global_layers + ) + global_proj_params = config.emb_dim * config.num_query_heads * global_head_dim * num_global_layers + + total_qkv_proj_params = local_qkv_params + local_proj_params + global_qkv_params + global_proj_params + + active_params = total_ffn_params + total_qkv_proj_params + embedding_params + + expected_flops_from_params = 6 * active_params * config.per_device_batch_size * config.max_target_length / 10**12 + # If not sharing weights, active_params counts both embedding and unembedding matrices. + # However, embedding lookup is a gather operation and does not use dense math (FLOPs). + # We must subtract its FLOP equivalent from the expected result so it matches the physical math. + if getattr(config, "logits_via_embedding", False) is False: + expected_flops_from_params -= ( + 6 * (config.vocab_size * config.emb_dim) * config.per_device_batch_size * config.max_target_length / 10**12 + ) + + expected_total_flops = expected_flops_from_params + attention_tflops + + print( + f"\nActive params for {model_name} (seq_len={max_target_length}): {active_params}, " + f"Expected TFLOPs: {expected_total_flops} (Computed TFLOPs: {tflops})" + ) + + # 5% margin for approximations and any edge cases + self.assertAlmostEqual(tflops, expected_total_flops, delta=max(expected_total_flops * 0.05, 0.001)) + + def _verify_short_sequence_flops(self, model_name): + """Verifies short sequence flops.""" + self._verify_flops(model_name, max_target_length=1) + + def _verify_long_sequence_flops(self, model_name): + """Verifies long sequence flops.""" + self._verify_flops(model_name, max_target_length=8192) + + @parameterized.parameters( + ("llama3-8b",), + ("llama4-17b-16e",), + ("gemma3-4b",), + ("gemma3-12b",), + ("gemma3-27b",), + ("gemma4-26b",), + ("gemma4-31b",), + ("gpt-oss-20b",), + ("gpt-oss-120b",), + ("qwen3-8b",), + ("qwen3-next-80b-a3b",), + ("deepseek3-671b",), + ) + def test_short_sequence_flops(self, model_name): + """ + Validates that the computed TFLOPs match the `6 * active_params * tokens` estimation + for various standard models when attention FLOPs are isolated (e.g. sequence length = 1) + """ + self._verify_short_sequence_flops(model_name) + + @parameterized.parameters( + ("llama3-8b",), + ("llama4-17b-16e",), + ("gemma3-4b",), + ("gemma3-12b",), + ("gemma3-27b",), + ("gemma4-26b",), + ("gemma4-31b",), + ("gpt-oss-20b",), + ("gpt-oss-120b",), + ("qwen3-8b",), + ("qwen3-next-80b-a3b",), + ("deepseek3-671b",), + ) + def test_long_sequence_flops(self, model_name): + """ + Validates that the computed TFLOPs match the `6 * active_params * tokens` estimation + plus expected attention flops for various standard models with a long sequence length. + """ + self._verify_long_sequence_flops(model_name) + + +if __name__ == "__main__": + unittest.main()