From 7c9029e113b6b6aea108bf40a5f3d8c73e0a548f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 21 Jan 2026 12:45:29 +0000 Subject: [PATCH 1/5] Add `max_pool2d` operator --- src/ntops/kernels/__init__.py | 2 + src/ntops/kernels/max_pool2d.py | 118 ++++++++++++++++++++++++++++++++ src/ntops/torch/__init__.py | 2 + src/ntops/torch/max_pool2d.py | 74 ++++++++++++++++++++ tests/test_max_pool2d.py | 48 +++++++++++++ 5 files changed, 244 insertions(+) create mode 100644 src/ntops/kernels/max_pool2d.py create mode 100644 src/ntops/torch/max_pool2d.py create mode 100644 tests/test_max_pool2d.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..98cd699 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -20,6 +20,7 @@ layer_norm, le, lt, + max_pool2d, mm, mul, ne, @@ -60,6 +61,7 @@ "layer_norm", "le", "lt", + "max_pool2d", "mm", "mul", "ne", diff --git a/src/ntops/kernels/max_pool2d.py b/src/ntops/kernels/max_pool2d.py new file mode 100644 index 0000000..19a2d98 --- /dev/null +++ b/src/ntops/kernels/max_pool2d.py @@ -0,0 +1,118 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = ninetoothed.block_size() + +KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16) +KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16) +STRIDE_H = Symbol("stride_h", constexpr=True) +STRIDE_W = Symbol("stride_w", constexpr=True) +PADDING_H = Symbol("padding_h", constexpr=True) +PADDING_W = Symbol("padding_w", constexpr=True) +DILATION_H = Symbol("dilation_h", constexpr=True) +DILATION_W = Symbol("dilation_w", constexpr=True) + + +def arrangement( + input, + output, + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + block_size=None, +): + if kernel_size_h is None: + kernel_size_h = KERNEL_SIZE_H + + if kernel_size_w is None: + kernel_size_w = KERNEL_SIZE_W + + if stride_h is None: + stride_h = STRIDE_H + + if stride_w is None: + stride_w = STRIDE_W + + if padding_h is None: + padding_h = PADDING_H + + if padding_w is None: + padding_w = PADDING_W + + if dilation_h is None: + dilation_h = DILATION_H + + if dilation_w is None: + dilation_w = DILATION_W + + if ceil_mode is None: + ceil_mode = False + + if block_size is None: + block_size = BLOCK_SIZE + + input_arranged = input.pad( + ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) + ) + input_arranged = input_arranged.tile( + (1, 1, kernel_size_h, kernel_size_w), + strides=(-1, -1, stride_h, stride_w), + dilation=(1, 1, dilation_h, dilation_w), + floor_mode=not ceil_mode, + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + output_arranged = output.tile((1, 1, 1, 1)) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged + + +def application(input, output): + output = ntl.max(input, axis=-1) # noqa: F841 + + +def premake( + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + dtype=None, + block_size=None, +): + arrangement_ = functools.partial( + arrangement, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + padding_h=padding_h, + padding_w=padding_w, + dilation_h=dilation_h, + dilation_w=dilation_w, + ceil_mode=ceil_mode, + block_size=block_size, + ) + + tensors = (Tensor(4, dtype=dtype, other=float("-inf")), Tensor(4, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..efaef78 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -20,6 +20,7 @@ from ntops.torch.le import le from ntops.torch.lt import lt from ntops.torch.matmul import matmul +from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm from ntops.torch.mul import mul from ntops.torch.ne import ne @@ -60,6 +61,7 @@ "le", "lt", "matmul", + "max_pool2d", "mm", "mul", "ne", diff --git a/src/ntops/torch/max_pool2d.py b/src/ntops/torch/max_pool2d.py new file mode 100644 index 0000000..1e79609 --- /dev/null +++ b/src/ntops/torch/max_pool2d.py @@ -0,0 +1,74 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def max_pool2d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + if stride is None: + stride = kernel_size + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(padding, int): + padding = (padding, padding) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + assert not ceil_mode, "`ceil_mode` is not supported yet." + + assert not return_indices, "`return_indices` is not supported yet." + + n, c, h, w = input.shape + + def _calculate_output_size( + input_size, kernel_size, stride, padding, dilation, ceil_mode + ): + int_ = math.ceil if ceil_mode else math.floor + + result = int_( + (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + ) + + if ceil_mode and (result - 1) * stride >= input_size + padding: + result -= 1 + + return result + + h_ = _calculate_output_size( + h, kernel_size[0], stride[0], padding[0], dilation[0], ceil_mode + ) + w_ = _calculate_output_size( + w, kernel_size[1], stride[1], padding[1], dilation[1], ceil_mode + ) + + output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.max_pool2d.premake, ceil_mode=ceil_mode) + + kernel( + input, + output, + kernel_size_h=kernel_size[0], + kernel_size_w=kernel_size[1], + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + ) + + return output diff --git a/tests/test_max_pool2d.py b/tests/test_max_pool2d.py new file mode 100644 index 0000000..9f29485 --- /dev/null +++ b/tests/test_max_pool2d.py @@ -0,0 +1,48 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +@pytest.mark.parametrize("ceil_mode", (False,)) +@pytest.mark.parametrize("dilation", (1, 2, (2, 3))) +@pytest.mark.parametrize("padding", (0, 1, (2, 3))) +@pytest.mark.parametrize("stride", (None, 1, (2, 3))) +@pytest.mark.parametrize("kernel_size", ((1, 1), (3, 3))) +@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),)) +def test_max_pool2d( + n, c, h, w, kernel_size, stride, padding, dilation, ceil_mode, dtype, device +): + padding_ = padding + + if isinstance(padding_, int): + padding_ = (padding_, padding_) + + if padding_[0] > kernel_size[0] / 2 or padding_[1] > kernel_size[1] / 2: + pytest.skip(reason="Invalid padding.") + + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.max_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + reference_output = F.max_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + assert torch.allclose(ninetoothed_output, reference_output) From dea2c85a05ae0bc9e8dd418292a7b95a55dcaaad Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 22 Jan 2026 16:31:07 +0800 Subject: [PATCH 2/5] Add `avg_pool2d` operator --- src/ntops/kernels/__init__.py | 2 + src/ntops/kernels/avg_pool2d.py | 118 ++++++++++++++++++++++++++++++++ src/ntops/torch/__init__.py | 2 + src/ntops/torch/avg_pool2d.py | 68 ++++++++++++++++++ tests/test_avg_pool2d.py | 47 +++++++++++++ 5 files changed, 237 insertions(+) create mode 100644 src/ntops/kernels/avg_pool2d.py create mode 100644 src/ntops/torch/avg_pool2d.py create mode 100644 tests/test_avg_pool2d.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 98cd699..28bf2f4 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -2,6 +2,7 @@ abs, add, addmm, + avg_pool2d, bitwise_and, bitwise_not, bitwise_or, @@ -43,6 +44,7 @@ "abs", "add", "addmm", + "avg_pool2d", "bitwise_and", "bitwise_not", "bitwise_or", diff --git a/src/ntops/kernels/avg_pool2d.py b/src/ntops/kernels/avg_pool2d.py new file mode 100644 index 0000000..bbf0bb3 --- /dev/null +++ b/src/ntops/kernels/avg_pool2d.py @@ -0,0 +1,118 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = ninetoothed.block_size() + +KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16) +KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16) +STRIDE_H = Symbol("stride_h", constexpr=True) +STRIDE_W = Symbol("stride_w", constexpr=True) +PADDING_H = Symbol("padding_h", constexpr=True) +PADDING_W = Symbol("padding_w", constexpr=True) +DILATION_H = Symbol("dilation_h", constexpr=True) +DILATION_W = Symbol("dilation_w", constexpr=True) + + +def arrangement( + input, + output, + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + block_size=None, +): + if kernel_size_h is None: + kernel_size_h = KERNEL_SIZE_H + + if kernel_size_w is None: + kernel_size_w = KERNEL_SIZE_W + + if stride_h is None: + stride_h = STRIDE_H + + if stride_w is None: + stride_w = STRIDE_W + + if padding_h is None: + padding_h = PADDING_H + + if padding_w is None: + padding_w = PADDING_W + + if dilation_h is None: + dilation_h = DILATION_H + + if dilation_w is None: + dilation_w = DILATION_W + + if ceil_mode is None: + ceil_mode = False + + if block_size is None: + block_size = BLOCK_SIZE + + input_arranged = input.pad( + ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) + ) + input_arranged = input_arranged.tile( + (1, 1, kernel_size_h, kernel_size_w), + strides=(-1, -1, stride_h, stride_w), + dilation=(1, 1, dilation_h, dilation_w), + floor_mode=not ceil_mode, + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + output_arranged = output.tile((1, 1, 1, 1)) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged + + +def application(input, output): + output = ntl.sum(input, axis=-1) / input.shape[-1] # noqa: F841 + + +def premake( + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + dtype=None, + block_size=None, +): + arrangement_ = functools.partial( + arrangement, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + padding_h=padding_h, + padding_w=padding_w, + dilation_h=dilation_h, + dilation_w=dilation_w, + ceil_mode=ceil_mode, + block_size=block_size, + ) + + tensors = (Tensor(4, dtype=dtype), Tensor(4, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index efaef78..7474c8c 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -1,6 +1,7 @@ from ntops.torch.abs import abs from ntops.torch.add import add from ntops.torch.addmm import addmm +from ntops.torch.avg_pool2d import avg_pool2d from ntops.torch.bitwise_and import bitwise_and from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or @@ -42,6 +43,7 @@ "abs", "add", "addmm", + "avg_pool2d", "bitwise_and", "bitwise_not", "bitwise_or", diff --git a/src/ntops/torch/avg_pool2d.py b/src/ntops/torch/avg_pool2d.py new file mode 100644 index 0000000..d04186c --- /dev/null +++ b/src/ntops/torch/avg_pool2d.py @@ -0,0 +1,68 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def avg_pool2d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + if stride is None: + stride = kernel_size + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(padding, int): + padding = (padding, padding) + + assert not ceil_mode, "`ceil_mode` is not supported yet." + + assert count_include_pad, "`count_include_pad` is not supported yet." + + assert divisor_override is None, "`divisor_override` is not supported yet." + + n, c, h, w = input.shape + + def _calculate_output_size(input_size, kernel_size, stride, padding, ceil_mode): + int_ = math.ceil if ceil_mode else math.floor + + result = int_((input_size + 2 * padding - kernel_size) / stride + 1) + + if ceil_mode and (result - 1) * stride >= input_size + padding: + result -= 1 + + return result + + h_ = _calculate_output_size(h, kernel_size[0], stride[0], padding[0], ceil_mode) + w_ = _calculate_output_size(w, kernel_size[1], stride[1], padding[1], ceil_mode) + + output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device) + + kernel = _cached_make( + ntops.kernels.avg_pool2d.premake, + dilation_h=1, + dilation_w=1, + ceil_mode=ceil_mode, + ) + + kernel( + input, + output, + kernel_size_h=kernel_size[0], + kernel_size_w=kernel_size[1], + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + ) + + return output diff --git a/tests/test_avg_pool2d.py b/tests/test_avg_pool2d.py new file mode 100644 index 0000000..fe36390 --- /dev/null +++ b/tests/test_avg_pool2d.py @@ -0,0 +1,47 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize( + "dtype, rtol, atol", ((torch.float32, 1e-5, 1e-5), (torch.float16, 1e-3, 1e-3)) +) +@pytest.mark.parametrize("ceil_mode", (False,)) +@pytest.mark.parametrize("padding", (0, 1, (2, 3))) +@pytest.mark.parametrize("stride", (None, 1, (2, 3))) +@pytest.mark.parametrize("kernel_size", ((1, 1), (3, 3))) +@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),)) +def test_avg_pool2d( + n, c, h, w, kernel_size, stride, padding, ceil_mode, dtype, device, rtol, atol +): + padding_ = padding + + if isinstance(padding_, int): + padding_ = (padding_, padding_) + + if padding_[0] > kernel_size[0] / 2 or padding_[1] > kernel_size[1] / 2: + pytest.skip(reason="Invalid padding.") + + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.avg_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + ) + reference_output = F.avg_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) From d8235d606ef7dfdf41caf0c6cc1c9a2280bdcb3c Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 22 Jan 2026 17:11:37 +0800 Subject: [PATCH 3/5] Add `pooling.arrangement` --- src/ntops/kernels/avg_pool2d.py | 80 +-------------------------------- src/ntops/kernels/max_pool2d.py | 80 +-------------------------------- src/ntops/kernels/pooling.py | 68 ++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 156 deletions(-) create mode 100644 src/ntops/kernels/pooling.py diff --git a/src/ntops/kernels/avg_pool2d.py b/src/ntops/kernels/avg_pool2d.py index bbf0bb3..55588f0 100644 --- a/src/ntops/kernels/avg_pool2d.py +++ b/src/ntops/kernels/avg_pool2d.py @@ -1,85 +1,9 @@ import functools -import ninetoothed import ninetoothed.language as ntl -from ninetoothed import Symbol, Tensor +from ninetoothed import Tensor -BLOCK_SIZE = ninetoothed.block_size() - -KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16) -KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16) -STRIDE_H = Symbol("stride_h", constexpr=True) -STRIDE_W = Symbol("stride_w", constexpr=True) -PADDING_H = Symbol("padding_h", constexpr=True) -PADDING_W = Symbol("padding_w", constexpr=True) -DILATION_H = Symbol("dilation_h", constexpr=True) -DILATION_W = Symbol("dilation_w", constexpr=True) - - -def arrangement( - input, - output, - kernel_size_h=None, - kernel_size_w=None, - stride_h=None, - stride_w=None, - padding_h=None, - padding_w=None, - dilation_h=None, - dilation_w=None, - ceil_mode=None, - block_size=None, -): - if kernel_size_h is None: - kernel_size_h = KERNEL_SIZE_H - - if kernel_size_w is None: - kernel_size_w = KERNEL_SIZE_W - - if stride_h is None: - stride_h = STRIDE_H - - if stride_w is None: - stride_w = STRIDE_W - - if padding_h is None: - padding_h = PADDING_H - - if padding_w is None: - padding_w = PADDING_W - - if dilation_h is None: - dilation_h = DILATION_H - - if dilation_w is None: - dilation_w = DILATION_W - - if ceil_mode is None: - ceil_mode = False - - if block_size is None: - block_size = BLOCK_SIZE - - input_arranged = input.pad( - ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) - ) - input_arranged = input_arranged.tile( - (1, 1, kernel_size_h, kernel_size_w), - strides=(-1, -1, stride_h, stride_w), - dilation=(1, 1, dilation_h, dilation_w), - floor_mode=not ceil_mode, - ) - input_arranged = input_arranged.ravel() - input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) - input_arranged = input_arranged.tile((block_size, -1)) - - output_arranged = output.tile((1, 1, 1, 1)) - output_arranged = output_arranged.ravel() - output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) - output_arranged = output_arranged.tile((block_size, -1)) - output_arranged.dtype = output_arranged.dtype.squeeze(1) - - return input_arranged, output_arranged +from ntops.kernels.pooling import arrangement def application(input, output): diff --git a/src/ntops/kernels/max_pool2d.py b/src/ntops/kernels/max_pool2d.py index 19a2d98..c01edbc 100644 --- a/src/ntops/kernels/max_pool2d.py +++ b/src/ntops/kernels/max_pool2d.py @@ -1,85 +1,9 @@ import functools -import ninetoothed import ninetoothed.language as ntl -from ninetoothed import Symbol, Tensor +from ninetoothed import Tensor -BLOCK_SIZE = ninetoothed.block_size() - -KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16) -KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16) -STRIDE_H = Symbol("stride_h", constexpr=True) -STRIDE_W = Symbol("stride_w", constexpr=True) -PADDING_H = Symbol("padding_h", constexpr=True) -PADDING_W = Symbol("padding_w", constexpr=True) -DILATION_H = Symbol("dilation_h", constexpr=True) -DILATION_W = Symbol("dilation_w", constexpr=True) - - -def arrangement( - input, - output, - kernel_size_h=None, - kernel_size_w=None, - stride_h=None, - stride_w=None, - padding_h=None, - padding_w=None, - dilation_h=None, - dilation_w=None, - ceil_mode=None, - block_size=None, -): - if kernel_size_h is None: - kernel_size_h = KERNEL_SIZE_H - - if kernel_size_w is None: - kernel_size_w = KERNEL_SIZE_W - - if stride_h is None: - stride_h = STRIDE_H - - if stride_w is None: - stride_w = STRIDE_W - - if padding_h is None: - padding_h = PADDING_H - - if padding_w is None: - padding_w = PADDING_W - - if dilation_h is None: - dilation_h = DILATION_H - - if dilation_w is None: - dilation_w = DILATION_W - - if ceil_mode is None: - ceil_mode = False - - if block_size is None: - block_size = BLOCK_SIZE - - input_arranged = input.pad( - ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) - ) - input_arranged = input_arranged.tile( - (1, 1, kernel_size_h, kernel_size_w), - strides=(-1, -1, stride_h, stride_w), - dilation=(1, 1, dilation_h, dilation_w), - floor_mode=not ceil_mode, - ) - input_arranged = input_arranged.ravel() - input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) - input_arranged = input_arranged.tile((block_size, -1)) - - output_arranged = output.tile((1, 1, 1, 1)) - output_arranged = output_arranged.ravel() - output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) - output_arranged = output_arranged.tile((block_size, -1)) - output_arranged.dtype = output_arranged.dtype.squeeze(1) - - return input_arranged, output_arranged +from ntops.kernels.pooling import arrangement def application(input, output): diff --git a/src/ntops/kernels/pooling.py b/src/ntops/kernels/pooling.py new file mode 100644 index 0000000..fcce19f --- /dev/null +++ b/src/ntops/kernels/pooling.py @@ -0,0 +1,68 @@ +import ninetoothed +from ninetoothed import Symbol + + +def arrangement( + input, + output, + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + block_size=None, +): + if kernel_size_h is None: + kernel_size_h = Symbol("kernel_size_h", constexpr=True, upper_bound=16) + + if kernel_size_w is None: + kernel_size_w = Symbol("kernel_size_w", constexpr=True, upper_bound=16) + + if stride_h is None: + stride_h = Symbol("stride_h", constexpr=True) + + if stride_w is None: + stride_w = Symbol("stride_w", constexpr=True) + + if padding_h is None: + padding_h = Symbol("padding_h", constexpr=True) + + if padding_w is None: + padding_w = Symbol("padding_w", constexpr=True) + + if dilation_h is None: + dilation_h = Symbol("dilation_h", constexpr=True) + + if dilation_w is None: + dilation_w = Symbol("dilation_w", constexpr=True) + + if ceil_mode is None: + ceil_mode = False + + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.pad( + ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) + ) + input_arranged = input_arranged.tile( + (1, 1, kernel_size_h, kernel_size_w), + strides=(-1, -1, stride_h, stride_w), + dilation=(1, 1, dilation_h, dilation_w), + floor_mode=not ceil_mode, + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + output_arranged = output.tile((1, 1, 1, 1)) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged From a55bb0574b28e1f1974c49ffde97759b192c19ba Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 22 Jan 2026 17:18:39 +0800 Subject: [PATCH 4/5] Add `ntops.torch.pooling._calculate_output_size` --- src/ntops/torch/avg_pool2d.py | 21 +++++++-------------- src/ntops/torch/max_pool2d.py | 31 +++++++++++++------------------ src/ntops/torch/pooling.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 32 deletions(-) create mode 100644 src/ntops/torch/pooling.py diff --git a/src/ntops/torch/avg_pool2d.py b/src/ntops/torch/avg_pool2d.py index d04186c..eeaf2d7 100644 --- a/src/ntops/torch/avg_pool2d.py +++ b/src/ntops/torch/avg_pool2d.py @@ -1,8 +1,7 @@ -import math - import torch import ntops +from ntops.torch.pooling import _calculate_output_size from ntops.torch.utils import _cached_make @@ -32,18 +31,12 @@ def avg_pool2d( n, c, h, w = input.shape - def _calculate_output_size(input_size, kernel_size, stride, padding, ceil_mode): - int_ = math.ceil if ceil_mode else math.floor - - result = int_((input_size + 2 * padding - kernel_size) / stride + 1) - - if ceil_mode and (result - 1) * stride >= input_size + padding: - result -= 1 - - return result - - h_ = _calculate_output_size(h, kernel_size[0], stride[0], padding[0], ceil_mode) - w_ = _calculate_output_size(w, kernel_size[1], stride[1], padding[1], ceil_mode) + h_ = _calculate_output_size( + h, kernel_size[0], stride=stride[0], padding=padding[0], ceil_mode=ceil_mode + ) + w_ = _calculate_output_size( + w, kernel_size[1], stride=stride[1], padding=padding[1], ceil_mode=ceil_mode + ) output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device) diff --git a/src/ntops/torch/max_pool2d.py b/src/ntops/torch/max_pool2d.py index 1e79609..f4eb55c 100644 --- a/src/ntops/torch/max_pool2d.py +++ b/src/ntops/torch/max_pool2d.py @@ -1,8 +1,7 @@ -import math - import torch import ntops +from ntops.torch.pooling import _calculate_output_size from ntops.torch.utils import _cached_make @@ -33,25 +32,21 @@ def max_pool2d( n, c, h, w = input.shape - def _calculate_output_size( - input_size, kernel_size, stride, padding, dilation, ceil_mode - ): - int_ = math.ceil if ceil_mode else math.floor - - result = int_( - (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 - ) - - if ceil_mode and (result - 1) * stride >= input_size + padding: - result -= 1 - - return result - h_ = _calculate_output_size( - h, kernel_size[0], stride[0], padding[0], dilation[0], ceil_mode + h, + kernel_size[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + ceil_mode=ceil_mode, ) w_ = _calculate_output_size( - w, kernel_size[1], stride[1], padding[1], dilation[1], ceil_mode + w, + kernel_size[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + ceil_mode=ceil_mode, ) output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device) diff --git a/src/ntops/torch/pooling.py b/src/ntops/torch/pooling.py new file mode 100644 index 0000000..9820938 --- /dev/null +++ b/src/ntops/torch/pooling.py @@ -0,0 +1,19 @@ +import math + + +def _calculate_output_size( + input_size, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False +): + if stride is None: + stride = kernel_size + + int_ = math.ceil if ceil_mode else math.floor + + result = int_( + (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + ) + + if ceil_mode and (result - 1) * stride >= input_size + padding: + result -= 1 + + return result From 9440ce92c1ec2ef6a983dd166f992269487164ba Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 22 Jan 2026 19:27:12 +0800 Subject: [PATCH 5/5] Add `conv2d` operator --- src/ntops/kernels/__init__.py | 2 + src/ntops/kernels/conv2d.py | 144 ++++++++++++++++++++++++++++++++++ src/ntops/torch/__init__.py | 2 + src/ntops/torch/conv2d.py | 51 ++++++++++++ tests/test_conv2d.py | 33 ++++++++ 5 files changed, 232 insertions(+) create mode 100644 src/ntops/kernels/conv2d.py create mode 100644 src/ntops/torch/conv2d.py create mode 100644 tests/test_conv2d.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 28bf2f4..f6934ef 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -8,6 +8,7 @@ bitwise_or, bmm, clamp, + conv2d, cos, div, dropout, @@ -50,6 +51,7 @@ "bitwise_or", "bmm", "clamp", + "conv2d", "cos", "div", "dropout", diff --git a/src/ntops/kernels/conv2d.py b/src/ntops/kernels/conv2d.py new file mode 100644 index 0000000..96f24d8 --- /dev/null +++ b/src/ntops/kernels/conv2d.py @@ -0,0 +1,144 @@ +import copy +import functools + +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +from ntops.kernels import mm + + +def arrangement( + input, + weight, + bias, + output, + input_precision, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + block_size_m=None, + block_size_n=None, + block_size_k=None, +): + if stride_h is None: + stride_h = Symbol("stride_h", constexpr=True) + + if stride_w is None: + stride_w = Symbol("stride_w", constexpr=True) + + if padding_h is None: + padding_h = Symbol("padding_h", constexpr=True) + + if padding_w is None: + padding_w = Symbol("padding_w", constexpr=True) + + if dilation_h is None: + dilation_h = Symbol("dilation_h", constexpr=True) + + if dilation_w is None: + dilation_w = Symbol("dilation_w", constexpr=True) + + if block_size_m is None: + block_size_m = mm.BLOCK_SIZE_M + + if block_size_n is None: + block_size_n = mm.BLOCK_SIZE_N + + if block_size_k is None: + block_size_k = mm.BLOCK_SIZE_K + + mm_arrangement = functools.partial( + mm.arrangement, + block_size_m=block_size_m, + block_size_n=block_size_n, + block_size_k=block_size_k, + ) + + input_arranged = input.pad( + ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) + ) + input_arranged = input_arranged.tile( + (1, *weight.shape[1:]), + strides=(-1, -1, stride_h, stride_w), + dilation=(1, 1, dilation_h, dilation_w), + floor_mode=True, + ) + input_arranged = input_arranged.squeeze(1) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1) + + weight_arranged = weight.flatten(start_dim=1) + weight_arranged = weight_arranged.permute((1, 0)) + + bias_arranged = bias[None, :, None, None].expand( + (output.shape[0], -1, output.shape[2], output.shape[3]) + ) + bias_arranged = bias_arranged.permute((0, 2, 3, 1)).flatten(end_dim=3) + + output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3) + + _, _, bias_arranged, _ = mm_arrangement( + copy.deepcopy(input_arranged), + copy.deepcopy(weight_arranged), + bias_arranged, + copy.deepcopy(input_precision), + ) + + input_arranged, weight_arranged, output_arranged, input_precision_arranged = ( + mm_arrangement( + input_arranged, weight_arranged, output_arranged, input_precision + ) + ) + + return ( + input_arranged, + weight_arranged, + bias_arranged, + output_arranged, + input_precision_arranged, + ) + + +def application(input, weight, bias, output, input_precision): + mm_output = ntl.zeros(output.shape, dtype=ntl.float32) + mm.application(input, weight, mm_output, input_precision) + output = mm_output + bias + + +def premake( + input_precision=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + dtype=None, + block_size_m=None, + block_size_n=None, + block_size_k=None, +): + arrangement_ = functools.partial( + arrangement, + stride_h=stride_h, + stride_w=stride_w, + padding_h=padding_h, + padding_w=padding_w, + dilation_h=dilation_h, + dilation_w=dilation_w, + block_size_m=block_size_m, + block_size_n=block_size_n, + block_size_k=block_size_k, + ) + + input, weight, output = (Tensor(4, dtype=dtype) for _ in range(3)) + bias = Tensor(1, dtype=dtype) + input_precision = Tensor(0, dtype=dtype, constexpr=True, value=input_precision) + + tensors = (input, weight, bias, output, input_precision) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 7474c8c..82fc596 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -7,6 +7,7 @@ from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm from ntops.torch.clamp import clamp +from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos from ntops.torch.div import div from ntops.torch.dropout import dropout @@ -49,6 +50,7 @@ "bitwise_or", "bmm", "clamp", + "conv2d", "cos", "div", "dropout", diff --git a/src/ntops/torch/conv2d.py b/src/ntops/torch/conv2d.py new file mode 100644 index 0000000..5dadd39 --- /dev/null +++ b/src/ntops/torch/conv2d.py @@ -0,0 +1,51 @@ +import torch + +import ntops +from ntops.torch.pooling import _calculate_output_size +from ntops.torch.utils import _cached_make, _get_matmul_input_precision + + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(padding, str): + if padding == "valid": + padding = 0 + + if isinstance(padding, int): + padding = (padding, padding) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + assert groups == 1, "`groups` is not supported yet." + + n, c, h, w = input.shape + k, _, r, s = weight.shape + + p = _calculate_output_size( + h, r, stride=stride[0], padding=padding[0], dilation=dilation[0] + ) + q = _calculate_output_size( + w, s, stride=stride[1], padding=padding[1], dilation=dilation[1] + ) + + output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + + if bias is None: + bias = torch.zeros((k,), dtype=output.dtype, device=output.device) + + kernel = _cached_make( + ntops.kernels.conv2d.premake, + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + ) + + kernel(input, weight, bias, output, _get_matmul_input_precision()) + + return output diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py new file mode 100644 index 0000000..279d66b --- /dev/null +++ b/tests/test_conv2d.py @@ -0,0 +1,33 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize( + "dtype, rtol, atol", ((torch.float32, 1e-5, 1e-5), (torch.float16, 1e-3, 1e-3)) +) +@pytest.mark.parametrize("dilation", (1, 2, (2, 3))) +@pytest.mark.parametrize("padding", (0, 1, (2, 3))) +@pytest.mark.parametrize("stride", (1, 2, (2, 3))) +@pytest.mark.parametrize("r, s", ((1, 1), (3, 3))) +@pytest.mark.parametrize("n, c, h, w, k", ((2, 3, 112, 112, 4),)) +def test_conv2d( + n, c, h, w, k, r, s, stride, padding, dilation, dtype, device, rtol, atol +): + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + weight = torch.randn((k, c, r, s), dtype=dtype, device=device) + bias = torch.randn((k,), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.conv2d( + input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation + ) + reference_output = F.conv2d( + input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)