From d4f726de712438518665b70cf282e9c6bbe8227d Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Wed, 11 Feb 2026 16:05:21 +0800 Subject: [PATCH 1/3] issue/972 - feat: add scaled_mm with muDNN BatchMatMul for moore gpu --- src/infiniop/ops/scaled_mm/info.h | 2 +- src/infiniop/ops/scaled_mm/int8_gemm.h | 4 +- .../ops/scaled_mm/moore/int8_gemm_moore.h | 7 + .../ops/scaled_mm/moore/int8_gemm_moore.mu | 238 ++++++++++++++++++ src/infiniop/ops/scaled_mm/operator.cc | 16 ++ test/infiniop/scaled_mm_int8.py | 77 ++++-- 6 files changed, 316 insertions(+), 28 deletions(-) create mode 100644 src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h create mode 100644 src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.mu diff --git a/src/infiniop/ops/scaled_mm/info.h b/src/infiniop/ops/scaled_mm/info.h index f29a8198a..2bce44ca6 100644 --- a/src/infiniop/ops/scaled_mm/info.h +++ b/src/infiniop/ops/scaled_mm/info.h @@ -1,4 +1,4 @@ -#ifndef __GEMM_INFO_H__ +#ifndef __I8GEMM_INFO_H__ #define __I8GEMM_INFO_H__ #include "../../../utils.h" diff --git a/src/infiniop/ops/scaled_mm/int8_gemm.h b/src/infiniop/ops/scaled_mm/int8_gemm.h index d5a250e66..87c506f95 100644 --- a/src/infiniop/ops/scaled_mm/int8_gemm.h +++ b/src/infiniop/ops/scaled_mm/int8_gemm.h @@ -18,8 +18,8 @@ size_t workspace_size, \ infiniDtype_t out_dtype, \ infiniDevice_t device_type, int device_id) \ - : InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \ - _opaque(opaque), _info(info), _workspace_size(workspace_size) {} \ + : InfiniopDescriptor{device_type, device_id}, _opaque(opaque), \ + _workspace_size(workspace_size), _info(info), _out_dtype(out_dtype) {} \ \ public: \ ~Descriptor(); \ diff --git a/src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h b/src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h new file mode 100644 index 000000000..08aebff12 --- /dev/null +++ b/src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h @@ -0,0 +1,7 @@ +#ifndef __INT8_GEMM_MOORE_API_H__ +#define __INT8_GEMM_MOORE_API_H__ +#include "../int8_gemm.h" + +DESCRIPTOR(moore) + +#endif // __INT8_GEMM_MOORE_API_H__ diff --git a/src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.mu b/src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.mu new file mode 100644 index 000000000..848ae1e9e --- /dev/null +++ b/src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.mu @@ -0,0 +1,238 @@ +#include "../../../devices/moore/moore_common.h" +#include "../../../devices/moore/moore_handle.h" +#include "int8_gemm_moore.h" + +namespace op::i8gemm::moore { + +static void moore_i8gemm_launch( + const I8GemmInfo &info, + std::shared_ptr &internal, + void* out, + const int8_t* A, + const int8_t* B, + const float* A_scale, + const float* B_scale, + const void* bias, + infiniDtype_t out_dtype, + musaStream_t stream) +{ + internal->useMudnn(stream, + [&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t { + + // 1. Operator + auto matmul = std::make_unique<::musa::dnn::BatchMatMul>(); + matmul->SetComputeMode(::musa::dnn::BatchMatMul::ComputeMode::TENSOR); + + // 2. Tensors + ::musa::dnn::Tensor out_t, a_t, b_t, bias_t; + ::musa::dnn::Tensor scale_a_t, scale_b_t; + + // 3. Output dtype + if (out_dtype == INFINI_DTYPE_F16) { + out_t.SetType(::musa::dnn::Tensor::Type::HALF); + bias_t.SetType(::musa::dnn::Tensor::Type::HALF); + } else { + out_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + bias_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16); + } + + // 4. Input INT8 + a_t.SetType(::musa::dnn::Tensor::Type::INT8); + b_t.SetType(::musa::dnn::Tensor::Type::INT8); + + // 5. Scale (per-tensor) + scale_a_t.SetType(::musa::dnn::Tensor::Type::FLOAT); + scale_b_t.SetType(::musa::dnn::Tensor::Type::FLOAT); + + // 6. Bind memory + out_t.SetAddr(out); + a_t.SetAddr(const_cast(A)); + b_t.SetAddr(const_cast(B)); + scale_a_t.SetAddr(const_cast(A_scale)); + scale_b_t.SetAddr(const_cast(B_scale)); + + if (bias) + bias_t.SetAddr(const_cast(bias)); + + // 7. A NdInfo + { + std::array dims; + std::array strides; + + if (info.a_matrix.col_stride != 1) { + dims = {info.batch, info.k, info.m}; + } else { + dims = {info.batch, info.m, info.k}; + } + strides = { + info.a_matrix.stride, + info.a_matrix.ld(), + 1 + }; + a_t.SetNdInfo(3, dims.data(), strides.data()); + } + + // 8. B NdInfo + { + std::array dims; + std::array strides; + + if (info.b_matrix.col_stride != 1) { + dims = {info.batch, info.n, info.k}; + } else { + dims = {info.batch, info.k, info.n}; + } + strides = { + info.b_matrix.stride, + info.b_matrix.ld(), + 1 + }; + b_t.SetNdInfo(3, dims.data(), strides.data()); + } + + // 9. out NdInfo + { + std::array dims = { + info.batch, + info.m, + info.n + }; + + std::array strides = { + info.m * info.n, + info.n, + 1 + }; + + out_t.SetNdInfo(3, dims.data(), strides.data()); + } + + + // 10. Bias & scale NdInfo + if (bias) { + std::array dims = { info.n }; + std::array strides = { 1 }; + bias_t.SetNdInfo(1, dims.data(), strides.data()); + } + + { + std::array a_scale_dims = { info.batch, info.m, 1 }; + std::array a_scale_strides = { info.m, 1, 1 }; + scale_a_t.SetNdInfo(3, a_scale_dims.data(), a_scale_strides.data()); + + std::array b_scale_dims = { info.batch, 1, info.n }; + std::array b_scale_strides = { info.n, 1, 1 }; + scale_b_t.SetNdInfo(3, b_scale_dims.data(), b_scale_strides.data()); + + } + + // 11. Transpose + matmul->SetTranspose( + info.a_matrix.col_stride != 1, + info.b_matrix.col_stride != 1); + + // 12. Lt param (no epilogue enum) + ::musa::dnn::MatMulLtParam lt_param; + lt_param.SetScale( + scale_a_t, + scale_b_t, + ::musa::dnn::Tensor(), + ::musa::dnn::Tensor()); + + // 13. Alpha / Beta + matmul->SetAlpha(1.0); + matmul->SetBeta(0.0); + matmul->SetGamma(1.0); + + // 14. Workspace + ::musa::dnn::MemoryMaintainer maintainer = + [](size_t size) { + void* ptr = nullptr; + musaMalloc(&ptr, size); + return ::musa::dnn::MemoryHandler( + ptr, + [](void* p) { if (p) musaFree(p); }); + }; + + // 15. Run + matmul->RunLt( + mudnn_handle, + out_t, + a_t, + b_t, + ::musa::dnn::Tensor(), + bias ? bias_t : ::musa::dnn::Tensor(), + lt_param, + maintainer); + + return INFINI_STATUS_SUCCESS; + }); +} + +/* ================= Descriptor ================= */ + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t bias_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t a_scale_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scale_desc) +{ + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + + auto result = I8GemmInfo::create( + out_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + new Opaque{handle->internal()}, + result.take(), + 0, + dtype, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *bias, + const void *a, + const void *a_scale, + const void *b, + const void *b_scale, + void *stream_) const +{ + moore_i8gemm_launch( + _info, + _opaque->internal, + out, + static_cast(a), + static_cast(b), + static_cast(a_scale), + static_cast(b_scale), + bias, + _out_dtype, + reinterpret_cast(stream_)); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::i8gemm::moore diff --git a/src/infiniop/ops/scaled_mm/operator.cc b/src/infiniop/ops/scaled_mm/operator.cc index 9d51708c8..99ccbe7ee 100644 --- a/src/infiniop/ops/scaled_mm/operator.cc +++ b/src/infiniop/ops/scaled_mm/operator.cc @@ -6,6 +6,10 @@ #include "nvidia/int8_gemm_nvidia.cuh" #endif +#if defined(ENABLE_MOORE_API) +#include "moore/int8_gemm_moore.h" +#endif + __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle, infiniopI8GemmDescriptor_t *desc_ptr, infiniopTensorDescriptor_t out_desc, @@ -31,6 +35,9 @@ __C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle, #endif #if defined(ENABLE_QY_API) CREATE(INFINI_DEVICE_QY, nvidia) +#endif +#if defined(ENABLE_MOORE_API) + CREATE(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -49,6 +56,9 @@ __C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t des #endif #if defined(ENABLE_QY_API) GET(INFINI_DEVICE_QY, nvidia) +#endif +#if defined(ENABLE_MOORE_API) + GET(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -76,6 +86,9 @@ __C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc, #endif #if defined(ENABLE_QY_API) CACULATE(INFINI_DEVICE_QY, nvidia) +#endif +#if defined(ENABLE_MOORE_API) + CACULATE(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -94,6 +107,9 @@ __C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t de #endif #if defined(ENABLE_QY_API) DESTROY(INFINI_DEVICE_QY, nvidia) +#endif +#if defined(ENABLE_MOORE_API) + DESTROY(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/test/infiniop/scaled_mm_int8.py b/test/infiniop/scaled_mm_int8.py index 0712bfdf4..a5cfcff78 100644 --- a/test/infiniop/scaled_mm_int8.py +++ b/test/infiniop/scaled_mm_int8.py @@ -25,6 +25,7 @@ # These are not meant to be imported from other modules _TEST_CASES_ = [ # x_shape, w_shape, y_shape, alpha, beta + ((2, 4), (4, 2), (2, 2)), ((128, 512), (512, 1024), (128, 1024)), ((256, 1024), (1024, 2048), (256, 2048)), ((1024, 2048), (2048, 1024), (1024, 1024)), @@ -59,10 +60,8 @@ class Inplace(Enum): DEBUG = False PROFILE = False NUM_PRERUN = 10 -NUM_ITERATIONS = 1000 - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) +NUM_ITERATIONS = 100 + def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) @@ -72,6 +71,7 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) return o.to(out_dtype) + def test( handle, device, @@ -87,30 +87,38 @@ def test( ) M, K = x_shape N = w_shape[1] - - x_packed = to_int8(torch.randn((M, K), device="cuda") * 5) - weights = to_int8(torch.randn((N, K), device="cuda").t() * 5) - - x_scale = torch.randn((M,), device="cuda", dtype=torch.float32) - weights_scale = torch.randn((N,), device="cuda", dtype=torch.float32) - bias = torch.randn((N,), device="cuda", dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16) * 10 - - ans = torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias) - + x_packed = TestTensor( - (M, K), x_packed.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=x_packed - ) - x_scale = TestTensor( - (M,), x_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=x_scale + (M, K), + None, + InfiniDtype.I8, + device, + mode="randint", + randint_low=-128, + randint_high=127, ) weights = TestTensor( - (K, N), weights.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=weights + (K, N), + None, + InfiniDtype.I8, + device, + mode="randint", + randint_low=-128, + randint_high=127, ) - weights_scale = TestTensor( - (N,), weights_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=weights_scale + x_scale = TestTensor((M,), None, InfiniDtype.F32, device, mode="random") + weights_scale = TestTensor((N,), None, InfiniDtype.F32, device, mode="random") + bias = TestTensor((N,), None, dtype, device, mode="random") + y = TestTensor(y_shape, None, dtype, device, mode="zeros") + + ans = torch_scaled_mm( + x_packed.torch_tensor(), + weights.torch_tensor(), + x_scale.torch_tensor(), + weights_scale.torch_tensor(), + out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, + bias=bias.torch_tensor(), ) - y = TestTensor(y_shape, None, dtype, device) - bias = TestTensor((N,), bias.stride(), dtype, device, mode="manual", set_tensor=bias) descriptor = infiniopOperatorDescriptor_t() check_error( @@ -164,7 +172,20 @@ def lib_linear(): # Profiling workflow if PROFILE: # fmt: off - profile_operation("PyTorch", lambda: torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation( + "PyTorch", + lambda: torch_scaled_mm( + x_packed.torch_tensor(), + weights.torch_tensor(), + x_scale.torch_tensor(), + weights_scale.torch_tensor(), + out_dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, + bias=bias.torch_tensor() + ), + device, + NUM_PRERUN, + NUM_ITERATIONS + ) profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on @@ -181,6 +202,12 @@ def lib_linear(): NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): - test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + # muDNN(v3101): INT8 quantized multiplication → BF16 output. + # Moore backend: BF16 output only. + if args.moore == True: + _TENSOR_DTYPES_MOORE = [InfiniDtype.BF16] + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES_MOORE) + else: + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m") From e1974c6b7c9fbabc46a81ab8e824b6e26013f640 Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Wed, 11 Feb 2026 17:19:04 +0800 Subject: [PATCH 2/3] issue/972 - feat: add per_channel_quant_int8 for moore gpu referencing nvidia --- .../moore/per_channel_quant_int8_moore.h | 7 ++ .../moore/per_channel_quant_int8_moore.mu | 116 ++++++++++++++++++ .../quant/per_channel_quant_int8/operator.cc | 15 +++ xmake/moore.lua | 3 + 4 files changed, 141 insertions(+) create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.h create mode 100644 src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.mu diff --git a/src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.h b/src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.h new file mode 100644 index 000000000..757613410 --- /dev/null +++ b/src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.h @@ -0,0 +1,7 @@ +#ifndef __PER_CHANNEL_QUANT_INT8_MOORE_API_H__ +#define __PER_CHANNEL_QUANT_INT8_MOORE_API_H__ +#include "../per_channel_quant_int8.h" + +DESCRIPTOR(moore) + +#endif // __PER_CHANNEL_QUANT_INT8_MOORE_API_H__ diff --git a/src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.mu b/src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.mu new file mode 100644 index 000000000..429504b35 --- /dev/null +++ b/src/infiniop/ops/quant/per_channel_quant_int8/moore/per_channel_quant_int8_moore.mu @@ -0,0 +1,116 @@ +#include "../../../../devices/moore/moore_common.h" +#include "per_channel_quant_int8_moore.h" + +#include "../../../../devices/moore/moore_kernel_common.h" +#include "../../../../reduce/cuda/reduce.cuh" +#include + +#include "../cuda/kernel.cuh" + +template +INFINIOP_MOORE_KERNEL blockPerChannelQuantI8( + int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) { + blockPerChannelQuantI8Kernel(x_packed, x_scale, x_zero, x, M, K); +} +template +INFINIOP_MOORE_KERNEL blockPerChannelQuantI8Sym( + int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) { + blockPerChannelQuantI8SymKernel(x_packed, x_scale, x, M, K); +} + +template +INFINIOP_MOORE_KERNEL warpPerChannelQuantI8( + int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) { + warpPerChannelQuantI8Kernel(x_packed, x_scale, x_zero, x, M, K); +} +template +INFINIOP_MOORE_KERNEL warpPerChannelQuantI8Sym( + int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) { + warpPerChannelQuantI8SymKernel(x_packed, x_scale, x, M, K); +} + +namespace op::per_channel_quant_int8::moore { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t x_packed_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t x_desc) { + auto info = PerChannelQuantI8Info::createPerChannelQuantI8Info(x_packed_desc, x_scale_desc, x_zero_desc, x_desc); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t per_channel_quant_int8Kernel(const PerChannelQuantI8Info &info, int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, musaStream_t stream) { + int M = (int)info.M; + int K = (int)info.K; + + if (K >= 1024) { + if (x_zero == nullptr) { + blockPerChannelQuantI8Sym + <<>>(x_packed, x_scale, x, M, K); + } else { + blockPerChannelQuantI8 + <<>>(x_packed, x_scale, x_zero, x, M, K); + } + + } else { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + int num_block_x = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + if (x_zero == nullptr) { + warpPerChannelQuantI8Sym + <<>>(x_packed, x_scale, x, M, K); + } else { + warpPerChannelQuantI8 + <<>>(x_packed, x_scale, x_zero, x, M, K); + } + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *x_packed, void *x_scale, void *x_zero, const void *x, + void *stream_) const { + musaStream_t stream = (musaStream_t)stream_; +#define QUANT(BLOCK_SIZE, TDATA) \ + per_channel_quant_int8Kernel(_info, (int8_t *)x_packed, (float *)x_scale, (float *)x_zero, (const TDATA *)x, stream) +#define QUANT_WITH_BLOCK_SIZE(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return QUANT(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return QUANT(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return QUANT(BLOCK_SIZE, __mt_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) { + QUANT_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_1024) + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) { + QUANT_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_512) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::per_channel_quant_int8::moore diff --git a/src/infiniop/ops/quant/per_channel_quant_int8/operator.cc b/src/infiniop/ops/quant/per_channel_quant_int8/operator.cc index dade91c88..262d6f10a 100644 --- a/src/infiniop/ops/quant/per_channel_quant_int8/operator.cc +++ b/src/infiniop/ops/quant/per_channel_quant_int8/operator.cc @@ -5,6 +5,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #include "nvidia/per_channel_quant_int8_nvidia.cuh" #endif +#if defined(ENABLE_MOORE_API) +#include "moore/per_channel_quant_int8_moore.h" +#endif __C infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle, infiniopPerChannelQuantI8Descriptor_t *desc_ptr, @@ -27,6 +30,9 @@ __C infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t ha #endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia) +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -45,6 +51,9 @@ __C infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQ #endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia) +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -71,6 +80,9 @@ __C infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor #endif #ifdef ENABLE_QY_API QUANT(INFINI_DEVICE_QY, nvidia) +#endif +#ifdef ENABLE_MOORE_API + QUANT(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -90,6 +102,9 @@ __C infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannel #endif #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia) +#endif +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/xmake/moore.lua b/xmake/moore.lua index fdcad9564..2cd55e8b2 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -48,6 +48,9 @@ target("infiniop-moore") -- Add source files for Moore muBLAS/muDNN GEMM backends. add_files("../src/infiniop/ops/gemm/moore/*/*.mu", {rule = "mu"}) + + -- Add source files for Moore per_channel_quant_int8 backends. + add_files("../src/infiniop/ops/quant/per_channel_quant_int8/moore/*.mu", {rule = "mu"}) target_end() target("infinirt-moore") From 6841663b1ba5d5aead6157dc55c767a839b932d4 Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Wed, 11 Feb 2026 20:37:24 +0800 Subject: [PATCH 3/3] issue/972 - feat: adjust scaled_mm_int8 python test --- test/infiniop/scaled_mm_int8.py | 64 ++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/test/infiniop/scaled_mm_int8.py b/test/infiniop/scaled_mm_int8.py index a5cfcff78..9528c7ea3 100644 --- a/test/infiniop/scaled_mm_int8.py +++ b/test/infiniop/scaled_mm_int8.py @@ -25,7 +25,6 @@ # These are not meant to be imported from other modules _TEST_CASES_ = [ # x_shape, w_shape, y_shape, alpha, beta - ((2, 4), (4, 2), (2, 2)), ((128, 512), (512, 1024), (128, 1024)), ((256, 1024), (1024, 2048), (256, 2048)), ((1024, 2048), (2048, 1024), (1024, 1024)), @@ -83,12 +82,16 @@ def test( sync=None, ): print( - f"Testing Linear on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}" + f"Testing scaled_mm_int8 on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}" ) M, K = x_shape N = w_shape[1] - x_packed = TestTensor( + # --- Tensor Descriptor --- + # orig: create a random int8 tensor as the reference data source + # torch: extract the torch view to adjust layout/stride + # final: wrap it back as TestTensor with explicit stride for device execution + x_packed_orig = TestTensor( (M, K), None, InfiniDtype.I8, @@ -97,8 +100,18 @@ def test( randint_low=-128, randint_high=127, ) - weights = TestTensor( - (K, N), + x_packed_torch = x_packed_orig.torch_tensor() + x_packed = TestTensor( + (M, K), + x_packed_torch.stride(), + InfiniDtype.I8, + device, + mode="manual", + set_tensor=x_packed_torch, + ) + + weights_orig = TestTensor( + (N, K), None, InfiniDtype.I8, device, @@ -106,9 +119,44 @@ def test( randint_low=-128, randint_high=127, ) - x_scale = TestTensor((M,), None, InfiniDtype.F32, device, mode="random") - weights_scale = TestTensor((N,), None, InfiniDtype.F32, device, mode="random") - bias = TestTensor((N,), None, dtype, device, mode="random") + weights_torch = weights_orig.torch_tensor().t() + weights = TestTensor( + (K, N), + weights_torch.stride(), + InfiniDtype.I8, + device, + mode="manual", + set_tensor=weights_torch, + ) + + x_scale_orig = TestTensor((M,), None, InfiniDtype.F32, device, mode="random") + x_scale_torch = x_scale_orig.torch_tensor() + x_scale = TestTensor( + (M,), + x_scale_torch.stride(), + InfiniDtype.F32, + device, + mode="manual", + set_tensor=x_scale_torch, + ) + + weights_scale_orig = TestTensor((N,), None, InfiniDtype.F32, device, mode="random") + weights_scale_torch = weights_scale_orig.torch_tensor() + weights_scale = TestTensor( + (N,), + weights_scale_torch.stride(), + InfiniDtype.F32, + device, + mode="manual", + set_tensor=weights_scale_torch, + ) + + bias_orig = TestTensor((N,), None, dtype, device, mode="random") + bias_torch = bias_orig.torch_tensor() + bias = TestTensor( + (N,), bias_torch.stride(), dtype, device, mode="manual", set_tensor=bias_torch + ) + y = TestTensor(y_shape, None, dtype, device, mode="zeros") ans = torch_scaled_mm(