Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
20ad279
megatron: integrate lora grad sync with finalize_model_grads
FurtherAI Mar 10, 2026
112e97c
megatron: harden sharded lora merge validation
FurtherAI Mar 10, 2026
4d5c345
tests: add megatron lora oracle correctness harness
FurtherAI Mar 10, 2026
fde2ff3
Minor typing changes
FurtherAI Mar 10, 2026
d2c1161
megatron: extend LoRA grad-sync semantics across tp/expert-tp
FurtherAI Mar 12, 2026
e418018
megatron: add MoE routing replay core and unit tests
FurtherAI Mar 12, 2026
bc5e7a4
megatron runtime/service: wire routing replay into training jobs
FurtherAI Mar 12, 2026
c5e06d9
oracle worker/trace: capture forward traces and emit replay bundles
FurtherAI Mar 12, 2026
a73ca1a
oracle harness/tests: refactor suite and add oracle-replay parity flow
FurtherAI Mar 12, 2026
ec83716
typing: clear blocking ty errors in oracle replay and LoRA paths
FurtherAI Mar 12, 2026
83d871b
megatron: reduce oracle variance with sequence grad accumulation
FurtherAI Mar 14, 2026
84e2ea7
megatron lora: fix TP/EP export participation rules
FurtherAI Mar 14, 2026
0bc9919
oracle trace: canonicalize MoE outputs across arbitrary topologies
FurtherAI Mar 14, 2026
8370c7d
oracle harness: stabilize scoring and expand sensitivity mutations
FurtherAI Mar 14, 2026
d396bfd
oracle tests: write suite output tables to log files
FurtherAI Mar 14, 2026
5385fbb
Add correct data parallelism.
FurtherAI Mar 16, 2026
7525567
Fix per-token DP normalization in Megatron training
FurtherAI Mar 17, 2026
7eb96e5
Expand the oracle harness for DP correctness checks
FurtherAI Mar 17, 2026
204e580
Merge origin/main into austin/megatron_lora_correctness_oracle_tests
FurtherAI Mar 17, 2026
9cde0d4
Clean up type errors in Megatron correctness changes
FurtherAI Mar 17, 2026
b2494ea
Testing harness was working, but real training surfaced a few errors,…
FurtherAI Mar 20, 2026
a98fafc
Cut over Megatron LoRA to QuACK
FurtherAI Mar 20, 2026
45e32f5
Del held packed tensors so dir can be removed.
FurtherAI Mar 20, 2026
a77bd7c
Fuse LoRA scale into QuACK grouped GEMM
FurtherAI Mar 21, 2026
8b83fb2
Avoid grad_out copy in QuACK LoRA backward
FurtherAI Mar 21, 2026
f39a5b2
Fuse MoE FC1 gate and up LoRA paths
FurtherAI Mar 23, 2026
92858a9
Tune QuACK low-rank tiles and rank contract
FurtherAI Mar 23, 2026
8cc45b8
Inline FC1 QuACK dual call
FurtherAI Mar 23, 2026
ed671b1
Merge remote-tracking branch 'origin/main' into austin/megatron_lora_…
FurtherAI Mar 24, 2026
6494108
Revert unnecessary python 3.12 requirement.
FurtherAI Mar 24, 2026
c26c00b
Merge branch 'main' into austin/megatron_lora_correctness_oracle_tests
FurtherAI Mar 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
539 changes: 539 additions & 0 deletions dev/bench_cute_grouped_lora.py

Large diffs are not rendered by default.

184 changes: 184 additions & 0 deletions dev/tune_quack_lora_tiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""Offline tuner for QuACK grouped LoRA tile heuristics on the ART layer bench."""

from __future__ import annotations

import argparse
from collections.abc import Iterator
from contextlib import contextmanager
import gc
import itertools
import json
import os
from pathlib import Path
import sys
from typing import Any

import torch

REPO_ROOT = Path(__file__).resolve().parents[1]
ART_SRC_ROOT = REPO_ROOT / "src"


def _resolve_art_harness_root() -> Path:
for candidate in REPO_ROOT.parents:
maybe_root = candidate / "projects" / "art_harness"
if maybe_root.is_dir():
return maybe_root
raise RuntimeError(
"Unable to locate projects/art_harness from the current worktree."
)


ART_HARNESS_ROOT = _resolve_art_harness_root()

if str(ART_HARNESS_ROOT) not in sys.path:
sys.path.insert(0, str(ART_HARNESS_ROOT))

import art_harness.layer_benches.bench_moe_lora as bench

ENV_NAMES = {
"proj_tile_n": "ART_QUACK_PROJ_TILE_N",
"matmul_tile_n": "ART_QUACK_MATMUL_TILE_N",
"grad_a_tile_m": "ART_QUACK_GRAD_A_TILE_M",
"grad_b_tile_m": "ART_QUACK_GRAD_B_TILE_M",
}


def _parse_csv_ints(raw: str) -> list[int]:
values = [int(part.strip()) for part in raw.split(",") if part.strip()]
if not values:
raise ValueError(f"Expected at least one integer in '{raw}'")
for value in values:
if value <= 0:
raise ValueError(f"Expected positive integers, got {value}")
return values


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Tune QuACK grouped LoRA tile heuristics against the ART layer bench."
)
parser.add_argument("--batch", type=int, default=1)
parser.add_argument("--seq-len", type=int, default=65536)
parser.add_argument("--hidden-size", type=int, default=2048)
parser.add_argument("--ffn-hidden-size", type=int, default=768)
parser.add_argument("--num-experts", type=int, default=128)
parser.add_argument("--top-k", type=int, default=8)
parser.add_argument("--lora-rank", type=int, default=1)
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--warmup", type=int, default=6)
parser.add_argument("--iters", type=int, default=12)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--proj-tile-n", type=str, default="32,64,128")
parser.add_argument("--matmul-tile-n", type=str, default="64,128")
parser.add_argument("--grad-a-tile-m", type=str, default="64,128")
parser.add_argument("--grad-b-tile-m", type=str, default="64,128")
parser.add_argument("--top-results", type=int, default=5)
parser.add_argument("--output-json", type=Path, default=None)
return parser.parse_args()


@contextmanager
def _tile_env(config: dict[str, int]) -> Iterator[None]:
previous = {name: os.environ.get(name) for name in ENV_NAMES.values()}
try:
for key, value in config.items():
os.environ[ENV_NAMES[key]] = str(value)
yield
finally:
for name, old_value in previous.items():
if old_value is None:
os.environ.pop(name, None)
else:
os.environ[name] = old_value


def _run_config(args: argparse.Namespace, config: dict[str, int]) -> dict[str, Any]:
bench.ART_WORKTREE_SRC = ART_SRC_ROOT
with _tile_env(config):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
result = bench.benchmark(args)
peak_alloc = torch.cuda.max_memory_allocated()
peak_reserved = torch.cuda.max_memory_reserved()
return {
"config": config,
"timing_ms": result["timing_ms"],
"timed_module_breakdown_ms": result["timed_module_breakdown_ms"],
"flops": result["flops"],
"peak_memory_gib": {
"allocated": peak_alloc / (1024**3),
"reserved": peak_reserved / (1024**3),
},
}


def main() -> None:
if not torch.cuda.is_available():
raise SystemExit("CUDA is required for QuACK tile tuning.")

cli = _parse_args()
bench_args = argparse.Namespace(
batch=cli.batch,
seq_len=cli.seq_len,
hidden_size=cli.hidden_size,
ffn_hidden_size=cli.ffn_hidden_size,
num_experts=cli.num_experts,
top_k=cli.top_k,
lora_rank=cli.lora_rank,
dtype=cli.dtype,
warmup=cli.warmup,
iters=cli.iters,
peak_tflops=None,
seed=cli.seed,
)

configs: list[dict[str, int]] = []
for proj_tile_n, matmul_tile_n, grad_a_tile_m, grad_b_tile_m in itertools.product(
_parse_csv_ints(cli.proj_tile_n),
_parse_csv_ints(cli.matmul_tile_n),
_parse_csv_ints(cli.grad_a_tile_m),
_parse_csv_ints(cli.grad_b_tile_m),
):
configs.append(
{
"proj_tile_n": proj_tile_n,
"matmul_tile_n": matmul_tile_n,
"grad_a_tile_m": grad_a_tile_m,
"grad_b_tile_m": grad_b_tile_m,
}
)

results: list[dict[str, Any]] = []
for config in configs:
try:
payload = _run_config(bench_args, config)
except Exception as exc:
payload = {"config": config, "error": repr(exc)}
results.append(payload)
print(json.dumps(payload, sort_keys=True), flush=True)

successful = [item for item in results if "timing_ms" in item]
successful.sort(key=lambda item: float(item["timing_ms"]["total_mean"]))
summary = {
"search_space": {
"proj_tile_n": _parse_csv_ints(cli.proj_tile_n),
"matmul_tile_n": _parse_csv_ints(cli.matmul_tile_n),
"grad_a_tile_m": _parse_csv_ints(cli.grad_a_tile_m),
"grad_b_tile_m": _parse_csv_ints(cli.grad_b_tile_m),
},
"benchmark_config": vars(bench_args),
"top_results": successful[: cli.top_results],
"num_successful": len(successful),
"num_total": len(results),
}
if cli.output_json is not None:
cli.output_json.parent.mkdir(parents=True, exist_ok=True)
cli.output_json.write_text(json.dumps(summary, indent=2, sort_keys=True))
print(json.dumps(summary, indent=2, sort_keys=True))


if __name__ == "__main__":
main()
31 changes: 24 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ backend = [
]
megatron = [
"torch>=2.8.0",
"quack-kernels==0.2.5",
"apex",
"transformer-engine==2.11.0",
"transformer-engine-cu12==2.11.0",
Expand Down Expand Up @@ -126,26 +127,40 @@ markers = [

[tool.uv]
required-version = ">=0.6.15"
# Override numpy to <2.0 for compatibility with megatron-core in the training
# environment. vLLM pulls opencv-python-headless>=4.13 which wants numpy>=2 on
# Python 3.9+, but megatron-core requires numpy<2.
override-dependencies = ["transformer-engine>=2.11.0", "numpy<2"]
# Keep apex build isolation enabled so uv can inject torch from
# `extra-build-dependencies` during lock/sync on non-GPU client machines.
no-build-isolation-package = ["transformer-engine", "transformer-engine-cu12", "transformer-engine-torch", "megatron-core", "megatron-bridge", "nv-grouped-gemm", "mamba-ssm", "causal-conv1d"]
override-dependencies = [
"transformer-engine==2.11.0",
"numpy<2",
"torch==2.10.0",
"quack-kernels==0.2.5",
]
no-build-isolation-package = ["transformer-engine-torch", "megatron-core", "megatron-bridge", "nv-grouped-gemm", "mamba-ssm", "causal-conv1d"]

[tool.uv.extra-build-dependencies]
apex = ["torch>=2.8.0"]
transformer-engine-torch = ["torch>=2.8.0"]

[tool.uv.extra-build-variables]
apex = { APEX_CPP_EXT = "1", APEX_CUDA_EXT = "1", APEX_FAST_LAYER_NORM = "1", APEX_PARALLEL_BUILD = "16", NVCC_APPEND_FLAGS = "--threads 4" }
transformer-engine-torch = { NVTE_NO_LOCAL_VERSION = "1" }

[[tool.uv.dependency-metadata]]
name = "apex"
version = "0.1"
requires-dist = ["packaging"]

[[tool.uv.dependency-metadata]]
name = "transformer-engine-torch"
version = "2.11.0"
requires-dist = [
"einops",
"onnx",
"onnxscript",
"packaging",
"pydantic",
"torch",
"transformer-engine-cu12",
]

[tool.ty.environment]
python-version = "3.11"

Expand Down Expand Up @@ -194,6 +209,7 @@ allowed-unresolved-imports = [
"seaborn.**",
# megatron deps
"megatron.**",
"quack.**",
]

[dependency-groups]
Expand All @@ -217,3 +233,4 @@ dev = [
[tool.uv.sources]
panza = { git = "https://github.com/corbt/panza.git" }
apex = { git = "https://github.com/NVIDIA/apex.git", branch = "25.09" }
transformer-engine-torch = { git = "https://github.com/NVIDIA/TransformerEngine.git", tag = "v2.11", subdirectory = "transformer_engine/pytorch" }
8 changes: 7 additions & 1 deletion src/art/dev/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Literal
from typing import TYPE_CHECKING, Literal

from typing_extensions import TypedDict

if TYPE_CHECKING:
from art.megatron.routing_replay import MoeRoutingReplayBundle


class TrainConfig(TypedDict, total=False):
advantage_balance: float
Expand All @@ -22,6 +25,9 @@ class TrainConfig(TypedDict, total=False):
logprob_calculation_chunk_size: int
mask_prob_ratio: bool
max_negative_advantage_importance_sampling_weight: float
moe_routing_replay_bundle: "MoeRoutingReplayBundle | None"
moe_routing_replay_path: str | None
moe_routing_replay_strict: bool
num_trajectories_learning_rate_multiplier_power: float
plot_tensors: bool
ppo: bool
Expand Down
5 changes: 4 additions & 1 deletion src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,10 @@ async def _train_model(
packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors"
)
# Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train())
estimated_gradient_steps = disk_packed_tensors["num_sequences"]
grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences))
estimated_gradient_steps = math.ceil(
disk_packed_tensors["num_sequences"] / grad_accumulation_sequences
)
pbar = tqdm.tqdm(total=estimated_gradient_steps, desc="train")
async for result in service.train(
disk_packed_tensors, config, dev_config, verbose
Expand Down
34 changes: 23 additions & 11 deletions src/art/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from pydantic import BaseModel, ConfigDict
import torch
Expand All @@ -13,8 +13,10 @@

class Loss(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mean_policy_loss: torch.Tensor
mean_entropy: torch.Tensor | None
reduction: Literal["mean", "sum"]
policy_loss: torch.Tensor
kl: torch.Tensor
entropy: torch.Tensor | None
policy_loss_sum: torch.Tensor
probs_corr: torch.Tensor
kl_policy_ref: torch.Tensor | None = None
Expand All @@ -26,6 +28,7 @@ def loss_fn(
ref_logprobs: torch.Tensor | None,
entropies: torch.Tensor | None,
experimental_config: dev.TrainConfig,
reduction: Literal["mean", "sum"] = "mean",
) -> Loss:
old_logprobs = shift_tensor(inputs["logprobs"], float("nan"))
advantages = shift_tensor(inputs["advantages"], 0.0)
Expand Down Expand Up @@ -123,19 +126,28 @@ def loss_fn(
logprob_diff = old_logprobs - original_logprobs
prob_ratio = torch.exp(logprob_diff)
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
if ref_logprobs is not None:
kl_div = (
torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0
)
else:
kl_div = torch.zeros_like(policy_loss)
policy_loss = policy_loss * weights * assistant_mask
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
# Compute mean entropy for the current step
kl_div = kl_div * weights * assistant_mask
denominator = assistant_mask.sum() + 1e-6 if reduction == "mean" else 1.0
reduced_policy_loss = policy_loss.sum() / denominator
kl = kl_div.sum() / denominator
# Compute reduced entropy for the current step.
if entropies is not None:
shifted_entropies = shift_tensor(entropies, 0.0)
mean_entropy = (shifted_entropies * weights * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
entropy = (shifted_entropies * weights * assistant_mask).sum() / denominator
else:
mean_entropy = None
entropy = None
return Loss(
mean_policy_loss=mean_policy_loss,
mean_entropy=mean_entropy,
reduction=reduction,
policy_loss=reduced_policy_loss,
kl=kl,
entropy=entropy,
policy_loss_sum=policy_loss.sum(),
probs_corr=probs_corr,
kl_policy_ref=kl_policy_ref,
Expand Down
Loading
Loading