Cleanup and refactoring related to tile loading#3505
Cleanup and refactoring related to tile loading#3505SamiAario-AMD wants to merge 37 commits intodevelopfrom
Conversation
47bbab6 to
0308965
Compare
| else if constexpr(LoadTranspose) | ||
| { | ||
| dst = load_tile_transpose(src); | ||
| load_tile_transpose(dst, src); |
There was a problem hiding this comment.
Copilot seems to agree with me that assignment cannot be optimized by the compiler in such a way that the creation of temporaries is avoided, when assigning to a complex type. It then ought to make sense to avoid assignment when possible, and pass the object to be assigned to as a reference instead.
There was a problem hiding this comment.
@SamiAario-AMD Have you checked the assembly for this change? Does the register usage is different?
There was a problem hiding this comment.
I checked the assembly, and unfortunately there is no difference between using dst = load_tile_transpose(src) and load_tile_transpose(dst, src): the generated assembly is identical for my toy example. Adding the latter does have the benefit of making the API more similar to that of load_tile however. This PR does not remove the existing load_tile_transpose but just adds the new overload.
There was a problem hiding this comment.
Pull request overview
This PR performs cleanup and refactoring in preparation for mixed precision fp16/bf16 x fp8 implementation. The changes focus on standardizing the tile loading infrastructure and improving code organization.
Key changes:
- Renamed
load_interleaved_pk_type.hpptoload_and_convert_tile.hppand refactored the API to use consistent naming conventions - Updated
load_tile_transposefunctions to use output parameters instead of return values for consistency - Removed unused variable declarations and simplified type deduction logic
Reviewed changes
Copilot reviewed 41 out of 41 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| include/ck_tile/ops/topk_softmax.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/topk.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/softmax.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/smoothquant.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/rmsnorm2d.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/reduce.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/pooling.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/permute.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/norm_reduce.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/layernorm2d.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/image_to_column.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/grouped_convolution.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | Replaced load_int4_tile calls with simplified load_and_convert_tile API |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp | Simplified tile loading by removing explicit type parameters |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp | Simplified tile loading by removing explicit type parameters |
| include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp | Simplified tile loading by removing explicit type parameters |
| include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp | Updated to use load_and_convert_tile API |
| include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp | Updated to use load_and_convert_tile API |
| include/ck_tile/ops/gemm_quant.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp | Replaced load_int4_tile calls with load_and_convert_tile maintaining explicit template parameters |
| include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp | Moved KPack variable declaration to narrower scope where it's actually used |
| include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp | Updated to use output parameter for load_tile_transpose and simplified type deduction |
| include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp | Updated to use load_and_convert_tile API and changed tile distribution to use decltype |
| include/ck_tile/ops/gemm.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/fused_moe.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp | Changed load_tile_transpose to use output parameter instead of assignment |
| include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp | Changed load_tile_transpose to use output parameter instead of assignment |
| include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | Changed all load_tile_transpose calls to use output parameter instead of assignment |
| include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | Changed all load_tile_transpose calls to use output parameter instead of assignment |
| include/ck_tile/ops/fmha.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/flatmm.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/epilogue.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/elementwise.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/common/load_and_convert_tile.hpp | Refactored and renamed from load_interleaved_pk_type.hpp with simplified API |
| include/ck_tile/ops/common.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/batched_transpose.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/batched_contraction.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | Updated include from load_interleaved_pk_type.hpp to load_and_convert_tile.hpp |
| include/ck_tile/core/tensor/tile_window.hpp | Refactored load function signature to accept tuple of tile windows |
| include/ck_tile/core/tensor/load_tile_transpose.hpp | Changed functions to use output parameters and updated documentation |
| include/ck_tile/core/tensor/load_tile.hpp | Changed return types from auto to void for functions that populate output parameters |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp
Outdated
Show resolved
Hide resolved
c811e2a to
211b784
Compare
| static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); | ||
| constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; | ||
| const auto in_dstr_tensors = load_tile(warp_window); | ||
| const auto tmp = load_tile(src); |
There was a problem hiding this comment.
tmp is not the best name. What about naming it src and the window src_window ?
There was a problem hiding this comment.
I will do this in a separate PR as discussed.
| using BLdsDataType = | ||
| std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>, | ||
| typename Problem::ADataType, | ||
| typename Problem::BDataType>; | ||
|
|
||
| auto b_lds_load_tile_distr = []() { | ||
| if constexpr(is_b_load_tr) | ||
| { | ||
| return make_static_tile_distribution( | ||
| typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode, | ||
| BLdsDataType>::TransposedDstrEncode{}); | ||
|
|
||
| typename InputTileDistributionTraits< | ||
| typename BLdsLoadTileDistr::DstrEncode, | ||
| typename BLdsTensorView::DataType>::TransposedDstrEncode{}); | ||
| } |
There was a problem hiding this comment.
You can't just use BLdsTensorView::DataType here, because of above conditional type selection. See here:
Problem::BDataType then you get different logic here.
There was a problem hiding this comment.
OverrideBDataType in GetABLdsTensorViews should be the same as BLdsDataType in MakeBLdsWindows. The idea is that once we have determined OverrideBDataType and used it to define the B LDS tensor view, then we can reuse the B LDS tensor view's data type later in MakeBLdsWindows instead of re-determining it with the conditional type selection that I removed.
There was a problem hiding this comment.
Ok - in general I like your idea to use data type from tensor views. So if you'd make sure that in all places this change is properly handled it's fine.
I think this should be mentioned in changelog, along to other load tile API changes.
There was a problem hiding this comment.
Two changelog entries added.
87bb39b to
fc1b683
Compare
|
Imported to ROCm/rocm-libraries |
Proposed changes
Cleanup and refactoring done while implementing mixed precision for fp16/bf16 x fp8
Key changes:
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered