Skip to content

Conversation

@chris-tsiaousis-hpc
Copy link

@chris-tsiaousis-hpc chris-tsiaousis-hpc commented Jan 20, 2026

Proposed changes

Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

The actual code duplication as per LoC is not significant, but this is mainly due to clang-format having template parameters in multiple lines. The actual value of this PR is that the implementation code is shared and the details are nicely hidden from the device structs.

@illsilin
Copy link
Collaborator

Hey guys, please create branches in ROCm repo directly, otherwise the CI builds won't run automatically.
I'll kick off the CI build manually for now.

@afagaj afagaj requested a review from Copilot January 20, 2026 22:10
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors batched GEMM-GEMM implementations to eliminate code duplication by extracting common functionality into a shared header file (device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp). The refactoring consolidates kernel launch logic, argument structures, and validation code that was previously duplicated across multiple device implementations.

Changes:

  • Created a new common header file containing shared kernel functions, argument structures, and validation logic
  • Refactored device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp to use the common implementation
  • Refactored device_batched_gemm_gemm_wmma_cshuffle_v3.hpp to use the common implementation

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp New file containing extracted common kernel, argument structures, invoker, and validation logic with support for both MultiD and non-MultiD variants
device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp Refactored to remove duplicated code by using common base classes and replacing RawArg with Argument from common implementation
device_batched_gemm_gemm_wmma_cshuffle_v3.hpp Refactored to remove duplicated code by using common base classes and replacing RawArg with Argument, with empty arrays for D0/D1 tensors

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 195 to 226
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec)
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter names suggest these are D1 tensors (d1_g_m_o_lengths_vec, d1_g_m_o_strides_vec), but the template parameter specifies NumD0Tensor. This should be NumD1Tensor to match the actual tensor type being described.

Suggested change
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec)
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_strides_vec)

Copilot uses AI. Check for mistakes.
Comment on lines +273 to +306
Number<I> d0_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]);
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter is named 'd0_idx' but should be 'd1_idx' in the GetD0BasePtr method to correctly reflect that it's indexing into D0 tensors, not D1.

Suggested change
Number<I> d0_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]);
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);

Copilot uses AI. Check for mistakes.
Comment on lines 548 to 615
constexpr TailNumber tn = decltype(tail_number)::value;

const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
GridwiseGemm,
has_loop,
tn,
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name 'tn' is ambiguous. It should be renamed to 'tail_num' or 'tail_number_value' for clarity.

Suggested change
constexpr TailNumber tn = decltype(tail_number)::value;
const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
GridwiseGemm,
has_loop,
tn,
constexpr TailNumber tail_num = decltype(tail_number)::value;
const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
GridwiseGemm,
has_loop,
tail_num,

Copilot uses AI. Check for mistakes.
@ErwinTerpstra ErwinTerpstra deleted the branch ROCm:develop January 21, 2026 09:03
@ErwinTerpstra ErwinTerpstra reopened this Jan 21, 2026
@ErwinTerpstra ErwinTerpstra changed the base branch from eterpstr/97-implement-device_batched_gemm_add_relu_gemm_add-for-rdna4 to develop January 21, 2026 09:36
@ErwinTerpstra ErwinTerpstra requested a review from a team as a code owner January 21, 2026 09:36
@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the 219-gemm-gemm-multi-d-duplication-removal branch from 9c0098b to d85089a Compare January 21, 2026 09:40
@chris-tsiaousis-hpc chris-tsiaousis-hpc changed the title 219 gemm gemm multi d duplication removal Remove code duplications in batched gemm (multi D) gemm (multi D) wmma Jan 21, 2026
@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the 219-gemm-gemm-multi-d-duplication-removal branch from d85089a to 18667f4 Compare January 21, 2026 13:38
…lti_d gemm multi_d wmma implementation

This file includes all shared components. The (shared between the two implementations) kernel, the pointer offset computation struct, the grid descriptor creator and definitions, the invoker struct and the argument struct.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…ementation

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…shuffle v3 implementation

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
…ltiple D wmma cshuffle v3 implementations

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
@chris-tsiaousis-hpc chris-tsiaousis-hpc force-pushed the 219-gemm-gemm-multi-d-duplication-removal branch from 18667f4 to 0d8c690 Compare January 22, 2026 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants