From 17a90851551fc53313bb479a7bf1f2d9251b13bb Mon Sep 17 00:00:00 2001 From: Eli Amesefe Date: Thu, 5 Feb 2026 19:08:27 -0800 Subject: [PATCH] Add rank2 unit tests to test_linear (#17254) Summary: Add rank2 unit tests to test_linear Differential Revision: D92111806 --- backends/arm/test/ops/test_linear.py | 50 ++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 08a3d04fcec..556417dcdb6 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -60,6 +60,40 @@ ), } +test_data_rank2_FP = { + # test_name: (test_data, out_features, has_bias) + "model_linear_rank2_zeros": lambda: ( + torch.zeros(10, 20), + 15, + True, + ), + "model_linear_rank2_ones": lambda: ( + torch.ones(2, 240), + 960, + False, + ), + "model_linear_rank2_negative_ones": lambda: ( + torch.ones(10, 20) * (-1), + 20, + True, + ), + "model_linear_rank2_rand": lambda: ( + torch.rand(2, 240), + 960, + True, + ), + "model_linear_rank2_negative_large_rand": lambda: ( + torch.rand(10, 20) * (-100), + 30, + False, + ), + "model_linear_rank2_large_randn": lambda: ( + torch.randn(15, 20) * 100, + 20, + True, + ), +} + test_data_rank4_FP = { # test_name: (test_data, out_features, has_bias) "model_linear_rank4_zeros": lambda: ( @@ -101,6 +135,13 @@ for q in [True, False] } +# Generate a new test set paired with per_channel_quant=True/False. +test_data_rank2_INT = { + f"{k},per_channel_quant={q}": (lambda v=v, q=q: (*v(), q)) + for (k, v) in test_data_rank2_FP.items() + for q in [True, False] +} + # Generate a new test set paired with per_channel_quant=True/False. test_data_rank4_INT = { f"{k},per_channel_quant={q}": (lambda v=v, q=q: (*v(), q)) @@ -192,7 +233,10 @@ def test_linear_tosa_INT_a8w4(test_data: torch.Tensor): pipeline.run() -@common.parametrize("test_data", test_data_rank1_INT) +@common.parametrize( + "test_data", + test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT, +) @common.XfailIfNoCorstone300 def test_linear_u55_INT(test_data: torch.Tensor): test_data, out_features, has_bias, per_channel_quantization = test_data() @@ -213,7 +257,7 @@ def test_linear_u55_INT(test_data: torch.Tensor): @common.parametrize( "test_data", - test_data_rank1_INT | test_data_rank4_INT, + test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT, ) @common.XfailIfNoCorstone320 def test_linear_u85_INT(test_data: torch.Tensor): @@ -264,7 +308,7 @@ def test_linear_vgf_quant(test_data: torch.Tensor): pipeline.run() -test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT +test_data_all_16a8w = test_data_rank1_INT | test_data_rank2_INT | test_data_rank4_INT @common.parametrize("test_data", test_data_all_16a8w)