diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 12b6ec1ae2b..78b87df7297 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -1229,7 +1229,7 @@ inline U qdot_safe( // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; + const int in_vec_size_g = (in_vec_size + group_size - 1) / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); @@ -1283,8 +1283,8 @@ inline U qdot_safe( U s = sl[0]; U b = bl[0]; - result[row] += - qdot(wl, x_thread, s, b, sum); + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); } } diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 04405f2da7d..8e0c4965170 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -283,6 +283,56 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearInt4_QMV_IMPL_small_odd(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 3, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_int4_qmv_impl_small_odd"] = { + "model_class": LinearInt4_QMV_IMPL_small_odd, + "input_shapes": [(1, 8)], + "description": "Linear int4 quantization dispatching to qmv_impl", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + +# ------------------------------------------------------------------------- +class LinearInt4_QMV_IMPL_small_even(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 10, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_int4_qmv_impl_small_even"] = { + "model_class": LinearInt4_QMV_IMPL_small_even, + "input_shapes": [(1, 8)], + "description": "Linear int4 quantization dispatching to qmv_impl", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + # ------------------------------------------------------------------------- # Convolution Modules # -------------------------------------------------------------------------