Skip to content

Conversation

@zsotakal
Copy link

@zsotakal zsotakal commented Jan 20, 2026

Proposed changes

Add support for grouped gemm multi ABD fixed NK. MR contains:

  • Device struct for grouped gemm with multiple ABD and fixed NK (DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK).
  • Wmma versions of existing example codes: 59_grouped_gemm_multi_ABD
  • Unit tests for both new wmma implementation and the reference xdl code (previously missing)
    Note: Some Xdl instances were commented out because of unit test failures. As mentioned apparently for xdl this feature was missing tests so our assumption is either there is an implemenetation bug or these instances were not set up correctly. Has the potential for a follow-up issue.
  • Generic ck profiler interface with the purpose of calling unit tests.
  • Gemm instances with specific elementwise operations for gemm bias gelu calculations.
  • Added class for grouped gemm multi ABD reference calculations.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

@afagaj afagaj requested a review from Copilot January 20, 2026 22:09
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements support for grouped GEMM with multiple ABD tensors and fixed NK on RDNA4 architecture, specifically for WMMA implementations. The feature was previously only available for XDL implementations.

Changes:

  • Added WMMA device operator implementations for grouped GEMM multi ABD with fixed NK
  • Unit tests for both new WMMA and existing XDL implementations
  • Reference implementation class for verification
  • Example code demonstrating WMMA usage patterns

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp Unit test framework for validating grouped GEMM multi ABD fixed NK implementations
test/grouped_gemm/CMakeLists.txt Build configuration for new unit test
profiler/include/profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp Generic profiler interface for calling unit tests and benchmarking
profiler/include/profiler/profile_gemm_multi_abd_impl.hpp Refactored to use new reference implementation
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp Commented out failing XDL instances
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp Commented out failing XDL instances
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp WMMA instances for MK-NK-MN layout with bias/gelu operations
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp WMMA instances for MK-KN-MN layout with bias/gelu operations
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp WMMA instances for KM-KN-MN layout with bias/gelu operations
library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt Build configuration for new WMMA instances
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp Factory functions for WMMA instances
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp Reference implementation for grouped GEMM multi ABD verification
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Added EDataType_ alias for type access
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp Added hardware support checks and main K block loop validation
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp New WMMA device operator implementation
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp Example using WMMA with FP16 and bias addition
example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp Example using WMMA with BF16/I8 and bias+GELU
example/59_grouped_gemm_multi_ABD/CMakeLists.txt Build configuration for WMMA examples

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty conditional block serves no purpose. Remove this branch or add a comment explaining why argc == 1 is explicitly handled (e.g., "use default parameters").

Suggested change
if(argc == 1) {}
if(argc == 1)
{
// use default parameters when no extra arguments are provided
}

Copilot uses AI. Check for mistakes.
auto ref_invoker = ref_gemm.MakeInvoker();

auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference GEMM arguments use variables a_m_k, b_k_n, and c_m_n that are no longer defined in this scope after the refactoring. These variables were computed within the removed reference computation code and need to be generated by the new ReferenceGemmMultiABD class.

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +21
// this function is also defined in CK but because of the way we use it in
// profile_gemm_multi_impl, it requires the arguments to not be const
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation explaining why this duplicate function definition exists and how it differs from the CK version. The comment on line 20-21 is insufficient - it should explain the const-correctness difference and the implications for usage in profile_gemm_multi_impl.

Suggested change
// this function is also defined in CK but because of the way we use it in
// profile_gemm_multi_impl, it requires the arguments to not be const
// NOTE:
// This helper intentionally duplicates `concat_tuple_of_refs` from the core CK utilities,
// but with a different const-correctness contract on its arguments:
//
// - The CK version is defined to operate on (typically) const-qualified tuples of
// references; its parameters are more permissive and can accept `const Tuple<...>&`.
// - This host-side overload is deliberately restricted to *non-const* tuples of
// references: `ck::Tuple<X&...>&` and `ck::Tuple<Y&...>&`.
//
// In `profile_gemm_multi_impl`, we need to concatenate tuples that contain non-const
// references to tensors/buffers so that:
// * The resulting concatenated tuple preserves non-const reference semantics, allowing
// the profiled kernels and host-side utilities to modify the referenced objects, and
// * Overload resolution / SFINAE continues to select APIs that require non-const
// references (these would reject a const-qualified tuple produced by the CK version).
//
// If this function were replaced by the CK version, the arguments in
// `profile_gemm_multi_impl` could become (or be treated as) const, which would either:
// - Prevent intended mutation of the underlying tensors, or
// - Cause subtle compilation or behavior differences due to const propagation.
//
// For that reason, this duplicate, non-const overload must remain local to the host-side
// GEMM multi reference implementation and should not be "simplified" by switching to the
// CK variant without carefully revisiting `profile_gemm_multi_impl` and its call sites.

Copilot uses AI. Check for mistakes.
{
if(arg.grouped_gemm_kernel_args_dev == nullptr)
{
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'nullpr' to 'nullptr'.

Suggested change
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr");

Copilot uses AI. Check for mistakes.
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'n0' to 'no'.

Suggested change
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");

Copilot uses AI. Check for mistakes.
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'n0' to 'no'.

Suggested change
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants