Conversation
…antizer, NVFP4MSECalibrator Signed-off-by: realAsma <akuriparambi@nvidia.com> fp4 static kernel fix, test fixes, minor clean ups Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
|
Caution Review failedFailed to post review comments 📝 WalkthroughWalkthroughThis PR introduces a new FP4 static quantization path with two-level scaling support. Key additions include NVFP4StaticQuantizer for per-block and global amax scaling, NVFP4MSECalibrator for FP4-aware calibration, and a new scale_after_dequant algorithm. The tensor quantization layer is refactored with new FP4 casting and static blockwise quantization functions, while the Triton backend is extended with fp4_fake_quant_block for Hopper+ GPUs. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant scale_after_dequant as scale_after_dequant Function
participant mse_calibrate as mse_calibrate
participant NVFP4MSECalibrator
participant Module as Module Quantizers
participant NVFP4SQ as NVFP4StaticQuantizer
User->>scale_after_dequant: Run with model, forward_loop, scale_algorithm
scale_after_dequant->>mse_calibrate: Call with scale_algorithm config
mse_calibrate->>Module: Detect NVFP4 quantizers
mse_calibrate->>NVFP4MSECalibrator: Create with global_amax, FP4 candidates
NVFP4MSECalibrator->>NVFP4MSECalibrator: Generate 126 FP8 E4M3 scales
User->>NVFP4MSECalibrator: Collect calibration data via forward_loop
NVFP4MSECalibrator->>NVFP4MSECalibrator: Aggregate losses per candidate
NVFP4MSECalibrator->>NVFP4MSECalibrator: Compute optimal amax (per-block)
scale_after_dequant->>Module: Extract per-block/per-tensor scales from quantizers
scale_after_dequant->>NVFP4SQ: Apply scales via enable_scale_after_dequant
NVFP4SQ->>NVFP4SQ: Store per_block_scale, per_tensor_scale (learnable)
scale_after_dequant->>User: Return model in scale-after-dequant mode
sequenceDiagram
participant Input as Input Tensor
participant NVFP4SQ as NVFP4StaticQuantizer
participant TritonKernel as Triton Kernel
participant Output as Output (FP4)
Input->>NVFP4SQ: _fake_quantize(x)
alt scale_after_dequant enabled
NVFP4SQ->>NVFP4SQ: Apply per_block_scale * per_tensor_scale
NVFP4SQ->>TritonKernel: static_blockwise_fp4_fake_quant (amax, global_amax)
else standard quantization
NVFP4SQ->>TritonKernel: static_blockwise_fp4_fake_quant (amax, global_amax)
end
TritonKernel->>TritonKernel: Compute per-block max, quantize, descale
TritonKernel->>Output: Return FP8 (stored as input dtype)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
scale_after_dequantscale_after_dequant PTQ/QAD
|
Experiment Script: |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## asma/refactor-scale-sweep #864 +/- ##
=============================================================
- Coverage 73.45% 73.25% -0.20%
=============================================================
Files 197 197
Lines 20651 20743 +92
=============================================================
+ Hits 15169 15196 +27
- Misses 5482 5547 +65 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
This is cool. @realAsma do you have some results to share as well? |
What does this PR do?
Type of change: New feature (quantization algorithm) + new tests + bug fix
Overview:
scale_after_dequant, for NVFP4 static weight quantizers:nn.Parameter) and per-tensor scale is frozen.Usage
Testing
pytest -q tests/gpu/torch/quantization/test_quantize_cuda.py -k scale_after_dequantBefore your PR is "Ready for review"
Additional Information
ScaleAfterDequantModeDescriptorand config added asScaleAfterDequantConfig.modelopt/torch/quantization/model_calib.py(scale_after_dequant).modelopt/torch/quantization/nn/modules/tensor_quantizer.py.modelopt/torch/quantization/tensor_quant.py+modelopt/torch/quantization/triton/fp4_kernel.py.Summary by CodeRabbit
New Features
Improvements