From 107a865bc72aa673b5cd4e227e8ed64abd13c9a0 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 13:28:06 +0000 Subject: [PATCH 01/51] Enabled persistency with WorkID Query feature Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 893 ++++++++++-------- 1 file changed, 507 insertions(+), 386 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 129d6724ac..e0a1a1a814 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -39,7 +39,8 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_T constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; -constexpr size_t BUFFS_NUM = 2; +constexpr size_t PREFETCH_STAGES = 1; +constexpr size_t BUFFS_NUM = PREFETCH_STAGES + 1; constexpr size_t PACK_SIZE = 4; constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; @@ -261,93 +262,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const size_t block_ID = blockIdx.y * gridDim.x + blockIdx.x; - const size_t block_global_offset = - is_single_tensor ? (blockIdx.y * CHUNK_DIM_Y * last_logical_dim + blockIdx.x * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); - - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, blockIdx.y, - first_logical_dim, last_logical_dim, offsets_ptr); - - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - - const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); - const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); - - // grouped tensor can be treated as continuous tensor for MXFP8 - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); - // For grouped tensors represented as a single logical tensor, scale swizzle must still be - // computed per tensor (expert) and then concatenated along dim-0. - const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) - ? static_cast(offsets_ptr[tensor_id]) - : tensor_base; - - // In graph-safe paged stashing, the logical shape can include trailing garbage. Skip CTAs that - // map outside the current tensor's valid [rows, cols] region. - if (rows == 0 || cols == 0) { - return; - } - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - return; - } - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - return; - } - } - - const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = - is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = - is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; - const bool leading_thread = (threadIdx.x == 0); - if (leading_thread && (!is_single_tensor)) { - fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&tensor_map_act_input); - } - if constexpr (ROWWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_rowwise); - } - if constexpr (COLWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_colwise); - } - } - - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); - - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; - - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; - - e8m0_t *const scales_rowwise = - scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); - e8m0_t *const scales_colwise = - scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); - - const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; const size_t tid_X_rowwise = threadIdx.x % THREADS_X; const size_t tid_Y_colwise = 0; @@ -356,11 +272,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -390,374 +301,578 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } + constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; float block_amax = 0.0f; -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; + __shared__ uint64_t workID_mbar; + __shared__ __uint128_t workID_response; + constexpr uint32_t workID_response_size = sizeof(workID_response); + static_assert(workID_response_size == 16); - initialize_barriers(mbar, leading_thread); + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], leading_thread); + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::mbarrier_init(&workID_mbar, 1); + ptx::fence_proxy_async_shared_cta(); } + __syncthreads(); + + int IN_buff_readable_parity[BUFFS_NUM] = {0}; + int ctaid_parity = 0; + int32_t ctaid_X = blockIdx.x; + int32_t ctaid_Y = blockIdx.y; + bool job_finished = false; + int buff_in = 0; + + // Prefetch the first stage of the first job. + { + const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (block_ID * ELTS_PER_CHUNK); + + const size_t tensor_id = + get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + if (rows == 0 || cols == 0) { + return; + } + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); + if (block_global_offset >= tensor_end_offset) { + return; + } + const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; + if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { + return; + } + } + + const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t block_id_in_current_tensor = + is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + } #pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const size_t buff = stage; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - leading_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread); + const size_t buff_offset = buff * BUFF_DIM; + uint64_t *barrier = &IN_buff_readable_mbar[buff]; + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buff_offset]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + barrier); + if constexpr (IS_DACT) { + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[buff_offset]), + reinterpret_cast(&tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); + } } } + } - ptx::fence_proxy_async_shared_cta(); + while (!job_finished) { + const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (block_ID * ELTS_PER_CHUNK); + const size_t tensor_id = + get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + bool current_job_is_valid = (rows != 0) && (cols != 0); + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); + if (block_global_offset >= tensor_end_offset) { + current_job_is_valid = false; + } + if (current_job_is_valid) { + const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; + if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { + current_job_is_valid = false; + } + } + } + if (!current_job_is_valid) { + // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); + break; + } - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); + const size_t block_id_in_current_tensor = + is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = + is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = + is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + + if (leading_thread && (!is_single_tensor)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } + } - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); #pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + for (int stage = 0; stage < STAGES; ++stage) { + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (stage == STAGES - PREFETCH_STAGES) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + ctaid_parity ^= 1; + } - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); + if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { + const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + + const size_t prefetch_block_ID = + static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t prefetch_block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (prefetch_block_ID * ELTS_PER_CHUNK); + const size_t prefetch_tensor_id = + get_current_tensor_id(shape_rep, num_tensors, prefetch_block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t prefetch_tensor_base = + is_single_tensor ? 0 : static_cast(offsets_ptr[prefetch_tensor_id]); + const size_t prefetch_cols = + get_tensor_cols_num(prefetch_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + const size_t prefetch_blocks_X_num_in_current_tensor = + DIVUP(prefetch_cols, static_cast(128)); + const size_t prefetch_block_id_in_current_tensor = + is_single_tensor ? prefetch_block_ID + : (prefetch_block_ID - prefetch_tensor_base / ELTS_PER_CHUNK); + const size_t prefetch_block_id_Y = + prefetch_block_id_in_current_tensor / prefetch_blocks_X_num_in_current_tensor; + const size_t prefetch_block_id_X = + prefetch_block_id_in_current_tensor % prefetch_blocks_X_num_in_current_tensor; + + const size_t global_offset_Y = prefetch_block_id_Y * CHUNK_DIM_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = prefetch_block_id_X * CHUNK_DIM_X; + const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; + + const CUtensorMap &prefetch_tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_tensor_id]; + const CUtensorMap &prefetch_tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_tensor_id]; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + if (leading_thread) { + if ((!is_single_tensor) && (stage == STAGES - PREFETCH_STAGES)) { + fence_acquire_tensormap(&prefetch_tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&prefetch_tensor_map_act_input); + } } + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[next_prefetch_buff_offset]), + reinterpret_cast(&prefetch_tensor_map_input), global_offset_X, + global_offset_Y, barrier); if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[next_prefetch_buff_offset]), + reinterpret_cast(&prefetch_tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_colwise[i] = elt; } + ptx::fence_proxy_async_shared_cta(); } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - const size_t tensor_base_row = tensor_base_for_scales / cols; - const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; - const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; - const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); - } else { - scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - } - scales_colwise[scale_idx] = biased_exponent; + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + const size_t buff = buff_in; + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; + } + } - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = tensor_scales_offset_colwise_base + + gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } + scales_colwise[scale_idx] = biased_exponent; - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if constexpr (std::is_same_v) { + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + Vec in_IType[WAVES]; + + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); #pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); } } - } - if constexpr (!std::is_same_v) { thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { + } else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; } } - } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, - DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - } - scales_rowwise[scale_idx] = biased_exponent; + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; -// 3. Scale elements #pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; + for (int w = 0; w < WAVES; ++w) { + Vec out; #pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - } - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); - // Initiate TMA transfer to copy shared memory to global memory - if (leading_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; + if (leading_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + ptx::cp_async_bulk_commit_group(); } - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); + buff_in = (buff_in + 1) % BUFFS_NUM; } - } - - parity ^= 1; - if constexpr (IS_DBIAS) { - if (is_single_tensor) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + for (int i = 0; i < THREADS_Y; ++i) { + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; } - } - const int dbias_stride = cols; - const int dbias_offset_Y = block_id_Y; - const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; } } } @@ -772,7 +887,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, leading_thread); + if (leading_thread) { +#pragma unroll + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } + ptx::mbarrier_invalid(&workID_mbar); + } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace group_quantize_kernel From caf664f7d4bf3068cac2e9587c755c03eca144d3 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 14:03:59 +0000 Subject: [PATCH 02/51] Added a struct with tunable parameters Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index e0a1a1a814..31e8645e07 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -36,17 +36,26 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +struct TunableConfig { + static constexpr size_t CHUNK_DIM_Y = 128; + static constexpr size_t CHUNK_DIM_X = 128; + static constexpr size_t THREADS_PER_CHUNK = 128; + static constexpr size_t PREFETCH_STAGES = 1; + // Set false to run one-CTA-per-block (non-persistent) mode. + static constexpr bool PERSISTENT = true; +}; + constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; -constexpr size_t PREFETCH_STAGES = 1; +constexpr size_t PREFETCH_STAGES = TunableConfig::PREFETCH_STAGES; constexpr size_t BUFFS_NUM = PREFETCH_STAGES + 1; constexpr size_t PACK_SIZE = 4; constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; +constexpr size_t CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK = TunableConfig::THREADS_PER_CHUNK; constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; @@ -511,9 +520,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - if (leading_thread) { - ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); - ptx::try_cancel_cta(&workID_mbar, &workID_response); + if constexpr (TunableConfig::PERSISTENT) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); + ptx::try_cancel_cta(&workID_mbar, &workID_response); + } } #pragma unroll @@ -521,12 +532,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (stage == STAGES - PREFETCH_STAGES) { - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); - ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + if constexpr (TunableConfig::PERSISTENT) { + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); + ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); + ctaid_parity ^= 1; + } else { + ctaid_X = -1; + ctaid_Y = -1; + } if (ctaid_X == -1 && ctaid_Y == -1) { job_finished = true; } - ctaid_parity ^= 1; } if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { From 68dbc624b316c0494c28b532cff5c71f3bd0823d Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 15:15:14 +0000 Subject: [PATCH 03/51] Added persistency with static scheduling Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 127 +++++++++++++----- 1 file changed, 96 insertions(+), 31 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 31e8645e07..791fe65098 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -19,6 +19,7 @@ #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" +#include "../../util/cuda_runtime.h" #include "../../utils.cuh" #include "../core/common.cuh" #include "swizzle.cuh" @@ -36,15 +37,30 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; +enum class PersistentStrategy : int { + NONE = 0, + DYNAMIC_WORK_STEALING = 1, + STATIC_GRID_STRIDE = 2, +}; + struct TunableConfig { static constexpr size_t CHUNK_DIM_Y = 128; static constexpr size_t CHUNK_DIM_X = 128; static constexpr size_t THREADS_PER_CHUNK = 128; static constexpr size_t PREFETCH_STAGES = 1; - // Set false to run one-CTA-per-block (non-persistent) mode. - static constexpr bool PERSISTENT = true; + static constexpr PersistentStrategy PERSISTENT_STRATEGY = + PersistentStrategy::STATIC_GRID_STRIDE; + // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). + static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 1; }; +constexpr bool DYNAMIC_PERSISTENT = + TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; +constexpr bool STATIC_PERSISTENT = + TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; +static_assert(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0, + "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero."); + constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -251,7 +267,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, - float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) { + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr, + const size_t work_blocks_X, const size_t work_blocks_Y) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -315,8 +332,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel float block_amax = 0.0f; __shared__ uint64_t workID_mbar; - __shared__ __uint128_t workID_response; - constexpr uint32_t workID_response_size = sizeof(workID_response); + [[maybe_unused]] __shared__ __uint128_t workID_response; + [[maybe_unused]] constexpr uint32_t workID_response_size = sizeof(workID_response); static_assert(workID_response_size == 16); __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; @@ -331,16 +348,32 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } __syncthreads(); + const size_t total_work_blocks = work_blocks_X * work_blocks_Y; + const size_t launch_block_id = + static_cast(blockIdx.y) * static_cast(gridDim.x) + static_cast(blockIdx.x); + int IN_buff_readable_parity[BUFFS_NUM] = {0}; - int ctaid_parity = 0; - int32_t ctaid_X = blockIdx.x; - int32_t ctaid_Y = blockIdx.y; + [[maybe_unused]] int ctaid_parity = 0; + int32_t ctaid_X = static_cast(blockIdx.x); + int32_t ctaid_Y = static_cast(blockIdx.y); + [[maybe_unused]] size_t static_next_block_id = 0; + [[maybe_unused]] size_t static_block_stride = 0; + if constexpr (STATIC_PERSISTENT) { + if (launch_block_id >= total_work_blocks) { + return; + } + ctaid_X = static_cast(launch_block_id % work_blocks_X); + ctaid_Y = static_cast(launch_block_id / work_blocks_X); + static_block_stride = static_cast(gridDim.x) * static_cast(gridDim.y); + static_next_block_id = launch_block_id + static_block_stride; + } bool job_finished = false; int buff_in = 0; + bool has_prefetched_current_job = true; // Prefetch the first stage of the first job. { - const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); const size_t block_global_offset = is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + static_cast(ctaid_X) * CHUNK_DIM_X) @@ -352,7 +385,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - if (rows == 0 || cols == 0) { + if (block_ID >= total_work_blocks || rows == 0 || cols == 0) { return; } if (shape_rep != SAME_BOTH_DIMS) { @@ -415,7 +448,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } while (!job_finished) { - const size_t block_ID = static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); const size_t block_global_offset = is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + static_cast(ctaid_X) * CHUNK_DIM_X) @@ -428,7 +461,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - bool current_job_is_valid = (rows != 0) && (cols != 0); + bool current_job_is_valid = (block_ID < total_work_blocks) && (rows != 0) && (cols != 0); if (shape_rep != SAME_BOTH_DIMS) { const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); @@ -445,11 +478,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } if (!current_job_is_valid) { - // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); - IN_buff_readable_parity[buff_in] ^= 1; - ptx::cp_async_bulk_wait_group_read(); + if (has_prefetched_current_job) { + // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); + } break; } @@ -520,38 +555,56 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - if constexpr (TunableConfig::PERSISTENT) { + if constexpr (DYNAMIC_PERSISTENT) { if (leading_thread) { ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); ptx::try_cancel_cta(&workID_mbar, &workID_response); } } + bool prefetched_next_job = false; #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t stage_offset_Y = stage * BUFF_DIM_Y; + bool allow_next_job_prefetch = true; if (stage == STAGES - PREFETCH_STAGES) { - if constexpr (TunableConfig::PERSISTENT) { + if constexpr (DYNAMIC_PERSISTENT) { ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); ctaid_parity ^= 1; + } else if constexpr (STATIC_PERSISTENT) { + if (static_next_block_id < total_work_blocks) { + ctaid_X = static_cast(static_next_block_id % work_blocks_X); + ctaid_Y = static_cast(static_next_block_id / work_blocks_X); + static_next_block_id += static_block_stride; + } else { + // Next loop iteration exits via current_job_is_valid check. + ctaid_X = 0; + ctaid_Y = static_cast(work_blocks_Y); + allow_next_job_prefetch = false; + } } else { ctaid_X = -1; ctaid_Y = -1; } - if (ctaid_X == -1 && ctaid_Y == -1) { - job_finished = true; + if constexpr (!STATIC_PERSISTENT) { + if (ctaid_X == -1 && ctaid_Y == -1) { + job_finished = true; + } } } - if (!job_finished || (stage < STAGES - PREFETCH_STAGES)) { + if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + if (stage >= STAGES - PREFETCH_STAGES) { + prefetched_next_job = true; + } const size_t prefetch_block_ID = - static_cast(ctaid_Y) * gridDim.x + static_cast(ctaid_X); + static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); const size_t prefetch_block_global_offset = is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + static_cast(ctaid_X) * CHUNK_DIM_X) @@ -851,6 +904,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel buff_in = (buff_in + 1) % BUFFS_NUM; } + has_prefetched_current_job = prefetched_next_job; if constexpr (IS_DBIAS) { if (is_single_tensor) { @@ -969,20 +1023,30 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks_X = 0; - size_t blocks_Y = 0; + size_t work_blocks_X = 0; + size_t work_blocks_Y = 0; if (is_single_tensor) { - blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + work_blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + work_blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); } else { NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - blocks_Y = 1; - blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + work_blocks_Y = 1; + work_blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + } + + size_t launch_blocks_X = work_blocks_X; + size_t launch_blocks_Y = work_blocks_Y; + if constexpr (STATIC_PERSISTENT) { + const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); + const size_t static_grid_size = sm_num * TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; + NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); + launch_blocks_X = static_grid_size; + launch_blocks_Y = 1; } - const dim3 grid(blocks_X, blocks_Y); + const dim3 grid(launch_blocks_X, launch_blocks_Y); const size_t block_size = THREADS_PER_CHUNK; const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -1138,7 +1202,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, + work_blocks_Y); if constexpr (IS_DBIAS) { common::grouped_reduce_dbias( From 051d925b70b5035a84e4fb901b4ceca271c0bff5 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 16:15:42 +0000 Subject: [PATCH 04/51] Fixed test cases Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 51 +++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index e469ad0845..6cff159d51 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -554,6 +554,11 @@ void performTest(const ProcessingMethod processing_method, const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; + // Compare only allocated contiguous output range. + // In graph-safe mode logical shape may include trailing garbage beyond offsets_h.back(). + const size_t compare_rows = 1; + const size_t compare_cols = elts_num; + if (rowwise) { cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost); cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost); @@ -566,7 +571,8 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("rowwise_output", out_data_rowwise_ref.data(), - out_data_rowwise_h.data(), rows, cols, true, mismatches_elts); + out_data_rowwise_h.data(), compare_rows, compare_cols, + true, mismatches_elts); } if (colwise) { @@ -581,7 +587,8 @@ void performTest(const ProcessingMethod processing_method, const size_t mismatches_elts = 32 * mismatches_scales; compare_scaled_elts("colwise_output", out_data_colwise_ref.data(), - out_data_colwise_h.data(), rows, cols, false, mismatches_elts); + out_data_colwise_h.data(), compare_rows, compare_cols, + false, mismatches_elts); } if (compute_dbias) { @@ -616,15 +623,15 @@ void performTest(const ProcessingMethod processing_method, std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - ProcessingMethod::CAST_DBIAS, - ProcessingMethod::CAST_DBIAS_DACT, - ProcessingMethod::CAST_DACT, - ProcessingMethod::CAST_ACT, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, }; std::vector activation_kinds = { ActivationKind::Identity, - ActivationKind::GeLU, + // ActivationKind::GeLU, // ActivationKind::SiLU, // ActivationKind::ReLU, // ActivationKind::QGeLU, @@ -639,21 +646,23 @@ enum ScalingDirection { std::vector scaling_directions = { ScalingDirection::ROWWISE, - ScalingDirection::COLWISE, - ScalingDirection::BOTH, + // ScalingDirection::COLWISE, + // ScalingDirection::BOTH, }; // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - {SAME_BOTH_DIMS, 1, 128,128}, - {SAME_BOTH_DIMS, 2, 256,128}, - {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, - {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + // {SAME_BOTH_DIMS, 1, 128,128}, + // {SAME_BOTH_DIMS, 2, 256,128}, + // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 4096,4096, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,4096, 128,256,384,1024,2304}, + // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace @@ -815,8 +824,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(activation_kinds), ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kBFloat16), + ::testing::Values(DType::kFloat8E4M3)), + // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); From 2f9a299a770db04e42afbc56cc780a4a5daeac0f Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 16:18:07 +0000 Subject: [PATCH 05/51] Ready for benchmarking Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 3 +- tests/cpp/operator/CMakeLists.txt | 56 +- .../common/activation/activation_template.h | 30 +- .../common/cast/dispatch/dequantize.cuh | 52 +- .../common/cast/dispatch/gated.cuh | 304 ++++---- .../common/cast/dispatch/quantize.cuh | 729 +++++++++--------- 6 files changed, 589 insertions(+), 585 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 6f4f163f08..2092975b2a 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -8,7 +8,8 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + set(CMAKE_CUDA_ARCHITECTURES 100) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 56880a428d..a04cc3c38c 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,35 +3,35 @@ # See LICENSE for license information. add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu + # test_cast.cu + # test_cast_current_scaling.cu + # test_cast_dbias.cu + # test_cast_dbias_dgelu.cu + # test_cast_gated_swiglu.cu + # test_cast_mxfp8_gated_swiglu.cu + # test_qdq.cu + # test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu - test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swizzle.cu - test_swap_first_dims.cu - test_grouped_gemm.cu + # test_cast_nvfp4_transpose.cu + # test_cast_float8blockwise.cu + # test_dequantize_mxfp8.cu + # test_transpose.cu + # test_cast_transpose.cu + # test_cast_transpose_current_scaling.cu + # test_cast_transpose_dbias.cu + # test_cast_transpose_dbias_dgelu.cu + # test_cast_transpose_dgeglu.cu + # test_act.cu + # test_normalization.cu + # test_normalization_mxfp8.cu + # test_memset.cu + # test_multi_cast_transpose.cu + # test_multi_padding.cu + # test_multi_unpadding.cu + # test_causal_softmax.cu + # test_swizzle.cu + # test_swap_first_dims.cu + # test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index ffbffafd1a..caf6cbda65 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -22,36 +22,36 @@ namespace transformer_engine { template void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - using namespace detail; - constexpr bool IS_ACT = true; - dispatch::quantize_fwd_helper(input, output, nullptr, stream); + // using namespace detail; + // constexpr bool IS_ACT = true; + // dispatch::quantize_fwd_helper(input, output, nullptr, stream); } template void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { - using namespace detail; - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = true; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - - dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, - nullptr, stream); + // using namespace detail; + // constexpr bool IS_DBIAS = false; + // constexpr bool IS_DACT = true; + // constexpr NVTETensor dbias = nullptr; + // constexpr NVTETensor workspace = nullptr; + + // dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, + // nullptr, stream); } template void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - using namespace detail; - dispatch::quantize_gated_fwd_helper(input, output, p, stream); + // using namespace detail; + // dispatch::quantize_gated_fwd_helper(input, output, p, stream); } template void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - using namespace detail; - dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); + // using namespace detail; + // dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 81304981d3..db2ad285a8 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -22,32 +22,32 @@ namespace transformer_engine { namespace dispatch { inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - switch (input.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); - NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); - fp8::dequantize(input, output, stream); - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { - mxfp8::dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - break; - } - case NVTE_NVFP4_1D_SCALING: { - nvfp4::dequantize(input, output, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } + // CheckInputTensor(input, "cast_input"); + // CheckOutputTensor(*output, "cast_output"); + + // switch (input.scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); + // NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); + // NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); + // fp8::dequantize(input, output, stream); + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // if (is_supported_by_CC_100()) { + // mxfp8::dequantize(input, output, stream); + // } else { + // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + // } + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // nvfp4::dequantize(input, output, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + // } } } // namespace dispatch diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 06e8f0e306..c2087533a6 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -25,164 +25,164 @@ namespace dispatch { template void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - const Tensor input = *convertNVTETensorCheck(nvte_input); - Tensor *output = convertNVTETensorCheck(nvte_output); - - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim() / 2; - - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == cols, - "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", - output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - if (use_tma_kernels) { - Tensor dummy_grad_tensor; - fp8::cast_gated_tma(input, dummy_grad_tensor, - output, p, stream); - } else { - fp8::cast_gated_fwd(input, output, p, stream); - } - if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // FP8 kernel only populates row-wise data, so perform - // transpose separately if needed - Tensor transpose_in, transpose_out, dummy; - transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_in.data.dptr = output->data.dptr; - transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - transpose_in.data.dtype = output->data.dtype; - transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_out.data.dptr = output->columnwise_data.dptr; - transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - transpose_out.data.dtype = output->data.dtype; - detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - NVTE_CHECK(cols % 32 == 0, - "Invalid input shape. Expected the last dimension to be " - "divisible by 32, but got ", - cols, "."); - if (output->has_data()) { - NVTE_CHECK(is_fp8_dtype(output->data.dtype), - "The type of the output tensor should be FP8."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - "The type of the columnwise output tensor should be FP8."); - } - NVTE_CHECK(is_supported_by_CC_100(), - "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - Tensor dummy_grad_tensor; - mxfp8::quantize_gated(input, dummy_grad_tensor, - output, p, stream); - break; - } - default: - NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - } + // const Tensor input = *convertNVTETensorCheck(nvte_input); + // Tensor *output = convertNVTETensorCheck(nvte_output); + + // CheckInputTensor(input, "input"); + // CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + // const size_t rows = input.flat_first_dim(); + // const size_t cols = input.flat_last_dim() / 2; + + // NVTE_CHECK(input.flat_last_dim() % 2 == 0, + // "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + // input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + // NVTE_CHECK(output->flat_last_dim() == cols, + // "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // switch (output->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // if (use_tma_kernels) { + // Tensor dummy_grad_tensor; + // fp8::cast_gated_tma(input, dummy_grad_tensor, + // output, p, stream); + // } else { + // fp8::cast_gated_fwd(input, output, p, stream); + // } + // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // // FP8 kernel only populates row-wise data, so perform + // // transpose separately if needed + // Tensor transpose_in, transpose_out, dummy; + // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_in.data.dptr = output->data.dptr; + // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + // transpose_in.data.dtype = output->data.dtype; + // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_out.data.dptr = output->columnwise_data.dptr; + // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + // transpose_out.data.dtype = output->data.dtype; + // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // NVTE_CHECK(cols % 32 == 0, + // "Invalid input shape. Expected the last dimension to be " + // "divisible by 32, but got ", + // cols, "."); + // if (output->has_data()) { + // NVTE_CHECK(is_fp8_dtype(output->data.dtype), + // "The type of the output tensor should be FP8."); + // } + // if (output->has_columnwise_data()) { + // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + // "The type of the columnwise output tensor should be FP8."); + // } + // NVTE_CHECK(is_supported_by_CC_100(), + // "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + // Tensor dummy_grad_tensor; + // mxfp8::quantize_gated(input, dummy_grad_tensor, + // output, p, stream); + // break; + // } + // default: + // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + // } } template void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); - const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); - Tensor *output = convertNVTETensorCheck(nvte_output); - - CheckInputTensor(grad, "grad"); - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", - gated_input.flat_last_dim(), "."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - - NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); - NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); - - NVTE_CHECK(grad.flat_first_dim() == rows, - "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", - grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - NVTE_CHECK(grad.flat_last_dim() == cols, - "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", - grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", - rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == cols * 2, - "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", - output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(gated_input.shape() == output->shape(), - "Gated input and output shapes must match. Input shape: ", gated_input.shape(), - ", output shape: ", output->shape(), "."); - - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - if (use_tma_kernels) { - fp8::cast_gated_tma(gated_input, grad, output, p, - stream); - } else { - fp8::cast_gated_bwd(gated_input, grad, output, p, stream); - } - if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // FP8 kernel only populates row-wise data, so perform - // transpose separately if needed - Tensor transpose_in, transpose_out, dummy; - transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_in.data.dptr = output->data.dptr; - transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - transpose_in.data.dtype = output->data.dtype; - transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_out.data.dptr = output->columnwise_data.dptr; - transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - transpose_out.data.dtype = output->data.dtype; - detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - NVTE_CHECK(cols % 32 == 0, - "Invalid input shape. Expected the last dimension to be " - "divisible by 32, but got ", - cols, "."); - if (output->has_data()) { - NVTE_CHECK(is_fp8_dtype(output->data.dtype), - "The type of the output tensor should be FP8."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - "The type of the columnwise output tensor should be FP8."); - } - NVTE_CHECK(is_supported_by_CC_100(), - "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - - mxfp8::quantize_gated(gated_input, grad, output, p, - stream); - break; - } - default: - NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - } + // const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + // const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + // Tensor *output = convertNVTETensorCheck(nvte_output); + + // CheckInputTensor(grad, "grad"); + // CheckInputTensor(gated_input, "gated_input"); + // CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + // NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + // gated_input.flat_last_dim(), "."); + + // const size_t rows = gated_input.flat_first_dim(); + // const size_t cols = gated_input.flat_last_dim() / 2; + + // NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); + // NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); + + // NVTE_CHECK(grad.flat_first_dim() == rows, + // "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + // NVTE_CHECK(grad.flat_last_dim() == cols, + // "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + // rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + // NVTE_CHECK(output->flat_last_dim() == cols * 2, + // "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + // NVTE_CHECK(gated_input.shape() == output->shape(), + // "Gated input and output shapes must match. Input shape: ", gated_input.shape(), + // ", output shape: ", output->shape(), "."); + + // switch (output->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // if (use_tma_kernels) { + // fp8::cast_gated_tma(gated_input, grad, output, p, + // stream); + // } else { + // fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + // } + // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // // FP8 kernel only populates row-wise data, so perform + // // transpose separately if needed + // Tensor transpose_in, transpose_out, dummy; + // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_in.data.dptr = output->data.dptr; + // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + // transpose_in.data.dtype = output->data.dtype; + // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_out.data.dptr = output->columnwise_data.dptr; + // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + // transpose_out.data.dtype = output->data.dtype; + // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // NVTE_CHECK(cols % 32 == 0, + // "Invalid input shape. Expected the last dimension to be " + // "divisible by 32, but got ", + // cols, "."); + // if (output->has_data()) { + // NVTE_CHECK(is_fp8_dtype(output->data.dtype), + // "The type of the output tensor should be FP8."); + // } + // if (output->has_columnwise_data()) { + // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + // "The type of the columnwise output tensor should be FP8."); + // } + // NVTE_CHECK(is_supported_by_CC_100(), + // "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + // mxfp8::quantize_gated(gated_input, grad, output, p, + // stream); + // break; + // } + // default: + // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + // } } } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index f7823b4c58..0aadffa940 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -30,282 +30,282 @@ namespace dispatch { template void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *input_tensor = convertNVTETensorCheck(input); - Tensor *output_tensor = convertNVTETensorCheck(output); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const Tensor *dummy_input_tensor = nullptr; - Tensor *dummy_dbias_tensor = nullptr; - Tensor *dummy_workspace_tensor = nullptr; - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_ACT) { - cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - const Tensor *dummy_input_tensor = nullptr; - Tensor *dummy_dbias_tensor = nullptr; - Tensor *dummy_workspace_tensor = nullptr; - mxfp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *input_tensor = convertNVTETensorCheck(input); + // Tensor *output_tensor = convertNVTETensorCheck(output); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // // Dispatch to quantization kernel depending on data format + // switch (output_tensor->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const Tensor *dummy_input_tensor = nullptr; + // Tensor *dummy_dbias_tensor = nullptr; + // Tensor *dummy_workspace_tensor = nullptr; + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_ACT) { + // cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // } else if (output_tensor->has_data()) { + // fp8::quantize( + // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // const Tensor *dummy_input_tensor = nullptr; + // Tensor *dummy_dbias_tensor = nullptr; + // Tensor *dummy_workspace_tensor = nullptr; + // mxfp8::quantize( + // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // CheckOutputTensor(*output_tensor, "output", false); + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // (cols % 32 == 0) && output_tensor->has_data(); + + // // Launch NVFP4 quantize kernel + // if (use_optimized_kernel) { + // if (quant_config_cpp.nvfp4_2d_quantization) { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } else { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } + // } else { + // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // : output_tensor->columnwise_amax; + // quantize_transpose_vector_blockwise_fp4( + // /*input=*/input_tensor->data, /*global_amax=*/global_amax, + // /*scale_inv=*/output_tensor->scale_inv, + // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // /*swizzled_scale=*/false, + // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // /*rng_state=*/quant_config_cpp.rng_state, + // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // } + // break; + // } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // quantize_transpose_square_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // /*noop_tensor=*/noop_tensor->data, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + // } } template void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *grad_tensor = convertNVTETensorCheck(grad); - const Tensor *input_tensor = convertNVTETensor(input); - - Tensor *output_tensor = convertNVTETensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT) { - cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*grad_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = grad_tensor->flat_first_dim(); - int32_t cols = grad_tensor->flat_last_dim(); - auto dtype = grad_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/grad_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *grad_tensor = convertNVTETensorCheck(grad); + // const Tensor *input_tensor = convertNVTETensor(input); + + // Tensor *output_tensor = convertNVTETensorCheck(output); + // Tensor *dbias_tensor = convertNVTETensor(dbias); + // Tensor *workspace_tensor = convertNVTETensor(workspace); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // // Dispatch to quantization kernel depending on data format + // switch (output_tensor->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_DBIAS && !IS_DACT) { + // cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + // } + // } else if (output_tensor->has_data()) { + // fp8::quantize( + // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::quantize( + // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*grad_tensor, "input"); + // CheckOutputTensor(*output_tensor, "output", false); + + // // Choose kernel + // int32_t rows = grad_tensor->flat_first_dim(); + // int32_t cols = grad_tensor->flat_last_dim(); + // auto dtype = grad_tensor->dtype(); + // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // (cols % 32 == 0) && output_tensor->has_data(); + + // // Launch NVFP4 quantize kernel + // if (use_optimized_kernel) { + // if (quant_config_cpp.nvfp4_2d_quantization) { + // nvfp4::quantize_transpose( + // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } else { + // nvfp4::quantize_transpose( + // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } + // } else { + // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // : output_tensor->columnwise_amax; + // quantize_transpose_vector_blockwise_fp4( + // /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + // /*scale_inv=*/output_tensor->scale_inv, + // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // /*swizzled_scale=*/false, + // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // /*rng_state=*/quant_config_cpp.rng_state, + // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // } + // break; + // } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // quantize_transpose_square_blockwise( + // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // /*noop_tensor=*/noop_tensor->data, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise( + // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + // } } // Host-aware and not graph-safe: group quantization with split section info from the host. @@ -314,64 +314,64 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *input_tensor = convertNVTETensorCheck(input); - std::vector output_tensors; - for (size_t i = 0; i < num_tensors; ++i) { - output_tensors.push_back(convertNVTETensorCheck(outputs[i])); - } - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - // Take the scaling mode of the first output tensor - auto scaling_mode = output_tensors[0]->scaling_mode; - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - // Skip checking output tensor list - // output list here is allowed to have empty tensor - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - - NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "2D quantization is not supported for group quantize."); - - // Launch NVFP4 group quantize kernel - nvfp4::group_quantize_transpose( - *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, - &quant_config_cpp, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *input_tensor = convertNVTETensorCheck(input); + // std::vector output_tensors; + // for (size_t i = 0; i < num_tensors; ++i) { + // output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + // } + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // // Take the scaling mode of the first output tensor + // auto scaling_mode = output_tensors[0]->scaling_mode; + + // // Dispatch to quantization kernel depending on data format + // switch (scaling_mode) { + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // // Skip checking output tensor list + // // output list here is allowed to have empty tensor + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + + // NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + // "2D quantization is not supported for group quantize."); + + // // Launch NVFP4 group quantize kernel + // nvfp4::group_quantize_transpose( + // *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + // &quant_config_cpp, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + // } } template @@ -407,7 +407,10 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( + // mxfp8::group_quantize( + // IS_ACT is set to false + // OP is set to nullptr + mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; @@ -422,40 +425,40 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); - const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); - GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( - grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } + // using namespace detail; + + // NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + // const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + // const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + // GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + // GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + // Tensor *workspace_tensor = convertNVTETensor(workspace); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Dispatch to quantization kernel depending on data format + // switch (scaling_mode) { + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::group_quantize( + // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + // } } } // namespace dispatch From c040d59a03a6fb60b4c6f7a1fd5e1da678593aa2 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 17:41:17 +0000 Subject: [PATCH 06/51] Fixed out-of-boundary error Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 791fe65098..2cbcfe8218 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -51,7 +51,7 @@ struct TunableConfig { static constexpr PersistentStrategy PERSISTENT_STRATEGY = PersistentStrategy::STATIC_GRID_STRIDE; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). - static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 1; + static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; constexpr bool DYNAMIC_PERSISTENT = @@ -595,6 +595,44 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } + // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. + // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. + if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { + const size_t next_block_ID = + static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); + const size_t next_block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (next_block_ID * ELTS_PER_CHUNK); + const size_t next_tensor_id = + get_current_tensor_id(shape_rep, num_tensors, next_block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t next_rows = + get_tensor_rows_num(next_tensor_id, shape_rep, first_logical_dim, first_dims_ptr, + num_tensors); + const size_t next_cols = + get_tensor_cols_num(next_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + + bool next_job_is_valid = + (next_block_ID < total_work_blocks) && (next_rows != 0) && (next_cols != 0); + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[next_tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[next_tensor_id + 1]); + if (next_block_global_offset >= tensor_end_offset) { + next_job_is_valid = false; + } + if (next_job_is_valid) { + const size_t tensor_offset_from_start = next_block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / next_cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % next_cols; + if (block_offset_Y_in_tensor >= next_rows || block_offset_X_in_tensor >= next_cols) { + next_job_is_valid = false; + } + } + } + allow_next_job_prefetch = next_job_is_valid; + } + if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; From 30c28fbdd49db0b998e344d2f606f065c3c9679d Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 17:59:35 +0000 Subject: [PATCH 07/51] Tuned kernel parameters Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 2cbcfe8218..26baa86ae4 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -48,16 +48,13 @@ struct TunableConfig { static constexpr size_t CHUNK_DIM_X = 128; static constexpr size_t THREADS_PER_CHUNK = 128; static constexpr size_t PREFETCH_STAGES = 1; - static constexpr PersistentStrategy PERSISTENT_STRATEGY = - PersistentStrategy::STATIC_GRID_STRIDE; + static constexpr PersistentStrategy PERSISTENT_STRATEGY = PersistentStrategy::STATIC_GRID_STRIDE; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; -constexpr bool DYNAMIC_PERSISTENT = - TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; -constexpr bool STATIC_PERSISTENT = - TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; +constexpr bool DYNAMIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; +constexpr bool STATIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; static_assert(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0, "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero."); From 977168eb8df0ced65188ea4469de4f74f3c17b93 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 18:12:36 +0000 Subject: [PATCH 08/51] Refactoring Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 261 +++++++++--------- 1 file changed, 123 insertions(+), 138 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 26baa86ae4..47f7131e76 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -152,6 +152,84 @@ __device__ __forceinline__ size_t get_tensor_cols_num( return cols_num; } +// Logical work-item decoded from CTA coordinates. +struct JobDescriptor { + size_t block_id = 0; + size_t block_global_offset = 0; + size_t tensor_id = 0; + size_t rows = 0; + size_t cols = 0; +}; + +// Tensor-local coordinates for a work-item. +struct BlockDescriptor { + size_t tensor_base = 0; + size_t block_id_in_current_tensor = 0; + size_t block_id_Y = 0; + size_t block_id_X = 0; + size_t block_offset_Y = 0; + size_t block_offset_X = 0; +}; + +__device__ __forceinline__ JobDescriptor decode_job( + const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, + const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { + JobDescriptor job{}; + job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); + job.block_global_offset = + is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + + static_cast(ctaid_X) * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); + job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + job.rows = + get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + return job; +} + +__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, + const ShapeRepresentation shape_rep, + const size_t total_work_blocks, + const int64_t *const __restrict__ offsets_ptr) { + bool is_valid = (job.block_id < total_work_blocks) && (job.rows != 0) && (job.cols != 0); + if (!is_valid || shape_rep == SAME_BOTH_DIMS) { + return is_valid; + } + + const size_t tensor_start_offset = static_cast(offsets_ptr[job.tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[job.tensor_id + 1]); + if (job.block_global_offset >= tensor_end_offset) { + return false; + } + + const size_t tensor_offset_from_start = job.block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / job.cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % job.cols; + if (block_offset_Y_in_tensor >= job.rows || block_offset_X_in_tensor >= job.cols) { + return false; + } + + return true; +} + +__device__ __forceinline__ BlockDescriptor decode_block(const JobDescriptor &job, + const bool is_single_tensor, + const int64_t *const __restrict__ offsets_ptr) { + BlockDescriptor block{}; + block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(128)); + block.block_id_in_current_tensor = + is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK); + block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor; + block.block_id_X = block.block_id_in_current_tensor % blocks_X_num_in_current_tensor; + block.block_offset_Y = block.block_id_Y * CHUNK_DIM_Y; + block.block_offset_X = block.block_id_X * CHUNK_DIM_X; + return block; +} + // Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map, CUtensorMap *global_tensor_map, @@ -335,6 +413,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + // Initialize barriers shared by the entire CTA: + // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. + // - workID_mbar synchronizes WorkID query response in dynamic persistent mode. if (leading_thread) { #pragma unroll for (int buff = 0; buff < BUFFS_NUM; ++buff) { @@ -355,6 +436,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel int32_t ctaid_Y = static_cast(blockIdx.y); [[maybe_unused]] size_t static_next_block_id = 0; [[maybe_unused]] size_t static_block_stride = 0; + // In STATIC_PERSISTENT mode physical CTAs iterate over a virtual work grid via grid-stride. if constexpr (STATIC_PERSISTENT) { if (launch_block_id >= total_work_blocks) { return; @@ -368,50 +450,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel int buff_in = 0; bool has_prefetched_current_job = true; - // Prefetch the first stage of the first job. + // Prime the pipeline with stage-0 of the first job assigned to this CTA. { - const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); - - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - if (block_ID >= total_work_blocks || rows == 0 || cols == 0) { + const JobDescriptor first_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + if (!is_job_valid(first_job, shape_rep, total_work_blocks, offsets_ptr)) { return; } - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - return; - } - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - return; - } - } - - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + const BlockDescriptor first_block = decode_block(first_job, is_single_tensor, offsets_ptr); const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[first_job.tensor_id]; const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[first_job.tensor_id]; if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); @@ -424,8 +476,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { const size_t buff = stage; const size_t stage_offset_Y = stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; + const size_t global_offset_Y = first_block.block_offset_Y + stage_offset_Y; + const size_t global_offset_X = first_block.block_offset_X; const size_t buff_offset = buff * BUFF_DIM; uint64_t *barrier = &IN_buff_readable_mbar[buff]; if (leading_thread) { @@ -444,36 +496,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } + // Main persistent loop: decode current job, run all 32-row stages, schedule/prefetch next job. while (!job_finished) { - const size_t block_ID = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (block_ID * ELTS_PER_CHUNK); - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - - bool current_job_is_valid = (block_ID < total_work_blocks) && (rows != 0) && (cols != 0); - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); - if (block_global_offset >= tensor_end_offset) { - current_job_is_valid = false; - } - if (current_job_is_valid) { - const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; - if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { - current_job_is_valid = false; - } - } - } + // Decode CTA assignment into logical tensor coordinates and validate bounds. + const JobDescriptor current_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const bool current_job_is_valid = + is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr); if (!current_job_is_valid) { if (has_prefetched_current_job) { // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. @@ -485,21 +515,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel break; } + const size_t tensor_id = current_job.tensor_id; + const size_t rows = current_job.rows; + const size_t cols = current_job.cols; + const BlockDescriptor current_block = decode_block(current_job, is_single_tensor, offsets_ptr); + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); - const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + const size_t tensor_base = current_block.tensor_base; const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) ? static_cast(offsets_ptr[tensor_id]) : tensor_base; - const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast(128)); - const size_t block_id_in_current_tensor = - is_single_tensor ? block_ID : (block_ID - tensor_base / ELTS_PER_CHUNK); - const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; - const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; - - const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; - const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + const size_t block_id_Y = current_block.block_id_Y; + const size_t block_id_X = current_block.block_id_X; + const size_t block_offset_Y = current_block.block_offset_Y; + const size_t block_offset_X = current_block.block_offset_X; e8m0_t *const scales_rowwise = scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); @@ -560,6 +591,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } bool prefetched_next_job = false; + // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t stage_offset_Y = stage * BUFF_DIM_Y; @@ -595,39 +627,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { - const size_t next_block_ID = - static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t next_block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (next_block_ID * ELTS_PER_CHUNK); - const size_t next_tensor_id = - get_current_tensor_id(shape_rep, num_tensors, next_block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - const size_t next_rows = - get_tensor_rows_num(next_tensor_id, shape_rep, first_logical_dim, first_dims_ptr, - num_tensors); - const size_t next_cols = - get_tensor_cols_num(next_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - - bool next_job_is_valid = - (next_block_ID < total_work_blocks) && (next_rows != 0) && (next_cols != 0); - if (shape_rep != SAME_BOTH_DIMS) { - const size_t tensor_start_offset = static_cast(offsets_ptr[next_tensor_id]); - const size_t tensor_end_offset = static_cast(offsets_ptr[next_tensor_id + 1]); - if (next_block_global_offset >= tensor_end_offset) { - next_job_is_valid = false; - } - if (next_job_is_valid) { - const size_t tensor_offset_from_start = next_block_global_offset - tensor_start_offset; - const size_t block_offset_Y_in_tensor = tensor_offset_from_start / next_cols; - const size_t block_offset_X_in_tensor = tensor_offset_from_start % next_cols; - if (block_offset_Y_in_tensor >= next_rows || block_offset_X_in_tensor >= next_cols) { - next_job_is_valid = false; - } - } - } - allow_next_job_prefetch = next_job_is_valid; + const JobDescriptor next_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + allow_next_job_prefetch = is_job_valid(next_job, shape_rep, total_work_blocks, offsets_ptr); } if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { @@ -638,37 +641,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel prefetched_next_job = true; } - const size_t prefetch_block_ID = - static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - const size_t prefetch_block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (prefetch_block_ID * ELTS_PER_CHUNK); - const size_t prefetch_tensor_id = - get_current_tensor_id(shape_rep, num_tensors, prefetch_block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - const size_t prefetch_tensor_base = - is_single_tensor ? 0 : static_cast(offsets_ptr[prefetch_tensor_id]); - const size_t prefetch_cols = - get_tensor_cols_num(prefetch_tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - const size_t prefetch_blocks_X_num_in_current_tensor = - DIVUP(prefetch_cols, static_cast(128)); - const size_t prefetch_block_id_in_current_tensor = - is_single_tensor ? prefetch_block_ID - : (prefetch_block_ID - prefetch_tensor_base / ELTS_PER_CHUNK); - const size_t prefetch_block_id_Y = - prefetch_block_id_in_current_tensor / prefetch_blocks_X_num_in_current_tensor; - const size_t prefetch_block_id_X = - prefetch_block_id_in_current_tensor % prefetch_blocks_X_num_in_current_tensor; - - const size_t global_offset_Y = prefetch_block_id_Y * CHUNK_DIM_Y + next_prefetch_stage_offset_Y; - const size_t global_offset_X = prefetch_block_id_X * CHUNK_DIM_X; + const JobDescriptor prefetch_job = + decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, + work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const BlockDescriptor prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); + + const size_t global_offset_Y = prefetch_block.block_offset_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = prefetch_block.block_offset_X; const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; const CUtensorMap &prefetch_tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_tensor_id]; + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_job.tensor_id]; const CUtensorMap &prefetch_tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_tensor_id]; + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_job.tensor_id]; uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; if (leading_thread) { From 885fcb928c53ac2924d82a9240a725d215abaad9 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 18:40:59 +0000 Subject: [PATCH 09/51] Refactoring 2 Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 121 +++++++++++------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 47f7131e76..fde1bf02c6 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -329,6 +329,51 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +// Issue TMA global->shared transfer for one stage of input (and optional activation input). +template +__device__ __forceinline__ void prefetch_input_stage( + IType *in_sh, IType *act_in_sh, const CUtensorMap &tensor_map_input, + const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, const size_t global_offset_Y, + const size_t buff_offset, const size_t shmem_buff_size, uint64_t *barrier, const bool leading_thread) { + if (leading_thread) { + ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buff_offset]), + reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, + barrier); + if constexpr (IS_DACT) { + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&act_in_sh[buff_offset]), + reinterpret_cast(&tensor_map_act_input), global_offset_X, global_offset_Y, + barrier); + } + } +} + +// Issue TMA shared->global transfer for one stage of outputs. +template +__device__ __forceinline__ void store_output_stage( + OType *out_rowwise_data_sh, OType *out_colwise_data_sh, + const CUtensorMap &tensor_map_output_rowwise, const CUtensorMap &tensor_map_output_colwise, + const int global_offset_X, const int global_offset_Y, const int buff_offset, + const bool leading_thread) { + if (!leading_thread) { + return; + } + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, global_offset_Y, + reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, global_offset_Y, + reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + ptx::cp_async_bulk_commit_group(); +} + template @@ -480,19 +525,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t global_offset_X = first_block.block_offset_X; const size_t buff_offset = buff * BUFF_DIM; uint64_t *barrier = &IN_buff_readable_mbar[buff]; - if (leading_thread) { - ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_sh[buff_offset]), - reinterpret_cast(&tensor_map_input), global_offset_X, global_offset_Y, - barrier); - if constexpr (IS_DACT) { - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&act_in_sh[buff_offset]), - reinterpret_cast(&tensor_map_act_input), global_offset_X, - global_offset_Y, barrier); - } - } + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, + global_offset_X, global_offset_Y, buff_offset, + shmem_buff_size, barrier, leading_thread); } } @@ -596,6 +631,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel for (int stage = 0; stage < STAGES; ++stage) { const size_t stage_offset_Y = stage * BUFF_DIM_Y; bool allow_next_job_prefetch = true; + JobDescriptor prefetch_job = current_job; + BlockDescriptor prefetch_block = current_block; if (stage == STAGES - PREFETCH_STAGES) { if constexpr (DYNAMIC_PERSISTENT) { @@ -627,10 +664,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { - const JobDescriptor next_job = + prefetch_job = decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - allow_next_job_prefetch = is_job_valid(next_job, shape_rep, total_work_blocks, offsets_ptr); + allow_next_job_prefetch = is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); + if (allow_next_job_prefetch) { + prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); + } } if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { @@ -641,11 +681,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel prefetched_next_job = true; } - const JobDescriptor prefetch_job = - decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, - work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - const BlockDescriptor prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); - const size_t global_offset_Y = prefetch_block.block_offset_Y + next_prefetch_stage_offset_Y; const size_t global_offset_X = prefetch_block.block_offset_X; const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; @@ -663,18 +698,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel fence_acquire_tensormap(&prefetch_tensor_map_act_input); } } - ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_sh[next_prefetch_buff_offset]), - reinterpret_cast(&prefetch_tensor_map_input), global_offset_X, - global_offset_Y, barrier); - if constexpr (IS_DACT) { - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&act_in_sh[next_prefetch_buff_offset]), - reinterpret_cast(&prefetch_tensor_map_act_input), global_offset_X, - global_offset_Y, barrier); - } } + prefetch_input_stage( + in_sh, act_in_sh, prefetch_tensor_map_input, prefetch_tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); ptx::fence_proxy_async_shared_cta(); } @@ -686,6 +713,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t buff = buff_in; float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { + // Column-wise path: + // 1) load/compute values for one [32x1] stripe per thread + // 2) compute/write E8M0 scale + // 3) scale and write FP8 values into shared output buffer const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; thread_amax = 0.0f; float in_compute_colwise[BUFF_DIM_Y]; @@ -766,6 +797,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } if constexpr (ROWWISE_SCALING) { + // Row-wise path: + // 1) load/compute values for one [1x32] stripe per thread + // 2) compute/write E8M0 scale + // 3) scale and write FP8 values into shared output buffer const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; thread_amax = 0.0f; @@ -904,23 +939,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ptx::fence_proxy_async_shared_cta(); __syncthreads(); - if (leading_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); - } - ptx::cp_async_bulk_commit_group(); - } + // Publish the stage from shared memory into global outputs via TMA. + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + store_output_stage( + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, tensor_map_output_colwise, + global_offset_X, global_offset_Y, buff_offset, leading_thread); buff_in = (buff_in + 1) % BUFFS_NUM; } From d787847d7503f34d93e4d3fc44edb72bc5cbdd0a Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 4 Mar 2026 19:17:28 +0000 Subject: [PATCH 10/51] Refactoring 3 Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 497 ++++++++++-------- 1 file changed, 276 insertions(+), 221 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index fde1bf02c6..67537e8dfd 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -374,6 +374,271 @@ __device__ __forceinline__ void store_output_stage( ptx::cp_async_bulk_commit_group(); } +template +__device__ __forceinline__ float process_colwise_stage( + const size_t buff, const int stage, const size_t tid_X_colwise, + const size_t scales_offset_Y_colwise, const size_t scales_offset_X_colwise, + const size_t scale_stride_colwise, const size_t tensor_base_for_scales, const size_t rows, + const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, + OType *out_colwise_data_sh, e8m0_t *scales_colwise, float &partial_dbias_colwise) { + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; + if constexpr (!IS_CACHED_ACT_OP) { + (void)cached_act_sh; + } + if constexpr (!IS_DACT) { + (void)act_in_sh; + } + if constexpr (!IS_DBIAS) { + (void)partial_dbias_colwise; + } + if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { + (void)tensor_base_for_scales; + (void)rows; + } + + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + float thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_colwise[i] = elt; + } + } + + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = tensor_scales_offset_colwise_base + + transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + global_scales_offset_X, local_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + + return thread_amax; +} + +template +__device__ __forceinline__ float process_rowwise_stage( + const size_t buff, const size_t stage_offset_Y, const size_t thread_offset_Y_rowwise, + const size_t thread_offset_X_rowwise, const int bank_group, + const size_t scales_offset_Y_rowwise, const size_t scales_offset_X_rowwise, + const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, + const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, + OType *out_rowwise_data_sh, e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; + if constexpr (!IS_DACT) { + (void)act_in_sh; + } + if constexpr (!IS_CACHED_ACT_OP) { + (void)cached_act_sh; + } + if constexpr (!(IS_DBIAS && (!COLWISE_SCALING))) { + (void)thread_dbias_rowwise; + } + if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { + (void)cols; + } + + const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + float thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + Vec in_IType[WAVES]; + + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; + } + } + } + + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + + return thread_amax; +} + template @@ -393,19 +658,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; - if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { return; } } - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); const bool leading_thread = (threadIdx.x == 0); @@ -713,223 +971,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t buff = buff_in; float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { - // Column-wise path: - // 1) load/compute values for one [32x1] stripe per thread - // 2) compute/write E8M0 scale - // 3) scale and write FP8 values into shared output buffer - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_colwise[i] = elt; - } - } - - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - const size_t tensor_base_row = tensor_base_for_scales / cols; - const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; - const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; - const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); - } else { - scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - } - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } + thread_amax = process_colwise_stage( + buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, cached_act_sh, + out_colwise_data_sh, scales_colwise, partial_dbias_colwise); } if constexpr (ROWWISE_SCALING) { - // Row-wise path: - // 1) load/compute values for one [1x32] stripe per thread - // 2) compute/write E8M0 scale - // 3) scale and write FP8 values into shared output buffer - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - Vec in_IType[WAVES]; - - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; - } - } - } - - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, - DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - } - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } + thread_amax = process_rowwise_stage( + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, + scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, out_rowwise_data_sh, + scales_rowwise, thread_dbias_rowwise); } __builtin_assume(block_amax >= 0); From 79c1ac2ec1aea09726f81e420c29ad4251456c50 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 5 Mar 2026 13:38:12 +0000 Subject: [PATCH 11/51] Removed the dynamic (WorkID Query) persistency Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 92 ++++--------------- 1 file changed, 20 insertions(+), 72 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 67537e8dfd..7a510c1295 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -37,26 +37,21 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; -enum class PersistentStrategy : int { - NONE = 0, - DYNAMIC_WORK_STEALING = 1, - STATIC_GRID_STRIDE = 2, -}; - struct TunableConfig { static constexpr size_t CHUNK_DIM_Y = 128; static constexpr size_t CHUNK_DIM_X = 128; static constexpr size_t THREADS_PER_CHUNK = 128; static constexpr size_t PREFETCH_STAGES = 1; - static constexpr PersistentStrategy PERSISTENT_STRATEGY = PersistentStrategy::STATIC_GRID_STRIDE; + // true -> static persistent grid-stride scheduler + // false -> non-persistent one-job-per-CTA execution + static constexpr bool PERSISTENT = true; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; }; -constexpr bool DYNAMIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::DYNAMIC_WORK_STEALING; -constexpr bool STATIC_PERSISTENT = TunableConfig::PERSISTENT_STRATEGY == PersistentStrategy::STATIC_GRID_STRIDE; -static_assert(TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0, - "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero."); +constexpr bool PERSISTENT = TunableConfig::PERSISTENT; +static_assert(!PERSISTENT || (TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0), + "STATIC_PERSISTENT_BLOCKS_PER_SM must be greater than zero in persistent mode."); constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -177,11 +172,10 @@ __device__ __forceinline__ JobDescriptor decode_job( const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { JobDescriptor job{}; - job.block_id = static_cast(ctaid_Y) * work_blocks_X + static_cast(ctaid_X); - job.block_global_offset = - is_single_tensor ? (static_cast(ctaid_Y) * CHUNK_DIM_Y * last_logical_dim + - static_cast(ctaid_X) * CHUNK_DIM_X) - : (job.block_id * ELTS_PER_CHUNK); + job.block_id = ctaid_Y * work_blocks_X + ctaid_X; + job.block_global_offset = is_single_tensor + ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, first_logical_dim, last_logical_dim, offsets_ptr); job.rows = @@ -386,19 +380,6 @@ __device__ __forceinline__ float process_colwise_stage( constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; - if constexpr (!IS_CACHED_ACT_OP) { - (void)cached_act_sh; - } - if constexpr (!IS_DACT) { - (void)act_in_sh; - } - if constexpr (!IS_DBIAS) { - (void)partial_dbias_colwise; - } - if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { - (void)tensor_base_for_scales; - (void)rows; - } const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; float thread_amax = 0.0f; @@ -496,18 +477,6 @@ __device__ __forceinline__ float process_rowwise_stage( constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; - if constexpr (!IS_DACT) { - (void)act_in_sh; - } - if constexpr (!IS_CACHED_ACT_OP) { - (void)cached_act_sh; - } - if constexpr (!(IS_DBIAS && (!COLWISE_SCALING))) { - (void)thread_dbias_rowwise; - } - if constexpr (!WITH_GEMM_SWIZZLED_SCALES) { - (void)cols; - } const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; float thread_amax = 0.0f; @@ -709,44 +678,35 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel float block_amax = 0.0f; - __shared__ uint64_t workID_mbar; - [[maybe_unused]] __shared__ __uint128_t workID_response; - [[maybe_unused]] constexpr uint32_t workID_response_size = sizeof(workID_response); - static_assert(workID_response_size == 16); - __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; // Initialize barriers shared by the entire CTA: // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. - // - workID_mbar synchronizes WorkID query response in dynamic persistent mode. if (leading_thread) { #pragma unroll for (int buff = 0; buff < BUFFS_NUM; ++buff) { ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); } - ptx::mbarrier_init(&workID_mbar, 1); ptx::fence_proxy_async_shared_cta(); } __syncthreads(); const size_t total_work_blocks = work_blocks_X * work_blocks_Y; - const size_t launch_block_id = - static_cast(blockIdx.y) * static_cast(gridDim.x) + static_cast(blockIdx.x); + const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; int IN_buff_readable_parity[BUFFS_NUM] = {0}; - [[maybe_unused]] int ctaid_parity = 0; int32_t ctaid_X = static_cast(blockIdx.x); int32_t ctaid_Y = static_cast(blockIdx.y); - [[maybe_unused]] size_t static_next_block_id = 0; - [[maybe_unused]] size_t static_block_stride = 0; - // In STATIC_PERSISTENT mode physical CTAs iterate over a virtual work grid via grid-stride. - if constexpr (STATIC_PERSISTENT) { + size_t static_next_block_id = 0; + size_t static_block_stride = 0; + // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. + if constexpr (PERSISTENT) { if (launch_block_id >= total_work_blocks) { return; } ctaid_X = static_cast(launch_block_id % work_blocks_X); ctaid_Y = static_cast(launch_block_id / work_blocks_X); - static_block_stride = static_cast(gridDim.x) * static_cast(gridDim.y); + static_block_stride = gridDim.x * gridDim.y; static_next_block_id = launch_block_id + static_block_stride; } bool job_finished = false; @@ -789,7 +749,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - // Main persistent loop: decode current job, run all 32-row stages, schedule/prefetch next job. + // Main work loop: decode current job, run all 32-row stages, schedule/prefetch next job. while (!job_finished) { // Decode CTA assignment into logical tensor coordinates and validate bounds. const JobDescriptor current_job = @@ -876,13 +836,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - if constexpr (DYNAMIC_PERSISTENT) { - if (leading_thread) { - ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size); - ptx::try_cancel_cta(&workID_mbar, &workID_response); - } - } - bool prefetched_next_job = false; // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll @@ -893,11 +846,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel BlockDescriptor prefetch_block = current_block; if (stage == STAGES - PREFETCH_STAGES) { - if constexpr (DYNAMIC_PERSISTENT) { - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity); - ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y); - ctaid_parity ^= 1; - } else if constexpr (STATIC_PERSISTENT) { + if constexpr (PERSISTENT) { if (static_next_block_id < total_work_blocks) { ctaid_X = static_cast(static_next_block_id % work_blocks_X); ctaid_Y = static_cast(static_next_block_id / work_blocks_X); @@ -912,7 +861,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ctaid_X = -1; ctaid_Y = -1; } - if constexpr (!STATIC_PERSISTENT) { + if constexpr (!PERSISTENT) { if (ctaid_X == -1 && ctaid_Y == -1) { job_finished = true; } @@ -1062,7 +1011,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel for (int buff = 0; buff < BUFFS_NUM; ++buff) { ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); } - ptx::mbarrier_invalid(&workID_mbar); } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -1139,7 +1087,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations size_t launch_blocks_X = work_blocks_X; size_t launch_blocks_Y = work_blocks_Y; - if constexpr (STATIC_PERSISTENT) { + if constexpr (PERSISTENT) { const size_t sm_num = static_cast(transformer_engine::cuda::sm_count()); const size_t static_grid_size = sm_num * TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM; NVTE_CHECK(static_grid_size > 0, "Static persistent grid size must be greater than zero."); From 12b8712a94bfa3fc49f1c17815706d06053a45fa Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 5 Mar 2026 15:57:41 +0000 Subject: [PATCH 12/51] Ready for PR Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 3 +- tests/cpp/operator/CMakeLists.txt | 56 +- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 41 +- .../common/activation/activation_template.h | 30 +- .../common/cast/dispatch/dequantize.cuh | 52 +- .../common/cast/dispatch/gated.cuh | 304 ++++---- .../common/cast/dispatch/quantize.cuh | 729 +++++++++--------- 7 files changed, 604 insertions(+), 611 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 2092975b2a..6f4f163f08 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -8,8 +8,7 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) - set(CMAKE_CUDA_ARCHITECTURES 100) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index a04cc3c38c..56880a428d 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,35 +3,35 @@ # See LICENSE for license information. add_executable(test_operator - # test_cast.cu - # test_cast_current_scaling.cu - # test_cast_dbias.cu - # test_cast_dbias_dgelu.cu - # test_cast_gated_swiglu.cu - # test_cast_mxfp8_gated_swiglu.cu - # test_qdq.cu - # test_cast_mxfp8.cu + test_cast.cu + test_cast_current_scaling.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu + test_cast_gated_swiglu.cu + test_cast_mxfp8_gated_swiglu.cu + test_qdq.cu + test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu - # test_cast_nvfp4_transpose.cu - # test_cast_float8blockwise.cu - # test_dequantize_mxfp8.cu - # test_transpose.cu - # test_cast_transpose.cu - # test_cast_transpose_current_scaling.cu - # test_cast_transpose_dbias.cu - # test_cast_transpose_dbias_dgelu.cu - # test_cast_transpose_dgeglu.cu - # test_act.cu - # test_normalization.cu - # test_normalization_mxfp8.cu - # test_memset.cu - # test_multi_cast_transpose.cu - # test_multi_padding.cu - # test_multi_unpadding.cu - # test_causal_softmax.cu - # test_swizzle.cu - # test_swap_first_dims.cu - # test_grouped_gemm.cu + test_cast_nvfp4_transpose.cu + test_cast_float8blockwise.cu + test_dequantize_mxfp8.cu + test_transpose.cu + test_cast_transpose.cu + test_cast_transpose_current_scaling.cu + test_cast_transpose_dbias.cu + test_cast_transpose_dbias_dgelu.cu + test_cast_transpose_dgeglu.cu + test_act.cu + test_normalization.cu + test_normalization_mxfp8.cu + test_memset.cu + test_multi_cast_transpose.cu + test_multi_padding.cu + test_multi_unpadding.cu + test_causal_softmax.cu + test_swizzle.cu + test_swap_first_dims.cu + test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 6cff159d51..647737171a 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -623,15 +623,15 @@ void performTest(const ProcessingMethod processing_method, std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - // ProcessingMethod::CAST_DBIAS, - // ProcessingMethod::CAST_DBIAS_DACT, - // ProcessingMethod::CAST_DACT, - // ProcessingMethod::CAST_ACT, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, }; std::vector activation_kinds = { ActivationKind::Identity, - // ActivationKind::GeLU, + ActivationKind::GeLU, // ActivationKind::SiLU, // ActivationKind::ReLU, // ActivationKind::QGeLU, @@ -646,23 +646,22 @@ enum ScalingDirection { std::vector scaling_directions = { ScalingDirection::ROWWISE, - // ScalingDirection::COLWISE, - // ScalingDirection::BOTH, + ScalingDirection::COLWISE, + ScalingDirection::BOTH, }; // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - // {SAME_BOTH_DIMS, 1, 128,128}, - // {SAME_BOTH_DIMS, 2, 256,128}, - // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, - // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - {VARYING_FIRST_DIM, 5, 4096,4096, 128,256,384,1024,2304}, - {VARYING_FIRST_DIM, 5, 16 * 4096,4096, 128,256,384,1024,2304}, - // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace @@ -824,10 +823,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(activation_kinds), ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), - ::testing::Values(DType::kBFloat16), - ::testing::Values(DType::kFloat8E4M3)), - // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index caf6cbda65..ffbffafd1a 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -22,36 +22,36 @@ namespace transformer_engine { template void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - // using namespace detail; - // constexpr bool IS_ACT = true; - // dispatch::quantize_fwd_helper(input, output, nullptr, stream); + using namespace detail; + constexpr bool IS_ACT = true; + dispatch::quantize_fwd_helper(input, output, nullptr, stream); } template void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { - // using namespace detail; - // constexpr bool IS_DBIAS = false; - // constexpr bool IS_DACT = true; - // constexpr NVTETensor dbias = nullptr; - // constexpr NVTETensor workspace = nullptr; - - // dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, - // nullptr, stream); + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + + dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, + nullptr, stream); } template void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - // using namespace detail; - // dispatch::quantize_gated_fwd_helper(input, output, p, stream); + using namespace detail; + dispatch::quantize_gated_fwd_helper(input, output, p, stream); } template void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { - // using namespace detail; - // dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); + using namespace detail; + dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index db2ad285a8..81304981d3 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -22,32 +22,32 @@ namespace transformer_engine { namespace dispatch { inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - // CheckInputTensor(input, "cast_input"); - // CheckOutputTensor(*output, "cast_output"); - - // switch (input.scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); - // NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); - // NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); - // fp8::dequantize(input, output, stream); - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // if (is_supported_by_CC_100()) { - // mxfp8::dequantize(input, output, stream); - // } else { - // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - // } - // break; - // } - // case NVTE_NVFP4_1D_SCALING: { - // nvfp4::dequantize(input, output, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - // } + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); + NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); + fp8::dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + mxfp8::dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + nvfp4::dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } } } // namespace dispatch diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index c2087533a6..06e8f0e306 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -25,164 +25,164 @@ namespace dispatch { template void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - // const Tensor input = *convertNVTETensorCheck(nvte_input); - // Tensor *output = convertNVTETensorCheck(nvte_output); - - // CheckInputTensor(input, "input"); - // CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - // const size_t rows = input.flat_first_dim(); - // const size_t cols = input.flat_last_dim() / 2; - - // NVTE_CHECK(input.flat_last_dim() % 2 == 0, - // "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - // input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - // NVTE_CHECK(output->flat_last_dim() == cols, - // "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", - // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // switch (output->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - // if (use_tma_kernels) { - // Tensor dummy_grad_tensor; - // fp8::cast_gated_tma(input, dummy_grad_tensor, - // output, p, stream); - // } else { - // fp8::cast_gated_fwd(input, output, p, stream); - // } - // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // // FP8 kernel only populates row-wise data, so perform - // // transpose separately if needed - // Tensor transpose_in, transpose_out, dummy; - // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_in.data.dptr = output->data.dptr; - // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - // transpose_in.data.dtype = output->data.dtype; - // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_out.data.dptr = output->columnwise_data.dptr; - // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - // transpose_out.data.dtype = output->data.dtype; - // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // NVTE_CHECK(cols % 32 == 0, - // "Invalid input shape. Expected the last dimension to be " - // "divisible by 32, but got ", - // cols, "."); - // if (output->has_data()) { - // NVTE_CHECK(is_fp8_dtype(output->data.dtype), - // "The type of the output tensor should be FP8."); - // } - // if (output->has_columnwise_data()) { - // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - // "The type of the columnwise output tensor should be FP8."); - // } - // NVTE_CHECK(is_supported_by_CC_100(), - // "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - // Tensor dummy_grad_tensor; - // mxfp8::quantize_gated(input, dummy_grad_tensor, - // output, p, stream); - // break; - // } - // default: - // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - // } + const Tensor input = *convertNVTETensorCheck(nvte_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim() / 2; + + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols, + "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + Tensor dummy_grad_tensor; + fp8::cast_gated_tma(input, dummy_grad_tensor, + output, p, stream); + } else { + fp8::cast_gated_fwd(input, output, p, stream); + } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + Tensor dummy_grad_tensor; + mxfp8::quantize_gated(input, dummy_grad_tensor, + output, p, stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } } template void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - // const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); - // const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); - // Tensor *output = convertNVTETensorCheck(nvte_output); - - // CheckInputTensor(grad, "grad"); - // CheckInputTensor(gated_input, "gated_input"); - // CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - // NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", - // gated_input.flat_last_dim(), "."); - - // const size_t rows = gated_input.flat_first_dim(); - // const size_t cols = gated_input.flat_last_dim() / 2; - - // NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); - // NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); - - // NVTE_CHECK(grad.flat_first_dim() == rows, - // "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", - // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - // NVTE_CHECK(grad.flat_last_dim() == cols, - // "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", - // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - - // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", - // rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - // NVTE_CHECK(output->flat_last_dim() == cols * 2, - // "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", - // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - // NVTE_CHECK(gated_input.shape() == output->shape(), - // "Gated input and output shapes must match. Input shape: ", gated_input.shape(), - // ", output shape: ", output->shape(), "."); - - // switch (output->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - // if (use_tma_kernels) { - // fp8::cast_gated_tma(gated_input, grad, output, p, - // stream); - // } else { - // fp8::cast_gated_bwd(gated_input, grad, output, p, stream); - // } - // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // // FP8 kernel only populates row-wise data, so perform - // // transpose separately if needed - // Tensor transpose_in, transpose_out, dummy; - // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_in.data.dptr = output->data.dptr; - // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - // transpose_in.data.dtype = output->data.dtype; - // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_out.data.dptr = output->columnwise_data.dptr; - // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - // transpose_out.data.dtype = output->data.dtype; - // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // NVTE_CHECK(cols % 32 == 0, - // "Invalid input shape. Expected the last dimension to be " - // "divisible by 32, but got ", - // cols, "."); - // if (output->has_data()) { - // NVTE_CHECK(is_fp8_dtype(output->data.dtype), - // "The type of the output tensor should be FP8."); - // } - // if (output->has_columnwise_data()) { - // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - // "The type of the columnwise output tensor should be FP8."); - // } - // NVTE_CHECK(is_supported_by_CC_100(), - // "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - - // mxfp8::quantize_gated(gated_input, grad, output, p, - // stream); - // break; - // } - // default: - // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - // } + const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(grad, "grad"); + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + gated_input.flat_last_dim(), "."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + + NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); + NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); + + NVTE_CHECK(grad.flat_first_dim() == rows, + "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + NVTE_CHECK(grad.flat_last_dim() == cols, + "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols * 2, + "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(gated_input.shape() == output->shape(), + "Gated input and output shapes must match. Input shape: ", gated_input.shape(), + ", output shape: ", output->shape(), "."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + fp8::cast_gated_tma(gated_input, grad, output, p, + stream); + } else { + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + mxfp8::quantize_gated(gated_input, grad, output, p, + stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } } } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 0aadffa940..f7823b4c58 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -30,282 +30,282 @@ namespace dispatch { template void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // const Tensor *input_tensor = convertNVTETensorCheck(input); - // Tensor *output_tensor = convertNVTETensorCheck(output); - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // // Dispatch to quantization kernel depending on data format - // switch (output_tensor->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const Tensor *dummy_input_tensor = nullptr; - // Tensor *dummy_dbias_tensor = nullptr; - // Tensor *dummy_workspace_tensor = nullptr; - // if (output_tensor->has_columnwise_data()) { - // NVTE_CHECK(output_tensor->has_data(), - // "Quantizing in only the columnwise direction not supported yet!"); - // if constexpr (!IS_ACT) { - // cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - // } else { - // cast_transpose_fused( - // *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // } - // } else if (output_tensor->has_data()) { - // fp8::quantize( - // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // const Tensor *dummy_input_tensor = nullptr; - // Tensor *dummy_dbias_tensor = nullptr; - // Tensor *dummy_workspace_tensor = nullptr; - // mxfp8::quantize( - // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // break; - // } - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*input_tensor, "input"); - // CheckOutputTensor(*output_tensor, "output", false); - - // // Choose kernel - // int32_t rows = input_tensor->flat_first_dim(); - // int32_t cols = input_tensor->flat_last_dim(); - // auto dtype = input_tensor->dtype(); - // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - // (cols % 32 == 0) && output_tensor->has_data(); - - // // Launch NVFP4 quantize kernel - // if (use_optimized_kernel) { - // if (quant_config_cpp.nvfp4_2d_quantization) { - // nvfp4::quantize_transpose( - // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } else { - // nvfp4::quantize_transpose( - // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } - // } else { - // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - // : output_tensor->columnwise_amax; - // quantize_transpose_vector_blockwise_fp4( - // /*input=*/input_tensor->data, /*global_amax=*/global_amax, - // /*scale_inv=*/output_tensor->scale_inv, - // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - // /*swizzled_scale=*/false, - // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - // /*rng_state=*/quant_config_cpp.rng_state, - // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - // } - // break; - // } - // case NVTE_BLOCK_SCALING_2D: { - // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // quantize_transpose_square_blockwise( - // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, - // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - // /*noop_tensor=*/noop_tensor->data, stream); - // break; - // } - // case NVTE_BLOCK_SCALING_1D: { - // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - // if (output_tensor->has_data()) { - // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - // } - // if (output_tensor->has_columnwise_data()) { - // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - // } - // quantize_transpose_vector_blockwise( - // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + Tensor *output_tensor = convertNVTETensorCheck(output); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_ACT) { + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + mxfp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } } template void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // const Tensor *grad_tensor = convertNVTETensorCheck(grad); - // const Tensor *input_tensor = convertNVTETensor(input); - - // Tensor *output_tensor = convertNVTETensorCheck(output); - // Tensor *dbias_tensor = convertNVTETensor(dbias); - // Tensor *workspace_tensor = convertNVTETensor(workspace); - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // // Dispatch to quantization kernel depending on data format - // switch (output_tensor->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // if (output_tensor->has_columnwise_data()) { - // NVTE_CHECK(output_tensor->has_data(), - // "Quantizing in only the columnwise direction not supported yet!"); - // if constexpr (!IS_DBIAS && !IS_DACT) { - // cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); - // } else { - // cast_transpose_fused( - // *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); - // } - // } else if (output_tensor->has_data()) { - // fp8::quantize( - // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // mxfp8::quantize( - // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // stream); - // break; - // } - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*grad_tensor, "input"); - // CheckOutputTensor(*output_tensor, "output", false); - - // // Choose kernel - // int32_t rows = grad_tensor->flat_first_dim(); - // int32_t cols = grad_tensor->flat_last_dim(); - // auto dtype = grad_tensor->dtype(); - // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - // (cols % 32 == 0) && output_tensor->has_data(); - - // // Launch NVFP4 quantize kernel - // if (use_optimized_kernel) { - // if (quant_config_cpp.nvfp4_2d_quantization) { - // nvfp4::quantize_transpose( - // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } else { - // nvfp4::quantize_transpose( - // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } - // } else { - // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - // : output_tensor->columnwise_amax; - // quantize_transpose_vector_blockwise_fp4( - // /*input=*/grad_tensor->data, /*global_amax=*/global_amax, - // /*scale_inv=*/output_tensor->scale_inv, - // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - // /*swizzled_scale=*/false, - // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - // /*rng_state=*/quant_config_cpp.rng_state, - // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - // } - // break; - // } - // case NVTE_BLOCK_SCALING_2D: { - // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // quantize_transpose_square_blockwise( - // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, - // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - // /*noop_tensor=*/noop_tensor->data, stream); - // break; - // } - // case NVTE_BLOCK_SCALING_1D: { - // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - // if (output_tensor->has_data()) { - // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - // } - // if (output_tensor->has_columnwise_data()) { - // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - // } - // quantize_transpose_vector_blockwise( - // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *grad_tensor = convertNVTETensorCheck(grad); + const Tensor *input_tensor = convertNVTETensor(input); + + Tensor *output_tensor = convertNVTETensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT) { + cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*grad_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = grad_tensor->flat_first_dim(); + int32_t cols = grad_tensor->flat_last_dim(); + auto dtype = grad_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } } // Host-aware and not graph-safe: group quantization with split section info from the host. @@ -314,64 +314,64 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // const Tensor *input_tensor = convertNVTETensorCheck(input); - // std::vector output_tensors; - // for (size_t i = 0; i < num_tensors; ++i) { - // output_tensors.push_back(convertNVTETensorCheck(outputs[i])); - // } - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // // Take the scaling mode of the first output tensor - // auto scaling_mode = output_tensors[0]->scaling_mode; - - // // Dispatch to quantization kernel depending on data format - // switch (scaling_mode) { - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*input_tensor, "input"); - // // Skip checking output tensor list - // // output list here is allowed to have empty tensor - - // // Choose kernel - // int32_t rows = input_tensor->flat_first_dim(); - // int32_t cols = input_tensor->flat_last_dim(); - // auto dtype = input_tensor->dtype(); - - // NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - // "2D quantization is not supported for group quantize."); - - // // Launch NVFP4 group quantize kernel - // nvfp4::group_quantize_transpose( - // *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, - // &quant_config_cpp, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_tensors; + for (size_t i = 0; i < num_tensors; ++i) { + output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + } + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + // Take the scaling mode of the first output tensor + auto scaling_mode = output_tensors[0]->scaling_mode; + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + // Skip checking output tensor list + // output list here is allowed to have empty tensor + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "2D quantization is not supported for group quantize."); + + // Launch NVFP4 group quantize kernel + nvfp4::group_quantize_transpose( + *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + &quant_config_cpp, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } } template @@ -407,10 +407,7 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - // mxfp8::group_quantize( - // IS_ACT is set to false - // OP is set to nullptr - mxfp8::group_quantize( + mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; @@ -425,40 +422,40 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - // const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); - // const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); - // GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - // GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); - // Tensor *workspace_tensor = convertNVTETensor(workspace); - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Dispatch to quantization kernel depending on data format - // switch (scaling_mode) { - // case NVTE_MXFP8_1D_SCALING: { - // mxfp8::group_quantize( - // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - // } + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } } } // namespace dispatch From 2812d55ec4887ab5a49a94127145d5aaae8bcea0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 16:17:04 +0000 Subject: [PATCH 13/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 115 ++++++++++-------- 1 file changed, 63 insertions(+), 52 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 7a510c1295..b28fe1d820 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -17,9 +17,9 @@ #include #include "../../common.h" +#include "../../util/cuda_runtime.h" #include "../../util/math.h" #include "../../util/ptx.cuh" -#include "../../util/cuda_runtime.h" #include "../../utils.cuh" #include "../core/common.cuh" #include "swizzle.cuh" @@ -170,12 +170,13 @@ __device__ __forceinline__ JobDescriptor decode_job( const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr) { JobDescriptor job{}; job.block_id = ctaid_Y * work_blocks_X + ctaid_X; job.block_global_offset = is_single_tensor - ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) - : (job.block_id * ELTS_PER_CHUNK); + ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (job.block_id * ELTS_PER_CHUNK); job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, first_logical_dim, last_logical_dim, offsets_ptr); job.rows = @@ -209,9 +210,9 @@ __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, return true; } -__device__ __forceinline__ BlockDescriptor decode_block(const JobDescriptor &job, - const bool is_single_tensor, - const int64_t *const __restrict__ offsets_ptr) { +__device__ __forceinline__ BlockDescriptor +decode_block(const JobDescriptor &job, const bool is_single_tensor, + const int64_t *const __restrict__ offsets_ptr) { BlockDescriptor block{}; block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(128)); @@ -327,8 +328,9 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso template __device__ __forceinline__ void prefetch_input_stage( IType *in_sh, IType *act_in_sh, const CUtensorMap &tensor_map_input, - const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, const size_t global_offset_Y, - const size_t buff_offset, const size_t shmem_buff_size, uint64_t *barrier, const bool leading_thread) { + const CUtensorMap &tensor_map_act_input, const size_t global_offset_X, + const size_t global_offset_Y, const size_t buff_offset, const size_t shmem_buff_size, + uint64_t *barrier, const bool leading_thread) { if (leading_thread) { ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size); ptx::cp_async_bulk_tensor_2d_global_to_shared( @@ -338,39 +340,41 @@ __device__ __forceinline__ void prefetch_input_stage( if constexpr (IS_DACT) { ptx::cp_async_bulk_tensor_2d_global_to_shared( reinterpret_cast(&act_in_sh[buff_offset]), - reinterpret_cast(&tensor_map_act_input), global_offset_X, global_offset_Y, - barrier); + reinterpret_cast(&tensor_map_act_input), global_offset_X, + global_offset_Y, barrier); } } } // Issue TMA shared->global transfer for one stage of outputs. template -__device__ __forceinline__ void store_output_stage( - OType *out_rowwise_data_sh, OType *out_colwise_data_sh, - const CUtensorMap &tensor_map_output_rowwise, const CUtensorMap &tensor_map_output_colwise, - const int global_offset_X, const int global_offset_Y, const int buff_offset, - const bool leading_thread) { +__device__ __forceinline__ void store_output_stage(OType *out_rowwise_data_sh, + OType *out_colwise_data_sh, + const CUtensorMap &tensor_map_output_rowwise, + const CUtensorMap &tensor_map_output_colwise, + const int global_offset_X, + const int global_offset_Y, const int buff_offset, + const bool leading_thread) { if (!leading_thread) { return; } if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, global_offset_Y, - reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); } if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, global_offset_Y, - reinterpret_cast(&out_colwise_data_sh[buff_offset])); + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); } ptx::cp_async_bulk_commit_group(); } template + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING, + bool WITH_GEMM_SWIZZLED_SCALES> __device__ __forceinline__ float process_colwise_stage( const size_t buff, const int stage, const size_t tid_X_colwise, const size_t scales_offset_Y_colwise, const size_t scales_offset_X_colwise, @@ -434,10 +438,10 @@ __device__ __forceinline__ float process_colwise_stage( const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; - scale_idx = tensor_scales_offset_colwise_base + - transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( - global_scales_offset_X, local_scales_offset_Y, - DIVUP(rows, static_cast(128))); + scale_idx = + tensor_scales_offset_colwise_base + + transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + global_scales_offset_X, local_scales_offset_Y, DIVUP(rows, static_cast(128))); } else { scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; } @@ -463,15 +467,15 @@ __device__ __forceinline__ float process_colwise_stage( } template + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool COLWISE_SCALING, + bool WITH_GEMM_SWIZZLED_SCALES> __device__ __forceinline__ float process_rowwise_stage( const size_t buff, const size_t stage_offset_Y, const size_t thread_offset_Y_rowwise, const size_t thread_offset_X_rowwise, const int bank_group, const size_t scales_offset_Y_rowwise, const size_t scales_offset_X_rowwise, - const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, - const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, - OType *out_rowwise_data_sh, e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { + const size_t scale_stride_rowwise, const bool rowwise_scale_is_within_bounds, const size_t cols, + IType *in_sh, IType *act_in_sh, IType *cached_act_sh, OType *out_rowwise_data_sh, + e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; @@ -725,8 +729,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const CUtensorMap &tensor_map_input = is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[first_job.tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[first_job.tensor_id]; + const CUtensorMap &tensor_map_act_input = is_single_tensor + ? tensor_map_act_input_static + : g_tensor_maps_act_input[first_job.tensor_id]; if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); @@ -809,10 +814,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; const CUtensorMap &tensor_map_act_input = is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = - is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = - is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = is_single_tensor + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = is_single_tensor + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; if (leading_thread && (!is_single_tensor)) { fence_acquire_tensormap(&tensor_map_input); @@ -871,10 +878,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { - prefetch_job = - decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, - work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - allow_next_job_prefetch = is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); + prefetch_job = decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, + last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, + first_dims_ptr, last_dims_ptr); + allow_next_job_prefetch = + is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); if (allow_next_job_prefetch) { prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); } @@ -893,9 +901,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; const CUtensorMap &prefetch_tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[prefetch_job.tensor_id]; + is_single_tensor ? tensor_map_input_static + : g_tensor_maps_input[prefetch_job.tensor_id]; const CUtensorMap &prefetch_tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[prefetch_job.tensor_id]; + is_single_tensor ? tensor_map_act_input_static + : g_tensor_maps_act_input[prefetch_job.tensor_id]; uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; if (leading_thread) { @@ -906,9 +916,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } } - prefetch_input_stage( - in_sh, act_in_sh, prefetch_tensor_map_input, prefetch_tensor_map_act_input, global_offset_X, - global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); + prefetch_input_stage(in_sh, act_in_sh, prefetch_tensor_map_input, + prefetch_tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, + shmem_buff_size, barrier, leading_thread); ptx::fence_proxy_async_shared_cta(); } @@ -923,8 +934,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel thread_amax = process_colwise_stage( buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, - scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, cached_act_sh, - out_colwise_data_sh, scales_colwise, partial_dbias_colwise); + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, + cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); } if constexpr (ROWWISE_SCALING) { @@ -932,8 +943,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel COLWISE_SCALING, WITH_GEMM_SWIZZLED_SCALES>( buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, - rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, out_rowwise_data_sh, - scales_rowwise, thread_dbias_rowwise); + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, + out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); } __builtin_assume(block_amax >= 0); @@ -948,8 +959,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int global_offset_X = block_offset_X; const int buff_offset = buff * BUFF_DIM; store_output_stage( - out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, tensor_map_output_colwise, - global_offset_X, global_offset_Y, buff_offset, leading_thread); + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, + tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, leading_thread); buff_in = (buff_in + 1) % BUFFS_NUM; } From f24afb27b40198489c65a98e04e0e88af28385ce Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 6 Mar 2026 10:37:56 +0000 Subject: [PATCH 14/51] Fixes per the review Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 2 +- transformer_engine/common/cast/core/common.cuh | 4 ++-- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 7 ++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 647737171a..e54ceebaa3 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -371,7 +371,7 @@ void performTest(const ProcessingMethod processing_method, NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); - std::vector dbias_logical_shape_vec= {num_tensors, cols}; + std::vector dbias_logical_shape_vec = {num_tensors, cols}; NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), dbias_logical_shape_vec.size()); diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index a4e033939b..ce9fce6285 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -100,14 +100,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t tensor_id = blockIdx.y; const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) ? (first_logical_dim / num_tensors) - : first_dims_ptr[tensor_id]; + : static_cast(first_dims_ptr[tensor_id]); const size_t rows = tensor_rows / chunk_dim_Y; const size_t cols = last_logical_dim; const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) ? (tensor_id * (tensor_rows / chunk_dim_Y)) - : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index b28fe1d820..8c452f2a7a 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -142,6 +142,10 @@ __device__ __forceinline__ size_t get_tensor_cols_num( case ShapeRepresentation::VARYING_LAST_DIM: case ShapeRepresentation::VARYING_BOTH_DIMS: cols_num = static_cast(last_dims_ptr[tensor_id]); + if (cols_num % 128 != 0) { + NVTE_DEVICE_ERROR("For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); + } break; } return cols_num; @@ -215,7 +219,8 @@ decode_block(const JobDescriptor &job, const bool is_single_tensor, const int64_t *const __restrict__ offsets_ptr) { BlockDescriptor block{}; block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); - const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, static_cast(128)); + const size_t CHUNK_DIM_X_ = CHUNK_DIM_X; + const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_); block.block_id_in_current_tensor = is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK); block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor; From aa484a3fff022695f42f0fde84c41c93433f175e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 10:40:30 +0000 Subject: [PATCH 15/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/core/common.cuh | 7 ++++--- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 5 +++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index ce9fce6285..9c16666db0 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -105,9 +105,10 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t rows = tensor_rows / chunk_dim_Y; const size_t cols = last_logical_dim; - const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) - ? (tensor_id * (tensor_rows / chunk_dim_Y)) - : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); + const size_t dbias_in_offset_Y = + (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (static_cast(offsets_ptr[tensor_id]) / cols / chunk_dim_Y); const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 8c452f2a7a..6e9bd3dc5e 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -143,8 +143,9 @@ __device__ __forceinline__ size_t get_tensor_cols_num( case ShapeRepresentation::VARYING_BOTH_DIMS: cols_num = static_cast(last_dims_ptr[tensor_id]); if (cols_num % 128 != 0) { - NVTE_DEVICE_ERROR("For non-single tensors, the last dimension of each tensor in a group " - "must be divisible by 128."); + NVTE_DEVICE_ERROR( + "For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); } break; } From 74722a53b50a25c2bb7a4afaad23b08dca93780f Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 16:48:29 +0000 Subject: [PATCH 16/51] Ready for benchmark Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 3 +- tests/cpp/operator/CMakeLists.txt | 56 +- tests/cpp/operator/test_cast_mxfp8.cu | 47 +- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 43 +- .../common/cast/dispatch/dequantize.cuh | 52 +- .../common/cast/dispatch/gated.cuh | 304 ++++---- .../common/cast/dispatch/quantize.cuh | 648 +++++++++--------- 7 files changed, 582 insertions(+), 571 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 6f4f163f08..2092975b2a 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -8,7 +8,8 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + set(CMAKE_CUDA_ARCHITECTURES 100) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 5e73675f4f..84c682134f 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,36 +3,36 @@ # See LICENSE for license information. add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu + # test_cast.cu + # test_cast_current_scaling.cu + # test_cast_dbias.cu + # test_cast_dbias_dgelu.cu + # test_cast_gated_swiglu.cu + # test_cast_mxfp8_gated_swiglu.cu + # test_qdq.cu test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu - test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_splits_to_offsets.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swizzle.cu - test_swap_first_dims.cu - test_grouped_gemm.cu + # test_cast_nvfp4_transpose.cu + # test_cast_float8blockwise.cu + # test_dequantize_mxfp8.cu + # test_transpose.cu + # test_cast_transpose.cu + # test_cast_transpose_current_scaling.cu + # test_cast_transpose_dbias.cu + # test_cast_transpose_dbias_dgelu.cu + # test_cast_transpose_dgeglu.cu + # test_act.cu + # test_normalization.cu + # test_normalization_mxfp8.cu + # test_memset.cu + # test_splits_to_offsets.cu + # test_multi_cast_transpose.cu + # test_multi_padding.cu + # test_multi_unpadding.cu + # test_causal_softmax.cu + # test_swizzle.cu + # test_swap_first_dims.cu + # test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index b5e11c30e1..e22b76ff89 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -522,24 +522,25 @@ void performTest_x2(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { - {1, 16}, - {16, 48}, - {65, 96}, - {128, 128}, - {256, 256}, - {993, 512}, - {511, 6144}, - {8192, 128}, - {2048, 160}, - {577, 1632}, - {1024}, - {8, 32, 1024}, - {16, 8, 4, 512}, + // {1, 16}, + // {16, 48}, + // {65, 96}, + // {128, 128}, + // {256, 256}, + // {993, 512}, + // {511, 6144}, + // {8192, 128}, + // {2048, 160}, + // {577, 1632}, + // {1024}, + // {8, 32, 1024}, + // {16, 8, 4, 512}, + {8192, 7168}, }; std::vector> block_sizes = { - {1, 32}, - {32, 1}, + // {1, 32}, + // {32, 1}, {32, 32}, }; @@ -553,16 +554,16 @@ std::vector input_scenarios = { std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - ProcessingMethod::CAST_DBIAS, - ProcessingMethod::CAST_DBIAS_DACT, - ProcessingMethod::CAST_DACT, - ProcessingMethod::CAST_ACT, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, }; // Only GeLU activation tests are supported std::vector Activation_types = { ActivationType::Identity, - ActivationType::GeLU, + // ActivationType::GeLU, // ActivationType::SiLU, // ActivationType::ReLU, // ActivationType::QGeLU, @@ -691,8 +692,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), ::testing::ValuesIn(block_sizes), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(DType::kBFloat16), + ::testing::Values(DType::kFloat8E4M3), + // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index e54ceebaa3..b49d68b1c4 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -623,15 +623,15 @@ void performTest(const ProcessingMethod processing_method, std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - ProcessingMethod::CAST_DBIAS, - ProcessingMethod::CAST_DBIAS_DACT, - ProcessingMethod::CAST_DACT, - ProcessingMethod::CAST_ACT, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, }; std::vector activation_kinds = { ActivationKind::Identity, - ActivationKind::GeLU, + // ActivationKind::GeLU, // ActivationKind::SiLU, // ActivationKind::ReLU, // ActivationKind::QGeLU, @@ -645,23 +645,26 @@ enum ScalingDirection { }; std::vector scaling_directions = { - ScalingDirection::ROWWISE, - ScalingDirection::COLWISE, + // ScalingDirection::ROWWISE, + // ScalingDirection::COLWISE, ScalingDirection::BOTH, }; // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - {SAME_BOTH_DIMS, 1, 128,128}, - {SAME_BOTH_DIMS, 2, 256,128}, - {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, - {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, - {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + {SAME_BOTH_DIMS, 1, 8192,7168}, + {VARYING_FIRST_DIM, 6, 8192,7168, 128,256,384,1024,2304,4096}, + {VARYING_FIRST_DIM, 6, 16*8192,7168, 128,256,384,1024,2304,4096}, + // {SAME_BOTH_DIMS, 1, 128,128}, + // {SAME_BOTH_DIMS, 2, 256,128}, + // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + // {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, + // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace @@ -823,8 +826,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(activation_kinds), ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kBFloat16), + ::testing::Values(DType::kFloat8E4M3)), + // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 81304981d3..db2ad285a8 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -22,32 +22,32 @@ namespace transformer_engine { namespace dispatch { inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - switch (input.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); - NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); - fp8::dequantize(input, output, stream); - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { - mxfp8::dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - break; - } - case NVTE_NVFP4_1D_SCALING: { - nvfp4::dequantize(input, output, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } + // CheckInputTensor(input, "cast_input"); + // CheckOutputTensor(*output, "cast_output"); + + // switch (input.scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); + // NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); + // NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); + // fp8::dequantize(input, output, stream); + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // if (is_supported_by_CC_100()) { + // mxfp8::dequantize(input, output, stream); + // } else { + // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + // } + // break; + // } + // case NVTE_NVFP4_1D_SCALING: { + // nvfp4::dequantize(input, output, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + // } } } // namespace dispatch diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 06e8f0e306..c2087533a6 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -25,164 +25,164 @@ namespace dispatch { template void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - const Tensor input = *convertNVTETensorCheck(nvte_input); - Tensor *output = convertNVTETensorCheck(nvte_output); - - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim() / 2; - - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == cols, - "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", - output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - if (use_tma_kernels) { - Tensor dummy_grad_tensor; - fp8::cast_gated_tma(input, dummy_grad_tensor, - output, p, stream); - } else { - fp8::cast_gated_fwd(input, output, p, stream); - } - if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // FP8 kernel only populates row-wise data, so perform - // transpose separately if needed - Tensor transpose_in, transpose_out, dummy; - transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_in.data.dptr = output->data.dptr; - transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - transpose_in.data.dtype = output->data.dtype; - transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_out.data.dptr = output->columnwise_data.dptr; - transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - transpose_out.data.dtype = output->data.dtype; - detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - NVTE_CHECK(cols % 32 == 0, - "Invalid input shape. Expected the last dimension to be " - "divisible by 32, but got ", - cols, "."); - if (output->has_data()) { - NVTE_CHECK(is_fp8_dtype(output->data.dtype), - "The type of the output tensor should be FP8."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - "The type of the columnwise output tensor should be FP8."); - } - NVTE_CHECK(is_supported_by_CC_100(), - "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - Tensor dummy_grad_tensor; - mxfp8::quantize_gated(input, dummy_grad_tensor, - output, p, stream); - break; - } - default: - NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - } + // const Tensor input = *convertNVTETensorCheck(nvte_input); + // Tensor *output = convertNVTETensorCheck(nvte_output); + + // CheckInputTensor(input, "input"); + // CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + // const size_t rows = input.flat_first_dim(); + // const size_t cols = input.flat_last_dim() / 2; + + // NVTE_CHECK(input.flat_last_dim() % 2 == 0, + // "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + // input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + // NVTE_CHECK(output->flat_last_dim() == cols, + // "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // switch (output->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // if (use_tma_kernels) { + // Tensor dummy_grad_tensor; + // fp8::cast_gated_tma(input, dummy_grad_tensor, + // output, p, stream); + // } else { + // fp8::cast_gated_fwd(input, output, p, stream); + // } + // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // // FP8 kernel only populates row-wise data, so perform + // // transpose separately if needed + // Tensor transpose_in, transpose_out, dummy; + // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_in.data.dptr = output->data.dptr; + // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + // transpose_in.data.dtype = output->data.dtype; + // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_out.data.dptr = output->columnwise_data.dptr; + // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + // transpose_out.data.dtype = output->data.dtype; + // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // NVTE_CHECK(cols % 32 == 0, + // "Invalid input shape. Expected the last dimension to be " + // "divisible by 32, but got ", + // cols, "."); + // if (output->has_data()) { + // NVTE_CHECK(is_fp8_dtype(output->data.dtype), + // "The type of the output tensor should be FP8."); + // } + // if (output->has_columnwise_data()) { + // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + // "The type of the columnwise output tensor should be FP8."); + // } + // NVTE_CHECK(is_supported_by_CC_100(), + // "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + // Tensor dummy_grad_tensor; + // mxfp8::quantize_gated(input, dummy_grad_tensor, + // output, p, stream); + // break; + // } + // default: + // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + // } } template void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); - const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); - Tensor *output = convertNVTETensorCheck(nvte_output); - - CheckInputTensor(grad, "grad"); - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", - gated_input.flat_last_dim(), "."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - - NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); - NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); - - NVTE_CHECK(grad.flat_first_dim() == rows, - "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", - grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - NVTE_CHECK(grad.flat_last_dim() == cols, - "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", - grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", - rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == cols * 2, - "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", - output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(gated_input.shape() == output->shape(), - "Gated input and output shapes must match. Input shape: ", gated_input.shape(), - ", output shape: ", output->shape(), "."); - - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - if (use_tma_kernels) { - fp8::cast_gated_tma(gated_input, grad, output, p, - stream); - } else { - fp8::cast_gated_bwd(gated_input, grad, output, p, stream); - } - if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // FP8 kernel only populates row-wise data, so perform - // transpose separately if needed - Tensor transpose_in, transpose_out, dummy; - transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_in.data.dptr = output->data.dptr; - transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - transpose_in.data.dtype = output->data.dtype; - transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - transpose_out.data.dptr = output->columnwise_data.dptr; - transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - transpose_out.data.dtype = output->data.dtype; - detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - NVTE_CHECK(cols % 32 == 0, - "Invalid input shape. Expected the last dimension to be " - "divisible by 32, but got ", - cols, "."); - if (output->has_data()) { - NVTE_CHECK(is_fp8_dtype(output->data.dtype), - "The type of the output tensor should be FP8."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - "The type of the columnwise output tensor should be FP8."); - } - NVTE_CHECK(is_supported_by_CC_100(), - "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - - mxfp8::quantize_gated(gated_input, grad, output, p, - stream); - break; - } - default: - NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - } + // const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + // const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + // Tensor *output = convertNVTETensorCheck(nvte_output); + + // CheckInputTensor(grad, "grad"); + // CheckInputTensor(gated_input, "gated_input"); + // CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + // NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + // gated_input.flat_last_dim(), "."); + + // const size_t rows = gated_input.flat_first_dim(); + // const size_t cols = gated_input.flat_last_dim() / 2; + + // NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); + // NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); + + // NVTE_CHECK(grad.flat_first_dim() == rows, + // "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + // NVTE_CHECK(grad.flat_last_dim() == cols, + // "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + // rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + // NVTE_CHECK(output->flat_last_dim() == cols * 2, + // "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + // NVTE_CHECK(gated_input.shape() == output->shape(), + // "Gated input and output shapes must match. Input shape: ", gated_input.shape(), + // ", output shape: ", output->shape(), "."); + + // switch (output->scaling_mode) { + // case NVTE_DELAYED_TENSOR_SCALING: { + // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // if (use_tma_kernels) { + // fp8::cast_gated_tma(gated_input, grad, output, p, + // stream); + // } else { + // fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + // } + // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // // FP8 kernel only populates row-wise data, so perform + // // transpose separately if needed + // Tensor transpose_in, transpose_out, dummy; + // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_in.data.dptr = output->data.dptr; + // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + // transpose_in.data.dtype = output->data.dtype; + // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + // transpose_out.data.dptr = output->columnwise_data.dptr; + // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + // transpose_out.data.dtype = output->data.dtype; + // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + // } + // break; + // } + // case NVTE_MXFP8_1D_SCALING: { + // NVTE_CHECK(cols % 32 == 0, + // "Invalid input shape. Expected the last dimension to be " + // "divisible by 32, but got ", + // cols, "."); + // if (output->has_data()) { + // NVTE_CHECK(is_fp8_dtype(output->data.dtype), + // "The type of the output tensor should be FP8."); + // } + // if (output->has_columnwise_data()) { + // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + // "The type of the columnwise output tensor should be FP8."); + // } + // NVTE_CHECK(is_supported_by_CC_100(), + // "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + // mxfp8::quantize_gated(gated_input, grad, output, p, + // stream); + // break; + // } + // default: + // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + // } } } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index f7823b4c58..835ae149c8 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -59,109 +59,110 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, // Dispatch to quantization kernel depending on data format switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const Tensor *dummy_input_tensor = nullptr; - Tensor *dummy_dbias_tensor = nullptr; - Tensor *dummy_workspace_tensor = nullptr; - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_ACT) { - cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); - } - break; - } + // case NVTE_DELAYED_TENSOR_SCALING: { + // const Tensor *dummy_input_tensor = nullptr; + // Tensor *dummy_dbias_tensor = nullptr; + // Tensor *dummy_workspace_tensor = nullptr; + // if (output_tensor->has_columnwise_data()) { + // NVTE_CHECK(output_tensor->has_data(), + // "Quantizing in only the columnwise direction not supported yet!"); + // if constexpr (!IS_ACT) { + // cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + // } else { + // cast_transpose_fused( + // *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // } else if (output_tensor->has_data()) { + // fp8::quantize( + // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + // dummy_workspace_tensor, stream); + // } + // break; + // } case NVTE_MXFP8_1D_SCALING: { const Tensor *dummy_input_tensor = nullptr; Tensor *dummy_dbias_tensor = nullptr; Tensor *dummy_workspace_tensor = nullptr; - mxfp8::quantize( + // mxfp8::quantize( + mxfp8::quantize( *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, dummy_workspace_tensor, stream); break; } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // CheckOutputTensor(*output_tensor, "output", false); + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // (cols % 32 == 0) && output_tensor->has_data(); + + // // Launch NVFP4 quantize kernel + // if (use_optimized_kernel) { + // if (quant_config_cpp.nvfp4_2d_quantization) { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } else { + // nvfp4::quantize_transpose( + // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // } + // } else { + // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // : output_tensor->columnwise_amax; + // quantize_transpose_vector_blockwise_fp4( + // /*input=*/input_tensor->data, /*global_amax=*/global_amax, + // /*scale_inv=*/output_tensor->scale_inv, + // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // /*swizzled_scale=*/false, + // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // /*rng_state=*/quant_config_cpp.rng_state, + // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // } + // break; + // } + // case NVTE_BLOCK_SCALING_2D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // quantize_transpose_square_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, + // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // /*noop_tensor=*/noop_tensor->data, stream); + // break; + // } + // case NVTE_BLOCK_SCALING_1D: { + // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // float epsilon = quant_config_cpp.amax_epsilon; + // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // if (output_tensor->has_data()) { + // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // } + // if (output_tensor->has_columnwise_data()) { + // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // } + // quantize_transpose_vector_blockwise( + // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // break; + // } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } @@ -171,141 +172,141 @@ template (quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT) { - cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); - } - } else if (output_tensor->has_data()) { - fp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*grad_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = grad_tensor->flat_first_dim(); - int32_t cols = grad_tensor->flat_last_dim(); - auto dtype = grad_tensor->dtype(); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4::quantize_transpose( - *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - quantize_transpose_vector_blockwise_fp4( - /*input=*/grad_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT), - "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *grad_tensor = convertNVTETensorCheck(grad); + // const Tensor *input_tensor = convertNVTETensor(input); + + // Tensor *output_tensor = convertNVTETensorCheck(output); + // Tensor *dbias_tensor = convertNVTETensor(dbias); + // Tensor *workspace_tensor = convertNVTETensor(workspace); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + // "Either rowwise or columnwise output data need to be allocated."); + + // // Dispatch to quantization kernel depending on data format + // switch (output_tensor->scaling_mode) { + // // case NVTE_DELAYED_TENSOR_SCALING: { + // // if (output_tensor->has_columnwise_data()) { + // // NVTE_CHECK(output_tensor->has_data(), + // // "Quantizing in only the columnwise direction not supported yet!"); + // // if constexpr (!IS_DBIAS && !IS_DACT) { + // // cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + // // } else { + // // cast_transpose_fused( + // // *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + // // } + // // } else if (output_tensor->has_data()) { + // // fp8::quantize( + // // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // // stream); + // // } + // // break; + // // } + // // case NVTE_MXFP8_1D_SCALING: { + // // mxfp8::quantize( + // // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // // stream); + // // break; + // // } + // // case NVTE_NVFP4_1D_SCALING: { + // // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // // "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // // // Check tensors + // // CheckNoopTensor(*noop_tensor, "cast_noop"); + // // CheckInputTensor(*grad_tensor, "input"); + // // CheckOutputTensor(*output_tensor, "output", false); + + // // // Choose kernel + // // int32_t rows = grad_tensor->flat_first_dim(); + // // int32_t cols = grad_tensor->flat_last_dim(); + // // auto dtype = grad_tensor->dtype(); + // // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + // // (cols % 32 == 0) && output_tensor->has_data(); + + // // // Launch NVFP4 quantize kernel + // // if (use_optimized_kernel) { + // // if (quant_config_cpp.nvfp4_2d_quantization) { + // // nvfp4::quantize_transpose( + // // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // // } else { + // // nvfp4::quantize_transpose( + // // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + // // } + // // } else { + // // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + // // : output_tensor->columnwise_amax; + // // quantize_transpose_vector_blockwise_fp4( + // // /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + // // /*scale_inv=*/output_tensor->scale_inv, + // // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + // // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + // // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + // // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + // // /*swizzled_scale=*/false, + // // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + // // /*rng_state=*/quant_config_cpp.rng_state, + // // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + // // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + // // } + // // break; + // // } + // // case NVTE_BLOCK_SCALING_2D: { + // // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + // // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // // float epsilon = quant_config_cpp.amax_epsilon; + // // quantize_transpose_square_blockwise( + // // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // // output_tensor->data, output_tensor->columnwise_data, epsilon, + // // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + // // /*noop_tensor=*/noop_tensor->data, stream); + // // break; + // // } + // // case NVTE_BLOCK_SCALING_1D: { + // // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + // // NVTE_CHECK((!IS_DBIAS && !IS_DACT), + // // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + // // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + // // float epsilon = quant_config_cpp.amax_epsilon; + // // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + // // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + // // if (output_tensor->has_data()) { + // // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + // // } + // // if (output_tensor->has_columnwise_data()) { + // // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + // // } + // // quantize_transpose_vector_blockwise( + // // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + // // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + // // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + // // break; + // // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + // } } // Host-aware and not graph-safe: group quantization with split section info from the host. @@ -314,64 +315,64 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - const Tensor *input_tensor = convertNVTETensorCheck(input); - std::vector output_tensors; - for (size_t i = 0; i < num_tensors; ++i) { - output_tensors.push_back(convertNVTETensorCheck(outputs[i])); - } - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - // Take the scaling mode of the first output tensor - auto scaling_mode = output_tensors[0]->scaling_mode; - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_NVFP4_1D_SCALING: { - NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - // Skip checking output tensor list - // output list here is allowed to have empty tensor - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - - NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "2D quantization is not supported for group quantize."); - - // Launch NVFP4 group quantize kernel - nvfp4::group_quantize_transpose( - *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, - &quant_config_cpp, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } + // using namespace detail; + + // const Tensor *input_tensor = convertNVTETensorCheck(input); + // std::vector output_tensors; + // for (size_t i = 0; i < num_tensors; ++i) { + // output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + // } + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Check for unsupported options + // if (quant_config_cpp.stochastic_rounding) { + // NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + // "Stochastic rounding is only supported for NVFP4 quantization."); + // } + + // // Take the scaling mode of the first output tensor + // auto scaling_mode = output_tensors[0]->scaling_mode; + + // // Dispatch to quantization kernel depending on data format + // switch (scaling_mode) { + // case NVTE_NVFP4_1D_SCALING: { + // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // // Check tensors + // CheckNoopTensor(*noop_tensor, "cast_noop"); + // CheckInputTensor(*input_tensor, "input"); + // // Skip checking output tensor list + // // output list here is allowed to have empty tensor + + // // Choose kernel + // int32_t rows = input_tensor->flat_first_dim(); + // int32_t cols = input_tensor->flat_last_dim(); + // auto dtype = input_tensor->dtype(); + + // NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + // "2D quantization is not supported for group quantize."); + + // // Launch NVFP4 group quantize kernel + // nvfp4::group_quantize_transpose( + // *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + // &quant_config_cpp, stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + // } } template @@ -407,7 +408,8 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( + // mxfp8::group_quantize( + mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; @@ -422,40 +424,40 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - using namespace detail; - - NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); - const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); - GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Dispatch to quantization kernel depending on data format - switch (scaling_mode) { - case NVTE_MXFP8_1D_SCALING: { - mxfp8::group_quantize( - grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - } + // using namespace detail; + + // NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + // const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + // const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + // GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + // GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + // Tensor *workspace_tensor = convertNVTETensor(workspace); + + // // Quantization config + // QuantizationConfig quant_config_cpp; + // if (quant_config != nullptr) { + // quant_config_cpp = *reinterpret_cast(quant_config); + // } + + // // Noop flag + // Tensor dummy_tensor; + // Tensor *noop_tensor = &dummy_tensor; + // if (quant_config_cpp.noop_tensor != nullptr) { + // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + // } + + // // Dispatch to quantization kernel depending on data format + // switch (scaling_mode) { + // case NVTE_MXFP8_1D_SCALING: { + // mxfp8::group_quantize( + // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + // stream); + // break; + // } + // default: + // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + // } } } // namespace dispatch From 5c570cd6474a7964da971850b704fa670787b442 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 17:23:01 +0000 Subject: [PATCH 17/51] Ready for benchmark - Regular kernel Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/quantize_mxfp8.cuh | 138 +++++++++--------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 70a68132ad..3142b39272 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -641,75 +641,75 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { - switch (scaling_type) { - case ScalingType::ROWWISE: { - using traits = specialized::CastTraits; - auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); - - dim3 block(traits::threadLayout::num, traits::warpLayout::N, - traits::warpLayout::M); - dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN, - (rows + traits::blockDimM - 1) / traits::blockDimM); - kernel<<>>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - - break; - } - case ScalingType::COLWISE: { - NVTE_WARN("Colwise scaling will fallback to original kernel."); - break; - } - case ScalingType::BIDIMENSIONAL: { - using traits = specialized::CastTraits; - auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - traits::smem); - // TMA for loading, so that we don't need STS for transposing - alignas(64) CUtensorMap tensor_map_input{}; - constexpr size_t input_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, - traits::blockIterDim::M, traits::blockIterDim::N, - /*stride_elems=*/cols, - /*offset_elems=*/0, input_type_bit_size, - traits::input_swizzle_pattern); - - alignas(64) CUtensorMap tensor_map_rowwise_output{}; - alignas(64) CUtensorMap tensor_map_colwise_output{}; - constexpr size_t output_type_bit_size = TypeInfo::size; - create_2D_tensor_map(tensor_map_rowwise_output, output->data, rows, cols, - traits::blockIterDim::M, traits::blockIterDim::N, - /*stride_elems=*/cols, - /*offset_elems=*/0, output_type_bit_size, - traits::output_swizzle_pattern); - create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data, rows, - cols, traits::blockIterDim::M, traits::blockIterDim::N, - cols, 0, output_type_bit_size, - traits::output_swizzle_pattern); - - dim3 block(traits::rowThreadLayout::num, traits::numWarps); - dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N, - (rows + traits::blockDIM::M - 1) / traits::blockDIM::M); - kernel<<>>( - tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - - break; - } - default: { - NVTE_ERROR("Invalid scaling type."); - } - } - return; - } + // if (specialized::hasSpec() && + // !WITH_GEMM_SWIZZLED_SCALES) { + // switch (scaling_type) { + // case ScalingType::ROWWISE: { + // using traits = specialized::CastTraits; + // auto kernel = specialized::quantize_mxfp8_kernel_cast_only; + + // cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + // traits::smem); + + // dim3 block(traits::threadLayout::num, traits::warpLayout::N, + // traits::warpLayout::M); + // dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN, + // (rows + traits::blockDimM - 1) / traits::blockDimM); + // kernel<<>>( + // reinterpret_cast(input.data.dptr), + // reinterpret_cast(output->data.dptr), + // scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + + // break; + // } + // case ScalingType::COLWISE: { + // NVTE_WARN("Colwise scaling will fallback to original kernel."); + // break; + // } + // case ScalingType::BIDIMENSIONAL: { + // using traits = specialized::CastTraits; + // auto kernel = specialized::quantize_mxfp8_kernel_cast_only; + + // cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + // traits::smem); + // // TMA for loading, so that we don't need STS for transposing + // alignas(64) CUtensorMap tensor_map_input{}; + // constexpr size_t input_type_bit_size = TypeInfo::size; + // create_2D_tensor_map(tensor_map_input, input.data, rows, cols, + // traits::blockIterDim::M, traits::blockIterDim::N, + // /*stride_elems=*/cols, + // /*offset_elems=*/0, input_type_bit_size, + // traits::input_swizzle_pattern); + + // alignas(64) CUtensorMap tensor_map_rowwise_output{}; + // alignas(64) CUtensorMap tensor_map_colwise_output{}; + // constexpr size_t output_type_bit_size = TypeInfo::size; + // create_2D_tensor_map(tensor_map_rowwise_output, output->data, rows, cols, + // traits::blockIterDim::M, traits::blockIterDim::N, + // /*stride_elems=*/cols, + // /*offset_elems=*/0, output_type_bit_size, + // traits::output_swizzle_pattern); + // create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data, rows, + // cols, traits::blockIterDim::M, traits::blockIterDim::N, + // cols, 0, output_type_bit_size, + // traits::output_swizzle_pattern); + + // dim3 block(traits::rowThreadLayout::num, traits::numWarps); + // dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N, + // (rows + traits::blockDIM::M - 1) / traits::blockDIM::M); + // kernel<<>>( + // tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output, + // scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + // scale_stride_colwise); + + // break; + // } + // default: { + // NVTE_ERROR("Invalid scaling type."); + // } + // } + // return; + // } alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; From c5b1f7db85fa457aad5c8f59b3f518fdc2dd0a8c Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 17:24:35 +0000 Subject: [PATCH 18/51] Added the source code to the profiler Signed-off-by: Oleg Goncharov --- transformer_engine/common/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b3d48f68bd..737b3d4108 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -384,6 +384,7 @@ endforeach() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-line-info") # Add source code mapping into the profiler output # Number of parallel build jobs if($ENV{MAX_JOBS}) From 3edcb5dc1b734bb23089954dbd2647bd39575b2b Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 18:00:18 +0000 Subject: [PATCH 19/51] Added constructors to Job and Block descriptors Signed-off-by: Oleg Goncharov --- tests/cpp/operator/CMakeLists.txt | 2 +- .../cast/mxfp8/group_quantize_mxfp8.cuh | 64 +++++++++++++------ 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 84c682134f..cf3e556556 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -10,7 +10,7 @@ add_executable(test_operator # test_cast_gated_swiglu.cu # test_cast_mxfp8_gated_swiglu.cu # test_qdq.cu - test_cast_mxfp8.cu + # test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu # test_cast_nvfp4_transpose.cu # test_cast_float8blockwise.cu diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 6e9bd3dc5e..044629c6d3 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -159,6 +159,17 @@ struct JobDescriptor { size_t tensor_id = 0; size_t rows = 0; size_t cols = 0; + + __host__ __device__ __forceinline__ constexpr JobDescriptor() = default; + + __host__ __device__ __forceinline__ constexpr JobDescriptor( + const size_t block_id_, const size_t block_global_offset_, const size_t tensor_id_, + const size_t rows_, const size_t cols_) + : block_id(block_id_), + block_global_offset(block_global_offset_), + tensor_id(tensor_id_), + rows(rows_), + cols(cols_) {} }; // Tensor-local coordinates for a work-item. @@ -169,6 +180,19 @@ struct BlockDescriptor { size_t block_id_X = 0; size_t block_offset_Y = 0; size_t block_offset_X = 0; + + __host__ __device__ __forceinline__ constexpr BlockDescriptor() = default; + + __host__ __device__ __forceinline__ constexpr BlockDescriptor( + const size_t tensor_base_, const size_t block_id_in_current_tensor_, + const size_t block_id_Y_, const size_t block_id_X_, const size_t block_offset_Y_, + const size_t block_offset_X_) + : tensor_base(tensor_base_), + block_id_in_current_tensor(block_id_in_current_tensor_), + block_id_Y(block_id_Y_), + block_id_X(block_id_X_), + block_offset_Y(block_offset_Y_), + block_offset_X(block_offset_X_) {} }; __device__ __forceinline__ JobDescriptor decode_job( @@ -177,17 +201,17 @@ __device__ __forceinline__ JobDescriptor decode_job( const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { - JobDescriptor job{}; - job.block_id = ctaid_Y * work_blocks_X + ctaid_X; - job.block_global_offset = is_single_tensor - ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) - : (job.block_id * ELTS_PER_CHUNK); - job.tensor_id = get_current_tensor_id(shape_rep, num_tensors, job.block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); - job.rows = - get_tensor_rows_num(job.tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - job.cols = get_tensor_cols_num(job.tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - return job; + const size_t block_id = ctaid_Y * work_blocks_X + ctaid_X; + const size_t block_global_offset = + is_single_tensor ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) + : (block_id * ELTS_PER_CHUNK); + const size_t tensor_id = + get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, + first_logical_dim, last_logical_dim, offsets_ptr); + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + return JobDescriptor(block_id, block_global_offset, tensor_id, rows, cols); } __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, @@ -218,17 +242,17 @@ __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, __device__ __forceinline__ BlockDescriptor decode_block(const JobDescriptor &job, const bool is_single_tensor, const int64_t *const __restrict__ offsets_ptr) { - BlockDescriptor block{}; - block.tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); const size_t CHUNK_DIM_X_ = CHUNK_DIM_X; const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_); - block.block_id_in_current_tensor = - is_single_tensor ? job.block_id : (job.block_id - block.tensor_base / ELTS_PER_CHUNK); - block.block_id_Y = block.block_id_in_current_tensor / blocks_X_num_in_current_tensor; - block.block_id_X = block.block_id_in_current_tensor % blocks_X_num_in_current_tensor; - block.block_offset_Y = block.block_id_Y * CHUNK_DIM_Y; - block.block_offset_X = block.block_id_X * CHUNK_DIM_X; - return block; + const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); + const size_t block_id_in_current_tensor = + is_single_tensor ? job.block_id : (job.block_id - tensor_base / ELTS_PER_CHUNK); + const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor; + const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor; + const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y; + const size_t block_offset_X = block_id_X * CHUNK_DIM_X; + return BlockDescriptor(tensor_base, block_id_in_current_tensor, block_id_Y, block_id_X, + block_offset_Y, block_offset_X); } // Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index From 6e00237b09cc5d3ebd31fee854ec8e203d6b7131 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 18:08:54 +0000 Subject: [PATCH 20/51] Removed the prefetch overlapping between jobs Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 152 +++++------------- 1 file changed, 37 insertions(+), 115 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 044629c6d3..e8837fcde5 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -744,47 +744,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel static_next_block_id = launch_block_id + static_block_stride; } bool job_finished = false; - int buff_in = 0; - bool has_prefetched_current_job = true; - // Prime the pipeline with stage-0 of the first job assigned to this CTA. - { - const JobDescriptor first_job = - decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, - work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - if (!is_job_valid(first_job, shape_rep, total_work_blocks, offsets_ptr)) { - return; - } - const BlockDescriptor first_block = decode_block(first_job, is_single_tensor, offsets_ptr); - - const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[first_job.tensor_id]; - const CUtensorMap &tensor_map_act_input = is_single_tensor - ? tensor_map_act_input_static - : g_tensor_maps_act_input[first_job.tensor_id]; - - if (leading_thread && (!is_single_tensor)) { - fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&tensor_map_act_input); - } - } - -#pragma unroll - for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { - const size_t buff = stage; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - const size_t global_offset_Y = first_block.block_offset_Y + stage_offset_Y; - const size_t global_offset_X = first_block.block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; - uint64_t *barrier = &IN_buff_readable_mbar[buff]; - prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, - global_offset_X, global_offset_Y, buff_offset, - shmem_buff_size, barrier, leading_thread); - } - } - - // Main work loop: decode current job, run all 32-row stages, schedule/prefetch next job. + // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. while (!job_finished) { // Decode CTA assignment into logical tensor coordinates and validate bounds. const JobDescriptor current_job = @@ -793,13 +754,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const bool current_job_is_valid = is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr); if (!current_job_is_valid) { - if (has_prefetched_current_job) { - // A stage-0 prefetch may already be in flight for this CTA. Drain it before exiting. - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); - IN_buff_readable_parity[buff_in] ^= 1; - ptx::cp_async_bulk_wait_group_read(); - } break; } @@ -864,6 +818,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } + int buff_in = 0; + + // Prime the pipeline with the first PREFETCH_STAGES slices of the current block. +#pragma unroll + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const size_t buff = stage; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + uint64_t *barrier = &IN_buff_readable_mbar[buff]; + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, + global_offset_X, global_offset_Y, buff_offset, + shmem_buff_size, barrier, leading_thread); + } + float partial_dbias_colwise = 0.0f; float thread_dbias_rowwise[SCALE_DIM_X]; if constexpr (IS_DBIAS) { @@ -873,83 +843,24 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - bool prefetched_next_job = false; // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { const size_t stage_offset_Y = stage * BUFF_DIM_Y; - bool allow_next_job_prefetch = true; - JobDescriptor prefetch_job = current_job; - BlockDescriptor prefetch_block = current_block; - - if (stage == STAGES - PREFETCH_STAGES) { - if constexpr (PERSISTENT) { - if (static_next_block_id < total_work_blocks) { - ctaid_X = static_cast(static_next_block_id % work_blocks_X); - ctaid_Y = static_cast(static_next_block_id / work_blocks_X); - static_next_block_id += static_block_stride; - } else { - // Next loop iteration exits via current_job_is_valid check. - ctaid_X = 0; - ctaid_Y = static_cast(work_blocks_Y); - allow_next_job_prefetch = false; - } - } else { - ctaid_X = -1; - ctaid_Y = -1; - } - if constexpr (!PERSISTENT) { - if (ctaid_X == -1 && ctaid_Y == -1) { - job_finished = true; - } - } - } - - // Stage-(STAGES - PREFETCH_STAGES) prefetches stage-0 of the next job. - // Validate that job before issuing TMA reads to avoid OOB accesses on graph-safe tails. - if ((stage >= STAGES - PREFETCH_STAGES) && allow_next_job_prefetch && !job_finished) { - prefetch_job = decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, - last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, - first_dims_ptr, last_dims_ptr); - allow_next_job_prefetch = - is_job_valid(prefetch_job, shape_rep, total_work_blocks, offsets_ptr); - if (allow_next_job_prefetch) { - prefetch_block = decode_block(prefetch_job, is_single_tensor, offsets_ptr); - } - } - - if ((stage < STAGES - PREFETCH_STAGES) || (allow_next_job_prefetch && !job_finished)) { + if (stage < STAGES - PREFETCH_STAGES) { const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; - const size_t next_prefetch_stage = (stage + PREFETCH_STAGES) % STAGES; + const size_t next_prefetch_stage = stage + PREFETCH_STAGES; const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; - if (stage >= STAGES - PREFETCH_STAGES) { - prefetched_next_job = true; - } - const size_t global_offset_Y = prefetch_block.block_offset_Y + next_prefetch_stage_offset_Y; - const size_t global_offset_X = prefetch_block.block_offset_X; + const size_t global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = block_offset_X; const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; - const CUtensorMap &prefetch_tensor_map_input = - is_single_tensor ? tensor_map_input_static - : g_tensor_maps_input[prefetch_job.tensor_id]; - const CUtensorMap &prefetch_tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static - : g_tensor_maps_act_input[prefetch_job.tensor_id]; - uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; - if (leading_thread) { - if ((!is_single_tensor) && (stage == STAGES - PREFETCH_STAGES)) { - fence_acquire_tensormap(&prefetch_tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&prefetch_tensor_map_act_input); - } - } - } - prefetch_input_stage(in_sh, act_in_sh, prefetch_tensor_map_input, - prefetch_tensor_map_act_input, global_offset_X, - global_offset_Y, next_prefetch_buff_offset, - shmem_buff_size, barrier, leading_thread); + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, + global_offset_X, global_offset_Y, + next_prefetch_buff_offset, shmem_buff_size, barrier, + leading_thread); ptx::fence_proxy_async_shared_cta(); } @@ -994,7 +905,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel buff_in = (buff_in + 1) % BUFFS_NUM; } - has_prefetched_current_job = prefetched_next_job; if constexpr (IS_DBIAS) { if (is_single_tensor) { @@ -1035,6 +945,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } } + + if constexpr (PERSISTENT) { + if (static_next_block_id < total_work_blocks) { + ctaid_X = static_cast(static_next_block_id % work_blocks_X); + ctaid_Y = static_cast(static_next_block_id / work_blocks_X); + static_next_block_id += static_block_stride; + } else { + job_finished = true; + } + } else { + job_finished = true; + } } if (amax_ptr != nullptr) { From 274f91ebe2f00b05343746438a8ef153c872173e Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 18:18:38 +0000 Subject: [PATCH 21/51] Cache tensor ID Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index e8837fcde5..2af68c61d2 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -744,6 +744,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel static_next_block_id = launch_block_id + static_block_stride; } bool job_finished = false; + size_t last_acquired_tensor_id = num_tensors; // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. while (!job_finished) { @@ -805,7 +806,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id]; - if (leading_thread && (!is_single_tensor)) { + if (leading_thread && (!is_single_tensor) && (last_acquired_tensor_id != tensor_id)) { fence_acquire_tensormap(&tensor_map_input); if constexpr (COMPUTE_ACTIVATIONS) { fence_acquire_tensormap(&tensor_map_act_input); @@ -816,6 +817,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel if constexpr (COLWISE_SCALING) { fence_acquire_tensormap(&tensor_map_output_colwise); } + last_acquired_tensor_id = tensor_id; } int buff_in = 0; From 38b7e4ee663d80286c97907f3abd038ea8aefe0c Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 19:02:50 +0000 Subject: [PATCH 22/51] ShapeRepresentation is not a template parameter Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 192 ++++++++++++------ 1 file changed, 133 insertions(+), 59 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 2af68c61d2..a6de0c7e7f 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -84,11 +84,12 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 +template __device__ __forceinline__ size_t get_current_tensor_id( - const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, - const size_t block_Y, const size_t first_logical_dim, const size_t last_logical_dim, + const size_t num_tensors, const size_t current_offset, const size_t block_Y, + const size_t first_logical_dim, const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr) { - if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) { + if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t current_row = block_Y * CHUNK_DIM_Y; const size_t rows_per_tensor = first_logical_dim / num_tensors; return current_row / rows_per_tensor; @@ -110,46 +111,80 @@ __device__ __forceinline__ size_t get_current_tensor_id( } } +template __device__ __forceinline__ size_t get_tensor_rows_num( - const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const size_t tensor_id, const size_t first_logical_dim, const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { size_t rows_num = 0; + if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_LAST_DIM) { + rows_num = first_logical_dim; + } else { + rows_num = static_cast(first_dims_ptr[tensor_id]); + } + if (rows_num % 128 != 0) { + NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + } + return rows_num; +} + +__device__ __forceinline__ size_t get_tensor_rows_num( + const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: - case ShapeRepresentation::VARYING_LAST_DIM: - rows_num = first_logical_dim; - break; + return get_tensor_rows_num( + tensor_id, first_logical_dim, first_dims_ptr, num_tensors); case ShapeRepresentation::VARYING_FIRST_DIM: + return get_tensor_rows_num( + tensor_id, first_logical_dim, first_dims_ptr, num_tensors); + case ShapeRepresentation::VARYING_LAST_DIM: + return get_tensor_rows_num( + tensor_id, first_logical_dim, first_dims_ptr, num_tensors); case ShapeRepresentation::VARYING_BOTH_DIMS: - rows_num = static_cast(first_dims_ptr[tensor_id]); - break; + return get_tensor_rows_num( + tensor_id, first_logical_dim, first_dims_ptr, num_tensors); } - if (rows_num % 128 != 0) { - NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + return 0; +} + +template +__device__ __forceinline__ size_t get_tensor_cols_num( + const size_t tensor_id, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { + size_t cols_num = 0; + if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM) { + cols_num = last_logical_dim; + } else { + cols_num = static_cast(last_dims_ptr[tensor_id]); + if (cols_num % 128 != 0) { + NVTE_DEVICE_ERROR( + "For non-single tensors, the last dimension of each tensor in a group " + "must be divisible by 128."); + } } - return rows_num; + return cols_num; } __device__ __forceinline__ size_t get_tensor_cols_num( const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim, const int64_t *const __restrict__ last_dims_ptr) { - size_t cols_num = 0; switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: + return get_tensor_cols_num( + tensor_id, last_logical_dim, last_dims_ptr); case ShapeRepresentation::VARYING_FIRST_DIM: - cols_num = last_logical_dim; - break; + return get_tensor_cols_num( + tensor_id, last_logical_dim, last_dims_ptr); case ShapeRepresentation::VARYING_LAST_DIM: + return get_tensor_cols_num( + tensor_id, last_logical_dim, last_dims_ptr); case ShapeRepresentation::VARYING_BOTH_DIMS: - cols_num = static_cast(last_dims_ptr[tensor_id]); - if (cols_num % 128 != 0) { - NVTE_DEVICE_ERROR( - "For non-single tensors, the last dimension of each tensor in a group " - "must be divisible by 128."); - } - break; + return get_tensor_cols_num( + tensor_id, last_logical_dim, last_dims_ptr); } - return cols_num; + return 0; } // Logical work-item decoded from CTA coordinates. @@ -195,33 +230,38 @@ struct BlockDescriptor { block_offset_X(block_offset_X_) {} }; +template __device__ __forceinline__ JobDescriptor decode_job( - const ShapeRepresentation shape_rep, const bool is_single_tensor, const size_t num_tensors, - const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, - const int32_t ctaid_X, const int32_t ctaid_Y, const int64_t *const __restrict__ offsets_ptr, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const size_t work_blocks_X, const int32_t ctaid_X, const int32_t ctaid_Y, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { + constexpr bool is_single_tensor = + (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); const size_t block_id = ctaid_Y * work_blocks_X + ctaid_X; const size_t block_global_offset = is_single_tensor ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) : (block_id * ELTS_PER_CHUNK); - const size_t tensor_id = - get_current_tensor_id(shape_rep, num_tensors, block_global_offset, ctaid_Y, - first_logical_dim, last_logical_dim, offsets_ptr); + const size_t tensor_id = get_current_tensor_id( + num_tensors, block_global_offset, ctaid_Y, first_logical_dim, last_logical_dim, offsets_ptr); const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); - const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); + get_tensor_rows_num(tensor_id, first_logical_dim, first_dims_ptr, num_tensors); + const size_t cols = get_tensor_cols_num(tensor_id, last_logical_dim, last_dims_ptr); return JobDescriptor(block_id, block_global_offset, tensor_id, rows, cols); } -__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, - const ShapeRepresentation shape_rep, - const size_t total_work_blocks, +template +__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, const size_t total_work_blocks, const int64_t *const __restrict__ offsets_ptr) { bool is_valid = (job.block_id < total_work_blocks) && (job.rows != 0) && (job.cols != 0); - if (!is_valid || shape_rep == SAME_BOTH_DIMS) { + if (!is_valid) { return is_valid; } + if constexpr (SHAPE_REP == SAME_BOTH_DIMS) { + return true; + } const size_t tensor_start_offset = static_cast(offsets_ptr[job.tensor_id]); const size_t tensor_end_offset = static_cast(offsets_ptr[job.tensor_id + 1]); @@ -239,9 +279,12 @@ __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, return true; } -__device__ __forceinline__ BlockDescriptor -decode_block(const JobDescriptor &job, const bool is_single_tensor, - const int64_t *const __restrict__ offsets_ptr) { +template +__device__ __forceinline__ BlockDescriptor decode_block( + const JobDescriptor &job, const int64_t *const __restrict__ offsets_ptr) { + constexpr bool is_single_tensor = + (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); const size_t CHUNK_DIM_X_ = CHUNK_DIM_X; const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_); const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); @@ -644,14 +687,15 @@ __device__ __forceinline__ float process_rowwise_stage( template + bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES, + ShapeRepresentation SHAPE_REP> __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, - const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, - const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, @@ -667,7 +711,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + constexpr ShapeRepresentation shape_rep = SHAPE_REP; + constexpr bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); const bool leading_thread = (threadIdx.x == 0); @@ -749,11 +794,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. while (!job_finished) { // Decode CTA assignment into logical tensor coordinates and validate bounds. - const JobDescriptor current_job = - decode_job(shape_rep, is_single_tensor, num_tensors, first_logical_dim, last_logical_dim, - work_blocks_X, ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const JobDescriptor current_job = decode_job( + num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, + offsets_ptr, first_dims_ptr, last_dims_ptr); const bool current_job_is_valid = - is_job_valid(current_job, shape_rep, total_work_blocks, offsets_ptr); + is_job_valid(current_job, total_work_blocks, offsets_ptr); if (!current_job_is_valid) { break; } @@ -761,7 +806,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t tensor_id = current_job.tensor_id; const size_t rows = current_job.rows; const size_t cols = current_job.cols; - const BlockDescriptor current_block = decode_block(current_job, is_single_tensor, offsets_ptr); + const BlockDescriptor current_block = decode_block(current_job, offsets_ptr); const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); @@ -1165,24 +1210,53 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations auto kernel = group_quantize_mxfp8_kernel; + true, true, WITH_GEMM_SWIZZLED_SCALES, + ShapeRepresentation::SAME_BOTH_DIMS>; + auto assign_kernel_for_shape = [&](auto rowwise_scaling, auto colwise_scaling) { + constexpr bool ROWWISE_SCALING_VALUE = decltype(rowwise_scaling)::value; + constexpr bool COLWISE_SCALING_VALUE = decltype(colwise_scaling)::value; + switch (shape_rep) { + case ShapeRepresentation::SAME_BOTH_DIMS: { + kernel = group_quantize_mxfp8_kernel< + IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, + ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, + ShapeRepresentation::SAME_BOTH_DIMS>; + break; + } + case ShapeRepresentation::VARYING_FIRST_DIM: { + kernel = group_quantize_mxfp8_kernel< + IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, + ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, + ShapeRepresentation::VARYING_FIRST_DIM>; + break; + } + case ShapeRepresentation::VARYING_LAST_DIM: { + kernel = group_quantize_mxfp8_kernel< + IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, + ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, + ShapeRepresentation::VARYING_LAST_DIM>; + break; + } + case ShapeRepresentation::VARYING_BOTH_DIMS: { + kernel = group_quantize_mxfp8_kernel< + IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, + ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, + ShapeRepresentation::VARYING_BOTH_DIMS>; + break; + } + } + }; switch (scaling_type) { case ScalingType::ROWWISE: { - kernel = - group_quantize_mxfp8_kernel; + assign_kernel_for_shape(std::true_type{}, std::false_type{}); break; } case ScalingType::COLWISE: { - kernel = - group_quantize_mxfp8_kernel; + assign_kernel_for_shape(std::false_type{}, std::true_type{}); break; } case ScalingType::BIDIMENSIONAL: { - kernel = - group_quantize_mxfp8_kernel; + assign_kernel_for_shape(std::true_type{}, std::true_type{}); break; } } @@ -1213,8 +1287,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations kernel<<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + tensor_map_output_colwise, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, work_blocks_Y); From 44052551d5fa2b072a5ed3008317d1953f0be684 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Mar 2026 19:33:06 +0000 Subject: [PATCH 23/51] Removed redundant fence_proxy Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index a6de0c7e7f..cebabd7cf3 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -46,7 +46,7 @@ struct TunableConfig { // false -> non-persistent one-job-per-CTA execution static constexpr bool PERSISTENT = true; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). - static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 4; + static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 24; }; constexpr bool PERSISTENT = TunableConfig::PERSISTENT; @@ -908,7 +908,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel global_offset_X, global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); - ptx::fence_proxy_async_shared_cta(); } ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], From 8cad6e69e1eddc9b4c8be666651ec6840d65032b Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Mon, 16 Mar 2026 15:19:01 +0000 Subject: [PATCH 24/51] Refactoring Signed-off-by: Oleg Goncharov --- .../common/cast/core/common.cuh | 24 ++ .../cast/mxfp8/group_quantize_mxfp8.cuh | 278 ++++++++---------- transformer_engine/common/common.h | 19 ++ 3 files changed, 163 insertions(+), 158 deletions(-) diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 9c16666db0..16db7ae856 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -30,6 +30,30 @@ enum ShapeRepresentation { VARYING_BOTH_DIMS = 3 }; + +#define TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(SHAPE_REP, SHAPE, ...) \ + switch (SHAPE_REP) { \ + case ShapeRepresentation::SAME_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::SAME_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_FIRST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_FIRST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_LAST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_LAST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported grouped tensor shape representation."); \ + } \ + } + inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index cebabd7cf3..29915159e5 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -338,7 +338,7 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te } template -__global__ void update_tma_descriptors( +__global__ void __launch_bounds__(1) update_tma_descriptors( const __grid_constant__ CUtensorMap base_tensor_map_input, const __grid_constant__ CUtensorMap base_tensor_map_act_input, const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, @@ -351,16 +351,19 @@ __global__ void update_tma_descriptors( const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, const bool compute_dactivations) { - const bool leading_thread = (threadIdx.x == 0); const size_t tensor_id = blockIdx.x; - - const size_t rows = - get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); - const size_t offset_elts = offsets_ptr[tensor_id]; - if (leading_thread && (tensor_id < num_tensors)) { + // Zero-sized groups: skip TMA descriptor update. The main kernel already returns + // early for rows==0 or cols==0, but creating a TMA descriptor with a zero dimension + // is invalid and causes CUDA_ERROR_ILLEGAL_ADDRESS. + if (rows == 0 || cols == 0) { + return; + } + + if (tensor_id < num_tensors) { { const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], @@ -686,8 +689,8 @@ __device__ __forceinline__ float process_rowwise_stage( } template __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, @@ -711,6 +714,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } + constexpr bool ROWWISE_SCALING = (SCALING_TYPE == ScalingType::ROWWISE) + || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool COLWISE_SCALING = (SCALING_TYPE == ScalingType::COLWISE) + || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr ShapeRepresentation shape_rep = SHAPE_REP; constexpr bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); @@ -1151,155 +1159,109 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } } - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input->dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - auto kernel = - group_quantize_mxfp8_kernel; - auto assign_kernel_for_shape = [&](auto rowwise_scaling, auto colwise_scaling) { - constexpr bool ROWWISE_SCALING_VALUE = decltype(rowwise_scaling)::value; - constexpr bool COLWISE_SCALING_VALUE = decltype(colwise_scaling)::value; - switch (shape_rep) { - case ShapeRepresentation::SAME_BOTH_DIMS: { - kernel = group_quantize_mxfp8_kernel< - IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, - ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, - ShapeRepresentation::SAME_BOTH_DIMS>; - break; - } - case ShapeRepresentation::VARYING_FIRST_DIM: { - kernel = group_quantize_mxfp8_kernel< - IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, - ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, - ShapeRepresentation::VARYING_FIRST_DIM>; - break; - } - case ShapeRepresentation::VARYING_LAST_DIM: { - kernel = group_quantize_mxfp8_kernel< - IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, - ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, - ShapeRepresentation::VARYING_LAST_DIM>; - break; - } - case ShapeRepresentation::VARYING_BOTH_DIMS: { - kernel = group_quantize_mxfp8_kernel< - IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, - ROWWISE_SCALING_VALUE, COLWISE_SCALING_VALUE, WITH_GEMM_SWIZZLED_SCALES, - ShapeRepresentation::VARYING_BOTH_DIMS>; - break; - } - } - }; - switch (scaling_type) { - case ScalingType::ROWWISE: { - assign_kernel_for_shape(std::true_type{}, std::false_type{}); - break; - } - case ScalingType::COLWISE: { - assign_kernel_for_shape(std::false_type{}, std::true_type{}); - break; - } - case ScalingType::BIDIMENSIONAL: { - assign_kernel_for_shape(std::true_type{}, std::true_type{}); - break; - } - } - - // Update tensor descriptors before launching the kernel - if (!is_single_tensor) { - const IType *const input_dptr = reinterpret_cast(input->data.dptr); - - const IType *const act_input_dptr = - IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; - - OType *const output_rowwise_dptr = - use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; - - OType *const output_colwise_dptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - update_tma_descriptors<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, - output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, - use_rowwise_scaling, use_colwise_scaling, IS_DACT); - } - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, - work_blocks_Y); - - if constexpr (IS_DBIAS) { - common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, - first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); - } - - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype(), OType, + TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH(scaling_type, SCALING_TYPE, + TRANSFORMER_ENGINE_SWITCH_CONDITION(with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(shape_rep, SHAPE_REP, + { + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, + use_rowwise_scaling, use_colwise_scaling, IS_DACT); + } + + auto kernel = + group_quantize_mxfp8_kernel; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, + work_blocks_Y); + + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError()); + } + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 41a8fd1112..2f561dbfa2 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -877,6 +877,25 @@ struct TypeInfo { { __VA_ARGS__ } \ } +#define TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH(SCALING_TYPE, SCALING_T, ...) \ + switch (SCALING_TYPE) { \ + case ScalingType::ROWWISE: { \ + constexpr ScalingType SCALING_T = ScalingType::ROWWISE; \ + { __VA_ARGS__ } \ + } break; \ + case ScalingType::COLWISE: { \ + constexpr ScalingType SCALING_T = ScalingType::COLWISE; \ + { __VA_ARGS__ } \ + } break; \ + case ScalingType::BIDIMENSIONAL: { \ + constexpr ScalingType SCALING_T = ScalingType::BIDIMENSIONAL; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported scaling type."); \ + } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { From c6622d467f6bbd8ae8745cba6334fc2cea122fd6 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 13:52:37 +0000 Subject: [PATCH 25/51] Used mixed precision FMA Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 71 +++++++++++++------ 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 29915159e5..c60f536bc2 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -657,32 +657,61 @@ __device__ __forceinline__ float process_rowwise_stage( scales_rowwise[scale_idx] = biased_exponent; } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + // const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const bf16 multiplier = static_cast(ptx::exp2f_rcp(biased_exponent)); + // const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; -#pragma unroll + #pragma unroll for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } + uint32_t out = 0; + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(out) + : "l"(reinterpret_cast(in_IType[w].data.elt[0])), + "h"(reinterpret_cast(multiplier))); + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + // out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out)); + + + // Vec out; + // #pragma unroll + // for (int e = 0; e < PACK_SIZE / 2; ++e) { + // IType2 in; + // OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + // if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + // in = in_IType[w].data.elt[e]; + // } else if constexpr (IS_CACHED_ACT_OP) { + // in.x = in_cached[w].data.elt[2 * e]; + // in.y = in_cached[w].data.elt[2 * e + 1]; + // } else { + // const int j = w * PACK_SIZE + 2 * e; + // in.x = in_compute_rowwise[j]; + // in.y = in_compute_rowwise[j + 1]; + // } + // ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + // } + // const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + // const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + // const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + // out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } return thread_amax; From e6a737c6a0cc43c8d103933069e8028e8d4796b8 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 14:56:02 +0000 Subject: [PATCH 26/51] Added Quantize parameters Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 25 ++- transformer_engine/common/cast/cast.cu | 9 + .../common/cast/dispatch/quantize.cuh | 4 +- .../cast/mxfp8/group_quantize_mxfp8.cuh | 178 +++++++++--------- .../common/include/transformer_engine/cast.h | 13 ++ 5 files changed, 136 insertions(+), 93 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index b49d68b1c4..ed58127b68 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -241,7 +241,8 @@ void performTest(const ProcessingMethod processing_method, const std::vector& last_dims_h, const std::vector& offsets_h, const bool rowwise, - const bool colwise) { + const bool colwise, + const bool use_fast_math) { using namespace test; DType itype = TypeInfo::dtype; @@ -499,11 +500,14 @@ void performTest(const ProcessingMethod processing_method, scales_stride_colwise); } + QuantizationConfigWrapper quant_config; + quant_config.set_use_fast_math(use_fast_math); + // GPU Tensor workspace; switch (processing_method) { case ProcessingMethod::CAST_ONLY: { - nvte_group_quantize(in_group_tensor, out_group_tensor, 0); + nvte_group_quantize_v2(in_group_tensor, out_group_tensor, quant_config, 0); break; } case ProcessingMethod::CAST_DBIAS: { @@ -675,7 +679,8 @@ class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam ScalingDirection, std::vector, // Config transformer_engine::DType, // InputType - transformer_engine::DType // OutputType + transformer_engine::DType, // OutputType + bool >> {}; TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { @@ -693,6 +698,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { const std::vector input_config = std::get<3>(GetParam()); const DType input_type = std::get<4>(GetParam()); const DType output_type = std::get<5>(GetParam()); + const bool use_fast_math = std::get<6>(GetParam()); const ShapeRepresentation shape_rep = static_cast(input_config[0]); const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); @@ -756,6 +762,10 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { GTEST_SKIP(); } + // Skip fused tests in fast math is enabled. + if ((processing_method != ProcessingMethod::CAST_ONLY) && use_fast_math) { + GTEST_SKIP(); + } bool rowwise = false; bool colwise = false; @@ -790,7 +800,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, performTest(processing_method, OP, shape_rep, num_tensors, logical_shape, first_dims, last_dims, offsets, - rowwise, colwise); + rowwise, colwise, use_fast_math); ); ); } @@ -827,7 +837,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), ::testing::Values(DType::kBFloat16), - ::testing::Values(DType::kFloat8E4M3)), + ::testing::Values(DType::kFloat8E4M3), + ::testing::Values(true, false)), // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { @@ -858,5 +869,9 @@ INSTANTIATE_TEST_SUITE_P( name += "_" + test::typeName(std::get<4>(info.param)) + "_" + test::typeName(std::get<5>(info.param)); + + if (std::get<6>(info.param)) { + name += "_FASTMATH"; + } return name; }); diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 4f9ddb4fc5..67b7b908e6 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -56,6 +56,15 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, dispatch::quantize_fwd_helper(input, output, quant_config, stream); } +void nvte_group_quantize_v2(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_group_quantize_v2); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::group_quantize_fwd_helper(input, output, quant_config, stream); +} + void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_dbias); diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 835ae149c8..9c17fcd98d 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -411,7 +411,7 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // mxfp8::group_quantize( mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); + workspace_tensor, &quant_config_cpp, stream); break; } default: @@ -452,7 +452,7 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe // case NVTE_MXFP8_1D_SCALING: { // mxfp8::group_quantize( // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // stream); + // &quant_config_cpp, stream); // break; // } // default: diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index c60f536bc2..663beeb298 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -450,7 +450,7 @@ __device__ __forceinline__ void store_output_stage(OType *out_rowwise_data_sh, template + bool WITH_GEMM_SWIZZLED_SCALES, bool USE_FAST_MATH> __device__ __forceinline__ float process_colwise_stage( const size_t buff, const int stage, const size_t tid_X_colwise, const size_t scales_offset_Y_colwise, const size_t scales_offset_X_colwise, @@ -544,7 +544,7 @@ __device__ __forceinline__ float process_colwise_stage( template + bool WITH_GEMM_SWIZZLED_SCALES, bool USE_FAST_MATH> __device__ __forceinline__ float process_rowwise_stage( const size_t buff, const size_t stage_offset_Y, const size_t thread_offset_Y_rowwise, const size_t thread_offset_X_rowwise, const int bank_group, @@ -720,7 +720,7 @@ __device__ __forceinline__ float process_rowwise_stage( template + ShapeRepresentation SHAPE_REP, bool USE_FAST_MATH> __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, @@ -956,7 +956,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { thread_amax = process_colwise_stage( + ROWWISE_SCALING, WITH_GEMM_SWIZZLED_SCALES, USE_FAST_MATH>( buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); @@ -964,7 +964,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel if constexpr (ROWWISE_SCALING) { thread_amax = process_rowwise_stage( + COLWISE_SCALING, WITH_GEMM_SWIZZLED_SCALES, USE_FAST_MATH>( buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, @@ -1066,9 +1066,12 @@ template void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, - Tensor *workspace, cudaStream_t stream) { + Tensor *workspace, const QuantizationConfig *quant_config, + cudaStream_t stream) { using namespace group_quantize_kernel; + const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -1193,99 +1196,102 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH(scaling_type, SCALING_TYPE, TRANSFORMER_ENGINE_SWITCH_CONDITION(with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(shape_rep, SHAPE_REP, - { - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; + TRANSFORMER_ENGINE_SWITCH_CONDITION(use_fast_math, USE_FAST_MATH, + { + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - output_type_bit_size); - } - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } - // Update tensor descriptors before launching the kernel - if (!is_single_tensor) { - const IType *const input_dptr = reinterpret_cast(input->data.dptr); + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, + use_rowwise_scaling, use_colwise_scaling, IS_DACT); + } - const IType *const act_input_dptr = - IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + auto kernel = + group_quantize_mxfp8_kernel; - OType *const output_rowwise_dptr = - use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - OType *const output_colwise_dptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - update_tma_descriptors<<>>( + kernel<<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, - output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, - use_rowwise_scaling, use_colwise_scaling, IS_DACT); - } - - auto kernel = - group_quantize_mxfp8_kernel; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, - work_blocks_Y); - - if constexpr (IS_DBIAS) { - common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, - first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); - } + tensor_map_output_colwise, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, + work_blocks_Y); + + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); + } - NVTE_CHECK_CUDA(cudaGetLastError()); - } + NVTE_CHECK_CUDA(cudaGetLastError()); + } + ); // NOLINT(*) ); // NOLINT(*) ); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 755052d6dd..02b88bfba6 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -124,6 +124,19 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no void nvte_quantize_v2(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! \brief Casts input grouped tensor to MXFP8. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. See file level comments. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. + * + * \param[in] input Input grouped tensor to be cast. + * \param[in,out] output Output grouped MXFP8 tensor. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_group_quantize_v2(const NVTEGroupedTensor input, NVTEGroupedTensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream); + /*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. From 7be113682e6c0cd921813f57d20ac1e9191c3055 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 16:20:34 +0000 Subject: [PATCH 27/51] Added the fast math branch Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 133 ++++++++++-------- transformer_engine/common/util/ptx.cuh | 99 ++++++++++++- 2 files changed, 170 insertions(+), 62 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 663beeb298..1a5e251845 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -457,6 +457,9 @@ __device__ __forceinline__ float process_colwise_stage( const size_t scale_stride_colwise, const size_t tensor_base_for_scales, const size_t rows, const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, OType *out_colwise_data_sh, e8m0_t *scales_colwise, float &partial_dbias_colwise) { + using IType4 = typename ptx::FPx4; + using OType4 = typename ptx::FPx4; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; @@ -524,21 +527,37 @@ __device__ __forceinline__ float process_colwise_stage( scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; + if constexpr (USE_FAST_MATH && NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + #pragma unroll + for (int i = 0; i < SCALE_DIM_Y; i += 4) { + OType4 out; + const IType4& in = *reinterpret_cast(&in_colwise_IType[i]); - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } + ptx::mul_cvt_4x(out, in, block_scale_inverse_f16); + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt + 0 * BUFF_DIM_X] = out.x1; + out_colwise_data_sh[shmem_offset_elt + 1 * BUFF_DIM_X] = out.x2; + out_colwise_data_sh[shmem_offset_elt + 2 * BUFF_DIM_X] = out.x3; + out_colwise_data_sh[shmem_offset_elt + 3 * BUFF_DIM_X] = out.x4; + } + } else { + #pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } return thread_amax; } @@ -553,7 +572,9 @@ __device__ __forceinline__ float process_rowwise_stage( IType *in_sh, IType *act_in_sh, IType *cached_act_sh, OType *out_rowwise_data_sh, e8m0_t *scales_rowwise, float *thread_dbias_rowwise) { using IType2 = typename ptx::FPx2; + using IType4 = typename ptx::FPx4; using OType2 = typename ptx::FPx2; + using OType4 = typename ptx::FPx4; constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; @@ -657,61 +678,45 @@ __device__ __forceinline__ float process_rowwise_stage( scales_rowwise[scale_idx] = biased_exponent; } - // const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const bf16 multiplier = static_cast(ptx::exp2f_rcp(biased_exponent)); - // const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; #pragma unroll for (int w = 0; w < WAVES; ++w) { - uint32_t out = 0; - asm volatile( - "{\n\t" - ".reg.b16 x0,x1,x2,x3; \n\t" - "mov.b64 {x0,x1,x2,x3}, %1; \n\t" - ".reg.f32 y0,y1,y2,y3; \n\t" - "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" - ".reg.b16 z01, z23; \n\t" - "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" - "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" - "mov.b32 %0, {z01, z23}; \n" - "}\n" - : "=r"(out) - : "l"(reinterpret_cast(in_IType[w].data.elt[0])), - "h"(reinterpret_cast(multiplier))); - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - // out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - - const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); - asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out)); - - - // Vec out; - // #pragma unroll - // for (int e = 0; e < PACK_SIZE / 2; ++e) { - // IType2 in; - // OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - // if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - // in = in_IType[w].data.elt[e]; - // } else if constexpr (IS_CACHED_ACT_OP) { - // in.x = in_cached[w].data.elt[2 * e]; - // in.y = in_cached[w].data.elt[2 * e + 1]; - // } else { - // const int j = w * PACK_SIZE + 2 * e; - // in.x = in_compute_rowwise[j]; - // in.y = in_compute_rowwise[j + 1]; - // } - // ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - // } - // const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - // const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - // const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - // out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + + if constexpr (USE_FAST_MATH && NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + uint32_t out_4x = 0; + OType4& out = *reinterpret_cast(&out_4x); + const IType4& in = *reinterpret_cast(&in_IType[w].data.elt[0]); + + ptx::mul_cvt_4x(out, in, block_scale_inverse_f16); + + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); + } else { + Vec out; + #pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } } return thread_amax; @@ -1071,6 +1076,12 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations using namespace group_quantize_kernel; const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + if (use_fast_math) { + NVTE_CHECK(input->dtype() == DType::kBFloat16 || input->dtype() == DType::kFloat16, + "Fast math supports only BF16 and FP16 input types."); + NVTE_CHECK(!IS_DBIAS && !IS_DACT && !IS_ACT, + "Fast math does not support fused casts."); + } checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 5367d7e781..634c2ccaaa 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -491,7 +491,7 @@ struct alignas(2 * sizeof(T)) FPx2 { }; template -struct FPx4 { +struct alignas(4 * sizeof(T)) FPx4 { T x1; T x2; T x3; @@ -1167,6 +1167,103 @@ __device__ __forceinline__ fp16 get_amax(fp16 a, fp16 b) { #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +// Using mixed precision FMA instruction +__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, const bf16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in, const bf16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in, const fp16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const fp16x4 &in, const fp16 scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + asm volatile( + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + __device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, const ptx::floatx2 &scale) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) From 4c2bed5358d30565bd0afab1e4ec00a4a4ba5617 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 16:26:16 +0000 Subject: [PATCH 28/51] Added the fast math to cpp test suite Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index ed58127b68..f86c43e3c2 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -58,7 +58,8 @@ void compute_ref(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) + const size_t scales_stride_colwise, + const bool use_fast_math) { const size_t tile_size_Y = 32; const size_t tile_size_X = 32; @@ -129,7 +130,10 @@ void compute_ref(const ProcessingMethod processing_method, const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); const size_t scale_idx = i * scales_stride_rowwise + tile_X; output_scales_rowwise[scale_idx] = biased_exponent; - const float scale_reciprocal = exp2f_rcp(biased_exponent); + float scale_reciprocal = exp2f_rcp(biased_exponent); + if (use_fast_math) { + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } for (size_t j = j_min; j < j_max; ++j) { const size_t idx = i * cols + j; @@ -150,7 +154,10 @@ void compute_ref(const ProcessingMethod processing_method, const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); const size_t scale_idx = tile_Y * scales_stride_colwise + j; output_scales_colwise[scale_idx] = biased_exponent; - const float scale_reciprocal = exp2f_rcp(biased_exponent); + float scale_reciprocal = exp2f_rcp(biased_exponent); + if (use_fast_math) { + scale_reciprocal = static_cast(static_cast(scale_reciprocal)); + } for (size_t i = i_min; i < i_max; ++i) { const size_t idx = i * cols + j; @@ -497,7 +504,8 @@ void performTest(const ProcessingMethod processing_method, out_scales_rowwise_ptr, out_scales_colwise_ptr, ref_output_dbias_ptr, M, K, scales_stride_rowwise, - scales_stride_colwise); + scales_stride_colwise, + use_fast_math); } QuantizationConfigWrapper quant_config; From e296b0baa0544facbf713fd46186550310a05db8 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 16:39:34 +0000 Subject: [PATCH 29/51] Align tests Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index f86c43e3c2..7005e92e18 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -846,7 +846,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(input_config), ::testing::Values(DType::kBFloat16), ::testing::Values(DType::kFloat8E4M3), - ::testing::Values(true, false)), + ::testing::Values(true)), // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), [](const testing::TestParamInfo& info) { From e63eee9fed8f5f12417d222c3ae4d3c4a0a47ea1 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 17:41:53 +0000 Subject: [PATCH 30/51] Use STS instead of generic ST Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 1a5e251845..472e4716c3 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -538,10 +538,29 @@ __device__ __forceinline__ float process_colwise_stage( ptx::mul_cvt_4x(out, in, block_scale_inverse_f16); const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt + 0 * BUFF_DIM_X] = out.x1; - out_colwise_data_sh[shmem_offset_elt + 1 * BUFF_DIM_X] = out.x2; - out_colwise_data_sh[shmem_offset_elt + 2 * BUFF_DIM_X] = out.x3; - out_colwise_data_sh[shmem_offset_elt + 3 * BUFF_DIM_X] = out.x4; + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_colwise_data_sh[shmem_offset_elt]); + + asm volatile( + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %0; \n\t" + "mov.u32 stride, %1; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b8 x0,x1,x2,x3; \n\t" + "mov.b32 {x0,x1,x2,x3}, %2; \n\t" + "st.shared.b8 [ptr0], x0; \n\t" + "st.shared.b8 [ptr1], x1; \n\t" + "st.shared.b8 [ptr2], x2; \n\t" + "st.shared.b8 [ptr3], x3; \n\t" + "}\n" + :: "r"(dst_smem_ptr), + "r"(static_cast(BUFF_DIM_X * sizeof(OType))), + "r"(reinterpret_cast(out)) + ); } } else { #pragma unroll From 68742063c69344c709d9c857b5d6e09f2eda9c03 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 18:43:52 +0000 Subject: [PATCH 31/51] Add zero-tensor cases Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 10 +++- .../cast/mxfp8/group_quantize_mxfp8.cuh | 46 +++++++++++++------ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 7005e92e18..5dbad0e115 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -280,9 +280,13 @@ void performTest(const ProcessingMethod processing_method, const size_t elts = M * K; elts_num += elts; + auto divide_round_up_blocks = [](const size_t N, const size_t M) -> size_t { + return (N == 0) ? 0 : 1 + (N - 1) / M; + }; + const size_t unpadded_rowwise_blocks_Y = M; - const size_t unpadded_rowwise_blocks_X = divide_round_up(K, 32); - const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32); + const size_t unpadded_rowwise_blocks_X = divide_round_up_blocks(K, 32); + const size_t unpadded_colwise_blocks_Y = divide_round_up_blocks(M, 32); const size_t unpadded_colwise_blocks_X = K; rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128); @@ -676,6 +680,8 @@ std::vector> input_config = { // {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + // // Empty tensor in the middle of the group must not terminate the persistent work loop. + // {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 472e4716c3..e0d81cfec0 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -255,9 +255,12 @@ __device__ __forceinline__ JobDescriptor decode_job( template __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, const size_t total_work_blocks, const int64_t *const __restrict__ offsets_ptr) { - bool is_valid = (job.block_id < total_work_blocks) && (job.rows != 0) && (job.cols != 0); + const bool is_valid = (job.block_id < total_work_blocks); if (!is_valid) { - return is_valid; + return false; + } + if (job.rows == 0 || job.cols == 0) { + return true; } if constexpr (SHAPE_REP == SAME_BOTH_DIMS) { return true; @@ -279,6 +282,26 @@ __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, const siz return true; } +__device__ __forceinline__ bool job_has_work(const JobDescriptor &job) { + return job.rows != 0 && job.cols != 0; +} + +__device__ __forceinline__ void advance_to_next_job( + bool &job_finished, int32_t &ctaid_X, int32_t &ctaid_Y, size_t &static_next_block_id, + const size_t static_block_stride, const size_t total_work_blocks, const size_t work_blocks_X) { + if constexpr (PERSISTENT) { + if (static_next_block_id < total_work_blocks) { + ctaid_X = static_cast(static_next_block_id % work_blocks_X); + ctaid_Y = static_cast(static_next_block_id / work_blocks_X); + static_next_block_id += static_block_stride; + } else { + job_finished = true; + } + } else { + job_finished = true; + } +} + template __device__ __forceinline__ BlockDescriptor decode_block( const JobDescriptor &job, const int64_t *const __restrict__ offsets_ptr) { @@ -863,6 +886,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel if (!current_job_is_valid) { break; } + if (!job_has_work(current_job)) { + // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, + static_block_stride, total_work_blocks, work_blocks_X); + continue; + } const size_t tensor_id = current_job.tensor_id; const size_t rows = current_job.rows; @@ -1053,17 +1082,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - if constexpr (PERSISTENT) { - if (static_next_block_id < total_work_blocks) { - ctaid_X = static_cast(static_next_block_id % work_blocks_X); - ctaid_Y = static_cast(static_next_block_id / work_blocks_X); - static_next_block_id += static_block_stride; - } else { - job_finished = true; - } - } else { - job_finished = true; - } + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, + static_block_stride, total_work_blocks, work_blocks_X); } if (amax_ptr != nullptr) { From a02c71ce74f71b725d1930f4436bfbfbaabef077 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 19:43:15 +0000 Subject: [PATCH 32/51] Used LDS instead of generic LD in colwise path Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 90 ++++++++++++++++--- 1 file changed, 80 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index e0d81cfec0..9367457717 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -480,19 +480,90 @@ __device__ __forceinline__ float process_colwise_stage( const size_t scale_stride_colwise, const size_t tensor_base_for_scales, const size_t rows, const size_t cols, IType *in_sh, IType *act_in_sh, IType *cached_act_sh, OType *out_colwise_data_sh, e8m0_t *scales_colwise, float &partial_dbias_colwise) { + using IType2 = typename ptx::FPx2; using IType4 = typename ptx::FPx4; using OType4 = typename ptx::FPx4; + constexpr uint32_t IN_SHMEM_STRIDE = static_cast(BUFF_DIM_X * sizeof(IType)); + constexpr uint32_t OUT_SHMEM_STRIDE = static_cast(BUFF_DIM_X * sizeof(OType)); + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; + constexpr bool NON_FP32_CAST_ONLY = NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; float thread_amax = 0.0f; float in_compute_colwise[BUFF_DIM_Y]; IType in_colwise_IType[BUFF_DIM_Y]; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + #pragma unroll + for (int i = 0; i < BUFF_DIM_Y; i += 4) { + IType4& in = *reinterpret_cast(&in_colwise_IType[i]); + + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_colwise]); + + // Load 4x elts S2R and find amax + if constexpr (std::is_same_v) { + asm volatile( + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %2; \n\t" + "mov.u32 stride, %3; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "ld.shared.b16 x0, [ptr0]; \n\t" + "ld.shared.b16 x1, [ptr1]; \n\t" + "ld.shared.b16 x2, [ptr2]; \n\t" + "ld.shared.b16 x3, [ptr3]; \n\t" + "mov.b64 %0, {x0,x1,x2,x3}; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b32 x01, {x0,x1}; \n\t" + "mov.b32 x23, {x2,x3}; \n\t" + "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.bf16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in)), "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE) + ); + } else { + asm volatile( + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %2; \n\t" + "mov.u32 stride, %3; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "ld.shared.b16 x0, [ptr0]; \n\t" + "ld.shared.b16 x1, [ptr1]; \n\t" + "ld.shared.b16 x2, [ptr2]; \n\t" + "ld.shared.b16 x3, [ptr3]; \n\t" + "mov.b64 %0, {x0,x1,x2,x3}; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b32 x01, {x0,x1}; \n\t" + "mov.b32 x23, {x2,x3}; \n\t" + "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.f16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in)), "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE) + ); + } + } + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + } else if constexpr (NON_FP32_CAST_ONLY) { IType thread_amax_f16 = static_cast(0.0f); #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { @@ -552,7 +623,7 @@ __device__ __forceinline__ float process_colwise_stage( const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); - if constexpr (USE_FAST_MATH && NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { #pragma unroll for (int i = 0; i < SCALE_DIM_Y; i += 4) { OType4 out; @@ -578,18 +649,16 @@ __device__ __forceinline__ float process_colwise_stage( "st.shared.b8 [ptr0], x0; \n\t" "st.shared.b8 [ptr1], x1; \n\t" "st.shared.b8 [ptr2], x2; \n\t" - "st.shared.b8 [ptr3], x3; \n\t" + "st.shared.b8 [ptr3], x3; \n" "}\n" - :: "r"(dst_smem_ptr), - "r"(static_cast(BUFF_DIM_X * sizeof(OType))), - "r"(reinterpret_cast(out)) + :: "r"(dst_smem_ptr), "r"(OUT_SHMEM_STRIDE), "r"(reinterpret_cast(out)) ); } } else { #pragma unroll for (int i = 0; i < SCALE_DIM_Y; ++i) { float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + if constexpr (NON_FP32_CAST_ONLY) { in = static_cast(in_colwise_IType[i]); } else { in = in_compute_colwise[i]; @@ -620,6 +689,7 @@ __device__ __forceinline__ float process_rowwise_stage( constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; + constexpr bool NON_FP32_CAST_ONLY = NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; float thread_amax = 0.0f; @@ -627,7 +697,7 @@ __device__ __forceinline__ float process_rowwise_stage( Vec in_cached[WAVES]; Vec in_IType[WAVES]; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + if constexpr (NON_FP32_CAST_ONLY) { IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int w = 0; w < WAVES; ++w) { @@ -730,7 +800,7 @@ __device__ __forceinline__ float process_rowwise_stage( const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - if constexpr (USE_FAST_MATH && NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { uint32_t out_4x = 0; OType4& out = *reinterpret_cast(&out_4x); const IType4& in = *reinterpret_cast(&in_IType[w].data.elt[0]); @@ -745,7 +815,7 @@ __device__ __forceinline__ float process_rowwise_stage( for (int e = 0; e < PACK_SIZE / 2; ++e) { IType2 in; OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + if constexpr (NON_FP32_CAST_ONLY) { in = in_IType[w].data.elt[e]; } else if constexpr (IS_CACHED_ACT_OP) { in.x = in_cached[w].data.elt[2 * e]; From 4c992b0705c628bc9ce1fad523013f6b2e6b9ea8 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 23:01:10 +0000 Subject: [PATCH 33/51] Used LDS instead of generic LD in rowwise Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 9367457717..f93b915fe1 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -496,13 +496,12 @@ __device__ __forceinline__ float process_colwise_stage( float thread_amax = 0.0f; float in_compute_colwise[BUFF_DIM_Y]; IType in_colwise_IType[BUFF_DIM_Y]; + IType4 in_colwise_IType4[BUFF_DIM_Y/4]; if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int i = 0; i < BUFF_DIM_Y; i += 4) { - IType4& in = *reinterpret_cast(&in_colwise_IType[i]); - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_colwise]); @@ -530,7 +529,8 @@ __device__ __forceinline__ float process_colwise_stage( "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" "max.xorsign.abs.bf16x2 %1, %1, x01; \n" "}\n" - : "=l"(reinterpret_cast(in)), "+r"(reinterpret_cast(thread_amax_2x)) + : "=l"(reinterpret_cast(in_colwise_IType4[i/4])), + "+r"(reinterpret_cast(thread_amax_2x)) : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE) ); } else { @@ -556,7 +556,8 @@ __device__ __forceinline__ float process_colwise_stage( "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" "max.xorsign.abs.f16x2 %1, %1, x01; \n" "}\n" - : "=l"(reinterpret_cast(in)), "+r"(reinterpret_cast(thread_amax_2x)) + : "=l"(reinterpret_cast(in_colwise_IType4[i/4])), + "+r"(reinterpret_cast(thread_amax_2x)) : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE) ); } @@ -627,9 +628,7 @@ __device__ __forceinline__ float process_colwise_stage( #pragma unroll for (int i = 0; i < SCALE_DIM_Y; i += 4) { OType4 out; - const IType4& in = *reinterpret_cast(&in_colwise_IType[i]); - - ptx::mul_cvt_4x(out, in, block_scale_inverse_f16); + ptx::mul_cvt_4x(out, in_colwise_IType4[i/4], block_scale_inverse_f16); const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_colwise_data_sh[shmem_offset_elt]); @@ -696,18 +695,36 @@ __device__ __forceinline__ float process_rowwise_stage( float in_compute_rowwise[SCALE_DIM_X]; Vec in_cached[WAVES]; Vec in_IType[WAVES]; + IType4 in_IType4[WAVES]; if constexpr (NON_FP32_CAST_ONLY) { IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll + #pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + if constexpr (USE_FAST_MATH) { + const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_rowwise]); + // Load 4x elts S2R and find amax + asm volatile( + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.bf16x2 %1, %1, x01; \n" + "}\n" + : "+l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr) + ); + } else { + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + #pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } } } thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); @@ -803,9 +820,7 @@ __device__ __forceinline__ float process_rowwise_stage( if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { uint32_t out_4x = 0; OType4& out = *reinterpret_cast(&out_4x); - const IType4& in = *reinterpret_cast(&in_IType[w].data.elt[0]); - - ptx::mul_cvt_4x(out, in, block_scale_inverse_f16); + ptx::mul_cvt_4x(out, in_IType4[w], block_scale_inverse_f16); const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); From 8ceeed05c686193bce2b2e56a953a6b0c0c477d1 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 17 Mar 2026 23:40:30 +0000 Subject: [PATCH 34/51] Ready for merge Signed-off-by: Oleg Goncharov --- tests/cpp/CMakeLists.txt | 3 +- tests/cpp/operator/CMakeLists.txt | 58 +- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 30 +- transformer_engine/common/CMakeLists.txt | 1 - .../common/cast/dispatch/dequantize.cuh | 52 +- .../common/cast/dispatch/gated.cuh | 304 ++++---- .../common/cast/dispatch/quantize.cuh | 648 +++++++++--------- 7 files changed, 546 insertions(+), 550 deletions(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 2092975b2a..6f4f163f08 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -8,8 +8,7 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) else () - # set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) - set(CMAKE_CUDA_ARCHITECTURES 100) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) endif() endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index cf3e556556..5e73675f4f 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -3,36 +3,36 @@ # See LICENSE for license information. add_executable(test_operator - # test_cast.cu - # test_cast_current_scaling.cu - # test_cast_dbias.cu - # test_cast_dbias_dgelu.cu - # test_cast_gated_swiglu.cu - # test_cast_mxfp8_gated_swiglu.cu - # test_qdq.cu - # test_cast_mxfp8.cu + test_cast.cu + test_cast_current_scaling.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu + test_cast_gated_swiglu.cu + test_cast_mxfp8_gated_swiglu.cu + test_qdq.cu + test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu - # test_cast_nvfp4_transpose.cu - # test_cast_float8blockwise.cu - # test_dequantize_mxfp8.cu - # test_transpose.cu - # test_cast_transpose.cu - # test_cast_transpose_current_scaling.cu - # test_cast_transpose_dbias.cu - # test_cast_transpose_dbias_dgelu.cu - # test_cast_transpose_dgeglu.cu - # test_act.cu - # test_normalization.cu - # test_normalization_mxfp8.cu - # test_memset.cu - # test_splits_to_offsets.cu - # test_multi_cast_transpose.cu - # test_multi_padding.cu - # test_multi_unpadding.cu - # test_causal_softmax.cu - # test_swizzle.cu - # test_swap_first_dims.cu - # test_grouped_gemm.cu + test_cast_nvfp4_transpose.cu + test_cast_float8blockwise.cu + test_dequantize_mxfp8.cu + test_transpose.cu + test_cast_transpose.cu + test_cast_transpose_current_scaling.cu + test_cast_transpose_dbias.cu + test_cast_transpose_dbias_dgelu.cu + test_cast_transpose_dgeglu.cu + test_act.cu + test_normalization.cu + test_normalization_mxfp8.cu + test_memset.cu + test_splits_to_offsets.cu + test_multi_cast_transpose.cu + test_multi_padding.cu + test_multi_unpadding.cu + test_causal_softmax.cu + test_swizzle.cu + test_swap_first_dims.cu + test_grouped_gemm.cu ../test_common.cu) # Find required packages diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 5dbad0e115..0b63c7510a 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -668,21 +668,21 @@ std::vector scaling_directions = { // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - {SAME_BOTH_DIMS, 1, 8192,7168}, - {VARYING_FIRST_DIM, 6, 8192,7168, 128,256,384,1024,2304,4096}, - {VARYING_FIRST_DIM, 6, 16*8192,7168, 128,256,384,1024,2304,4096}, - // {SAME_BOTH_DIMS, 1, 128,128}, - // {SAME_BOTH_DIMS, 2, 256,128}, - // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, - // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - // {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, - // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - // // Empty tensor in the middle of the group must not terminate the persistent work loop. - // {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, - // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + // {SAME_BOTH_DIMS, 1, 8192,7168}, + // {VARYING_FIRST_DIM, 6, 8192,7168, 128,256,384,1024,2304,4096}, + // {VARYING_FIRST_DIM, 6, 16*8192,7168, 128,256,384,1024,2304,4096}, + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + // Empty tensor in the middle of the group must not terminate the persistent work loop. + {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 737b3d4108..b3d48f68bd 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -384,7 +384,6 @@ endforeach() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --generate-line-info") # Add source code mapping into the profiler output # Number of parallel build jobs if($ENV{MAX_JOBS}) diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index db2ad285a8..81304981d3 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -22,32 +22,32 @@ namespace transformer_engine { namespace dispatch { inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - // CheckInputTensor(input, "cast_input"); - // CheckOutputTensor(*output, "cast_output"); - - // switch (input.scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); - // NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); - // NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); - // fp8::dequantize(input, output, stream); - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // if (is_supported_by_CC_100()) { - // mxfp8::dequantize(input, output, stream); - // } else { - // NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - // } - // break; - // } - // case NVTE_NVFP4_1D_SCALING: { - // nvfp4::dequantize(input, output, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - // } + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->dtype()), "Output must be in higher precision."); + NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); + fp8::dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + mxfp8::dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + nvfp4::dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } } } // namespace dispatch diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index c2087533a6..06e8f0e306 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -25,164 +25,164 @@ namespace dispatch { template void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - // const Tensor input = *convertNVTETensorCheck(nvte_input); - // Tensor *output = convertNVTETensorCheck(nvte_output); - - // CheckInputTensor(input, "input"); - // CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - // const size_t rows = input.flat_first_dim(); - // const size_t cols = input.flat_last_dim() / 2; - - // NVTE_CHECK(input.flat_last_dim() % 2 == 0, - // "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - // input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - // NVTE_CHECK(output->flat_last_dim() == cols, - // "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", - // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // switch (output->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - // if (use_tma_kernels) { - // Tensor dummy_grad_tensor; - // fp8::cast_gated_tma(input, dummy_grad_tensor, - // output, p, stream); - // } else { - // fp8::cast_gated_fwd(input, output, p, stream); - // } - // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // // FP8 kernel only populates row-wise data, so perform - // // transpose separately if needed - // Tensor transpose_in, transpose_out, dummy; - // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_in.data.dptr = output->data.dptr; - // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - // transpose_in.data.dtype = output->data.dtype; - // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_out.data.dptr = output->columnwise_data.dptr; - // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - // transpose_out.data.dtype = output->data.dtype; - // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // NVTE_CHECK(cols % 32 == 0, - // "Invalid input shape. Expected the last dimension to be " - // "divisible by 32, but got ", - // cols, "."); - // if (output->has_data()) { - // NVTE_CHECK(is_fp8_dtype(output->data.dtype), - // "The type of the output tensor should be FP8."); - // } - // if (output->has_columnwise_data()) { - // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - // "The type of the columnwise output tensor should be FP8."); - // } - // NVTE_CHECK(is_supported_by_CC_100(), - // "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - // Tensor dummy_grad_tensor; - // mxfp8::quantize_gated(input, dummy_grad_tensor, - // output, p, stream); - // break; - // } - // default: - // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - // } + const Tensor input = *convertNVTETensorCheck(nvte_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim() / 2; + + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols, + "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + Tensor dummy_grad_tensor; + fp8::cast_gated_tma(input, dummy_grad_tensor, + output, p, stream); + } else { + fp8::cast_gated_fwd(input, output, p, stream); + } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + Tensor dummy_grad_tensor; + mxfp8::quantize_gated(input, dummy_grad_tensor, + output, p, stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } } template void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { - // const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); - // const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); - // Tensor *output = convertNVTETensorCheck(nvte_output); - - // CheckInputTensor(grad, "grad"); - // CheckInputTensor(gated_input, "gated_input"); - // CheckOutputTensor(*output, "output", /*allow_empty=*/false); - - // NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", - // gated_input.flat_last_dim(), "."); - - // const size_t rows = gated_input.flat_first_dim(); - // const size_t cols = gated_input.flat_last_dim() / 2; - - // NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); - // NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); - - // NVTE_CHECK(grad.flat_first_dim() == rows, - // "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", - // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - // NVTE_CHECK(grad.flat_last_dim() == cols, - // "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", - // grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); - - // NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", - // rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - // NVTE_CHECK(output->flat_last_dim() == cols * 2, - // "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", - // output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - // NVTE_CHECK(gated_input.shape() == output->shape(), - // "Gated input and output shapes must match. Input shape: ", gated_input.shape(), - // ", output shape: ", output->shape(), "."); - - // switch (output->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); - // if (use_tma_kernels) { - // fp8::cast_gated_tma(gated_input, grad, output, p, - // stream); - // } else { - // fp8::cast_gated_bwd(gated_input, grad, output, p, stream); - // } - // if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { - // // FP8 kernel only populates row-wise data, so perform - // // transpose separately if needed - // Tensor transpose_in, transpose_out, dummy; - // transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_in.data.dptr = output->data.dptr; - // transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; - // transpose_in.data.dtype = output->data.dtype; - // transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; - // transpose_out.data.dptr = output->columnwise_data.dptr; - // transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; - // transpose_out.data.dtype = output->data.dtype; - // detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); - // } - // break; - // } - // case NVTE_MXFP8_1D_SCALING: { - // NVTE_CHECK(cols % 32 == 0, - // "Invalid input shape. Expected the last dimension to be " - // "divisible by 32, but got ", - // cols, "."); - // if (output->has_data()) { - // NVTE_CHECK(is_fp8_dtype(output->data.dtype), - // "The type of the output tensor should be FP8."); - // } - // if (output->has_columnwise_data()) { - // NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), - // "The type of the columnwise output tensor should be FP8."); - // } - // NVTE_CHECK(is_supported_by_CC_100(), - // "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); - - // mxfp8::quantize_gated(gated_input, grad, output, p, - // stream); - // break; - // } - // default: - // NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); - // } + const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(grad, "grad"); + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + gated_input.flat_last_dim(), "."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + + NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision."); + NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match."); + + NVTE_CHECK(grad.flat_first_dim() == rows, + "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + NVTE_CHECK(grad.flat_last_dim() == cols, + "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols * 2, + "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(gated_input.shape() == output->shape(), + "Gated input and output shapes must match. Input shape: ", gated_input.shape(), + ", output shape: ", output->shape(), "."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + fp8::cast_gated_tma(gated_input, grad, output, p, + stream); + } else { + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + } + if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) { + // FP8 kernel only populates row-wise data, so perform + // transpose separately if needed + Tensor transpose_in, transpose_out, dummy; + transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_in.data.dptr = output->data.dptr; + transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()}; + transpose_in.data.dtype = output->data.dtype; + transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + transpose_out.data.dptr = output->columnwise_data.dptr; + transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()}; + transpose_out.data.dtype = output->data.dtype; + detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + mxfp8::quantize_gated(gated_input, grad, output, p, + stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } } } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9c17fcd98d..8d985f64f3 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -59,110 +59,109 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, // Dispatch to quantization kernel depending on data format switch (output_tensor->scaling_mode) { - // case NVTE_DELAYED_TENSOR_SCALING: { - // const Tensor *dummy_input_tensor = nullptr; - // Tensor *dummy_dbias_tensor = nullptr; - // Tensor *dummy_workspace_tensor = nullptr; - // if (output_tensor->has_columnwise_data()) { - // NVTE_CHECK(output_tensor->has_data(), - // "Quantizing in only the columnwise direction not supported yet!"); - // if constexpr (!IS_ACT) { - // cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - // } else { - // cast_transpose_fused( - // *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // } - // } else if (output_tensor->has_data()) { - // fp8::quantize( - // *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - // dummy_workspace_tensor, stream); - // } - // break; - // } + case NVTE_DELAYED_TENSOR_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_ACT) { + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + break; + } case NVTE_MXFP8_1D_SCALING: { const Tensor *dummy_input_tensor = nullptr; Tensor *dummy_dbias_tensor = nullptr; Tensor *dummy_workspace_tensor = nullptr; - // mxfp8::quantize( - mxfp8::quantize( + mxfp8::quantize( *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, dummy_workspace_tensor, stream); break; } - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*input_tensor, "input"); - // CheckOutputTensor(*output_tensor, "output", false); - - // // Choose kernel - // int32_t rows = input_tensor->flat_first_dim(); - // int32_t cols = input_tensor->flat_last_dim(); - // auto dtype = input_tensor->dtype(); - // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - // (cols % 32 == 0) && output_tensor->has_data(); - - // // Launch NVFP4 quantize kernel - // if (use_optimized_kernel) { - // if (quant_config_cpp.nvfp4_2d_quantization) { - // nvfp4::quantize_transpose( - // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } else { - // nvfp4::quantize_transpose( - // *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // } - // } else { - // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - // : output_tensor->columnwise_amax; - // quantize_transpose_vector_blockwise_fp4( - // /*input=*/input_tensor->data, /*global_amax=*/global_amax, - // /*scale_inv=*/output_tensor->scale_inv, - // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - // /*swizzled_scale=*/false, - // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - // /*rng_state=*/quant_config_cpp.rng_state, - // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - // } - // break; - // } - // case NVTE_BLOCK_SCALING_2D: { - // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // quantize_transpose_square_blockwise( - // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, - // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - // /*noop_tensor=*/noop_tensor->data, stream); - // break; - // } - // case NVTE_BLOCK_SCALING_1D: { - // // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. - // NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); - // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // float epsilon = quant_config_cpp.amax_epsilon; - // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - // if (output_tensor->has_data()) { - // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - // } - // if (output_tensor->has_columnwise_data()) { - // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - // } - // quantize_transpose_vector_blockwise( - // input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - // break; - // } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } @@ -172,141 +171,141 @@ template (quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), - // "Either rowwise or columnwise output data need to be allocated."); - - // // Dispatch to quantization kernel depending on data format - // switch (output_tensor->scaling_mode) { - // // case NVTE_DELAYED_TENSOR_SCALING: { - // // if (output_tensor->has_columnwise_data()) { - // // NVTE_CHECK(output_tensor->has_data(), - // // "Quantizing in only the columnwise direction not supported yet!"); - // // if constexpr (!IS_DBIAS && !IS_DACT) { - // // cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); - // // } else { - // // cast_transpose_fused( - // // *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); - // // } - // // } else if (output_tensor->has_data()) { - // // fp8::quantize( - // // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // // stream); - // // } - // // break; - // // } - // // case NVTE_MXFP8_1D_SCALING: { - // // mxfp8::quantize( - // // *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // // stream); - // // break; - // // } - // // case NVTE_NVFP4_1D_SCALING: { - // // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // // "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); - - // // // Check tensors - // // CheckNoopTensor(*noop_tensor, "cast_noop"); - // // CheckInputTensor(*grad_tensor, "input"); - // // CheckOutputTensor(*output_tensor, "output", false); - - // // // Choose kernel - // // int32_t rows = grad_tensor->flat_first_dim(); - // // int32_t cols = grad_tensor->flat_last_dim(); - // // auto dtype = grad_tensor->dtype(); - // // bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - // // (cols % 32 == 0) && output_tensor->has_data(); - - // // // Launch NVFP4 quantize kernel - // // if (use_optimized_kernel) { - // // if (quant_config_cpp.nvfp4_2d_quantization) { - // // nvfp4::quantize_transpose( - // // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // // } else { - // // nvfp4::quantize_transpose( - // // *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - // // } - // // } else { - // // auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - // // : output_tensor->columnwise_amax; - // // quantize_transpose_vector_blockwise_fp4( - // // /*input=*/grad_tensor->data, /*global_amax=*/global_amax, - // // /*scale_inv=*/output_tensor->scale_inv, - // // /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - // // /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - // // /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - // // /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - // // /*swizzled_scale=*/false, - // // /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - // // /*rng_state=*/quant_config_cpp.rng_state, - // // /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - // // /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - // // } - // // break; - // // } - // // case NVTE_BLOCK_SCALING_2D: { - // // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - // // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); - // // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // // float epsilon = quant_config_cpp.amax_epsilon; - // // quantize_transpose_square_blockwise( - // // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // // output_tensor->data, output_tensor->columnwise_data, epsilon, - // // /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - // // /*noop_tensor=*/noop_tensor->data, stream); - // // break; - // // } - // // case NVTE_BLOCK_SCALING_1D: { - // // // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. - // // NVTE_CHECK((!IS_DBIAS && !IS_DACT), - // // "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); - // // bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - // // float epsilon = quant_config_cpp.amax_epsilon; - // // FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - // // FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - // // if (output_tensor->has_data()) { - // // rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - // // } - // // if (output_tensor->has_columnwise_data()) { - // // columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - // // } - // // quantize_transpose_vector_blockwise( - // // grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - // // output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - // // columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - // // break; - // // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *grad_tensor = convertNVTETensorCheck(grad); + const Tensor *input_tensor = convertNVTETensor(input); + + Tensor *output_tensor = convertNVTETensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT) { + cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*grad_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = grad_tensor->flat_first_dim(); + int32_t cols = grad_tensor->flat_last_dim(); + auto dtype = grad_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } } // Host-aware and not graph-safe: group quantization with split section info from the host. @@ -315,64 +314,64 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // const Tensor *input_tensor = convertNVTETensorCheck(input); - // std::vector output_tensors; - // for (size_t i = 0; i < num_tensors; ++i) { - // output_tensors.push_back(convertNVTETensorCheck(outputs[i])); - // } - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Check for unsupported options - // if (quant_config_cpp.stochastic_rounding) { - // NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, - // "Stochastic rounding is only supported for NVFP4 quantization."); - // } - - // // Take the scaling mode of the first output tensor - // auto scaling_mode = output_tensors[0]->scaling_mode; - - // // Dispatch to quantization kernel depending on data format - // switch (scaling_mode) { - // case NVTE_NVFP4_1D_SCALING: { - // NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); - - // // Check tensors - // CheckNoopTensor(*noop_tensor, "cast_noop"); - // CheckInputTensor(*input_tensor, "input"); - // // Skip checking output tensor list - // // output list here is allowed to have empty tensor - - // // Choose kernel - // int32_t rows = input_tensor->flat_first_dim(); - // int32_t cols = input_tensor->flat_last_dim(); - // auto dtype = input_tensor->dtype(); - - // NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - // "2D quantization is not supported for group quantize."); - - // // Launch NVFP4 group quantize kernel - // nvfp4::group_quantize_transpose( - // *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, - // &quant_config_cpp, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - // } + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + std::vector output_tensors; + for (size_t i = 0; i < num_tensors; ++i) { + output_tensors.push_back(convertNVTETensorCheck(outputs[i])); + } + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + // Take the scaling mode of the first output tensor + auto scaling_mode = output_tensors[0]->scaling_mode; + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + // Skip checking output tensor list + // output list here is allowed to have empty tensor + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "2D quantization is not supported for group quantize."); + + // Launch NVFP4 group quantize kernel + nvfp4::group_quantize_transpose( + *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors, + &quant_config_cpp, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } } template @@ -408,8 +407,7 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor // Dispatch to quantization kernel depending on data format switch (scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - // mxfp8::group_quantize( - mxfp8::group_quantize( + mxfp8::group_quantize( input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, &quant_config_cpp, stream); break; @@ -424,40 +422,40 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - // using namespace detail; - - // NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); - - // const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); - // const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); - // GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - // GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); - // Tensor *workspace_tensor = convertNVTETensor(workspace); - - // // Quantization config - // QuantizationConfig quant_config_cpp; - // if (quant_config != nullptr) { - // quant_config_cpp = *reinterpret_cast(quant_config); - // } - - // // Noop flag - // Tensor dummy_tensor; - // Tensor *noop_tensor = &dummy_tensor; - // if (quant_config_cpp.noop_tensor != nullptr) { - // noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - // } - - // // Dispatch to quantization kernel depending on data format - // switch (scaling_mode) { - // case NVTE_MXFP8_1D_SCALING: { - // mxfp8::group_quantize( - // grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - // &quant_config_cpp, stream); - // break; - // } - // default: - // NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); - // } + using namespace detail; + + NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); + + const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); + const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); + GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Dispatch to quantization kernel depending on data format + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + mxfp8::group_quantize( + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + &quant_config_cpp, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + "."); + } } } // namespace dispatch From f119d1f97718d1691c4010e0370bc573f6071dad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 23:46:06 +0000 Subject: [PATCH 35/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/core/common.cuh | 43 +- .../cast/mxfp8/group_quantize_mxfp8.cuh | 585 +++++++++--------- transformer_engine/common/common.h | 34 +- transformer_engine/common/util/ptx.cuh | 128 ++-- 4 files changed, 404 insertions(+), 386 deletions(-) diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 16db7ae856..53e72e42b8 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -30,28 +30,27 @@ enum ShapeRepresentation { VARYING_BOTH_DIMS = 3 }; - -#define TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(SHAPE_REP, SHAPE, ...) \ - switch (SHAPE_REP) { \ - case ShapeRepresentation::SAME_BOTH_DIMS: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::SAME_BOTH_DIMS; \ - { __VA_ARGS__ } \ - } break; \ - case ShapeRepresentation::VARYING_FIRST_DIM: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_FIRST_DIM; \ - { __VA_ARGS__ } \ - } break; \ - case ShapeRepresentation::VARYING_LAST_DIM: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_LAST_DIM; \ - { __VA_ARGS__ } \ - } break; \ - case ShapeRepresentation::VARYING_BOTH_DIMS: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_BOTH_DIMS; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Unsupported grouped tensor shape representation."); \ - } \ +#define TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(SHAPE_REP, SHAPE, ...) \ + switch (SHAPE_REP) { \ + case ShapeRepresentation::SAME_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::SAME_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_FIRST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_FIRST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_LAST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_LAST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported grouped tensor shape representation."); \ + } \ } inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index f93b915fe1..4f7e77c65a 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -85,10 +85,10 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 template -__device__ __forceinline__ size_t get_current_tensor_id( - const size_t num_tensors, const size_t current_offset, const size_t block_Y, - const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr) { +__device__ __forceinline__ size_t +get_current_tensor_id(const size_t num_tensors, const size_t current_offset, const size_t block_Y, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr) { if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS) { const size_t current_row = block_Y * CHUNK_DIM_Y; const size_t rows_per_tensor = first_logical_dim / num_tensors; @@ -112,9 +112,9 @@ __device__ __forceinline__ size_t get_current_tensor_id( } template -__device__ __forceinline__ size_t get_tensor_rows_num( - const size_t tensor_id, const size_t first_logical_dim, - const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { +__device__ __forceinline__ size_t +get_tensor_rows_num(const size_t tensor_id, const size_t first_logical_dim, + const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { size_t rows_num = 0; if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || SHAPE_REP == ShapeRepresentation::VARYING_LAST_DIM) { @@ -133,8 +133,8 @@ __device__ __forceinline__ size_t get_tensor_rows_num( const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) { switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: - return get_tensor_rows_num( - tensor_id, first_logical_dim, first_dims_ptr, num_tensors); + return get_tensor_rows_num(tensor_id, first_logical_dim, + first_dims_ptr, num_tensors); case ShapeRepresentation::VARYING_FIRST_DIM: return get_tensor_rows_num( tensor_id, first_logical_dim, first_dims_ptr, num_tensors); @@ -149,9 +149,9 @@ __device__ __forceinline__ size_t get_tensor_rows_num( } template -__device__ __forceinline__ size_t get_tensor_cols_num( - const size_t tensor_id, const size_t last_logical_dim, - const int64_t *const __restrict__ last_dims_ptr) { +__device__ __forceinline__ size_t +get_tensor_cols_num(const size_t tensor_id, const size_t last_logical_dim, + const int64_t *const __restrict__ last_dims_ptr) { size_t cols_num = 0; if constexpr (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM) { @@ -172,14 +172,14 @@ __device__ __forceinline__ size_t get_tensor_cols_num( const int64_t *const __restrict__ last_dims_ptr) { switch (shape_rep) { case ShapeRepresentation::SAME_BOTH_DIMS: - return get_tensor_cols_num( - tensor_id, last_logical_dim, last_dims_ptr); + return get_tensor_cols_num(tensor_id, last_logical_dim, + last_dims_ptr); case ShapeRepresentation::VARYING_FIRST_DIM: return get_tensor_cols_num( tensor_id, last_logical_dim, last_dims_ptr); case ShapeRepresentation::VARYING_LAST_DIM: - return get_tensor_cols_num( - tensor_id, last_logical_dim, last_dims_ptr); + return get_tensor_cols_num(tensor_id, last_logical_dim, + last_dims_ptr); case ShapeRepresentation::VARYING_BOTH_DIMS: return get_tensor_cols_num( tensor_id, last_logical_dim, last_dims_ptr); @@ -197,9 +197,11 @@ struct JobDescriptor { __host__ __device__ __forceinline__ constexpr JobDescriptor() = default; - __host__ __device__ __forceinline__ constexpr JobDescriptor( - const size_t block_id_, const size_t block_global_offset_, const size_t tensor_id_, - const size_t rows_, const size_t cols_) + __host__ __device__ __forceinline__ constexpr JobDescriptor(const size_t block_id_, + const size_t block_global_offset_, + const size_t tensor_id_, + const size_t rows_, + const size_t cols_) : block_id(block_id_), block_global_offset(block_global_offset_), tensor_id(tensor_id_), @@ -219,9 +221,8 @@ struct BlockDescriptor { __host__ __device__ __forceinline__ constexpr BlockDescriptor() = default; __host__ __device__ __forceinline__ constexpr BlockDescriptor( - const size_t tensor_base_, const size_t block_id_in_current_tensor_, - const size_t block_id_Y_, const size_t block_id_X_, const size_t block_offset_Y_, - const size_t block_offset_X_) + const size_t tensor_base_, const size_t block_id_in_current_tensor_, const size_t block_id_Y_, + const size_t block_id_X_, const size_t block_offset_Y_, const size_t block_offset_X_) : tensor_base(tensor_base_), block_id_in_current_tensor(block_id_in_current_tensor_), block_id_Y(block_id_Y_), @@ -234,12 +235,10 @@ template __device__ __forceinline__ JobDescriptor decode_job( const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, const size_t work_blocks_X, const int32_t ctaid_X, const int32_t ctaid_Y, - const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr) { - constexpr bool is_single_tensor = - (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || - SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); + constexpr bool is_single_tensor = (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); const size_t block_id = ctaid_Y * work_blocks_X + ctaid_X; const size_t block_global_offset = is_single_tensor ? (ctaid_Y * CHUNK_DIM_Y * last_logical_dim + ctaid_X * CHUNK_DIM_X) @@ -253,7 +252,8 @@ __device__ __forceinline__ JobDescriptor decode_job( } template -__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, const size_t total_work_blocks, +__device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, + const size_t total_work_blocks, const int64_t *const __restrict__ offsets_ptr) { const bool is_valid = (job.block_id < total_work_blocks); if (!is_valid) { @@ -286,9 +286,11 @@ __device__ __forceinline__ bool job_has_work(const JobDescriptor &job) { return job.rows != 0 && job.cols != 0; } -__device__ __forceinline__ void advance_to_next_job( - bool &job_finished, int32_t &ctaid_X, int32_t &ctaid_Y, size_t &static_next_block_id, - const size_t static_block_stride, const size_t total_work_blocks, const size_t work_blocks_X) { +__device__ __forceinline__ void advance_to_next_job(bool &job_finished, int32_t &ctaid_X, + int32_t &ctaid_Y, size_t &static_next_block_id, + const size_t static_block_stride, + const size_t total_work_blocks, + const size_t work_blocks_X) { if constexpr (PERSISTENT) { if (static_next_block_id < total_work_blocks) { ctaid_X = static_cast(static_next_block_id % work_blocks_X); @@ -303,11 +305,10 @@ __device__ __forceinline__ void advance_to_next_job( } template -__device__ __forceinline__ BlockDescriptor decode_block( - const JobDescriptor &job, const int64_t *const __restrict__ offsets_ptr) { - constexpr bool is_single_tensor = - (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || - SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); +__device__ __forceinline__ BlockDescriptor +decode_block(const JobDescriptor &job, const int64_t *const __restrict__ offsets_ptr) { + constexpr bool is_single_tensor = (SHAPE_REP == ShapeRepresentation::SAME_BOTH_DIMS || + SHAPE_REP == ShapeRepresentation::VARYING_FIRST_DIM); const size_t CHUNK_DIM_X_ = CHUNK_DIM_X; const size_t blocks_X_num_in_current_tensor = DIVUP(job.cols, CHUNK_DIM_X_); const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[job.tensor_id]); @@ -361,21 +362,24 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te } template -__global__ void __launch_bounds__(1) update_tma_descriptors( - const __grid_constant__ CUtensorMap base_tensor_map_input, - const __grid_constant__ CUtensorMap base_tensor_map_act_input, - const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, - const IType *const __restrict__ input_data_ptr, - const IType *const __restrict__ act_input_data_ptr, - const OType *const __restrict__ output_rowwise_data_ptr, - const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep, - const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, - const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise, - const bool compute_dactivations) { +__global__ void __launch_bounds__(1) + update_tma_descriptors(const __grid_constant__ CUtensorMap base_tensor_map_input, + const __grid_constant__ CUtensorMap base_tensor_map_act_input, + const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap base_tensor_map_output_colwise, + const IType *const __restrict__ input_data_ptr, + const IType *const __restrict__ act_input_data_ptr, + const OType *const __restrict__ output_rowwise_data_ptr, + const OType *const __restrict__ output_colwise_data_ptr, + const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, + const bool colwise, const bool compute_dactivations) { const size_t tensor_id = blockIdx.x; - const size_t rows = get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); + const size_t rows = + get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors); const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr); const size_t offset_elts = offsets_ptr[tensor_id]; @@ -490,17 +494,18 @@ __device__ __forceinline__ float process_colwise_stage( constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING; - constexpr bool NON_FP32_CAST_ONLY = NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); + constexpr bool NON_FP32_CAST_ONLY = + NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; float thread_amax = 0.0f; float in_compute_colwise[BUFF_DIM_Y]; IType in_colwise_IType[BUFF_DIM_Y]; - IType4 in_colwise_IType4[BUFF_DIM_Y/4]; + IType4 in_colwise_IType4[BUFF_DIM_Y / 4]; if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - #pragma unroll +#pragma unroll for (int i = 0; i < BUFF_DIM_Y; i += 4) { const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_colwise]); @@ -508,58 +513,56 @@ __device__ __forceinline__ float process_colwise_stage( // Load 4x elts S2R and find amax if constexpr (std::is_same_v) { asm volatile( - "{\n" - ".reg.u32 base_offset, stride; \n\t" - "mov.u32 base_offset, %2; \n\t" - "mov.u32 stride, %3; \n\t" - ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" - "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" - "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" - "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" - "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" - ".reg.b16 x0,x1,x2,x3; \n\t" - "ld.shared.b16 x0, [ptr0]; \n\t" - "ld.shared.b16 x1, [ptr1]; \n\t" - "ld.shared.b16 x2, [ptr2]; \n\t" - "ld.shared.b16 x3, [ptr3]; \n\t" - "mov.b64 %0, {x0,x1,x2,x3}; \n\t" - ".reg.b32 x01,x23; \n\t" - "mov.b32 x01, {x0,x1}; \n\t" - "mov.b32 x23, {x2,x3}; \n\t" - "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" - "max.xorsign.abs.bf16x2 %1, %1, x01; \n" - "}\n" - : "=l"(reinterpret_cast(in_colwise_IType4[i/4])), - "+r"(reinterpret_cast(thread_amax_2x)) - : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE) - ); + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %2; \n\t" + "mov.u32 stride, %3; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "ld.shared.b16 x0, [ptr0]; \n\t" + "ld.shared.b16 x1, [ptr1]; \n\t" + "ld.shared.b16 x2, [ptr2]; \n\t" + "ld.shared.b16 x3, [ptr3]; \n\t" + "mov.b64 %0, {x0,x1,x2,x3}; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b32 x01, {x0,x1}; \n\t" + "mov.b32 x23, {x2,x3}; \n\t" + "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.bf16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in_colwise_IType4[i / 4])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE)); } else { asm volatile( - "{\n" - ".reg.u32 base_offset, stride; \n\t" - "mov.u32 base_offset, %2; \n\t" - "mov.u32 stride, %3; \n\t" - ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" - "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" - "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" - "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" - "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" - ".reg.b16 x0,x1,x2,x3; \n\t" - "ld.shared.b16 x0, [ptr0]; \n\t" - "ld.shared.b16 x1, [ptr1]; \n\t" - "ld.shared.b16 x2, [ptr2]; \n\t" - "ld.shared.b16 x3, [ptr3]; \n\t" - "mov.b64 %0, {x0,x1,x2,x3}; \n\t" - ".reg.b32 x01,x23; \n\t" - "mov.b32 x01, {x0,x1}; \n\t" - "mov.b32 x23, {x2,x3}; \n\t" - "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" - "max.xorsign.abs.f16x2 %1, %1, x01; \n" - "}\n" - : "=l"(reinterpret_cast(in_colwise_IType4[i/4])), - "+r"(reinterpret_cast(thread_amax_2x)) - : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE) - ); + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %2; \n\t" + "mov.u32 stride, %3; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "ld.shared.b16 x0, [ptr0]; \n\t" + "ld.shared.b16 x1, [ptr1]; \n\t" + "ld.shared.b16 x2, [ptr2]; \n\t" + "ld.shared.b16 x3, [ptr3]; \n\t" + "mov.b64 %0, {x0,x1,x2,x3}; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b32 x01, {x0,x1}; \n\t" + "mov.b32 x23, {x2,x3}; \n\t" + "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.f16x2 %1, %1, x01; \n" + "}\n" + : "=l"(reinterpret_cast(in_colwise_IType4[i / 4])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr), "r"(IN_SHMEM_STRIDE)); } } thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); @@ -625,36 +628,36 @@ __device__ __forceinline__ float process_colwise_stage( const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { - #pragma unroll +#pragma unroll for (int i = 0; i < SCALE_DIM_Y; i += 4) { OType4 out; - ptx::mul_cvt_4x(out, in_colwise_IType4[i/4], block_scale_inverse_f16); + ptx::mul_cvt_4x(out, in_colwise_IType4[i / 4], block_scale_inverse_f16); const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_colwise_data_sh[shmem_offset_elt]); + const uint32_t dst_smem_ptr = + __cvta_generic_to_shared(&out_colwise_data_sh[shmem_offset_elt]); asm volatile( - "{\n" - ".reg.u32 base_offset, stride; \n\t" - "mov.u32 base_offset, %0; \n\t" - "mov.u32 stride, %1; \n\t" - ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" - "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" - "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" - "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" - "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" - ".reg.b8 x0,x1,x2,x3; \n\t" - "mov.b32 {x0,x1,x2,x3}, %2; \n\t" - "st.shared.b8 [ptr0], x0; \n\t" - "st.shared.b8 [ptr1], x1; \n\t" - "st.shared.b8 [ptr2], x2; \n\t" - "st.shared.b8 [ptr3], x3; \n" - "}\n" - :: "r"(dst_smem_ptr), "r"(OUT_SHMEM_STRIDE), "r"(reinterpret_cast(out)) - ); + "{\n" + ".reg.u32 base_offset, stride; \n\t" + "mov.u32 base_offset, %0; \n\t" + "mov.u32 stride, %1; \n\t" + ".reg.u32 ptr0,ptr1,ptr2,ptr3; \n\t" + "mad.lo.u32 ptr0, 0, stride, base_offset; \n\t" + "mad.lo.u32 ptr1, 1, stride, base_offset; \n\t" + "mad.lo.u32 ptr2, 2, stride, base_offset; \n\t" + "mad.lo.u32 ptr3, 3, stride, base_offset; \n\t" + ".reg.b8 x0,x1,x2,x3; \n\t" + "mov.b32 {x0,x1,x2,x3}, %2; \n\t" + "st.shared.b8 [ptr0], x0; \n\t" + "st.shared.b8 [ptr1], x1; \n\t" + "st.shared.b8 [ptr2], x2; \n\t" + "st.shared.b8 [ptr3], x3; \n" + "}\n" ::"r"(dst_smem_ptr), + "r"(OUT_SHMEM_STRIDE), "r"(reinterpret_cast(out))); } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < SCALE_DIM_Y; ++i) { float in; if constexpr (NON_FP32_CAST_ONLY) { @@ -663,7 +666,7 @@ __device__ __forceinline__ float process_colwise_stage( in = in_compute_colwise[i]; } const float scaled_out = in * block_scale_inverse; - + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); } @@ -688,7 +691,8 @@ __device__ __forceinline__ float process_rowwise_stage( constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && COLWISE_SCALING; - constexpr bool NON_FP32_CAST_ONLY = NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); + constexpr bool NON_FP32_CAST_ONLY = + NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); const size_t shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; float thread_amax = 0.0f; @@ -699,7 +703,7 @@ __device__ __forceinline__ float process_rowwise_stage( if constexpr (NON_FP32_CAST_ONLY) { IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; @@ -708,20 +712,19 @@ __device__ __forceinline__ float process_rowwise_stage( const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_rowwise]); // Load 4x elts S2R and find amax asm volatile( - "{\n" - "ld.shared.b64 %0, [%2]; \n\t" - ".reg.b32 x01,x23; \n\t" - "mov.b64 {x01, x23}, %0; \n\t" - "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" - "max.xorsign.abs.bf16x2 %1, %1, x01; \n" - "}\n" - : "+l"(reinterpret_cast(in_IType4[w])), - "+r"(reinterpret_cast(thread_amax_2x)) - : "r"(src_smem_ptr) - ); + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.bf16x2 %1, %1, x01; \n" + "}\n" + : "+l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr)); } else { in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); } @@ -811,7 +814,7 @@ __device__ __forceinline__ float process_rowwise_stage( const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; @@ -819,14 +822,15 @@ __device__ __forceinline__ float process_rowwise_stage( if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { uint32_t out_4x = 0; - OType4& out = *reinterpret_cast(&out_4x); + OType4 &out = *reinterpret_cast(&out_4x); ptx::mul_cvt_4x(out, in_IType4[w], block_scale_inverse_f16); - const uint32_t dst_smem_ptr = __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); + const uint32_t dst_smem_ptr = + __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); } else { Vec out; - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { IType2 in; OType2 &out_pair = reinterpret_cast(out.data.elt[e]); @@ -851,16 +855,15 @@ __device__ __forceinline__ float process_rowwise_stage( template + ScalingType SCALING_TYPE, bool WITH_GEMM_SWIZZLED_SCALES, ShapeRepresentation SHAPE_REP, + bool USE_FAST_MATH> __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, - const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, - const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr, @@ -875,10 +878,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - constexpr bool ROWWISE_SCALING = (SCALING_TYPE == ScalingType::ROWWISE) - || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); - constexpr bool COLWISE_SCALING = (SCALING_TYPE == ScalingType::COLWISE) - || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool ROWWISE_SCALING = + (SCALING_TYPE == ScalingType::ROWWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool COLWISE_SCALING = + (SCALING_TYPE == ScalingType::COLWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); constexpr ShapeRepresentation shape_rep = SHAPE_REP; constexpr bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); @@ -963,9 +966,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. while (!job_finished) { // Decode CTA assignment into logical tensor coordinates and validate bounds. - const JobDescriptor current_job = decode_job( - num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, ctaid_X, ctaid_Y, - offsets_ptr, first_dims_ptr, last_dims_ptr); + const JobDescriptor current_job = + decode_job(num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, + ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); const bool current_job_is_valid = is_job_valid(current_job, total_work_blocks, offsets_ptr); if (!current_job_is_valid) { @@ -973,8 +976,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } if (!job_has_work(current_job)) { // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. - advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, - static_block_stride, total_work_blocks, work_blocks_X); + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, + total_work_blocks, work_blocks_X); continue; } @@ -1079,10 +1082,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; - prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, - global_offset_X, global_offset_Y, - next_prefetch_buff_offset, shmem_buff_size, barrier, - leading_thread); + prefetch_input_stage( + in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); } ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], @@ -1093,20 +1095,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t buff = buff_in; float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { - thread_amax = process_colwise_stage( - buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, - scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, - cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); + thread_amax = + process_colwise_stage( + buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, + cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); } if constexpr (ROWWISE_SCALING) { - thread_amax = process_rowwise_stage( - buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, - scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, - rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, - out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); + thread_amax = + process_rowwise_stage( + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, + scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, + out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); } __builtin_assume(block_amax >= 0); @@ -1167,8 +1171,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } } - advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, - static_block_stride, total_work_blocks, work_blocks_X); + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, + total_work_blocks, work_blocks_X); } if (amax_ptr != nullptr) { @@ -1203,8 +1207,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations if (use_fast_math) { NVTE_CHECK(input->dtype() == DType::kBFloat16 || input->dtype() == DType::kFloat16, "Fast math supports only BF16 and FP16 input types."); - NVTE_CHECK(!IS_DBIAS && !IS_DACT && !IS_ACT, - "Fast math does not support fused casts."); + NVTE_CHECK(!IS_DBIAS && !IS_DACT && !IS_ACT, "Fast math does not support fused casts."); } checkCuDriverContext(stream); @@ -1326,112 +1329,128 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } } - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input->dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype(), OType, - TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH(scaling_type, SCALING_TYPE, - TRANSFORMER_ENGINE_SWITCH_CONDITION(with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(shape_rep, SHAPE_REP, - TRANSFORMER_ENGINE_SWITCH_CONDITION(use_fast_math, USE_FAST_MATH, - { - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - // Update tensor descriptors before launching the kernel - if (!is_single_tensor) { - const IType *const input_dptr = reinterpret_cast(input->data.dptr); - - const IType *const act_input_dptr = - IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; - - OType *const output_rowwise_dptr = - use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; - - OType *const output_colwise_dptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input->dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH( + scaling_type, SCALING_TYPE, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH( + shape_rep, SHAPE_REP, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_fast_math, USE_FAST_MATH, + { + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, + BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, + BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, + output->columnwise_data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = + (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = + (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = + (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = + (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = + reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) + : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) + : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling + ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - update_tma_descriptors<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, - output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, - use_rowwise_scaling, use_colwise_scaling, IS_DACT); - } - - auto kernel = - group_quantize_mxfp8_kernel; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, work_blocks_X, - work_blocks_Y); - - if constexpr (IS_DBIAS) { - common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, - first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); - } - - NVTE_CHECK_CUDA(cudaGetLastError()); - } - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, + output_rowwise_dptr, output_colwise_dptr, shape_rep, num_tensors, + first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr, + last_dims_ptr, use_rowwise_scaling, use_colwise_scaling, IS_DACT); + } + + auto kernel = group_quantize_mxfp8_kernel< + IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, SCALING_TYPE, + WITH_GEMM_SWIZZLED_SCALES, SHAPE_REP, USE_FAST_MATH>; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, + scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, + amax_ptr, work_blocks_X, work_blocks_Y); + + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, + CHUNK_DIM_Y, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError()); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index e2fa3549cc..2187b61d61 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -904,23 +904,23 @@ struct TypeInfo { { __VA_ARGS__ } \ } -#define TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH(SCALING_TYPE, SCALING_T, ...) \ - switch (SCALING_TYPE) { \ - case ScalingType::ROWWISE: { \ - constexpr ScalingType SCALING_T = ScalingType::ROWWISE; \ - { __VA_ARGS__ } \ - } break; \ - case ScalingType::COLWISE: { \ - constexpr ScalingType SCALING_T = ScalingType::COLWISE; \ - { __VA_ARGS__ } \ - } break; \ - case ScalingType::BIDIMENSIONAL: { \ - constexpr ScalingType SCALING_T = ScalingType::BIDIMENSIONAL; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Unsupported scaling type."); \ - } \ +#define TRANSFORMER_ENGINE_SCALING_TYPE_SWITCH(SCALING_TYPE, SCALING_T, ...) \ + switch (SCALING_TYPE) { \ + case ScalingType::ROWWISE: { \ + constexpr ScalingType SCALING_T = ScalingType::ROWWISE; \ + { __VA_ARGS__ } \ + } break; \ + case ScalingType::COLWISE: { \ + constexpr ScalingType SCALING_T = ScalingType::COLWISE; \ + { __VA_ARGS__ } \ + } break; \ + case ScalingType::BIDIMENSIONAL: { \ + constexpr ScalingType SCALING_T = ScalingType::BIDIMENSIONAL; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported scaling type."); \ + } \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index b9fb961345..c29b89c832 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1173,22 +1173,22 @@ __device__ __forceinline__ fp16 get_amax(fp16 a, fp16 b) { __device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, const bf16 scale) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( - "{\n\t" - ".reg.b16 x0,x1,x2,x3; \n\t" - "mov.b64 {x0,x1,x2,x3}, %1; \n\t" - ".reg.f32 y0,y1,y2,y3; \n\t" - "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" - ".reg.b16 z01, z23; \n\t" - "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" - "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" - "mov.b32 %0, {z01, z23}; \n" - "}\n" - : "=r"(reinterpret_cast(out)) - : "l"(reinterpret_cast(in)), - "h"(reinterpret_cast(scale))); + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); #else NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -1197,22 +1197,22 @@ __device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, con __device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in, const bf16 scale) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( - "{\n\t" - ".reg.b16 x0,x1,x2,x3; \n\t" - "mov.b64 {x0,x1,x2,x3}, %1; \n\t" - ".reg.f32 y0,y1,y2,y3; \n\t" - "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" - "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" - ".reg.b16 z01, z23; \n\t" - "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" - "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" - "mov.b32 %0, {z01, z23}; \n" - "}\n" - : "=r"(reinterpret_cast(out)) - : "l"(reinterpret_cast(in)), - "h"(reinterpret_cast(scale))); + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.bf16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.bf16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); #else NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -1221,22 +1221,22 @@ __device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in, con __device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in, const fp16 scale) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( - "{\n\t" - ".reg.b16 x0,x1,x2,x3; \n\t" - "mov.b64 {x0,x1,x2,x3}, %1; \n\t" - ".reg.f32 y0,y1,y2,y3; \n\t" - "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" - "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" - "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" - "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" - ".reg.b16 z01, z23; \n\t" - "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" - "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" - "mov.b32 %0, {z01, z23}; \n" - "}\n" - : "=r"(reinterpret_cast(out)) - : "l"(reinterpret_cast(in)), - "h"(reinterpret_cast(scale))); + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); #else NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -1245,22 +1245,22 @@ __device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in, con __device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const fp16x4 &in, const fp16 scale) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( - "{\n\t" - ".reg.b16 x0,x1,x2,x3; \n\t" - "mov.b64 {x0,x1,x2,x3}, %1; \n\t" - ".reg.f32 y0,y1,y2,y3; \n\t" - "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" - "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" - "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" - "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" - ".reg.b16 z01, z23; \n\t" - "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" - "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" - "mov.b32 %0, {z01, z23}; \n" - "}\n" - : "=r"(reinterpret_cast(out)) - : "l"(reinterpret_cast(in)), - "h"(reinterpret_cast(scale))); + "{\n\t" + ".reg.b16 x0,x1,x2,x3; \n\t" + "mov.b64 {x0,x1,x2,x3}, %1; \n\t" + ".reg.f32 y0,y1,y2,y3; \n\t" + "fma.rn.f32.f16 y0, x0, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y1, x1, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y2, x2, %2, 0f00000000; \n\t" + "fma.rn.f32.f16 y3, x3, %2, 0f00000000; \n\t" + ".reg.b16 z01, z23; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z01, y1, y0; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 z23, y3, y2; \n\t" + "mov.b32 %0, {z01, z23}; \n" + "}\n" + : "=r"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "h"(reinterpret_cast(scale))); #else NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+."); #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) From 6874935b7a91b3e5717a75bf6a937446f580f9c8 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 14:24:14 +0000 Subject: [PATCH 36/51] Uncommented test cases Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 0b63c7510a..e5c62282f5 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -639,15 +639,15 @@ void performTest(const ProcessingMethod processing_method, std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - // ProcessingMethod::CAST_DBIAS, - // ProcessingMethod::CAST_DBIAS_DACT, - // ProcessingMethod::CAST_DACT, - // ProcessingMethod::CAST_ACT, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, }; std::vector activation_kinds = { ActivationKind::Identity, - // ActivationKind::GeLU, + ActivationKind::GeLU, // ActivationKind::SiLU, // ActivationKind::ReLU, // ActivationKind::QGeLU, @@ -661,16 +661,13 @@ enum ScalingDirection { }; std::vector scaling_directions = { - // ScalingDirection::ROWWISE, - // ScalingDirection::COLWISE, + ScalingDirection::ROWWISE, + ScalingDirection::COLWISE, ScalingDirection::BOTH, }; // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - // {SAME_BOTH_DIMS, 1, 8192,7168}, - // {VARYING_FIRST_DIM, 6, 8192,7168, 128,256,384,1024,2304,4096}, - // {VARYING_FIRST_DIM, 6, 16*8192,7168, 128,256,384,1024,2304,4096}, {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, @@ -850,11 +847,9 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(activation_kinds), ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), - ::testing::Values(DType::kBFloat16), - ::testing::Values(DType::kFloat8E4M3), - ::testing::Values(true)), - // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); std::string name = to_string(method); From f985c01fdae852af57afa2c16ee28c4d68c63fe2 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 14:29:37 +0000 Subject: [PATCH 37/51] Added FP16 Fast math path to rowwise processing Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 4f7e77c65a..96b722d8f3 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -711,17 +711,31 @@ __device__ __forceinline__ float process_rowwise_stage( if constexpr (USE_FAST_MATH) { const uint32_t src_smem_ptr = __cvta_generic_to_shared(&in_sh[shmem_offset_rowwise]); // Load 4x elts S2R and find amax - asm volatile( - "{\n" - "ld.shared.b64 %0, [%2]; \n\t" - ".reg.b32 x01,x23; \n\t" - "mov.b64 {x01, x23}, %0; \n\t" - "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" - "max.xorsign.abs.bf16x2 %1, %1, x01; \n" - "}\n" - : "+l"(reinterpret_cast(in_IType4[w])), - "+r"(reinterpret_cast(thread_amax_2x)) - : "r"(src_smem_ptr)); + if constexpr (std::is_same_v) { + asm volatile( + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.bf16x2 %1, %1, x01; \n" + "}\n" + : "+l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr)); + else { + asm volatile( + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.f16x2 %1, %1, x01; \n" + "}\n" + : "+l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr)); + } } else { in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); #pragma unroll From 50685563577ce9f1afccfbae81ef0b7ae5b6b175 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 14:33:47 +0000 Subject: [PATCH 38/51] Refactoring Signed-off-by: Oleg Goncharov --- .../common/cast/core/common.cuh | 30 ------------------- transformer_engine/common/common.h | 23 ++++++++++++++ transformer_engine/common/utils.cuh | 7 +++++ 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 53e72e42b8..f150fa7981 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -23,36 +23,6 @@ namespace transformer_engine { namespace dispatch { namespace common { -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 -}; - -#define TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(SHAPE_REP, SHAPE, ...) \ - switch (SHAPE_REP) { \ - case ShapeRepresentation::SAME_BOTH_DIMS: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::SAME_BOTH_DIMS; \ - { __VA_ARGS__ } \ - } break; \ - case ShapeRepresentation::VARYING_FIRST_DIM: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_FIRST_DIM; \ - { __VA_ARGS__ } \ - } break; \ - case ShapeRepresentation::VARYING_LAST_DIM: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_LAST_DIM; \ - { __VA_ARGS__ } \ - } break; \ - case ShapeRepresentation::VARYING_BOTH_DIMS: { \ - constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_BOTH_DIMS; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Unsupported grouped tensor shape representation."); \ - } \ - } - inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 2187b61d61..8db34b5756 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -923,6 +923,29 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_GROUP_TENSOR_SHAPE_REPRESENTATION_SWITCH(SHAPE_REP, SHAPE, ...) \ + switch (SHAPE_REP) { \ + case ShapeRepresentation::SAME_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::SAME_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_FIRST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_FIRST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_LAST_DIM: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_LAST_DIM; \ + { __VA_ARGS__ } \ + } break; \ + case ShapeRepresentation::VARYING_BOTH_DIMS: { \ + constexpr ShapeRepresentation SHAPE = ShapeRepresentation::VARYING_BOTH_DIMS; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported grouped tensor shape representation."); \ + } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 26549191a3..8c50e83926 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -928,6 +928,13 @@ using e8m0_t = uint8_t; enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 }; +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + template struct Numeric_Traits; From 6c945d6a36aea79dc8de6b6a9cac30284d85e325 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Mar 2026 14:34:56 +0000 Subject: [PATCH 39/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 795 +++++++++--------- 1 file changed, 401 insertions(+), 394 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 96b722d8f3..0f7b86b836 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -723,490 +723,497 @@ __device__ __forceinline__ float process_rowwise_stage( : "+l"(reinterpret_cast(in_IType4[w])), "+r"(reinterpret_cast(thread_amax_2x)) : "r"(src_smem_ptr)); - else { - asm volatile( - "{\n" - "ld.shared.b64 %0, [%2]; \n\t" - ".reg.b32 x01,x23; \n\t" - "mov.b64 {x01, x23}, %0; \n\t" - "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" - "max.xorsign.abs.f16x2 %1, %1, x01; \n" - "}\n" - : "+l"(reinterpret_cast(in_IType4[w])), - "+r"(reinterpret_cast(thread_amax_2x)) - : "r"(src_smem_ptr)); - } - } else { - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + else { + asm volatile( + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.f16x2 %1, %1, x01; \n" + "}\n" + : "+l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr)); + } + } else { + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); #pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } } } + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } - thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (std::is_same_v) { + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { #pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } } } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } } - if constexpr (!std::is_same_v) { - thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { + else { #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - Vec in; - Vec act_in; + Vec in; + Vec act_in; - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } + in.load_from(&in_sh[shmem_offset_rowwise]); if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; } } - } - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( - stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - } - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { - uint32_t out_4x = 0; - OType4 &out = *reinterpret_cast(&out_4x); - ptx::mul_cvt_4x(out, in_IType4[w], block_scale_inverse_f16); + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { + uint32_t out_4x = 0; + OType4 &out = *reinterpret_cast(&out_4x); + ptx::mul_cvt_4x(out, in_IType4[w], block_scale_inverse_f16); - const uint32_t dst_smem_ptr = - __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); - asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); - } else { - Vec out; + const uint32_t dst_smem_ptr = + __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); + } else { + Vec out; #pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NON_FP32_CAST_ONLY) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NON_FP32_CAST_ONLY) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - } - return thread_amax; -} + return thread_amax; + } -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( - const __grid_constant__ CUtensorMap tensor_map_input_static, - const __grid_constant__ CUtensorMap tensor_map_act_input_static, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, - const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, const size_t num_tensors, - const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, - const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, - e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, - float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr, - const size_t work_blocks_X, const size_t work_blocks_Y) { + template + __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, + e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, + const float *__restrict__ noop, float *const __restrict__ dbias_workspace, + float *const __restrict__ amax_ptr, const size_t work_blocks_X, const size_t work_blocks_Y) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } } - } - constexpr bool ROWWISE_SCALING = - (SCALING_TYPE == ScalingType::ROWWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); - constexpr bool COLWISE_SCALING = - (SCALING_TYPE == ScalingType::COLWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool ROWWISE_SCALING = + (SCALING_TYPE == ScalingType::ROWWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool COLWISE_SCALING = + (SCALING_TYPE == ScalingType::COLWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); - constexpr ShapeRepresentation shape_rep = SHAPE_REP; - constexpr bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + constexpr ShapeRepresentation shape_rep = SHAPE_REP; + constexpr bool is_single_tensor = + (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const bool leading_thread = (threadIdx.x == 0); + const bool leading_thread = (threadIdx.x == 0); - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - extern __shared__ unsigned char dynamic_shmem[]; - unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; + constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; - float block_amax = 0.0f; + float block_amax = 0.0f; - __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; - // Initialize barriers shared by the entire CTA: - // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. - if (leading_thread) { + // Initialize barriers shared by the entire CTA: + // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. + if (leading_thread) { #pragma unroll - for (int buff = 0; buff < BUFFS_NUM; ++buff) { - ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); + } + ptx::fence_proxy_async_shared_cta(); } - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); - - const size_t total_work_blocks = work_blocks_X * work_blocks_Y; - const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; + __syncthreads(); - int IN_buff_readable_parity[BUFFS_NUM] = {0}; - int32_t ctaid_X = static_cast(blockIdx.x); - int32_t ctaid_Y = static_cast(blockIdx.y); - size_t static_next_block_id = 0; - size_t static_block_stride = 0; - // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. - if constexpr (PERSISTENT) { - if (launch_block_id >= total_work_blocks) { - return; - } - ctaid_X = static_cast(launch_block_id % work_blocks_X); - ctaid_Y = static_cast(launch_block_id / work_blocks_X); - static_block_stride = gridDim.x * gridDim.y; - static_next_block_id = launch_block_id + static_block_stride; - } - bool job_finished = false; - size_t last_acquired_tensor_id = num_tensors; - - // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. - while (!job_finished) { - // Decode CTA assignment into logical tensor coordinates and validate bounds. - const JobDescriptor current_job = - decode_job(num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, - ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - const bool current_job_is_valid = - is_job_valid(current_job, total_work_blocks, offsets_ptr); - if (!current_job_is_valid) { - break; - } - if (!job_has_work(current_job)) { - // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. - advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, - total_work_blocks, work_blocks_X); - continue; + const size_t total_work_blocks = work_blocks_X * work_blocks_Y; + const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; + + int IN_buff_readable_parity[BUFFS_NUM] = {0}; + int32_t ctaid_X = static_cast(blockIdx.x); + int32_t ctaid_Y = static_cast(blockIdx.y); + size_t static_next_block_id = 0; + size_t static_block_stride = 0; + // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. + if constexpr (PERSISTENT) { + if (launch_block_id >= total_work_blocks) { + return; + } + ctaid_X = static_cast(launch_block_id % work_blocks_X); + ctaid_Y = static_cast(launch_block_id / work_blocks_X); + static_block_stride = gridDim.x * gridDim.y; + static_next_block_id = launch_block_id + static_block_stride; } - - const size_t tensor_id = current_job.tensor_id; - const size_t rows = current_job.rows; - const size_t cols = current_job.cols; - const BlockDescriptor current_block = decode_block(current_job, offsets_ptr); - - const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); - const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); - - const size_t tensor_base = current_block.tensor_base; - const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) - ? static_cast(offsets_ptr[tensor_id]) - : tensor_base; - const size_t block_id_Y = current_block.block_id_Y; - const size_t block_id_X = current_block.block_id_X; - const size_t block_offset_Y = current_block.block_offset_Y; - const size_t block_offset_X = current_block.block_offset_X; - - e8m0_t *const scales_rowwise = - scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); - e8m0_t *const scales_colwise = - scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); - - const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - - const int dbias_offset_Y = block_id_Y; - const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; - - const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = is_single_tensor - ? tensor_map_output_rowwise_static - : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = is_single_tensor - ? tensor_map_output_colwise_static - : g_tensor_maps_output_colwise[tensor_id]; - - if (leading_thread && (!is_single_tensor) && (last_acquired_tensor_id != tensor_id)) { - fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&tensor_map_act_input); + bool job_finished = false; + size_t last_acquired_tensor_id = num_tensors; + + // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. + while (!job_finished) { + // Decode CTA assignment into logical tensor coordinates and validate bounds. + const JobDescriptor current_job = + decode_job(num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, + ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const bool current_job_is_valid = + is_job_valid(current_job, total_work_blocks, offsets_ptr); + if (!current_job_is_valid) { + break; } - if constexpr (ROWWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_rowwise); + if (!job_has_work(current_job)) { + // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, + static_block_stride, total_work_blocks, work_blocks_X); + continue; } - if constexpr (COLWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_colwise); + + const size_t tensor_id = current_job.tensor_id; + const size_t rows = current_job.rows; + const size_t cols = current_job.cols; + const BlockDescriptor current_block = decode_block(current_job, offsets_ptr); + + const size_t scale_stride_rowwise = + DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + + const size_t tensor_base = current_block.tensor_base; + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + const size_t block_id_Y = current_block.block_id_Y; + const size_t block_id_X = current_block.block_id_X; + const size_t block_offset_Y = current_block.block_offset_Y; + const size_t block_offset_X = current_block.block_offset_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = is_single_tensor + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = is_single_tensor + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; + + if (leading_thread && (!is_single_tensor) && (last_acquired_tensor_id != tensor_id)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); + } + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); + } + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); + } + last_acquired_tensor_id = tensor_id; } - last_acquired_tensor_id = tensor_id; - } - int buff_in = 0; + int buff_in = 0; - // Prime the pipeline with the first PREFETCH_STAGES slices of the current block. + // Prime the pipeline with the first PREFETCH_STAGES slices of the current block. #pragma unroll - for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { - const size_t buff = stage; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; - uint64_t *barrier = &IN_buff_readable_mbar[buff]; - prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, - global_offset_X, global_offset_Y, buff_offset, - shmem_buff_size, barrier, leading_thread); - } + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const size_t buff = stage; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + uint64_t *barrier = &IN_buff_readable_mbar[buff]; + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, + tensor_map_act_input, global_offset_X, global_offset_Y, + buff_offset, shmem_buff_size, barrier, leading_thread); + } - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { #pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } } - } - // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). + // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - if (stage < STAGES - PREFETCH_STAGES) { - const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; - const size_t next_prefetch_stage = stage + PREFETCH_STAGES; - const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; - - const size_t global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; - - uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; - prefetch_input_stage( - in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, global_offset_X, - global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); - } + for (int stage = 0; stage < STAGES; ++stage) { + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + if (stage < STAGES - PREFETCH_STAGES) { + const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const size_t next_prefetch_stage = stage + PREFETCH_STAGES; + const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + + const size_t global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; + + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + prefetch_input_stage( + in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); + } - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); - IN_buff_readable_parity[buff_in] ^= 1; - ptx::cp_async_bulk_wait_group_read(); + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); - const size_t buff = buff_in; - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_amax = - process_colwise_stage( - buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, - scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, - cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); - } + const size_t buff = buff_in; + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_amax = + process_colwise_stage( + buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, + cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); + } - if constexpr (ROWWISE_SCALING) { - thread_amax = - process_rowwise_stage( - buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, - scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, - rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, - out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); - } + if constexpr (ROWWISE_SCALING) { + thread_amax = + process_rowwise_stage( + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, + bank_group, scales_offset_Y_rowwise, scales_offset_X_rowwise, + scale_stride_rowwise, rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, + cached_act_sh, out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); + } - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); - // Publish the stage from shared memory into global outputs via TMA. - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; - store_output_stage( - out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, - tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, leading_thread); + // Publish the stage from shared memory into global outputs via TMA. + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + store_output_stage( + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, + tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, + leading_thread); - buff_in = (buff_in + 1) % BUFFS_NUM; - } + buff_in = (buff_in + 1) % BUFFS_NUM; + } - if constexpr (IS_DBIAS) { - if (is_single_tensor) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - float *partial_dbias_rowwise = reinterpret_cast(dshmem); + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + for (int i = 0; i < THREADS_Y; ++i) { + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; } - } - const int dbias_stride = cols; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; } } - } - advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, - total_work_blocks, work_blocks_X); - } + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, + total_work_blocks, work_blocks_X); + } - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } - if (leading_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } + if (leading_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } - if (leading_thread) { + if (leading_thread) { #pragma unroll - for (int buff = 0; buff < BUFFS_NUM; ++buff) { - ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); + } } - } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} + } } // namespace group_quantize_kernel template Date: Wed, 18 Mar 2026 15:18:32 +0000 Subject: [PATCH 40/51] Fixed lint Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 808 +++++++++--------- 1 file changed, 402 insertions(+), 406 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 0f7b86b836..1ad6421f61 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -723,497 +723,494 @@ __device__ __forceinline__ float process_rowwise_stage( : "+l"(reinterpret_cast(in_IType4[w])), "+r"(reinterpret_cast(thread_amax_2x)) : "r"(src_smem_ptr)); - else { - asm volatile( - "{\n" - "ld.shared.b64 %0, [%2]; \n\t" - ".reg.b32 x01,x23; \n\t" - "mov.b64 {x01, x23}, %0; \n\t" - "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" - "max.xorsign.abs.f16x2 %1, %1, x01; \n" - "}\n" - : "+l"(reinterpret_cast(in_IType4[w])), - "+r"(reinterpret_cast(thread_amax_2x)) - : "r"(src_smem_ptr)); - } } else { - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } + asm volatile( + "{\n" + "ld.shared.b64 %0, [%2]; \n\t" + ".reg.b32 x01,x23; \n\t" + "mov.b64 {x01, x23}, %0; \n\t" + "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" + "max.xorsign.abs.f16x2 %1, %1, x01; \n" + "}\n" + : "+l"(reinterpret_cast(in_IType4[w])), + "+r"(reinterpret_cast(thread_amax_2x)) + : "r"(src_smem_ptr)); + } + } else { + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + #pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); } } - thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } - else if constexpr (IS_CACHED_ACT_OP) { - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (std::is_same_v) { + #pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { + #pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } } - else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { + #pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - Vec in; - Vec act_in; + Vec in; + Vec act_in; - in.load_from(&in_sh[shmem_offset_rowwise]); + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } + #pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - thread_amax = fmaxf(thread_amax, fabsf(elt)); - in_compute_rowwise[j] = elt; + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + thread_amax = fmaxf(thread_amax, fabsf(elt)); + in_compute_rowwise[j] = elt; } } + } - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; - size_t scale_idx = 0; - if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( - stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); - } else { - scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - } - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + stage_scales_offset_Y, stage_scales_offset_X, DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const IType block_scale_inverse_f16 = static_cast(block_scale_inverse); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; #pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { - uint32_t out_4x = 0; - OType4 &out = *reinterpret_cast(&out_4x); - ptx::mul_cvt_4x(out, in_IType4[w], block_scale_inverse_f16); + if constexpr (USE_FAST_MATH && NON_FP32_CAST_ONLY) { + uint32_t out_4x = 0; + OType4 &out = *reinterpret_cast(&out_4x); + ptx::mul_cvt_4x(out, in_IType4[w], block_scale_inverse_f16); - const uint32_t dst_smem_ptr = - __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); - asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); - } else { - Vec out; + const uint32_t dst_smem_ptr = + __cvta_generic_to_shared(&out_rowwise_data_sh[shmem_offset_rowwise]); + asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(out_4x)); + } else { + Vec out; #pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NON_FP32_CAST_ONLY) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NON_FP32_CAST_ONLY) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; } - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } - - return thread_amax; } + return thread_amax; +} - template - __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( - const __grid_constant__ CUtensorMap tensor_map_input_static, - const __grid_constant__ CUtensorMap tensor_map_act_input_static, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, - const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, - const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, - const int64_t *const __restrict__ last_dims_ptr, - e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, - const float *__restrict__ noop, float *const __restrict__ dbias_workspace, - float *const __restrict__ amax_ptr, const size_t work_blocks_X, const size_t work_blocks_Y) { +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( + const __grid_constant__ CUtensorMap tensor_map_input_static, + const __grid_constant__ CUtensorMap tensor_map_act_input_static, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, + const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, + const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, + e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, + const float *__restrict__ noop, float *const __restrict__ dbias_workspace, + float *const __restrict__ amax_ptr, const size_t work_blocks_X, const size_t work_blocks_Y) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; } + } - constexpr bool ROWWISE_SCALING = - (SCALING_TYPE == ScalingType::ROWWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); - constexpr bool COLWISE_SCALING = - (SCALING_TYPE == ScalingType::COLWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool ROWWISE_SCALING = + (SCALING_TYPE == ScalingType::ROWWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); + constexpr bool COLWISE_SCALING = + (SCALING_TYPE == ScalingType::COLWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); - constexpr ShapeRepresentation shape_rep = SHAPE_REP; - constexpr bool is_single_tensor = - (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + constexpr ShapeRepresentation shape_rep = SHAPE_REP; + constexpr bool is_single_tensor = + (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); - const bool leading_thread = (threadIdx.x == 0); + const bool leading_thread = (threadIdx.x == 0); - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - extern __shared__ unsigned char dynamic_shmem[]; - unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + extern __shared__ unsigned char dynamic_shmem[]; + unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; + constexpr size_t shmem_buff_size = (IS_DACT ? 2 : 1) * buff_size_aligned_in / BUFFS_NUM; - float block_amax = 0.0f; + float block_amax = 0.0f; - __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; - // Initialize barriers shared by the entire CTA: - // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. - if (leading_thread) { + // Initialize barriers shared by the entire CTA: + // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. + if (leading_thread) { #pragma unroll - for (int buff = 0; buff < BUFFS_NUM; ++buff) { - ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); - } - ptx::fence_proxy_async_shared_cta(); + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); } - __syncthreads(); + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); - const size_t total_work_blocks = work_blocks_X * work_blocks_Y; - const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; - - int IN_buff_readable_parity[BUFFS_NUM] = {0}; - int32_t ctaid_X = static_cast(blockIdx.x); - int32_t ctaid_Y = static_cast(blockIdx.y); - size_t static_next_block_id = 0; - size_t static_block_stride = 0; - // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. - if constexpr (PERSISTENT) { - if (launch_block_id >= total_work_blocks) { - return; - } - ctaid_X = static_cast(launch_block_id % work_blocks_X); - ctaid_Y = static_cast(launch_block_id / work_blocks_X); - static_block_stride = gridDim.x * gridDim.y; - static_next_block_id = launch_block_id + static_block_stride; + const size_t total_work_blocks = work_blocks_X * work_blocks_Y; + const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; + + int IN_buff_readable_parity[BUFFS_NUM] = {0}; + int32_t ctaid_X = static_cast(blockIdx.x); + int32_t ctaid_Y = static_cast(blockIdx.y); + size_t static_next_block_id = 0; + size_t static_block_stride = 0; + // In persistent mode, physical CTAs iterate over a virtual work grid via grid-stride. + if constexpr (PERSISTENT) { + if (launch_block_id >= total_work_blocks) { + return; + } + ctaid_X = static_cast(launch_block_id % work_blocks_X); + ctaid_Y = static_cast(launch_block_id / work_blocks_X); + static_block_stride = gridDim.x * gridDim.y; + static_next_block_id = launch_block_id + static_block_stride; + } + bool job_finished = false; + size_t last_acquired_tensor_id = num_tensors; + + // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. + while (!job_finished) { + // Decode CTA assignment into logical tensor coordinates and validate bounds. + const JobDescriptor current_job = + decode_job(num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, + ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); + const bool current_job_is_valid = + is_job_valid(current_job, total_work_blocks, offsets_ptr); + if (!current_job_is_valid) { + break; } - bool job_finished = false; - size_t last_acquired_tensor_id = num_tensors; - - // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. - while (!job_finished) { - // Decode CTA assignment into logical tensor coordinates and validate bounds. - const JobDescriptor current_job = - decode_job(num_tensors, first_logical_dim, last_logical_dim, work_blocks_X, - ctaid_X, ctaid_Y, offsets_ptr, first_dims_ptr, last_dims_ptr); - const bool current_job_is_valid = - is_job_valid(current_job, total_work_blocks, offsets_ptr); - if (!current_job_is_valid) { - break; + if (!job_has_work(current_job)) { + // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, + static_block_stride, total_work_blocks, work_blocks_X); + continue; + } + + const size_t tensor_id = current_job.tensor_id; + const size_t rows = current_job.rows; + const size_t cols = current_job.cols; + const BlockDescriptor current_block = decode_block(current_job, offsets_ptr); + + const size_t scale_stride_rowwise = + DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); + + const size_t tensor_base = current_block.tensor_base; + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + const size_t block_id_Y = current_block.block_id_Y; + const size_t block_id_X = current_block.block_id_X; + const size_t block_offset_Y = current_block.block_offset_Y; + const size_t block_offset_X = current_block.block_offset_X; + + e8m0_t *const scales_rowwise = + scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); + e8m0_t *const scales_colwise = + scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); + + const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + const int dbias_offset_Y = block_id_Y; + const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + + const CUtensorMap &tensor_map_input = + is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; + const CUtensorMap &tensor_map_act_input = + is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; + const CUtensorMap &tensor_map_output_rowwise = is_single_tensor + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; + const CUtensorMap &tensor_map_output_colwise = is_single_tensor + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; + + if (leading_thread && (!is_single_tensor) && (last_acquired_tensor_id != tensor_id)) { + fence_acquire_tensormap(&tensor_map_input); + if constexpr (COMPUTE_ACTIVATIONS) { + fence_acquire_tensormap(&tensor_map_act_input); } - if (!job_has_work(current_job)) { - // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. - advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, - static_block_stride, total_work_blocks, work_blocks_X); - continue; + if constexpr (ROWWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_rowwise); } - - const size_t tensor_id = current_job.tensor_id; - const size_t rows = current_job.rows; - const size_t cols = current_job.cols; - const BlockDescriptor current_block = decode_block(current_job, offsets_ptr); - - const size_t scale_stride_rowwise = - DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); - const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); - - const size_t tensor_base = current_block.tensor_base; - const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) - ? static_cast(offsets_ptr[tensor_id]) - : tensor_base; - const size_t block_id_Y = current_block.block_id_Y; - const size_t block_id_X = current_block.block_id_X; - const size_t block_offset_Y = current_block.block_offset_Y; - const size_t block_offset_X = current_block.block_offset_X; - - e8m0_t *const scales_rowwise = - scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X); - e8m0_t *const scales_colwise = - scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y); - - const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X; - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - - const int dbias_offset_Y = block_id_Y; - const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; - - const CUtensorMap &tensor_map_input = - is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; - const CUtensorMap &tensor_map_act_input = - is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; - const CUtensorMap &tensor_map_output_rowwise = is_single_tensor - ? tensor_map_output_rowwise_static - : g_tensor_maps_output_rowwise[tensor_id]; - const CUtensorMap &tensor_map_output_colwise = is_single_tensor - ? tensor_map_output_colwise_static - : g_tensor_maps_output_colwise[tensor_id]; - - if (leading_thread && (!is_single_tensor) && (last_acquired_tensor_id != tensor_id)) { - fence_acquire_tensormap(&tensor_map_input); - if constexpr (COMPUTE_ACTIVATIONS) { - fence_acquire_tensormap(&tensor_map_act_input); - } - if constexpr (ROWWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_rowwise); - } - if constexpr (COLWISE_SCALING) { - fence_acquire_tensormap(&tensor_map_output_colwise); - } - last_acquired_tensor_id = tensor_id; + if constexpr (COLWISE_SCALING) { + fence_acquire_tensormap(&tensor_map_output_colwise); } + last_acquired_tensor_id = tensor_id; + } - int buff_in = 0; + int buff_in = 0; - // Prime the pipeline with the first PREFETCH_STAGES slices of the current block. + // Prime the pipeline with the first PREFETCH_STAGES slices of the current block. #pragma unroll - for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { - const size_t buff = stage; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; - uint64_t *barrier = &IN_buff_readable_mbar[buff]; - prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, - tensor_map_act_input, global_offset_X, global_offset_Y, - buff_offset, shmem_buff_size, barrier, leading_thread); - } + for (int stage = 0; stage < PREFETCH_STAGES; ++stage) { + const size_t buff = stage; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; + uint64_t *barrier = &IN_buff_readable_mbar[buff]; + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, + tensor_map_act_input, global_offset_X, global_offset_Y, + buff_offset, shmem_buff_size, barrier, leading_thread); + } - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { #pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; } + } - // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). + // Process one [CHUNK_DIM_Y x CHUNK_DIM_X] block in STAGES slices (32 rows each). #pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - if (stage < STAGES - PREFETCH_STAGES) { - const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; - const size_t next_prefetch_stage = stage + PREFETCH_STAGES; - const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; - - const size_t global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; - - uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; - prefetch_input_stage( - in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, global_offset_X, - global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); - } + for (int stage = 0; stage < STAGES; ++stage) { + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + if (stage < STAGES - PREFETCH_STAGES) { + const size_t next_prefetch_buff = (buff_in + PREFETCH_STAGES) % BUFFS_NUM; + const size_t next_prefetch_stage = stage + PREFETCH_STAGES; + const size_t next_prefetch_stage_offset_Y = next_prefetch_stage * BUFF_DIM_Y; + + const size_t global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_prefetch_buff_offset = next_prefetch_buff * BUFF_DIM; - ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); - IN_buff_readable_parity[buff_in] ^= 1; - ptx::cp_async_bulk_wait_group_read(); + uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff]; + prefetch_input_stage( + in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, global_offset_X, + global_offset_Y, next_prefetch_buff_offset, shmem_buff_size, barrier, leading_thread); + } - const size_t buff = buff_in; - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_amax = - process_colwise_stage( - buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, - scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, - cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); - } + ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], + IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in] ^= 1; + ptx::cp_async_bulk_wait_group_read(); - if constexpr (ROWWISE_SCALING) { - thread_amax = - process_rowwise_stage( - buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, - bank_group, scales_offset_Y_rowwise, scales_offset_X_rowwise, - scale_stride_rowwise, rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, - cached_act_sh, out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); - } + const size_t buff = buff_in; + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_amax = + process_colwise_stage( + buff, stage, tid_X_colwise, scales_offset_Y_colwise, scales_offset_X_colwise, + scale_stride_colwise, tensor_base_for_scales, rows, cols, in_sh, act_in_sh, + cached_act_sh, out_colwise_data_sh, scales_colwise, partial_dbias_colwise); + } - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); + if constexpr (ROWWISE_SCALING) { + thread_amax = + process_rowwise_stage( + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, + bank_group, scales_offset_Y_rowwise, scales_offset_X_rowwise, + scale_stride_rowwise, rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, + cached_act_sh, out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); + } - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - // Publish the stage from shared memory into global outputs via TMA. - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; - store_output_stage( - out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, - tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, - leading_thread); + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); - buff_in = (buff_in + 1) % BUFFS_NUM; - } + // Publish the stage from shared memory into global outputs via TMA. + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + store_output_stage( + out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, + tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, + leading_thread); - if constexpr (IS_DBIAS) { - if (is_single_tensor) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - float *partial_dbias_rowwise = reinterpret_cast(dshmem); + buff_in = (buff_in + 1) % BUFFS_NUM; + } + + if constexpr (IS_DBIAS) { + if (is_single_tensor) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; - } - } - __syncthreads(); + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } - const int dbias_stride = cols; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } + const int dbias_stride = cols; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } } - - advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, - total_work_blocks, work_blocks_X); } - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, + total_work_blocks, work_blocks_X); + } - if (leading_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (leading_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } - if (leading_thread) { + if (leading_thread) { #pragma unroll - for (int buff = 0; buff < BUFFS_NUM; ++buff) { - ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); - } + for (int buff = 0; buff < BUFFS_NUM; ++buff) { + ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} } // namespace group_quantize_kernel template Date: Wed, 18 Mar 2026 15:19:31 +0000 Subject: [PATCH 41/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 69 +++++++++---------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 1ad6421f61..dede34f489 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -738,29 +738,29 @@ __device__ __forceinline__ float process_rowwise_stage( } } else { in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE / 2; ++e) { ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); } } } - thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } else if constexpr (IS_CACHED_ACT_OP) { __syncthreads(); IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); if constexpr (std::is_same_v) { - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); } } else { - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE; e += 2) { const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); @@ -768,11 +768,10 @@ __device__ __forceinline__ float process_rowwise_stage( } } if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + thread_amax = static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); } } else { - #pragma unroll +#pragma unroll for (int w = 0; w < WAVES; ++w) { const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; @@ -785,7 +784,7 @@ __device__ __forceinline__ float process_rowwise_stage( if constexpr (IS_DACT) { act_in.load_from(&act_in_sh[shmem_offset_rowwise]); } - #pragma unroll +#pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; float elt = static_cast(in.data.elt[e]); @@ -875,14 +874,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, - const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, - const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim, - const int64_t *const __restrict__ offsets_ptr, - const int64_t *const __restrict__ first_dims_ptr, - const int64_t *const __restrict__ last_dims_ptr, - e8m0_t *const __restrict__ scales_rowwise_ptr, e8m0_t *const __restrict__ scales_colwise_ptr, - const float *__restrict__ noop, float *const __restrict__ dbias_workspace, - float *const __restrict__ amax_ptr, const size_t work_blocks_X, const size_t work_blocks_Y) { + const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr, + const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr, + e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop, + float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr, + const size_t work_blocks_X, const size_t work_blocks_Y) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; @@ -899,8 +897,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel (SCALING_TYPE == ScalingType::COLWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); constexpr ShapeRepresentation shape_rep = SHAPE_REP; - constexpr bool is_single_tensor = - (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + constexpr bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); const bool leading_thread = (threadIdx.x == 0); @@ -992,8 +989,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } if (!job_has_work(current_job)) { // Zero-sized tensors are valid grouped-tensor entries; skip them and keep scheduling work. - advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, - static_block_stride, total_work_blocks, work_blocks_X); + advance_to_next_job(job_finished, ctaid_X, ctaid_Y, static_next_block_id, static_block_stride, + total_work_blocks, work_blocks_X); continue; } @@ -1002,8 +999,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t cols = current_job.cols; const BlockDescriptor current_block = decode_block(current_job, offsets_ptr); - const size_t scale_stride_rowwise = - DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); + const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast(32)), 4); const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128); const size_t tensor_base = current_block.tensor_base; @@ -1040,11 +1036,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const CUtensorMap &tensor_map_act_input = is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id]; const CUtensorMap &tensor_map_output_rowwise = is_single_tensor - ? tensor_map_output_rowwise_static - : g_tensor_maps_output_rowwise[tensor_id]; + ? tensor_map_output_rowwise_static + : g_tensor_maps_output_rowwise[tensor_id]; const CUtensorMap &tensor_map_output_colwise = is_single_tensor - ? tensor_map_output_colwise_static - : g_tensor_maps_output_colwise[tensor_id]; + ? tensor_map_output_colwise_static + : g_tensor_maps_output_colwise[tensor_id]; if (leading_thread && (!is_single_tensor) && (last_acquired_tensor_id != tensor_id)) { fence_acquire_tensormap(&tensor_map_input); @@ -1071,9 +1067,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t global_offset_X = block_offset_X; const size_t buff_offset = buff * BUFF_DIM; uint64_t *barrier = &IN_buff_readable_mbar[buff]; - prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, - tensor_map_act_input, global_offset_X, global_offset_Y, - buff_offset, shmem_buff_size, barrier, leading_thread); + prefetch_input_stage(in_sh, act_in_sh, tensor_map_input, tensor_map_act_input, + global_offset_X, global_offset_Y, buff_offset, + shmem_buff_size, barrier, leading_thread); } float partial_dbias_colwise = 0.0f; @@ -1105,7 +1101,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in], - IN_buff_readable_parity[buff_in]); + IN_buff_readable_parity[buff_in]); IN_buff_readable_parity[buff_in] ^= 1; ptx::cp_async_bulk_wait_group_read(); @@ -1124,10 +1120,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel thread_amax = process_rowwise_stage( - buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, - bank_group, scales_offset_Y_rowwise, scales_offset_X_rowwise, - scale_stride_rowwise, rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, - cached_act_sh, out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); + buff, stage_offset_Y, thread_offset_Y_rowwise, thread_offset_X_rowwise, bank_group, + scales_offset_Y_rowwise, scales_offset_X_rowwise, scale_stride_rowwise, + rowwise_scale_is_within_bounds, cols, in_sh, act_in_sh, cached_act_sh, + out_rowwise_data_sh, scales_rowwise, thread_dbias_rowwise); } __builtin_assume(block_amax >= 0); @@ -1143,8 +1139,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const int buff_offset = buff * BUFF_DIM; store_output_stage( out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, - tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, - leading_thread); + tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, leading_thread); buff_in = (buff_in + 1) % BUFFS_NUM; } From 3d2d1ba111a66dbf709c353b9dc538ef929e0a2e Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 15:59:30 +0000 Subject: [PATCH 42/51] Fix Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/quantize_mxfp8.cuh | 138 +++++++++--------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 3142b39272..70a68132ad 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -641,75 +641,75 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - // if (specialized::hasSpec() && - // !WITH_GEMM_SWIZZLED_SCALES) { - // switch (scaling_type) { - // case ScalingType::ROWWISE: { - // using traits = specialized::CastTraits; - // auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - - // cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - // traits::smem); - - // dim3 block(traits::threadLayout::num, traits::warpLayout::N, - // traits::warpLayout::M); - // dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN, - // (rows + traits::blockDimM - 1) / traits::blockDimM); - // kernel<<>>( - // reinterpret_cast(input.data.dptr), - // reinterpret_cast(output->data.dptr), - // scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); - - // break; - // } - // case ScalingType::COLWISE: { - // NVTE_WARN("Colwise scaling will fallback to original kernel."); - // break; - // } - // case ScalingType::BIDIMENSIONAL: { - // using traits = specialized::CastTraits; - // auto kernel = specialized::quantize_mxfp8_kernel_cast_only; - - // cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - // traits::smem); - // // TMA for loading, so that we don't need STS for transposing - // alignas(64) CUtensorMap tensor_map_input{}; - // constexpr size_t input_type_bit_size = TypeInfo::size; - // create_2D_tensor_map(tensor_map_input, input.data, rows, cols, - // traits::blockIterDim::M, traits::blockIterDim::N, - // /*stride_elems=*/cols, - // /*offset_elems=*/0, input_type_bit_size, - // traits::input_swizzle_pattern); - - // alignas(64) CUtensorMap tensor_map_rowwise_output{}; - // alignas(64) CUtensorMap tensor_map_colwise_output{}; - // constexpr size_t output_type_bit_size = TypeInfo::size; - // create_2D_tensor_map(tensor_map_rowwise_output, output->data, rows, cols, - // traits::blockIterDim::M, traits::blockIterDim::N, - // /*stride_elems=*/cols, - // /*offset_elems=*/0, output_type_bit_size, - // traits::output_swizzle_pattern); - // create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data, rows, - // cols, traits::blockIterDim::M, traits::blockIterDim::N, - // cols, 0, output_type_bit_size, - // traits::output_swizzle_pattern); - - // dim3 block(traits::rowThreadLayout::num, traits::numWarps); - // dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N, - // (rows + traits::blockDIM::M - 1) / traits::blockDIM::M); - // kernel<<>>( - // tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output, - // scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - // scale_stride_colwise); - - // break; - // } - // default: { - // NVTE_ERROR("Invalid scaling type."); - // } - // } - // return; - // } + if (specialized::hasSpec() && + !WITH_GEMM_SWIZZLED_SCALES) { + switch (scaling_type) { + case ScalingType::ROWWISE: { + using traits = specialized::CastTraits; + auto kernel = specialized::quantize_mxfp8_kernel_cast_only; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + traits::smem); + + dim3 block(traits::threadLayout::num, traits::warpLayout::N, + traits::warpLayout::M); + dim3 grid((cols + traits::blockDimN - 1) / traits::blockDimN, + (rows + traits::blockDimM - 1) / traits::blockDimM); + kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + scales_rowwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + + break; + } + case ScalingType::COLWISE: { + NVTE_WARN("Colwise scaling will fallback to original kernel."); + break; + } + case ScalingType::BIDIMENSIONAL: { + using traits = specialized::CastTraits; + auto kernel = specialized::quantize_mxfp8_kernel_cast_only; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + traits::smem); + // TMA for loading, so that we don't need STS for transposing + alignas(64) CUtensorMap tensor_map_input{}; + constexpr size_t input_type_bit_size = TypeInfo::size; + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, + traits::blockIterDim::M, traits::blockIterDim::N, + /*stride_elems=*/cols, + /*offset_elems=*/0, input_type_bit_size, + traits::input_swizzle_pattern); + + alignas(64) CUtensorMap tensor_map_rowwise_output{}; + alignas(64) CUtensorMap tensor_map_colwise_output{}; + constexpr size_t output_type_bit_size = TypeInfo::size; + create_2D_tensor_map(tensor_map_rowwise_output, output->data, rows, cols, + traits::blockIterDim::M, traits::blockIterDim::N, + /*stride_elems=*/cols, + /*offset_elems=*/0, output_type_bit_size, + traits::output_swizzle_pattern); + create_2D_tensor_map(tensor_map_colwise_output, output->columnwise_data, rows, + cols, traits::blockIterDim::M, traits::blockIterDim::N, + cols, 0, output_type_bit_size, + traits::output_swizzle_pattern); + + dim3 block(traits::rowThreadLayout::num, traits::numWarps); + dim3 grid((cols + traits::blockDIM::N - 1) / traits::blockDIM::N, + (rows + traits::blockDIM::M - 1) / traits::blockDIM::M); + kernel<<>>( + tensor_map_input, tensor_map_rowwise_output, tensor_map_colwise_output, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + + break; + } + default: { + NVTE_ERROR("Invalid scaling type."); + } + } + return; + } alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; From ac75ea25b09cc3d51affb86ebb715be925f67b79 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 16:51:10 +0000 Subject: [PATCH 43/51] Fixes Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8.cu | 46 +++++++++---------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 1 + 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index e22b76ff89..ccc605c060 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -522,25 +522,25 @@ void performTest_x2(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { - // {1, 16}, - // {16, 48}, - // {65, 96}, - // {128, 128}, - // {256, 256}, - // {993, 512}, - // {511, 6144}, - // {8192, 128}, - // {2048, 160}, - // {577, 1632}, - // {1024}, - // {8, 32, 1024}, - // {16, 8, 4, 512}, + {1, 16}, + {16, 48}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {511, 6144}, + {8192, 128}, + {2048, 160}, + {577, 1632}, + {1024}, + {8, 32, 1024}, + {16, 8, 4, 512}, {8192, 7168}, }; std::vector> block_sizes = { - // {1, 32}, - // {32, 1}, + {1, 32}, + {32, 1}, {32, 32}, }; @@ -554,16 +554,16 @@ std::vector input_scenarios = { std::vector processing_methods = { ProcessingMethod::CAST_ONLY, - // ProcessingMethod::CAST_DBIAS, - // ProcessingMethod::CAST_DBIAS_DACT, - // ProcessingMethod::CAST_DACT, - // ProcessingMethod::CAST_ACT, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, }; // Only GeLU activation tests are supported std::vector Activation_types = { ActivationType::Identity, - // ActivationType::GeLU, + ActivationType::GeLU, // ActivationType::SiLU, // ActivationType::ReLU, // ActivationType::QGeLU, @@ -692,10 +692,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), ::testing::ValuesIn(block_sizes), - ::testing::Values(DType::kBFloat16), - ::testing::Values(DType::kFloat8E4M3), - // ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - // ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index dede34f489..8329a747af 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -1055,6 +1055,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } last_acquired_tensor_id = tensor_id; } + __syncthreads(); int buff_in = 0; From 3fc8a3e93d2924207e1871e5a10e95540f3908b7 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 17:22:24 +0000 Subject: [PATCH 44/51] Fix Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 8329a747af..9c277988f7 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -78,6 +78,9 @@ static_assert(BUFF_DIM_Y == 32); constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; static_assert(STAGES >= 1); +static_assert(CHUNK_DIM_Y % SCALE_DIM_Y == 0); +static_assert(CHUNK_DIM_X % SCALE_DIM_X == 0); + // Number of 1-byte elements that span 32 banks (4-byte each) of shared memory constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 @@ -275,7 +278,7 @@ __device__ __forceinline__ bool is_job_valid(const JobDescriptor &job, const size_t tensor_offset_from_start = job.block_global_offset - tensor_start_offset; const size_t block_offset_Y_in_tensor = tensor_offset_from_start / job.cols; const size_t block_offset_X_in_tensor = tensor_offset_from_start % job.cols; - if (block_offset_Y_in_tensor >= job.rows || block_offset_X_in_tensor >= job.cols) { + if (block_offset_Y_in_tensor >= job.rows) { return false; } @@ -720,7 +723,7 @@ __device__ __forceinline__ float process_rowwise_stage( "max.xorsign.abs.bf16x2 x01, x01, x23; \n\t" "max.xorsign.abs.bf16x2 %1, %1, x01; \n" "}\n" - : "+l"(reinterpret_cast(in_IType4[w])), + : "=l"(reinterpret_cast(in_IType4[w])), "+r"(reinterpret_cast(thread_amax_2x)) : "r"(src_smem_ptr)); } else { @@ -732,7 +735,7 @@ __device__ __forceinline__ float process_rowwise_stage( "max.xorsign.abs.f16x2 x01, x01, x23; \n\t" "max.xorsign.abs.f16x2 %1, %1, x01; \n" "}\n" - : "+l"(reinterpret_cast(in_IType4[w])), + : "=l"(reinterpret_cast(in_IType4[w])), "+r"(reinterpret_cast(thread_amax_2x)) : "r"(src_smem_ptr)); } From 62dfbd48ee9d6158f73004bf433493c24cc36f5c Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 18:29:07 +0000 Subject: [PATCH 45/51] Fixed test suite Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 12 +++++++++--- .../graph_safe_group_hadamard_transform.cu | 7 ------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index e5c62282f5..73952e5292 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -672,6 +672,7 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 512,160, 128,0,0,256}, {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, @@ -774,8 +775,13 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { GTEST_SKIP(); } // Skip fused tests in fast math is enabled. - if ((processing_method != ProcessingMethod::CAST_ONLY) && use_fast_math) { - GTEST_SKIP(); + if (use_fast_math) { + if (processing_method != ProcessingMethod::CAST_ONLY) { + GTEST_SKIP(); + } + if ((input_type != DType::kBFloat16) || (input_type != DType::kFloat16)) { + GTEST_SKIP(); + } } bool rowwise = false; @@ -848,7 +854,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(scaling_directions), ::testing::ValuesIn(input_config), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::Values(true, false)), [](const testing::TestParamInfo& info) { const ProcessingMethod method = std::get<0>(info.param); diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 04e965a9da..0fb73cc439 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -25,13 +25,6 @@ namespace { constexpr int kMaxTensorsPerKernel = 64; constexpr int kThreadsPerWarp = 32; -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 -}; - __device__ __forceinline__ size_t get_current_tensor_id( const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset, const size_t first_logical_dim, const size_t last_logical_dim, From 1b6938a7315994e87ff75ceb543fc819f0b2d4e7 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 19:40:39 +0000 Subject: [PATCH 46/51] Fixed test suite Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 73952e5292..8c950489b7 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -668,19 +668,19 @@ std::vector scaling_directions = { // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - {SAME_BOTH_DIMS, 1, 128,128}, - {SAME_BOTH_DIMS, 2, 256,128}, - {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + // {SAME_BOTH_DIMS, 1, 128,128}, + // {SAME_BOTH_DIMS, 2, 256,128}, + // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, {VARYING_FIRST_DIM, 4, 512,160, 128,0,0,256}, - {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, - {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - // Empty tensor in the middle of the group must not terminate the persistent work loop. - {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, - {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + // {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, + // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + // // Empty tensor in the middle of the group must not terminate the persistent work loop. + // {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, + // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, }; } // namespace @@ -779,7 +779,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { if (processing_method != ProcessingMethod::CAST_ONLY) { GTEST_SKIP(); } - if ((input_type != DType::kBFloat16) || (input_type != DType::kFloat16)) { + if ((input_type != DType::kBFloat16) && (input_type != DType::kFloat16)) { GTEST_SKIP(); } } From add9e9c7ba39719071f241a3fd2ef65f760c3521 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 18 Mar 2026 22:00:31 +0000 Subject: [PATCH 47/51] Fixes per the review Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 93 ++++++++++--------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 9 +- 2 files changed, 55 insertions(+), 47 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 8c950489b7..75e8058a6a 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -668,19 +668,19 @@ std::vector scaling_directions = { // {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]} std::vector> input_config = { - // {SAME_BOTH_DIMS, 1, 128,128}, - // {SAME_BOTH_DIMS, 2, 256,128}, - // {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - // {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {SAME_BOTH_DIMS, 1, 128,128}, + {SAME_BOTH_DIMS, 2, 256,128}, + {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, + {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, + {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, + {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, + {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, + {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + // Empty tensor in the middle of the group must not terminate the persistent work loop. {VARYING_FIRST_DIM, 4, 512,160, 128,0,0,256}, - // {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, - // {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, - // {VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304}, - // {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, - // {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, - // // Empty tensor in the middle of the group must not terminate the persistent work loop. - // {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, - // {VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640}, + {VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128}, }; } // namespace @@ -845,6 +845,40 @@ std::string to_string(const ActivationKind activation) { } } +std::string MakeGroupedFusedCastMXFP8TestName( + const testing::TestParamInfo& info) { + const ProcessingMethod method = std::get<0>(info.param); + std::string name = to_string(method); + name += "X" + to_string(std::get<1>(info.param)); + + switch (std::get<2>(info.param)) { + case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break; + case ScalingDirection::COLWISE: name += "_COLWISE_"; break; + case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break; + } + + const std::vector input = std::get<3>(info.param); + + switch (static_cast(input[0])) { + case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; + case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; + case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; + case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; + } + + name += "_N_" + std::to_string(input[1]); + + name += "_SHAPE_" + std::to_string(input[2]) + "X" + std::to_string(input[3]); + + name += "_" + test::typeName(std::get<4>(info.param)) + + "_" + test::typeName(std::get<5>(info.param)); + + if (std::get<6>(info.param)) { + name += "_FASTMATH"; + } + return name; +} + INSTANTIATE_TEST_SUITE_P( OperatorTest, GroupedFusedCastMXFP8TestSuite, @@ -856,37 +890,4 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::Values(true, false)), - [](const testing::TestParamInfo& info) { - const ProcessingMethod method = std::get<0>(info.param); - std::string name = to_string(method); - name += "X" + to_string(std::get<1>(info.param)); - - switch (std::get<2>(info.param)) { - case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break; - case ScalingDirection::COLWISE: name += "_COLWISE_"; break; - case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break; - } - - const std::vector input = std::get<3>(info.param); - - switch(static_cast(input[0])) { - case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break; - case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break; - case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break; - case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break; - }; - - name += "_N_" + std::to_string(input[1]); - - name += "_SHAPE_" + - std::to_string(input[2]) + - "X" + std::to_string(input[3]); - - name += "_" + test::typeName(std::get<4>(info.param)) + - "_" + test::typeName(std::get<5>(info.param)); - - if (std::get<6>(info.param)) { - name += "_FASTMATH"; - } - return name; - }); + MakeGroupedFusedCastMXFP8TestName); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 9c277988f7..4c42992e42 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -475,7 +475,9 @@ __device__ __forceinline__ void store_output_stage(OType *out_rowwise_data_sh, reinterpret_cast(&tensor_map_output_colwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); } - ptx::cp_async_bulk_commit_group(); + if constexpr (ROWWISE_SCALING || COLWISE_SCALING) { + ptx::cp_async_bulk_commit_group(); + } } template = 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + constexpr bool NON_FP32_CAST_ONLY = + NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { return; } } + if constexpr (USE_FAST_MATH && !NON_FP32_CAST_ONLY) { + return; + } constexpr bool ROWWISE_SCALING = (SCALING_TYPE == ScalingType::ROWWISE) || (SCALING_TYPE == ScalingType::BIDIMENSIONAL); From 86abab845b52301eb12b48806eaabef564c143bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Mar 2026 22:01:30 +0000 Subject: [PATCH 48/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 4c42992e42..92e2cc0f22 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -889,7 +889,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - constexpr bool NON_FP32_CAST_ONLY = + constexpr bool NON_FP32_CAST_ONLY = NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v); if constexpr (NO_ACTIVATIONS) { From 4e28663a43446e03fa769598da79e219130e47e9 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 19 Mar 2026 15:46:12 +0000 Subject: [PATCH 49/51] Modifications per the review Signed-off-by: Oleg Goncharov --- .../cast/mxfp8/group_quantize_mxfp8.cuh | 107 ++++++++---------- 1 file changed, 47 insertions(+), 60 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 92e2cc0f22..198fc04206 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -38,15 +38,14 @@ __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_T __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; struct TunableConfig { - static constexpr size_t CHUNK_DIM_Y = 128; - static constexpr size_t CHUNK_DIM_X = 128; - static constexpr size_t THREADS_PER_CHUNK = 128; - static constexpr size_t PREFETCH_STAGES = 1; + static constexpr uint CHUNK_DIM_Y = 128; + static constexpr uint CHUNK_DIM_X = 128; + static constexpr uint THREADS_PER_CHUNK = 128; // true -> static persistent grid-stride scheduler // false -> non-persistent one-job-per-CTA execution static constexpr bool PERSISTENT = true; // Launch static persistent grid as (SM_count * STATIC_PERSISTENT_BLOCKS_PER_SM, 1, 1). - static constexpr size_t STATIC_PERSISTENT_BLOCKS_PER_SM = 24; + static constexpr uint STATIC_PERSISTENT_BLOCKS_PER_SM = 24; }; constexpr bool PERSISTENT = TunableConfig::PERSISTENT; @@ -56,36 +55,36 @@ static_assert(!PERSISTENT || (TunableConfig::STATIC_PERSISTENT_BLOCKS_PER_SM > 0 constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; -constexpr size_t PREFETCH_STAGES = TunableConfig::PREFETCH_STAGES; -constexpr size_t BUFFS_NUM = PREFETCH_STAGES + 1; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; +constexpr uint PREFETCH_STAGES = 1; +constexpr uint BUFFS_NUM = PREFETCH_STAGES + 1; +constexpr uint PACK_SIZE = 4; +constexpr uint WAVES = SCALE_DIM_X / PACK_SIZE; -constexpr size_t CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; -constexpr size_t CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; -constexpr size_t THREADS_PER_CHUNK = TunableConfig::THREADS_PER_CHUNK; +constexpr uint CHUNK_DIM_Y = TunableConfig::CHUNK_DIM_Y; +constexpr uint CHUNK_DIM_X = TunableConfig::CHUNK_DIM_X; +constexpr uint THREADS_PER_CHUNK = TunableConfig::THREADS_PER_CHUNK; constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X; -constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; -constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; +constexpr uint THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; +constexpr uint THREADS_Y = THREADS_PER_CHUNK / THREADS_X; -constexpr size_t BUFF_DIM_Y = THREADS_Y; -constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +constexpr uint BUFF_DIM_Y = THREADS_Y; +constexpr uint BUFF_DIM_X = CHUNK_DIM_X; +constexpr uint BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; static_assert(BUFF_DIM_Y == 32); -constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; +constexpr uint STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; static_assert(STAGES >= 1); static_assert(CHUNK_DIM_Y % SCALE_DIM_Y == 0); static_assert(CHUNK_DIM_X % SCALE_DIM_X == 0); // Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 +constexpr uint TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 +constexpr uint THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 template __device__ __forceinline__ size_t @@ -458,8 +457,9 @@ __device__ __forceinline__ void store_output_stage(OType *out_rowwise_data_sh, OType *out_colwise_data_sh, const CUtensorMap &tensor_map_output_rowwise, const CUtensorMap &tensor_map_output_colwise, - const int global_offset_X, - const int global_offset_Y, const int buff_offset, + const size_t global_offset_X, + const size_t global_offset_Y, + const size_t buff_offset, const bool leading_thread) { if (!leading_thread) { return; @@ -815,8 +815,8 @@ __device__ __forceinline__ float process_rowwise_stage( const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; + const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const size_t stage_scales_offset_X = scales_offset_X_rowwise; size_t scale_idx = 0; if constexpr (WITH_GEMM_SWIZZLED_SCALES) { @@ -952,19 +952,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel float block_amax = 0.0f; - __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; - - // Initialize barriers shared by the entire CTA: - // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. - if (leading_thread) { -#pragma unroll - for (int buff = 0; buff < BUFFS_NUM; ++buff) { - ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1); - } - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); - const size_t total_work_blocks = work_blocks_X * work_blocks_Y; const size_t launch_block_id = blockIdx.y * gridDim.x + blockIdx.x; @@ -986,6 +973,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel bool job_finished = false; size_t last_acquired_tensor_id = num_tensors; + __shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM]; + // Initialize barriers shared by the entire CTA: + // - IN_buff_readable_mbar tracks per-buffer TMA global->shared completion. + initialize_barriers(IN_buff_readable_mbar, leading_thread); + // Main work loop: decode current job, prime its pipeline, then process all 32-row stages. while (!job_finished) { // Decode CTA assignment into logical tensor coordinates and validate bounds. @@ -1038,8 +1030,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - const int dbias_offset_Y = block_id_Y; - const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; + const size_t dbias_offset_Y = block_id_Y; + const size_t dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x; const CUtensorMap &tensor_map_input = is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; @@ -1145,9 +1137,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel __syncthreads(); // Publish the stage from shared memory into global outputs via TMA. - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; store_output_stage( out_rowwise_data_sh, out_colwise_data_sh, tensor_map_output_rowwise, tensor_map_output_colwise, global_offset_X, global_offset_Y, buff_offset, leading_thread); @@ -1163,18 +1155,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel } else { float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = + const size_t shmem_thread_offset = tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; + const size_t j = w * PACK_SIZE + e; + const size_t shmem_elt_idx = swizzled_group_offset + e; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } @@ -1186,8 +1178,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } - const int dbias_stride = cols; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const size_t dbias_stride = cols; + const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); if (!col_out_of_bounds_dbias) { dbias_workspace[dbias_idx] = thread_partial_dbias; @@ -1209,12 +1201,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel atomicMaxFloat(amax_ptr, block_amax); } - if (leading_thread) { -#pragma unroll - for (int buff = 0; buff < BUFFS_NUM; ++buff) { - ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]); - } - } + destroy_barriers(IN_buff_readable_mbar, leading_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace group_quantize_kernel @@ -1286,14 +1273,14 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations size_t work_blocks_Y = 0; if (is_single_tensor) { - work_blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - work_blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + work_blocks_Y = DIVUP(first_logical_dim, static_cast(CHUNK_DIM_Y)); + work_blocks_X = DIVUP(last_logical_dim, static_cast(CHUNK_DIM_X)); } else { NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); work_blocks_Y = 1; - work_blocks_X = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); + work_blocks_X = DIVUP(elts_total, ELTS_PER_CHUNK); } size_t launch_blocks_X = work_blocks_X; @@ -1344,7 +1331,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias."); NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - const size_t dbias_workspace_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_workspace_rows = DIVUP(first_logical_dim, static_cast(CHUNK_DIM_Y)); const size_t dbias_workspace_cols = last_logical_dim; if (workspace->data.dptr == nullptr) { workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; From b6b86974f72b9b94b5a5078ccd079c6b559df925 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:48:21 +0000 Subject: [PATCH 50/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 198fc04206..e242394215 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -453,14 +453,11 @@ __device__ __forceinline__ void prefetch_input_stage( // Issue TMA shared->global transfer for one stage of outputs. template -__device__ __forceinline__ void store_output_stage(OType *out_rowwise_data_sh, - OType *out_colwise_data_sh, - const CUtensorMap &tensor_map_output_rowwise, - const CUtensorMap &tensor_map_output_colwise, - const size_t global_offset_X, - const size_t global_offset_Y, - const size_t buff_offset, - const bool leading_thread) { +__device__ __forceinline__ void store_output_stage( + OType *out_rowwise_data_sh, OType *out_colwise_data_sh, + const CUtensorMap &tensor_map_output_rowwise, const CUtensorMap &tensor_map_output_colwise, + const size_t global_offset_X, const size_t global_offset_Y, const size_t buff_offset, + const bool leading_thread) { if (!leading_thread) { return; } From 2ae38cb84450967b965a51cc82d6ae3fdb008101 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 19 Mar 2026 17:34:40 +0000 Subject: [PATCH 51/51] Assert the buffer size Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index e242394215..2350837dad 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -77,6 +77,7 @@ static_assert(BUFF_DIM_Y == 32); constexpr uint STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; static_assert(STAGES >= 1); +static_assert(CHUNK_DIM_Y % BUFF_DIM_Y == 0); static_assert(CHUNK_DIM_Y % SCALE_DIM_Y == 0); static_assert(CHUNK_DIM_X % SCALE_DIM_X == 0);