diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 8e0c4965170..403ce355381 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -287,7 +287,7 @@ def forward(self, x: torch.Tensor): class LinearInt4_QMV_IMPL_small_odd(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(8, 3, bias=True) + self.linear = nn.Linear(32, 3, bias=True) def forward(self, x: torch.Tensor): return self.linear(x) @@ -295,7 +295,7 @@ def forward(self, x: torch.Tensor): MODULE_REGISTRY["linear_int4_qmv_impl_small_odd"] = { "model_class": LinearInt4_QMV_IMPL_small_odd, - "input_shapes": [(1, 8)], + "input_shapes": [(1, 32)], "description": "Linear int4 quantization dispatching to qmv_impl", "qlinear": "fpa4w", "qlinear_group_size": 32, @@ -312,7 +312,7 @@ def forward(self, x: torch.Tensor): class LinearInt4_QMV_IMPL_small_even(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(8, 10, bias=True) + self.linear = nn.Linear(32, 10, bias=True) def forward(self, x: torch.Tensor): return self.linear(x) @@ -320,7 +320,7 @@ def forward(self, x: torch.Tensor): MODULE_REGISTRY["linear_int4_qmv_impl_small_even"] = { "model_class": LinearInt4_QMV_IMPL_small_even, - "input_shapes": [(1, 8)], + "input_shapes": [(1, 32)], "description": "Linear int4 quantization dispatching to qmv_impl", "qlinear": "fpa4w", "qlinear_group_size": 32, @@ -694,12 +694,14 @@ def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32) else: raise ValueError(f"Unsupported linear quantization config '{qlinear}'.") - def linear_filter(module, fqn): - if isinstance(module, torch.nn.Linear): - # Check if hidden dimension is divisible by group size - return qlinear_group_size == 0 or ( - module.weight.shape[1] % qlinear_group_size == 0 - ) + def linear_filter(m, fqn): + if isinstance(m, torch.nn.Linear): + if m.weight.shape[1] % qlinear_group_size != 0: + raise ValueError( + f"Metal int4 quantization requires weight dimension (K) to be multiple of group_size. " + f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]}, group_size={qlinear_group_size})" # noqa: E501 + ) + return True return False quantize_(model, linear_config, filter_fn=linear_filter)