diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..f6934ef 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -2,11 +2,13 @@ abs, add, addmm, + avg_pool2d, bitwise_and, bitwise_not, bitwise_or, bmm, clamp, + conv2d, cos, div, dropout, @@ -20,6 +22,7 @@ layer_norm, le, lt, + max_pool2d, mm, mul, ne, @@ -42,11 +45,13 @@ "abs", "add", "addmm", + "avg_pool2d", "bitwise_and", "bitwise_not", "bitwise_or", "bmm", "clamp", + "conv2d", "cos", "div", "dropout", @@ -60,6 +65,7 @@ "layer_norm", "le", "lt", + "max_pool2d", "mm", "mul", "ne", diff --git a/src/ntops/kernels/avg_pool2d.py b/src/ntops/kernels/avg_pool2d.py new file mode 100644 index 0000000..55588f0 --- /dev/null +++ b/src/ntops/kernels/avg_pool2d.py @@ -0,0 +1,42 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.pooling import arrangement + + +def application(input, output): + output = ntl.sum(input, axis=-1) / input.shape[-1] # noqa: F841 + + +def premake( + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + dtype=None, + block_size=None, +): + arrangement_ = functools.partial( + arrangement, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + padding_h=padding_h, + padding_w=padding_w, + dilation_h=dilation_h, + dilation_w=dilation_w, + ceil_mode=ceil_mode, + block_size=block_size, + ) + + tensors = (Tensor(4, dtype=dtype), Tensor(4, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/conv2d.py b/src/ntops/kernels/conv2d.py new file mode 100644 index 0000000..96f24d8 --- /dev/null +++ b/src/ntops/kernels/conv2d.py @@ -0,0 +1,144 @@ +import copy +import functools + +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +from ntops.kernels import mm + + +def arrangement( + input, + weight, + bias, + output, + input_precision, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + block_size_m=None, + block_size_n=None, + block_size_k=None, +): + if stride_h is None: + stride_h = Symbol("stride_h", constexpr=True) + + if stride_w is None: + stride_w = Symbol("stride_w", constexpr=True) + + if padding_h is None: + padding_h = Symbol("padding_h", constexpr=True) + + if padding_w is None: + padding_w = Symbol("padding_w", constexpr=True) + + if dilation_h is None: + dilation_h = Symbol("dilation_h", constexpr=True) + + if dilation_w is None: + dilation_w = Symbol("dilation_w", constexpr=True) + + if block_size_m is None: + block_size_m = mm.BLOCK_SIZE_M + + if block_size_n is None: + block_size_n = mm.BLOCK_SIZE_N + + if block_size_k is None: + block_size_k = mm.BLOCK_SIZE_K + + mm_arrangement = functools.partial( + mm.arrangement, + block_size_m=block_size_m, + block_size_n=block_size_n, + block_size_k=block_size_k, + ) + + input_arranged = input.pad( + ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) + ) + input_arranged = input_arranged.tile( + (1, *weight.shape[1:]), + strides=(-1, -1, stride_h, stride_w), + dilation=(1, 1, dilation_h, dilation_w), + floor_mode=True, + ) + input_arranged = input_arranged.squeeze(1) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1) + + weight_arranged = weight.flatten(start_dim=1) + weight_arranged = weight_arranged.permute((1, 0)) + + bias_arranged = bias[None, :, None, None].expand( + (output.shape[0], -1, output.shape[2], output.shape[3]) + ) + bias_arranged = bias_arranged.permute((0, 2, 3, 1)).flatten(end_dim=3) + + output_arranged = output.permute((0, 2, 3, 1)).flatten(end_dim=3) + + _, _, bias_arranged, _ = mm_arrangement( + copy.deepcopy(input_arranged), + copy.deepcopy(weight_arranged), + bias_arranged, + copy.deepcopy(input_precision), + ) + + input_arranged, weight_arranged, output_arranged, input_precision_arranged = ( + mm_arrangement( + input_arranged, weight_arranged, output_arranged, input_precision + ) + ) + + return ( + input_arranged, + weight_arranged, + bias_arranged, + output_arranged, + input_precision_arranged, + ) + + +def application(input, weight, bias, output, input_precision): + mm_output = ntl.zeros(output.shape, dtype=ntl.float32) + mm.application(input, weight, mm_output, input_precision) + output = mm_output + bias + + +def premake( + input_precision=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + dtype=None, + block_size_m=None, + block_size_n=None, + block_size_k=None, +): + arrangement_ = functools.partial( + arrangement, + stride_h=stride_h, + stride_w=stride_w, + padding_h=padding_h, + padding_w=padding_w, + dilation_h=dilation_h, + dilation_w=dilation_w, + block_size_m=block_size_m, + block_size_n=block_size_n, + block_size_k=block_size_k, + ) + + input, weight, output = (Tensor(4, dtype=dtype) for _ in range(3)) + bias = Tensor(1, dtype=dtype) + input_precision = Tensor(0, dtype=dtype, constexpr=True, value=input_precision) + + tensors = (input, weight, bias, output, input_precision) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/max_pool2d.py b/src/ntops/kernels/max_pool2d.py new file mode 100644 index 0000000..c01edbc --- /dev/null +++ b/src/ntops/kernels/max_pool2d.py @@ -0,0 +1,42 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.pooling import arrangement + + +def application(input, output): + output = ntl.max(input, axis=-1) # noqa: F841 + + +def premake( + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + dtype=None, + block_size=None, +): + arrangement_ = functools.partial( + arrangement, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_h=stride_h, + stride_w=stride_w, + padding_h=padding_h, + padding_w=padding_w, + dilation_h=dilation_h, + dilation_w=dilation_w, + ceil_mode=ceil_mode, + block_size=block_size, + ) + + tensors = (Tensor(4, dtype=dtype, other=float("-inf")), Tensor(4, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/pooling.py b/src/ntops/kernels/pooling.py new file mode 100644 index 0000000..fcce19f --- /dev/null +++ b/src/ntops/kernels/pooling.py @@ -0,0 +1,68 @@ +import ninetoothed +from ninetoothed import Symbol + + +def arrangement( + input, + output, + kernel_size_h=None, + kernel_size_w=None, + stride_h=None, + stride_w=None, + padding_h=None, + padding_w=None, + dilation_h=None, + dilation_w=None, + ceil_mode=None, + block_size=None, +): + if kernel_size_h is None: + kernel_size_h = Symbol("kernel_size_h", constexpr=True, upper_bound=16) + + if kernel_size_w is None: + kernel_size_w = Symbol("kernel_size_w", constexpr=True, upper_bound=16) + + if stride_h is None: + stride_h = Symbol("stride_h", constexpr=True) + + if stride_w is None: + stride_w = Symbol("stride_w", constexpr=True) + + if padding_h is None: + padding_h = Symbol("padding_h", constexpr=True) + + if padding_w is None: + padding_w = Symbol("padding_w", constexpr=True) + + if dilation_h is None: + dilation_h = Symbol("dilation_h", constexpr=True) + + if dilation_w is None: + dilation_w = Symbol("dilation_w", constexpr=True) + + if ceil_mode is None: + ceil_mode = False + + if block_size is None: + block_size = ninetoothed.block_size() + + input_arranged = input.pad( + ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) + ) + input_arranged = input_arranged.tile( + (1, 1, kernel_size_h, kernel_size_w), + strides=(-1, -1, stride_h, stride_w), + dilation=(1, 1, dilation_h, dilation_w), + floor_mode=not ceil_mode, + ) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((block_size, -1)) + + output_arranged = output.tile((1, 1, 1, 1)) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..82fc596 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -1,11 +1,13 @@ from ntops.torch.abs import abs from ntops.torch.add import add from ntops.torch.addmm import addmm +from ntops.torch.avg_pool2d import avg_pool2d from ntops.torch.bitwise_and import bitwise_and from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm from ntops.torch.clamp import clamp +from ntops.torch.conv2d import conv2d from ntops.torch.cos import cos from ntops.torch.div import div from ntops.torch.dropout import dropout @@ -20,6 +22,7 @@ from ntops.torch.le import le from ntops.torch.lt import lt from ntops.torch.matmul import matmul +from ntops.torch.max_pool2d import max_pool2d from ntops.torch.mm import mm from ntops.torch.mul import mul from ntops.torch.ne import ne @@ -41,11 +44,13 @@ "abs", "add", "addmm", + "avg_pool2d", "bitwise_and", "bitwise_not", "bitwise_or", "bmm", "clamp", + "conv2d", "cos", "div", "dropout", @@ -60,6 +65,7 @@ "le", "lt", "matmul", + "max_pool2d", "mm", "mul", "ne", diff --git a/src/ntops/torch/avg_pool2d.py b/src/ntops/torch/avg_pool2d.py new file mode 100644 index 0000000..eeaf2d7 --- /dev/null +++ b/src/ntops/torch/avg_pool2d.py @@ -0,0 +1,61 @@ +import torch + +import ntops +from ntops.torch.pooling import _calculate_output_size +from ntops.torch.utils import _cached_make + + +def avg_pool2d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + if stride is None: + stride = kernel_size + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(padding, int): + padding = (padding, padding) + + assert not ceil_mode, "`ceil_mode` is not supported yet." + + assert count_include_pad, "`count_include_pad` is not supported yet." + + assert divisor_override is None, "`divisor_override` is not supported yet." + + n, c, h, w = input.shape + + h_ = _calculate_output_size( + h, kernel_size[0], stride=stride[0], padding=padding[0], ceil_mode=ceil_mode + ) + w_ = _calculate_output_size( + w, kernel_size[1], stride=stride[1], padding=padding[1], ceil_mode=ceil_mode + ) + + output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device) + + kernel = _cached_make( + ntops.kernels.avg_pool2d.premake, + dilation_h=1, + dilation_w=1, + ceil_mode=ceil_mode, + ) + + kernel( + input, + output, + kernel_size_h=kernel_size[0], + kernel_size_w=kernel_size[1], + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + ) + + return output diff --git a/src/ntops/torch/conv2d.py b/src/ntops/torch/conv2d.py new file mode 100644 index 0000000..5dadd39 --- /dev/null +++ b/src/ntops/torch/conv2d.py @@ -0,0 +1,51 @@ +import torch + +import ntops +from ntops.torch.pooling import _calculate_output_size +from ntops.torch.utils import _cached_make, _get_matmul_input_precision + + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(padding, str): + if padding == "valid": + padding = 0 + + if isinstance(padding, int): + padding = (padding, padding) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + assert groups == 1, "`groups` is not supported yet." + + n, c, h, w = input.shape + k, _, r, s = weight.shape + + p = _calculate_output_size( + h, r, stride=stride[0], padding=padding[0], dilation=dilation[0] + ) + q = _calculate_output_size( + w, s, stride=stride[1], padding=padding[1], dilation=dilation[1] + ) + + output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + + if bias is None: + bias = torch.zeros((k,), dtype=output.dtype, device=output.device) + + kernel = _cached_make( + ntops.kernels.conv2d.premake, + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + ) + + kernel(input, weight, bias, output, _get_matmul_input_precision()) + + return output diff --git a/src/ntops/torch/max_pool2d.py b/src/ntops/torch/max_pool2d.py new file mode 100644 index 0000000..f4eb55c --- /dev/null +++ b/src/ntops/torch/max_pool2d.py @@ -0,0 +1,69 @@ +import torch + +import ntops +from ntops.torch.pooling import _calculate_output_size +from ntops.torch.utils import _cached_make + + +def max_pool2d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + if stride is None: + stride = kernel_size + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(padding, int): + padding = (padding, padding) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + assert not ceil_mode, "`ceil_mode` is not supported yet." + + assert not return_indices, "`return_indices` is not supported yet." + + n, c, h, w = input.shape + + h_ = _calculate_output_size( + h, + kernel_size[0], + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + ceil_mode=ceil_mode, + ) + w_ = _calculate_output_size( + w, + kernel_size[1], + stride=stride[1], + padding=padding[1], + dilation=dilation[1], + ceil_mode=ceil_mode, + ) + + output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device) + + kernel = _cached_make(ntops.kernels.max_pool2d.premake, ceil_mode=ceil_mode) + + kernel( + input, + output, + kernel_size_h=kernel_size[0], + kernel_size_w=kernel_size[1], + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + ) + + return output diff --git a/src/ntops/torch/pooling.py b/src/ntops/torch/pooling.py new file mode 100644 index 0000000..9820938 --- /dev/null +++ b/src/ntops/torch/pooling.py @@ -0,0 +1,19 @@ +import math + + +def _calculate_output_size( + input_size, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False +): + if stride is None: + stride = kernel_size + + int_ = math.ceil if ceil_mode else math.floor + + result = int_( + (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + ) + + if ceil_mode and (result - 1) * stride >= input_size + padding: + result -= 1 + + return result diff --git a/tests/test_avg_pool2d.py b/tests/test_avg_pool2d.py new file mode 100644 index 0000000..fe36390 --- /dev/null +++ b/tests/test_avg_pool2d.py @@ -0,0 +1,47 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize( + "dtype, rtol, atol", ((torch.float32, 1e-5, 1e-5), (torch.float16, 1e-3, 1e-3)) +) +@pytest.mark.parametrize("ceil_mode", (False,)) +@pytest.mark.parametrize("padding", (0, 1, (2, 3))) +@pytest.mark.parametrize("stride", (None, 1, (2, 3))) +@pytest.mark.parametrize("kernel_size", ((1, 1), (3, 3))) +@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),)) +def test_avg_pool2d( + n, c, h, w, kernel_size, stride, padding, ceil_mode, dtype, device, rtol, atol +): + padding_ = padding + + if isinstance(padding_, int): + padding_ = (padding_, padding_) + + if padding_[0] > kernel_size[0] / 2 or padding_[1] > kernel_size[1] / 2: + pytest.skip(reason="Invalid padding.") + + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.avg_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + ) + reference_output = F.avg_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py new file mode 100644 index 0000000..279d66b --- /dev/null +++ b/tests/test_conv2d.py @@ -0,0 +1,33 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize( + "dtype, rtol, atol", ((torch.float32, 1e-5, 1e-5), (torch.float16, 1e-3, 1e-3)) +) +@pytest.mark.parametrize("dilation", (1, 2, (2, 3))) +@pytest.mark.parametrize("padding", (0, 1, (2, 3))) +@pytest.mark.parametrize("stride", (1, 2, (2, 3))) +@pytest.mark.parametrize("r, s", ((1, 1), (3, 3))) +@pytest.mark.parametrize("n, c, h, w, k", ((2, 3, 112, 112, 4),)) +def test_conv2d( + n, c, h, w, k, r, s, stride, padding, dilation, dtype, device, rtol, atol +): + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + weight = torch.randn((k, c, r, s), dtype=dtype, device=device) + bias = torch.randn((k,), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.conv2d( + input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation + ) + reference_output = F.conv2d( + input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_max_pool2d.py b/tests/test_max_pool2d.py new file mode 100644 index 0000000..9f29485 --- /dev/null +++ b/tests/test_max_pool2d.py @@ -0,0 +1,48 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("device", ("cuda",)) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float16)) +@pytest.mark.parametrize("ceil_mode", (False,)) +@pytest.mark.parametrize("dilation", (1, 2, (2, 3))) +@pytest.mark.parametrize("padding", (0, 1, (2, 3))) +@pytest.mark.parametrize("stride", (None, 1, (2, 3))) +@pytest.mark.parametrize("kernel_size", ((1, 1), (3, 3))) +@pytest.mark.parametrize("n, c, h, w", ((2, 3, 112, 112),)) +def test_max_pool2d( + n, c, h, w, kernel_size, stride, padding, dilation, ceil_mode, dtype, device +): + padding_ = padding + + if isinstance(padding_, int): + padding_ = (padding_, padding_) + + if padding_[0] > kernel_size[0] / 2 or padding_[1] > kernel_size[1] / 2: + pytest.skip(reason="Invalid padding.") + + input = torch.randn((n, c, h, w), dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.max_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + reference_output = F.max_pool2d( + input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + assert torch.allclose(ninetoothed_output, reference_output)