diff --git a/src/dependencies/extra_deps/qwix b/src/dependencies/extra_deps/qwix new file mode 160000 index 0000000000..3d30c04d49 --- /dev/null +++ b/src/dependencies/extra_deps/qwix @@ -0,0 +1 @@ +Subproject commit 3d30c04d49160bb87be64576c3820c6be5480b88 diff --git a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt index 08da4a3ab7..f4dc969001 100644 --- a/src/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/src/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -189,7 +189,6 @@ python-dotenv>=1.2.1 pytype>=2024.10.11 pytz>=2025.2 pyyaml>=6.0.3 -qwix>=0.1.4 regex>=2025.11.3 requests-oauthlib>=2.0.0 requests>=2.32.5 @@ -223,7 +222,7 @@ tensorflow>=2.19.1 tensorstore>=0.1.79 termcolor>=3.2.0 tiktoken>=0.12.0 -tokamax>=0.0.8 +tokamax>=0.0.11 tokenizers>=0.22.1 toml>=0.10.2 tomlkit>=0.13.3 diff --git a/src/dependencies/requirements/requirements.txt b/src/dependencies/requirements/requirements.txt index 33241acff3..87f2c5d31f 100644 --- a/src/dependencies/requirements/requirements.txt +++ b/src/dependencies/requirements/requirements.txt @@ -32,7 +32,6 @@ pyink pylint pytest pytype -qwix sentencepiece tensorboard-plugin-profile tensorboardx @@ -40,7 +39,7 @@ tensorflow-datasets tensorflow-text tensorflow tiktoken -tokamax>=0.0.4 +tokamax>=0.0.11 transformers google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/src/dependencies/scripts/setup.sh b/src/dependencies/scripts/setup.sh index 10f7d262a5..6008662392 100644 --- a/src/dependencies/scripts/setup.sh +++ b/src/dependencies/scripts/setup.sh @@ -216,6 +216,8 @@ install_maxtext_with_deps() { fi echo "Installing requirements from $dep_name" python3 -m uv pip install --resolution=lowest -r "$dep_name" + echo "Installing local qwix from extra_deps" + python3 -m uv pip install -e src/dependencies/extra_deps/qwix python3 -m src.dependencies.scripts.install_pre_train_extra_deps install_maxtext_package_without_deps @@ -230,6 +232,8 @@ install_post_training_deps() { dep_name='src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt' echo "Installing requirements from $dep_name" python3 -m uv pip install --resolution=lowest -r "$dep_name" + echo "Installing local qwix from extra_deps" + python3 -m uv pip install -e src/dependencies/extra_deps/qwix python3 -m src.dependencies.scripts.install_post_train_extra_deps } diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index cdfde92d50..aec30c42cd 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -601,7 +601,9 @@ def map_to_pspec(data): ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) - checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + checkpoint_args = ocp.args.PyTreeRestore( + item=abstract_unboxed_pre_state, restore_args=restore_args, partial_restore=True + ) match (checkpoint_manager, dataset_type, data_iterator): # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index fb58aa79b4..898a68a9d0 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -145,6 +145,14 @@ act_quantization_calibration_method: "absmax" bwd_quantization_calibration_method: "absmax" # shard the range finding operation for quantization. by default this is set to number of slices. quantization_local_shard_count: -1 +# The 'N' in N:M sparsity, representing the maximum number of non-zero values in each block. +weight_sparsity_n: null +# The 'M' in N:M sparsity, representing the number of values in each block. +weight_sparsity_m: null +# The step size to update the sparsity masks. +weight_sparsity_update_step: 10 +# The first number of steps before updating the sparsity masks. +weight_sparsity_start_step: 50 decoder_block: "llama2" # which style of decoderblock to use. # global parameter scale needs to be a power of 2. if you want finer grained control of the model sizes diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 4d26a0a309..1961a5645b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -428,6 +428,14 @@ class Quantization(BaseModel): "absmax", description="Quantization calibration method used for gradients.", ) + weight_sparsity_n: int | None = Field( + None, description="The 'N' in N:M sparsity, representing the maximum number of non-zero values in each block." + ) + weight_sparsity_m: int | None = Field( + None, description="The 'M' in N:M sparsity, representing the number of values in each block." + ) + weight_sparsity_update_step: int = Field(10, description="The step size for updating weight sparsity masks.") + weight_sparsity_start_step: int = Field(50, description="The first number of steps before updating the sparsity masks.") class ModelArchitecture(BaseModel): diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 6922bd0016..f03aebf583 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -563,6 +563,7 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me "cache": cache_spec, "intermediates": 0, "aqt": 0, + "batch_stats": 0, "_overwrite_with_gradient": 0, }, split_rngs={ diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 5f6570bc88..e3cfa7b14e 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -39,6 +39,8 @@ from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding from maxtext.utils.sharding import logical_to_mesh_axes import numpy as np +import qwix +from qwix.contrib.sparsity import sparsity_module import qwix.pallas as qpl import tokamax @@ -389,6 +391,36 @@ def __init__( shard_mode=config.shard_mode, rngs=self.rngs, ) + rule = qpl.get_current_rule("gmm") + sparsity_rule = None + if rule is not None: + if not isinstance(rule, qwix.QtRule): + raise ValueError("Expect a QtRule for quantized training.") + if rule.additional_qt_config and "sparsity_rule" in rule.additional_qt_config: + q_s_rule = rule.additional_qt_config["sparsity_rule"] + if q_s_rule and q_s_rule.weight_sparsity_n and q_s_rule.weight_sparsity_m: + sparsity_rule = q_s_rule + + if sparsity_rule is not None: + self.wi_0_sparsity_module = sparsity_module.SparsityModule( + shape=(self.num_experts, self.config.emb_dim, self.intermediate_dim), + sharding_axes=self.wi_kernel_axes, + sparsity_rule=sparsity_rule, + ) + self.wi_1_sparsity_module = sparsity_module.SparsityModule( + shape=(self.num_experts, self.config.emb_dim, self.intermediate_dim), + sharding_axes=self.wi_kernel_axes, + sparsity_rule=sparsity_rule, + ) + self.wo_sparsity_module = sparsity_module.SparsityModule( + shape=(self.num_experts, self.intermediate_dim, self.config.emb_dim), + sharding_axes=self.wo_kernel_axes, + sparsity_rule=sparsity_rule, + ) + else: + self.wi_0_sparsity_module = None + self.wi_1_sparsity_module = None + self.wo_sparsity_module = None # pylint: disable=protected-access self.activation_fn = linears._convert_to_activation_function(self.config.mlp_activations[0]) @@ -909,7 +941,9 @@ def gmm( inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes ): # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm - if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat: + if self.config.use_qwix_quantization or ( + self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat + ): tokamax_group_sizes = group_sizes elif self.config.attention == "vllm_rpa": tokamax_group_sizes = group_sizes @@ -2033,6 +2067,10 @@ def __call__( if self.per_expert_scale is not None: wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] + if self.wi_0_sparsity_module is not None: + _, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel) + _, w1_kernel = self.wi_1_sparsity_module(jnp.zeros_like(w1_kernel), w1_kernel) + _, wo_kernel = self.wo_sparsity_module(jnp.zeros_like(wo_kernel), wo_kernel) if cfg.mlp_bias: w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype) w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype) diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 503e7e0b04..b8007197d1 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -28,6 +28,7 @@ import qwix from qwix._src.core import dot_general_qt +from qwix._src.core import sparsity import jax import jax.numpy as jnp @@ -730,68 +731,89 @@ def dot_general(self, *args, **kwargs): return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs) -def get_fp8_full_qwix_rule(config: Config): - return qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - weight_calibration_method=config.weight_quantization_calibration_method, - act_calibration_method=config.act_quantization_calibration_method, - bwd_calibration_method=config.bwd_quantization_calibration_method, - op_names=("dot_general", "gmm", "ragged_dot"), - ) +def get_fp8_full_qwix_rule_w_sparsity(config: Config): + sparsity_rule = None + if config.weight_sparsity_n and config.weight_sparsity_m: + sparsity_rule = sparsity.SparsityRule( + weight_sparsity_n=config.weight_sparsity_n, + weight_sparsity_m=config.weight_sparsity_m, + weight_sparsity_update_step=config.weight_sparsity_update_step, + weight_sparsity_start_step=config.weight_sparsity_start_step, + ) + return [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, + additional_qt_config={"sparsity_rule": sparsity_rule}, + op_names=("dot_general", "gmm", "ragged_dot"), + ), + ] def get_quantization_rule(config: Config): match config.quantization: case "int4": - return qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.int4, - act_qtype=jnp.int4, - bwd_qtype=jnp.int4, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) + return [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=jnp.int4, + act_qtype=jnp.int4, + bwd_qtype=jnp.int4, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, + op_names=("dot_general",), + ) + ] case "int8": - return qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.int8, - act_qtype=jnp.int8, - bwd_qtype=jnp.int8, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) + return [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=jnp.int8, + act_qtype=jnp.int8, + bwd_qtype=jnp.int8, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, + op_names=("dot_general",), + ) + ] case "fp8": - return qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) + return [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e4m3fn, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, + op_names=("dot_general",), + ) + ] case "fp8_full": - return get_fp8_full_qwix_rule(config) + return get_fp8_full_qwix_rule_w_sparsity(config) case "fp8_gpu": - return qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) + return [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e4m3fn, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, + op_names=("dot_general",), + ) + ] case "fp8_nanoo": - return qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - ) + return [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e4m3fn, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, + op_names=("dot_general",), + ) + ] case "": return None @@ -800,17 +822,17 @@ def get_qt_provider(config): """Get quantization rules based on the config.""" match config.quantization: case "int8": - return qwix.QtProvider([get_quantization_rule(config)]) + return qwix.QtProvider(get_quantization_rule(config)) case "int4": - return qwix.QtProvider([get_quantization_rule(config)]) + return qwix.QtProvider(get_quantization_rule(config)) case "fp8": - return qwix.QtProvider([get_quantization_rule(config)]) + return qwix.QtProvider(get_quantization_rule(config)) case "fp8_full": - return qwix.QtProvider([get_quantization_rule(config)]) + return qwix.QtProvider(get_quantization_rule(config)) case "fp8_gpu": - return NvidaFp8Provider([get_quantization_rule(config)]) + return NvidaFp8Provider(get_quantization_rule(config)) case "fp8_nanoo": - return NANOOFp8Provider([get_quantization_rule(config)]) + return NANOOFp8Provider(get_quantization_rule(config)) return None diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index 24ccf1c7b5..01257a9e89 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -964,7 +964,7 @@ def gmm( weight_gather_axes=weight_gather_axes, input_buffer_count=input_buffer_count, combine_scopes=combine_scopes, - qwix_rule=quantizations.get_fp8_full_qwix_rule(config), + qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config), ) else: output = tokamax.ragged_dot( diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index a3c39acb9f..060dadf637 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -84,7 +84,7 @@ def get_first_step(state): # ----------------------------------------------------------------------------- -def loss_fn(model, config, data, dropout_rng, params, is_train=True): +def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_train=True): """loss_fn for both train and eval. Args: @@ -117,13 +117,18 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): if config.mtp_eval_target_module > 0 and not is_train: mutable_collections.append("mtp_acceptance") + if is_train and config.weight_sparsity_n and config.weight_sparsity_m: + mutable_collections.append("batch_stats") if isinstance(model, nn.Module): # inputs, targets, segments, positions = apply_args rng1, aqt_rng = jax.random.split(dropout_rng) # Flax Linen model + model_vars = {"params": params} + if sparsity_state: + model_vars["batch_stats"] = sparsity_state logits, intermediate_outputs = model.apply( - params, + model_vars, data["inputs"], data["inputs_position"], decoder_segment_ids=data["inputs_segmentation"], @@ -144,7 +149,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): elif config.num_vocab_tiling > 1: hidden_state_key = ("intermediates", "decoder", "hidden_states") hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] - total_loss, total_z_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train) + total_loss, total_z_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, {"params": params}, is_train) else: one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) @@ -290,6 +295,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): "indexer_loss": indexer_loss, "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, + "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } return loss, aux @@ -322,7 +328,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat _loss_fn = dpo_loss_fn params = state.params - if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( _loss_fn, @@ -348,8 +353,20 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat params, params_shardings, ) + pure_params = params["params"] if "params" in params else params + batch_stats = params.get("batch_stats", {}) + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + (loss, aux), raw_grads = grad_func( + model, + config, + data, + dropout_rng, + pure_params, + *extra_dpo_args, + sparsity_state=batch_stats, + is_train=True, + ) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -396,8 +413,22 @@ def move(path, value): jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), ) ) - new_state = state.apply_gradients(grads=grads) + # Re-wrap grads to match state.params structure if it's a dict of collections + if isinstance(state.params, dict) and "params" in state.params: + full_grads = {"params": grads} + if "batch_stats" in state.params: + batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) + full_grads["batch_stats"] = batch_stats_grads + full_grads = max_utils.unbox_logicallypartioned(full_grads) + else: + full_grads = grads + new_state = state.apply_gradients(grads=full_grads) + + if "batch_stats" in aux and isinstance(state.params, dict) and "batch_stats" in state.params: + new_params = dict(new_state.params) + new_params["batch_stats"] = max_utils.unbox_logicallypartioned(aux["batch_stats"]) + new_state = new_state.replace(params=new_params) # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") @@ -452,8 +483,11 @@ def eval_step(model, config, state, data, dropout_rng): extra_dpo_args = [reference_params] _loss_fn = dpo_loss_fn + pure_params = state.params["params"] if "params" in state.params else state.params + batch_stats = state.params.get("batch_stats", {}) + eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn(state.params, *extra_dpo_args) + loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 675f920357..6f73a48eec 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1183,7 +1183,14 @@ def setup_initial_state( out_shardings=state_mesh_shardings, )(rng) if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params=raw_params) + + def _merge_params(p_raw, p_init): + if isinstance(p_raw, jax.ShapeDtypeStruct): + return p_init + return p_raw + + merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) + state = state.replace(params=merged_params) state = max_utils.unbox_logicallypartioned(state)