diff --git a/backends/cortex_m/quantizer/pattern_checkers.py b/backends/cortex_m/quantizer/pattern_checkers.py index 0210cf4dee3..a4c5d5ee4cb 100644 --- a/backends/cortex_m/quantizer/pattern_checkers.py +++ b/backends/cortex_m/quantizer/pattern_checkers.py @@ -94,10 +94,7 @@ def check_quantization_config(cls, quantization_config): is_per_tensor = PatternCheck.is_per_tensor( quantization_config.input_activation ) and PatternCheck.is_per_tensor(quantization_config.output_activation) - is_int8 = ( - quantization_config.input_activation.dtype == torch.int8 - and quantization_config.output_activation.dtype == torch.int8 - ) + is_int8 = cls.is_int8_activations(quantization_config) return is_per_tensor and is_int8 @@ -128,10 +125,7 @@ def check_quantization_config(cls, quantization_config): """ Checks that the quantization config uses per-tensor int8 quantization. """ - is_int8 = ( - quantization_config.input_activation.dtype == torch.int8 - and quantization_config.output_activation.dtype == torch.int8 - ) + is_int8 = cls.is_int8_activations(quantization_config) return is_int8