Skip to content

Made a new branch to add necessary files for batched_gemm_a16w8_block…#2297

Open
nidal567 wants to merge 2 commits intoROCm:mainfrom
nidal567:batched_gemm_a16w8_blockscale
Open

Made a new branch to add necessary files for batched_gemm_a16w8_block…#2297
nidal567 wants to merge 2 commits intoROCm:mainfrom
nidal567:batched_gemm_a16w8_blockscale

Conversation

@nidal567
Copy link
Copy Markdown
Contributor

Motivation

This PR introduces a Triton implementation of a batched GEMM A16W8 blockscale kernel to support mixed-precision matrix multiplication with FP16/BF16 activations and FP8 block-scaled weights. The goal is to approach a balance of efficient execution of quantized workloads while maintaining high compute utilization.

Technical Details

This PR adds a 2D-tiled Triton GEMM kernel implementing batched matrix multiplication with block-scale dequantization
Main components include:

  • Kernel implementation
    • Performs on-the-fly FP8 weight dequantization with block scaling
    • Uses 2D tiling (MxN tiles with K-loop) and MFMA tensor instructions
  • Wrapper
    • Python wrapper to support kernel through operating interface
  • Configuration
    • GEMM config json file used to tune tile sizes and achieve a balance of MFMA utilization and memory bandwidth
  • Benchmark
    • Benchmark script for measuring TFLOPs throughput across given shapes
  • Testing
    • Unit test file to validate kernel correctness

Performance observations:

  • On AMD Instinct MI350X, the kernel achieves ~645 TFLOPs (~60% of theoretical peak) for large workloads (B=16, M=N=K=8192).
  • Roofline analysis shows the kernel is compute-bound, primarily due to the additional FP8 conversion and blockscale operations required in the A16W8 path.

Future optimization opportunities:

  • Exploring split-K strategies (more advanced) for the reduction loop to increase parallelism, supporting better MFMA utilization
  • Potential improvements in register pressure and instruction scheduling
  • Additional tuning of tile sizes and pipeline structure

Test Plan

The following tests were executed to validate correctness and functionality:

  • Unit tests using the existing pytest framework
  • Kernel outputs compared against reference GEMM implementations
    • Consistent tracing and profiling the kernel with efficient shapes to find the optimal performance based on tuning the parameters
  • Benchmarks executed to verify runtime behavior and performance

Test Result

  • All 173 pytest tests passed successfully
  • Kernel outputs match expected results for tested configurations
  • Benchmarks confirm stable execution but lower than expected performance characteristics due to the compute-bound kernel

@nidal567 nidal567 requested a review from a team March 16, 2026 13:43
@nidal567 nidal567 force-pushed the batched_gemm_a16w8_blockscale branch from 77a11c0 to 3f2e822 Compare March 16, 2026 13:56
@nidal567
Copy link
Copy Markdown
Contributor Author

I ran benchmarks using the DeepSeek-R1 model shapes.

  • Base shapes: (B=128, N=512, K=128) and (B=128, N=128, K=512)
  • Batch sizes are scaled by TP degrees (1, 2, 4, and 8) as it affects dimension "B": 128, 64, 32, 16
  • M dimension: 1, 32, 128 (to cover small, medium, and large row dimensions so we can see performance across a realistic range of matrix shapes as noted without exploding the number of runs)

Here are the values:

B M N K TFLOPs
16 1 512 128 0.069
16 1 128 512 0.071
16 32 512 128 2.251
16 32 128 512 2.313
16 128 512 128 8.771
16 128 128 512 8.857
32 1 512 128 0.139
32 1 128 512 0.140
32 32 512 128 4.455
32 32 128 512 4.564
32 128 512 128 18.331
32 128 128 512 18.339
64 1 512 128 0.285
64 1 128 512 0.280
64 32 512 128 8.920
64 32 128 512 8.941
64 128 512 128 36.204
64 128 128 512 35.727
128 1 512 128 0.409
128 1 128 512 0.508
128 32 512 128 12.700
128 32 128 512 15.891
128 128 512 128 43.844
128 128 128 512 58.690

@nidal567 nidal567 force-pushed the batched_gemm_a16w8_blockscale branch from 2127150 to 98fad3a Compare March 18, 2026 00:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant