Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dependencies/extra_deps/qwix
Submodule qwix added at 3d30c0
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ python-dotenv>=1.2.1
pytype>=2024.10.11
pytz>=2025.2
pyyaml>=6.0.3
qwix>=0.1.4
regex>=2025.11.3
requests-oauthlib>=2.0.0
requests>=2.32.5
Expand Down Expand Up @@ -223,7 +222,7 @@ tensorflow>=2.19.1
tensorstore>=0.1.79
termcolor>=3.2.0
tiktoken>=0.12.0
tokamax>=0.0.8
tokamax>=0.0.11
tokenizers>=0.22.1
toml>=0.10.2
tomlkit>=0.13.3
Expand Down
3 changes: 1 addition & 2 deletions src/dependencies/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ pyink
pylint
pytest
pytype
qwix
sentencepiece
tensorboard-plugin-profile
tensorboardx
tensorflow-datasets
tensorflow-text
tensorflow
tiktoken
tokamax>=0.0.4
tokamax>=0.0.11
transformers
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
4 changes: 4 additions & 0 deletions src/dependencies/scripts/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ install_maxtext_with_deps() {
fi
echo "Installing requirements from $dep_name"
python3 -m uv pip install --resolution=lowest -r "$dep_name"
echo "Installing local qwix from extra_deps"
python3 -m uv pip install -e src/dependencies/extra_deps/qwix
python3 -m src.dependencies.scripts.install_pre_train_extra_deps

install_maxtext_package_without_deps
Expand All @@ -230,6 +232,8 @@ install_post_training_deps() {
dep_name='src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt'
echo "Installing requirements from $dep_name"
python3 -m uv pip install --resolution=lowest -r "$dep_name"
echo "Installing local qwix from extra_deps"
python3 -m uv pip install -e src/dependencies/extra_deps/qwix
python3 -m src.dependencies.scripts.install_post_train_extra_deps
}

Expand Down
4 changes: 3 additions & 1 deletion src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,9 @@ def map_to_pspec(data):
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
checkpoint_args = ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state, restore_args=restore_args, partial_restore=True
)

match (checkpoint_manager, dataset_type, data_iterator):
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ act_quantization_calibration_method: "absmax"
bwd_quantization_calibration_method: "absmax"
# shard the range finding operation for quantization. by default this is set to number of slices.
quantization_local_shard_count: -1
# The 'N' in N:M sparsity, representing the maximum number of non-zero values in each block.
weight_sparsity_n: null
# The 'M' in N:M sparsity, representing the number of values in each block.
weight_sparsity_m: null
# The step size to update the sparsity masks.
weight_sparsity_update_step: 10
# The first number of steps before updating the sparsity masks.
weight_sparsity_start_step: 50

decoder_block: "llama2" # which style of decoderblock to use.
# global parameter scale needs to be a power of 2. if you want finer grained control of the model sizes
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,14 @@ class Quantization(BaseModel):
"absmax",
description="Quantization calibration method used for gradients.",
)
weight_sparsity_n: int | None = Field(
None, description="The 'N' in N:M sparsity, representing the maximum number of non-zero values in each block."
)
weight_sparsity_m: int | None = Field(
None, description="The 'M' in N:M sparsity, representing the number of values in each block."
)
weight_sparsity_update_step: int = Field(10, description="The step size for updating weight sparsity masks.")
weight_sparsity_start_step: int = Field(50, description="The first number of steps before updating the sparsity masks.")


class ModelArchitecture(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
"cache": cache_spec,
"intermediates": 0,
"aqt": 0,
"batch_stats": 0,
"_overwrite_with_gradient": 0,
},
split_rngs={
Expand Down
40 changes: 39 additions & 1 deletion src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from maxtext.utils.sharding import maybe_shard_with_logical, create_sharding
from maxtext.utils.sharding import logical_to_mesh_axes
import numpy as np
import qwix
from qwix.contrib.sparsity import sparsity_module
import qwix.pallas as qpl
import tokamax

Expand Down Expand Up @@ -389,6 +391,36 @@ def __init__(
shard_mode=config.shard_mode,
rngs=self.rngs,
)
rule = qpl.get_current_rule("gmm")
sparsity_rule = None
if rule is not None:
if not isinstance(rule, qwix.QtRule):
raise ValueError("Expect a QtRule for quantized training.")
if rule.additional_qt_config and "sparsity_rule" in rule.additional_qt_config:
q_s_rule = rule.additional_qt_config["sparsity_rule"]
if q_s_rule and q_s_rule.weight_sparsity_n and q_s_rule.weight_sparsity_m:
sparsity_rule = q_s_rule

if sparsity_rule is not None:
self.wi_0_sparsity_module = sparsity_module.SparsityModule(
shape=(self.num_experts, self.config.emb_dim, self.intermediate_dim),
sharding_axes=self.wi_kernel_axes,
sparsity_rule=sparsity_rule,
)
self.wi_1_sparsity_module = sparsity_module.SparsityModule(
shape=(self.num_experts, self.config.emb_dim, self.intermediate_dim),
sharding_axes=self.wi_kernel_axes,
sparsity_rule=sparsity_rule,
)
self.wo_sparsity_module = sparsity_module.SparsityModule(
shape=(self.num_experts, self.intermediate_dim, self.config.emb_dim),
sharding_axes=self.wo_kernel_axes,
sparsity_rule=sparsity_rule,
)
else:
self.wi_0_sparsity_module = None
self.wi_1_sparsity_module = None
self.wo_sparsity_module = None

# pylint: disable=protected-access
self.activation_fn = linears._convert_to_activation_function(self.config.mlp_activations[0])
Expand Down Expand Up @@ -909,7 +941,9 @@ def gmm(
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
):
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
if self.config.use_qwix_quantization or (
self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat
):
tokamax_group_sizes = group_sizes
elif self.config.attention == "vllm_rpa":
tokamax_group_sizes = group_sizes
Expand Down Expand Up @@ -2033,6 +2067,10 @@ def __call__(
if self.per_expert_scale is not None:
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]

if self.wi_0_sparsity_module is not None:
_, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel)
_, w1_kernel = self.wi_1_sparsity_module(jnp.zeros_like(w1_kernel), w1_kernel)
_, wo_kernel = self.wo_sparsity_module(jnp.zeros_like(wo_kernel), wo_kernel)
if cfg.mlp_bias:
w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype)
w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype)
Expand Down
138 changes: 80 additions & 58 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import qwix
from qwix._src.core import dot_general_qt
from qwix._src.core import sparsity

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -730,68 +731,89 @@ def dot_general(self, *args, **kwargs):
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)


def get_fp8_full_qwix_rule(config: Config):
return qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
op_names=("dot_general", "gmm", "ragged_dot"),
)
def get_fp8_full_qwix_rule_w_sparsity(config: Config):
sparsity_rule = None
if config.weight_sparsity_n and config.weight_sparsity_m:
sparsity_rule = sparsity.SparsityRule(
weight_sparsity_n=config.weight_sparsity_n,
weight_sparsity_m=config.weight_sparsity_m,
weight_sparsity_update_step=config.weight_sparsity_update_step,
weight_sparsity_start_step=config.weight_sparsity_start_step,
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
weight_calibration_method=config.weight_quantization_calibration_method,
act_calibration_method=config.act_quantization_calibration_method,
bwd_calibration_method=config.bwd_quantization_calibration_method,
additional_qt_config={"sparsity_rule": sparsity_rule},
op_names=("dot_general", "gmm", "ragged_dot"),
),
]


def get_quantization_rule(config: Config):
match config.quantization:
case "int4":
return qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.int4,
act_qtype=jnp.int4,
bwd_qtype=jnp.int4,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.int4,
act_qtype=jnp.int4,
bwd_qtype=jnp.int4,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "int8":
return qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.int8,
act_qtype=jnp.int8,
bwd_qtype=jnp.int8,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.int8,
act_qtype=jnp.int8,
bwd_qtype=jnp.int8,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "fp8":
return qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "fp8_full":
return get_fp8_full_qwix_rule(config)
return get_fp8_full_qwix_rule_w_sparsity(config)
case "fp8_gpu":
return qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "fp8_nanoo":
return qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
return [
qwix.QtRule(
module_path="decoder/.*layers.*",
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
op_names=("dot_general",),
)
]
case "":
return None

Expand All @@ -800,17 +822,17 @@ def get_qt_provider(config):
"""Get quantization rules based on the config."""
match config.quantization:
case "int8":
return qwix.QtProvider([get_quantization_rule(config)])
return qwix.QtProvider(get_quantization_rule(config))
case "int4":
return qwix.QtProvider([get_quantization_rule(config)])
return qwix.QtProvider(get_quantization_rule(config))
case "fp8":
return qwix.QtProvider([get_quantization_rule(config)])
return qwix.QtProvider(get_quantization_rule(config))
case "fp8_full":
return qwix.QtProvider([get_quantization_rule(config)])
return qwix.QtProvider(get_quantization_rule(config))
case "fp8_gpu":
return NvidaFp8Provider([get_quantization_rule(config)])
return NvidaFp8Provider(get_quantization_rule(config))
case "fp8_nanoo":
return NANOOFp8Provider([get_quantization_rule(config)])
return NANOOFp8Provider(get_quantization_rule(config))
return None


Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/deepseek_batchsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def gmm(
weight_gather_axes=weight_gather_axes,
input_buffer_count=input_buffer_count,
combine_scopes=combine_scopes,
qwix_rule=quantizations.get_fp8_full_qwix_rule(config),
qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config),
)
else:
output = tokamax.ragged_dot(
Expand Down
Loading
Loading