⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@kabrahamAMD
Copy link
Contributor

@kabrahamAMD kabrahamAMD commented Jan 16, 2026

This PR adds reflection for wmma fwd as well as xdl and wmma bwd_weight kernels to the ck builder

To allow reflection of bwd weight kernels, several std::optional parameters were added to conv_traits.hpp. These include:

  • gemm_padding: previously existing parameter made optional as not all kernels have this parameter
  • num_gemm_k_prefetch_stage: added as optional, used by wmma kernels.
  • max_transpose_transfer_src_scalar_per_vector/max_transpose_dst_scalar_per_vector: used by bwd_weight kernels
  • num_groups_to_merge: used by two_stage wmma and xdl kernels

Further, several helper functions were introduced to reduce duplication. These can be found in conv_traits_helpers.hpp. For this abstraction to work, several names in the instance_traits had to be changed to be more consistent.
The conv_data_type function was changed from taking an instance_traits object as a template parameter to and accessing it's ADataType object to a function taking a data type directly (as bwd_weight functions do not have ADataType) and is now essentially a conversion to a builder::DataType object.

@Snektron Snektron changed the title [CK Builder] Add reflection for wmma and bwd weight instances to ck builder reflection [CK_BUILDER] Add reflection for wmma and bwd weight instances to ck builder reflection Jan 16, 2026
@shumway
Copy link
Collaborator

shumway commented Jan 16, 2026

It's a lot easer to get comments on a draft if you can explain the what and why in your draft PR description. 😜 I think I can see what you are doing here, but generally introduce what the PR is doing and what you want feedback on.

builder::ElementwiseOperation output_element_op;

builder::GemmPadding gemm_padding;
std::optional<builder::GemmPadding> gemm_padding = std::nullopt;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really good, and you should lead your PR description with this change to ConvTraits (the "what" of the PR), as well as why we are making these optional now (the "why"). One question I have is where we should use std::optional versus using std::variant.

That's the design discussion we should focus on: how should ConvTraits be generalized for backward weights. This PR should update code comments and our relect/README.md file so that everyone understands this important generalization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the std::optional vs std::variant part, I would use variant if their is an obvious either-or, like with the loop sched / blockGemmSched. For these fields, std::optional seems to be the obvious choice


#pragma once

// Fwd instances
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to decide if we need this instance_to_conv_traits.hpp file. I left it in as part of the refactoring, but it may be that we can always use the specific includes. I worry some about files like this leading to longer compile times. Let me know your thoughts on this, and I'll consider this for my next refactoring and cleanup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean splitting fwd and bwd?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just meant that I don't know if we even need a file that includes all the functions to convert conv traits, or if we always have enough context to only include the specific function we need.

@shumway
Copy link
Collaborator

shumway commented Jan 16, 2026

What I've read so far looks really good. My two suggestions for the PR description:

  1. Lead with an description about adding std::optional fields to ConvTraits to generalize to backwards weights. Should we consider std::variant for some of this? Conceptually, how are ConvTriats different for forward and backward convolutions. Does it make sense to have one conv traits for forward and backwards, or would FwdConvTraits and BwdWeiConvTraits be better. (I think one ConvTraits is better, but you should document that design choice.
  2. Summarize the changes to the conv traits helpers. We will likely do more cleanup and refactoring there, and capturing some of your ideas in the PR description will help us with future work.

@kabrahamAMD kabrahamAMD force-pushed the kabraham/builder_bwd_reflection branch from f62204c to ee0494b Compare January 21, 2026 11:42
@kabrahamAMD kabrahamAMD marked this pull request as ready for review January 21, 2026 12:28
@kabrahamAMD kabrahamAMD requested review from a team and ddembeckAMD as code owners January 21, 2026 12:28
// Tensor Layouts
// ----------------------------------------------------------------------------

// Helper variable template to check if CK layout enums match
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice use of a templated constexpr variable.

return std::array<builder::TensorLayout, 3>{in, weight, out};
};
// Helper lambda to construct layout array
auto layouts = [](auto... Ls) { return std::array<builder::TensorLayout, 3>{Ls...}; };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants