diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index b5e11c30e1..ccc605c060 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -535,6 +535,7 @@ std::vector> matrix_sizes = { {1024}, {8, 32, 1024}, {16, 8, 4, 512}, + {8192, 7168}, }; std::vector> block_sizes = { diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index e469ad0845..75e8058a6a 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -58,7 +58,8 @@ void compute_ref(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) + const size_t scales_stride_colwise, + const bool use_fast_math) { const size_t tile_size_Y = 32; const size_t tile_size_X = 32; @@ -129,7 +130,10 @@ void compute_ref(const ProcessingMethod processing_method, const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); const size_t scale_idx = i * scales_stride_rowwise + tile_X; output_scales_rowwise[scale_idx] = biased_exponent; - const float scale_reciprocal = exp2f_rcp(biased_exponent); + float scale_reciprocal = exp2f_rcp(biased_exponent); + if (use_fast_math) { + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } for (size_t j = j_min; j < j_max; ++j) { const size_t idx = i * cols + j; @@ -150,7 +154,10 @@ void compute_ref(const ProcessingMethod processing_method, const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); const size_t scale_idx = tile_Y * scales_stride_colwise + j; output_scales_colwise[scale_idx] = biased_exponent; - const float scale_reciprocal = exp2f_rcp(biased_exponent); + float scale_reciprocal = exp2f_rcp(biased_exponent); + if (use_fast_math) { + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } for (size_t i = i_min; i < i_max; ++i) { const size_t idx = i * cols + j; @@ -241,7 +248,8 @@ void performTest(const ProcessingMethod processing_method, const std::vector& last_dims_h, const std::vector& offsets_h, const bool rowwise, - const bool colwise) { + const bool colwise, + const bool use_fast_math) { using namespace test; DType itype = TypeInfo::dtype; @@ -272,9 +280,13 @@ void performTest(const ProcessingMethod processing_method, const size_t elts = M * K; elts_num += elts; + auto divide_round_up_blocks = [](const size_t N, const size_t M) -> size_t { + return (N == 0) ? 0 : 1 + (N - 1) / M; + }; + const size_t unpadded_rowwise_blocks_Y = M; - const size_t unpadded_rowwise_blocks_X = divide_round_up(K, 32); - const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32); + const size_t unpadded_rowwise_blocks_X = divide_round_up_blocks(K, 32); + const size_t unpadded_colwise_blocks_Y = divide_round_up_blocks(M, 32); const size_t unpadded_colwise_blocks_X = K; rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128); @@ -371,7 +383,7 @@ void performTest(const ProcessingMethod processing_method, NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); - std::vector dbias_logical_shape_vec= {num_tensors, cols}; + std::vector dbias_logical_shape_vec = {num_tensors, cols}; NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), dbias_logical_shape_vec.size()); @@ -496,14 +508,18 @@ void performTest(const ProcessingMethod processing_method, out_scales_rowwise_ptr, out_scales_colwise_ptr, ref_output_dbias_ptr, M, K, scales_stride_rowwise, - scales_stride_colwise); + scales_stride_colwise, + use_fast_math); } + QuantizationConfigWrapper quant_config; + quant_config.set_use_fast_math(use_fast_math); + // GPU Tensor workspace; switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_group_quantize(in_group_tensor, out_group_tensor, 0); + nvte_group_quantize_v2(in_group_tensor, out_group_tensor, quant_config, 0); break; } case ProcessingMethod::CAST_DBIAS: { @@ -554,6 +570,11 @@ void performTest(const ProcessingMethod processing_method, const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; + // Compare only allocated contiguous output range. + // In graph-safe mode logical shape may include trailing garbage beyond offsets_h.back(). + const size_t compare_rows = 1; + const size_t compare_cols = elts_num; + if (rowwise) { cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost); @@ -566,7 +587,8 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), - out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); + out_data_rowwise_h.data(), compare_rows, compare_cols, + true, mismatches_elts); } if (colwise) { @@ -581,7 +603,8 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), - out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + out_data_colwise_h.data(), compare_rows, compare_cols, + false, mismatches_elts); } if (compute_dbias) { @@ -651,9 +674,13 @@ std::vector> input_config = { {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + // Empty tensor in the middle of the group must not terminate the persistent work loop. + {VARYING_FIRST_DIM, 4, 512,160, 128,0,0,256}, + {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, }; } // namespace @@ -664,7 +691,8 @@ class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam ScalingDirection, std::vector, // Config transformer_engine::DType, // InputType - transformer_engine::DType // OutputType + transformer_engine::DType, // OutputType + bool >> {}; TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { @@ -682,6 +710,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { const std::vector input_config = std::get<3>(GetParam()); const DType input_type = std::get<4>(GetParam()); const DType output_type = std::get<5>(GetParam()); + const bool use_fast_math = std::get<6>(GetParam()); const ShapeRepresentation shape_rep = static_cast(input_config[0]); const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); @@ -745,6 +774,15 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { GTEST_SKIP(); } + // Skip fused tests in fast math is enabled. + if (use_fast_math) { + if (processing_method != ProcessingMethod::CAST_ONLY) { + GTEST_SKIP(); + } + if ((input_type != DType::kBFloat16) && (input_type != DType::kFloat16)) { + GTEST_SKIP(); + } + } bool rowwise = false; bool colwise = false; @@ -779,7 +817,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, performTest(processing_method, OP, shape_rep, num_tensors, logical_shape, first_dims, last_dims, offsets, - rowwise, colwise); + rowwise, colwise, use_fast_math); ); ); } @@ -807,6 +845,40 @@ std::string to_string(const ActivationKind activation) { } } +std::string MakeGroupedFusedCastMXFP8TestName( + const testing::TestParamInfo& info) { + const ProcessingMethod method = std::get<0>(info.param); + std::string name = to_string(method); + name += "X" + to_string(std::get<1>(info.param)); + + switch (std::get<2>(info.param)) { + case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break; + case ScalingDirection::COLWISE: name += "_COLWISE_"; break; + case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break; + } + + const std::vector input = std::get<3>(info.param); + + switch (static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; + case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; + case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; + case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; + } + + name += "_N_" + std::to_string(input[1]); + + name += "_SHAPE_" + std::to_string(input[2]) + "X" + std::to_string(input[3]); + + name += "_" + test::typeName(std::get<4>(info.param)) + + "_" + test::typeName(std::get<5>(info.param)); + + if (std::get<6>(info.param)) { + name += "_FASTMATH"; + } + return name; +} + INSTANTIATE_TEST_SUITE_P( OperatorTest, GroupedFusedCastMXFP8TestSuite, @@ -816,34 +888,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), - [](const testing::TestParamInfo& info) { - const ProcessingMethod method = std::get<0>(info.param); - std::string name = to_string(method); - name += "X" + to_string(std::get<1>(info.param)); - - switch (std::get<2>(info.param)) { - case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break; - case ScalingDirection::COLWISE: name += "_COLWISE_"; break; - case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break; - } - - const std::vector input = std::get<3>(info.param); - - switch(static_cast(input[0])) { - case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; - case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; - case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; - case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; - }; - - name += "_N_" + std::to_string(input[1]); - - name += "_SHAPE_" + - std::to_string(input[2]) + - "X" + std::to_string(input[3]); - - name += "_" + test::typeName(std::get<4>(info.param)) + - "_" + test::typeName(std::get<5>(info.param)); - return name; - }); + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(true, false)), + MakeGroupedFusedCastMXFP8TestName); diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 4f9ddb4fc5..67b7b908e6 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -56,6 +56,15 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, dispatch::quantize_fwd_helper(input, output, quant_config, stream); } +void nvte_group_quantize_v2(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_v2); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::group_quantize_fwd_helper(input, output, quant_config, stream); +} + void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias); diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index a4e033939b..f150fa7981 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -23,13 +23,6 @@ namespace transformer_engine { namespace dispatch { namespace common { -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 -}; - inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); @@ -100,14 +93,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t tensor_id = blockIdx.y; const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) ? (first_logical_dim / num_tensors) - : first_dims_ptr[tensor_id]; + : static_cast(first_dims_ptr[tensor_id]); const size_t rows = tensor_rows / chunk_dim_Y; const size_t cols = last_logical_dim; - const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) - ? (tensor_id * (tensor_rows / chunk_dim_Y)) - : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + const size_t dbias_in_offset_Y = + (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index f7823b4c58..8d985f64f3 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -409,7 +409,7 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor case NVTE_MXFP8_1D_SCALING: { mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); + workspace_tensor, &quant_config_cpp, stream); break; } default: @@ -450,7 +450,7 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe case NVTE_MXFP8_1D_SCALING: { mxfp8::group_quantize( grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); + &quant_config_cpp, stream); break; } default: diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 129d6724ac..2350837dad 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -17,6 +17,7 @@ #include #include "../../common.h" +#include "../../util/cuda_runtime.h" #include "../../util/math.h" #include "../../util/ptx.cuh" #include "../../utils.cuh" @@ -36,41 +37,62 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +struct TunableConfig { + static constexpr uint CHUNK_DIM_Y = 128; + static constexpr uint CHUNK_DIM_X = 128; + static constexpr uint THREADS_PER_CHUNK = 128; + // true -> static persistent grid-stride scheduler + // false -> non-persistent one-job-per-CTA execution + static constexpr bool PERSISTENT = true; + // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). + static constexpr uint STATIC_PERSISTENT_BLOCKS_PER_SM = 24; +}; + +constexpr bool PERSISTENT = TunableConfig::PERSISTENT; +static_assert(!PERSISTENT || (TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0), + "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero in persistent mode."); + constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; -constexpr size_t BUFFS_NUM = 2; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; +constexpr uint PREFETCH_STAGES = 1; +constexpr uint BUFFS_NUM = PREFETCH_STAGES + 1; +constexpr uint PACK_SIZE = 4; +constexpr uint WAVES = SCALE_DIM_X / PACK_SIZE; -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 128; +constexpr uint CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; +constexpr uint CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; +constexpr uint THREADS_PER_CHUNK = TunableConfig::THREADS_PER_CHUNK; constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; -constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; -constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; +constexpr uint THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; +constexpr uint THREADS_Y = THREADS_PER_CHUNK / THREADS_X; -constexpr size_t BUFF_DIM_Y = THREADS_Y; -constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +constexpr uint BUFF_DIM_Y = THREADS_Y; +constexpr uint BUFF_DIM_X = CHUNK_DIM_X; +constexpr uint BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; static_assert(BUFF_DIM_Y == 32); -constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; +constexpr uint STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; static_assert(STAGES >= 1); +static_assert(CHUNK_DIM_Y % BUFF_DIM_Y == 0); +static_assert(CHUNK_DIM_Y % SCALE_DIM_Y == 0); +static_assert(CHUNK_DIM_X % SCALE_DIM_X == 0); + // Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 +constexpr uint TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 - -__device__ __forceinline__ size_t get_current_tensor_id( - const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, - const size_t block_Y, const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr) { - if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { +constexpr uint THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +template +__device__ __forceinline__ size_t +get_current_tensor_id(const size_t num_tensors, const size_t current_offset, const size_t block_Y, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { + if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t current_row = block_Y * CHUNK_DIM_Y; const size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; @@ -92,41 +114,215 @@ __device__ __forceinline__ size_t get_current_tensor_id( } } +template +__device__ __forceinline__ size_t +get_tensor_rows_num(const size_t tensor_id, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { + size_t rows_num = 0; + if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_LAST_DIM) { + rows_num = first_logical_dim; + } else { + rows_num = static_cast(first_dims_ptr[tensor_id]); + } + if (rows_num % 128 != 0) { + NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + } + return rows_num; +} + __device__ __forceinline__ size_t get_tensor_rows_num( const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { - size_t rows_num = 0; switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: - case ShapeRepresentation::VARYING_LAST_DIM: - rows_num = first_logical_dim; - break; + return get_tensor_rows_num(tensor_id, first_logical_dim, + first_dims_ptr, num_tensors); case ShapeRepresentation::VARYING_FIRST_DIM: + return get_tensor_rows_num( + tensor_id, first_logical_dim, first_dims_ptr, num_tensors); + case ShapeRepresentation::VARYING_LAST_DIM: + return get_tensor_rows_num( + tensor_id, first_logical_dim, first_dims_ptr, num_tensors); case ShapeRepresentation::VARYING_BOTH_DIMS: - rows_num = static_cast(first_dims_ptr[tensor_id]); - break; + return get_tensor_rows_num( + tensor_id, first_logical_dim, first_dims_ptr, num_tensors); } - if (rows_num % 128 != 0) { - NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + return 0; +} + +template +__device__ __forceinline__ size_t +get_tensor_cols_num(const size_t tensor_id, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { + size_t cols_num = 0; + if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM) { + cols_num = last_logical_dim; + } else { + cols_num = static_cast(last_dims_ptr[tensor_id]); + if (cols_num % 128 != 0) { + NVTE_DEVICE_ERROR( + "For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); + } } - return rows_num; + return cols_num; } __device__ __forceinline__ size_t get_tensor_cols_num( const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, const int64_t *const __restrict__ last_dims_ptr) { - size_t cols_num = 0; switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: + return get_tensor_cols_num(tensor_id, last_logical_dim, + last_dims_ptr); case ShapeRepresentation::VARYING_FIRST_DIM: - cols_num = last_logical_dim; - break; + return get_tensor_cols_num( + tensor_id, last_logical_dim, last_dims_ptr); case ShapeRepresentation::VARYING_LAST_DIM: + return get_tensor_cols_num(tensor_id, last_logical_dim, + last_dims_ptr); case ShapeRepresentation::VARYING_BOTH_DIMS: - cols_num = static_cast(last_dims_ptr[tensor_id]); - break; + return get_tensor_cols_num( + tensor_id, last_logical_dim, last_dims_ptr); } - return cols_num; + return 0; +} + +// Logical work-item decoded from CTA coordinates. +struct JobDescriptor { + size_t block_id = 0; + size_t block_global_offset = 0; + size_t tensor_id = 0; + size_t rows = 0; + size_t cols = 0; + + __host__ __device__ __forceinline__ constexpr JobDescriptor() = default; + + __host__ __device__ __forceinline__ constexpr JobDescriptor(const size_t block_id_, + const size_t block_global_offset_, + const size_t tensor_id_, + const size_t rows_, + const size_t cols_) + : block_id(block_id_), + block_global_offset(block_global_offset_), + tensor_id(tensor_id_), + rows(rows_), + cols(cols_) {} +}; + +// Tensor-local coordinates for a work-item. +struct BlockDescriptor { + size_t tensor_base = 0; + size_t block_id_in_current_tensor = 0; + size_t block_id_Y = 0; + size_t block_id_X = 0; + size_t block_offset_Y = 0; + size_t block_offset_X = 0; + + __host__ __device__ __forceinline__ constexpr BlockDescriptor() = default; + + __host__ __device__ __forceinline__ constexpr BlockDescriptor( + const size_t tensor_base_, const size_t block_id_in_current_tensor_, const size_t block_id_Y_, + const size_t block_id_X_, const size_t block_offset_Y_, const size_t block_offset_X_) + : tensor_base(tensor_base_), + block_id_in_current_tensor(block_id_in_current_tensor_), + block_id_Y(block_id_Y_), + block_id_X(block_id_X_), + block_offset_Y(block_offset_Y_), + block_offset_X(block_offset_X_) {} +}; + +template +__device__ __forceinline__ JobDescriptor decode_job( + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const size_t work_blocks_X, const int32_t ctaid_X, const int32_t ctaid_Y, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr) { + constexpr bool is_single_tensor = (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); + const size_t block_id = ctaid_Y * work_blocks_X + ctaid_X; + const size_t block_global_offset = + is_single_tensor ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (block_id * ELTS_PER_CHUNK); + const size_t tensor_id = get_current_tensor_id( + num_tensors, block_global_offset, ctaid_Y, first_logical_dim, last_logical_dim, offsets_ptr); + const size_t rows = + get_tensor_rows_num(tensor_id, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, last_logical_dim, last_dims_ptr); + return JobDescriptor(block_id, block_global_offset, tensor_id, rows, cols); +} + +template +__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, + const size_t total_work_blocks, + const int64_t *const __restrict__ offsets_ptr) { + const bool is_valid = (job.block_id < total_work_blocks); + if (!is_valid) { + return false; + } + if (job.rows == 0 || job.cols == 0) { + return true; + } + if constexpr (SHAPE_REP == SAME_BOTH_DIMS) { + return true; + } + + const size_t tensor_start_offset = static_cast(offsets_ptr[job.tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[job.tensor_id + 1]); + if (job.block_global_offset >= tensor_end_offset) { + return false; + } + + const size_t tensor_offset_from_start = job.block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / job.cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % job.cols; + if (block_offset_Y_in_tensor >= job.rows) { + return false; + } + + return true; +} + +__device__ __forceinline__ bool job_has_work(const JobDescriptor &job) { + return job.rows != 0 && job.cols != 0; +} + +__device__ __forceinline__ void advance_to_next_job(bool &job_finished, int32_t &ctaid_X, + int32_t &ctaid_Y, size_t &static_next_block_id, + const size_t static_block_stride, + const size_t total_work_blocks, + const size_t work_blocks_X) { + if constexpr (PERSISTENT) { + if (static_next_block_id < total_work_blocks) { + ctaid_X = static_cast(static_next_block_id % work_blocks_X); + ctaid_Y = static_cast(static_next_block_id / work_blocks_X); + static_next_block_id += static_block_stride; + } else { + job_finished = true; + } + } else { + job_finished = true; + } +} + +template +__device__ __forceinline__ BlockDescriptor +decode_block(const JobDescriptor &job, const int64_t *const __restrict__ offsets_ptr) { + constexpr bool is_single_tensor = (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); + const size_t CHUNK_DIM_X_ = CHUNK_DIM_X; + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_); + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); + const size_t block_id_in_current_tensor = + is_single_tensor ? job.block_id : (job.block_id - tensor_base / ELTS_PER_CHUNK); + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + return BlockDescriptor(tensor_base, block_id_in_current_tensor, block_id_Y, block_id_X, + block_offset_Y, block_offset_X); } // Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index @@ -169,29 +365,35 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te } template -__global__ void update_tma_descriptors( - const __grid_constant__ CUtensorMap base_tensor_map_input, - const __grid_constant__ CUtensorMap base_tensor_map_act_input, - const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, - const IType *const __restrict__ input_data_ptr, - const IType *const __restrict__ act_input_data_ptr, - const OType *const __restrict__ output_rowwise_data_ptr, - const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep, - const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, - const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, - const bool compute_dactivations) { - const bool leading_thread = (threadIdx.x == 0); +__global__ void __launch_bounds__(1) + update_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType *const __restrict__ input_data_ptr, + const IType *const __restrict__ act_input_data_ptr, + const OType *const __restrict__ output_rowwise_data_ptr, + const OType *const __restrict__ output_colwise_data_ptr, + const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, + const bool colwise, const bool compute_dactivations) { const size_t tensor_id = blockIdx.x; - const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - const size_t offset_elts = offsets_ptr[tensor_id]; - if (leading_thread && (tensor_id < num_tensors)) { + // Zero-sized groups: skip TMA descriptor update. The main kernel already returns + // early for rows==0 or cols==0, but creating a TMA descriptor with a zero dimension + // is invalid and causes CUDA_ERROR_ILLEGAL_ADDRESS. + if (rows == 0 || cols == 0) { + return; + } + + if (tensor_id < num_tensors) { { const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], @@ -228,125 +430,484 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +// Issue TMA global->shared transfer for one stage of input (and optional activation input). +template +__device__ __forceinline__ void prefetch_input_stage( + IType *in_sh, IType *act_in_sh, const CUtensorMap &tensor_map_input, + const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, + const size_t global_offset_Y, const size_t buff_offset, const size_t shmem_buff_size, + uint64_t *barrier, const bool leading_thread) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buff_offset]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + barrier); + if constexpr (IS_DACT) { + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[buff_offset]), + reinterpret_cast(&tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); + } + } +} + +// Issue TMA shared->global transfer for one stage of outputs. +template +__device__ __forceinline__ void store_output_stage( + OType *out_rowwise_data_sh, OType *out_colwise_data_sh, + const CUtensorMap &tensor_map_output_rowwise, const CUtensorMap &tensor_map_output_colwise, + const size_t global_offset_X, const size_t global_offset_Y, const size_t buff_offset, + const bool leading_thread) { + if (!leading_thread) { + return; + } + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + if constexpr (ROWWISE_SCALING || COLWISE_SCALING) { + ptx::cp_async_bulk_commit_group(); + } +} + template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( - const __grid_constant__ CUtensorMap tensor_map_input_static, - const __grid_constant__ CUtensorMap tensor_map_act_input_static, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, - const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, - const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, - const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, - const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, - e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, - float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + bool WITH_GEMM_SWIZZLED_SCALES, bool USE_FAST_MATH> +__device__ __forceinline__ float process_colwise_stage( + const size_t buff, const int stage, const size_t tid_X_colwise, + const size_t scales_offset_Y_colwise, const size_t scales_offset_X_colwise, + const size_t scale_stride_colwise, const size_t tensor_base_for_scales, const size_t rows, + const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, + OType *out_colwise_data_sh, e8m0_t *scales_colwise, float &partial_dbias_colwise) { + using IType2 = typename ptx::FPx2; + using IType4 = typename ptx::FPx4; + using OType4 = typename ptx::FPx4; + + constexpr uint32_t IN_SHMEM_STRIDE = static_cast(BUFF_DIM_X * sizeof(IType)); + constexpr uint32_t OUT_SHMEM_STRIDE = static_cast(BUFF_DIM_X * sizeof(OType)); + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; + constexpr bool NON_FP32_CAST_ONLY = + NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); + + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + float thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + IType4 in_colwise_IType4[BUFF_DIM_Y / 4]; + + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; i += 4) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_colwise]); + + // Load 4x elts S2R and find amax + if constexpr (std::is_same_v) { + asm volatile( + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %2; \n\t" + "mov.u32 stride, %3; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "ld.shared.b16 x0, [ptr0]; \n\t" + "ld.shared.b16 x1, [ptr1]; \n\t" + "ld.shared.b16 x2, [ptr2]; \n\t" + "ld.shared.b16 x3, [ptr3]; \n\t" + "mov.b64 %0, {x0,x1,x2,x3}; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b32 x01, {x0,x1}; \n\t" + "mov.b32 x23, {x2,x3}; \n\t" + "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.bf16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in_colwise_IType4[i / 4])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE)); + } else { + asm volatile( + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %2; \n\t" + "mov.u32 stride, %3; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "ld.shared.b16 x0, [ptr0]; \n\t" + "ld.shared.b16 x1, [ptr1]; \n\t" + "ld.shared.b16 x2, [ptr2]; \n\t" + "ld.shared.b16 x3, [ptr3]; \n\t" + "mov.b64 %0, {x0,x1,x2,x3}; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b32 x01, {x0,x1}; \n\t" + "mov.b32 x23, {x2,x3}; \n\t" + "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.f16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in_colwise_IType4[i / 4])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE)); + } + } + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; + } else if constexpr (NON_FP32_CAST_ONLY) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; } } - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = + tensor_scales_offset_colwise_base + + transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + global_scales_offset_X, local_scales_offset_Y, DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } + scales_colwise[scale_idx] = biased_exponent; - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); - const size_t block_ID = blockIdx.y * gridDim.x + blockIdx.x; - const size_t block_global_offset = - is_single_tensor ? (blockIdx.y * CHUNK_DIM_Y * last_logical_dim + blockIdx.x * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; i += 4) { + OType4 out; + ptx::mul_cvt_4x(out, in_colwise_IType4[i / 4], block_scale_inverse_f16); + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + const uint32_t dst_smem_ptr = + __cvta_generic_to_shared(&out_colwise_data_sh[shmem_offset_elt]); + + asm volatile( + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %0; \n\t" + "mov.u32 stride, %1; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b8 x0,x1,x2,x3; \n\t" + "mov.b32 {x0,x1,x2,x3}, %2; \n\t" + "st.shared.b8 [ptr0], x0; \n\t" + "st.shared.b8 [ptr1], x1; \n\t" + "st.shared.b8 [ptr2], x2; \n\t" + "st.shared.b8 [ptr3], x3; \n" + "}\n" ::"r"(dst_smem_ptr), + "r"(OUT_SHMEM_STRIDE), "r"(reinterpret_cast(out))); + } + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NON_FP32_CAST_ONLY) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, blockIdx.y, - first_logical_dim, last_logical_dim, offsets_ptr); + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + return thread_amax; +} - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); +template +__device__ __forceinline__ float process_rowwise_stage( + const size_t buff, const size_t stage_offset_Y, const size_t thread_offset_Y_rowwise, + const size_t thread_offset_X_rowwise, const int bank_group, + const size_t scales_offset_Y_rowwise, const size_t scales_offset_X_rowwise, + const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, const size_t cols, + IType *in_sh, IType *act_in_sh, IType *cached_act_sh, OType *out_rowwise_data_sh, + e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { + using IType2 = typename ptx::FPx2; + using IType4 = typename ptx::FPx4; + using OType2 = typename ptx::FPx2; + using OType4 = typename ptx::FPx4; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; + constexpr bool NON_FP32_CAST_ONLY = + NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); + + const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + float thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + Vec in_IType[WAVES]; + IType4 in_IType4[WAVES]; + + if constexpr (NON_FP32_CAST_ONLY) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + if constexpr (USE_FAST_MATH) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_rowwise]); + // Load 4x elts S2R and find amax + if constexpr (std::is_same_v) { + asm volatile( + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.bf16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr)); + } else { + asm volatile( + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.f16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr)); + } + } else { + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + } + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); - const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + Vec in; + Vec act_in; - // grouped tensor can be treated as continuous tensor for MXFP8 - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); - // For grouped tensors represented as a single logical tensor, scale swizzle must still be - // computed per tensor (expert) and then concatenated along dim-0. - const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) - ? static_cast(offsets_ptr[tensor_id]) - : tensor_base; + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } - // In graph-safe paged stashing, the logical shape can include trailing garbage. Skip CTAs that - // map outside the current tensor's valid [rows, cols] region. - if (rows == 0 || cols == 0) { - return; - } - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - return; - } - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - return; + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; + } } } - const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = - is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = - is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const bool leading_thread = (threadIdx.x == 0); + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } - if (leading_thread && (!is_single_tensor)) { - fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&tensor_map_act_input); - } - if constexpr (ROWWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_rowwise); - } - if constexpr (COLWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_colwise); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { + uint32_t out_4x = 0; + OType4 &out = *reinterpret_cast(&out_4x); + ptx::mul_cvt_4x(out, in_IType4[w], block_scale_inverse_f16); + + const uint32_t dst_smem_ptr = + __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); + } else { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NON_FP32_CAST_ONLY) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } } + return thread_amax; +} - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr, + const size_t work_blocks_X, const size_t work_blocks_Y) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool NON_FP32_CAST_ONLY = + NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + if constexpr (USE_FAST_MATH && !NON_FP32_CAST_ONLY) { + return; + } - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + constexpr bool ROWWISE_SCALING = + (SCALING_TYPE == ScalingType::ROWWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool COLWISE_SCALING = + (SCALING_TYPE == ScalingType::COLWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); - e8m0_t *const scales_rowwise = - scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); - e8m0_t *const scales_colwise = - scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + constexpr ShapeRepresentation shape_rep = SHAPE_REP; + constexpr bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + const bool leading_thread = (threadIdx.x == 0); const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; const size_t tid_X_rowwise = threadIdx.x % THREADS_X; @@ -356,11 +917,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -390,376 +946,247 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } + constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; float block_amax = 0.0f; -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, leading_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); + const size_t total_work_blocks = work_blocks_X * work_blocks_Y; + const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; + + int IN_buff_readable_parity[BUFFS_NUM] = {0}; + int32_t ctaid_X = static_cast(blockIdx.x); + int32_t ctaid_Y = static_cast(blockIdx.y); + size_t static_next_block_id = 0; + size_t static_block_stride = 0; + // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. + if constexpr (PERSISTENT) { + if (launch_block_id >= total_work_blocks) { + return; + } + ctaid_X = static_cast(launch_block_id % work_blocks_X); + ctaid_Y = static_cast(launch_block_id / work_blocks_X); + static_block_stride = gridDim.x * gridDim.y; + static_next_block_id = launch_block_id + static_block_stride; } - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - leading_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); - } + bool job_finished = false; + size_t last_acquired_tensor_id = num_tensors; + + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + // Initialize barriers shared by the entire CTA: + // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. + initialize_barriers(IN_buff_readable_mbar, leading_thread); + + // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. + while (!job_finished) { + // Decode CTA assignment into logical tensor coordinates and validate bounds. + const JobDescriptor current_job = + decode_job(num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, + ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const bool current_job_is_valid = + is_job_valid(current_job, total_work_blocks, offsets_ptr); + if (!current_job_is_valid) { + break; + } + if (!job_has_work(current_job)) { + // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, + total_work_blocks, work_blocks_X); + continue; } - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_colwise[i] = elt; - } + const size_t tensor_id = current_job.tensor_id; + const size_t rows = current_job.rows; + const size_t cols = current_job.cols; + const BlockDescriptor current_block = decode_block(current_job, offsets_ptr); + + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + + const size_t tensor_base = current_block.tensor_base; + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + const size_t block_id_Y = current_block.block_id_Y; + const size_t block_id_X = current_block.block_id_X; + const size_t block_offset_Y = current_block.block_offset_Y; + const size_t block_offset_X = current_block.block_offset_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + const size_t dbias_offset_Y = block_id_Y; + const size_t dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = is_single_tensor + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = is_single_tensor + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; + + if (leading_thread && (!is_single_tensor) && (last_acquired_tensor_id != tensor_id)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - const size_t tensor_base_row = tensor_base_for_scales / cols; - const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; - const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; - const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); - } else { - scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); } - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); } + last_acquired_tensor_id = tensor_id; } + __syncthreads(); - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; + int buff_in = 0; - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { + // Prime the pipeline with the first PREFETCH_STAGES slices of the current block. #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const size_t buff = stage; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + uint64_t *barrier = &IN_buff_readable_mbar[buff]; + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, + global_offset_X, global_offset_Y, buff_offset, + shmem_buff_size, barrier, leading_thread); + } - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; - } - } + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; } + } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, - DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + if (stage < STAGES - PREFETCH_STAGES) { + const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const size_t next_prefetch_stage = stage + PREFETCH_STAGES; + const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + + const size_t global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + prefetch_input_stage( + in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); } - scales_rowwise[scale_idx] = biased_exponent; - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); -// 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + const size_t buff = buff_in; + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_amax = + process_colwise_stage( + buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, + cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); } - } - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); + if constexpr (ROWWISE_SCALING) { + thread_amax = + process_rowwise_stage( + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, + scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, + out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); + } - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - // Initiate TMA transfer to copy shared memory to global memory - if (leading_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); - } + // Publish the stage from shared memory into global outputs via TMA. + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + store_output_stage( + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, + tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, leading_thread); - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); + buff_in = (buff_in + 1) % BUFFS_NUM; } - } - - parity ^= 1; - if constexpr (IS_DBIAS) { - if (is_single_tensor) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); + const size_t shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + const size_t shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + for (int i = 0; i < THREADS_Y; ++i) { + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const size_t dbias_stride = cols; + const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; } - } - const int dbias_stride = cols; - const int dbias_offset_Y = block_id_Y; - const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; } } + + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, + total_work_blocks, work_blocks_X); } if (amax_ptr != nullptr) { @@ -772,7 +1199,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, leading_thread); + destroy_barriers(IN_buff_readable_mbar, leading_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace group_quantize_kernel @@ -781,9 +1208,17 @@ template void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, - Tensor *workspace, cudaStream_t stream) { + Tensor *workspace, const QuantizationConfig *quant_config, + cudaStream_t stream) { using namespace group_quantize_kernel; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + if (use_fast_math) { + NVTE_CHECK(input->dtype() == DType::kBFloat16 || input->dtype() == DType::kFloat16, + "Fast math supports only BF16 and FP16 input types."); + NVTE_CHECK(!IS_DBIAS && !IS_DACT && !IS_ACT, "Fast math does not support fused casts."); + } + checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -832,20 +1267,30 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks_X = 0; - size_t blocks_Y = 0; + size_t work_blocks_X = 0; + size_t work_blocks_Y = 0; if (is_single_tensor) { - blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + work_blocks_Y = DIVUP(first_logical_dim, static_cast(CHUNK_DIM_Y)); + work_blocks_X = DIVUP(last_logical_dim, static_cast(CHUNK_DIM_X)); } else { NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - blocks_Y = 1; - blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + work_blocks_Y = 1; + work_blocks_X = DIVUP(elts_total, ELTS_PER_CHUNK); + } + + size_t launch_blocks_X = work_blocks_X; + size_t launch_blocks_Y = work_blocks_Y; + if constexpr (PERSISTENT) { + const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); + const size_t static_grid_size = sm_num * TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; + NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); + launch_blocks_X = static_grid_size; + launch_blocks_Y = 1; } - const dim3 grid(blocks_X, blocks_Y); + const dim3 grid(launch_blocks_X, launch_blocks_Y); const size_t block_size = THREADS_PER_CHUNK; const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -884,7 +1329,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias."); NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - const size_t dbias_workspace_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_workspace_rows = DIVUP(first_logical_dim, static_cast(CHUNK_DIM_Y)); const size_t dbias_workspace_cols = last_logical_dim; if (workspace->data.dptr == nullptr) { workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; @@ -897,125 +1342,127 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations input->dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - auto kernel = - group_quantize_mxfp8_kernel; - switch (scaling_type) { - case ScalingType::ROWWISE: { - kernel = - group_quantize_mxfp8_kernel; - break; - } - case ScalingType::COLWISE: { - kernel = - group_quantize_mxfp8_kernel; - break; - } - case ScalingType::BIDIMENSIONAL: { - kernel = - group_quantize_mxfp8_kernel; - break; - } - } - - // Update tensor descriptors before launching the kernel - if (!is_single_tensor) { - const IType *const input_dptr = reinterpret_cast(input->data.dptr); - - const IType *const act_input_dptr = - IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; - - OType *const output_rowwise_dptr = - use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; - - OType *const output_colwise_dptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - update_tma_descriptors<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, - output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, - use_rowwise_scaling, use_colwise_scaling, IS_DACT); - } - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - - if constexpr (IS_DBIAS) { - common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, - first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); - } - - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH( + scaling_type, SCALING_TYPE, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH( + shape_rep, SHAPE_REP, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + { + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, + BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, + BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, + output->columnwise_data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = + (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = + (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = + (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = + (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = + reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) + : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) + : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling + ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, + output_rowwise_dptr, output_colwise_dptr, shape_rep, num_tensors, + first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, + last_dims_ptr, use_rowwise_scaling, use_colwise_scaling, IS_DACT); + } + + auto kernel = group_quantize_mxfp8_kernel< + IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, SCALING_TYPE, + WITH_GEMM_SWIZZLED_SCALES, SHAPE_REP, USE_FAST_MATH>; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, + scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, + amax_ptr, work_blocks_X, work_blocks_Y); + + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, + CHUNK_DIM_Y, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError()); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 } // namespace dispatch } // namespace transformer_engine - #endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index a98668d058..8db34b5756 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -904,6 +904,48 @@ struct TypeInfo { { __VA_ARGS__ } \ } +#define TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH(SCALING_TYPE, SCALING_T, ...) \ + switch (SCALING_TYPE) { \ + case ScalingType::ROWWISE: { \ + constexpr ScalingType SCALING_T = ScalingType::ROWWISE; \ + { __VA_ARGS__ } \ + } break; \ + case ScalingType::COLWISE: { \ + constexpr ScalingType SCALING_T = ScalingType::COLWISE; \ + { __VA_ARGS__ } \ + } break; \ + case ScalingType::BIDIMENSIONAL: { \ + constexpr ScalingType SCALING_T = ScalingType::BIDIMENSIONAL; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported scaling type."); \ + } \ + } + +#define TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(SHAPE_REP, SHAPE, ...) \ + switch (SHAPE_REP) { \ + case ShapeRepresentation::SAME_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::SAME_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_FIRST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_FIRST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_LAST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_LAST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported grouped tensor shape representation."); \ + } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..0fb73cc439 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -25,13 +25,6 @@ namespace { constexpr int kMaxTensorsPerKernel = 64; constexpr int kThreadsPerWarp = 32; -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 -}; - __device__ __forceinline__ size_t get_current_tensor_id( const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, const size_t first_logical_dim, const size_t last_logical_dim, diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 755052d6dd..02b88bfba6 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -124,6 +124,19 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped MXFP8 tensor. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_v2(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream); + /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index f7611e60c5..c29b89c832 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -493,7 +493,7 @@ struct alignas(2 * sizeof(T)) FPx2 { }; template -struct FPx4 { +struct alignas(4 * sizeof(T)) FPx4 { T x1; T x2; T x3; @@ -1169,6 +1169,103 @@ __device__ __forceinline__ fp16 get_amax(fp16 a, fp16 b) { #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +// Using mixed precision FMA instruction +__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, const bf16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in, const bf16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in, const fp16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const fp16x4 &in, const fp16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + __device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, const ptx::floatx2 &scale) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 26549191a3..8c50e83926 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -928,6 +928,13 @@ using e8m0_t = uint8_t; enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 }; +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + template struct Numeric_Traits;