-
Notifications
You must be signed in to change notification settings - Fork 268
Remove code duplications in batched gemm (multi D) gemm (multi D) wmma #3617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Remove code duplications in batched gemm (multi D) gemm (multi D) wmma #3617
Conversation
5c807c4 to
9c0098b
Compare
|
Hey guys, please create branches in ROCm repo directly, otherwise the CI builds won't run automatically. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors batched GEMM-GEMM implementations to eliminate code duplication by extracting common functionality into a shared header file (device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp). The refactoring consolidates kernel launch logic, argument structures, and validation code that was previously duplicated across multiple device implementations.
Changes:
- Created a new common header file containing shared kernel functions, argument structures, and validation logic
- Refactored
device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hppto use the common implementation - Refactored
device_batched_gemm_gemm_wmma_cshuffle_v3.hppto use the common implementation
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp |
New file containing extracted common kernel, argument structures, invoker, and validation logic with support for both MultiD and non-MultiD variants |
device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp |
Refactored to remove duplicated code by using common base classes and replacing RawArg with Argument from common implementation |
device_batched_gemm_gemm_wmma_cshuffle_v3.hpp |
Refactored to remove duplicated code by using common base classes and replacing RawArg with Argument, with empty arrays for D0/D1 tensors |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec, | ||
| const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter names suggest these are D1 tensors (d1_g_m_o_lengths_vec, d1_g_m_o_strides_vec), but the template parameter specifies NumD0Tensor. This should be NumD1Tensor to match the actual tensor type being described.
| const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec, | |
| const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec) | |
| const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_lengths_vec, | |
| const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_strides_vec) |
| Number<I> d0_idx) const | ||
| { | ||
| return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]); |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter is named 'd0_idx' but should be 'd1_idx' in the GetD0BasePtr method to correctly reflect that it's indexing into D0 tensors, not D1.
| Number<I> d0_idx) const | |
| { | |
| return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]); | |
| Number<I> d1_idx) const | |
| { | |
| return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]); |
| constexpr TailNumber tn = decltype(tail_number)::value; | ||
|
|
||
| const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, | ||
| GridwiseGemm, | ||
| has_loop, | ||
| tn, |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name 'tn' is ambiguous. It should be renamed to 'tail_num' or 'tail_number_value' for clarity.
| constexpr TailNumber tn = decltype(tail_number)::value; | |
| const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, | |
| GridwiseGemm, | |
| has_loop, | |
| tn, | |
| constexpr TailNumber tail_num = decltype(tail_number)::value; | |
| const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, | |
| GridwiseGemm, | |
| has_loop, | |
| tail_num, |
9c0098b to
d85089a
Compare
d85089a to
18667f4
Compare
…lti_d gemm multi_d wmma implementation This file includes all shared components. The (shared between the two implementations) kernel, the pointer offset computation struct, the grid descriptor creator and definitions, the invoker struct and the argument struct. Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…ementation Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…shuffle v3 implementation Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…ltiple D wmma cshuffle v3 implementations Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
18667f4 to
0d8c690
Compare
Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
The actual code duplication as per LoC is not significant, but this is mainly due to clang-format having template parameters in multiple lines. The actual value of this PR is that the implementation code is shared and the details are nicely hidden from the device structs.