From 916ee87b614127e717cce4709b3cbb478cdb52be Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 17 Mar 2026 01:16:08 +0000 Subject: [PATCH 1/8] GEMM + Swiglu fused Grouped MLP for MXFP8 Signed-off-by: Kirthi Shankar Sivamani --- tests/cpp/operator/test_grouped_gemm.cu | 367 ++++++++- tests/cpp/operator/test_swizzle.cu | 74 ++ tests/pytorch/test_fusible_ops.py | 353 ++++++++- tests/pytorch/test_numerics.py | 269 +++++++ .../cast/mxfp8/group_quantize_mxfp8.cuh | 6 +- .../include/transformer_engine/swizzle.h | 15 + .../transformer_engine/transformer_engine.h | 17 + .../common/transformer_engine.cpp | 16 + transformer_engine/pytorch/csrc/quantizer.cpp | 14 +- .../pytorch/csrc/type_converters.cpp | 3 + transformer_engine/pytorch/module/base.py | 9 +- .../pytorch/module/grouped_linear.py | 5 + transformer_engine/pytorch/ops/_common.py | 112 +++ .../pytorch/ops/basic/grouped_linear.py | 208 ++++- .../pytorch/ops/fused/__init__.py | 9 + .../pytorch/ops/fused/backward_grouped_mlp.py | 741 ++++++++++++++++++ .../pytorch/ops/fused/forward_grouped_mlp.py | 596 ++++++++++++++ .../pytorch/tensor/grouped_tensor.py | 3 +- .../tensor/storage/grouped_tensor_storage.py | 123 ++- 19 files changed, 2858 insertions(+), 82 deletions(-) create mode 100644 transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py create mode 100644 transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 34bb729b25..58dcb81f0e 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -88,7 +88,6 @@ Tensor make_bf16_operand(const std::string& name, const std::vector& sha return t; } - // Creates an MXFP8 operand with the correct data layout for GEMM. // MXFP8 GEMM requirements (scales are along K dimension): // A transposed -> needs rowwise data/scales @@ -352,12 +351,378 @@ void run_grouped_gemm_case(const TestParams& params) { #endif // CUBLAS_VERSION >= 130200 } +void run_grouped_gemm_discrete_out_case(const TestParams& params) { +#if CUBLAS_VERSION < 130200 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{N, K} + : std::vector{K, N}; + const std::vector b_shape = params.transb ? std::vector{K, M} + : std::vector{M, K}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kMXFP8: { + A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, + /*is_A=*/true, params.transa)); + B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, + /*is_A=*/false, params.transb)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); + + // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = + Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, // grad + workspace_ptrs.data(), + false, // accumulate + false, // use_split_accumulator + 0, // sm_count + 0); + + GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode()); + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_list_tensors; + C_tensors.reserve(num_gemms); + D_list_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + if (!params.use_null_c) { + C_tensors.emplace_back( + Tensor("C" + std::to_string(i), std::vector{M, N}, DType::kBFloat16)); + } + D_list_tensors.emplace_back( + Tensor("D_list" + std::to_string(i), std::vector{M, N}, DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_list_tensors.back().rowwise_dptr(), 0, + bytes(D_list_tensors.back().rowwise_shape(), + D_list_tensors.back().dtype()))); + } + + std::vector C_list_ptrs; + std::vector D_list_ptrs; + if (!params.use_null_c) { + C_list_ptrs.reserve(num_gemms); + } + D_list_ptrs.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + if (!params.use_null_c) { + C_list_ptrs.push_back(C_tensors[i].data()); + } + D_list_ptrs.push_back(D_list_tensors[i].data()); + } + + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + nvte_grouped_gemm_with_discrete_out(grouped_A.get_handle(), + params.transa, + grouped_B.get_handle(), + params.transb, + params.use_null_c ? nullptr : C_list_ptrs.data(), + params.use_null_c ? 0 : num_gemms, + D_list_ptrs.data(), + num_gemms, + alpha_tensor.data(), + beta_tensor.data(), + setup_ws.data(), + cublas_ws.data(), + nullptr, // config (use defaults) + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Compare results + for (size_t i = 0; i < num_gemms; ++i) { + D_list_tensors[i].to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_list_vs_multi", + D_list_tensors[i], + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +#endif // CUBLAS_VERSION >= 130200 +} + +void run_grouped_gemm_discrete_in_case(const TestParams& params) { +#if CUBLAS_VERSION < 130200 + GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.2+, but compile-time cuBLAS version is " + << CUBLAS_VERSION << "."; +#else + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer."; + } + + const std::vector> shapes = make_shapes(params.shape_case); + + const size_t num_gemms = shapes.size(); + std::vector A_tensors; + std::vector B_tensors; + std::vector D_multi; + + A_tensors.reserve(num_gemms); + B_tensors.reserve(num_gemms); + D_multi.reserve(num_gemms); + + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + const std::vector a_shape = params.transa ? std::vector{N, K} + : std::vector{K, N}; + const std::vector b_shape = params.transb ? std::vector{K, M} + : std::vector{M, K}; + switch (params.input_case) { + case InputCase::kFP8Current: { + A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kBF16: { + A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape)); + B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape)); + break; + } + case InputCase::kMXFP8: { + A_tensors.emplace_back(make_mxfp8_operand("A" + std::to_string(i), a_shape, + /*is_A=*/true, params.transa)); + B_tensors.emplace_back(make_mxfp8_operand("B" + std::to_string(i), b_shape, + /*is_A=*/false, params.transb)); + break; + } + } + D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + + std::vector A_ptrs(num_gemms); + std::vector B_ptrs(num_gemms); + std::vector D_ptrs(num_gemms); + std::vector workspaces(num_gemms); + std::vector workspace_ptrs(num_gemms, nullptr); + std::vector A_views; + std::vector B_views; + A_views.reserve(num_gemms); + B_views.reserve(num_gemms); + + // Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues) + std::vector bias_ptrs(num_gemms, nullptr); + std::vector gelu_ptrs(num_gemms, nullptr); + + const size_t cublas_ws_bytes = 32ull * 1024 * 1024; + + for (size_t i = 0; i < num_gemms; ++i) { + A_ptrs[i] = A_tensors[i].data(); + B_ptrs[i] = B_tensors[i].data(); + D_ptrs[i] = D_multi[i].data(); + workspaces[i] = + Tensor("workspace" + std::to_string(i), std::vector{cublas_ws_bytes}, DType::kByte); + workspace_ptrs[i] = workspaces[i].data(); + A_views.push_back(&A_tensors[i]); + B_views.push_back(&B_tensors[i]); + } + + nvte_multi_tensor_gemm(A_ptrs.data(), + B_ptrs.data(), + D_ptrs.data(), + bias_ptrs.data(), + gelu_ptrs.data(), + static_cast(num_gemms), + params.transa, + params.transb, + false, // grad + workspace_ptrs.data(), + false, // accumulate + false, // use_split_accumulator + 0, // sm_count + 0); + + GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode()); + + std::vector C_tensors; + std::vector D_group_tensors; + C_tensors.reserve(num_gemms); + D_group_tensors.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + const auto [M, N, K] = shapes[i]; + (void)K; + if (!params.use_null_c) { + C_tensors.emplace_back(Tensor("C" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + } + D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i), + std::vector{M, N}, + DType::kBFloat16)); + NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, + bytes(D_group_tensors.back().rowwise_shape(), + D_group_tensors.back().dtype()))); + } + + std::vector C_views, D_views; + for (size_t i = 0; i < num_gemms; ++i) { + if (!params.use_null_c) { + C_views.push_back(&C_tensors[i]); + } + D_views.push_back(&D_group_tensors[i]); + } + + std::optional grouped_C; + if (!params.use_null_c) { + grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING); + } + GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING); + + // Per-matrix alpha/beta (all 1.0 and 0.0 respectively) + Tensor alpha_tensor("alpha", std::vector{num_gemms}, DType::kFloat32); + Tensor beta_tensor("beta", std::vector{num_gemms}, DType::kFloat32); + std::vector alpha_vals(num_gemms, 1.f); + std::vector beta_vals(num_gemms, 0.f); + NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(), + num_gemms * sizeof(float), cudaMemcpyHostToDevice)); + + const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms); + Tensor setup_ws("setup_ws", std::vector{setup_ws_bytes}, DType::kByte); + Tensor cublas_ws("cublas_ws", std::vector{cublas_ws_bytes}, DType::kByte); + + std::vector A_list_ptrs; + A_list_ptrs.reserve(num_gemms); + for (size_t i = 0; i < num_gemms; ++i) { + A_list_ptrs.push_back(A_tensors[i].data()); + } + + nvte_grouped_gemm_with_discrete_inputA(A_list_ptrs.data(), + num_gemms, + params.transa, + grouped_B.get_handle(), + params.transb, + params.use_null_c ? nullptr : grouped_C->get_handle(), + grouped_D.get_handle(), + alpha_tensor.data(), + beta_tensor.data(), + setup_ws.data(), + cublas_ws.data(), + nullptr, // config (use defaults) + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Compare results + for (size_t i = 0; i < num_gemms; ++i) { + Tensor grouped_split("grouped_D" + std::to_string(i), + std::vector{static_cast(std::get<0>(shapes[i])), + static_cast(std::get<1>(shapes[i]))}, + D_multi[i].dtype()); + const size_t offset_bytes = static_cast(grouped_D.offsets_host[i]) * grouped_D.elem_size; + NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(), + static_cast(grouped_D.get_data()) + offset_bytes, + grouped_D.tensor_bytes[i], + cudaMemcpyDeviceToDevice)); + grouped_split.to_cpu(); + D_multi[i].to_cpu(); + auto [atol, rtol] = getTolerances(D_multi[i].dtype()); + compareResults("grouped_discrete_in_vs_multi", + grouped_split, + D_multi[i].rowwise_cpu_dptr(), + true, + atol, + rtol); + } +#endif // CUBLAS_VERSION >= 130200 +} + class GroupedGemmTest : public ::testing::TestWithParam {}; TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) { run_grouped_gemm_case(GetParam()); } +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemmDiscreteOut) { + run_grouped_gemm_discrete_out_case(GetParam()); +} + +TEST_P(GroupedGemmTest, CompareWithMultiTensorGemmDiscreteIn) { + run_grouped_gemm_discrete_in_case(GetParam()); +} + std::string MakeGroupedGemmTestName(const testing::TestParamInfo& info) { constexpr const char* kInputNames[] = {"FP8Current", "BF16", "MXFP8"}; constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"}; diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 694b348a9b..97f317965c 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -110,6 +110,76 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row } } +void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const size_t K) { + using namespace transformer_engine; + using namespace test; + + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs; + std::vector output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + const std::vector shape{M, K}; + for (int i = 0; i < num_tensors; ++i) { + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + nvte_set_grouped_tensor_swizzled_scales(grouped_input.get_handle(), 0); + nvte_set_grouped_tensor_swizzled_scales(grouped_output.get_handle(), 1); + + const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0, num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); + + nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), 0); + + std::vector output_row(num_tensors * row_numel); + std::vector output_col(num_tensors * col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(), + output_row.size(), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), grouped_output.columnwise_scale_inv.get(), + output_col.size(), cudaMemcpyDeviceToHost)); + + std::vector ref_row(num_tensors * row_numel); + std::vector ref_col(num_tensors * col_numel); + for (int i = 0; i < num_tensors; ++i) { + compute_ref_swizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data() + i * row_numel, + row_shape.data[0], row_shape.data[1]); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data() + i * col_numel, + col_shape.data[1], col_shape.data[0]); + } + + compareResults("grouped_swizzle_rowwise", output_row.data(), ref_row.data(), + num_tensors * row_numel); + compareResults("grouped_swizzle_colwise", output_col.data(), ref_col.data(), + num_tensors * col_numel); +} + class SwizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; @@ -126,6 +196,10 @@ TEST_P(SwizzleTestSuite, TestSwizzle) { transa); } +TEST(SwizzleGroupedTestSuite, TestGroupedSwizzleMXFP8) { + performTestGroupedSwizzleMXFP8(3, 256, 256); +} + namespace { std::vector> num_tiles = { diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index b97afbc191..1de9d666ed 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -6,6 +6,7 @@ from collections.abc import Iterable import functools +import gc import io import math import random @@ -18,6 +19,7 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine.pytorch.ops as te_ops + from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, BackwardAddRMSNorm, @@ -3236,6 +3238,8 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("single_grouped_parameter", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) @pytest.mark.parametrize("glu_interleave_size", (None, 32)) def test_grouped_mlp( self, @@ -3245,6 +3249,8 @@ def test_grouped_mlp( hidden_size: int = 256, dtype: torch.dtype, quantization: Optional[str], + single_grouped_parameter: bool, + accumulate_into_main_grad: bool, device: torch.device = "cuda", split_alignment: int = 256, glu_interleave_size: Optional[int], @@ -3252,7 +3258,7 @@ def test_grouped_mlp( """GroupedLinear + ScaledSwiGLU + GroupedLinear""" # Split sizes - split_sizes = [split_alignment * i for i in range(group_size)] + split_sizes = [split_alignment * (i) for i in range(group_size)] random.shuffle(split_sizes) split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) @@ -3263,6 +3269,8 @@ def test_grouped_mlp( # Skip invalid configurations with_quantization = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if single_grouped_parameter and quantization != "mxfp8": + pytest.skip("single_grouped_parameter is only supported for MXFP8 quantization") if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") @@ -3370,6 +3378,8 @@ def test_grouped_mlp( bias=bias, device=device, dtype=dtype, + single_grouped_parameter=single_grouped_parameter, + accumulate_into_main_grad=accumulate_into_main_grad, ) fc2 = te_ops.GroupedLinear( group_size, @@ -3378,6 +3388,8 @@ def test_grouped_mlp( bias=bias, device=device, dtype=dtype, + single_grouped_parameter=single_grouped_parameter, + accumulate_into_main_grad=accumulate_into_main_grad, ) module = te_ops.Sequential( fc1, @@ -3387,12 +3399,51 @@ def test_grouped_mlp( # Copy weights with torch.no_grad(): + if single_grouped_parameter: + fc1_weights = fc1.weight.quantized_tensors + if fc1_weights is None: + fc1_weights = fc1.weight.split_into_quantized_tensors() + fc2_weights = fc2.weight.quantized_tensors + if fc2_weights is None: + fc2_weights = fc2.weight.split_into_quantized_tensors() for group_idx in range(group_size): - getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) - getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if single_grouped_parameter: + fc1_weights[group_idx].copy_(fc1_ws_test[group_idx]) + fc2_weights[group_idx].copy_(fc2_ws_test[group_idx]) + else: + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) if bias: getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + if accumulate_into_main_grad: + if single_grouped_parameter: + fc1.weight.main_grad = torch.full( + fc1.weight.size(), + 0.5, + device=device, + dtype=torch.float32, + ) + fc2.weight.main_grad = torch.full( + fc2.weight.size(), + 0.5, + device=device, + dtype=torch.float32, + ) + else: + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").main_grad = torch.full( + getattr(fc1, f"weight{group_idx}").size(), + 0.5, + device=device, + dtype=torch.float32, + ) + getattr(fc2, f"weight{group_idx}").main_grad = torch.full( + getattr(fc2, f"weight{group_idx}").size(), + 0.5, + device=device, + dtype=torch.float32, + ) del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test # Fuse ops and perform forward and backward pass @@ -3400,6 +3451,28 @@ def test_grouped_mlp( y_test = module(x_test, split_sizes, probs_test, split_sizes) y_test.backward(dy_test) + # Check for expected fusions + if ( + quantization == "mxfp8" + and dtype in (torch.bfloat16, torch.float16) + and not bias + and glu_interleave_size == 32 + ): + assert te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() + assert te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported() + forward_ops = module._module_groups[0]._forward_ops + backward_ops = module._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ) + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + ) + # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} if quantization == "nvfp4": @@ -3410,10 +3483,280 @@ def test_grouped_mlp( assert_close_grads(x_test, x_ref, **tols) assert_close_grads(probs_test, probs_ref, **tols) for group_idx in range(group_size): - assert_close_grads(getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols) assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) - assert_close_grads(getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols) assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + if not single_grouped_parameter and not accumulate_into_main_grad: + assert_close_grads( + getattr(fc2, f"weight{group_idx}"), fc2_ws_ref[group_idx], **tols + ) + assert_close_grads( + getattr(fc1, f"weight{group_idx}"), fc1_ws_ref[group_idx], **tols + ) + fc1_w_ref_grad = torch.stack([w.grad for w in fc1_ws_ref], dim=0) + fc2_w_ref_grad = torch.stack([w.grad for w in fc2_ws_ref], dim=0) + if accumulate_into_main_grad: + if single_grouped_parameter: + fc1_w_test_grad = fc1.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5 + fc2_w_test_grad = fc2.weight.main_grad.to(dtype=torch.float64, device="cpu") - 0.5 + else: + fc1_w_test_grad = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").main_grad.to( + dtype=torch.float64, device="cpu" + ) + - 0.5 + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_w_test_grad = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").main_grad.to( + dtype=torch.float64, device="cpu" + ) + - 0.5 + for group_idx in range(group_size) + ], + dim=0, + ) + assert_close(fc1_w_test_grad, fc1_w_ref_grad, **tols) + assert_close(fc2_w_test_grad, fc2_w_ref_grad, **tols) + elif single_grouped_parameter: + assert_close(fc1.weight.grad, fc1_w_ref_grad, **tols) + assert_close(fc2.weight.grad, fc2_w_ref_grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("single_grouped_parameter", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_grouped_mlp_cuda_graph_safe_mxfp8( + self, + *, + dtype: torch.dtype, + single_grouped_parameter: bool, + accumulate_into_main_grad: bool, + device: torch.device = "cuda", + group_size: int = 4, + hidden_size: int = 256, + split_alignment: int = 256, + glu_interleave_size: int = 32, + ) -> None: + """Grouped MLP forward+backward should be CUDA graph capturable (MXFP8).""" + + if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + pytest.skip("MXFP8 fused grouped MLP is not supported on this system") + if dtype not in (torch.bfloat16, torch.float16): + pytest.skip("MXFP8 fused grouped MLP is only supported with BF16/FP16") + + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int64, device=device) + in_shape = (split_sizes.sum().item(), hidden_size) + + recipe = make_recipe("mxfp8") + with te.quantized_model_init(enabled=True, recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_parameter=single_grouped_parameter, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=False, + device=device, + dtype=dtype, + single_grouped_parameter=single_grouped_parameter, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + module = te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + def _init_main_grads(value: float = 0.0) -> None: + if not accumulate_into_main_grad: + return + with torch.no_grad(): + if single_grouped_parameter: + if getattr(fc1.weight, "main_grad", None) is None: + fc1.weight.main_grad = torch.empty( + fc1.weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2.weight, "main_grad", None) is None: + fc2.weight.main_grad = torch.empty( + fc2.weight.size(), + device=device, + dtype=torch.float32, + ) + fc1.weight.main_grad.fill_(value) + fc2.weight.main_grad.fill_(value) + else: + for group_idx in range(group_size): + fc1_weight = getattr(fc1, f"weight{group_idx}") + fc2_weight = getattr(fc2, f"weight{group_idx}") + if getattr(fc1_weight, "main_grad", None) is None: + fc1_weight.main_grad = torch.empty( + fc1_weight.size(), + device=device, + dtype=torch.float32, + ) + if getattr(fc2_weight, "main_grad", None) is None: + fc2_weight.main_grad = torch.empty( + fc2_weight.size(), + device=device, + dtype=torch.float32, + ) + fc1_weight.main_grad.fill_(value) + fc2_weight.main_grad.fill_(value) + + def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: + if single_grouped_parameter: + fc1_main_grad = fc1.weight.main_grad.detach().clone() + fc2_main_grad = fc2.weight.main_grad.detach().clone() + else: + fc1_main_grad = torch.stack( + [ + getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + fc2_main_grad = torch.stack( + [ + getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() + for group_idx in range(group_size) + ], + dim=0, + ) + return fc1_main_grad, fc2_main_grad + + static_split_sizes = split_sizes.clone() + + def train_step( + x: torch.Tensor, + probs: torch.Tensor, + dy: torch.Tensor, + out_buf: torch.Tensor, + *, + use_graphed: bool, + ) -> torch.Tensor: + with te.autocast(enabled=True, recipe=recipe): + out = ( + graphed_module(x, static_split_sizes, probs, static_split_sizes) + if use_graphed + else module(x, static_split_sizes, probs, static_split_sizes) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + # Warmup to initialize kernels and allocator state. + _init_main_grads(0.0) + warmup_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + warmup_probs = torch.randn((in_shape[0],), device=device, dtype=dtype, requires_grad=True) + warmup_dy = torch.randn(in_shape, device=device, dtype=dtype) + warmup_out = torch.empty((in_shape[0], hidden_size), device=device, dtype=dtype) + # Single forward+backward to initialize MXFP8 grad cache. + train_step(warmup_x, warmup_probs, warmup_dy, warmup_out, use_graphed=False) + # Clear warmup graph references before capture. + del warmup_out, warmup_x, warmup_probs, warmup_dy + gc.collect() + torch.cuda.synchronize() + + static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + static_probs = torch.randn((in_shape[0],), device=device, dtype=dtype, requires_grad=True) + static_dy = torch.randn(in_shape, device=device, dtype=dtype) + static_out_buf = torch.empty((in_shape[0], hidden_size), device=device, dtype=dtype) + + forward_ops = module._module_groups[0]._forward_ops + backward_ops = module._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert isinstance( + forward_ops[0][0], + te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ) + assert len(backward_ops) == 1 + assert isinstance( + backward_ops[0][0], + te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, + ) + + graphed_module = te.make_graphed_callables( + module, + (static_x, static_split_sizes, static_probs, static_split_sizes), + num_warmup_iters=3, + enabled=True, + recipe=recipe, + ) + + fresh_x = torch.randn_like(static_x) + fresh_probs = torch.randn_like(static_probs) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_probs.copy_(fresh_probs) + static_dy.copy_(fresh_dy) + + for param in module.parameters(): + param.grad = torch.zeros_like(param) + _init_main_grads(0.5) + if static_x.grad is not None: + static_x.grad.zero_() + if static_probs.grad is not None: + static_probs.grad.zero_() + + graph_out = ( + train_step(static_x, static_probs, static_dy, static_out_buf, use_graphed=True) + .detach() + .clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + graph_dprobs = static_probs.grad.detach().clone() + if accumulate_into_main_grad: + graph_fc1_main_grad, graph_fc2_main_grad = _collect_main_grads() + else: + graph_param_grads = [param.grad.detach().clone() for param in module.parameters()] + + for param in module.parameters(): + param.grad.zero_() + _init_main_grads(0.5) + static_x.grad.zero_() + static_probs.grad.zero_() + + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_probs = fresh_probs.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with te.autocast(enabled=True, recipe=recipe): + expected_out = module( + expected_x, + static_split_sizes, + expected_probs, + static_split_sizes, + ) + expected_out.backward(expected_dy) + + tols = dtype_tols(dtype) + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) + assert_close(graph_dprobs, expected_probs.grad, **tols) + if accumulate_into_main_grad: + expected_fc1_main_grad, expected_fc2_main_grad = _collect_main_grads() + assert_close(graph_fc1_main_grad, expected_fc1_main_grad, **tols) + assert_close(graph_fc2_main_grad, expected_fc2_main_grad, **tols) + else: + for graph_grad, param in zip(graph_param_grads, module.parameters()): + assert_close(graph_grad, param.grad, **tols) class TestCustomOps: diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 19b94d3531..cdab6a6bc1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2709,6 +2709,275 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ) +def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + data = grouped_tensor.rowwise_data + if data is None: + data = grouped_tensor.columnwise_data + if data is None: + raise ValueError("GroupedTensor has no data buffers to pack.") + offset = 0 + for tensor in tensors: + numel = tensor.numel() + data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + +def _make_grouped_tensor_from_splits( + m_sizes: List[int], + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) + return GroupedTensor.make_grouped_tensor( + num_tensors=len(m_sizes), + first_dims=first_dims, + last_dims=None, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +def _make_grouped_tensor_uniform( + num_tensors: int, + first_dim: int, + last_dim: int, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, + last_dims=None, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + +def _make_mxfp8_quantizer(*, is_a: bool, transposed: bool) -> MXFP8Quantizer: + if is_a: + rowwise = transposed + columnwise = not transposed + else: + rowwise = not transposed + columnwise = transposed + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + quantizer.optimize_for_gemm = True + return quantizer + + +def _make_grouped_tensor_quantized( + tensors: List[torch.Tensor], + quantizer: MXFP8Quantizer, + device: torch.device, + dtype: torch.dtype, +) -> GroupedTensor: + shapes = [tuple(t.size()) for t in tensors] + grouped = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=len(tensors), + shapes=shapes, + quantizer=quantizer, + device=device, + dtype=dtype, + ) + grouped.quantize(tensors) + return grouped + + +@pytest.mark.parametrize("case", ["no_discrete", "discrete_in", "discrete_out"]) +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False, True]) +def test_grouped_gemm_grouped_tensor(case, layout, accumulate) -> None: + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + if case == "discrete_in" and not accumulate: + pytest.xfail("discrete_in accumulate=False not supported yet.") + if case == "discrete_out" and (layout != "NT"): + pytest.skip("discrete_out only covers NT") + if layout == "NT" and case == "discrete_in": + pytest.skip("NT is not supported for discrete_in.") + + torch.manual_seed(0) + + z, m, n, k = (4, 512, 512, 256) + dtype = torch.bfloat16 + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + if layout == "NT": + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [torch.matmul(B[i].transpose(0, 1).float(), A[i].float()) for i in range(z)] + else: + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [ + torch.randn(ms, k if layout == "TN" else n, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> input, NN --> grad_output + out = [ + torch.randn(ms, n if layout == "TN" else k, dtype=dtype, device="cuda") + for ms in m_sizes + ] # TN --> output, NN --> dgrad + if layout == "NN": + out_ref = [torch.matmul(B[i].float(), A[i].float()) for i in range(z)] + else: # layout == "TN" + out_ref = [torch.matmul(B[i].float(), A[i].transpose(0, 1).float()) for i in range(z)] + + if accumulate: + out_ref = [out[i].float() + o for i, o in enumerate(out_ref)] + out_ref = [o.to(dtype) for o in out_ref] + + # Create grouped tensors based on case + device = A[0].device + grouped_A = A + grouped_out = out + if layout == "TN": + grouped_A = ( + _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # + grouped_B = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) # input + grouped_out = ( + _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + if case != "discrete_out" + else grouped_out + ) # output + elif layout == "NN": + grouped_A = ( + _make_grouped_tensor_uniform(z, n, k, device, dtype) if case != "discrete_in" else A + ) # weight + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + grouped_out = ( + _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + if case != "discrete_out" + else grouped_out + ) # dgrad + else: # layout == "NT" + grouped_A = ( + _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + if case != "discrete_in" + else A + ) # input + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) # grad_output + grouped_out = ( + _make_grouped_tensor_uniform(z, n, k, device, dtype) + if case != "discrete_out" + else grouped_out + ) # wgrad + _pack_grouped_tensor(grouped_B, B) + if case != "discrete_out": + _pack_grouped_tensor(grouped_out, out) + if case != "discrete_in": + _pack_grouped_tensor(grouped_A, A) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=accumulate, + ) + out_grouped = ( + grouped_out if isinstance(grouped_out, list) else grouped_out.split_into_quantized_tensors() + ) + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + + +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +def test_grouped_gemm_grouped_tensor_mxfp8(layout: str) -> None: + torch.manual_seed(0) + z = 3 + m_sizes = [512, 512, 512] + n, k = 512, 512 + dtype = torch.float16 + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output + grad = False + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad + grad = True + else: # layout == "NT" + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + + out_ref = [o.clone() for o in out] + + transa = layout[0] == "T" + transb = layout[1] == "T" + a_quantizer = _make_mxfp8_quantizer(is_a=True, transposed=transa) + b_quantizer = _make_mxfp8_quantizer(is_a=False, transposed=transb) + + grouped_A = _make_grouped_tensor_quantized(A, a_quantizer, "cuda", dtype) + grouped_B = _make_grouped_tensor_quantized(B, b_quantizer, "cuda", dtype) + A_fp8 = grouped_A.split_into_quantized_tensors() + B_fp8 = grouped_B.split_into_quantized_tensors() + + general_grouped_gemm( + A_fp8, + B_fp8, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=False, + layout=layout, + single_output=False, + ) + + device = A[0].device + + if layout == "TN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n, device, dtype) + elif layout == "NN": + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k, device, dtype) + else: # layout == "NT" + grouped_out = _make_grouped_tensor_uniform(z, n, k, device, dtype) + + _pack_grouped_tensor(grouped_out, out) + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=False, + ) + + out_grouped = grouped_out.split_into_quantized_tensors() + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + + @pytest.mark.parametrize( "shape", [ diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 129d6724ac..5e32b2ac44 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -810,8 +810,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS; } - // Treat a grouped tensor with const last dims as a single tensor - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM); + const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); NVTE_CHECK(input->num_tensors == output->num_tensors, "Number of input and output tensors must be same."); @@ -848,8 +848,6 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const dim3 grid(blocks_X, blocks_Y); const size_t block_size = THREADS_PER_CHUNK; - const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; - // Logical shape of a tensor with varying all dims is [1, M*K] if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { NVTE_CHECK(first_logical_dim % 128 == 0, diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 904812118c..03d88ef0b6 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -45,6 +45,21 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM (grouped tensor) + * + * \param[in] input Input grouped tensor with non-swizzled scale_inv. + * \param[in,out] output Output grouped tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scaling mode must be MXFP8 1D scaling. + * - scale_inv is stored in row-major per group. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + /*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM * * \param[in] input Input FP8 block-scaled tensor. diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..5c777f1202 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -558,6 +558,23 @@ NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor) */ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Set whether the grouped tensor has GEMM-swizzled scales. + * + * \param[in] tensor Grouped tensor. + * \param[in] val 1 if scales are swizzled, 0 otherwise. + */ +void nvte_set_grouped_tensor_swizzled_scales(NVTEGroupedTensor tensor, uint8_t val); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get whether the grouped tensor has GEMM-swizzled scales. + * + * \param[in] tensor Grouped tensor. + * + * \return 1 if scales are swizzled, 0 otherwise. + */ +uint8_t nvte_get_grouped_tensor_swizzled_scales(const NVTEGroupedTensor tensor); + #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b97504f2ae..7791c9c82c 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1329,3 +1329,19 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); return t.logical_shape; } + +void nvte_set_grouped_tensor_swizzled_scales(NVTEGroupedTensor tensor, uint8_t val) { + if (tensor == nullptr) { + return; + } + auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + t.with_gemm_swizzled_scales = (val != 0); +} + +uint8_t nvte_get_grouped_tensor_swizzled_scales(const NVTEGroupedTensor tensor) { + if (tensor == nullptr) { + return 0; + } + const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + return t.with_gemm_swizzled_scales ? 1 : 0; +} diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..1269a94d6b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -414,7 +414,7 @@ std::pair Float8Quantizer::create_grouped_tens kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = py::cast(false); + kwargs["with_gemm_swizzled_scales"] = py::cast(this->optimize_for_gemm); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -737,7 +737,7 @@ std::pair Float8CurrentScalingQuantizer::creat kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = py::cast(false); + kwargs["with_gemm_swizzled_scales"] = py::cast(this->optimize_for_gemm); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -1100,7 +1100,7 @@ std::pair Float8BlockQuantizer::create_grouped kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = py::cast(false); + kwargs["with_gemm_swizzled_scales"] = py::cast(this->optimize_for_gemm); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -1478,6 +1478,8 @@ std::pair MXFP8Quantizer::create_grouped_tenso columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); } + const bool with_gemm_swizzled_scales = this->optimize_for_gemm; + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); if (rowwise_usage) { out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); @@ -1497,6 +1499,8 @@ std::pair MXFP8Quantizer::create_grouped_tenso out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, getTensorShape(*tensor_offsets)); } + nvte_set_grouped_tensor_swizzled_scales(out_cpp.data(), + static_cast(with_gemm_swizzled_scales)); out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); @@ -1521,7 +1525,7 @@ std::pair MXFP8Quantizer::create_grouped_tenso kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + kwargs["with_gemm_swizzled_scales"] = py::cast(this->optimize_for_gemm); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -1954,7 +1958,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + kwargs["with_gemm_swizzled_scales"] = py::cast(this->optimize_for_gemm); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index e9c6ca882e..3e11606ff7 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -285,6 +285,9 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { bool with_gemm_swizzled = false; if (py::hasattr(tensor, "_with_gemm_swizzled_scales")) { with_gemm_swizzled = tensor.attr("_with_gemm_swizzled_scales").cast(); + } else if (py::hasattr(tensor, "with_gemm_swizzled_scales") && + !tensor.attr("with_gemm_swizzled_scales").is_none()) { + with_gemm_swizzled = tensor.attr("with_gemm_swizzled_scales").cast(); } ret.set_with_gemm_swizzled_scales(with_gemm_swizzled); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 28da4873f0..6c708ed397 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -82,17 +82,18 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor """Returns a dummy tensor of given shape.""" if len(shape) != 2: raise ValueError(f"Expected 2D shape, got {len(shape)}D: {shape}") + key = (*shape, dtype) global _dummy_wgrads - if (shape[0], shape[1], dtype) not in _dummy_wgrads: - _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( + if key not in _dummy_wgrads: + _dummy_wgrads[key] = torch.empty( shape, dtype=dtype, device="cuda", requires_grad=False, ) if zero: - _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0) - return _dummy_wgrads[(shape[0], shape[1], dtype)].detach() + _dummy_wgrads[key].fill_(0) + return _dummy_wgrads[key].detach() def initialize_ub( diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 30c1dbf408..6cc0e1c0a1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -6,6 +6,7 @@ from typing import Union, Optional, Callable, Tuple, List from itertools import chain import warnings +import os import functools import torch @@ -633,6 +634,10 @@ def __init__( ) -> None: super().__init__(name) + # Temporary for quick testing. + if os.getenv("_NVTE_SINGLE_GROUPED_PARAMETER_TMP_VAR", None) == "1": + single_grouped_parameter = True + self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms self.in_features = in_features diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 4520dbc313..f955eabd87 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -15,6 +15,9 @@ from ..tensor.float8_tensor import Float8Tensor from ..quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype +from ..module._common import noop_cat +from ..tensor import Quantizer +from ..tensor.grouped_tensor import GroupedTensor def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: @@ -71,3 +74,112 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device) fp8_meta.scale_inv = tensor._scale_inv return fp8_meta, 0 + + +def make_grouped_tensor_from_buffers( + *, + num_groups: int, + data: torch.Tensor, + split_sizes: torch.Tensor, + columnwise_data: torch.Tensor = None, + scale_inv: torch.Tensor = None, + columnwise_scale_inv: torch.Tensor = None, + tensor_offsets: torch.Tensor = None, + logical_last_dim: int, + dtype: torch.dtype, + quantizer: Quantizer = None, + with_gemm_swizzled_scales: bool = False, +) -> GroupedTensor: + """Build GroupedTensor from FC1+SwiGLU / dSwiGLU kernel outputs. + + Scales are already in GEMM swizzled layout. + """ + if tensor_offsets is None: + tensor_offsets = GroupedTensor.make_tensor_offsets(split_sizes, logical_last_dim) + logical_first_dim = data.shape[0] if data is not None else columnwise_data.shape[0] + ndim = data.ndim if data is not None else columnwise_data.ndim + if ndim == 1: + logical_first_dim = logical_first_dim // logical_last_dim + return GroupedTensor( + shape=(logical_first_dim, logical_last_dim), + dtype=dtype, + quantizer=quantizer, + num_tensors=num_groups, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=None, + columnwise_amax=None, + scale=None, + first_dims=split_sizes, + last_dims=None, + tensor_offsets=tensor_offsets, + offsets=None, + scale_inv_offsets=None, + columnwise_scale_inv_offsets=None, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) + + +def make_grouped_tensor_from_mxfp8_weights( + weights: list, + quantizer: Quantizer, + device: torch.device, + dtype: torch.dtype, + with_gemm_swizzled_scales: bool = False, +) -> GroupedTensor: + """Build a GroupedTensor from MXFP8 weight tensors by packing their buffers (no copy when contiguous).""" + num_groups = len(weights) + weight_shape = weights[0].shape + O, I = weight_shape[0], weight_shape[1] + logical_first_dim = num_groups * O + logical_last_dim = I + + tensor_offsets = None + data = None + scale_inv = None + scale_inv_offsets = None + columnwise_data = None + columnwise_scale_inv = None + columnwise_scale_inv_offsets = None + + # Pack rowwise into data/scale_inv when available. + # GEMM expects scales in swizzled layout (same as FC1 weight scales in grouped_gemm_swiglu). + if weights[0]._rowwise_data is not None: + data = noop_cat([w._rowwise_data.reshape(-1) for w in weights]) + rowwise_scales = noop_cat([w._rowwise_scale_inv for w in weights]) + if with_gemm_swizzled_scales: + rowwise_scales = rowwise_scales.view(num_groups, O // 128, 4, 32, I // 128, 4) + rowwise_scales = rowwise_scales.permute(0, 1, 4, 3, 2, 5).contiguous() + scale_inv = rowwise_scales.reshape(-1) + # Pack columnwise into columnwise_* when available. + # GEMM expects columnwise scales in swizzled layout (same as FC2 weight scales in backward dSwiGLU kernel). + if weights[0]._columnwise_data is not None: + columnwise_data = noop_cat([w._columnwise_data.reshape(-1) for w in weights]) + columnwise_scales = noop_cat([w._columnwise_scale_inv for w in weights]) + if with_gemm_swizzled_scales: + columnwise_scales = columnwise_scales.view(num_groups, O // 128, 4, I // 128, 4, 32) + columnwise_scales = columnwise_scales.permute(0, 3, 1, 5, 4, 2).contiguous() + columnwise_scale_inv = columnwise_scales.reshape(-1) + + return GroupedTensor( + shape=(logical_first_dim, logical_last_dim), + dtype=dtype, + num_tensors=num_groups, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=None, + columnwise_amax=None, + scale=None, + first_dims=None, + last_dims=None, + tensor_offsets=tensor_offsets, + offsets=None, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + ) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index b44e77b0c6..45295a2324 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -32,6 +32,7 @@ ) from .._common import is_quantized_tensor, maybe_dequantize from ..op import BasicOperation, OperationContext +from ...tensor import GroupedTensor class GroupedLinear(BasicOperation): @@ -86,6 +87,7 @@ def __init__( dtype: Optional[torch.dtype] = None, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, + single_grouped_parameter: bool = False, ) -> None: super().__init__() @@ -93,6 +95,7 @@ def __init__( self.num_groups: int = num_groups self.in_features: int = in_features self.out_features: int = out_features + self.single_grouped_parameter: bool = single_grouped_parameter if self.num_groups <= 0: raise ValueError(f"Invalid number of groups ({self.num_groups})") if self.in_features <= 0: @@ -116,12 +119,15 @@ def __init__( self._rng_state_tracker_function = rng_state_tracker_function # Register weights + # TODO(ksivaman): Proper support for meta device. + # We do not want to reset params later as it wipes off + # main_grad and related attributes. self.weight0: torch.nn.Parameter for group_idx in range(self.num_groups): weight_tensor = torch.empty( self.out_features, self.in_features, - device="meta", + device=device, dtype=dtype, ) self.register_parameter( @@ -136,7 +142,7 @@ def __init__( if bias: bias_tensor = torch.empty( self.out_features, - device="meta", + device=device, dtype=dtype, ) bias_tensor = torch.nn.Parameter(bias_tensor) @@ -232,6 +238,46 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(packed_biases[group_idx]) setattr(self, f"bias{group_idx}", bias) + if self.single_grouped_parameter: + self.make_grouped_weights() + + def make_grouped_weights(self) -> None: + """ + Convert parameters into a GroupedTensor and re-register them as parameters. + """ + + weights = [getattr(self, f"weight{idx}") for idx in range(self.num_groups)] + quantizer = self.get_quantizer("forward", 1) + + recipe = None if quantizer is None else quantizer._get_compatible_recipe() + if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()): + return + + grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=self.num_groups, + shapes=[(self.out_features, self.in_features)] * self.num_groups, + quantizer=quantizer, + dtype=self.weight0.dtype, + device=self.weight0.device, + ) + + # Copy existing params into storage. + with torch.no_grad(): + for i in range(self.num_groups): + if self._with_quantized_weight: + grouped_weights.quantized_tensors[i].copy_from_storage(weights[i]) + else: + grouped_weights.quantized_tensors[i].copy_(weights[i]) + + assert isinstance(grouped_weights, torch.Tensor) and ( + quantizer is None or not quantizer.internal + ), "Found internal quantizer with `single_grouped_parameter=True`." + + # Re-register as a single grouped weight parameter. + self.register_parameter("weight", torch.nn.Parameter(grouped_weights)) + for group_idx in range(self.num_groups): + self.register_parameter(f"weight{group_idx}", None) + def _quantize_weights( self, weights: Sequence[torch.Tensor], @@ -328,33 +374,40 @@ def pre_first_fuser_forward(self) -> None: if any(param.device.type == "meta" for param in self.parameters()): self.reset_parameters() - # Check that weights are consistent - dtype = self.weight0.dtype - device = self.weight0.device - weight_requires_grad = self.weight0.requires_grad - weight_tensor_type = type(self.weight0.data) - for group_idx in range(self.num_groups): - weight = getattr(self, f"weight{group_idx}") - if weight.dtype != dtype: - raise RuntimeError( - f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})." - ) - if not devices_match(weight.device, device): - raise RuntimeError( - f"Weight {group_idx} has invalid device " - f"(expected {device}, got {weight.device})." - ) - if weight.requires_grad != weight_requires_grad: - raise RuntimeError( - f"Weight {group_idx} has requires_grad={weight.requires_grad}, " - f"but expected requires_grad={weight_requires_grad}." - ) - if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck - raise RuntimeError( - f"Weight {group_idx} has invalid tensor type " - f"(expected {weight_tensor_type.__name__}, " - f"got {type(weight.data).__name__})." - ) + # Check that all weight params are consistent + if not self.single_grouped_parameter: + dtype = self.weight0.dtype + device = self.weight0.device + weight_requires_grad = self.weight0.requires_grad + weight_tensor_type = type(self.weight0.data) + for group_idx in range(self.num_groups): + weight = getattr(self, f"weight{group_idx}") + if weight.dtype != dtype: + raise RuntimeError( + f"Weight {group_idx} has invalid dtype (expected {dtype}, got" + f" {weight.dtype})." + ) + if not devices_match(weight.device, device): + raise RuntimeError( + f"Weight {group_idx} has invalid device " + f"(expected {device}, got {weight.device})." + ) + if weight.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Weight {group_idx} has requires_grad={weight.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck + raise RuntimeError( + f"Weight {group_idx} has invalid tensor type " + f"(expected {weight_tensor_type.__name__}, " + f"got {type(weight.data).__name__})." + ) + else: + dtype = self.weight.dtype + device = self.weight.device + weight_requires_grad = self.weight.requires_grad + weight_tensor_type = type(self.weight.data) # Check that biases are consistent for group_idx in range(self.num_groups): @@ -384,7 +437,12 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: super().pre_fuser_forward(requires_grad=requires_grad) if FP8GlobalStateManager.is_fp8_enabled(): # Assume weights have consistent grad requirement - weight_requires_grad = requires_grad and self.weight0.requires_grad + weight_requires_grad = ( + self.weight.requires_grad + if self.single_grouped_parameter + else self.weight0.requires_grad + ) + weight_requires_grad = requires_grad and weight_requires_grad # Configure quantizer usages # Note: We cache the quantized input for backward pass, @@ -419,13 +477,17 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: # Make sure weight param has correct quantizer weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) weight_quantizer.internal = False - getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + if self.single_grouped_parameter: + self.weight.quantizer = weight_quantizer.copy() + else: + getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) else: # Use internal tensors if quantized weights will not be # exposed externally weight_quantizer.internal = ( not FP8GlobalStateManager.with_fp8_parameters() and not getattr(self, "_with_quantized_weight", False) + and not self.single_grouped_parameter ) # Recipe-specific configuration @@ -472,12 +534,17 @@ def fuser_forward( ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: num_groups = self.num_groups has_bias = self.has_bias - device = self.weight0.device + weight_param = self.weight if self.single_grouped_parameter else self.weight0 + device = weight_param.device + + if self._accumulate_into_main_grad: + assert hasattr(weight_param, "main_grad"), "MAIN GRAD NOT FOUND !!!!" + assert weight_param.main_grad is not None, "MAIN GRAD IS NONE !!!!" # Check which grads are required ctx = basic_op_ctxs[0] input_requires_grad = ctx.requires_grad - weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad + weight_requires_grad = ctx.requires_grad and weight_param.requires_grad # Quantizers input_quantizers = [None] * num_groups @@ -494,7 +561,7 @@ def fuser_forward( if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") else: - dtype = self.weight0.dtype + dtype = weight_param.dtype # Extract split sizes from extra input split_sizes = basic_op_extra_inputs[0][0] @@ -503,7 +570,12 @@ def fuser_forward( raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.") # Extract params - weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] + if self.single_grouped_parameter: + weights = self.weight.quantized_tensors + if weights is None: + weights = self.weight.split_into_quantized_tensors() + else: + weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] bs = None if has_bias: bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)] @@ -589,7 +661,8 @@ def fuser_backward( ]: num_groups = self.num_groups has_bias = self.has_bias - device = self.weight0.device + weight_param = self.weight if self.single_grouped_parameter else self.weight0 + device = weight_param.device # Saved tensors from forward pass ctx = basic_op_ctxs[0] @@ -628,14 +701,42 @@ def fuser_backward( # Megatron-LM wgrad fusion # Note: Get grad tensors from params so we can # accumulate directly into it. - for group_idx in range(num_groups): - weight_param = getattr(self, f"weight{group_idx}") + if self.single_grouped_parameter: if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() - grad_weights[group_idx] = weight_param.main_grad - accumulate_into_main_grad = not getattr(self.weight0, "overwrite_main_grad", False) + main_grad = weight_param.main_grad + if isinstance(main_grad, GroupedTensor): + grad_weights = main_grad.quantized_tensors + if grad_weights is None: + grad_weights = main_grad.split_into_quantized_tensors() + else: + # main_grad may be [num_groups, out, in] or a flat buffer. + # Canonicalize to grouped layout before slicing per-group views. + weight_shape = (self.out_features, self.in_features) + grouped_shape = (num_groups, *weight_shape) + if main_grad.shape != grouped_shape: + if main_grad.numel() != math.prod(grouped_shape): + raise RuntimeError( + "GroupedLinear expected grouped weight main_grad to have " + f"shape {grouped_shape} or matching numel, " + f"but got shape {tuple(main_grad.shape)}" + ) + main_grad = main_grad.reshape(grouped_shape) + grad_weights = [main_grad[idx] for idx in range(num_groups)] + accumulate_into_main_grad = not getattr( + weight_param, "overwrite_main_grad", False + ) + else: + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + grad_weights[group_idx] = weight_param.main_grad + accumulate_into_main_grad = not getattr( + self.weight0, "overwrite_main_grad", False + ) else: - weight_shape = ws[0].size() + weight_shape = (self.out_features, self.in_features) for group_idx in range(num_groups): grad_weights[group_idx] = torch.empty( weight_shape, @@ -688,6 +789,20 @@ def fuser_backward( # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weights = [None] * num_groups + if self.single_grouped_parameter: + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + else: + grad_weight = None + # Parameter registration order with single_grouped_parameter=True is: + # bias0..biasN-1, then weight. Return grads in the same order. + grad_params = grad_biases + [grad_weight] if has_bias else [grad_weight] + return grad_input, [grad_params], [(None,)] for group_idx in range(num_groups): weight_param = getattr(self, f"weight{group_idx}") if hasattr(weight_param, "grad_added_to_main_grad"): @@ -698,5 +813,14 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - grad_params = grad_weights + grad_biases if has_bias else grad_weights + if self.single_grouped_parameter: + grad_weight = None + # TODO:ksivaman change workflow to avoid stack. + if ctx.weight_requires_grad: + grad_weight = torch.stack(grad_weights, dim=0) + # Parameter registration order with single_grouped_parameter=True is: + # bias0..biasN-1, then weight. Return grads in the same order. + grad_params = grad_biases + [grad_weight] if has_bias else [grad_weight] + else: + grad_params = grad_weights + grad_biases if has_bias else grad_weights return grad_input, [grad_params], [(None,)] diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 19608894e0..19a090f121 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -28,3 +28,12 @@ register_backward_fusion(BackwardLinearScale.fuse_backward_ops) register_backward_fusion(BackwardActivationBias.fuse_backward_ops) register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) + +# Import experimental fusions +# Note: Registration logic is non-trivial, so submodule handles it internally. +from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position + ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, +) +from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position + BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, +) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py new file mode 100644 index 0000000000..04ebd6b819 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -0,0 +1,741 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for MoE grouped MLP.""" + +from __future__ import annotations +from collections.abc import Callable +import os +import functools +import math +from pickle import TRUE +from typing import Optional + +import torch + +import transformer_engine_torch as tex +from cuda.bindings import driver as cuda +from ...cpp_extensions import ( + general_grouped_gemm_for_grouped_tensor, +) +from ...module._common import noop_cat +from ...module.base import get_dummy_wgrad +from ...quantization import Recipe +from ...tensor import Quantizer +from ...tensor.grouped_tensor import GroupedTensor +from ...utils import clear_tensor_data, get_device_compute_capability +from ..basic import GroupedLinear, ScaledSwiGLU +from ..fuser import register_backward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + is_quantized_tensor, + make_grouped_tensor_from_buffers, + maybe_dequantize, +) + +global_alpha_tensor = None + + +class BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8(FusedOperation): + """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel from cuDNN front-end. + + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_dswiglu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, SwiGLU backward, and scale grad.""" + from cudnn import grouped_gemm_dswiglu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_dswiglu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if get_device_compute_capability() < (10, 0): + # Kernel requires SM100+ + return False + try: + # Make sure kernel is available + cls.grouped_gemm_dswiglu_kernel() + except ImportError: + return False + return True + + def __init__( + self, + *, + fc1: GroupedLinear, + swiglu: ScaledSwiGLU, + fc2: GroupedLinear, + ) -> None: + super().__init__((fc1, swiglu, fc2)) + self._mxfp8_alpha_tensor: Optional[torch.Tensor] = None + self._mxfp8_norm_const_tensor: Optional[torch.Tensor] = None + + # Check for unsupported configurations + if not self.is_supported(): + self.grouped_gemm_dswiglu_kernel() # Try triggering import error + raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") + if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups: + raise ValueError( + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if fc1.has_bias or fc2.has_bias: + raise ValueError("Fused kernel does not support bias.") + if swiglu.glu_interleave_size != 32: + raise ValueError( + "Fused kernel requires 32-wide GLU interleaving, " + f"but got glu_interleave_size={swiglu.glu_interleave_size}." + ) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + # Tensor properties + out_shape = list(grad_output.size()) + assert len(out_shape) == 2, f"Expected 2D grad output tensor, got shape={out_shape}." + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) + num_groups = fc1_op.num_groups + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_parameter else fc1_op.weight0 + device = fc1_weight_param.device + dtype = fc1_ctx.dtype + + # Saved tensors from FC1 forward + saved_tensors = fc1_ctx.saved_tensors + split_sizes, split_points, saved_tensors = ( + saved_tensors[0], + saved_tensors[1], + saved_tensors[2:], + ) + + if fc1_op.single_grouped_parameter: + grouped_fc1_weight, saved_tensors = saved_tensors[0], saved_tensors[1:] + else: + grouped_fc1_weight, saved_tensors = ( + saved_tensors[:num_groups], + saved_tensors[num_groups:], + ) + + ( + fc1_x_data, + fc1_x_col_data, + fc1_x_scale, + fc1_x_col_scale, + fc1_x_tensor_offsets, + ), saved_tensors = ( + saved_tensors[:5], + saved_tensors[5:], + ) + + # Saved tensors from scaled SwiGLU forward + swiglu_in, scales = swiglu_ctx.saved_tensors + + # Saved tensors from FC2 forward + saved_tensors = fc2_ctx.saved_tensors + _, saved_tensors = saved_tensors[0], saved_tensors[1:] # Assume same split sizes as FC1 + if fc2_op.single_grouped_parameter: + grouped_fc2_weight, saved_tensors = saved_tensors[0], saved_tensors[1:] + else: + grouped_fc2_weight, saved_tensors = ( + saved_tensors[:num_groups], + saved_tensors[num_groups:], + ) + + ( + fc2_x_data, + fc2_x_col_data, + fc2_x_scale, + fc2_x_col_scale, + fc2_x_tensor_offsets, + ), saved_tensors = ( + saved_tensors[:5], + saved_tensors[5:], + ) + + # Group splits + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") + split_sizes = split_sizes.to(dtype=torch.int64, device=device) + split_points = split_points.to(dtype=torch.int, device=device) + + grouped_fc1_x = None + if fc1_ctx.weight_requires_grad: + grouped_fc1_x = make_grouped_tensor_from_buffers( + num_groups=num_groups, + data=fc1_x_data, + columnwise_data=fc1_x_col_data, + scale_inv=fc1_x_scale, + columnwise_scale_inv=fc1_x_col_scale, + split_sizes=split_sizes, + logical_last_dim=fc1_weight_shape[1], + dtype=dtype, + quantizer=fc1_ctx.input_quantizers[0], + with_gemm_swizzled_scales=True, + tensor_offsets=fc1_x_tensor_offsets, + ) + + grouped_fc2_x = None + if fc2_ctx.weight_requires_grad: + grouped_fc2_x = make_grouped_tensor_from_buffers( + num_groups=num_groups, + data=fc2_x_data, + columnwise_data=fc2_x_col_data, + scale_inv=fc2_x_scale, + columnwise_scale_inv=fc2_x_col_scale, + split_sizes=split_sizes, + logical_last_dim=fc2_weight_shape[1], + dtype=dtype, + quantizer=fc2_ctx.input_quantizers[0], + with_gemm_swizzled_scales=True, + tensor_offsets=fc2_x_tensor_offsets, + ) + + # Split grad output tensor and convert dtypes if needed + fc2_dy = maybe_dequantize(grad_output, dtype) + for quantizer in fc2_ctx.grad_output_quantizers: + quantizer.set_usage(rowwise=True, columnwise=fc2_ctx.weight_requires_grad) + quantizer.optimize_for_gemm = True + grouped_fc2_dy = tex.group_quantize( + fc2_dy, fc2_ctx.grad_output_quantizers[0], num_groups, split_sizes + ) + + # Pack data tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (1, sum(m), k) + # Scale actual shape: (1, sum(m)/128, k/128, 32 (block row), + # 4 (block row), 4 (block col)) + # Data logical shape: (sum(m), k, 1) + # Scale logical shape: (32 (block row), 4 (block row), + # sum(m)/128, 4 (block col), k/128, 1) + fc2_dy_data = grouped_fc2_dy.rowwise_data.view(out_shape[0], out_shape[1]) + fc2_dy_data = fc2_dy_data.view(dtype=torch.float8_e4m3fn) + fc2_dy_data = fc2_dy_data.unsqueeze(0).permute(1, 2, 0) + fc2_dy_scales = grouped_fc2_dy.scale_inv + fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu) + fc2_dy_scales = fc2_dy_scales.view( + 1, + out_shape[0] // 128, + out_shape[1] // 128, + 32, + 4, + 4, + ) + fc2_dy_scales = fc2_dy_scales.permute(3, 4, 1, 5, 2, 0) + + # Pack weight tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (num_groups, k, n) + # Scale actual shape: (num_groups, n/128, k/128, 32 (block col), + # 4 (block col), 4 (block row)) + # Data logical shape: (n, k, num_groups) + # Scale logical shape: (32 (block col), 4 (block col), n/128, + # 4 (block row), k/128, num_groups) + fc2_w_data = ( + grouped_fc2_weight.columnwise_data + if fc2_op.single_grouped_parameter + else noop_cat([w._columnwise_data for w in grouped_fc2_weight]) + ) + fc2_w_data = fc2_w_data.view(dtype=torch.float8_e4m3fn) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1]) + fc2_w_data = fc2_w_data.permute(2, 1, 0) + fc2_w_scales = ( + grouped_fc2_weight.columnwise_scale_inv + if fc2_op.single_grouped_parameter + else noop_cat([w._columnwise_scale_inv for w in grouped_fc2_weight]) + ) + fc2_w_scales = fc2_w_scales.view(dtype=torch.float8_e8m0fnu) + fc2_w_scales = fc2_w_scales.view( + num_groups, fc2_weight_shape[0] // 128, 4, fc2_weight_shape[1] // 128, 4, 32 + ) # Unswizzled layout + fc2_w_scales = fc2_w_scales.permute( + 0, 3, 1, 5, 4, 2 + ).contiguous() # Convert to swizzled layout + fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + + # Kernel scaling factors + alpha_tensor, norm_const_tensor = self._get_kernel_constants( + num_groups=num_groups, dtype=dtype, device=device + ) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Fused kernel for FC2 dgrad + dSwiGLU + grad scale + fc2_dgrad_kernel_out = self.grouped_gemm_dswiglu_kernel()( + fc2_dy_data, + fc2_w_data, + swiglu_in.unsqueeze(0).permute(1, 2, 0), + fc2_dy_scales, + fc2_w_scales, + split_points, + alpha_tensor, # alpha_tensor + alpha_tensor, # beta_tensor + scales.detach().to(dtype=torch.float32).reshape(-1, 1, 1), + norm_const_tensor=norm_const_tensor, + d_dtype=torch.float8_e4m3fn, + cd_major="n", + sf_vec_size=32, + current_stream=current_stream, + discrete_col_sfd=True, + ) + + # Unpack kernel outputs + # Note: Fused kernel outputs tensors with non-contiguous + # logical dims. + # Row-wise data logical shape: (sum(m), k, 1) + # Row-wise scale logical shape: (32 (block row), 4 (block row), + # sum(m)/128, 4 (block col), k/128, 1) + # Column-wise data logical shape: (k, sum(m), 1) + # Column-wise scale logical shape: (32 (block col), 4 (block col), + # k/128, 4 (block row), sum(m)/128, 1) + fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] + fc1_dy_row_data = fc1_dy_row_data.permute(2, 0, 1) + fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]).contiguous() + fc1_dy_row_scale = fc2_dgrad_kernel_out["sfd_row_tensor"] + fc1_dy_row_scale = fc1_dy_row_scale.permute(5, 2, 4, 0, 1, 3) + fc1_dy_row_scale = fc1_dy_row_scale.view( + out_shape[0], fc1_weight_shape[0] // 32 + ).contiguous() + fc1_dy_col_data = fc2_dgrad_kernel_out["d_col_tensor"] + fc1_dy_col_data = fc1_dy_col_data.permute(2, 0, 1) + fc1_dy_col_data = fc1_dy_col_data.view(out_shape[0], fc1_weight_shape[0]).contiguous() + fc1_dy_col_scale = fc2_dgrad_kernel_out["sfd_col_tensor"] + fc1_dy_col_scale = fc1_dy_col_scale.permute(5, 2, 4, 0, 1, 3) + fc1_dy_col_scale = fc1_dy_col_scale.reshape(-1) + + grad_scales = fc2_dgrad_kernel_out["dprob_tensor"] + grad_scales = grad_scales.view(-1).to(dtype=dtype) + + # FC1 grad output for dgrad and wgrad GEMMs + grouped_fc1_dy = make_grouped_tensor_from_buffers( + num_groups=num_groups, + data=fc1_dy_row_data, + columnwise_data=fc1_dy_col_data, + scale_inv=fc1_dy_row_scale, + columnwise_scale_inv=fc1_dy_col_scale, + split_sizes=split_sizes, + logical_last_dim=fc1_weight_shape[0], + dtype=dtype, + quantizer=fc1_ctx.grad_output_quantizers[0], + with_gemm_swizzled_scales=True, + ) + + # FC2 wgrad GEMM + fc2_packed_wgrad = None + fc2_weight_grads: list[Optional[torch.Tensor]] + if fc2_op.single_grouped_parameter: + fc2_weight_grads = [None] + else: + fc2_weight_grads = [None] * num_groups + if fc2_ctx.weight_requires_grad: + + # Initialize grad buffers + accumulate_into_main_grad = False + if fc2_op.single_grouped_parameter: + grouped_fc2_wgrad = None + weight_param = fc2_op.weight + if fc2_op._accumulate_into_main_grad: + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can + # accumulate directly into it. + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + main_grad = weight_param.main_grad + grouped_shape = (num_groups, *fc2_weight_shape) + if main_grad.shape != grouped_shape: + if main_grad.numel() != math.prod(grouped_shape): + raise RuntimeError( + "Grouped MLP fused backward expected FC2 main_grad to have " + f"shape {grouped_shape} or matching numel, " + f"but got shape {tuple(main_grad.shape)}" + ) + # Keep aliasing with weight.main_grad; do not allow implicit copies. + try: + main_grad = main_grad.view(grouped_shape) + except RuntimeError as e: + raise RuntimeError( + "Grouped MLP fused backward requires FC2 main_grad to be viewable" + f" as {grouped_shape} without copy, but got shape" + f" {tuple(main_grad.shape)} and stride {tuple(main_grad.stride())}" + ) from e + accumulate_into_main_grad = not getattr( + weight_param, "overwrite_main_grad", False + ) + if accumulate_into_main_grad: + grouped_fc2_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( + num_tensors=num_groups, + tensor_shape=fc2_weight_shape, + rowwise_data=main_grad, + dtype=main_grad.dtype, + ) + + if grouped_fc2_wgrad is None: + # TODO:ksivaman: This is not CUDA Graph safe. + grouped_fc2_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_groups, + shapes=[fc2_weight_shape] * num_groups, + quantizer=None, + device=device, + dtype=dtype, + ) + + # Launch GEMM + # A=grouped_input, B=grouped_fc2_dy; B's scales are GEMM-swizzled (see group_quantize above). + general_grouped_gemm_for_grouped_tensor( + grouped_fc2_x, + grouped_fc2_dy, + grouped_fc2_wgrad, + layout="NT", + accumulate=accumulate_into_main_grad, + ) + fc2_packed_wgrad = grouped_fc2_wgrad.rowwise_data.view( + num_groups, *fc2_weight_shape + ) + if accumulate_into_main_grad and hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + fc2_packed_wgrad = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + else: + if fc2_op._accumulate_into_main_grad: + for idx in range(num_groups): + weight_param = getattr(fc2_op, f"weight{idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + fc2_weight_grads[idx] = weight_param.main_grad + accumulate_into_main_grad = not getattr( + fc2_op.weight0, "overwrite_main_grad", False + ) + else: + for idx in range(num_groups): + fc2_weight_grads[idx] = torch.empty( + fc2_weight_shape, dtype=dtype, device=device + ) + + general_grouped_gemm_for_grouped_tensor( + grouped_fc2_x, + grouped_fc2_dy, + fc2_weight_grads, + layout="NT", + accumulate=accumulate_into_main_grad, + ) + if accumulate_into_main_grad: + for idx in range(num_groups): + weight_param = getattr(fc2_op, f"weight{idx}") + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + fc2_weight_grads[idx] = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + # Clear FC2 input tensor if possible + if grouped_fc2_x is not None: + clear_tensor_data( + grouped_fc2_x.data, + grouped_fc2_x.columnwise_data, + grouped_fc2_x.scale_inv, + grouped_fc2_x.columnwise_scale_inv, + ) + + # FC1 dgrad GEMM + grad_input = None + if fc1_ctx.input_requires_grad: + # Launch GEMM + in_shape = out_shape[:-1] + [fc1_weight_shape[1]] + grad_input = torch.empty(in_shape, dtype=dtype, device=device) + grouped_grad_input = make_grouped_tensor_from_buffers( + num_groups=num_groups, + data=grad_input, + split_sizes=split_sizes, + dtype=grad_input.dtype, + logical_last_dim=fc1_weight_shape[1], + ) + + general_grouped_gemm_for_grouped_tensor( + grouped_fc1_weight, + grouped_fc1_dy, + grouped_grad_input, + layout="NN", + accumulate=False, + ) + + # FC1 wgrad GEMM + fc1_packed_wgrad = None + fc1_weight_grads: list[Optional[torch.Tensor]] + if fc1_op.single_grouped_parameter: + fc1_weight_grads = [None] + else: + fc1_weight_grads = [None] * num_groups + if fc1_ctx.weight_requires_grad: + + # Initialize grad buffers + accumulate_into_main_grad = False + if fc1_op.single_grouped_parameter: + grouped_fc1_wgrad = None + weight_param = fc1_op.weight + if fc1_op._accumulate_into_main_grad: + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can + # accumulate directly into it. + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + main_grad = weight_param.main_grad + grouped_shape = (num_groups, *fc1_weight_shape) + if main_grad.shape != grouped_shape: + if main_grad.numel() != math.prod(grouped_shape): + raise RuntimeError( + "Grouped MLP fused backward expected FC1 main_grad to have " + f"shape {grouped_shape} or matching numel, " + f"but got shape {tuple(main_grad.shape)}" + ) + # Keep aliasing with weight.main_grad; do not allow implicit copies. + try: + main_grad = main_grad.view(grouped_shape) + except RuntimeError as e: + raise RuntimeError( + "Grouped MLP fused backward requires FC1 main_grad to be viewable" + f" as {grouped_shape} without copy, but got shape" + f" {tuple(main_grad.shape)} and stride {tuple(main_grad.stride())}" + ) from e + accumulate_into_main_grad = not getattr( + weight_param, "overwrite_main_grad", False + ) + if accumulate_into_main_grad: + grouped_fc1_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( + num_tensors=num_groups, + tensor_shape=fc1_weight_shape, + rowwise_data=main_grad, + dtype=main_grad.dtype, + ) + + if grouped_fc1_wgrad is None: + # TODO:ksivaman: This is not CUDA Graph safe. + grouped_fc1_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_groups, + shapes=[fc1_weight_shape] * num_groups, + quantizer=None, + device=device, + dtype=dtype, + ) + + # Launch GEMM + general_grouped_gemm_for_grouped_tensor( + grouped_fc1_x, + grouped_fc1_dy, + grouped_fc1_wgrad, + layout="NT", + accumulate=accumulate_into_main_grad, + ) + fc1_packed_wgrad = grouped_fc1_wgrad.rowwise_data.view( + num_groups, *fc1_weight_shape + ) + if accumulate_into_main_grad and hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + fc1_packed_wgrad = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + else: + if fc1_op._accumulate_into_main_grad: + for idx in range(num_groups): + weight_param = getattr(fc1_op, f"weight{idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + fc1_weight_grads[idx] = weight_param.main_grad + accumulate_into_main_grad = not getattr( + fc1_op.weight0, "overwrite_main_grad", False + ) + else: + for idx in range(num_groups): + fc1_weight_grads[idx] = torch.empty( + fc1_weight_shape, dtype=dtype, device=device + ) + + general_grouped_gemm_for_grouped_tensor( + grouped_fc1_x, + grouped_fc1_dy, + fc1_weight_grads, + layout="NT", + accumulate=accumulate_into_main_grad, + ) + if accumulate_into_main_grad: + for idx in range(num_groups): + weight_param = getattr(fc1_op, f"weight{idx}") + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + fc1_weight_grads[idx] = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + # Clear FC1 input tensor if possible + if grouped_fc1_x is not None: + clear_tensor_data( + grouped_fc1_x.data, + grouped_fc1_x.columnwise_data, + grouped_fc1_x.scale_inv, + grouped_fc1_x.columnwise_scale_inv, + ) + + # Construct param grads in parameter registration order. + if fc1_op.single_grouped_parameter: + fc1_weight_grads = [fc1_packed_wgrad] if fc1_packed_wgrad is not None else [None] + if fc2_op.single_grouped_parameter: + fc2_weight_grads = [fc2_packed_wgrad] if fc2_packed_wgrad is not None else [None] + + return ( + grad_input, + [fc1_weight_grads, (), fc2_weight_grads], + [(None,), (grad_scales,), (None,)], + ) + + def _get_kernel_constants( + self, + *, + num_groups: int, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + global global_alpha_tensor + alpha_tensor = self._mxfp8_alpha_tensor + norm_const_tensor = self._mxfp8_norm_const_tensor + if ( + alpha_tensor is None + or alpha_tensor.numel() != num_groups + or alpha_tensor.dtype != dtype + or alpha_tensor.device != device + ): + if global_alpha_tensor is None: + global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device) + alpha_tensor = global_alpha_tensor + norm_const_tensor = alpha_tensor[:1] + self._mxfp8_alpha_tensor = alpha_tensor + self._mxfp8_norm_const_tensor = norm_const_tensor + + return alpha_tensor, norm_const_tensor + + +def fuse_backward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Return immediately if fused kernel is not supported + if not BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + return ops + + # Check if recipe is supported + if recipe is None: + return ops + if not recipe.mxfp8(): + return ops + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + + # Check if window matches pattern + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[1], ScaledSwiGLU) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif window[0].has_bias or window[2].has_bias: + matches_pattern = False + elif window[0].num_groups != window[2].num_groups: + matches_pattern = False + elif ( + window[0].in_features % 256 != 0 + or window[0].out_features % 256 != 0 + or window[2].in_features % 256 != 0 + or window[2].out_features % 256 != 0 + ): + matches_pattern = False + elif window[1].glu_interleave_size != 32: + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8( + fc1=window[0], + swiglu=window[1], + fc2=window[2], + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops + out.extend(window) + return out + + +# Register fusion if available +if BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): + register_backward_fusion(fuse_backward_ops, prepend=True) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py new file mode 100644 index 0000000000..7b13e70d4e --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -0,0 +1,596 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for MoE grouped MLP.""" + +from __future__ import annotations +from collections.abc import Callable, Iterable +import os +import functools +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from cuda.bindings import driver as cuda +from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor +from ...module._common import noop_cat +from ...quantization import Recipe +from ...tensor import Quantizer +from ...utils import get_device_compute_capability +from ...tensor.grouped_tensor import GroupedTensor +from ...constants import MXFP8_BLOCK_SCALING_SIZE +from ..basic import GroupedLinear, ScaledSwiGLU +from ..fuser import register_forward_fusion +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + is_quantized_tensor, + make_grouped_tensor_from_buffers, + maybe_dequantize, +) + +global_alpha_tensor = None + + +class ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8(FusedOperation): + """Fused op for MXFP8 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel from cuDNN front-end. + + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_swiglu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, SwiGLU, and post-multiplication.""" + from cudnn import grouped_gemm_swiglu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_swiglu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if get_device_compute_capability() < (10, 0): + # Kernel requires SM100+ + return False + try: + # Make sure kernel is available + cls.grouped_gemm_swiglu_kernel() + except ImportError: + return False + return True + + def __init__( + self, + *, + fc1: GroupedLinear, + swiglu: ScaledSwiGLU, + fc2: GroupedLinear, + ) -> None: + super().__init__((fc1, swiglu, fc2)) + self._mxfp8_alpha_tensor: Optional[torch.Tensor] = None + self._mxfp8_norm_const_tensor: Optional[torch.Tensor] = None + # Check for unsupported configurations + if not self.is_supported(): + self.grouped_gemm_swiglu_kernel() # Try triggering import error + raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.") + if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, " + f"in_features={fc1.in_features}, out_features={fc1.out_features})." + ) + if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0: + raise ValueError( + f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, " + f"in_features={fc2.in_features}, out_features={fc2.out_features})." + ) + if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups: + raise ValueError( + f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, " + f"out_features={fc1.out_features}) " + f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, " + f"out_features={fc2.out_features}) do not match." + ) + if fc1.has_bias or fc2.has_bias: + raise ValueError("Fused kernel does not support bias.") + if swiglu.glu_interleave_size != 32: + raise ValueError( + "Fused kernel requires 32-wide GLU interleaving, " + f"but got glu_interleave_size={swiglu.glu_interleave_size}." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + # Get basic operations + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + # Tensor properties + in_shape = list(input_.size()) + assert len(in_shape) == 2, f"Expected 2D input tensor, got shape={in_shape}." + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) + + num_groups = fc1_op.num_groups + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_parameter else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_parameter else fc2_op.weight0 + device = fc1_weight_param.device + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_weight_param.dtype + + # Check which grads are required + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + weight_requires_grad = requires_grad and ( + fc1_weight_param.requires_grad or fc2_weight_param.requires_grad + ) + + # Quantizers + fc1_input_quantizers = [None] * num_groups + fc1_weight_quantizer = fc1_op.get_quantizer("forward", 1) + fc1_grad_output_quantizers = [None] * num_groups + fc2_input_quantizers = [None] * num_groups + fc2_weight_quantizer = fc2_op.get_quantizer("forward", 1) + fc2_grad_output_quantizers = [None] * num_groups + for idx in range(num_groups): + fc1_input_quantizers[idx] = fc1_op.get_quantizer("forward", 2 * idx) + fc1_grad_output_quantizers[idx] = fc1_op.get_quantizer("backward", idx) + fc2_input_quantizers[idx] = fc2_op.get_quantizer("forward", 2 * idx) + fc2_grad_output_quantizers[idx] = fc2_op.get_quantizer("backward", idx) + + # Extract split sizes from extra input + fc1_split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + fc1_split_sizes.size() != fc2_split_sizes.size() + or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError( + f"{self.__class__.__name__} got different split points for FC1 and FC2." + ) + split_sizes = fc1_split_sizes + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") + split_sizes = split_sizes.to(dtype=torch.int64, device=device) + split_points = torch.cumsum(split_sizes, 0, dtype=torch.int) + fc1_x_tensor_offsets = GroupedTensor.make_tensor_offsets(split_sizes, fc1_weight_shape[1]) + fc2_x_tensor_offsets = GroupedTensor.make_tensor_offsets(split_sizes, fc2_weight_shape[1]) + + # Extract post-scales from extra input + scales = basic_op_extra_inputs[1][0] + + # Prepare FC1 grouped weight tensor for fused kernels. + # Support both: + # - single_grouped_parameter=True: op.weight is already a GroupedTensor + # - single_grouped_parameter=False: pack per-group weights into a GroupedTensor + if fc1_op.single_grouped_parameter: + if not isinstance(fc1_op.weight, GroupedTensor): + raise RuntimeError( + "FC1 expected GroupedTensor weight with single_grouped_parameter=True." + ) + if fc1_op.weight.quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc1_op.weight.quantizer = fc1_weight_quantizer + grouped_fc1_weight = fc1_op.weight + else: + if fc1_op.weight.rowwise_data is None: + raise RuntimeError("FC1 grouped weight has no rowwise_data to quantize.") + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc1_weight = tex.group_quantize( + fc1_op.weight.rowwise_data.view(fc1_op.weight.logical_shape), + fc1_weight_quantizer, + num_groups, + None, + ) + else: + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc1_weights = [] + for idx, weight in enumerate(fc1_weights): + quantizer = fc1_op.get_quantizer("forward", 2 * idx + 1) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc1_weights.append(quantizer(weight)) + else: + quantized_fc1_weights.append(weight) + grouped_fc1_weight = quantized_fc1_weights + + # Prepare FC2 grouped weight tensor for fused kernels. + if fc2_op.single_grouped_parameter: + if not isinstance(fc2_op.weight, GroupedTensor): + raise RuntimeError( + "FC2 expected GroupedTensor weight with single_grouped_parameter=True." + ) + if fc2_op.weight.quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc2_op.weight.quantizer = fc2_weight_quantizer + grouped_fc2_weight = fc2_op.weight + else: + if fc2_op.weight.rowwise_data is None: + raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc2_weight = tex.group_quantize( + fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), + fc2_weight_quantizer, + num_groups, + None, + ) + else: + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc2_weights = [] + for idx, weight in enumerate(fc2_weights): + quantizer = fc2_op.get_quantizer("forward", 2 * idx + 1) + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc2_weights.append(quantizer(weight)) + else: + quantized_fc2_weights.append(weight) + grouped_fc2_weight = quantized_fc2_weights + + # Some wrapper-copy paths may drop grouped storage metadata; enforce defaults. + if getattr(grouped_fc1_weight, "with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc1_weight, GroupedTensor + ): + grouped_fc1_weight.with_gemm_swizzled_scales = False + if getattr(grouped_fc2_weight, "with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc2_weight, GroupedTensor + ): + grouped_fc2_weight.with_gemm_swizzled_scales = False + + # Group-quantize input tensor and convert dtypes if needed + fc1_x = maybe_dequantize(input_, dtype) + for quantizer in fc1_input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + quantizer.optimize_for_gemm = True + grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizers[0], num_groups, split_sizes) + + # Pack data tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (1, sum(m), k) + # Scale actual shape: (1, sum(m)/128, k/128, 32 (block row), + # 4 (block row), 4 (block col)) + # Data logical shape: (sum(m), k, 1) + # Scale logical shape: (32 (block row), 4 (block row), + # sum(m)/128, 4 (block col), k/128, 1) + fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1]) + fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) + fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) + fc1_x_scales = grouped_fc1_x.scale_inv + fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) + fc1_x_scales = fc1_x_scales.view( + 1, + in_shape[0] // 128, + in_shape[1] // 128, + 32, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + + # Pack weight tensors + # Note: Fused kernel expects tensor with non-contiguous + # logical dims. + # Data actual shape: (num_groups, n, k) + # Scale actual shape: (num_groups, n/128, k/128, 32 (block row), + # 4 (block row), 4 (block col)) + # Data logical shape: (n, k, num_groups) + # Scale logical shape: (32 (block row), 4 (block row), n/128, + # 4 (block col), k/128, num_groups) + fc1_w_data = ( + grouped_fc1_weight.rowwise_data + if fc1_op.single_grouped_parameter + else noop_cat([w._rowwise_data for w in grouped_fc1_weight]) + ) + fc1_w_data = fc1_w_data.view(dtype=torch.float8_e4m3fn) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1]) + fc1_w_data = fc1_w_data.permute(1, 2, 0) + fc1_w_scales = ( + grouped_fc1_weight.scale_inv + if fc1_op.single_grouped_parameter + else noop_cat([w._rowwise_scale_inv for w in grouped_fc1_weight]) + ) + fc1_w_scales = fc1_w_scales.view(dtype=torch.float8_e8m0fnu) + fc1_w_scales = fc1_w_scales.view( + num_groups, fc1_weight_shape[0] // 128, 4, 32, fc1_weight_shape[1] // 128, 4 + ) # Unswizzled layout + fc1_w_scales = fc1_w_scales.permute( + 0, 1, 4, 3, 2, 5 + ).contiguous() # Convert to swizzled layout + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + # Kernel scaling factors + alpha_tensor, norm_const_tensor = self._get_kernel_constants( + num_groups=num_groups, dtype=dtype, device=device + ) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Fused kernel for FC1 + SwiGLU + post-scale + fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( + fc1_x_data, + fc1_w_data, + fc1_x_scales, + fc1_w_scales, + split_points, + alpha_tensor, # alpha_tensor + norm_const_tensor=norm_const_tensor, + prob_tensor=scales.detach().reshape(-1, 1, 1), + acc_dtype=torch.float32, + c_dtype=torch.bfloat16, + d_dtype=torch.float8_e4m3fn, + cd_major="n", + sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, + current_stream=current_stream, + discrete_col_sfd=True, + ) + + # Unpack kernel outputs + # Note: Fused kernel outputs tensors with non-contiguous + # logical dims. + # Row-wise data logical shape: (sum(m_splits), k, 1) + # Row-wise scale logical shape: (32 (block row), 4 (block row), + # sum(m_splits)/128, 4 (block col), k/128, 1) + # Column-wise data logical shape: (sum(m_splits), k, 1) + # Column-wise scale logical shape: (32 (block col), 4 (block col), + # k/128, 4 (block row), sum(m_splits)/128, 1) + swiglu_in = fc1_kernel_out["c_tensor"] + swiglu_in = swiglu_in.permute(2, 0, 1) + swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0]) + fc2_in_row_data = fc1_kernel_out["d_tensor"] + fc2_in_row_data = fc2_in_row_data.permute(2, 0, 1) + fc2_in_row_data = fc2_in_row_data.view(in_shape[0], fc2_weight_shape[1]).contiguous() + fc2_in_row_scale = fc1_kernel_out["sfd_row_tensor"] + fc2_in_row_scale = fc2_in_row_scale.permute(5, 2, 4, 0, 1, 3) + + fc2_in_col_data = fc1_kernel_out["d_col_tensor"] + fc2_in_col_data = fc2_in_col_data.permute(2, 0, 1) + fc2_in_col_data = fc2_in_col_data.view(in_shape[0], fc2_weight_shape[1]).contiguous() + fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] + fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) + # Repack columnwise scales on GPU to preserve group ordering. + + # FC2 inputs scales are already swizzled/optimized for GEMM + grouped_fc2_x = make_grouped_tensor_from_buffers( + num_groups=num_groups, + data=fc2_in_row_data.reshape(-1), + columnwise_data=fc2_in_col_data.reshape(-1), + scale_inv=fc2_in_row_scale.reshape(-1), + columnwise_scale_inv=fc2_in_col_scale.reshape(-1), + split_sizes=split_sizes, + logical_last_dim=fc2_weight_shape[1], + dtype=dtype, + quantizer=fc2_input_quantizers[0], + with_gemm_swizzled_scales=True, + tensor_offsets=fc2_x_tensor_offsets, + ) + + # FC2 GEMM + fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] + fc2_out = torch.empty(fc2_out_shape, dtype=dtype, device=device) + grouped_fc2_out = make_grouped_tensor_from_buffers( + num_groups=num_groups, + data=fc2_out, + split_sizes=split_sizes, + dtype=fc2_out.dtype, + logical_last_dim=fc2_weight_shape[0], + ) + + general_grouped_gemm_for_grouped_tensor( + grouped_fc2_weight, + grouped_fc2_x, + grouped_fc2_out, + layout="TN", + accumulate=False, + ) + + # Prepare input tensors for backward pass + if not weight_requires_grad: + grouped_fc1_x = None + grouped_fc2_x = None + + # Save state for backward pass + if requires_grad: + if grouped_fc1_x is not None: + grouped_fc1_x.columnwise_data.grouped_name = "fc1_columnwise_data" + grouped_fc1_x.columnwise_data.logical_shape = grouped_fc1_x.logical_shape + grouped_fc1_x.columnwise_scale_inv.grouped_name = "fc1_columnwise_scale_inv" + grouped_fc1_x.columnwise_scale_inv.logical_shape = grouped_fc1_x.logical_shape + fc1_input_tensors = ( + None, # data + grouped_fc1_x.columnwise_data, # columnwise_data + None, # scale_inv + grouped_fc1_x.columnwise_scale_inv, # columnwise_scale_inv + fc1_x_tensor_offsets, # tensor_offsets + ) + else: + fc1_input_tensors = (None, None, None, None, None) + # FC1 + if fc1_op.single_grouped_parameter: + fc1_ctx.save_for_backward( + split_sizes, split_points, grouped_fc1_weight, *fc1_input_tensors + ) + else: + fc1_ctx.save_for_backward( + split_sizes, split_points, *grouped_fc1_weight, *fc1_input_tensors + ) + fc1_ctx.with_quantized_compute = True + fc1_ctx.input_quantizers = fc1_input_quantizers + fc1_ctx.weight_quantizer = fc1_weight_quantizer + fc1_ctx.grad_output_quantizers = fc1_grad_output_quantizers + fc1_ctx.grad_input_quantizers = None + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = weight_requires_grad + + # Scaled SwiGLU + swiglu_in.grouped_name = "swiglu_in" + scales.grouped_name = "scales" + swiglu_ctx.save_for_backward(swiglu_in, scales) + swiglu_ctx.input_requires_grad = True + swiglu_ctx.extra_input_requires_grad = True + swiglu_ctx.dtype = dtype + + # FC2 state + if grouped_fc2_x is not None: + grouped_fc2_x.columnwise_data.grouped_name = "fc2_columnwise_data" + grouped_fc2_x.columnwise_data.logical_shape = grouped_fc2_x.logical_shape + grouped_fc2_x.columnwise_scale_inv.grouped_name = "fc2_columnwise_scale_inv" + grouped_fc2_x.columnwise_scale_inv.logical_shape = grouped_fc2_x.logical_shape + fc2_input_tensors = ( + None, # data + grouped_fc2_x.columnwise_data, # columnwise_data + None, # scale_inv + grouped_fc2_x.columnwise_scale_inv, # columnwise_scale_inv + fc2_x_tensor_offsets, # tensor_offsets + ) + else: + fc2_input_tensors = (None, None, None, None, None) + + if fc2_op.single_grouped_parameter: + fc2_ctx.save_for_backward(split_sizes, grouped_fc2_weight, *fc2_input_tensors) + else: + fc2_ctx.save_for_backward(split_sizes, *grouped_fc2_weight, *fc2_input_tensors) + + fc2_ctx.with_quantized_compute = True + fc2_ctx.input_quantizers = fc2_input_quantizers + fc2_ctx.weight_quantizer = fc2_weight_quantizer + fc2_ctx.grad_output_quantizers = fc2_grad_output_quantizers + fc2_ctx.grad_input_quantizers = None + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = weight_requires_grad + + return fc2_out, [(), (), ()] + + def _get_kernel_constants( + self, + *, + num_groups: int, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + global global_alpha_tensor + alpha_tensor = self._mxfp8_alpha_tensor + norm_const_tensor = self._mxfp8_norm_const_tensor + if ( + alpha_tensor is None + or alpha_tensor.numel() != num_groups + or alpha_tensor.dtype != dtype + or alpha_tensor.device != device + ): + if global_alpha_tensor is None: + global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device) + alpha_tensor = global_alpha_tensor + norm_const_tensor = alpha_tensor[:1] + self._mxfp8_alpha_tensor = alpha_tensor + self._mxfp8_norm_const_tensor = norm_const_tensor + elif ( + norm_const_tensor is None + or norm_const_tensor.numel() != 1 + or norm_const_tensor.dtype != dtype + or norm_const_tensor.device != device + ): + norm_const_tensor = alpha_tensor[:1] + self._mxfp8_norm_const_tensor = norm_const_tensor + return alpha_tensor, norm_const_tensor + + +def fuse_forward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Return immediately if fused kernel is not supported + if not ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + return ops + + # Check if recipe is supported + if recipe is None: + return ops + if not recipe.mxfp8(): + return ops + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + + # Check if window matches pattern + matches_pattern = True + if not ( + isinstance(window[0], GroupedLinear) + and isinstance(window[1], ScaledSwiGLU) + and isinstance(window[2], GroupedLinear) + ): + matches_pattern = False + elif window[0].has_bias or window[2].has_bias: + matches_pattern = False + elif window[0].num_groups != window[2].num_groups: + matches_pattern = False + elif ( + window[0].in_features % 256 != 0 + or window[0].out_features % 256 != 0 + or window[2].in_features % 256 != 0 + or window[2].out_features % 256 != 0 + ): + matches_pattern = False + elif window[1].glu_interleave_size != 32: + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8( + fc1=window[0], + swiglu=window[1], + fc2=window[2], + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops + out.extend(window) + return out + + +# Register fusion if available +if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): + register_forward_fusion(fuse_forward_ops, prepend=True) diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index 2fce9a38e2..56c70c3c02 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -89,9 +89,9 @@ def __new__( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, + with_gemm_swizzled_scales: bool = False, requires_grad: bool = False, stride: Optional[List[int]] = None, - with_gemm_swizzled_scales: bool = False, ): if ( shapes is not None @@ -186,6 +186,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.columnwise_scale_inv_offsets = src.columnwise_scale_inv_offsets dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors + dst.with_gemm_swizzled_scales = src.with_gemm_swizzled_scales def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 68097259c6..a1c02934bd 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -69,9 +69,9 @@ def _initialize_storage_fields( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, + with_gemm_swizzled_scales: bool = False, requires_grad: bool = False, stride: Optional[List[int]] = None, - with_gemm_swizzled_scales: bool = False, ) -> None: """ Initialize a GroupedTensor. @@ -147,6 +147,8 @@ def _initialize_storage_fields( instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + instance.with_gemm_swizzled_scales = with_gemm_swizzled_scales + def __new__( cls, shape: Tuple[int, int], @@ -168,9 +170,9 @@ def __new__( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, + with_gemm_swizzled_scales: bool = False, requires_grad: bool = False, stride: Optional[List[int]] = None, - with_gemm_swizzled_scales: bool = False, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -193,12 +195,14 @@ def __new__( offsets=offsets, scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, requires_grad=requires_grad, stride=stride, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, ) return instance + self.with_gemm_swizzled_scales = with_gemm_swizzled_scales + def has_data(self) -> bool: """ Check if the tensor has row-wise data. @@ -383,6 +387,78 @@ def make_grouped_tensor_with_shapes( dtype=dtype, ) + @staticmethod + def make_grouped_tensor_from_rowwise_data( + *, + num_tensors: int, + tensor_shape: Tuple[int, int], + rowwise_data: torch.Tensor, + dtype: Optional[torch.dtype] = None, + internal: bool = False, + ) -> GroupedTensorStorage: + """Wrap pre-existing contiguous rowwise data as a grouped tensor. + + This helper does not allocate storage. It creates grouped metadata over + `rowwise_data`, which is expected to contain `num_tensors` matrices of + shape `tensor_shape` in packed contiguous layout. + """ + if num_tensors <= 0: + raise ValueError(f"num_tensors must be positive, got {num_tensors}") + if rowwise_data is None: + raise ValueError("rowwise_data must not be None") + if not rowwise_data.is_contiguous(): + rowwise_data = rowwise_data.contiguous() + + rows, cols = tensor_shape + expected_numel = num_tensors * rows * cols + if rowwise_data.numel() != expected_numel: + raise ValueError( + "Grouped rowwise buffer size mismatch: expected " + f"{expected_numel} elements for {num_tensors}x{tensor_shape}, " + f"but got {rowwise_data.numel()}" + ) + if dtype is None: + dtype = rowwise_data.dtype + + logical_shape = (num_tensors * rows, cols) + grouped_tensor_class = GroupedTensorStorage + if not internal: + from ..grouped_tensor import GroupedTensor + + grouped_tensor_class = GroupedTensor + + return grouped_tensor_class( + shape=logical_shape, + dtype=dtype, + num_tensors=num_tensors, + shapes=[tensor_shape] * num_tensors, + quantizer=None, + data=rowwise_data.view(-1), + columnwise_data=None, + scale_inv=None, + columnwise_scale_inv=None, + amax=None, + columnwise_amax=None, + scale=None, + first_dims=None, + last_dims=None, + tensor_offsets=None, + offsets=None, + scale_inv_offsets=None, + columnwise_scale_inv_offsets=None, + with_gemm_swizzled_scales=False, + requires_grad=False, + ) + + @staticmethod + def make_tensor_offsets(first_dims: torch.Tensor, logical_last_dim: int) -> torch.Tensor: + return torch.cat( + [ + torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), + torch.cumsum(first_dims * logical_last_dim, dim=0), + ] + ) + @staticmethod def make_grouped_tensor( num_tensors: int, @@ -439,16 +515,20 @@ def make_grouped_tensor( # Kernels need to calculate precise pointers based on size of elements. # TODO(ksivaman): Single kernel + remove the host offset calculation. - tensor_offsets = torch.cat( - [ - torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), - torch.cumsum(first_dims * logical_last_dim, dim=0), - ] - ) - offsets = tensor_offsets.tolist() - first_dims_list = first_dims.tolist() - for i in range(num_tensors): - shape.append((first_dims_list[i], logical_last_dim)) + tensor_offsets = GroupedTensorStorage.make_tensor_offsets(first_dims, logical_last_dim) + if ( + first_dims.device.type == "cuda" + and torch.cuda.is_available() + and torch.cuda.is_current_stream_capturing() + ): + # Avoid host sync during CUDA graph capture. + offsets = None + shape = None + else: + offsets = tensor_offsets.tolist() + first_dims_list = first_dims.tolist() + for i in range(num_tensors): + shape.append((first_dims_list[i], logical_last_dim)) else: offsets = [ i * logical_first_dim * logical_last_dim // num_tensors @@ -464,7 +544,7 @@ def make_grouped_tensor( rowwise_usage = quantizer.rowwise_usage if not no_quantization else True columnwise_usage = quantizer.columnwise_usage if not no_quantization else False - + with_gemm_swizzled_scales = quantizer.optimize_for_gemm if not no_quantization else False # Calculate total elements across all tensors total_elements = logical_first_dim * logical_last_dim @@ -477,6 +557,11 @@ def make_grouped_tensor( scale = None scale_inv_offsets = None columnwise_scale_inv_offsets = None + if shape is None and not no_quantization: + raise RuntimeError( + "Cannot materialize quantized GroupedTensor with varying first dims " + "during CUDA graph capture." + ) if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -508,7 +593,7 @@ def make_grouped_tensor( total_columnwise_scale_elements = 0 columnwise_scale_inv_offsets = [0] for i, s in enumerate(shape): - scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_inv_shape = quantizer.get_scale_shape(s, True) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) @@ -649,12 +734,11 @@ def make_grouped_tensor( offsets=offsets, scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, - with_gemm_swizzled_scales=( - quantizer.optimize_for_gemm if quantizer is not None else False - ), + with_gemm_swizzled_scales=with_gemm_swizzled_scales, ) - grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() + if grouped_tensor.shape is not None: + grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor def split_into_quantized_tensors( @@ -674,7 +758,6 @@ def split_into_quantized_tensors( TODO(ksivaman): Block cases where any dims are varying. This is needed only to expose the weights as separate parameters. """ - result = [] no_quantization = self.quantizer is None From a15481e971bd6a6421dddcd5a2cdf711cb57005d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 17 Mar 2026 03:38:12 +0000 Subject: [PATCH 2/8] cleanup/lint Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/ops/_common.py | 64 ------------------- .../pytorch/ops/basic/grouped_linear.py | 2 +- .../pytorch/ops/fused/backward_grouped_mlp.py | 12 ++-- .../pytorch/ops/fused/forward_grouped_mlp.py | 7 +- .../tensor/storage/grouped_tensor_storage.py | 3 +- 5 files changed, 10 insertions(+), 78 deletions(-) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index f955eabd87..8dab6dd464 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -15,7 +15,6 @@ from ..tensor.float8_tensor import Float8Tensor from ..quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype -from ..module._common import noop_cat from ..tensor import Quantizer from ..tensor.grouped_tensor import GroupedTensor @@ -120,66 +119,3 @@ def make_grouped_tensor_from_buffers( columnwise_scale_inv_offsets=None, with_gemm_swizzled_scales=with_gemm_swizzled_scales, ) - - -def make_grouped_tensor_from_mxfp8_weights( - weights: list, - quantizer: Quantizer, - device: torch.device, - dtype: torch.dtype, - with_gemm_swizzled_scales: bool = False, -) -> GroupedTensor: - """Build a GroupedTensor from MXFP8 weight tensors by packing their buffers (no copy when contiguous).""" - num_groups = len(weights) - weight_shape = weights[0].shape - O, I = weight_shape[0], weight_shape[1] - logical_first_dim = num_groups * O - logical_last_dim = I - - tensor_offsets = None - data = None - scale_inv = None - scale_inv_offsets = None - columnwise_data = None - columnwise_scale_inv = None - columnwise_scale_inv_offsets = None - - # Pack rowwise into data/scale_inv when available. - # GEMM expects scales in swizzled layout (same as FC1 weight scales in grouped_gemm_swiglu). - if weights[0]._rowwise_data is not None: - data = noop_cat([w._rowwise_data.reshape(-1) for w in weights]) - rowwise_scales = noop_cat([w._rowwise_scale_inv for w in weights]) - if with_gemm_swizzled_scales: - rowwise_scales = rowwise_scales.view(num_groups, O // 128, 4, 32, I // 128, 4) - rowwise_scales = rowwise_scales.permute(0, 1, 4, 3, 2, 5).contiguous() - scale_inv = rowwise_scales.reshape(-1) - # Pack columnwise into columnwise_* when available. - # GEMM expects columnwise scales in swizzled layout (same as FC2 weight scales in backward dSwiGLU kernel). - if weights[0]._columnwise_data is not None: - columnwise_data = noop_cat([w._columnwise_data.reshape(-1) for w in weights]) - columnwise_scales = noop_cat([w._columnwise_scale_inv for w in weights]) - if with_gemm_swizzled_scales: - columnwise_scales = columnwise_scales.view(num_groups, O // 128, 4, I // 128, 4, 32) - columnwise_scales = columnwise_scales.permute(0, 3, 1, 5, 4, 2).contiguous() - columnwise_scale_inv = columnwise_scales.reshape(-1) - - return GroupedTensor( - shape=(logical_first_dim, logical_last_dim), - dtype=dtype, - num_tensors=num_groups, - quantizer=quantizer, - data=data, - columnwise_data=columnwise_data, - scale_inv=scale_inv, - columnwise_scale_inv=columnwise_scale_inv, - amax=None, - columnwise_amax=None, - scale=None, - first_dims=None, - last_dims=None, - tensor_offsets=tensor_offsets, - offsets=None, - scale_inv_offsets=scale_inv_offsets, - columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, - with_gemm_swizzled_scales=with_gemm_swizzled_scales, - ) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 45295a2324..1884348622 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -815,7 +815,7 @@ def fuser_backward( if self.single_grouped_parameter: grad_weight = None - # TODO:ksivaman change workflow to avoid stack. + # Unfused path single param: Can be optimized to remove stack. if ctx.weight_requires_grad: grad_weight = torch.stack(grad_weights, dim=0) # Parameter registration order with single_grouped_parameter=True is: diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 04ebd6b819..a3794f5f3e 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -6,30 +6,26 @@ from __future__ import annotations from collections.abc import Callable -import os import functools import math -from pickle import TRUE from typing import Optional import torch +from cuda.bindings import driver as cuda import transformer_engine_torch as tex -from cuda.bindings import driver as cuda from ...cpp_extensions import ( general_grouped_gemm_for_grouped_tensor, ) from ...module._common import noop_cat from ...module.base import get_dummy_wgrad from ...quantization import Recipe -from ...tensor import Quantizer from ...tensor.grouped_tensor import GroupedTensor from ...utils import clear_tensor_data, get_device_compute_capability from ..basic import GroupedLinear, ScaledSwiGLU from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( - is_quantized_tensor, make_grouped_tensor_from_buffers, maybe_dequantize, ) @@ -289,7 +285,9 @@ def fuser_backward( alpha_tensor, norm_const_tensor = self._get_kernel_constants( num_groups=num_groups, dtype=dtype, device=device ) - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + current_stream = cuda.CUstream( # pylint: disable=c-extension-no-member + torch.cuda.current_stream().cuda_stream + ) # Fused kernel for FC2 dgrad + dSwiGLU + grad scale fc2_dgrad_kernel_out = self.grouped_gemm_dswiglu_kernel()( @@ -401,7 +399,6 @@ def fuser_backward( ) if grouped_fc2_wgrad is None: - # TODO:ksivaman: This is not CUDA Graph safe. grouped_fc2_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_groups, shapes=[fc2_weight_shape] * num_groups, @@ -544,7 +541,6 @@ def fuser_backward( ) if grouped_fc1_wgrad is None: - # TODO:ksivaman: This is not CUDA Graph safe. grouped_fc1_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_groups, shapes=[fc1_weight_shape] * num_groups, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 7b13e70d4e..c8550a2a7c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -6,14 +6,13 @@ from __future__ import annotations from collections.abc import Callable, Iterable -import os import functools from typing import Any, Optional import torch +from cuda.bindings import driver as cuda import transformer_engine_torch as tex -from cuda.bindings import driver as cuda from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor from ...module._common import noop_cat from ...quantization import Recipe @@ -315,7 +314,9 @@ def fuser_forward( alpha_tensor, norm_const_tensor = self._get_kernel_constants( num_groups=num_groups, dtype=dtype, device=device ) - current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + current_stream = cuda.CUstream( # pylint: disable=c-extension-no-member + torch.cuda.current_stream().cuda_stream + ) # Fused kernel for FC1 + SwiGLU + post-scale fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index a1c02934bd..d64463d692 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -201,8 +201,6 @@ def __new__( ) return instance - self.with_gemm_swizzled_scales = with_gemm_swizzled_scales - def has_data(self) -> bool: """ Check if the tensor has row-wise data. @@ -452,6 +450,7 @@ def make_grouped_tensor_from_rowwise_data( @staticmethod def make_tensor_offsets(first_dims: torch.Tensor, logical_last_dim: int) -> torch.Tensor: + """Calculate GPU offsets from first dim splits.""" return torch.cat( [ torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), From bab8bf7bb9501d3b2bbff331dc23b98949814a00 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 17 Mar 2026 04:15:03 +0000 Subject: [PATCH 3/8] Properly cache the alpha tensor Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/ops/fused/backward_grouped_mlp.py | 7 ++++++- .../pytorch/ops/fused/forward_grouped_mlp.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index a3794f5f3e..544a463b1a 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -638,7 +638,12 @@ def _get_kernel_constants( or alpha_tensor.dtype != dtype or alpha_tensor.device != device ): - if global_alpha_tensor is None: + if ( + global_alpha_tensor is None + or global_alpha_tensor.numel() != num_groups + or global_alpha_tensor.dtype != dtype + or global_alpha_tensor.device != device + ): global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device) alpha_tensor = global_alpha_tensor norm_const_tensor = alpha_tensor[:1] diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index c8550a2a7c..ba8d0fa284 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -491,7 +491,12 @@ def _get_kernel_constants( or alpha_tensor.dtype != dtype or alpha_tensor.device != device ): - if global_alpha_tensor is None: + if ( + global_alpha_tensor is None + or global_alpha_tensor.numel() != num_groups + or global_alpha_tensor.dtype != dtype + or global_alpha_tensor.device != device + ): global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device) alpha_tensor = global_alpha_tensor norm_const_tensor = alpha_tensor[:1] From 7d95c175cf2eb08078b45037a76c5d0ef5fb7290 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 16 Mar 2026 21:59:22 -0700 Subject: [PATCH 4/8] nD dummy grad Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6c708ed397..937484451c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -77,11 +77,9 @@ class UserBufferQuantizationMode(Enum): NONE = "none" FP8 = "fp8" - def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: """Returns a dummy tensor of given shape.""" - if len(shape) != 2: - raise ValueError(f"Expected 2D shape, got {len(shape)}D: {shape}") + key = (*shape, dtype) global _dummy_wgrads if key not in _dummy_wgrads: @@ -95,7 +93,6 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor _dummy_wgrads[key].fill_(0) return _dummy_wgrads[key].detach() - def initialize_ub( shape: list, tp_size: int, From 817d6f92f0e9a16bede22e6c909b373bdf68e778 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 05:07:13 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 937484451c..a96a87bf89 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -77,6 +77,7 @@ class UserBufferQuantizationMode(Enum): NONE = "none" FP8 = "fp8" + def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: """Returns a dummy tensor of given shape.""" @@ -93,6 +94,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor _dummy_wgrads[key].fill_(0) return _dummy_wgrads[key].detach() + def initialize_ub( shape: list, tp_size: int, From 886fc4d0f7e2d9386bff31b5752d7555f746b059 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 16 Mar 2026 22:32:26 -0700 Subject: [PATCH 6/8] 0 tokens in entire rank Signed-off-by: Kirthi Shankar Sivamani --- .../common/gemm/cublaslt_grouped_gemm.cu | 85 ++++++++++++++++--- 1 file changed, 72 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index ccf1e53ba4..62d9dd32b4 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -107,6 +107,27 @@ inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) return static_cast(t->logical_shape.data[1]) / static_cast(t->num_tensors); } +inline bool grouped_tensor_has_zero_work(const transformer_engine::GroupedTensor *t) { + return t != nullptr && t->logical_shape.ndim == 2 && + (t->logical_shape.data[0] == 0 || t->logical_shape.data[1] == 0); +} + +inline bool tensor_has_zero_work(const transformer_engine::Tensor *t) { + const auto shape = t->shape(); + return shape.size() == 2 && (shape[0] == 0 || shape[1] == 0); +} + +inline bool tensor_list_has_zero_work(const NVTETensor *tensor_list, size_t list_size) { + if (list_size == 0) return false; + for (size_t i = 0; i < list_size; ++i) { + const auto *t = transformer_engine::convertNVTETensorCheck(tensor_list[i]); + if (!tensor_has_zero_work(t)) { + return false; + } + } + return true; +} + // Constants for grouped GEMM workspace (declared early for use in helpers) static constexpr size_t kGroupedGemmAlignment = 256; static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB @@ -227,13 +248,16 @@ inline size_t validate_grouped_gemm_inputs( }; bool dtype_ok = true; for (const auto *tensor : inputs) { - dtype_ok = dtype_ok && is_supported_input_dtype(tensor->dtype()); - } - NVTE_CHECK(dtype_ok, "Grouped GEMM inputs must be FP8, BF16, or FP16."); - for (const auto *tensor : inputs) { - NVTE_CHECK(tensor->has_data() || tensor->has_columnwise_data(), + const bool has_empty_logical_shape = + tensor->logical_shape.ndim == 2 && + (tensor->logical_shape.data[0] == 0 || tensor->logical_shape.data[1] == 0); + if (!has_empty_logical_shape) { + dtype_ok = dtype_ok && is_supported_input_dtype(tensor->dtype()); + } + NVTE_CHECK(tensor->has_data() || tensor->has_columnwise_data() || has_empty_logical_shape, "Grouped GEMM: input tensor is missing both row-wise and column-wise data"); } + NVTE_CHECK(dtype_ok, "Grouped GEMM inputs must be FP8, BF16, or FP16."); // Cross-operand consistency across all inputs. const auto *ref = *inputs.begin(); @@ -303,6 +327,7 @@ struct GroupedOperandSelection { transformer_engine::DType dtype = transformer_engine::DType::kNumTypes; NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; bool with_gemm_swizzled_scales = false; + bool is_empty_logical_shape = false; bool trans = false; }; @@ -551,7 +576,10 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: using namespace transformer_engine; const bool has_row = t->has_data(); const bool has_col = t->has_columnwise_data(); - NVTE_CHECK(has_row || has_col, + const bool has_empty_logical_shape = + t->logical_shape.ndim == 2 && + (t->logical_shape.data[0] == 0 || t->logical_shape.data[1] == 0); + NVTE_CHECK(has_row || has_col || has_empty_logical_shape, "Grouped GEMM operand is missing both row-wise and column-wise data"); const auto sm = t->scaling_mode; @@ -568,6 +596,14 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: sel.scaling_mode = sm; sel.with_gemm_swizzled_scales = t->with_gemm_swizzled_scales; + // Empty logical tensors may not allocate rowwise/columnwise buffers. + if (!has_row && !has_col && has_empty_logical_shape) { + sel.is_empty_logical_shape = true; + sel.dtype = t->dtype(); + sel.shape = create_shape_info(t, /*swap_dims=*/false); + return sel; + } + const DType rep_dtype = has_row ? row_dtype : col_dtype; const bool is_fp8 = is_fp8_dtype(rep_dtype); const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); @@ -756,6 +792,11 @@ inline void execute_grouped_gemm(const GroupedGemmSetupWorkspace &setup_workspac bool use_split_accumulator, bool use_fp8, int64_t avg_m_val, int64_t avg_n_val, int64_t avg_k_val, void *cublas_workspace_ptr, cudaStream_t stream) { + // Zero-work grouped GEMM is a no-op. + if (avg_m_val == 0 || avg_n_val == 0 || avg_k_val == 0) { + return; + } + using cublasHandleManager = transformer_engine::detail::HandleManager; cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); @@ -897,13 +938,16 @@ __global__ void setup_grouped_gemm_kernel( int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx); // Compute data pointers - A_ptrs[idx] = - has_a_multi_tensor ? a_multi_tensor_args.data_ptrs[idx] : (a_base + a_offset * a_elem_size); - B_ptrs[idx] = b_base + b_offset * b_elem_size; - C_ptrs[idx] = - has_c_multi_tensor ? c_multi_tensor_args.data_ptrs[idx] : (c_base + c_offset * c_elem_size); - D_ptrs[idx] = - has_d_multi_tensor ? d_multi_tensor_args.data_ptrs[idx] : (d_base + d_offset * d_elem_size); + A_ptrs[idx] = has_a_multi_tensor + ? a_multi_tensor_args.data_ptrs[idx] + : (a_base ? (a_base + a_offset * a_elem_size) : nullptr); + B_ptrs[idx] = b_base ? (b_base + b_offset * b_elem_size) : nullptr; + C_ptrs[idx] = has_c_multi_tensor + ? c_multi_tensor_args.data_ptrs[idx] + : (c_base ? (c_base + c_offset * c_elem_size) : nullptr); + D_ptrs[idx] = has_d_multi_tensor + ? d_multi_tensor_args.data_ptrs[idx] + : (d_base ? (d_base + d_offset * d_elem_size) : nullptr); // Compute storage dimensions for cuBLAS matrix layouts. // For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS, @@ -1047,6 +1091,11 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT // Parse config (if provided) GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); + // Zero-token edge case: bypass grouped GEMM setup/launch entirely. + if (grouped_tensor_has_zero_work(inputA) || grouped_tensor_has_zero_work(inputB)) { + return; + } + // Validate inputs and outputs. const size_t num_tensors = validate_grouped_gemm_inputs(inputA->num_tensors, {inputA, inputB}, alpha_tensor, beta_tensor); @@ -1108,6 +1157,11 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num // Parse config (if provided) GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); + // Zero-token edge case: bypass grouped GEMM setup/launch entirely. + if (tensor_list_has_zero_work(A_list, num_a_tensors) || grouped_tensor_has_zero_work(inputB)) { + return; + } + // Validate inputs and outputs. const size_t num_tensors = validate_grouped_gemm_inputs(num_a_tensors, {inputB}, alpha_tensor, beta_tensor); @@ -1247,6 +1301,11 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, // Parse config (if provided) GroupedMatmulConfig config_ = parse_grouped_gemm_config(config); + // Zero-token edge case: bypass grouped GEMM setup/launch entirely. + if (grouped_tensor_has_zero_work(inputA) || grouped_tensor_has_zero_work(inputB)) { + return; + } + // Select operand storage (row-wise vs column-wise) and adjust transpose flags to // mirror the non-grouped GEMM logic for FP8 layout constraints. auto A_sel = select_grouped_operand(inputA, static_cast(transa), /*is_A=*/true); From 6bd3812e9396e04be23678be58da44b1da9d9ba2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 05:33:22 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_grouped_gemm.cu | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 62d9dd32b4..ce9b3579fe 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -938,16 +938,13 @@ __global__ void setup_grouped_gemm_kernel( int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx); // Compute data pointers - A_ptrs[idx] = has_a_multi_tensor - ? a_multi_tensor_args.data_ptrs[idx] - : (a_base ? (a_base + a_offset * a_elem_size) : nullptr); + A_ptrs[idx] = has_a_multi_tensor ? a_multi_tensor_args.data_ptrs[idx] + : (a_base ? (a_base + a_offset * a_elem_size) : nullptr); B_ptrs[idx] = b_base ? (b_base + b_offset * b_elem_size) : nullptr; - C_ptrs[idx] = has_c_multi_tensor - ? c_multi_tensor_args.data_ptrs[idx] - : (c_base ? (c_base + c_offset * c_elem_size) : nullptr); - D_ptrs[idx] = has_d_multi_tensor - ? d_multi_tensor_args.data_ptrs[idx] - : (d_base ? (d_base + d_offset * d_elem_size) : nullptr); + C_ptrs[idx] = has_c_multi_tensor ? c_multi_tensor_args.data_ptrs[idx] + : (c_base ? (c_base + c_offset * c_elem_size) : nullptr); + D_ptrs[idx] = has_d_multi_tensor ? d_multi_tensor_args.data_ptrs[idx] + : (d_base ? (d_base + d_offset * d_elem_size) : nullptr); // Compute storage dimensions for cuBLAS matrix layouts. // For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS, From bf7af9fb60ee367cc9056a623a0ab14771f746ca Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 16 Mar 2026 22:58:08 -0700 Subject: [PATCH 8/8] tmp downgrade cublas version check Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/gemm/cublaslt_grouped_gemm.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 4029c9cb52..446e939b15 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -30,10 +30,10 @@ inline void CreateCublasHandle(cublasLtHandle_t *handle) { } // namespace // MXFP8 support for grouped GEMM requires cuBLAS 13.3+ -#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130300 +#define CUBLAS_MXFP8_GROUPED_GEMM_VERSION 130200 // BF16 support for grouped GEMM requires cuBLAS 13.3+ // cuBLAS 13.2 is mostly functional but contains a bug for wgrad when a group has k=0, the weight gradient will be uninitialized random data instead of zeros. -#define CUBLAS_GROUPED_GEMM_VERSION 130300 +#define CUBLAS_GROUPED_GEMM_VERSION 130200 #if CUBLAS_VERSION >= CUBLAS_GROUPED_GEMM_VERSION