⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@DDEle
Copy link
Contributor

@DDEle DDEle commented Jan 20, 2026

Proposed changes

This PR fixes an integer overflow issue in the deterministic FMHA backward implementation by changing batch/nhead stride variables from index_t (32-bit) to long_index_t (64-bit) for the dq_acc. See also ROCm/aiter#1873.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

@DDEle DDEle requested a review from Copilot January 20, 2026 05:35
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes an integer overflow issue in the deterministic FMHA (Fused Multi-Head Attention) backward pass implementation by changing stride variables from index_t (32-bit) to long_index_t (64-bit) for the dq_acc accumulator tensor.

Changes:

  • Changed nhead_stride_dq_acc and batch_stride_dq_acc from ck_tile::index_t to ck_tile::long_index_t across kernel definitions and argument structures
  • Restructured the dq_acc tensor layout to be [batch, nhead, nsplits, seqlen_q, hdim_q] instead of conditional layouts
  • Updated stride calculations to use proper casting and correct dimension ordering

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp Changed type of nhead_stride_dq_acc and batch_stride_dq_acc from index_t to long_index_t in kernel structs and function signatures
example/ck_tile/01_fmha/fmha_bwd_runner.hpp Simplified dq_acc tensor layout, corrected stride calculations with proper casting, and reordered computation of stride variables
example/ck_tile/01_fmha/fmha_bwd.hpp Updated argument structure to use long_index_t for nhead_stride_dq_acc and batch_stride_dq_acc

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@DDEle DDEle merged commit fcc9372 into develop Jan 21, 2026
26 checks passed
@DDEle DDEle deleted the fix-fmha-det-overflow branch January 21, 2026 01:54
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dq_acc layout change from (nsplit, B, H, S, D) to (B, H, nsplit, S, D) due to the fact that only batch_stride_dq_acc and nhead_stride_dq_acc are promoted to long_index_t?

In fact, in order to satisfy some aiter asm bwd requirement, our TE bwd had to make dq_acc layout as (nsplits, B, H, S, D). Can we also promote the nsplit_stride_dq_acc as well?

Copy link
Contributor Author

@DDEle DDEle Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Promoting nsplits_stride_dq_acc would introduce extra overhead as a thread block of the reduction & convert kernel needs access multiple splits in its hotloop (thus it is part of the ck tile window rather than simple pointer arithmetic at the beginning of the kernel).

The layout of dq_acc is kind of in the scope of kernel implementation and I guess it's better to change it accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants