-
Notifications
You must be signed in to change notification settings - Fork 268
[CK_BUILDER] Add reflection for wmma and bwd weight instances to ck builder reflection #3592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
It's a lot easer to get comments on a draft if you can explain the what and why in your draft PR description. 😜 I think I can see what you are doing here, but generally introduce what the PR is doing and what you want feedback on. |
| builder::ElementwiseOperation output_element_op; | ||
|
|
||
| builder::GemmPadding gemm_padding; | ||
| std::optional<builder::GemmPadding> gemm_padding = std::nullopt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really good, and you should lead your PR description with this change to ConvTraits (the "what" of the PR), as well as why we are making these optional now (the "why"). One question I have is where we should use std::optional versus using std::variant.
That's the design discussion we should focus on: how should ConvTraits be generalized for backward weights. This PR should update code comments and our relect/README.md file so that everyone understands this important generalization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the std::optional vs std::variant part, I would use variant if their is an obvious either-or, like with the loop sched / blockGemmSched. For these fields, std::optional seems to be the obvious choice
experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp
Show resolved
Hide resolved
experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp
Show resolved
Hide resolved
|
|
||
| #pragma once | ||
|
|
||
| // Fwd instances |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to decide if we need this instance_to_conv_traits.hpp file. I left it in as part of the refactoring, but it may be that we can always use the specific includes. I worry some about files like this leading to longer compile times. Let me know your thoughts on this, and I'll consider this for my next refactoring and cleanup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean splitting fwd and bwd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just meant that I don't know if we even need a file that includes all the functions to convert conv traits, or if we always have enough context to only include the specific function we need.
|
What I've read so far looks really good. My two suggestions for the PR description:
|
8936a15 to
71df5c5
Compare
…, _v3, grouped_conv_two_stage_wmma_cshuffle_v3,
f62204c to
ee0494b
Compare
| // Tensor Layouts | ||
| // ---------------------------------------------------------------------------- | ||
|
|
||
| // Helper variable template to check if CK layout enums match |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice use of a templated constexpr variable.
| return std::array<builder::TensorLayout, 3>{in, weight, out}; | ||
| }; | ||
| // Helper lambda to construct layout array | ||
| auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great!
This PR adds reflection for wmma fwd as well as xdl and wmma bwd_weight kernels to the ck builder
To allow reflection of bwd weight kernels, several std::optional parameters were added to conv_traits.hpp. These include:
Further, several helper functions were introduced to reduce duplication. These can be found in conv_traits_helpers.hpp. For this abstraction to work, several names in the instance_traits had to be changed to be more consistent.
The conv_data_type function was changed from taking an instance_traits object as a template parameter to and accessing it's ADataType object to a function taking a data type directly (as bwd_weight functions do not have ADataType) and is now essentially a conversion to a builder::DataType object.