diff --git a/backends/apple/coreml/test/test_image_processing.py b/backends/apple/coreml/test/test_image_processing.py new file mode 100644 index 00000000000..28a79fdbe36 --- /dev/null +++ b/backends/apple/coreml/test/test_image_processing.py @@ -0,0 +1,670 @@ +# Copyright © 2025 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +""" +Tests for CoreML image preprocessing models. + +These tests serve as reference examples for how to deploy ImagePreprocessor +and ImagePostprocessor to CoreML. Each test demonstrates the recommended +dtype and compute precision settings for different use cases. + +## Dtype and Compute Precision Guidelines + +### Pattern 1: Full fp16 Pipeline (Most Common) +For SDR operations, simple scaling, ImageNet normalization: +```python +ep = ImagePreprocessor.from_scale_0_1( + shape=(1, 3, 224, 224), + input_dtype=torch.float16, +) +compile_specs = CoreMLBackend.generate_compile_specs( + compute_precision=ct.precision.FLOAT16, # Best ANE performance + minimum_deployment_target=ct.target.iOS18, +) +``` + +### Pattern 2: fp16 I/O with fp32 Compute (Precision-Sensitive Operations) +For HDR10 PQ transfer function which has high-power exponents: +```python +ep = ImagePreprocessor.from_hdr10( + shape=(1, 3, 1080, 1920), + input_dtype=torch.float16, # Memory efficient I/O + output_dtype=torch.float16, + bit_depth=10, +) +compile_specs = CoreMLBackend.generate_compile_specs( + compute_precision=ct.precision.FLOAT32, # Required for PQ accuracy + minimum_deployment_target=ct.target.iOS18, +) +``` + +Reference correctness tests (comparison against colour-science) are in +extension/vision/test/test_image_processing.py. These CoreML tests focus on +verifying the CoreML delegate matches EP output and demonstrating deployment. +""" + +import unittest + +import coremltools as ct +import executorch.exir +import numpy as np +import torch +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.backends.apple.coreml.test.test_coreml_utils import ( + IS_VALID_TEST_RUNTIME, +) +from executorch.extension.vision.image_processing import ( + ColorGamut, + ColorLayout, + ImagePostprocessor, + ImagePreprocessor, + TransferFunction, +) + +if IS_VALID_TEST_RUNTIME: + from executorch.runtime import Runtime + + +class TestCoreMLImagePreprocessor(unittest.TestCase): + """ + Tests for lowering ImagePreprocessor to CoreML. + + Each test method demonstrates the recommended deployment pattern for + a specific preprocessor factory method. + """ + + # ==================== Helper Methods ==================== + + def _lower_to_coreml( + self, + ep: torch.export.ExportedProgram, + compute_precision: ct.precision = ct.precision.FLOAT16, + ): + """Lower ExportedProgram to CoreML-delegated ExecutorchProgram. + + Args: + ep: The ExportedProgram to lower. + compute_precision: CoreML compute precision. + - ct.precision.FLOAT16: Best ANE performance (default) + - ct.precision.FLOAT32: For precision-sensitive operations + + Returns: + ExecutorchProgram ready for execution. + """ + compile_specs = CoreMLBackend.generate_compile_specs( + compute_precision=compute_precision, + minimum_deployment_target=ct.target.iOS18, + ) + partitioner = CoreMLPartitioner(compile_specs=compile_specs) + + edge_program = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[partitioner] + ) + + # Verify all ops are delegated to CoreML + for node in edge_program.exported_program().graph.nodes: + if node.op == "call_function": + target_str = str(node.target) + is_delegate = "executorch_call_delegate" in target_str + is_getitem = "getitem" in target_str + self.assertTrue( + is_delegate or is_getitem, + f"Found non-delegated op: {node.target}", + ) + + return edge_program.to_executorch() + + def _run_coreml(self, executorch_program, inputs: torch.Tensor) -> np.ndarray: + """Execute CoreML-delegated program and return output.""" + if not IS_VALID_TEST_RUNTIME: + return None + + runtime = Runtime.get() + program = runtime.load_program(executorch_program.buffer) + method = program.load_method("forward") + return method.execute([inputs])[0].numpy() + + def _generate_8bit_input(self, shape=(1, 3, 64, 64)) -> torch.Tensor: + """Generate 8-bit test input (values 0-255).""" + return torch.randint(0, 256, shape, dtype=torch.float16).float() + + def _generate_10bit_input(self, shape=(1, 3, 64, 64)) -> torch.Tensor: + """Generate 10-bit test input (values 0-1023).""" + return torch.randint(0, 1024, shape, dtype=torch.float16).float() + + def _generate_12bit_input(self, shape=(1, 3, 64, 64)) -> torch.Tensor: + """Generate 12-bit test input (values 0-4095).""" + return torch.randint(0, 4096, shape, dtype=torch.float16).float() + + # ==================== SDR Preprocessors (fp16 pipeline) ==================== + + def test_from_scale_0_1(self): + """ + Example: Scale 8-bit input to [0, 1] range. + + Use case: Generic preprocessing for models expecting normalized input. + Precision: Full fp16 pipeline (best performance). + """ + shape = (1, 3, 224, 224) + + # Create preprocessor with fp16 I/O + ep = ImagePreprocessor.from_scale_0_1( + shape=shape, + input_dtype=torch.float16, + ) + + # Lower to CoreML with fp16 compute (best ANE performance) + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + # Run and verify + test_input = self._generate_8bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + # Compare with EP output + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + def test_from_scale_neg1_1(self): + """ + Example: Scale 8-bit input to [-1, 1] range. + + Use case: Models expecting symmetric normalized input (e.g., some GANs). + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePreprocessor.from_scale_neg1_1( + shape=shape, + input_dtype=torch.float16, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + test_input = self._generate_8bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + def test_from_imagenet(self): + """ + Example: ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]). + + Use case: Classification models trained on ImageNet (ResNet, EfficientNet, etc.). + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePreprocessor.from_imagenet( + shape=shape, + input_dtype=torch.float16, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + test_input = self._generate_8bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + def test_from_sdr_to_linear(self): + """ + Example: Convert SDR (sRGB gamma) to linear light. + + Use case: HDR processing pipelines, tone mapping, color grading. + Precision: Full fp16 pipeline (sRGB gamma is well-behaved). + """ + shape = (1, 3, 224, 224) + + ep = ImagePreprocessor.from_sdr( + shape=shape, + input_dtype=torch.float16, + output_dtype=torch.float16, + normalize_to_linear=True, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + test_input = self._generate_8bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + # ==================== HDR Preprocessors ==================== + + def test_from_hlg(self): + """ + Example: Convert HLG (Hybrid Log-Gamma) to linear light. + + Use case: Processing HLG HDR content (common in broadcast). + Precision: Full fp16 pipeline (HLG is well-behaved in fp16). + """ + shape = (1, 3, 1080, 1920) + + ep = ImagePreprocessor.from_hlg( + shape=shape, + input_dtype=torch.float16, + output_dtype=torch.float16, + bit_depth=10, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + test_input = self._generate_10bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.05, atol=0.05) + + def test_from_hdr10(self): + """ + Example: Convert HDR10 (PQ/ST.2084) to linear light. + + Use case: Processing HDR10 content (streaming, UHD Blu-ray). + + IMPORTANT: HDR10 uses the PQ transfer function which has high-power + exponents (m2=78.84). This causes significant precision loss in fp16. + Use fp32 compute precision for accurate results. + + Precision: fp16 I/O with fp32 compute (required for PQ accuracy). + """ + shape = (1, 3, 1080, 1920) + + # Create with fp16 I/O for memory efficiency + ep = ImagePreprocessor.from_hdr10( + shape=shape, + input_dtype=torch.float16, + output_dtype=torch.float16, + bit_depth=10, + ) + + # Use fp32 compute precision for PQ accuracy + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT32 + ) + + test_input = self._generate_10bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + # For comparison, create fp32 EP to match fp32 compute precision + ep_fp32 = ImagePreprocessor.from_hdr10( + shape=shape, + input_dtype=torch.float32, + output_dtype=torch.float32, + bit_depth=10, + ) + ep_out = ep_fp32.module()(test_input.to(torch.float32)).detach().numpy() + ep_out = ep_out.astype(np.float16) # Cast to match CoreML output dtype + + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + # ==================== Color Layout Conversions ==================== + + def test_bgr_to_rgb(self): + """ + Example: Convert BGR input to RGB output. + + Use case: OpenCV uses BGR format; convert for RGB-expecting models. + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePreprocessor.from_scale_0_1( + shape=shape, + input_dtype=torch.float16, + input_color=ColorLayout.BGR, + output_color=ColorLayout.RGB, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + test_input = self._generate_8bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + def test_rgb_to_grayscale(self): + """ + Example: Convert RGB to grayscale. + + Use case: Models expecting single-channel input. + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePreprocessor.from_scale_0_1( + shape=shape, + input_dtype=torch.float16, + input_color=ColorLayout.RGB, + output_color=ColorLayout.GRAYSCALE, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + test_input = self._generate_8bit_input(shape).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + # ==================== Gamut Conversion ==================== + + def test_bt2020_to_bt709(self): + """ + Example: Convert BT.2020 (wide gamut) to BT.709 (SDR gamut). + + Use case: HDR to SDR conversion, displaying HDR content on SDR screens. + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 1080, 1920) + + model = ImagePreprocessor( + bit_depth=10, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.LINEAR, + input_gamut=ColorGamut.BT2020, + output_gamut=ColorGamut.BT709, + ) + model.eval() + test_input = self._generate_10bit_input(shape).to(torch.float16) + ep = torch.export.export(model, (test_input,), strict=True) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + +class TestCoreMLImagePostprocessor(unittest.TestCase): + """ + Tests for lowering ImagePostprocessor to CoreML. + + Each test method demonstrates the recommended deployment pattern for + a specific postprocessor factory method. + """ + + # ==================== Helper Methods ==================== + + def _lower_to_coreml( + self, + ep: torch.export.ExportedProgram, + compute_precision: ct.precision = ct.precision.FLOAT16, + ): + """Lower ExportedProgram to CoreML-delegated ExecutorchProgram.""" + compile_specs = CoreMLBackend.generate_compile_specs( + compute_precision=compute_precision, + minimum_deployment_target=ct.target.iOS18, + ) + partitioner = CoreMLPartitioner(compile_specs=compile_specs) + + edge_program = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[partitioner] + ) + + for node in edge_program.exported_program().graph.nodes: + if node.op == "call_function": + target_str = str(node.target) + is_delegate = "executorch_call_delegate" in target_str + is_getitem = "getitem" in target_str + self.assertTrue( + is_delegate or is_getitem, + f"Found non-delegated op: {node.target}", + ) + + return edge_program.to_executorch() + + def _run_coreml(self, executorch_program, inputs: torch.Tensor) -> np.ndarray: + """Execute CoreML-delegated program and return output.""" + if not IS_VALID_TEST_RUNTIME: + return None + + runtime = Runtime.get() + program = runtime.load_program(executorch_program.buffer) + method = program.load_method("forward") + return method.execute([inputs])[0].numpy() + + # ==================== SDR Postprocessors (fp16 pipeline) ==================== + + def test_from_scale_0_1(self): + """ + Example: Convert [0, 1] normalized output to 8-bit [0, 255]. + + Use case: Converting model output back to displayable image. + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePostprocessor.from_scale_0_1( + shape=shape, + input_dtype=torch.float16, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + # Input is [0, 1] normalized + test_input = torch.rand(shape, dtype=torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + def test_from_scale_neg1_1(self): + """ + Example: Convert [-1, 1] normalized output to 8-bit [0, 255]. + + Use case: GAN output conversion. + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePostprocessor.from_scale_neg1_1( + shape=shape, + input_dtype=torch.float16, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + # Input is [-1, 1] normalized + test_input = torch.rand(shape, dtype=torch.float16) * 2 - 1 + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + def test_from_imagenet(self): + """ + Example: Reverse ImageNet normalization. + + Use case: Visualizing model intermediate outputs or reconstructions. + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePostprocessor.from_imagenet( + shape=shape, + input_dtype=torch.float16, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + # Input is ImageNet-normalized + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + test_input = ((torch.rand(shape) - mean) / std).to(torch.float16) + + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.02, atol=0.1) + + def test_from_linear_to_srgb(self): + """ + Example: Convert linear light to sRGB gamma. + + Use case: Displaying linear light processing results. + Precision: Full fp16 pipeline. + """ + shape = (1, 3, 224, 224) + + ep = ImagePostprocessor.from_linear_to_srgb( + shape=shape, + input_dtype=torch.float16, + output_dtype=torch.float16, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + # Input is linear light [0, 1] + test_input = torch.rand(shape, dtype=torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.01, atol=0.01) + + # ==================== HDR Postprocessors ==================== + + def test_from_linear_to_hlg(self): + """ + Example: Convert linear light to HLG. + + Use case: Encoding processed content for HLG HDR display. + Precision: Full fp16 pipeline (HLG is well-behaved). + """ + shape = (1, 3, 1080, 1920) + + ep = ImagePostprocessor.from_linear_to_hlg( + shape=shape, + input_dtype=torch.float16, + output_dtype=torch.float16, + ) + + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT16 + ) + + # Input is linear light [0, 1] + test_input = torch.rand(shape, dtype=torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + ep_out = ep.module()(test_input).detach().numpy() + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.05, atol=0.05) + + def test_from_linear_to_hdr10(self): + """ + Example: Convert linear light to HDR10 (PQ/ST.2084). + + Use case: Encoding processed content for HDR10 display. + + IMPORTANT: HDR10 uses the PQ transfer function which has high-power + exponents (m2=78.84). Use fp32 compute precision for accurate results. + + Precision: fp16 I/O with fp32 compute (required for PQ accuracy). + """ + shape = (1, 3, 1080, 1920) + + # Create with fp16 I/O for memory efficiency + ep = ImagePostprocessor.from_linear_to_hdr10( + shape=shape, + input_dtype=torch.float16, + output_dtype=torch.float16, + ) + + # Use fp32 compute precision for PQ accuracy + executorch_program = self._lower_to_coreml( + ep, compute_precision=ct.precision.FLOAT32 + ) + + # Input is linear light [0, 1], avoid very small values + test_input = (torch.rand(shape) * 0.99 + 0.01).to(torch.float16) + coreml_out = self._run_coreml(executorch_program, test_input) + + if coreml_out is None: + self.skipTest("CoreML runtime not available") + + # For comparison, create fp32 EP to match fp32 compute precision + ep_fp32 = ImagePostprocessor.from_linear_to_hdr10( + shape=shape, + input_dtype=torch.float32, + output_dtype=torch.float32, + ) + ep_out = ep_fp32.module()(test_input.to(torch.float32)).detach().numpy() + ep_out = ep_out.astype(np.float16) # Cast to match CoreML output dtype + + np.testing.assert_allclose(coreml_out, ep_out, rtol=0.05, atol=35.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/vision/README.md b/extension/vision/README.md new file mode 100644 index 00000000000..ef276b38229 --- /dev/null +++ b/extension/vision/README.md @@ -0,0 +1,257 @@ +# ExecuTorch Vision Extension + +Image preprocessing and postprocessing utilities for on-device inference with ExecuTorch. + +## Overview + +This module provides `ImagePreprocessor` and `ImagePostprocessor` classes that generate `ExportedProgram` objects ready to be lowered to any ExecuTorch backend (CoreML, MPS, XNNPACK, etc.). + +### Key Features + +- **SDR Support**: 8-bit sRGB content (photos, standard video) +- **HDR Support**: 10/12-bit PQ (HDR10, Dolby Vision) and HLG (broadcast HDR) +- **Color Gamut Conversion**: BT.2020 ↔ BT.709 wide color gamut support +- **Standard Normalizations**: ImageNet, [0,1], [-1,1] scaling +- **Color Layout**: RGB ↔ BGR conversion, grayscale +- **Precision Control**: fp16/fp32 output dtype selection + +## Quick Start + +```python +from executorch.extension.vision import ImagePreprocessor, ImagePostprocessor + +# Simple [0, 255] → [0, 1] preprocessing +preprocessor = ImagePreprocessor.from_scale_0_1( + shape=(1, 3, 480, 640), + input_dtype=torch.float16, +) + +# ImageNet normalization +preprocessor = ImagePreprocessor.from_imagenet( + shape=(1, 3, 224, 224), + input_dtype=torch.float16, +) + +# HDR10 → linear BT.709 (for HDR video processing) +preprocessor = ImagePreprocessor.from_hdr10( + shape=(1, 3, 1080, 1920), + input_dtype=torch.float32, # fp32 recommended for PQ precision + output_dtype=torch.float16, +) + +# Lower to your backend +from executorch.exir import to_edge_transform_and_lower +program = to_edge_transform_and_lower(preprocessor, partitioner=[YourPartitioner()]) +``` + +## ImagePreprocessor + +Converts input images to normalized tensors for model inference. + +### Factory Methods + +| Method | Input | Output | Use Case | +|--------|-------|--------|----------| +| `from_scale_0_1()` | [0, 255] | [0, 1] | Simple normalization | +| `from_scale_neg1_1()` | [0, 255] | [-1, 1] | Zero-centered models | +| `from_imagenet()` | [0, 255] | ImageNet normalized | Classification models | +| `from_sdr()` | 8-bit sRGB | Linear or normalized | SDR with gamma correction | +| `from_hdr10()` | 10-bit PQ BT.2020 | Linear BT.709 | HDR10/Dolby Vision | +| `from_hlg()` | 10-bit HLG BT.2020 | Linear BT.709 | Broadcast HDR | + +### Processing Pipeline + +``` +Input Image + │ + ▼ +┌─────────────────────────┐ +│ 1. Color Layout │ BGR → RGB (if needed) +├─────────────────────────┤ +│ 2. Normalize to [0,1] │ Divide by max_value (255, 1023, 4095) +├─────────────────────────┤ +│ 3. Inverse Transfer │ PQ/HLG/sRGB → Linear +├─────────────────────────┤ +│ 4. Gamut Conversion │ BT.2020 → BT.709 (if needed) +├─────────────────────────┤ +│ 5. Output Transfer │ Linear → sRGB (if needed) +├─────────────────────────┤ +│ 6. Bias & Scale │ (x + bias) * scale +├─────────────────────────┤ +│ 7. Output Layout │ RGB → Grayscale/BGR (if needed) +└─────────────────────────┘ + │ + ▼ +Normalized Tensor +``` + +## ImagePostprocessor + +Converts model output tensors back to displayable images. Performs the inverse operations of `ImagePreprocessor`. + +### Factory Methods + +| Method | Input | Output | Use Case | +|--------|-------|--------|----------| +| `from_scale_0_1()` | [0, 1] | [0, 255] | Simple denormalization | +| `from_scale_neg1_1()` | [-1, 1] | [0, 255] | Zero-centered models | +| `from_imagenet()` | ImageNet normalized | [0, 255] | Classification models | +| `from_linear_to_srgb()` | Linear | 8-bit sRGB | SDR output | +| `from_linear_to_hdr10()` | Linear BT.709 | 10-bit PQ BT.2020 | HDR10 output | +| `from_linear_to_hlg()` | Linear BT.709 | 10-bit HLG BT.2020 | Broadcast HDR output | + +## HDR Processing + +### Precision Considerations + +⚠️ **Important**: PQ (HDR10/Dolby Vision) transfer functions require **fp32 compute precision** for accurate results. The PQ EOTF contains an exponent of m2=78.84, which causes significant precision loss when computed in fp16. + +| Transfer Function | Recommended Precision | Notes | +|-------------------|----------------------|-------| +| **PQ (HDR10)** | fp32 input | m2=78.84 exponent loses precision in fp16 | +| **HLG** | fp16 OK | Piecewise sqrt/log is fp16-friendly | +| **sRGB** | fp16 OK | Simple x^2.2 power function | + +### HDR10 (PQ / Dolby Vision) + +```python +# Preprocessing: HDR10 input → linear for model +# ⚠️ Use fp32 input_dtype for accurate PQ decoding +preprocessor = ImagePreprocessor.from_hdr10( + shape=(1, 3, 1080, 1920), + input_dtype=torch.float32, # IMPORTANT: fp32 required for PQ precision + output_dtype=torch.float16, # Output can be fp16 for model inference +) + +# Postprocessing: model output → HDR10 display +# ⚠️ Use fp32 output_dtype for accurate PQ encoding +postprocessor = ImagePostprocessor.from_linear_to_hdr10( + shape=(1, 3, 1080, 1920), + input_dtype=torch.float16, + output_dtype=torch.float32, # IMPORTANT: fp32 required for PQ precision +) +``` + +### HLG (Broadcast HDR) + +HLG uses a piecewise sqrt/log function that is more numerically stable in fp16. + +```python +# Preprocessing: HLG input → linear for model +preprocessor = ImagePreprocessor.from_hlg( + shape=(1, 3, 1080, 1920), + input_dtype=torch.float16, # fp16 OK for HLG + output_dtype=torch.float16, +) + +# Postprocessing: model output → HLG display +postprocessor = ImagePostprocessor.from_linear_to_hlg( + shape=(1, 3, 1080, 1920), + input_dtype=torch.float16, + output_dtype=torch.float16, # fp16 OK for HLG +) +``` + +## Platform Integration Notes + +The following operations should be done **outside** the model using platform-native APIs for best performance: + +| Operation | iOS | Android | +|-----------|-----|---------| +| uint8 ↔ float | vDSP (~1ms for 512×512) | RenderScript | +| YUV → RGB | vImage | YuvImage | +| Resize/Crop | vImage, Metal | Bitmap APIs | + +The preprocessor/postprocessor handles the **numerically intensive** operations (transfer functions, gamut conversion, normalization) that benefit from acceleration. + +## Transfer Functions + +### Supported Functions + +| Function | Standard | Use Case | +|----------|----------|----------| +| **sRGB** | IEC 61966-2-1 | Standard displays, web | +| **PQ** | SMPTE ST.2084 | HDR10, Dolby Vision | +| **HLG** | BT.2100 | Broadcast HDR (BBC/NHK) | +| **Linear** | - | ML models, compositing | + +### Implementation Details + +- **PQ (Perceptual Quantizer)**: Uses the ST.2084 EOTF with constants m1=0.1593, m2=78.84, c1=0.8359, c2=18.85, c3=18.69. The large m2 exponent causes precision loss in fp16 - use fp32 for accurate results. + +- **HLG (Hybrid Log-Gamma)**: Piecewise function with sqrt for low values and log for high values. More fp16-friendly than PQ. + +- **sRGB**: Uses the x^2.2 / x^(1/2.2) approximation rather than the piecewise sRGB transfer function. + +## Color Gamut + +### Supported Gamuts + +| Gamut | Standard | Coverage | +|-------|----------|----------| +| **BT.709** | Rec. 709 / sRGB | Standard SDR | +| **BT.2020** | Rec. 2020 | Wide color gamut (HDR) | + +### Conversion Matrices + +The module includes accurate 3×3 matrices for BT.709 ↔ BT.2020 conversion, validated against the `colour-science` reference library. + +## Testing + +Tests compare implementations against the `colour-science` library which provides reference implementations of ITU/SMPTE standards. + +```bash +# Run tests +python -m unittest extension.vision.test.test_image_processing -v + +# Or directly +python extension/vision/test/test_image_processing.py -v +``` + +### Test Coverage + +- Transfer function accuracy vs colour-science reference +- Gamut conversion matrix validation +- Full E2E pipeline tests for all factory methods +- fp16 and fp32 precision validation +- Roundtrip tests (forward → inverse) + +## API Reference + +### ImagePreprocessor + +```python +class ImagePreprocessor(torch.nn.Module): + """ + Args: + bit_depth: Input bit depth (8, 10, or 12). Default 8. + input_transfer: Transfer function of input (SRGB, PQ, HLG, LINEAR). + output_transfer: Desired transfer function of output. + input_gamut: Color gamut of input (BT709, BT2020). + output_gamut: Desired color gamut of output. + input_color: Color layout of input (RGB, BGR). + output_color: Desired color layout (RGB, BGR, GRAYSCALE). + channel_bias: Per-channel bias [R, G, B]. + channel_scale: Per-channel scale [R, G, B]. + preset: Preset name ("scale_0_1", "scale_neg1_1", "imagenet"). + output_dtype: Output dtype (torch.float16 or torch.float32). + """ +``` + +### ImagePostprocessor + +```python +class ImagePostprocessor(torch.nn.Module): + """ + Args: + bit_depth: Output bit depth (8, 10, or 12). Default 8. + input_transfer: Transfer function of input (LINEAR, SRGB). + output_transfer: Desired transfer function of output (SRGB, PQ, HLG). + input_gamut: Color gamut of input (BT709, BT2020). + output_gamut: Desired color gamut of output. + input_color: Color layout of input (RGB, BGR, GRAYSCALE). + output_color: Desired color layout (RGB, BGR). + preset: Preset name for inverse normalization. + output_dtype: Output dtype (torch.float16 or torch.float32). + """ +``` diff --git a/extension/vision/TARGETS b/extension/vision/TARGETS new file mode 100644 index 00000000000..273386f5245 --- /dev/null +++ b/extension/vision/TARGETS @@ -0,0 +1,32 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "image_processing", + srcs = [ + "__init__.py", + "image_processing.py", + ], + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_test( + name = "test_image_processing", + srcs = [ + "test/__init__.py", + "test/test_image_processing.py", + ], + deps = [ + "fbsource//third-party/pypi/colour-science:colour-science", + "fbsource//third-party/pypi/parameterized:parameterized", + ":image_processing", + "//caffe2:torch", + ], +) diff --git a/extension/vision/__init__.py b/extension/vision/__init__.py new file mode 100644 index 00000000000..cc13577ed57 --- /dev/null +++ b/extension/vision/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Vision extension module for image processing utilities.""" + +from executorch.extension.vision.image_processing import ( + ColorGamut, + ImagePostprocessor, + ImagePreprocessor, + TransferFunction, +) + +__all__ = [ + "ImagePreprocessor", + "ImagePostprocessor", + "TransferFunction", + "ColorGamut", +] diff --git a/extension/vision/image_processing.py b/extension/vision/image_processing.py new file mode 100644 index 00000000000..3ceb7711395 --- /dev/null +++ b/extension/vision/image_processing.py @@ -0,0 +1,1010 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Image preprocessing and postprocessing utilities for on-device inference. + +This module provides: +- ImagePreprocessor: Convert input images to normalized tensors for model inference +- ImagePostprocessor: Convert model output tensors back to displayable images + +Common operations handled: +- Color space conversion (RGB ↔ BGR, grayscale) +- Per-channel bias and scale normalization +- HDR support (PQ, HLG transfer functions) +- Wide color gamut (BT.2020 ↔ BT.709 conversion) +- Multiple bit depths (8, 10, 12-bit) + +NOTE: The following should be done OUTSIDE the model using platform-native APIs: +- uint8 ↔ float conversion (use vDSP on iOS, ~1ms for 512x512) +- YUV → RGB conversion (use vImage on iOS) +- Resize/crop (use vImage or Metal on iOS) + +Usage Examples: + from executorch.extension.vision import ImagePreprocessor + + # Get an ExportedProgram for [0, 255] → [0, 1] preprocessing + ep = ImagePreprocessor.from_scale_0_1( + shape=(1, 3, 480, 640), + input_dtype=torch.float16, + ) + + # For HDR10 with float32 precision (recommended for accurate PQ) + ep = ImagePreprocessor.from_hdr10( + shape=(1, 3, 1080, 1920), + input_dtype=torch.float32, # Use float32 for HDR precision + output_dtype=torch.float16, + ) + + # Lower to any backend + from executorch.exir import to_edge_transform_and_lower + program = to_edge_transform_and_lower(ep, partitioner=[YourPartitioner()]) +""" + +import warnings +from enum import Enum +from typing import List, Optional, Tuple + +import torch +import torch.export + + +# ============================================================================= +# Enums +# ============================================================================= + + +class ColorLayout(Enum): + """Color layout options matching CoreML's color_layout.""" + + RGB = "RGB" + BGR = "BGR" + GRAYSCALE = "GRAYSCALE" + + +class TransferFunction(Enum): + """Transfer function (gamma curve) options.""" + + SRGB = "srgb" # Standard sRGB gamma (~2.2) + LINEAR = "linear" # Linear (no gamma) + PQ = "pq" # Perceptual Quantizer (HDR10, Dolby Vision) + HLG = "hlg" # Hybrid Log-Gamma (broadcast HDR) + + +class ColorGamut(Enum): + """Color gamut options.""" + + BT709 = "bt709" # Standard sRGB/Rec.709 (SDR) + BT2020 = "bt2020" # Wide color gamut (HDR) + + +# ============================================================================= +# Shared Constants +# ============================================================================= + +# BT.2020 to BT.709 color matrix (3x3) +BT2020_TO_BT709 = torch.tensor( + [ + [1.6605, -0.5876, -0.0728], + [-0.1246, 1.1329, -0.0083], + [-0.0182, -0.1006, 1.1187], + ] +) + +# BT.709 to BT.2020 color matrix (inverse) +BT709_TO_BT2020 = torch.tensor( + [ + [0.6274, 0.3293, 0.0433], + [0.0691, 0.9195, 0.0114], + [0.0164, 0.0880, 0.8956], + ] +) + +# Luminance weights for RGB to grayscale conversion +LUMINANCE_WEIGHTS = torch.tensor([0.299, 0.587, 0.114]) + + +# ============================================================================= +# Shared Transfer Functions +# ============================================================================= + + +def apply_srgb_gamma(x: torch.Tensor) -> torch.Tensor: + """Convert linear to sRGB gamma. Approximation: x^(1/2.2)""" + return torch.pow(x.clamp(min=1e-6), 1.0 / 2.2) + + +def apply_srgb_inverse(x: torch.Tensor) -> torch.Tensor: + """Convert sRGB gamma to linear. Approximation: x^2.2""" + return torch.pow(x.clamp(min=1e-6), 2.2) + + +def apply_pq_inverse(x: torch.Tensor) -> torch.Tensor: + """ + Convert PQ (Perceptual Quantizer) to linear. + PQ is used in HDR10 and Dolby Vision. + """ + m1 = 0.1593017578125 + m2 = 78.84375 + c1 = 0.8359375 + c2 = 18.8515625 + c3 = 18.6875 + + x = x.clamp(min=1e-6, max=1.0) + x_m2 = torch.pow(x, 1.0 / m2) + numerator = (x_m2 - c1).clamp(min=0.0) + denominator = (c2 - c3 * x_m2).clamp(min=1e-6) + return torch.pow(numerator / denominator, 1.0 / m1) + + +def apply_pq_forward(x: torch.Tensor) -> torch.Tensor: + """ + Convert linear to PQ (Perceptual Quantizer). + Inverse of apply_pq_inverse. + """ + m1 = 0.1593017578125 + m2 = 78.84375 + c1 = 0.8359375 + c2 = 18.8515625 + c3 = 18.6875 + + x = x.clamp(min=1e-6, max=1.0) + x_m1 = torch.pow(x, m1) + numerator = c1 + c2 * x_m1 + denominator = 1.0 + c3 * x_m1 + return torch.pow(numerator / denominator, m2) + + +def apply_hlg_inverse(x: torch.Tensor) -> torch.Tensor: + """ + Convert HLG (Hybrid Log-Gamma) to linear. + HLG is used in broadcast HDR (BBC/NHK). + """ + a = 0.17883277 + b = 0.28466892 + c = 0.55991073 + + x = x.clamp(min=1e-6, max=1.0) + low = (x * x) / 3.0 + high = (torch.exp((x - c) / a) + b) / 12.0 + return torch.where(x <= 0.5, low, high) + + +def apply_hlg_forward(x: torch.Tensor) -> torch.Tensor: + """ + Convert linear to HLG (Hybrid Log-Gamma). + Inverse of apply_hlg_inverse. + """ + a = 0.17883277 + b = 0.28466892 + c = 0.55991073 + + x = x.clamp(min=1e-6, max=1.0) + low = torch.sqrt(3.0 * x) + high = a * torch.log(12.0 * x - b) + c + return torch.where(x <= 1.0 / 12.0, low, high) + + +def apply_gamut_conversion(x: torch.Tensor, color_matrix: torch.Tensor) -> torch.Tensor: + """Apply 3x3 color matrix for gamut conversion.""" + # [B, 3, H, W] -> [B, H, W, 3] for matmul + x = x.permute(0, 2, 3, 1) + x = torch.matmul(x, color_matrix.T) + # [B, H, W, 3] -> [B, 3, H, W] + x = x.permute(0, 3, 1, 2) + return x.clamp(min=0.0, max=1.0) + + +# ============================================================================= +# Shared Presets +# ============================================================================= + +# Standard normalization presets +# For preprocessor: bias and scale applied as output = (input + bias) * scale +# For postprocessor: inverse applied as output = (input / scale) - bias +PRESETS = { + # Simple pass-through [0, 1] -> [0, 1] + "none": { + "bias": [0.0, 0.0, 0.0], + "scale": [1.0, 1.0, 1.0], + }, + # Scale [0, 1] to [0, 1] (identity, but explicit) + "scale_0_1": { + "bias": [0.0, 0.0, 0.0], + "scale": [1.0, 1.0, 1.0], + }, + # Zero-centered [-1, 1] from [0, 1] input + "scale_neg1_1": { + "bias": [-0.5, -0.5, -0.5], + "scale": [2.0, 2.0, 2.0], + }, + # ImageNet normalization from [0, 1] input + # Formula: (x - mean) / std + "imagenet": { + "bias": [-0.485, -0.456, -0.406], + "scale": [1 / 0.229, 1 / 0.224, 1 / 0.225], + }, +} + + +# ============================================================================= +# ImagePreprocessor +# ============================================================================= + + +class ImagePreprocessor(torch.nn.Module): + """ + Comprehensive image preprocessing model - replacement for CoreML's ImageType. + + Handles all common preprocessing for image/video models: + - SDR: 8-bit sRGB content (photos, standard video) + - HDR: 10/12-bit PQ or HLG content (HDR10, Dolby Vision, broadcast HDR) + - Color layout conversion (RGB ↔ BGR, grayscale) + - Per-channel bias and scale normalization + - Color gamut conversion (BT.2020 ↔ BT.709) + + Processing pipeline: + 1. Convert input color layout (BGR → RGB if needed) + 2. Normalize to [0, 1] based on bit depth + 3. Apply inverse transfer function (linearize): gamma/PQ/HLG → linear + 4. Apply color gamut conversion (BT.2020 → BT.709 if needed) + 5. Apply output transfer function (linear → gamma if needed) + 6. Apply per-channel bias and scale (normalization) + 7. Convert to output color layout (grayscale, BGR if needed) + + Args: + bit_depth: Input bit depth (8, 10, or 12). Default 8 for SDR. + input_transfer: Transfer function of input (SRGB, PQ, HLG, LINEAR). + output_transfer: Desired transfer function of output (SRGB, LINEAR). + input_gamut: Color gamut of input (BT709, BT2020). Default BT709. + output_gamut: Desired color gamut of output (BT709, BT2020). Default BT709. + input_color: Color layout of input (RGB, BGR). Default RGB. + output_color: Desired color layout (RGB, BGR, GRAYSCALE). Default RGB. + channel_bias: Per-channel bias [R, G, B] applied after all conversions. + channel_scale: Per-channel scale [R, G, B] applied after bias. + preset: Optional preset name that sets bias/scale. + output_dtype: Output data type (torch.float16 or torch.float32). + + Input: + float16 tensor [B, C, H, W] with values in [0, max_val] based on bit_depth + + Output: + float16/float32 tensor [B, C', H, W] normalized and converted + """ + + def __init__( + self, + bit_depth: int = 8, + input_transfer: TransferFunction = TransferFunction.LINEAR, + output_transfer: TransferFunction = TransferFunction.LINEAR, + input_gamut: ColorGamut = ColorGamut.BT709, + output_gamut: ColorGamut = ColorGamut.BT709, + input_color: ColorLayout = ColorLayout.RGB, + output_color: ColorLayout = ColorLayout.RGB, + channel_bias: Optional[List[float]] = None, + channel_scale: Optional[List[float]] = None, + preset: Optional[str] = None, + output_dtype: torch.dtype = torch.float16, + ): + super().__init__() + + if bit_depth not in (8, 10, 12): + raise ValueError(f"bit_depth must be 8, 10, or 12, got {bit_depth}") + + self.bit_depth = bit_depth + self.max_value = float((2**bit_depth) - 1) + self.output_dtype = output_dtype + + # Store flags for control flow (avoids enum comparisons during tracing) + self.input_is_bgr = input_color == ColorLayout.BGR + self.output_is_bgr = output_color == ColorLayout.BGR + self.output_is_grayscale = output_color == ColorLayout.GRAYSCALE + self.input_is_srgb = input_transfer == TransferFunction.SRGB + self.input_is_pq = input_transfer == TransferFunction.PQ + self.input_is_hlg = input_transfer == TransferFunction.HLG + self.output_is_srgb = output_transfer == TransferFunction.SRGB + + # Use preset if specified + if preset is not None: + if preset not in PRESETS: + raise ValueError( + f"Unknown preset: {preset}. Available: {list(PRESETS.keys())}" + ) + preset_config = PRESETS[preset] + channel_bias = preset_config["bias"] + channel_scale = preset_config["scale"] + + # Default: no additional normalization + if channel_bias is None: + channel_bias = [0.0, 0.0, 0.0] + if channel_scale is None: + channel_scale = [1.0, 1.0, 1.0] + + # Register color matrix if gamut conversion needed + if input_gamut != output_gamut: + if input_gamut == ColorGamut.BT2020 and output_gamut == ColorGamut.BT709: + self.register_buffer("color_matrix", BT2020_TO_BT709.to(torch.float16)) + elif input_gamut == ColorGamut.BT709 and output_gamut == ColorGamut.BT2020: + self.register_buffer("color_matrix", BT709_TO_BT2020.to(torch.float16)) + else: + self.color_matrix = None + + # Register grayscale weights if needed + if output_color == ColorLayout.GRAYSCALE: + self.register_buffer( + "luminance_weights", + LUMINANCE_WEIGHTS.view(1, 3, 1, 1).to(torch.float16), + ) + gray_bias = sum(channel_bias) / 3.0 + gray_scale = sum(channel_scale) / 3.0 + self.register_buffer( + "bias", torch.tensor([gray_bias]).view(1, 1, 1, 1).to(torch.float16) + ) + self.register_buffer( + "scale", torch.tensor([gray_scale]).view(1, 1, 1, 1).to(torch.float16) + ) + else: + self.register_buffer( + "bias", torch.tensor(channel_bias).view(1, 3, 1, 1).to(torch.float16) + ) + self.register_buffer( + "scale", torch.tensor(channel_scale).view(1, 3, 1, 1).to(torch.float16) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Step 1: Handle BGR → RGB if needed + if self.input_is_bgr: + x = x.flip(dims=[1]) + + # Step 2: Normalize to [0, 1] based on bit depth + x = x / self.max_value + + # Step 3: Apply inverse transfer function (to linear) + if self.input_is_srgb: + x = apply_srgb_inverse(x) + elif self.input_is_pq: + x = apply_pq_inverse(x) + elif self.input_is_hlg: + x = apply_hlg_inverse(x) + + # Step 4: Apply color gamut conversion if needed + if self.color_matrix is not None: + x = apply_gamut_conversion(x, self.color_matrix) + + # Step 5: Apply output transfer function (from linear) + if self.output_is_srgb: + x = apply_srgb_gamma(x) + + # Step 6: Convert to grayscale if requested + if self.output_is_grayscale: + x = (x * self.luminance_weights).sum(dim=1, keepdim=True) + + # Step 7: Apply normalization (bias and scale) + x = (x + self.bias) * self.scale + + # Step 8: Convert RGB → BGR if needed + if self.output_is_bgr: + x = x.flip(dims=[1]) + + # Step 9: Convert to output dtype + return x.to(self.output_dtype) + + # ==================== Factory Methods ==================== + # These return ExportedProgram ready to lower to any backend + + @classmethod + def _export( + cls, + model: "ImagePreprocessor", + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype, + ) -> torch.export.ExportedProgram: + """Helper to export a model to ExportedProgram.""" + # Convert model buffers to match input dtype for consistent precision + model = model.to(input_dtype) + model.eval() + example_inputs = (torch.randn(*shape, dtype=input_dtype),) + return torch.export.export(model, example_inputs, strict=True) + + @classmethod + def from_scale_0_1( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + input_color: ColorLayout = ColorLayout.RGB, + output_color: ColorLayout = ColorLayout.RGB, + ) -> torch.export.ExportedProgram: + """ + Create and export preprocessor that scales [0, 255] → [0, 1]. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype (use float32 for precision) + output_dtype: Output tensor dtype + input_color: Input color layout (RGB, BGR) + output_color: Output color layout (RGB, BGR, GRAYSCALE) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.LINEAR, + input_color=input_color, + output_color=output_color, + preset="scale_0_1", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_scale_neg1_1( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + input_color: ColorLayout = ColorLayout.RGB, + output_color: ColorLayout = ColorLayout.RGB, + ) -> torch.export.ExportedProgram: + """ + Create and export preprocessor that scales [0, 255] → [-1, 1]. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype (use float32 for precision) + output_dtype: Output tensor dtype + input_color: Input color layout (RGB, BGR) + output_color: Output color layout (RGB, BGR, GRAYSCALE) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.LINEAR, + input_color=input_color, + output_color=output_color, + preset="scale_neg1_1", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_imagenet( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + input_color: ColorLayout = ColorLayout.RGB, + ) -> torch.export.ExportedProgram: + """ + Create and export preprocessor with ImageNet normalization. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + input_color: Input color layout (RGB, BGR) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.LINEAR, + input_color=input_color, + preset="imagenet", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_hdr10( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, # CoreML uses fp16 internally + output_dtype: torch.dtype = torch.float16, + bit_depth: int = 10, + output_transfer: TransferFunction = TransferFunction.LINEAR, + output_gamut: ColorGamut = ColorGamut.BT709, + preset: Optional[str] = None, + ) -> torch.export.ExportedProgram: + """ + Create and export preprocessor for HDR10 content. + HDR10 uses PQ transfer function and BT.2020 color gamut. + + NOTE: CoreML uses fp16 internally. The PQ transfer function has limited + precision in fp16 due to high-power exponents (m2=78.84). This may cause + ~5% error in bright regions compared to float32 reference implementations. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + bit_depth: Input bit depth (10 or 12) + output_transfer: Output transfer function (LINEAR or SRGB) + output_gamut: Output color gamut (BT709 or BT2020) + preset: Optional normalization preset + + Returns: + ExportedProgram ready to lower to any backend + """ + if input_dtype == torch.float16: + warnings.warn( + "HDR10 (PQ transfer function) has significant precision loss in fp16 " + "due to high-power exponents (m2=78.84). Expect up to 50% relative error " + "compared to fp32. Consider using fp32 for input_dtype if accuracy " + "is critical.", + UserWarning, + stacklevel=2, + ) + model = cls( + bit_depth=bit_depth, + input_transfer=TransferFunction.PQ, + output_transfer=output_transfer, + input_gamut=ColorGamut.BT2020, + output_gamut=output_gamut, + preset=preset, + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_hlg( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, # CoreML uses fp16 internally + output_dtype: torch.dtype = torch.float16, + bit_depth: int = 10, + output_transfer: TransferFunction = TransferFunction.LINEAR, + output_gamut: ColorGamut = ColorGamut.BT709, + preset: Optional[str] = None, + ) -> torch.export.ExportedProgram: + """ + Create and export preprocessor for HLG (Hybrid Log-Gamma) content. + HLG is used in broadcast HDR and typically uses BT.2020 color gamut. + + NOTE: CoreML uses fp16 internally, which may affect precision. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + bit_depth: Input bit depth (10 or 12) + output_transfer: Output transfer function (LINEAR or SRGB) + output_gamut: Output color gamut (BT709 or BT2020) + preset: Optional normalization preset + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=bit_depth, + input_transfer=TransferFunction.HLG, + output_transfer=output_transfer, + input_gamut=ColorGamut.BT2020, + output_gamut=output_gamut, + preset=preset, + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_sdr( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + normalize_to_linear: bool = False, + preset: Optional[str] = "scale_0_1", + ) -> torch.export.ExportedProgram: + """ + Create and export preprocessor for standard SDR content. + SDR uses sRGB gamma and BT.709 color gamut. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + normalize_to_linear: If True, convert sRGB gamma to linear + preset: Normalization preset (scale_0_1, scale_neg1_1, imagenet) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=( + TransferFunction.SRGB + if normalize_to_linear + else TransferFunction.LINEAR + ), + output_transfer=TransferFunction.LINEAR, + input_gamut=ColorGamut.BT709, + output_gamut=ColorGamut.BT709, + preset=preset, + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + +# ============================================================================= +# ImagePostprocessor +# ============================================================================= + + +class ImagePostprocessor(torch.nn.Module): + """ + Convert model output tensor back to displayable image format. + Inverse of ImagePreprocessor. + + Processing pipeline (inverse of preprocessor): + 1. Convert input color layout (BGR → RGB if needed) + 2. Apply inverse normalization: x = (x / scale) - bias + 3. Convert from grayscale to RGB if needed + 4. Apply inverse output transfer function (sRGB → linear if needed) + 5. Apply inverse color gamut conversion (BT.709 → BT.2020 if needed) + 6. Apply forward transfer function (linear → gamma/PQ/HLG if needed) + 7. Scale to [0, max_value] based on bit depth + 8. Convert output color layout (RGB → BGR if needed) + + Args: + bit_depth: Output bit depth (8, 10, or 12). Default 8 for SDR. + input_transfer: Transfer function of model output (SRGB, LINEAR). + output_transfer: Desired transfer function of output (SRGB, PQ, HLG, LINEAR). + input_gamut: Color gamut of model output (BT709, BT2020). Default BT709. + output_gamut: Desired color gamut of output (BT709, BT2020). Default BT709. + input_color: Color layout of model output (RGB, BGR). Default RGB. + output_color: Desired color layout (RGB, BGR). Default RGB. + channel_bias: Per-channel bias used by preprocessor (will be inverted). + channel_scale: Per-channel scale used by preprocessor (will be inverted). + preset: Optional preset name matching preprocessor preset. + output_dtype: Output data type (torch.float16 or torch.float32). + + Input: + float16/float32 tensor [B, C, H, W] - model output (normalized) + + Output: + float16 tensor [B, C, H, W] with values in [0, max_val] based on bit_depth + Ready for uint8 conversion via vDSP. + """ + + def __init__( + self, + bit_depth: int = 8, + input_transfer: TransferFunction = TransferFunction.LINEAR, + output_transfer: TransferFunction = TransferFunction.LINEAR, + input_gamut: ColorGamut = ColorGamut.BT709, + output_gamut: ColorGamut = ColorGamut.BT709, + input_color: ColorLayout = ColorLayout.RGB, + output_color: ColorLayout = ColorLayout.RGB, + channel_bias: Optional[List[float]] = None, + channel_scale: Optional[List[float]] = None, + preset: Optional[str] = None, + output_dtype: torch.dtype = torch.float16, + ): + super().__init__() + + if bit_depth not in (8, 10, 12): + raise ValueError(f"bit_depth must be 8, 10, or 12, got {bit_depth}") + + self.bit_depth = bit_depth + self.max_value = float((2**bit_depth) - 1) + self.output_dtype = output_dtype + + # Store flags for control flow + self.input_is_bgr = input_color == ColorLayout.BGR + self.output_is_bgr = output_color == ColorLayout.BGR + self.input_is_srgb = input_transfer == TransferFunction.SRGB + self.output_is_srgb = output_transfer == TransferFunction.SRGB + self.output_is_pq = output_transfer == TransferFunction.PQ + self.output_is_hlg = output_transfer == TransferFunction.HLG + + # Use preset if specified + if preset is not None: + if preset not in PRESETS: + raise ValueError( + f"Unknown preset: {preset}. Available: {list(PRESETS.keys())}" + ) + preset_config = PRESETS[preset] + channel_bias = preset_config["bias"] + channel_scale = preset_config["scale"] + + # Default: no normalization to invert + if channel_bias is None: + channel_bias = [0.0, 0.0, 0.0] + if channel_scale is None: + channel_scale = [1.0, 1.0, 1.0] + + # Register color matrix if gamut conversion needed (inverse direction) + if input_gamut != output_gamut: + if input_gamut == ColorGamut.BT709 and output_gamut == ColorGamut.BT2020: + self.register_buffer("color_matrix", BT709_TO_BT2020.to(torch.float16)) + elif input_gamut == ColorGamut.BT2020 and output_gamut == ColorGamut.BT709: + self.register_buffer("color_matrix", BT2020_TO_BT709.to(torch.float16)) + else: + self.color_matrix = None + + # Register bias and scale for inverse normalization + self.register_buffer( + "bias", torch.tensor(channel_bias).view(1, 3, 1, 1).to(torch.float16) + ) + self.register_buffer( + "scale", torch.tensor(channel_scale).view(1, 3, 1, 1).to(torch.float16) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Step 1: Handle BGR → RGB if needed + if self.input_is_bgr: + x = x.flip(dims=[1]) + + # Step 2: Inverse normalization: x = (x / scale) - bias + x = (x / self.scale) - self.bias + + # Step 3: Apply inverse input transfer function (if model output was in sRGB) + if self.input_is_srgb: + x = apply_srgb_inverse(x) + + # Step 4: Apply color gamut conversion if needed + if self.color_matrix is not None: + x = apply_gamut_conversion(x, self.color_matrix) + + # Step 5: Apply output transfer function (to target gamma) + if self.output_is_srgb: + x = apply_srgb_gamma(x) + elif self.output_is_pq: + x = apply_pq_forward(x) + elif self.output_is_hlg: + x = apply_hlg_forward(x) + + # Step 6: Scale to [0, max_value] based on bit depth + x = x * self.max_value + + # Step 7: Clamp to valid range + x = x.clamp(min=0.0, max=self.max_value) + + # Step 8: Convert RGB → BGR if needed + if self.output_is_bgr: + x = x.flip(dims=[1]) + + # Step 9: Convert to output dtype + return x.to(self.output_dtype) + + # ==================== Factory Methods ==================== + # These return ExportedProgram ready to lower to any backend + + @classmethod + def _export( + cls, + model: "ImagePostprocessor", + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype, + ) -> torch.export.ExportedProgram: + """Helper to export a model to ExportedProgram.""" + # Convert model buffers to match input dtype for consistent precision + model = model.to(input_dtype) + model.eval() + example_inputs = (torch.randn(*shape, dtype=input_dtype),) + return torch.export.export(model, example_inputs, strict=True) + + @classmethod + def from_scale_0_1( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + input_color: ColorLayout = ColorLayout.RGB, + output_color: ColorLayout = ColorLayout.RGB, + ) -> torch.export.ExportedProgram: + """ + Create and export postprocessor for model output in [0, 1] range. + Converts [0, 1] → [0, 255]. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + input_color: Input color layout (RGB, BGR) + output_color: Output color layout (RGB, BGR) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.LINEAR, + input_color=input_color, + output_color=output_color, + preset="scale_0_1", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_scale_neg1_1( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + input_color: ColorLayout = ColorLayout.RGB, + output_color: ColorLayout = ColorLayout.RGB, + ) -> torch.export.ExportedProgram: + """ + Create and export postprocessor for model output in [-1, 1] range. + Converts [-1, 1] → [0, 255]. + Common for GANs and diffusion models. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + input_color: Input color layout (RGB, BGR) + output_color: Output color layout (RGB, BGR) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.LINEAR, + input_color=input_color, + output_color=output_color, + preset="scale_neg1_1", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_imagenet( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + output_color: ColorLayout = ColorLayout.RGB, + ) -> torch.export.ExportedProgram: + """ + Create and export postprocessor for ImageNet-normalized model output. + Inverts ImageNet normalization → [0, 255]. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + output_color: Output color layout (RGB, BGR) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.LINEAR, + output_color=output_color, + preset="imagenet", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_linear_to_srgb( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, + output_dtype: torch.dtype = torch.float16, + input_color: ColorLayout = ColorLayout.RGB, + output_color: ColorLayout = ColorLayout.RGB, + ) -> torch.export.ExportedProgram: + """ + Create and export postprocessor that converts linear [0, 1] to sRGB [0, 255]. + Useful for HDR models that output linear light. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + input_color: Input color layout (RGB, BGR) + output_color: Output color layout (RGB, BGR) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=8, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.SRGB, + input_color=input_color, + output_color=output_color, + preset="scale_0_1", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_linear_to_hdr10( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float16, # CoreML uses fp16 internally + output_dtype: torch.dtype = torch.float16, + bit_depth: int = 10, + input_gamut: ColorGamut = ColorGamut.BT709, + ) -> torch.export.ExportedProgram: + """ + Create and export postprocessor that converts linear [0, 1] to HDR10. + Outputs PQ-encoded BT.2020 content. + + NOTE: CoreML uses fp16 internally, which may affect PQ precision. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype + output_dtype: Output tensor dtype + bit_depth: Output bit depth (10 or 12) + input_gamut: Input color gamut (BT709 or BT2020) + + Returns: + ExportedProgram ready to lower to any backend + """ + if input_dtype == torch.float16: + warnings.warn( + "HDR10 (PQ transfer function) has significant precision loss in fp16 " + "due to high-power exponents (m2=78.84). Expect up to 50% relative error " + "compared to fp32. Consider using fp32 for input_dtype if accuracy " + "is critical.", + UserWarning, + stacklevel=2, + ) + model = cls( + bit_depth=bit_depth, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.PQ, + input_gamut=input_gamut, + output_gamut=ColorGamut.BT2020, + preset="scale_0_1", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) + + @classmethod + def from_linear_to_hlg( + cls, + shape: Tuple[int, int, int, int], + input_dtype: torch.dtype = torch.float32, # float32 recommended for HDR + output_dtype: torch.dtype = torch.float16, + bit_depth: int = 10, + input_gamut: ColorGamut = ColorGamut.BT709, + ) -> torch.export.ExportedProgram: + """ + Create and export postprocessor that converts linear [0, 1] to HLG. + Outputs HLG-encoded BT.2020 content for broadcast HDR. + + NOTE: float32 input_dtype is recommended for accurate HLG calculations. + + Args: + shape: Input shape (batch, channels, height, width) + input_dtype: Input tensor dtype (float32 recommended for HDR) + output_dtype: Output tensor dtype + bit_depth: Output bit depth (10 or 12) + input_gamut: Input color gamut (BT709 or BT2020) + + Returns: + ExportedProgram ready to lower to any backend + """ + model = cls( + bit_depth=bit_depth, + input_transfer=TransferFunction.LINEAR, + output_transfer=TransferFunction.HLG, + input_gamut=input_gamut, + output_gamut=ColorGamut.BT2020, + preset="scale_0_1", + output_dtype=output_dtype, + ) + return cls._export(model, shape, input_dtype) diff --git a/extension/vision/test/__init__.py b/extension/vision/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/extension/vision/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/extension/vision/test/test_image_processing.py b/extension/vision/test/test_image_processing.py new file mode 100644 index 00000000000..0fdf8ee13a5 --- /dev/null +++ b/extension/vision/test/test_image_processing.py @@ -0,0 +1,856 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for image processing utilities. + +Compares our implementations against the colour-science library +which provides reference implementations of ITU/SMPTE standards. +""" + +import unittest + +import numpy as np +import torch +from executorch.extension.vision.image_processing import ( + apply_hlg_forward, + apply_hlg_inverse, + apply_pq_forward, + apply_pq_inverse, + apply_srgb_gamma, + apply_srgb_inverse, + BT2020_TO_BT709, + BT709_TO_BT2020, + ImagePostprocessor, + ImagePreprocessor, + LUMINANCE_WEIGHTS, +) +from parameterized import parameterized + +# colour-science is optional - tests will skip if not installed +try: + import colour + + HAS_COLOUR = True +except ImportError: + HAS_COLOUR = False + + +# ============================================================================= +# Constants +# ============================================================================= + +IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) +IMAGENET_STD = np.array([0.229, 0.224, 0.225]) + +IMAGENET_MEAN_TORCH = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1) +IMAGENET_STD_TORCH = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1) + + +def max_value_for_bit_depth(bit_depth: int) -> int: + """Return max pixel value for given bit depth (e.g., 255 for 8-bit, 1023 for 10-bit).""" + return (2**bit_depth) - 1 + + +def requires_colour(test_func): + """Decorator to skip tests if colour-science is not installed.""" + return unittest.skipUnless(HAS_COLOUR, "colour-science not installed")(test_func) + + +class TestTransferFunctionsAgainstReference(unittest.TestCase): + """Test transfer functions against colour-science reference implementations.""" + + @requires_colour + def test_pq_inverse_against_reference(self): + """Test PQ (ST.2084) EOTF against colour-science.""" + # Test values across the PQ range + test_values = np.linspace(0.01, 1.0, 100) + + # Our implementation + x = torch.tensor(test_values, dtype=torch.float32).view(1, 1, 10, 10) + ours = apply_pq_inverse(x).numpy().flatten() + + # Reference: colour-science ST.2084 EOTF + # Note: colour uses normalized output [0, 1] when we pass normalized=True + ref = colour.eotf(test_values, "ST 2084", L_p=10000) + # Normalize to [0, 1] since colour returns absolute luminance + ref = ref / 10000.0 + + max_error = np.abs(ours - ref).max() + self.assertLess( + max_error, + 1e-4, + f"PQ inverse max error {max_error:.6f} exceeds tolerance", + ) + + @requires_colour + def test_pq_forward_against_reference(self): + """Test PQ (ST.2084) inverse EOTF against colour-science.""" + # Test linear values + test_values = np.linspace(0.001, 1.0, 100) + + # Our implementation + x = torch.tensor(test_values, dtype=torch.float32).view(1, 1, 10, 10) + ours = apply_pq_forward(x).numpy().flatten() + + # Reference: colour-science ST.2084 inverse EOTF (OETF) + # Input is absolute luminance, so scale by 10000 + ref = colour.eotf_inverse(test_values * 10000, "ST 2084", L_p=10000) + + max_error = np.abs(ours - ref).max() + self.assertLess( + max_error, + 1e-4, + f"PQ forward max error {max_error:.6f} exceeds tolerance", + ) + + @requires_colour + def test_hlg_inverse_against_reference(self): + """Test HLG OETF inverse against colour-science.""" + # Test values in the HLG range + test_values = np.linspace(0.01, 1.0, 100) + + # Our implementation + x = torch.tensor(test_values, dtype=torch.float32).view(1, 1, 10, 10) + ours = apply_hlg_inverse(x).numpy().flatten() + + # Reference: colour-science HLG OETF inverse + ref = colour.models.oetf_inverse_BT2100_HLG(test_values) + + max_error = np.abs(ours - ref).max() + self.assertLess( + max_error, + 1e-5, + f"HLG inverse max error {max_error:.6f} exceeds tolerance", + ) + + @requires_colour + def test_hlg_forward_against_reference(self): + """Test HLG OETF against colour-science.""" + # Test linear values + test_values = np.linspace(0.001, 1.0, 100) + + # Our implementation + x = torch.tensor(test_values, dtype=torch.float32).view(1, 1, 10, 10) + ours = apply_hlg_forward(x).numpy().flatten() + + # Reference: colour-science HLG OETF + ref = colour.models.oetf_BT2100_HLG(test_values) + + max_error = np.abs(ours - ref).max() + self.assertLess( + max_error, + 1e-5, + f"HLG forward max error {max_error:.6f} exceeds tolerance", + ) + + @requires_colour + def test_srgb_gamma_against_reference(self): + """Test sRGB gamma (x^(1/2.2) approximation) against colour-science. + + Note: We use the x^(1/2.2) approximation, not the piecewise sRGB transfer. + This test documents the expected divergence. + """ + # Test in mid-to-high range where approximation is reasonable + test_values = np.linspace(0.1, 1.0, 100) + + # Our implementation (x^(1/2.2) approximation) + x = torch.tensor(test_values, dtype=torch.float32).view(1, 1, 10, 10) + ours = apply_srgb_gamma(x).numpy().flatten() + + # Simple gamma approximation reference + ref = np.power(test_values, 1.0 / 2.2) + + max_error = np.abs(ours - ref).max() + self.assertLess( + max_error, + 1e-6, + f"sRGB gamma max error {max_error:.6f} exceeds tolerance", + ) + + @requires_colour + def test_srgb_inverse_against_reference(self): + """Test sRGB inverse gamma (x^2.2 approximation) against colour-science.""" + # Test in mid-to-high range + test_values = np.linspace(0.1, 1.0, 100) + + # Our implementation (x^2.2 approximation) + x = torch.tensor(test_values, dtype=torch.float32).view(1, 1, 10, 10) + ours = apply_srgb_inverse(x).numpy().flatten() + + # Simple gamma approximation reference + ref = np.power(test_values, 2.2) + + max_error = np.abs(ours - ref).max() + self.assertLess( + max_error, + 1e-6, + f"sRGB inverse max error {max_error:.6f} exceeds tolerance", + ) + + +class TestGamutMatricesAgainstReference(unittest.TestCase): + """Test color gamut conversion matrices against colour-science.""" + + @requires_colour + def test_bt2020_to_bt709_matrix(self): + """Test BT.2020 to BT.709 conversion matrix against colour-science.""" + # Get reference matrix from colour-science + # BT.2020 primaries and whitepoint + bt2020 = colour.RGB_COLOURSPACES["ITU-R BT.2020"] + bt709 = colour.RGB_COLOURSPACES["ITU-R BT.709"] + + # Compute conversion matrix + ref_matrix = colour.matrix_RGB_to_RGB(bt2020, bt709) + + ours = BT2020_TO_BT709.numpy() + + max_error = np.abs(ours - ref_matrix).max() + self.assertLess( + max_error, + 0.01, # Allow some tolerance for different derivations + f"BT.2020→BT.709 matrix max error {max_error:.6f}", + ) + + @requires_colour + def test_bt709_to_bt2020_matrix(self): + """Test BT.709 to BT.2020 conversion matrix against colour-science.""" + bt2020 = colour.RGB_COLOURSPACES["ITU-R BT.2020"] + bt709 = colour.RGB_COLOURSPACES["ITU-R BT.709"] + + # Compute conversion matrix + ref_matrix = colour.matrix_RGB_to_RGB(bt709, bt2020) + + ours = BT709_TO_BT2020.numpy() + + max_error = np.abs(ours - ref_matrix).max() + self.assertLess( + max_error, + 0.01, + f"BT.709→BT.2020 matrix max error {max_error:.6f}", + ) + + @requires_colour + def test_gamut_matrices_are_inverses(self): + """Verify that our gamut matrices are inverses of each other.""" + product = BT2020_TO_BT709 @ BT709_TO_BT2020 + identity = torch.eye(3) + + max_error = (product - identity).abs().max().item() + self.assertLess( + max_error, + 1e-4, + f"Gamut matrices not inverses, max error {max_error:.6f}", + ) + + +class TestLuminanceWeights(unittest.TestCase): + """Test luminance weights for grayscale conversion.""" + + def test_luminance_weights_are_bt601(self): + """Verify luminance weights match BT.601 standard.""" + # BT.601 standard weights + bt601_weights = np.array([0.299, 0.587, 0.114]) + + ours = LUMINANCE_WEIGHTS.numpy() + + max_error = np.abs(ours - bt601_weights).max() + self.assertLess( + max_error, + 1e-6, + f"Luminance weights don't match BT.601: {ours} vs {bt601_weights}", + ) + + def test_luminance_weights_sum_to_one(self): + """Luminance weights should sum to 1.0.""" + weight_sum = LUMINANCE_WEIGHTS.sum().item() + self.assertAlmostEqual( + weight_sum, + 1.0, + places=6, + msg=f"Luminance weights sum to {weight_sum}, expected 1.0", + ) + + +class TestTransferFunctionRoundtrip(unittest.TestCase): + """Test that forward/inverse transfer functions are true inverses.""" + + def setUp(self): + torch.manual_seed(42) + + def test_pq_roundtrip(self): + """PQ forward then inverse should return original values.""" + original = torch.rand(1, 3, 32, 32, dtype=torch.float32) * 0.9 + 0.05 + encoded = apply_pq_forward(original) + decoded = apply_pq_inverse(encoded) + + max_error = (original - decoded).abs().max().item() + self.assertLess( + max_error, + 2e-4, + f"PQ roundtrip max error {max_error:.6f}", + ) + + def test_hlg_roundtrip(self): + """HLG forward then inverse should return original values.""" + original = torch.rand(1, 3, 32, 32, dtype=torch.float32) * 0.9 + 0.05 + encoded = apply_hlg_forward(original) + decoded = apply_hlg_inverse(encoded) + + max_error = (original - decoded).abs().max().item() + self.assertLess( + max_error, + 1e-5, + f"HLG roundtrip max error {max_error:.6f}", + ) + + def test_srgb_roundtrip(self): + """sRGB gamma then inverse should return original values.""" + original = torch.rand(1, 3, 32, 32, dtype=torch.float32) * 0.9 + 0.05 + encoded = apply_srgb_gamma(original) + decoded = apply_srgb_inverse(encoded) + + max_error = (original - decoded).abs().max().item() + self.assertLess( + max_error, + 1e-5, + f"sRGB roundtrip max error {max_error:.6f}", + ) + + +# ============================================================================= +# E2E Tests: ExportedProgram vs colour-science reference +# ============================================================================= + + +def generate_test_pattern( + height: int, width: int, bit_depth: int, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + """ + Generate a synthetic test pattern with gradients, color bars, and gray ramp. + + Returns: [1, 3, H, W] tensor with values in [0, max_value] + """ + max_value = (2**bit_depth) - 1 + img = torch.zeros(1, 3, height, width, dtype=dtype) + + h_third = height // 3 + + # Top third: horizontal gradient per channel + for c in range(3): + gradient = torch.linspace(0, max_value, width, dtype=dtype) + img[0, c, :h_third, :] = gradient.unsqueeze(0).expand(h_third, -1) + # Offset each channel slightly + img[0, c, :h_third, :] = (img[0, c, :h_third, :] * (0.8 + 0.1 * c)).clamp( + 0, max_value + ) + + # Middle third: color bars (R, G, B, C, M, Y, W, K) + colors = [ + [max_value, 0, 0], # Red + [0, max_value, 0], # Green + [0, 0, max_value], # Blue + [0, max_value, max_value], # Cyan + [max_value, 0, max_value], # Magenta + [max_value, max_value, 0], # Yellow + [max_value, max_value, max_value], # White + [0, 0, 0], # Black + ] + bar_width = width // len(colors) + for i, color in enumerate(colors): + start = i * bar_width + end = start + bar_width if i < len(colors) - 1 else width + for c in range(3): + img[0, c, h_third : 2 * h_third, start:end] = color[c] + + # Bottom third: gray ramp + gray_ramp = torch.linspace(0, max_value, width, dtype=dtype) + for c in range(3): + img[0, c, 2 * h_third :, :] = gray_ramp.unsqueeze(0).expand( + height - 2 * h_third, -1 + ) + + return img + + +# Reference functions using colour-science + + +def _convert_gamut(img: np.ndarray, source: str, dest: str) -> np.ndarray: + """Convert between color gamuts using colour-science. + + Args: + img: Image array in CHW format + source: Source colorspace (e.g., "ITU-R BT.2020") + dest: Destination colorspace (e.g., "ITU-R BT.709") + + Returns: + Converted image in CHW format + """ + img = np.moveaxis(img, 0, -1) + img = colour.RGB_to_RGB( + img, + colour.RGB_COLOURSPACES[source], + colour.RGB_COLOURSPACES[dest], + chromatic_adaptation_transform=None, + ) + return np.moveaxis(img, -1, 0) + + +def ref_scale_0_1(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: Simple [0, max] → [0, 1] scaling.""" + max_value = (2**bit_depth) - 1 + return (img / max_value).astype(np.float32) + + +def ref_scale_neg1_1(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: [0, max] → [-1, 1] scaling.""" + max_value = (2**bit_depth) - 1 + return ((img / max_value - 0.5) * 2.0).astype(np.float32) + + +def ref_imagenet(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: ImageNet normalization.""" + max_value = max_value_for_bit_depth(bit_depth) + img = img / max_value + mean = IMAGENET_MEAN.reshape(3, 1, 1) + std = IMAGENET_STD.reshape(3, 1, 1) + return ((img - mean) / std).astype(np.float32) + + +@requires_colour +def ref_hdr10_to_linear_bt709(img: np.ndarray, bit_depth: int = 10) -> np.ndarray: + """Reference: HDR10 (PQ BT.2020) → linear BT.709.""" + max_value = (2**bit_depth) - 1 + img = img / max_value + img = colour.models.eotf_ST2084(img) / 10000.0 + img = _convert_gamut(img, "ITU-R BT.2020", "ITU-R BT.709") + return np.clip(img, 0, 1).astype(np.float32) + + +@requires_colour +def ref_hlg_to_linear_bt709(img: np.ndarray, bit_depth: int = 10) -> np.ndarray: + """Reference: HLG BT.2020 → linear BT.709.""" + max_value = (2**bit_depth) - 1 + img = img / max_value + img = colour.models.oetf_inverse_BT2100_HLG(img) + img = _convert_gamut(img, "ITU-R BT.2020", "ITU-R BT.709") + return np.clip(img, 0, 1).astype(np.float32) + + +def ref_sdr_to_linear(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: SDR (sRGB gamma) → linear.""" + max_value = (2**bit_depth) - 1 + img = img / max_value + return np.power(np.clip(img, 1e-6, 1.0), 2.2).astype(np.float32) + + +# Postprocessor reference functions (inverse operations) + + +def ref_post_scale_0_1(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: [0, 1] → [0, max] scaling.""" + max_value = (2**bit_depth) - 1 + return (img * max_value).astype(np.float32) + + +def ref_post_scale_neg1_1(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: [-1, 1] → [0, max] scaling.""" + max_value = (2**bit_depth) - 1 + return ((img / 2.0 + 0.5) * max_value).astype(np.float32) + + +def ref_post_imagenet(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: Reverse ImageNet normalization.""" + max_value = max_value_for_bit_depth(bit_depth) + mean = IMAGENET_MEAN.reshape(3, 1, 1) + std = IMAGENET_STD.reshape(3, 1, 1) + img = img * std + mean + return np.clip(img * max_value, 0, max_value).astype(np.float32) + + +@requires_colour +def ref_linear_bt709_to_hdr10(img: np.ndarray, bit_depth: int = 10) -> np.ndarray: + """Reference: linear BT.709 → HDR10 (PQ BT.2020).""" + max_value = (2**bit_depth) - 1 + img = _convert_gamut(img, "ITU-R BT.709", "ITU-R BT.2020") + img = np.clip(img, 0, 1) + img = colour.models.eotf_inverse_ST2084(img * 10000) + return (img * max_value).astype(np.float32) + + +@requires_colour +def ref_linear_bt709_to_hlg(img: np.ndarray, bit_depth: int = 10) -> np.ndarray: + """Reference: linear BT.709 → HLG BT.2020.""" + max_value = (2**bit_depth) - 1 + img = _convert_gamut(img, "ITU-R BT.709", "ITU-R BT.2020") + img = np.clip(img, 0, 1) + img = colour.models.oetf_BT2100_HLG(img) + return (img * max_value).astype(np.float32) + + +def ref_linear_to_srgb(img: np.ndarray, bit_depth: int = 8) -> np.ndarray: + """Reference: linear → sRGB gamma [0, 255].""" + max_value = (2**bit_depth) - 1 + img = np.power(np.clip(img, 1e-6, 1.0), 1.0 / 2.2) + return (img * max_value).astype(np.float32) + + +class TestE2EPipelines(unittest.TestCase): + """ + End-to-end tests comparing ExportedProgram output against colour-science. + + These tests validate that our image processing pipelines produce results + matching the reference colour-science implementations. + """ + + def setUp(self): + torch.manual_seed(42) + + # ==================== Tolerance Constants ==================== + # Tolerances determined empirically - see test analysis for actual error measurements. + # fp16 tolerances are looser due to inherent precision limitations. + + TOLERANCES = { + torch.float32: {"rtol": 1e-3, "atol": 1e-3}, # Simple scaling: max diff ~0.0003 + torch.float16: {"rtol": 0.01, "atol": 0.01}, + } + + HDR_TOLERANCES = { + torch.float32: { + "rtol": 0.005, + "atol": 1e-3, + }, # HLG transfer is very accurate (~6e-8) + torch.float16: {"rtol": 0.05, "atol": 0.05}, + } + + HDR10_TOLERANCES = { + torch.float32: {"rtol": 0.005, "atol": 1e-3}, + torch.float16: { + "rtol": 0.5, + "atol": 0.25, + }, # PQ has significant fp16 precision loss due to m2=78.84 exponent + } + + IMAGENET_TOLERANCES = { + torch.float32: {"rtol": 0.001, "atol": 0.001}, + torch.float16: {"rtol": 0.01, "atol": 0.01}, + } + + POST_IMAGENET_TOLERANCES = { + torch.float32: {"rtol": 0.001, "atol": 0.5}, # atol in [0, 255] range + torch.float16: {"rtol": 0.01, "atol": 1.0}, + } + + POST_HDR10_TOLERANCES = { + torch.float32: {"rtol": 0.005, "atol": 2.0}, # atol in [0, 1023] range + torch.float16: {"rtol": 0.1, "atol": 25.0}, # PQ fp16 precision loss + } + + POST_HLG_TOLERANCES = { + torch.float32: { + "rtol": 0.005, + "atol": 2.0, + }, # Edge cases near zero can have larger errors + torch.float16: {"rtol": 0.02, "atol": 5.0}, + } + + # ==================== Helper Methods ==================== + + def _run_exported_program( + self, ep: torch.export.ExportedProgram, inputs: torch.Tensor + ) -> np.ndarray: + """Run an ExportedProgram and return output as numpy array.""" + module = ep.module() + with torch.no_grad(): + output = module(inputs) + return output.numpy() + + def _compare_with_reference( + self, + ep: torch.export.ExportedProgram, + reference_fn, + test_img: torch.Tensor, + bit_depth: int, + rtol: float, + atol: float, + ): + """Compare ExportedProgram output against reference implementation.""" + ep_out = self._run_exported_program(ep, test_img) + reference_out = reference_fn(test_img[0].float().numpy(), bit_depth=bit_depth) + + np.testing.assert_allclose( + ep_out[0], + reference_out, + rtol=rtol, + atol=atol, + err_msg=f"ExportedProgram does not match reference (dtype={test_img.dtype})", + ) + + def _test_pipeline( + self, + processor_class: str, + factory_method: str, + reference_fn, + bit_depth: int = 8, + factory_kwargs: dict = None, + tolerances: dict = None, + input_transform=None, + uses_output_dtype: bool = False, + ): + """ + Test a processor factory method with both fp32 and fp16. + + Args: + processor_class: "ImagePreprocessor" or "ImagePostprocessor" + factory_method: Name of the factory method (e.g., "from_scale_0_1") + reference_fn: Reference function for comparison + bit_depth: Bit depth for test pattern + factory_kwargs: Additional kwargs for the factory method + tolerances: Override default tolerances {torch.float32: {...}, torch.float16: {...}} + input_transform: Optional transform to apply to test input + uses_output_dtype: If True, also pass output_dtype=dtype to factory method + """ + from executorch.extension.vision.image_processing import ( + ImagePostprocessor, + ImagePreprocessor, + ) + + processor_cls = ( + ImagePreprocessor + if processor_class == "ImagePreprocessor" + else ImagePostprocessor + ) + factory_kwargs = factory_kwargs or {} + tolerances = tolerances or self.TOLERANCES + shape = (1, 3, 64, 96) + + for dtype in [torch.float32, torch.float16]: + with self.subTest(dtype=dtype): + kwargs = {"shape": shape, "input_dtype": dtype, **factory_kwargs} + if uses_output_dtype: + kwargs["output_dtype"] = dtype + + ep = getattr(processor_cls, factory_method)(**kwargs) + + test_img = generate_test_pattern( + height=64, width=96, bit_depth=bit_depth, dtype=dtype + ) + if input_transform: + test_img = input_transform(test_img.to(torch.float32)).to(dtype) + + tol = tolerances.get(dtype, tolerances[torch.float32]) + self._compare_with_reference( + ep=ep, + reference_fn=reference_fn, + test_img=test_img, + bit_depth=bit_depth, + **tol, + ) + + # ==================== Parameterized E2E Tests ==================== + + # Preprocessor test cases: (name, factory_method, reference_fn, bit_depth, factory_kwargs, tolerances_key, input_transform, uses_output_dtype, needs_colour) + PREPROCESSOR_CASES = [ + ( + "scale_0_1", + "from_scale_0_1", + ref_scale_0_1, + 8, + None, + None, + None, + False, + False, + ), + ( + "scale_neg1_1", + "from_scale_neg1_1", + ref_scale_neg1_1, + 8, + None, + None, + None, + False, + False, + ), + ( + "imagenet", + "from_imagenet", + ref_imagenet, + 8, + None, + "IMAGENET_TOLERANCES", + None, + False, + False, + ), + ( + "hdr10_to_linear_bt709", + "from_hdr10", + ref_hdr10_to_linear_bt709, + 10, + None, + "HDR10_TOLERANCES", + lambda x: x.clamp(min=100), + True, + True, + ), + ( + "hlg_to_linear_bt709", + "from_hlg", + ref_hlg_to_linear_bt709, + 10, + None, + "HDR_TOLERANCES", + None, + True, + True, + ), + ( + "sdr_to_linear", + "from_sdr", + ref_sdr_to_linear, + 8, + {"normalize_to_linear": True}, + None, + None, + True, + False, + ), + ] + + # Postprocessor test cases: (name, factory_method, reference_fn, bit_depth, factory_kwargs, tolerances_key, input_transform, uses_output_dtype, needs_colour) + POSTPROCESSOR_CASES = [ + ( + "scale_0_1", + "from_scale_0_1", + ref_post_scale_0_1, + 8, + None, + None, + lambda x: x / 255.0, + False, + False, + ), + ( + "scale_neg1_1", + "from_scale_neg1_1", + ref_post_scale_neg1_1, + 8, + None, + None, + lambda x: (x / 255.0 - 0.5) * 2.0, + False, + False, + ), + ( + "imagenet", + "from_imagenet", + ref_post_imagenet, + 8, + None, + "POST_IMAGENET_TOLERANCES", + lambda x: (x / 255.0 - IMAGENET_MEAN_TORCH) / IMAGENET_STD_TORCH, + False, + False, + ), + ( + "linear_to_hdr10", + "from_linear_to_hdr10", + ref_linear_bt709_to_hdr10, + 10, + None, + "POST_HDR10_TOLERANCES", + lambda x: (x / 1023.0).clamp(min=0.01), + True, + True, + ), + ( + "linear_to_hlg", + "from_linear_to_hlg", + ref_linear_bt709_to_hlg, + 10, + None, + "POST_HLG_TOLERANCES", + lambda x: x / 1023.0, + True, + True, + ), + ( + "linear_to_srgb", + "from_linear_to_srgb", + ref_linear_to_srgb, + 8, + None, + None, + lambda x: x / 255.0, + True, + False, + ), + ] + + @parameterized.expand(PREPROCESSOR_CASES) + def test_e2e_preprocessor( + self, + name, + factory_method, + reference_fn, + bit_depth, + factory_kwargs, + tolerances_key, + input_transform, + uses_output_dtype, + needs_colour, + ): + """E2E test for ImagePreprocessor pipelines.""" + if needs_colour and not HAS_COLOUR: + self.skipTest("colour-science not installed") + + tolerances = getattr(self, tolerances_key) if tolerances_key else None + self._test_pipeline( + "ImagePreprocessor", + factory_method, + reference_fn, + bit_depth=bit_depth, + factory_kwargs=factory_kwargs, + tolerances=tolerances, + input_transform=input_transform, + uses_output_dtype=uses_output_dtype, + ) + + @parameterized.expand(POSTPROCESSOR_CASES) + def test_e2e_postprocessor( + self, + name, + factory_method, + reference_fn, + bit_depth, + factory_kwargs, + tolerances_key, + input_transform, + uses_output_dtype, + needs_colour, + ): + """E2E test for ImagePostprocessor pipelines.""" + if needs_colour and not HAS_COLOUR: + self.skipTest("colour-science not installed") + + tolerances = getattr(self, tolerances_key) if tolerances_key else None + self._test_pipeline( + "ImagePostprocessor", + factory_method, + reference_fn, + bit_depth=bit_depth, + factory_kwargs=factory_kwargs, + tolerances=tolerances, + input_transform=input_transform, + uses_output_dtype=uses_output_dtype, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/pyproject.toml b/pyproject.toml index 7a4ce277ade..fc34609b3e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ classifiers = [ requires-python = ">=3.10,<3.14" dependencies=[ + "colour-science", "expecttest", "flatbuffers", "hypothesis",