Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -617,32 +617,32 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m);
const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n);

GemmArgs new_args{};
new_args.a_ptrs_ = p_as_grid;
new_args.b_ptrs_ = p_bs_grid;
new_args.ds_ptrs_ = p_ds_grid;
new_args.e_ptr_ = p_e_grid;

new_args.a_element_op_ = a_element_op_;
new_args.b_element_op_ = b_element_op_;
new_args.cde_element_op_ = cde_element_op_;

new_args.M_ = gemm_m;
new_args.N_ = gemm_n;

new_args.a_grid_desc_ = a_grid_desc;
new_args.b_grid_desc_ = b_grid_desc;
new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
const auto ds_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, m_block, n_block);
new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
const auto e_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, m_block, n_block);

new_args.BlockStart_ = BlockStart;
new_args.BlockEnd_ = BlockEnd;

gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args;
gemm_desc_kernel_args_.Emplace(
valid_gemms_count_,
GemmArgs{.a_ptrs_ = p_as_grid,
.b_ptrs_ = p_bs_grid,
.ds_ptrs_ = p_ds_grid,
.e_ptr_ = p_e_grid,
.a_element_op_ = a_element_op_,
.b_element_op_ = b_element_op_,
.cde_element_op_ = cde_element_op_,
.M_ = gemm_m,
.N_ = gemm_n,
.a_grid_desc_ = a_grid_desc,
.b_grid_desc_ = b_grid_desc,
.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
ds_desc_mblock_mperblock_nblock_nperblock,
.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
e_desc_mblock_mperblock_nblock_nperblock,
.BlockStart_ = BlockStart,
.BlockEnd_ = BlockEnd});

valid_gemms_count_++;
}
Expand Down Expand Up @@ -789,11 +789,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;

static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
});
}
}

void Print() const
Expand All @@ -807,12 +810,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
<< ", is_split_valid=" << std::boolalpha << is_split_valid_
<< std::noboolalpha << ", grid_size=" << grid_size_ << std::endl;

static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << " Ds[" << i.value
<< "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i)
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i)
<< std::endl;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << " Ds[" << i.value << "] group stride="
<< compute_ptr_offset_of_groups_.BatchStrideDs_.At(i)
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i)
<< std::endl;
});
}

std::cout << "===== GEMM splits =====" << std::endl;
for(index_t i = 0; i < valid_gemms_count_; ++i)
Expand All @@ -836,11 +842,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: "
<< gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl;

static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
std::cout << " D" << d_idx.value << " descriptor: "
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx)
<< std::endl;
});
if constexpr(NumDTensor > 0)
{
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
std::cout << " D" << d_idx.value << " descriptor: "
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx)
<< std::endl;
});
}
}
}

Expand Down
11 changes: 11 additions & 0 deletions include/ck/utility/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "functional2.hpp"
#include "sequence.hpp"
#include <type_traits>
#include <cassert>

namespace ck {

Expand All @@ -27,6 +29,15 @@ struct Array

__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }

template <typename... Args>
__host__ constexpr auto Emplace(index_t i, Args&&... args)
-> std::enable_if_t<std::is_nothrow_constructible_v<TData, Args&&...>>
{
assert(i >= 0 && i < NSize);
mData[i].~TData();
new(mData + i) TData(ck::forward<Args>(args)...);
}

template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
Expand Down
Loading