Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
107a865
Enabled persistency with WorkID Query feature
Oleg-Goncharov Mar 4, 2026
caf664f
Added a struct with tunable parameters
Oleg-Goncharov Mar 4, 2026
68dbc62
Added persistency with static scheduling
Oleg-Goncharov Mar 4, 2026
051d925
Fixed test cases
Oleg-Goncharov Mar 4, 2026
2f9a299
Ready for benchmarking
Oleg-Goncharov Mar 4, 2026
c040d59
Fixed out-of-boundary error
Oleg-Goncharov Mar 4, 2026
30c28fb
Tuned kernel parameters
Oleg-Goncharov Mar 4, 2026
977168e
Refactoring
Oleg-Goncharov Mar 4, 2026
885fcb9
Refactoring 2
Oleg-Goncharov Mar 4, 2026
d787847
Refactoring 3
Oleg-Goncharov Mar 4, 2026
79c1ac2
Removed the dynamic (WorkID Query) persistency
Oleg-Goncharov Mar 5, 2026
12b8712
Ready for PR
Oleg-Goncharov Mar 5, 2026
2812d55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
f24afb2
Fixes per the review
Oleg-Goncharov Mar 6, 2026
aa484a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2026
f066851
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 13, 2026
74722a5
Ready for benchmark
Oleg-Goncharov Mar 13, 2026
5c570cd
Ready for benchmark - Regular kernel
Oleg-Goncharov Mar 13, 2026
c5b1f7d
Added the source code to the profiler
Oleg-Goncharov Mar 13, 2026
3edcb5d
Added constructors to Job and Block descriptors
Oleg-Goncharov Mar 13, 2026
6e00237
Removed the prefetch overlapping between jobs
Oleg-Goncharov Mar 13, 2026
274f91e
Cache tensor ID
Oleg-Goncharov Mar 13, 2026
38b7e4e
ShapeRepresentation is not a template parameter
Oleg-Goncharov Mar 13, 2026
4405255
Removed redundant fence_proxy
Oleg-Goncharov Mar 13, 2026
8cad6e6
Refactoring
Oleg-Goncharov Mar 16, 2026
c6622d4
Used mixed precision FMA
Oleg-Goncharov Mar 17, 2026
e6a737c
Added Quantize parameters
Oleg-Goncharov Mar 17, 2026
7be1136
Added the fast math branch
Oleg-Goncharov Mar 17, 2026
4c2bed5
Added the fast math to cpp test suite
Oleg-Goncharov Mar 17, 2026
e296b0b
Align tests
Oleg-Goncharov Mar 17, 2026
e63eee9
Use STS instead of generic ST
Oleg-Goncharov Mar 17, 2026
6874206
Add zero-tensor cases
Oleg-Goncharov Mar 17, 2026
a02c71c
Used LDS instead of generic LD in colwise path
Oleg-Goncharov Mar 17, 2026
4c992b0
Used LDS instead of generic LD in rowwise
Oleg-Goncharov Mar 17, 2026
8ceeed0
Ready for merge
Oleg-Goncharov Mar 17, 2026
2c20675
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 17, 2026
ef973d7
Merge branch 'moe_mxfp8_benchmark' into pr_persistent_grouped_mxfp8_k…
Oleg-Goncharov Mar 17, 2026
f119d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2026
6874935
Uncommented test cases
Oleg-Goncharov Mar 18, 2026
f3e07e5
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 18, 2026
f985c01
Added FP16 Fast math path to rowwise processing
Oleg-Goncharov Mar 18, 2026
5068556
Refactoring
Oleg-Goncharov Mar 18, 2026
6c945d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
a38888c
Fixed lint
Oleg-Goncharov Mar 18, 2026
20e354a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
3d2d1ba
Fix
Oleg-Goncharov Mar 18, 2026
ac75ea2
Fixes
Oleg-Goncharov Mar 18, 2026
3fc8a3e
Fix
Oleg-Goncharov Mar 18, 2026
62dfbd4
Fixed test suite
Oleg-Goncharov Mar 18, 2026
1b6938a
Fixed test suite
Oleg-Goncharov Mar 18, 2026
c319671
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 18, 2026
add9e9c
Fixes per the review
Oleg-Goncharov Mar 18, 2026
86abab8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2026
4e28663
Modifications per the review
Oleg-Goncharov Mar 19, 2026
a6a9bb6
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 19, 2026
b6b8697
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2026
2ae38cb
Assert the buffer size
Oleg-Goncharov Mar 19, 2026
d87c5e1
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
Oleg-Goncharov Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/test_cast_mxfp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ std::vector<std::vector<size_t>> matrix_sizes = {
{1024},
{8, 32, 1024},
{16, 8, 4, 512},
{8192, 7168},
};

std::vector<std::pair<size_t, size_t>> block_sizes = {
Expand Down
132 changes: 88 additions & 44 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ void compute_ref(const ProcessingMethod processing_method,
const size_t rows,
const size_t cols,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise)
const size_t scales_stride_colwise,
const bool use_fast_math)
{
const size_t tile_size_Y = 32;
const size_t tile_size_X = 32;
Expand Down Expand Up @@ -129,7 +130,10 @@ void compute_ref(const ProcessingMethod processing_method,
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
float scale_reciprocal = exp2f_rcp(biased_exponent);
if (use_fast_math) {
scale_reciprocal = static_cast<float>(static_cast<InputType>(scale_reciprocal));
}

for (size_t j = j_min; j < j_max; ++j) {
const size_t idx = i * cols + j;
Expand All @@ -150,7 +154,10 @@ void compute_ref(const ProcessingMethod processing_method,
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent);
float scale_reciprocal = exp2f_rcp(biased_exponent);
if (use_fast_math) {
scale_reciprocal = static_cast<float>(static_cast<InputType>(scale_reciprocal));
}

for (size_t i = i_min; i < i_max; ++i) {
const size_t idx = i * cols + j;
Expand Down Expand Up @@ -241,7 +248,8 @@ void performTest(const ProcessingMethod processing_method,
const std::vector<size_t>& last_dims_h,
const std::vector<size_t>& offsets_h,
const bool rowwise,
const bool colwise) {
const bool colwise,
const bool use_fast_math) {
using namespace test;

DType itype = TypeInfo<InputType>::dtype;
Expand Down Expand Up @@ -272,9 +280,13 @@ void performTest(const ProcessingMethod processing_method,
const size_t elts = M * K;
elts_num += elts;

auto divide_round_up_blocks = [](const size_t N, const size_t M) -> size_t {
return (N == 0) ? 0 : 1 + (N - 1) / M;
};

const size_t unpadded_rowwise_blocks_Y = M;
const size_t unpadded_rowwise_blocks_X = divide_round_up(K, 32);
const size_t unpadded_colwise_blocks_Y = divide_round_up(M, 32);
const size_t unpadded_rowwise_blocks_X = divide_round_up_blocks(K, 32);
const size_t unpadded_colwise_blocks_Y = divide_round_up_blocks(M, 32);
const size_t unpadded_colwise_blocks_X = K;

rowwise_scales_first_dim[t] = round_up_to_nearest_multiple(unpadded_rowwise_blocks_Y, 128);
Expand Down Expand Up @@ -371,7 +383,7 @@ void performTest(const ProcessingMethod processing_method,

NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());

std::vector<size_t> dbias_logical_shape_vec= {num_tensors, cols};
std::vector<size_t> dbias_logical_shape_vec = {num_tensors, cols};
NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(),
dbias_logical_shape_vec.size());

Expand Down Expand Up @@ -496,14 +508,18 @@ void performTest(const ProcessingMethod processing_method,
out_scales_rowwise_ptr, out_scales_colwise_ptr,
ref_output_dbias_ptr, M, K,
scales_stride_rowwise,
scales_stride_colwise);
scales_stride_colwise,
use_fast_math);
}

QuantizationConfigWrapper quant_config;
quant_config.set_use_fast_math(use_fast_math);

// GPU
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_group_quantize(in_group_tensor, out_group_tensor, 0);
nvte_group_quantize_v2(in_group_tensor, out_group_tensor, quant_config, 0);
break;
}
case ProcessingMethod::CAST_DBIAS: {
Expand Down Expand Up @@ -554,6 +570,11 @@ void performTest(const ProcessingMethod processing_method,
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;

// Compare only allocated contiguous output range.
// In graph-safe mode logical shape may include trailing garbage beyond offsets_h.back().
const size_t compare_rows = 1;
const size_t compare_cols = elts_num;

if (rowwise) {
cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost);
Expand All @@ -566,7 +587,8 @@ void performTest(const ProcessingMethod processing_method,
const size_t mismatches_elts = 32 * mismatches_scales;

compare_scaled_elts<OutputType>("rowwise_output", out_data_rowwise_ref.data(),
out_data_rowwise_h.data(), rows, cols, true, mismatches_elts);
out_data_rowwise_h.data(), compare_rows, compare_cols,
true, mismatches_elts);
}

if (colwise) {
Expand All @@ -581,7 +603,8 @@ void performTest(const ProcessingMethod processing_method,
const size_t mismatches_elts = 32 * mismatches_scales;

compare_scaled_elts<OutputType>("colwise_output", out_data_colwise_ref.data(),
out_data_colwise_h.data(), rows, cols, false, mismatches_elts);
out_data_colwise_h.data(), compare_rows, compare_cols,
false, mismatches_elts);
}

if (compute_dbias) {
Expand Down Expand Up @@ -651,9 +674,13 @@ std::vector<std::vector<size_t>> input_config = {
{VARYING_FIRST_DIM, 3, 1024,144, 128,384,512},
{VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512},
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
{VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304},
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
{VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640},
// Empty tensor in the middle of the group must not terminate the persistent work loop.
{VARYING_FIRST_DIM, 4, 512,160, 128,0,0,256},
{VARYING_BOTH_DIMS, 3, 1,(128*128)+(128*128), 128,0,128, 128,0,128},
};

} // namespace
Expand All @@ -664,7 +691,8 @@ class GroupedFusedCastMXFP8TestSuite : public ::testing::TestWithParam
ScalingDirection,
std::vector<size_t>, // Config
transformer_engine::DType, // InputType
transformer_engine::DType // OutputType
transformer_engine::DType, // OutputType
bool
>> {};

TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
Expand All @@ -682,6 +710,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
const std::vector<size_t> input_config = std::get<3>(GetParam());
const DType input_type = std::get<4>(GetParam());
const DType output_type = std::get<5>(GetParam());
const bool use_fast_math = std::get<6>(GetParam());

const ShapeRepresentation shape_rep = static_cast<ShapeRepresentation>(input_config[0]);
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);
Expand Down Expand Up @@ -745,6 +774,15 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
|| processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) {
GTEST_SKIP();
}
// Skip fused tests in fast math is enabled.
if (use_fast_math) {
if (processing_method != ProcessingMethod::CAST_ONLY) {
GTEST_SKIP();
}
if ((input_type != DType::kBFloat16) && (input_type != DType::kFloat16)) {
GTEST_SKIP();
}
}

bool rowwise = false;
bool colwise = false;
Expand Down Expand Up @@ -779,7 +817,7 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType,
performTest<InputType, OutputType>(processing_method, OP, shape_rep, num_tensors,
logical_shape, first_dims, last_dims, offsets,
rowwise, colwise);
rowwise, colwise, use_fast_math);
);
);
}
Expand Down Expand Up @@ -807,6 +845,40 @@ std::string to_string(const ActivationKind activation) {
}
}

std::string MakeGroupedFusedCastMXFP8TestName(
const testing::TestParamInfo<GroupedFusedCastMXFP8TestSuite::ParamType>& info) {
const ProcessingMethod method = std::get<0>(info.param);
std::string name = to_string(method);
name += "X" + to_string(std::get<1>(info.param));

switch (std::get<2>(info.param)) {
case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break;
case ScalingDirection::COLWISE: name += "_COLWISE_"; break;
case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break;
}

const std::vector<size_t> input = std::get<3>(info.param);

switch (static_cast<ShapeRepresentation>(input[0])) {
case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break;
case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break;
case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break;
case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break;
}

name += "_N_" + std::to_string(input[1]);

name += "_SHAPE_" + std::to_string(input[2]) + "X" + std::to_string(input[3]);

name += "_" + test::typeName(std::get<4>(info.param)) +
"_" + test::typeName(std::get<5>(info.param));

if (std::get<6>(info.param)) {
name += "_FASTMATH";
}
return name;
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GroupedFusedCastMXFP8TestSuite,
Expand All @@ -816,34 +888,6 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(scaling_directions),
::testing::ValuesIn(input_config),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)),
[](const testing::TestParamInfo<GroupedFusedCastMXFP8TestSuite::ParamType>& info) {
const ProcessingMethod method = std::get<0>(info.param);
std::string name = to_string(method);
name += "X" + to_string(std::get<1>(info.param));

switch (std::get<2>(info.param)) {
case ScalingDirection::ROWWISE: name += "_ROWWISE_"; break;
case ScalingDirection::COLWISE: name += "_COLWISE_"; break;
case ScalingDirection::BOTH: name += "_BIDIMENSIONAL_"; break;
}

const std::vector<size_t> input = std::get<3>(info.param);

switch(static_cast<ShapeRepresentation>(input[0])) {
case ShapeRepresentation::SAME_BOTH_DIMS: name += "SAME_BOTH_DIMS"; break;
case ShapeRepresentation::VARYING_FIRST_DIM: name += "VARYING_FIRST_DIM"; break;
case ShapeRepresentation::VARYING_LAST_DIM: name += "VARYING_LAST_DIM"; break;
case ShapeRepresentation::VARYING_BOTH_DIMS: name += "VARYING_BOTH_DIMS"; break;
};

name += "_N_" + std::to_string(input[1]);

name += "_SHAPE_" +
std::to_string(input[2]) +
"X" + std::to_string(input[3]);

name += "_" + test::typeName(std::get<4>(info.param)) +
"_" + test::typeName(std::get<5>(info.param));
return name;
});
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::Values(true, false)),
MakeGroupedFusedCastMXFP8TestName);
9 changes: 9 additions & 0 deletions transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
}

void nvte_group_quantize_v2(const NVTEGroupedTensor input, NVTEGroupedTensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_v2);
using namespace transformer_engine;

constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
}

void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias);
Expand Down
16 changes: 5 additions & 11 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,6 @@ namespace transformer_engine {
namespace dispatch {
namespace common {

enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};

inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % elems_per_block == 0);
Expand Down Expand Up @@ -100,14 +93,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const size_t tensor_id = blockIdx.y;
const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (first_logical_dim / num_tensors)
: first_dims_ptr[tensor_id];
: static_cast<size_t>(first_dims_ptr[tensor_id]);

const size_t rows = tensor_rows / chunk_dim_Y;
const size_t cols = last_logical_dim;

const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
const size_t dbias_in_offset_Y =
(shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (static_cast<size_t>(offsets_ptr[tensor_id]) / cols / chunk_dim_Y);

const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/cast/dispatch/quantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
workspace_tensor, &quant_config_cpp, stream);
break;
}
default:
Expand Down Expand Up @@ -450,7 +450,7 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
&quant_config_cpp, stream);
break;
}
default:
Expand Down
Loading
Loading