Skip to content

NVFP4 learn scale_after_dequant PTQ/QAD#864

Open
realAsma wants to merge 6 commits intomainfrom
asma/scale_learning
Open

NVFP4 learn scale_after_dequant PTQ/QAD#864
realAsma wants to merge 6 commits intomainfrom
asma/scale_learning

Conversation

@realAsma
Copy link
Contributor

@realAsma realAsma commented Feb 6, 2026

What does this PR do?

Type of change: New feature (quantization algorithm) + new tests + bug fix

Overview:

  • Adds a new quantization calibration/finetuning mode, scale_after_dequant, for NVFP4 static weight quantizers:
    • Runs MSE calibration with FP8 scale sweep (required for static NVFP4).
    • Converts NVFP4 weight quantizers into a scale-after-dequant setup where per-block scale is learnable (nn.Parameter) and per-tensor scale is frozen.
    • Uses an STE-based FP4 cast to keep gradients flowing to the learnable scale.
  • Adds a CUDA test validating:
    • Output parity vs “MSE + FP8 sweep only” before training
    • Gradients exist for the new per-block learnable scale parameter

Usage

import modelopt.torch.quantization as mtq

NVFP4_WEIGHT_SCALE_LEARN_CFG = {
    "quant_cfg": {
        "*weight_quantizer": {
            "num_bits": (2, 1),
            "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
            "axis": None,
            "enable": True,
        },
        "*input_quantizer": {"enable": False},
    },
    "algorithm": {"method": "scale_after_dequant"},
}

def forward_loop(model):
    # run a few calibration batches through model(...)
    ...

mtq.quantize(model, NVFP4_WEIGHT_SCALE_LEARN_CFG, forward_loop)
# Model now has NVFP4 weight quantizers in scale-after-dequant mode,
# with learnable per-block scale.

Testing

  • Ran: pytest -q tests/gpu/torch/quantization/test_quantize_cuda.py -k scale_after_dequant

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No - Will do later after experiment verification.

Additional Information

  • New mode registered as ScaleAfterDequantModeDescriptor and config added as ScaleAfterDequantConfig.
  • Core implementation: modelopt/torch/quantization/model_calib.py (scale_after_dequant).
  • NVFP4 scale-after-dequant behavior: modelopt/torch/quantization/nn/modules/tensor_quantizer.py.
  • STE FP4 cast: modelopt/torch/quantization/tensor_quant.py + modelopt/torch/quantization/triton/fp4_kernel.py.

Summary by CodeRabbit

  • New Features

    • Added "scale-after-dequant" quantization algorithm for optimized FP8 quantization workflows
    • Enabled accelerated FP4 quantization on Hopper GPUs
    • Introduced two-level scaling support for advanced quantizer configurations
  • Improvements

    • Enhanced MSE-based calibration with flexible candidate generation
    • Extended CUDA quantization capabilities with per-block and global-level scaling controls

…antizer, NVFP4MSECalibrator

Signed-off-by: realAsma <akuriparambi@nvidia.com>

fp4 static kernel fix, test fixes, minor clean ups

Signed-off-by: realAsma <akuriparambi@nvidia.com>

minor

Signed-off-by: realAsma <akuriparambi@nvidia.com>

minor

Signed-off-by: realAsma <akuriparambi@nvidia.com>

minor

Signed-off-by: realAsma <akuriparambi@nvidia.com>

minor

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma requested a review from a team as a code owner February 6, 2026 19:06
@realAsma realAsma requested a review from sugunav14 February 6, 2026 19:06
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

Caution

Review failed

Failed to post review comments

📝 Walkthrough

Walkthrough

This PR introduces a new FP4 static quantization path with two-level scaling support. Key additions include NVFP4StaticQuantizer for per-block and global amax scaling, NVFP4MSECalibrator for FP4-aware calibration, and a new scale_after_dequant algorithm. The tensor quantization layer is refactored with new FP4 casting and static blockwise quantization functions, while the Triton backend is extended with fp4_fake_quant_block for Hopper+ GPUs.

Changes

Cohort / File(s) Summary
Core Quantizer Classes
modelopt/torch/quantization/nn/modules/tensor_quantizer.py, modelopt/torch/quantization/tensor_quant.py
Introduced NVFP4StaticQuantizer for two-level scaling with from_tensor_quantizer conversion, per-block and global amax support, and scale-after-dequant workflow. Added FP4CastSTEFunction for FP4 casting with STE backward. Refactored StaticBlockwiseFP4FakeQuantFunction signature from scale-based to amax-based parameters (amax, global_amax, quantize_block_scales).
Calibration Infrastructure
modelopt/torch/quantization/calib/mse.py, modelopt/torch/quantization/model_calib.py
Introduced NVFP4MSECalibrator with FP4-specific candidate generation (126 E4M3 scales) and per-block global amax support. Refactored MseCalibrator to use generic candidate mechanism and removed fp8_scale_sweep parameter. Added scale_after_dequant function for MSE-based calibration with per-block scale application. Enhanced NVFP4 detection and conversion in mse_calibrate.
Configuration & Mode Registration
modelopt/torch/quantization/config.py, modelopt/torch/quantization/mode.py, modelopt/torch/quantization/conversion.py
Added ScaleAfterDequantConfig with scale_algorithm field. Registered ScaleAfterDequantModeDescriptor in CalibrateModeRegistry. Added NVFP4StaticQuantizer import and conversion logic in restore_quantizer_state for seamless quantizer type migration.
Triton Backend
modelopt/torch/quantization/triton/__init__.py, modelopt/torch/quantization/triton/fp4_kernel.py, modelopt/torch/quantization/triton/fp4_kernel_hopper.py
Relaxed IS_AVAILABLE guard to depend only on CUDA (hopper kernel conditional on capability ≥ 8.9). Refactored static_blockwise_fp4_fake_quant to use amax/global_amax semantics with quantize_block_scales flag. Added static_blockwise_fp4_cast for FP4 rounding. Introduced fp4_fake_quant_block for Hopper+ tiled FP4 quantization with block pointers and per-block max computation.
Test Suite
tests/_test_utils/torch/quantization/quantize_common.py, tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py, tests/gpu/torch/quantization/test_quantize_cuda.py, tests/gpu/torch/quantization/test_tensor_quant_cuda.py, tests/unit/torch/quantization/test_mse_calibrator.py
Added test_cpu_restore parameter to save_restore_test for quantizer type preservation verification. Created comprehensive NVFP4StaticQuantizer and NVFP4MSECalibrator test coverage. Added test_scale_after_dequant_grad for gradient tracking validation. Updated Triton path tests to use fp4_fake_quant_block attribute check and amax-based invocations. Removed FP8 scale sweep test.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant scale_after_dequant as scale_after_dequant Function
    participant mse_calibrate as mse_calibrate
    participant NVFP4MSECalibrator
    participant Module as Module Quantizers
    participant NVFP4SQ as NVFP4StaticQuantizer
    
    User->>scale_after_dequant: Run with model, forward_loop, scale_algorithm
    scale_after_dequant->>mse_calibrate: Call with scale_algorithm config
    mse_calibrate->>Module: Detect NVFP4 quantizers
    mse_calibrate->>NVFP4MSECalibrator: Create with global_amax, FP4 candidates
    NVFP4MSECalibrator->>NVFP4MSECalibrator: Generate 126 FP8 E4M3 scales
    User->>NVFP4MSECalibrator: Collect calibration data via forward_loop
    NVFP4MSECalibrator->>NVFP4MSECalibrator: Aggregate losses per candidate
    NVFP4MSECalibrator->>NVFP4MSECalibrator: Compute optimal amax (per-block)
    scale_after_dequant->>Module: Extract per-block/per-tensor scales from quantizers
    scale_after_dequant->>NVFP4SQ: Apply scales via enable_scale_after_dequant
    NVFP4SQ->>NVFP4SQ: Store per_block_scale, per_tensor_scale (learnable)
    scale_after_dequant->>User: Return model in scale-after-dequant mode
Loading
sequenceDiagram
    participant Input as Input Tensor
    participant NVFP4SQ as NVFP4StaticQuantizer
    participant TritonKernel as Triton Kernel
    participant Output as Output (FP4)
    
    Input->>NVFP4SQ: _fake_quantize(x)
    alt scale_after_dequant enabled
        NVFP4SQ->>NVFP4SQ: Apply per_block_scale * per_tensor_scale
        NVFP4SQ->>TritonKernel: static_blockwise_fp4_fake_quant (amax, global_amax)
    else standard quantization
        NVFP4SQ->>TritonKernel: static_blockwise_fp4_fake_quant (amax, global_amax)
    end
    TritonKernel->>TritonKernel: Compute per-block max, quantize, descale
    TritonKernel->>Output: Return FP8 (stored as input dtype)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 69.70% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main feature: enabling learnable scale-after-dequant quantization for NVFP4 in post-training quantization (PTQ) and quantization-aware training (QAT) scenarios.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch asma/scale_learning

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma changed the title NVFP4 Learn dequantize scale NVFP4 learn scale_after_dequant Feb 6, 2026
@realAsma realAsma changed the title NVFP4 learn scale_after_dequant NVFP4 learn scale_after_dequant PTQ/QAD Feb 6, 2026
@realAsma realAsma changed the base branch from main to asma/refactor-scale-sweep February 6, 2026 19:07
@realAsma
Copy link
Contributor Author

realAsma commented Feb 6, 2026

Experiment Script:

"""Train learnable FP4 per-block scales for an NVFP4-quantized HuggingFace model."""

import argparse
import os
import warnings

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.distill.losses import LogitsDistillationLoss
from modelopt.torch.quantization.nn import NVFP4StaticQuantizer
from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader

mto.enable_huggingface_checkpointing()

NVFP4_WEIGHT_SCALE_LEARN_CFG = {
    "quant_cfg": {
        "*weight_quantizer": {
            "num_bits": (2, 1),
            "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
            "axis": None,
            "enable": True,
        },
        "*input_quantizer": {"enable": False},
    },
    "algorithm": {"method": "scale_after_dequant"},
}

DEFAULT_DATASET = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model", type=str, default="Qwen/Qwen3-8B")
    p.add_argument(
        "--calib_size",
        type=str,
        default="32",
        help="Calibration samples. Comma-separated list for multiple datasets.",
    )
    p.add_argument("--train_size", type=int, default=1024)
    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--epochs", type=int, default=10)
    p.add_argument("--seq_len", type=int, default=512)
    p.add_argument("--log_interval", type=int, default=1)
    p.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Dataset name or comma-separated list. Defaults to cnn_dailymail,nemotron-post-training-dataset-v2.",
    )
    p.add_argument("--eval_size", type=int, default=128, help="Number of eval samples.")
    p.add_argument(
        "--eval_steps",
        type=int,
        default=20,
        help="Run eval every N training steps. 0 = eval only at end of each epoch.",
    )
    p.add_argument("--min_lr", type=float, default=1e-7, help="Minimum LR for cosine annealing.")
    p.add_argument(
        "--warmup_ratio",
        type=float,
        default=0.05,
        help="Fraction of total training steps used for linear LR warmup (default 5%%).",
    )
    p.add_argument("--seed", type=int, default=42, help="Random seed for train/eval split.")
    p.add_argument("--output_dir", type=str, default=None)
    return p.parse_args()


@torch.no_grad()
def evaluate(model, teacher_model, eval_dataloader, kd_loss_fn):
    """Run evaluation and return average KL-divergence loss."""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    for batch in eval_dataloader:
        student_logits = model(**batch).logits
        teacher_logits = teacher_model(**batch).logits
        loss = kd_loss_fn(student_logits, teacher_logits)
        total_loss += loss.item()
        num_batches += 1
    model.train()
    return total_loss / max(num_batches, 1)


def main():
    args = parse_args()
    device = torch.device("cuda")

    # Parse dataset and calib_size (same logic as hf_ptq.py)
    if args.dataset is not None:
        dataset_name = args.dataset.split(",")
    else:
        dataset_name = DEFAULT_DATASET
        warnings.warn(
            f"No dataset specified. Defaulting to {','.join(DEFAULT_DATASET)}."
        )
    calib_size = [int(s) for s in args.calib_size.split(",")]
    # Extend calib_size to match number of datasets
    calib_size = (calib_size + [calib_size[-1]] * len(dataset_name))[: len(dataset_name)]

    ptq_dir = os.path.join(args.output_dir, "ptq") if args.output_dir else None

    tokenizer = AutoTokenizer.from_pretrained(args.model)

    # --- Phase 1: Quantize with NVFP4 + scale_after_dequant (or load cached PTQ) ---
    if ptq_dir and os.path.isdir(ptq_dir):
        print(f"Loading cached PTQ model from {ptq_dir}...")
        model = AutoModelForCausalLM.from_pretrained(
            ptq_dir, torch_dtype=torch.bfloat16, device_map="auto"
        )
        mtq.print_quant_summary(model)
    else:
        print(f"Loading model: {args.model}")
        model = AutoModelForCausalLM.from_pretrained(
            args.model, torch_dtype=torch.bfloat16, device_map="auto"
        )

        print(f"Building calibration dataloader (datasets={dataset_name}, calib_size={calib_size})...")
        calib_dataloader = get_dataset_dataloader(
            dataset_name=dataset_name,
            tokenizer=tokenizer,
            batch_size=args.batch_size,
            num_samples=calib_size,
            max_sample_length=args.seq_len,
            device=device,
        )
        calib_loop = create_forward_loop(dataloader=calib_dataloader)

        print("Quantizing model with NVFP4 + scale_after_dequant...")
        mtq.quantize(model, NVFP4_WEIGHT_SCALE_LEARN_CFG, calib_loop)

        mtq.print_quant_summary(model)

        if ptq_dir:
            print(f"Saving PTQ model to {ptq_dir}...")
            model.save_pretrained(ptq_dir)
            tokenizer.save_pretrained(ptq_dir)

    # --- Phase 2: Freeze all params, enable grad only for per_block_scale ---
    for p in model.parameters():
        p.requires_grad_(False)

    scale_params = []
    scale_param_names = []
    for name, module in model.named_modules():
        if isinstance(module, NVFP4StaticQuantizer) and module._scale_after_dequant:
            module._per_block_scale.requires_grad_(True)
            scale_params.append(module._per_block_scale)
            scale_param_names.append(name)

    print(f"Learnable scale parameters: {len(scale_params)}")
    total_scale_elements = sum(p.numel() for p in scale_params)
    print(f"Total learnable elements: {total_scale_elements}")

    optimizer = torch.optim.Adam(scale_params, lr=args.lr)

    # --- Phase 3: Load unquantized teacher model ---
    print("Loading reference (unquantized) teacher model...")
    teacher_model = AutoModelForCausalLM.from_pretrained(
        args.model, torch_dtype=torch.bfloat16, device_map="auto"
    )
    teacher_model.eval()
    for p in teacher_model.parameters():
        p.requires_grad_(False)

    kd_loss_fn = LogitsDistillationLoss()

    # --- Phase 4: Train ---
    total_samples = args.train_size + args.eval_size
    per_ds = [total_samples // len(dataset_name)] * len(dataset_name)
    per_ds[-1] += total_samples - sum(per_ds)  # distribute remainder to the last dataset

    print(
        f"Building combined dataloader ({total_samples} samples = "
        f"{args.train_size} train + {args.eval_size} eval, seed={args.seed})..."
    )
    combined_dataloader = get_dataset_dataloader(
        dataset_name=dataset_name,
        tokenizer=tokenizer,
        batch_size=args.batch_size,
        num_samples=per_ds,
        max_sample_length=args.seq_len,
        device=device,
        include_labels=False,  # Labels are not needed since we use KL Loss
    )

    # Split the underlying dataset into train / eval with a deterministic seed
    combined_dataset = combined_dataloader.dataset
    generator = torch.Generator().manual_seed(args.seed)
    train_dataset, eval_dataset = random_split(
        combined_dataset, [args.train_size, args.eval_size], generator=generator
    )
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False)

    # --- LR schedule: linear warmup + cosine annealing ---
    steps_per_epoch = len(train_dataloader)
    total_steps = steps_per_epoch * args.epochs
    warmup_steps = max(int(args.warmup_ratio * total_steps), 1)
    cosine_steps = total_steps - warmup_steps
    print(
        f"LR schedule: warmup {warmup_steps} steps -> cosine {cosine_steps} steps "
        f"(lr={args.lr}, min_lr={args.min_lr})"
    )

    warmup_scheduler = LinearLR(
        optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_steps
    )
    cosine_scheduler = CosineAnnealingLR(
        optimizer, T_max=max(cosine_steps, 1), eta_min=args.min_lr
    )
    scheduler = SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[warmup_steps],
    )

    # --- TensorBoard ---
    tb_dir = os.path.join(args.output_dir, "tb_logs") if args.output_dir else "tb_logs"
    writer = SummaryWriter(log_dir=tb_dir)
    print(f"TensorBoard logs: {tb_dir}")

    # Enable gradient checkpointing to reduce memory usage during training
    model.gradient_checkpointing_enable()
    model.model.embed_tokens.weight.requires_grad_(True)

    # Run initial eval before training
    print("Running initial evaluation...")
    eval_loss = evaluate(model, teacher_model, eval_dataloader, kd_loss_fn)
    print(f"[Eval] Before training: eval_kl_loss={eval_loss:.6f}")
    writer.add_scalar("eval/kl_loss", eval_loss, 0)

    # Track the best eval loss and corresponding scale params
    best_eval_loss = eval_loss
    best_scale_state = {name: p.clone() for name, p in zip(scale_param_names, scale_params)}

    global_step = 0
    model.train()
    torch.set_grad_enabled(True)
    for epoch in range(args.epochs):
        total_loss = 0.0
        total_grad_norm = 0.0
        for step, batch in enumerate(train_dataloader):

            student_logits = model(**batch).logits

            with torch.no_grad():
                teacher_logits = teacher_model(**batch).logits

            loss = kd_loss_fn(student_logits, teacher_logits)

            loss.backward()

            # Compute grad norm before optimizer step clears gradients
            grad_norm = (
                sum(p.grad.norm().item() ** 2 for p in scale_params if p.grad is not None)
                ** 0.5
            )

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

            total_loss += loss.item()
            total_grad_norm += grad_norm
            global_step += 1

            # Per-step TensorBoard scalars
            current_lr = scheduler.get_last_lr()[0]
            writer.add_scalar("train/kl_loss", loss.item(), global_step)
            writer.add_scalar("train/grad_norm", grad_norm, global_step)
            writer.add_scalar("train/lr", current_lr, global_step)

            if (step + 1) % args.log_interval == 0:
                avg_loss = total_loss / args.log_interval
                avg_grad_norm = total_grad_norm / args.log_interval
                print(
                    f"Epoch {epoch+1} Step {step+1}: "
                    f"kl_loss={avg_loss:.6f}  grad_norm={avg_grad_norm:.3e}  lr={current_lr:.3e}"
                )
                total_loss = 0.0
                total_grad_norm = 0.0

            # Periodic eval within epoch
            if args.eval_steps > 0 and global_step % args.eval_steps == 0:
                eval_loss = evaluate(model, teacher_model, eval_dataloader, kd_loss_fn)
                is_best = eval_loss < best_eval_loss
                if is_best:
                    best_eval_loss = eval_loss
                    best_scale_state = {
                        n: p.clone() for n, p in zip(scale_param_names, scale_params)
                    }
                writer.add_scalar("eval/kl_loss", eval_loss, global_step)
                writer.add_scalar("eval/best_kl_loss", best_eval_loss, global_step)
                print(
                    f"[Eval] Epoch {epoch+1} Step {step+1} "
                    f"(global_step={global_step}): eval_kl_loss={eval_loss:.6f}"
                    f"{'  *best*' if is_best else ''}"
                )

        # Restore best scale parameters and evaluate end-of-epoch
        print(f"Restoring best scale params (eval_kl_loss={best_eval_loss:.6f})...")
        for name, module in model.named_modules():
            if name in best_scale_state:
                module._per_block_scale.data.copy_(best_scale_state[name])

        eval_loss = evaluate(model, teacher_model, eval_dataloader, kd_loss_fn)
        writer.add_scalar("eval/kl_loss_best_restored", eval_loss, epoch + 1)
        print(f"[Eval] End of epoch {epoch+1} (best restored): eval_kl_loss={eval_loss:.6f}")

    writer.close()

    # --- Phase 5: Save ---
    if args.output_dir:
        qad_dir = os.path.join(args.output_dir, "qad_scale")
        print(f"Saving trained model to {qad_dir}...")
        model.save_pretrained(qad_dir)
        tokenizer.save_pretrained(qad_dir)

    print("Done.")


if __name__ == "__main__":
    main()

@codecov
Copy link

codecov bot commented Feb 6, 2026

Codecov Report

❌ Patch coverage is 29.34783% with 65 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.25%. Comparing base (d17bad3) to head (d215b70).

Files with missing lines Patch % Lines
modelopt/torch/quantization/model_calib.py 11.42% 31 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 20.00% 28 Missing ⚠️
modelopt/torch/quantization/tensor_quant.py 50.00% 6 Missing ⚠️
Additional details and impacted files
@@                      Coverage Diff                      @@
##           asma/refactor-scale-sweep     #864      +/-   ##
=============================================================
- Coverage                      73.45%   73.25%   -0.20%     
=============================================================
  Files                            197      197              
  Lines                          20651    20743      +92     
=============================================================
+ Hits                           15169    15196      +27     
- Misses                          5482     5547      +65     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Base automatically changed from asma/refactor-scale-sweep to main February 6, 2026 19:47
@cjluo-nv
Copy link
Collaborator

cjluo-nv commented Feb 9, 2026

This is cool. @realAsma do you have some results to share as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants