-
Notifications
You must be signed in to change notification settings - Fork 270
[CK_TILE] Fix Int32 Overflow in Deterministic FMHA BWD #3615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes an integer overflow issue in the deterministic FMHA (Fused Multi-Head Attention) backward pass implementation by changing stride variables from index_t (32-bit) to long_index_t (64-bit) for the dq_acc accumulator tensor.
Changes:
- Changed
nhead_stride_dq_accandbatch_stride_dq_accfromck_tile::index_ttock_tile::long_index_tacross kernel definitions and argument structures - Restructured the
dq_acctensor layout to be[batch, nhead, nsplits, seqlen_q, hdim_q]instead of conditional layouts - Updated stride calculations to use proper casting and correct dimension ordering
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp | Changed type of nhead_stride_dq_acc and batch_stride_dq_acc from index_t to long_index_t in kernel structs and function signatures |
| example/ck_tile/01_fmha/fmha_bwd_runner.hpp | Simplified dq_acc tensor layout, corrected stride calculations with proper casting, and reordered computation of stride variables |
| example/ck_tile/01_fmha/fmha_bwd.hpp | Updated argument structure to use long_index_t for nhead_stride_dq_acc and batch_stride_dq_acc |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
f157fc9 to
d321f3a
Compare
| i_perm | ||
| ? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q} | ||
| : std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q}); | ||
| std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this dq_acc layout change from (nsplit, B, H, S, D) to (B, H, nsplit, S, D) due to the fact that only batch_stride_dq_acc and nhead_stride_dq_acc are promoted to long_index_t?
In fact, in order to satisfy some aiter asm bwd requirement, our TE bwd had to make dq_acc layout as (nsplits, B, H, S, D). Can we also promote the nsplit_stride_dq_acc as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Promoting nsplits_stride_dq_acc would introduce extra overhead as a thread block of the reduction & convert kernel needs access multiple splits in its hotloop (thus it is part of the ck tile window rather than simple pointer arithmetic at the beginning of the kernel).
The layout of dq_acc is kind of in the scope of kernel implementation and I guess it's better to change it accordingly.
Proposed changes
This PR fixes an integer overflow issue in the deterministic FMHA backward implementation by changing batch/nhead stride variables from index_t (32-bit) to long_index_t (64-bit) for the dq_acc. See also ROCm/aiter#1873.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed files