fix(kernel): fix NaN in use_gate_in_kernel=True path (cache_key, tail mask)#43
fix(kernel): fix NaN in use_gate_in_kernel=True path (cache_key, tail mask)#43meinie0826 wants to merge 3 commits intoinclusionAI:mainfrom
Conversation
|
Note Gemini is unable to generate a review for this pull request due to the file types involved not being currently supported. |
There was a problem hiding this comment.
Copilot wasn't able to review any files in this pull request.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@icavanyu Could you take a look at this PR when you have time and give some feedback? |
Hi, thank you for your contribution! As the filename suggests, cula/ops/kda_fully_fused_wip.py is still a work-in-progress, and there may be bugs, accuracy, or performance issues in it. The KDA fully fused version requires TF32 sub-chunk MMA and TF32 inverse. Would you be interested in tackling these two issues? |
Thanks for the heads-up! I'm definitely interested in tackling both the TF32 sub-chunk MMA and TF32 inverse — happy to follow up on those as next steps. That said, this PR fixes 2 independent correctness bugs (cache_key missing g_dtype and tail-chunk boundary masking) that were causing NaN outputs across all use_gate_in_kernel=True test cases. All 19 tests now pass on B300, which is a meaningful step forward for the WIP kernel. Would it be possible to merge this NaN-fix PR first, and I'll open a follow-up PR to address the TF32 sub-chunk MMA and TF32 inverse improvements? |
- add g_dtype to cache_key to prevent reusing wrong compiled kernel when g dtype differs between use_gate_in_kernel=True/False paths - apply boundary mask for non-varlen tail chunk in subchunk epilogue (valid_len_chunk was hardcoded to C, missing padding mask for tail) - remove xfail markers from use_gate_in_kernel=True test cases
📌 Description
Fix NaN outputs in the use_gate_in_kernel=True code path of the Blackwell KDA fused forward kernel (KDAChunkwise).
Root causes identified and fixed:
cache_key missing g_dtype (blackwell_fused_fwd.py): When use_gate_in_kernel=False, g is float32; when use_gate_in_kernel=True, kda_gate_chunk_cumsum may return a different dtype. Without g_dtype in the cache key, the first compiled kernel (float32) was reused for a different dtype, causing TMA to read SMEM with the wrong element size → NaN.
Missing boundary mask for non-varlen tail chunk in subchunk epilogue: valid_len_chunk was hardcoded to C for non-varlen sequences, so the last partial chunk (e.g. T=1500, tail=28 tokens) was not masked in apply_qk_kk_mask → invalid MMA accumulator values → NaN.
All 19 tests passed.
🔍 Related Issues
#16
🚀 Pull Request Checklist
Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
⚡ Performance
Reviewer Notes