-
Notifications
You must be signed in to change notification settings - Fork 268
[CK TILE] Add gemm basic v1 interwave pipeline #3616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
94ea368 to
4010341
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a new interwave pipeline implementation for GEMM operations in the CK Tile library. The changes refactor the existing pipeline code to support multiple scheduling strategies (Intrawave and Interwave) through template specialization.
Changes:
- Introduces
GemmPipelineScheduler::Interwavespecialization forGemmPipelineAGmemBGmemCRegV1 - Refactors pipeline implementations to use a common base class (
GemmPipelineAgBgCrImplBase) - Adds validation utilities with configurable tolerance values and max error tracking
- Updates profiler output to include maximum error information
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp | Adds max error reporting to validation output |
| include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp | Refactors V2 pipeline to use base class and extract window creation logic |
| include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp | Adds Interwave scheduler specialization and refactors V1 pipeline structure |
| experimental/builder/include/ck_tile/builder/testing/validation.hpp | Adds tolerance getter functions and max error tracking to validation |
| example/ck_tile/20_grouped_convolution/conv_configs.hpp | Adds pipeline type trait definitions for BASIC_V1 and BASIC_V2 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else if constexpr(DataType::FP16== DT) | ||
|
|
||
| { |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before == operator and unnecessary blank line. Should be DataType::FP16 == DT on line 41, and remove the blank line 42.
| else if constexpr(DataType::BF16 == DT) | ||
|
|
||
| { |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary blank line between the condition and opening brace. Remove line 47 for consistency.
| // { | ||
| // return 1e-6; | ||
| // } | ||
| else if constexpr(DataType::FP16== DT) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before == operator. Should be DataType::FP16 == DT.
|
|
||
| auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0]; | ||
| auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1]; | ||
| auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2]; |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect pointer arithmetic. The buffer contains two uint64_t values followed by a double, so the offset calculation is wrong. Index [2] treats the buffer as an array of doubles, but the first 16 bytes are uint64_t values. Should calculate the byte offset correctly or cast to the appropriate type at the correct position.
| auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2]; | |
| auto d_max_error = | |
| reinterpret_cast<double*>(static_cast<char*>(d_counters.get()) + 2 * sizeof(uint64_t)); |
| .wrong_elements = error_count, | ||
| .total_elements = descriptor.get_element_size(), | ||
| .zero_elements = zero_count, | ||
| .max_error= max_error, |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space before = operator. Should be .max_error = max_error,.
Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered