From 24b6bd7f6227a63b74ee4bcb899f37d5ec1c3660 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski Date: Fri, 16 Jan 2026 12:02:48 +0000 Subject: [PATCH 1/4] Added bias_bnorm_clamp for WMMA conv fwd large tensor. Following operations are added for FP16/BF16 data type and NHWGCxGKYXC layout. - grouped_conv2d_fwd_bias_bnorm_clamp - grouped_conv3d_fwd_bias_bnorm_clamp --- ...ltiple_d_wmma_cshuffle_v3_large_tensor.hpp | 81 ++++++++++--------- ...d_convolution_forward_bias_bnorm_clamp.hpp | 12 +++ ...rward_bias_bnorm_clamp_wmma_cshufflev3.inc | 60 ++++++++++++++ .../CMakeLists.txt | 2 + ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 44 ++++++++++ ..._tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 44 ++++++++++ .../CMakeLists.txt | 2 + ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 44 ++++++++++ ...nsor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 44 ++++++++++ 9 files changed, 296 insertions(+), 37 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp index 08d0f296f03..7cadc5e19e6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -617,32 +618,29 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m); const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n); - GemmArgs new_args{}; - new_args.a_ptrs_ = p_as_grid; - new_args.b_ptrs_ = p_bs_grid; - new_args.ds_ptrs_ = p_ds_grid; - new_args.e_ptr_ = p_e_grid; - - new_args.a_element_op_ = a_element_op_; - new_args.b_element_op_ = b_element_op_; - new_args.cde_element_op_ = cde_element_op_; - - new_args.M_ = gemm_m; - new_args.N_ = gemm_n; - - new_args.a_grid_desc_ = a_grid_desc; - new_args.b_grid_desc_ = b_grid_desc; - new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + const auto ds_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, m_block, n_block); - new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ = + const auto e_desc_mblock_mperblock_nblock_nperblock = GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, m_block, n_block); - new_args.BlockStart_ = BlockStart; - new_args.BlockEnd_ = BlockEnd; - - gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args; + auto* gemm_args = &gemm_desc_kernel_args_.At(valid_gemms_count_); + new(gemm_args) GemmArgs{p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + a_element_op_, + b_element_op_, + cde_element_op_, + gemm_m, + gemm_n, + a_grid_desc, + b_grid_desc, + ds_desc_mblock_mperblock_nblock_nperblock, + e_desc_mblock_mperblock_nblock_nperblock, + BlockStart, + BlockEnd}; valid_gemms_count_++; } @@ -789,11 +787,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - static_for<0, NumDTensor, 1>{}([&](auto i) { - compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; - compute_ptr_offset_of_n_.BatchStrideDs_(i) = - ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + }); + } } void Print() const @@ -807,12 +808,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor << ", is_split_valid=" << std::boolalpha << is_split_valid_ << std::noboolalpha << ", grid_size=" << grid_size_ << std::endl; - static_for<0, NumDTensor, 1>{}([&](auto i) { - std::cout << " Ds[" << i.value - << "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i) - << ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i) - << std::endl; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << " Ds[" << i.value << "] group stride=" + << compute_ptr_offset_of_groups_.BatchStrideDs_.At(i) + << ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i) + << std::endl; + }); + } std::cout << "===== GEMM splits =====" << std::endl; for(index_t i = 0; i < valid_gemms_count_; ++i) @@ -836,11 +840,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: " << gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl; - static_for<0, NumDTensor, 1>{}([&](auto d_idx) { - std::cout << " D" << d_idx.value << " descriptor: " - << gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx) - << std::endl; - }); + if constexpr(NumDTensor > 0) + { + static_for<0, NumDTensor, 1>{}([&](auto d_idx) { + std::cout << " D" << d_idx.value << " descriptor: " + << gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx) + << std::endl; + }); + } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp index 295b2c21b58..e42a3f2045b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_bnorm_clamp.hpp @@ -297,6 +297,9 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -56,6 +86,21 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhw PassThrough, BiasNormalizeInInferClamp>>>& instances); +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); + void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index d089663f37d..1f381f5f7d3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -328,6 +328,8 @@ generate_sharded_instantiations( add_instance_library(device_grouped_conv2d_fwd_bias_bnorm_clamp_instance wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp ${GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP} ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..6bd58617aaf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector< + std::unique_ptr, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 00000000000..5eebe7f3862 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector< + std::unique_ptr, + NHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index dc759cbb549..f54588991f5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -309,6 +309,8 @@ generate_sharded_instantiations( add_instance_library(device_grouped_conv3d_fwd_bias_bnorm_clamp_instance wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp ${GROUPED_CONV3D_FWD_BIAS_BNORM_CLAMP} ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..6d7ede939a2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_bf16_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..0a6dcf2e754 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BiasNormalizeInInferClamp = ck::tensor_operation::element_wise::BiasNormalizeInInferClamp; + +void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_large_tensor_f16_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 2052f4baf4f84e344e69294197a282dc04f19206 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski Date: Tue, 20 Jan 2026 09:44:49 +0000 Subject: [PATCH 2/4] changed strategy to handle GemmArgs array --- ...ltiple_d_wmma_cshuffle_v3_large_tensor.hpp | 36 ++++++++++--------- include/ck/utility/array.hpp | 11 ++++++ 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp index 7cadc5e19e6..ed0ead42d19 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle_v3_large_tensor.hpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include @@ -625,22 +624,25 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, m_block, n_block); - auto* gemm_args = &gemm_desc_kernel_args_.At(valid_gemms_count_); - new(gemm_args) GemmArgs{p_as_grid, - p_bs_grid, - p_ds_grid, - p_e_grid, - a_element_op_, - b_element_op_, - cde_element_op_, - gemm_m, - gemm_n, - a_grid_desc, - b_grid_desc, - ds_desc_mblock_mperblock_nblock_nperblock, - e_desc_mblock_mperblock_nblock_nperblock, - BlockStart, - BlockEnd}; + gemm_desc_kernel_args_.Emplace( + valid_gemms_count_, + GemmArgs{.a_ptrs_ = p_as_grid, + .b_ptrs_ = p_bs_grid, + .ds_ptrs_ = p_ds_grid, + .e_ptr_ = p_e_grid, + .a_element_op_ = a_element_op_, + .b_element_op_ = b_element_op_, + .cde_element_op_ = cde_element_op_, + .M_ = gemm_m, + .N_ = gemm_n, + .a_grid_desc_ = a_grid_desc, + .b_grid_desc_ = b_grid_desc, + .ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + ds_desc_mblock_mperblock_nblock_nperblock, + .e_grid_desc_mblock_mperblock_nblock_nperblock_ = + e_desc_mblock_mperblock_nblock_nperblock, + .BlockStart_ = BlockStart, + .BlockEnd_ = BlockEnd}); valid_gemms_count_++; } diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 2b249884b66..73eb18fe166 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -6,6 +6,8 @@ #include "functional2.hpp" #include "sequence.hpp" +#include +#include namespace ck { @@ -27,6 +29,15 @@ struct Array __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } + template + __host__ constexpr auto Emplace(index_t i, Args&&... args) + -> std::enable_if_t> + { + assert(i >= 0 && i < NSize); + mData[i].~TData(); + new(mData + i) TData(ck::forward(args)...); + } + template __host__ __device__ constexpr auto operator=(const T& a) { From 1aef63f644c413d5978eb5ed8cf9ff928399cfc1 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski Date: Tue, 20 Jan 2026 21:06:36 +0000 Subject: [PATCH 3/4] Adding generic instance --- ..._wmma_cshufflev3_large_tensor_instance.hpp | 34 +++++++++++++++++++ ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 2 +- ..._tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 2 +- ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 2 +- ...nsor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 2 +- 5 files changed, 38 insertions(+), 4 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp index c3769fbfd06..07c3e82194e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp @@ -35,6 +35,23 @@ static constexpr auto ConvFwdDefault = static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +template +using device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances = std::tuple< + // clang-format off + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template ; +template +using device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances = std::tuple< + // clang-format off + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template >>& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_large_tensor_bf16_instances< + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 5eebe7f3862..7be4be2f1e0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -27,7 +27,7 @@ void add_device_grouped_conv2d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nh BiasNormalizeInInferClamp>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_large_tensor_f16_instances< + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 6d7ede939a2..4a9c68b2d39 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nd BiasNormalizeInInferClamp>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_large_tensor_bf16_instances< + device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances< 3, NDHWGC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index 0a6dcf2e754..92c86b8df0a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/wmma/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_bias_bn_clamp_wmma_cshufflev3_large_tensor_nd BiasNormalizeInInferClamp>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_large_tensor_f16_instances< + device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances< 3, NDHWGC, GKZYXC, From 41ca771d9443743f2c42d99f787a123a379f0da9 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski Date: Wed, 21 Jan 2026 12:04:47 +0000 Subject: [PATCH 4/4] fixed last nits from reviewers and copilot --- ..._wmma_cshufflev3_large_tensor_instance.hpp | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp index 07c3e82194e..199d9d91ec6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp @@ -45,9 +45,10 @@ template using device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances = std::tuple< // clang-format off - //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -62,9 +63,10 @@ template using device_grouped_conv_fwd_wmma_large_tensor_f16_instances = std::tuple< // clang-format off - //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, @@ -83,9 +85,10 @@ template using device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances = std::tuple< // clang-format off - //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -100,9 +103,10 @@ template using device_grouped_conv_fwd_wmma_large_tensor_bf16_instances = std::tuple< // clang-format off - //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,