diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp index 45ec3a2065f..6b1144047f8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -3,77 +3,21 @@ #pragma once -#include #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" #include "ck/utility/tuple.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - Tuple<>{}, // p_d0s_grid - arg.p_b1_grid + b1_batch_offset, - Tuple<>{}, // p_d1s_grid - arg.p_c_grid + c_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.b1_grid_desc, - Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.c_element_op, - arg.block_2_ctile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes C = A * B0 * B1 // MN = MK * KL * LN // ^^^^^^ (Acc0) @@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC) - : BatchStrideA_(BatchStrideA), - BatchStrideB0_(BatchStrideB0), - BatchStrideB1_(BatchStrideB1), - BatchStrideC_(BatchStrideC) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB0_; - index_t BatchStrideB1_; - index_t BatchStrideC_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType + Tuple<>, // D1sDataType + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, // Ds0GridDesc - B1GridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, Tuple<>, // Ds1GridDesc - CGridDesc_M_N, + typename DeviceGemmGemmCommonBase::CGridDesc_M_N, // Tiling Family MPerBlock, LPerBlock, @@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm; - struct RawArg : public BaseArgument - { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - const B1DataType* p_b1_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - index_t StrideB1, - index_t StrideC, - index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CElementwiseOperation c_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_b1_grid{p_b1_grid_}, - p_c_grid{p_c_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - c_element_op{c_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - c_g_m_o_lengths = arr3{batch_count, M, O}; - c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides); - c_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); - block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1); - } - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - const B1DataType* p_b1_grid; - CDataType* p_c_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - arr3 c_g_m_o_lengths; - arr3 c_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CElementwiseOperation c_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - B1GridDesc b1_grid_desc; - CGridDesc_M_N c_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map; + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + Tuple<>, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType, + Tuple<>, // D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) + static bool IsSupportedArgument(const Argument& arg) { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B layout must be Column\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, - arg.b1_grid_desc, - Tuple<>{}, - arg.c_grid_desc_m_n, - arg.block_2_ctile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto c_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto c_stride_lowest = arg.c_g_m_o_strides[2]; - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - c_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_gemm_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - // polymorphic std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b0, @@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm(static_cast(p_a), - static_cast(p_b0), - static_cast(p_b1), - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideB1, - StrideC, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - BatchStrideC, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + + std::array p_d0_grid{}; + std::array p_d1_grid{}; + std::array StrideD0s{}, BatchStrideD0s{}; + std::array StrideD1s, BatchStrideD1s{}; + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0_grid, + static_cast(p_b1), + p_d1_grid, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideC, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 00000000000..a739af898f7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,902 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/integral_constant.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_e1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx))); + + auto [p_d0s_grid, p_d1s_grid] = [&]() { + if constexpr(IsMultiD) + { + auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) { + static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) { + const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In)); + grid_pointer(In) = arg_grid(In) + batch_offset; + }); + return std::move(grid_pointer); + }; + auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx); + }; + auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx); + }; + auto d0s_grid = create_grid(ck::integral_constant{}, + get_d0_base_ptr, + arg.p_d0s_grid, + GridwiseOp::MakeD0sGridPointer()); + auto d1s_grid = create_grid(ck::integral_constant{}, + get_d1_base_ptr, + arg.p_d1s_grid, + GridwiseOp::MakeD1sGridPointer()); + return std::make_pair(d0s_grid, d1s_grid); + } + else + { + return std::make_pair(Tuple<>{}, Tuple<>{}); + } + }(); + + GridwiseOp::template Run( + arg.p_a_grid + a_batch_offset, + arg.p_b0_grid + b0_batch_offset, + p_d0s_grid, + arg.p_b1_grid + b1_batch_offset, + p_d1s_grid, + arg.p_c_e1_grid + c_e1_batch_offset, + p_shared, + arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, + arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op, + arg.b0_element_op, + arg.acc_element_op, + arg.b1_element_op, + arg.cde1_element_op, + arg.block_2_etile_map); +#else + ignore = arg; +#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) +} + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common +{ + static constexpr ck::index_t NumD0Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD0Tensor; + } + return 0; + }(); + static constexpr ck::index_t NumD1Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD1Tensor; + } + return 0; + }(); + + struct GridDescriptorCreator + { + // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler + // Transform operator or just not use one at all. + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence<1, 1, 1, 1, 1>, + Sequence, + GemmSpec, + TensorSpecialization::Default, // ASpec + TensorSpecialization::Default, // B0Spec + TensorSpecialization::Default, // B1Spec + TensorSpecialization::Default>; // CSpec + + __host__ __device__ static auto + MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, + const std::array& a_g_m_k_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, + const std::array& b0_g_l_k_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, + const std::array& b1_g_n_l_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, + const std::array& d0_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); + } + + __host__ __device__ static auto MakeD0sGridDescriptor( + const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, + const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto MakeD1sGridDescriptor( + const std::array, NumD1Tensor>& d1_g_m_o_lengths_vec, + const std::array, NumD1Tensor>& d1_g_m_o_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto + MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, + const std::array& e1_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); + } + }; + + using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {})); + using D0sGridDesc = + remove_cvref_t; + using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {})); + using D1sGridDesc = + remove_cvref_t; + using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {})); + using CGridDesc_M_N = + decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {})); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB0, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB0_(BatchStrideB0), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_E1_(BatchStrideC) + { + } + + ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1) + : BatchStrideA_(BatchStrideA0), + BatchStrideB0_(BatchStrideB0), + BatchStrideD0s_(BatchStrideD0s), + BatchStrideB1_(BatchStrideB1), + BatchStrideD1s_(BatchStrideD1s), + BatchStrideC_E1_(BatchStrideE1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_E1_); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d0_idx) const + { + return g_idx * static_cast(BatchStrideD0s_[d0_idx]); + } + + template + __host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx, + Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD1s_[d1_idx]); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB0_; + std::array BatchStrideD0s_; + index_t BatchStrideB1_; + std::array BatchStrideD1s_; + index_t BatchStrideC_E1_; + }; +}; + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg +{ + using GridwiseGemm = typename DeviceOp::GridwiseOp; + using Common = + DeviceGemmGemm_Wmma_CShuffleV3_Common; + + static constexpr auto NumD0Tensor = Common::NumD0Tensor; + static constexpr auto NumD1Tensor = Common::NumD1Tensor; + + struct Argument : public BaseArgument + { + using arr3 = std::array; + + Argument(const ADataType* p_a_grid_, + const B0DataType* p_b0_grid_, + std::array p_d0s_grid_, + const B1DataType* p_b1_grid_, + std::array p_d1s_grid_, + CE1DataType* p_e1_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t O_, + index_t Batch, + index_t StrideA, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + AElementwiseOperation a_element_op_, + B0ElementwiseOperation b0_element_op_, + AccElementwiseOperation acc_element_op_, + B1ElementwiseOperation b1_element_op_, + CDE1ElementwiseOperation cde1_element_op_) + : p_a_grid{p_a_grid_}, + p_b0_grid{p_b0_grid_}, + p_d0s_grid{}, + p_b1_grid{p_b1_grid_}, + p_d1s_grid{}, + p_c_e1_grid{p_e1_grid_}, + M{M_}, + N{N_}, + K{K_}, + O{O_}, + batch_count{Batch}, + a_element_op{a_element_op_}, + b0_element_op{b0_element_op_}, + acc_element_op{acc_element_op_}, + b1_element_op{b1_element_op_}, + cde1_element_op{cde1_element_op_}, + compute_base_ptr_of_batch{BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1} + { + + a_g_m_k_lengths = arr3{batch_count, M, K}; + a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] + + b0_g_n_k_lengths = arr3{batch_count, N, K}; + b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] + + b1_g_o_n_lengths = arr3{batch_count, O, N}; + b1_g_o_n_strides = + is_same_v + ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] + : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] + + e1_g_m_o_lengths = arr3{batch_count, M, O}; + e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] + + a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths, + a_g_m_k_strides); + b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths, + b0_g_n_k_strides); + b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths, + b1_g_o_n_strides); + c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor( + e1_g_m_o_lengths, e1_g_m_o_strides); + c_e1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_e1_grid_desc_m_n); + block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1); + + if constexpr(IsMultiD) + { + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0DataType = remove_cvref_t>; + + // D0s layout [batch_count, M, N] + d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; + d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; + + // D0 pointer + p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); + }); + // D0 desc + d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor( + d0s_g_m_n_lengths, d0s_g_m_n_strides); + + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + using D1DataType = remove_cvref_t>; + + // D1s layout [batch_count, M, O] + d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; + d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; + + // D1 pointer + p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); + }); + // D1 desc + d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor( + d1s_g_m_o_lengths, d1s_g_m_o_strides); + + d1s_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d1s_grid_desc); + } + } + + // Pointers + const ADataType* p_a_grid; + const B0DataType* p_b0_grid; + typename GridwiseGemm::D0sGridPointer p_d0s_grid; + const B1DataType* p_b1_grid; + typename GridwiseGemm::D1sGridPointer p_d1s_grid; + CE1DataType* p_c_e1_grid; + + // Raw Problem Size + index_t M; + index_t N; + index_t K; + index_t O; + index_t batch_count; + + arr3 a_g_m_k_lengths; + arr3 a_g_m_k_strides; + arr3 b0_g_n_k_lengths; + arr3 b0_g_n_k_strides; + std::array d0s_g_m_n_lengths; + std::array d0s_g_m_n_strides; + arr3 b1_g_o_n_lengths; + arr3 b1_g_o_n_strides; + std::array d1s_g_m_o_lengths; + std::array d1s_g_m_o_strides; + arr3 e1_g_m_o_lengths; + arr3 e1_g_m_o_strides; + + AElementwiseOperation a_element_op; + B0ElementwiseOperation b0_element_op; + AccElementwiseOperation acc_element_op; + B1ElementwiseOperation b1_element_op; + CDE1ElementwiseOperation cde1_element_op; + + // Grid descriptors and other mem calculators + typename Common::AGridDesc a_grid_desc; + typename Common::B0GridDesc b0_grid_desc; + std::conditional_t> d0s_grid_desc; + typename Common::B1GridDesc b1_grid_desc; + typename Common::D1sGridDesc d1s_grid_desc; + std::conditional_t< + IsMultiD, + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + Tuple<>> + d1s_grid_desc_mblock_mperblock_nblock_nperblock; + + std::conditional_t + c_e1_grid_desc_m_n; + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_e1_grid_desc_mblock_mperblock_nblock_nperblock; + + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map; + + typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); + + const index_t grid_size = arg.batch_count * M0 * N0; + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_loop = decltype(has_main_k_block_loop)::value; + constexpr TailNumber tail_num = decltype(tail_number)::value; + const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3; + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + // check if DsLayout is supported + template + static constexpr bool CheckDLayout() + { + bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // Print lambda with env check and printf() style formmating. + const char* curFunc = __func__; + auto print = [&curFunc](const char* format, ...) -> void { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + va_list args; + va_start(args, format); + std::vfprintf(stdout, format, args); + va_end(args); +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; + } + }; + + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + print("DeviceOp: Arch err\n"); + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + print("DeviceOp: gfx 11 does not support fp8\n"); + return false; + } + } + + if constexpr(!(is_same_v || is_same_v)) + { + print("DeviceOp: Acc0 Type err\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: A layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v || + is_same_v)) + { + print("DeviceOp: B1 layout must be Column or Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: C layout must be Row\n"); + return false; + } + + // Other padding modes have not been tested and do not get checked individually. + if constexpr(GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MNKOPadding) + { + print("Padding mode must be default or MNKO\n"); + return false; + } + + // Per wmma dimensions not equal to 16 are very untested. + if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16) + { + print("M, L, N per Wmma must be 16\n"); + return false; + } + + if constexpr(IsMultiD) + { + if constexpr(!(is_same_v)) + { + print("DeviceOp: B0 layout must be Column\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D0s layout must be Row\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D1s layout must be Row\n"); + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto cde1_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; + + // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major + // and the lowest dimension stride is hardcoded to 1 + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + e1_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + else + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + Tuple<>{}, + arg.b1_grid_desc, + Tuple<>{}, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto c_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto c_stride_lowest = arg.e1_g_m_o_strides[2]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) + { + return false; + } + + return true; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp index 06651c0c0ee..83fec9c95f1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -3,91 +3,20 @@ #pragma once -#include #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t e1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx))); - - auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer(); - auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer(); - - static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset; - }); - - static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) { - const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); - p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset; - }); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - p_d0s_grid, - arg.p_b1_grid + b1_batch_offset, - p_d1s_grid, - arg.p_e1_grid + e1_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e1_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.cde1_element_op, - arg.block_2_etile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes: // Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...) // E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...) @@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); - static constexpr auto I0 = Number<0>{}; - // To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal // to LPerWmma (A.k.a Gemm0 NPerWmma). static constexpr index_t NPerWmma = LPerWmma; - // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler - // Transform operator or just not use one at all. - using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< - Sequence<1, 1, 1, 1, 1>, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, - const std::array& d0_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); - } - - __host__ __device__ static auto MakeD0sGridDescriptor( - const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, - const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto MakeD1sGridDescriptor( - const std::array, NumD0Tensor>& d1_g_m_o_lengths_vec, - const std::array, NumD0Tensor>& d1_g_m_o_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto - MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, - const std::array& e1_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using D0sGridDesc = remove_cvref_t; - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using D1sGridDesc = remove_cvref_t; - using E1GridDesc = decltype(MakeE1GridDescriptor({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1) - : BatchStrideA0_(BatchStrideA0), - BatchStrideB0_(BatchStrideB0), - BatchStrideD0s_(BatchStrideD0s), - BatchStrideB1_(BatchStrideB1), - BatchStrideD1s_(BatchStrideD1s), - BatchStrideE1_(BatchStrideE1) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA0_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - template - __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, - Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD0s_[d1_idx]); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE1_); - } - - template - __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD1s_[d1_idx]); - } - - private: - index_t BatchStrideA0_; - index_t BatchStrideB0_; - std::array BatchStrideD0s_; - index_t BatchStrideB1_; - std::array BatchStrideD1s_; - index_t BatchStrideE1_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, // InMemory Data Descriptor - AGridDesc, - B0GridDesc, - D0sGridDesc, - B1GridDesc, - D1sGridDesc, - E1GridDesc, + typename DeviceGemmGemmCommonBase::AGridDesc, + typename DeviceGemmGemmCommonBase::B0GridDesc, + typename DeviceGemmGemmCommonBase::D0sGridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, + typename DeviceGemmGemmCommonBase::D1sGridDesc, + typename DeviceGemmGemmCommonBase::E1GridDesc, // Tiling Family MPerBlock, LPerBlock, @@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, - Transform::matrix_padder.PadN, + DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer>; - struct RawArg : public BaseArgument - { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - std::array p_d0s_grid_, - const B1DataType* p_b1_grid_, - std::array p_d1s_grid_, - E1DataType* p_e1_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - std::array StrideD0s, - index_t StrideB1, - std::array StrideD1s, - index_t StrideE1, - index_t BatchStrideA, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CDE1ElementwiseOperation cde1_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_d0s_grid{}, - p_b1_grid{p_b1_grid_}, - p_d1s_grid{}, - p_e1_grid{p_e1_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - cde1_element_op{cde1_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - e1_g_m_o_lengths = arr3{batch_count, M, O}; - e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides); - e1_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e1_grid_desc_m_n); - block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1); - - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - using D0DataType = remove_cvref_t>; - - // D0s layout [batch_count, M, N] - d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; - d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; - - // D0 pointer - p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); - - // D0 desc - d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]); - }); - - static_for<0, NumD1Tensor, 1>{}([&](auto i) { - using D1DataType = remove_cvref_t>; - - // D1s layout [batch_count, M, O] - d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; - d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; - - // D1 pointer - p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); - - // D1 desc - d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]); - }); - - d1s_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc); - } - - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - typename GridwiseOp::D0sGridPointer p_d0s_grid; - const B1DataType* p_b1_grid; - typename GridwiseOp::D1sGridPointer p_d1s_grid; - E1DataType* p_e1_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - std::array d0s_g_m_n_lengths; - std::array d0s_g_m_n_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - std::array d1s_g_m_o_lengths; - std::array d1s_g_m_o_strides; - arr3 e1_g_m_o_lengths; - arr3 e1_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CDE1ElementwiseOperation cde1_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - D0sGridDesc d0s_grid_desc; - B1GridDesc b1_grid_desc; - D1sGridDesc d1s_grid_desc; - typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock; - - E1GridDesc e1_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map; - - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + D0sLayout, + B1Layout, + D1sLayout, + E1Layout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + E1DataType, + D0sDataType, + D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + CDE0BlockTransferSrcScalarPerVector, + CShuffleBlockTransferScalarPerVector_NPerBlock, + true>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; - // check if DsLayout is supported - template - static constexpr bool CheckDLayout() - { - bool valid = true; - // iterate over DLayout tuple - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - // if RefLayout and DLayout are same, keep valid true, otherwise false - valid = valid && is_same_v; - }); - return valid; - } + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) + static bool IsSupportedArgument(const Argument& arg) { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B0 layout must be Column\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D0s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D1s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc, - arg.e1_grid_desc_m_n, - arg.block_2_etile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto cde1_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; - - // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major - // and the lowest dimension stride is hardcoded to 1 - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - e1_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - static auto MakeArgument(const ADataType* p_a0, const B0DataType* p_b0, std::array p_d0s, @@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op) { - return RawArg{p_a0, p_b0, - p_d0s, p_b1, - p_d1s, p_e1, - MRaw, NRaw, - KRaw, Gemm1NRaw, - Batch, StrideA0, - StrideB0, StrideD0s, - StrideB1, StrideD1s, - StrideE1, BatchStrideA0, - BatchStrideB0, BatchStrideD0s, - BatchStrideB1, BatchStrideD1s, - BatchStrideE1, a0_element_op, - b0_element_op, cde0_element_op, - b1_element_op, cde1_element_op}; + return Argument{p_a0, p_b0, + p_d0s, p_b1, + p_d1s, p_e1, + MRaw, NRaw, + KRaw, Gemm1NRaw, + Batch, StrideA0, + StrideB0, StrideD0s, + StrideB1, StrideD1s, + StrideE1, BatchStrideA0, + BatchStrideB0, BatchStrideD0s, + BatchStrideB1, BatchStrideD1s, + BatchStrideE1, a0_element_op, + b0_element_op, cde0_element_op, + b1_element_op, cde1_element_op}; } // polymorphic @@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation c_element_op) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b0), - p_d0s, - static_cast(p_b1), - p_d1s, - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideD0s, - StrideB1, - StrideD1s, - StrideE1, - BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0s, + static_cast(p_b1), + p_d1s, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideE1, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; }