diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 84d006eefe..de81c01b40 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -203,3 +203,6 @@ fishjam Fishjam deinitialize Deinitialize +fastsam +promptable +topk diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index 03770c2720..a4868f92ae 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -189,6 +189,14 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} /> + ); diff --git a/apps/computer-vision/app/index.tsx b/apps/computer-vision/app/index.tsx index 15b9d8650b..e67e7eb5cb 100644 --- a/apps/computer-vision/app/index.tsx +++ b/apps/computer-vision/app/index.tsx @@ -47,6 +47,12 @@ export default function Home() { > Pose Estimation + router.navigate('segment_anything/')} + > + Segment Anything + router.navigate('ocr/')} diff --git a/apps/computer-vision/app/instance_segmentation/index.tsx b/apps/computer-vision/app/instance_segmentation/index.tsx index dba53875e5..f669c383d5 100644 --- a/apps/computer-vision/app/instance_segmentation/index.tsx +++ b/apps/computer-vision/app/instance_segmentation/index.tsx @@ -11,6 +11,8 @@ import { YOLO26X_SEG, RF_DETR_NANO_SEG, InstanceSegmentationModelSources, + FASTSAM_S, + FASTSAM_X, } from 'react-native-executorch'; import { View, @@ -35,6 +37,8 @@ const MODELS: ModelOption[] = [ { label: 'Yolo26L', value: YOLO26L_SEG }, { label: 'Yolo26X', value: YOLO26X_SEG }, { label: 'RF-DeTR Nano', value: RF_DETR_NANO_SEG }, + { label: 'FastSAM-S', value: FASTSAM_S }, + { label: 'FastSAM-X', value: FASTSAM_X }, ]; export default function InstanceSegmentationScreen() { diff --git a/apps/computer-vision/app/segment_anything/index.tsx b/apps/computer-vision/app/segment_anything/index.tsx new file mode 100644 index 0000000000..037a988327 --- /dev/null +++ b/apps/computer-vision/app/segment_anything/index.tsx @@ -0,0 +1,613 @@ +import React, { useContext, useEffect, useRef, useState } from 'react'; +import { + View, + StyleSheet, + Text, + TextInput, + TouchableOpacity, + TouchableWithoutFeedback, + GestureResponderEvent, + Keyboard, + KeyboardAvoidingView, + Platform, +} from 'react-native'; +import { + Canvas, + Rect, + Skia, + useImage, + type SkImage, + ColorType, + AlphaType, +} from '@shopify/react-native-skia'; +import { + useInstanceSegmentation, + useImageEmbeddings, + useTextEmbeddings, + FASTSAM_S, + FASTSAM_X, + CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED, + CLIP_VIT_BASE_PATCH32_TEXT, + InstanceSegmentationModelSources, + SegmentedInstance, + FastSAMLabel, + selectByPoint, + selectByBox, + selectByText, + Bbox, +} from 'react-native-executorch'; +import { GeneratingContext } from '../../context'; +import { ModelPicker, ModelOption } from '../../components/ModelPicker'; +import { BottomBar } from '../../components/BottomBar'; +import { StatsBar } from '../../components/StatsBar'; +import Spinner from '../../components/Spinner'; +import ScreenWrapper from '../../ScreenWrapper'; +import ImageWithMasks, { + buildDisplayInstances, + DisplayInstance, +} from '../../components/ImageWithMasks'; +import { getImage } from '../../utils'; +import ColorPalette from '../../colors'; + +type PromptMode = 'point' | 'box' | 'text'; + +const MODELS: ModelOption[] = [ + { label: 'FastSAM-S', value: FASTSAM_S }, + { label: 'FastSAM-X', value: FASTSAM_X }, +]; + +export default function SegmentAnythingScreen() { + const { setGlobalGenerating } = useContext(GeneratingContext); + + const [selectedModel, setSelectedModel] = + useState(FASTSAM_S); + const [mode, setMode] = useState('point'); + const [inferenceTime, setInferenceTime] = useState(null); + + const [imageUri, setImageUri] = useState(''); + const [imageSize, setImageSize] = useState({ width: 0, height: 0 }); + + const rawInstancesRef = useRef[]>([]); + const [selection, setSelection] = useState([]); + + const [draftBox, setDraftBox] = useState(null); + const boxStartRef = useRef<{ x: number; y: number } | null>(null); + const layoutRef = useRef({ width: 0, height: 0 }); + + const { isReady, isGenerating, downloadProgress, forward, error } = + useInstanceSegmentation({ model: selectedModel }); + + const clipImage = useImageEmbeddings({ + model: CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED, + }); + const clipText = useTextEmbeddings({ model: CLIP_VIT_BASE_PATCH32_TEXT }); + const skiaSource = useImage(imageUri || null); + + const [textPrompt, setTextPrompt] = useState(''); + const [textBusy, setTextBusy] = useState(false); + const [embeddingProgress, setEmbeddingProgress] = useState<{ + done: number; + total: number; + } | null>(null); + const instanceEmbeddingsRef = useRef(null); + + useEffect(() => { + setGlobalGenerating(isGenerating); + }, [isGenerating, setGlobalGenerating]); + + function applyMatch( + match: SegmentedInstance | null + ): void { + setSelection(match ? buildDisplayInstances([match]) : []); + } + + function touchToImageCoords(touchX: number, touchY: number) { + const { width: cw, height: ch } = layoutRef.current; + const { width: iw, height: ih } = imageSize; + if (iw === 0 || ih === 0) return null; + const scale = Math.min(cw / iw, ch / ih); + return { + x: (touchX - (cw - iw * scale) / 2) / scale, + y: (touchY - (ch - ih * scale) / 2) / scale, + }; + } + + function handleTap(e: GestureResponderEvent) { + if (mode !== 'point' || rawInstancesRef.current.length === 0) return; + const c = touchToImageCoords( + e.nativeEvent.locationX, + e.nativeEvent.locationY + ); + if (!c) return; + applyMatch( + selectByPoint(rawInstancesRef.current, Math.round(c.x), Math.round(c.y)) + ); + } + + function handleBoxStart(e: GestureResponderEvent) { + if (mode !== 'box') return; + const c = touchToImageCoords( + e.nativeEvent.locationX, + e.nativeEvent.locationY + ); + if (!c) return; + boxStartRef.current = c; + setDraftBox({ x1: c.x, y1: c.y, x2: c.x, y2: c.y }); + } + + function handleBoxMove(e: GestureResponderEvent) { + if (mode !== 'box' || !boxStartRef.current) return; + const c = touchToImageCoords( + e.nativeEvent.locationX, + e.nativeEvent.locationY + ); + if (!c) return; + const s = boxStartRef.current; + setDraftBox({ + x1: Math.min(s.x, c.x), + y1: Math.min(s.y, c.y), + x2: Math.max(s.x, c.x), + y2: Math.max(s.y, c.y), + }); + } + + function handleBoxEnd(e: GestureResponderEvent) { + if (mode !== 'box' || !boxStartRef.current) return; + const c = touchToImageCoords( + e.nativeEvent.locationX, + e.nativeEvent.locationY + ); + const s = boxStartRef.current; + boxStartRef.current = null; + setDraftBox(null); + if (!c || rawInstancesRef.current.length === 0) return; + applyMatch( + selectByBox(rawInstancesRef.current, { + x1: Math.min(s.x, c.x), + y1: Math.min(s.y, c.y), + x2: Math.max(s.x, c.x), + y2: Math.max(s.y, c.y), + }) + ); + } + + async function runTextPrompt() { + Keyboard.dismiss(); + const instances = rawInstancesRef.current; + if ( + !textPrompt.trim() || + instances.length === 0 || + !skiaSource || + !clipImage.isReady || + !clipText.isReady || + textBusy + ) { + return; + } + setTextBusy(true); + try { + if (!instanceEmbeddingsRef.current) { + setEmbeddingProgress({ done: 0, total: instances.length }); + const embeddings: Float32Array[] = []; + for (let i = 0; i < instances.length; i++) { + const inst = instances[i]!; + embeddings.push( + await cropAndEmbed( + skiaSource, + inst.bbox, + inst.mask, + inst.maskWidth, + inst.maskHeight, + clipImage.forward + ) + ); + setEmbeddingProgress({ done: i + 1, total: instances.length }); + } + instanceEmbeddingsRef.current = embeddings; + setEmbeddingProgress(null); + } + const textEmb = await clipText.forward(textPrompt); + const match = selectByText( + instances, + instanceEmbeddingsRef.current, + textEmb + ); + applyMatch(match); + } catch (e) { + console.error(e); + } finally { + setTextBusy(false); + } + } + + const handleCameraPress = async (isCamera: boolean) => { + Keyboard.dismiss(); + const image = await getImage(isCamera); + if (!image?.uri) return; + setImageUri(image.uri); + setImageSize({ width: image.width ?? 0, height: image.height ?? 0 }); + rawInstancesRef.current = []; + instanceEmbeddingsRef.current = null; + setSelection([]); + setInferenceTime(null); + }; + + const runForward = async () => { + Keyboard.dismiss(); + if (!imageUri) return; + try { + const start = Date.now(); + const output = await forward(imageUri, { + confidenceThreshold: 0.4, + iouThreshold: 0.9, + maxInstances: 50, + returnMaskAtOriginalResolution: true, + }); + setInferenceTime(Date.now() - start); + rawInstancesRef.current = output; + instanceEmbeddingsRef.current = null; + setSelection([]); + } catch (e) { + console.error(e); + } + }; + + if (!isReady && error) { + return ( + + + Error Loading Model + {error.message} + + + ); + } + + if (!isReady) { + return ( + + ); + } + + const { width: cw, height: ch } = layoutRef.current; + const { width: iw, height: ih } = imageSize; + const drawScale = iw > 0 && ih > 0 ? Math.min(cw / iw, ch / ih) : 1; + const offsetX = (cw - iw * drawScale) / 2; + const offsetY = (ch - ih * drawScale) / 2; + + const stepHint = !imageUri + ? null + : inferenceTime === null + ? 'Tap Run to detect instances' + : rawInstancesRef.current.length === 0 + ? 'No instances detected — try another image' + : selection.length === 0 + ? 'Tap a point, draw a box, or describe an object' + : null; + + return ( + + + + + + { + layoutRef.current = { + width: e.nativeEvent.layout.width, + height: e.nativeEvent.layout.height, + }; + }} + onTouchStart={(e) => { + Keyboard.dismiss(); + if (mode === 'point') handleTap(e); + else if (mode === 'box') handleBoxStart(e); + }} + onTouchMove={handleBoxMove} + onTouchEnd={handleBoxEnd} + > + + {draftBox && iw > 0 && ( + + + + )} + + {!imageUri && ( + + Segment Anything + + Segment any object in an image. (1) Pick an image, (2) tap + Run to detect instances, (3) tap a point, draw a box, or + describe an object to segment it. + + + )} + + + + {stepHint && {stepHint}} + + + {(['point', 'box', 'text'] as PromptMode[]).map((m) => { + const promptDisabled = rawInstancesRef.current.length === 0; + return ( + { + if (m !== 'text') Keyboard.dismiss(); + setMode(m); + }} + disabled={promptDisabled} + > + + {m[0]!.toUpperCase() + m.slice(1)} + + + ); + })} + + + {mode === 'text' && ( + + + {(() => { + const findInactive = + !textPrompt.trim() || + rawInstancesRef.current.length === 0 || + !clipImage.isReady || + !clipText.isReady; + return ( + + Find + + ); + })()} + + )} + {mode === 'text' && embeddingProgress && ( + + Embedding instances {embeddingProgress.done}/ + {embeddingProgress.total} (subsequent text queries are instant) + + )} + + { + if (m.modelName === selectedModel.modelName) return; + setSelectedModel(m); + rawInstancesRef.current = []; + instanceEmbeddingsRef.current = null; + setSelection([]); + setInferenceTime(null); + }} + /> + + 0 + ? rawInstancesRef.current.length + : null + } + /> + + + + + + ); +} + +async function cropAndEmbed( + image: SkImage, + bbox: Bbox, + mask: Uint8Array, + maskWidth: number, + maskHeight: number, + forward: (input: string) => Promise +): Promise { + const imgW = image.width(); + const imgH = image.height(); + const surface = Skia.Surface.MakeOffscreen(imgW, imgH); + if (!surface) throw new Error('Failed to create offscreen Skia surface'); + const canvas = surface.getCanvas(); + canvas.clear(Skia.Color('white')); + + const x1 = Math.max(0, Math.round(bbox.x1)); + const y1 = Math.max(0, Math.round(bbox.y1)); + const x2 = Math.min(imgW, Math.round(bbox.x2)); + const y2 = Math.min(imgH, Math.round(bbox.y2)); + const w = x2 - x1; + const h = y2 - y1; + if (w > 0 && h > 0) { + canvas.drawImageRect( + image, + { x: x1, y: y1, width: w, height: h }, + { x: x1, y: y1, width: w, height: h }, + Skia.Paint() + ); + } + + const inversePixels = new Uint8Array(mask.length * 4); + for (let i = 0; i < mask.length; i++) { + const outside = mask[i]! === 0; + const idx = i * 4; + inversePixels[idx] = outside ? 255 : 0; + inversePixels[idx + 1] = outside ? 255 : 0; + inversePixels[idx + 2] = outside ? 255 : 0; + inversePixels[idx + 3] = outside ? 255 : 0; + } + const inverseData = Skia.Data.fromBytes(inversePixels); + const inverseMaskImg = Skia.Image.MakeImage( + { + width: maskWidth, + height: maskHeight, + colorType: ColorType.RGBA_8888, + alphaType: AlphaType.Premul, + }, + inverseData, + maskWidth * 4 + ); + if (inverseMaskImg) { + canvas.drawImageRect( + inverseMaskImg, + { x: 0, y: 0, width: maskWidth, height: maskHeight }, + { + x: bbox.x1, + y: bbox.y1, + width: bbox.x2 - bbox.x1, + height: bbox.y2 - bbox.y1, + }, + Skia.Paint() + ); + } + + const base64 = surface.makeImageSnapshot().encodeToBase64(); + inverseData.dispose(); + return forward(`data:image/png;base64,${base64}`); +} + +const styles = StyleSheet.create({ + flex: { flex: 1 }, + container: { flex: 6, width: '100%' }, + imageContainer: { flex: 1, width: '100%', padding: 16 }, + imageTouchArea: { flex: 1, position: 'relative' }, + infoContainer: { alignItems: 'center', padding: 16, gap: 8 }, + infoTitle: { fontSize: 18, fontWeight: '600', color: 'navy' }, + infoText: { + fontSize: 14, + color: '#555', + textAlign: 'center', + lineHeight: 20, + }, + modeRow: { + flexDirection: 'row', + justifyContent: 'center', + paddingVertical: 8, + gap: 8, + }, + modeBtn: { + paddingHorizontal: 18, + paddingVertical: 8, + borderRadius: 8, + borderWidth: 1, + borderColor: ColorPalette.primary, + backgroundColor: '#fff', + }, + modeBtnActive: { backgroundColor: ColorPalette.primary }, + modeBtnDisabled: { borderColor: '#cbd5e1', backgroundColor: '#f8fafc' }, + modeBtnText: { fontSize: 14, fontWeight: '600', color: ColorPalette.primary }, + modeBtnTextActive: { color: '#fff' }, + modeBtnTextDisabled: { color: '#cbd5e1' }, + textRow: { + flexDirection: 'row', + alignItems: 'center', + paddingHorizontal: 16, + paddingBottom: 8, + gap: 8, + }, + textInput: { + flex: 1, + backgroundColor: '#fff', + borderWidth: 1, + borderColor: ColorPalette.primary, + borderRadius: 12, + paddingHorizontal: 14, + paddingVertical: 12, + fontSize: 16, + color: '#0f172a', + }, + textBtn: { + backgroundColor: ColorPalette.primary, + borderRadius: 12, + paddingVertical: 14, + width: 80, + alignItems: 'center', + }, + textBtnDisabled: { backgroundColor: '#cbd5e1' }, + textBtnLabel: { color: '#fff', fontWeight: '700', fontSize: 16 }, + statusLine: { + paddingHorizontal: 16, + paddingBottom: 6, + fontSize: 12, + color: '#64748b', + }, + stepHint: { + paddingHorizontal: 16, + paddingTop: 6, + fontSize: 13, + fontWeight: '500', + color: ColorPalette.primary, + textAlign: 'center', + }, + errorContainer: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + padding: 32, + }, + errorTitle: { + fontSize: 20, + fontWeight: '700', + color: '#e74c3c', + marginBottom: 12, + }, + errorText: { fontSize: 14, color: '#555', textAlign: 'center' }, +}); diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index 4020d20023..7a399f443f 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -54,6 +54,8 @@ type ModelId = | 'segmentationSelfie' | 'instanceSegmentationYolo26n' | 'instanceSegmentationRfdetr' + | 'instanceSegmentationFastsamS' + | 'instanceSegmentationFastsamX' | 'poseEstimationYolo26n' | 'ocr' | 'styleTransferCandy' @@ -87,6 +89,8 @@ const TASKS: Task[] = [ variants: [ { id: 'instanceSegmentationYolo26n', label: 'YOLO26N Seg' }, { id: 'instanceSegmentationRfdetr', label: 'RF-DETR Nano Seg' }, + { id: 'instanceSegmentationFastsamS', label: 'FastSAM-S' }, + { id: 'instanceSegmentationFastsamX', label: 'FastSAM-X' }, ], }, { @@ -284,6 +288,8 @@ export default function VisionCameraScreen() { activeModel as | 'instanceSegmentationYolo26n' | 'instanceSegmentationRfdetr' + | 'instanceSegmentationFastsamS' + | 'instanceSegmentationFastsamX' } /> )} diff --git a/apps/computer-vision/components/ImageWithMasks.tsx b/apps/computer-vision/components/ImageWithMasks.tsx index bd768909b2..8bb435f47a 100644 --- a/apps/computer-vision/components/ImageWithMasks.tsx +++ b/apps/computer-vision/components/ImageWithMasks.tsx @@ -156,7 +156,7 @@ export default function ImageWithMasks({ /> {instances.length > 0 && ( - + {instances.map((inst, idx) => { const mx = inst.bbox.x1 * scale + offsetX; diff --git a/apps/computer-vision/components/vision_camera/tasks/InstanceSegmentationTask.tsx b/apps/computer-vision/components/vision_camera/tasks/InstanceSegmentationTask.tsx index 8bcdfb3844..52251f6e3e 100644 --- a/apps/computer-vision/components/vision_camera/tasks/InstanceSegmentationTask.tsx +++ b/apps/computer-vision/components/vision_camera/tasks/InstanceSegmentationTask.tsx @@ -6,9 +6,12 @@ import { SegmentedInstance, YOLO26N_SEG, RF_DETR_NANO_SEG, + FASTSAM_S, + FASTSAM_X, useInstanceSegmentation, CocoLabel, CocoLabelYolo, + FastSAMLabel, } from 'react-native-executorch'; import { Canvas, Image as SkiaImage } from '@shopify/react-native-skia'; import { labelColor, labelColorBg } from '../../utils/colors'; @@ -20,7 +23,9 @@ import { type InstSegModelId = | 'instanceSegmentationYolo26n' - | 'instanceSegmentationRfdetr'; + | 'instanceSegmentationRfdetr' + | 'instanceSegmentationFastsamS' + | 'instanceSegmentationFastsamX'; type Props = TaskProps & { activeModel: InstSegModelId }; @@ -44,9 +49,21 @@ export default function InstanceSegmentationTask({ model: RF_DETR_NANO_SEG, preventLoad: activeModel !== 'instanceSegmentationRfdetr', }); + const fastsamS = useInstanceSegmentation({ + model: FASTSAM_S, + preventLoad: activeModel !== 'instanceSegmentationFastsamS', + }); + const fastsamX = useInstanceSegmentation({ + model: FASTSAM_X, + preventLoad: activeModel !== 'instanceSegmentationFastsamX', + }); - const active = - activeModel === 'instanceSegmentationYolo26n' ? yolo26n : rfdetr; + const active = { + instanceSegmentationYolo26n: yolo26n, + instanceSegmentationRfdetr: rfdetr, + instanceSegmentationFastsamS: fastsamS, + instanceSegmentationFastsamX: fastsamX, + }[activeModel]; const [instances, setInstances] = useState([]); const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); @@ -74,7 +91,8 @@ export default function InstanceSegmentationTask({ (p: { results: | SegmentedInstance[] - | SegmentedInstance[]; + | SegmentedInstance[] + | SegmentedInstance[]; imageWidth: number; imageHeight: number; }) => { diff --git a/docs/docs/02-benchmarks/inference-time.md b/docs/docs/02-benchmarks/inference-time.md index faef5c603d..8ef213238f 100644 --- a/docs/docs/02-benchmarks/inference-time.md +++ b/docs/docs/02-benchmarks/inference-time.md @@ -230,17 +230,23 @@ slower for very large images, which can increase total time. ## Instance Segmentation :::note -Times presented in the tables are measured for YOLO models with input size equal to 512. Other input sizes may yield slower or faster inference times. RF-DETR Nano Seg uses a fixed resolution of 312×312. +Times presented in the tables are measured for YOLO models with input size equal +to 512. Other input sizes may yield slower or faster inference times. RF-DETR +Nano Seg uses a fixed resolution of 312×312. ::: -| Model | Samsung Galaxy S24 (XNNPACK) [ms] | Iphone 17 pro (XNNPACK) [ms] | -| ---------------- | --------------------------------- | ---------------------------- | -| YOLO26N_SEG | 92 | 90 | -| YOLO26S_SEG | 220 | 188 | -| YOLO26M_SEG | 570 | 550 | -| YOLO26L_SEG | 680 | 608 | -| YOLO26X_SEG | 1410 | 1338 | -| RF_DETR_NANO_SEG | 549 | 330 | +| Model | Samsung Galaxy S24 [ms] | Iphone 17 pro [ms] | Pixel 10 [ms] | +| :------------------------- | :---------------------: | :----------------: | :-----------: | +| YOLO26N_SEG (XNNPACK) | 92 | 90 | 93 | +| YOLO26S_SEG (XNNPACK) | 220 | 188 | 193 | +| YOLO26M_SEG (XNNPACK) | 570 | 550 | 481 | +| YOLO26L_SEG (XNNPACK) | 680 | 608 | 582 | +| YOLO26X_SEG (XNNPACK) | 1410 | 1338 | 1191 | +| RF_DETR_NANO_SEG (XNNPACK) | 549 | 330 | 428 | +| FASTSAM_S (XNNPACK) | - | 30 | 286 | +| FASTSAM_X (XNNPACK) | - | 2520 | 1993 | +| FASTSAM_S (Core ML) | - | 51 | - | +| FASTSAM_X (Core ML) | - | 72 | - | ## Text to image diff --git a/docs/docs/02-benchmarks/model-size.md b/docs/docs/02-benchmarks/model-size.md index 8dea094839..6d7f7cb753 100644 --- a/docs/docs/02-benchmarks/model-size.md +++ b/docs/docs/02-benchmarks/model-size.md @@ -22,14 +22,16 @@ title: Model Size ## Instance Segmentation -| Model | XNNPACK [MB] | -| ---------------- | :----------: | -| YOLO26N_SEG | 11.6 | -| YOLO26S_SEG | 42.3 | -| YOLO26M_SEG | 95.4 | -| YOLO26L_SEG | 113 | -| YOLO26X_SEG | 252 | -| RF_DETR_NANO_SEG | 124 | +| Model | XNNPACK [MB] | Core ML FP32 [MB] | Core ML FP16 [MB] | +| ---------------- | :----------: | :---------------: | :---------------: | +| YOLO26N_SEG | 11.6 | - | - | +| YOLO26S_SEG | 42.3 | - | - | +| YOLO26M_SEG | 95.4 | - | - | +| YOLO26L_SEG | 113 | - | - | +| YOLO26X_SEG | 252 | - | - | +| RF_DETR_NANO_SEG | 124 | - | - | +| FASTSAM_S | 47.3 | 47.8 | 24.2 | +| FASTSAM_X | 289 | 290 | 145 | ## Style Transfer diff --git a/docs/docs/03-hooks/02-computer-vision/useInstanceSegmentation.md b/docs/docs/03-hooks/02-computer-vision/useInstanceSegmentation.md index 14e2ff8478..6835262a6a 100644 --- a/docs/docs/03-hooks/02-computer-vision/useInstanceSegmentation.md +++ b/docs/docs/03-hooks/02-computer-vision/useInstanceSegmentation.md @@ -132,3 +132,55 @@ YOLO models use the [`CocoLabelYolo`](../../06-api-reference/enumerations/CocoLa | yolo26l-seg | 80 | [COCO (YOLO)](../../06-api-reference/enumerations/CocoLabelYolo.md) | 384, 512, 640 | | yolo26x-seg | 80 | [COCO (YOLO)](../../06-api-reference/enumerations/CocoLabelYolo.md) | 384, 512, 640 | | rfdetr-nano-seg | 91 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | 312 (fixed) | +| fastsam-s | 1 | [FastSAMLabel](../../06-api-reference/enumerations/FastSAMLabel.md) | 640 (fixed) | +| fastsam-x | 1 | [FastSAMLabel](../../06-api-reference/enumerations/FastSAMLabel.md) | 640 (fixed) | + +:::tip +FastSAM models are class-agnostic, so they segment every instance without classifying it. That makes them a good fit for promptable selection workflows. +::: + +## Promptable selection + +Instance segmentation models return a list of segmented instances. After `forward()`, you can use prompt-based selectors to pick the instance you want. Use point selection for tap-to-select or cutout tools, box selection for drag-to-outline workflows, and text selection for search or describe-it-in-words flows. For example, a photo-editing app can use point selection to isolate a person, create custom sticker or background-removal flow can use box selection, and a shopping app can use text selection to find a product by name or description: + +1. Load an instance segmentation model with `useInstanceSegmentation`. +2. Run `forward(image)` once to get the detected instances. +3. Use a selector to pick the instance or instances matching the user's prompt. +4. Re-run the selector when the prompt changes; you do not need to call `forward` again unless the image changes. + +```typescript +import { + useInstanceSegmentation, + selectByPoint, + selectByBox, + selectByText, + FASTSAM_X, +} from 'react-native-executorch'; + +const model = useInstanceSegmentation({ model: FASTSAM_X }); + +try { + const instances = await model.forward(imageUri); + + // Point: the smallest instance whose mask covers (x, y). + const pointMatch = selectByPoint(instances, x, y); + console.log('point match:', pointMatch?.bbox); + + // Box: the instance with highest IoU with the prompt box. + const boxMatch = selectByBox(instances, { x1, y1, x2, y2 }); + console.log('box match:', boxMatch?.bbox); + + // Text: highest cosine similarity between text and per-instance image + // embeddings (you must provide the embeddings, e.g. with CLIP). + const textMatch = selectByText(instances, instanceEmbeddings, textEmbedding); + console.log('text match:', textMatch?.bbox); +} catch (error) { + console.error(error); +} +``` + +:::tip +Use FastSAM-S for faster performance on simple images with non-overlapping +instances and FastSAM-X for better accuracy on complex scenes with many +overlapping objects. +::: diff --git a/packages/react-native-executorch/src/constants/commonVision.ts b/packages/react-native-executorch/src/constants/commonVision.ts index ecea0f8069..6221d5701e 100644 --- a/packages/react-native-executorch/src/constants/commonVision.ts +++ b/packages/react-native-executorch/src/constants/commonVision.ts @@ -200,3 +200,14 @@ export enum CocoLabelYolo { HAIR_DRIER = 78, TOOTHBRUSH = 79, } + +/** + * Class label for FastSAM models. + * + * FastSAM is class-agnostic and produces a single "object" class for every + * detected region. Use this enum when working with `fastsam-s` or `fastsam-x`. + * @category Types + */ +export enum FastSAMLabel { + OBJECT = 0, +} diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 6895601e1e..387dfdc3d8 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -1010,6 +1010,32 @@ export const SELFIE_SEGMENTATION = { modelSource: SELFIE_SEGMENTATION_MODEL, } as const; +// FastSAM Instance Segmentation +const FASTSAM_S_SEG_MODEL = + Platform.OS === 'ios' + ? `${URL_PREFIX}-fast-sam/${NEXT_VERSION_TAG}/fastsam-s/coreml/fastsam_s_coreml_fp16.pte` + : `${URL_PREFIX}-fast-sam/${NEXT_VERSION_TAG}/fastsam-s/xnnpack/fastsam_s_xnnpack_fp32.pte`; +const FASTSAM_X_SEG_MODEL = + Platform.OS === 'ios' + ? `${URL_PREFIX}-fast-sam/${NEXT_VERSION_TAG}/fastsam-x/coreml/fastsam_x_coreml_fp16.pte` + : `${URL_PREFIX}-fast-sam/${NEXT_VERSION_TAG}/fastsam-x/xnnpack/fastsam_x_xnnpack_fp32.pte`; + +/** + * @category Models - Instance Segmentation + */ +export const FASTSAM_S = { + modelName: 'fastsam-s', + modelSource: FASTSAM_S_SEG_MODEL, +} as const; + +/** + * @category Models - Instance Segmentation + */ +export const FASTSAM_X = { + modelName: 'fastsam-x', + modelSource: FASTSAM_X_SEG_MODEL, +} as const; + /** * @category Models - Instance Segmentation */ @@ -1352,6 +1378,8 @@ export const MODEL_REGISTRY = { YOLO26L_SEG, YOLO26X_SEG, RF_DETR_NANO_SEG, + FASTSAM_S, + FASTSAM_X, CLIP_VIT_BASE_PATCH32_IMAGE, CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED, ALL_MINILM_L6_V2, diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 96d167a7d2..84d6da5150 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -212,6 +212,7 @@ export * from './utils/BaseResourceFetcherClass'; export * from './utils/llm'; export * from './common/Logger'; export * from './utils/llms/context_strategy'; +export * from './utils/segmentAnythingPrompts'; // types export * from './types/objectDetection'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts index 2e70e6bdec..e7e96f2deb 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts @@ -23,6 +23,7 @@ import { import { CocoLabel, CocoLabelYolo, + FastSAMLabel, IMAGENET1K_MEAN, IMAGENET1K_STD, } from '../../constants/commonVision'; @@ -39,6 +40,18 @@ const YOLO_SEG_CONFIG = { }, } satisfies InstanceSegmentationConfig; +const FASTSAM_CONFIG = { + preprocessorConfig: undefined, + labelMap: FastSAMLabel, + availableInputSizes: undefined, + defaultInputSize: undefined, + defaultConfidenceThreshold: 0.5, + defaultIouThreshold: 0.9, + postprocessorConfig: { + applyNMS: true, + }, +} satisfies InstanceSegmentationConfig; + const RF_DETR_NANO_SEG_CONFIG = { preprocessorConfig: { normMean: IMAGENET1K_MEAN, normStd: IMAGENET1K_STD }, labelMap: CocoLabel, @@ -81,10 +94,13 @@ const ModelConfigs = { 'yolo26l-seg': YOLO_SEG_CONFIG, 'yolo26x-seg': YOLO_SEG_CONFIG, 'rfdetr-nano-seg': RF_DETR_NANO_SEG_CONFIG, + 'fastsam-s': FASTSAM_CONFIG, + 'fastsam-x': FASTSAM_CONFIG, } as const satisfies Record< InstanceSegmentationModelName, | InstanceSegmentationConfig | InstanceSegmentationConfig + | InstanceSegmentationConfig >; /** @internal */ diff --git a/packages/react-native-executorch/src/types/instanceSegmentation.ts b/packages/react-native-executorch/src/types/instanceSegmentation.ts index 869f0cdcd7..ff7f4ae314 100644 --- a/packages/react-native-executorch/src/types/instanceSegmentation.ts +++ b/packages/react-native-executorch/src/types/instanceSegmentation.ts @@ -114,7 +114,9 @@ export type InstanceSegmentationModelSources = | { modelName: 'yolo26m-seg'; modelSource: ResourceSource } | { modelName: 'yolo26l-seg'; modelSource: ResourceSource } | { modelName: 'yolo26x-seg'; modelSource: ResourceSource } - | { modelName: 'rfdetr-nano-seg'; modelSource: ResourceSource }; + | { modelName: 'rfdetr-nano-seg'; modelSource: ResourceSource } + | { modelName: 'fastsam-s'; modelSource: ResourceSource } + | { modelName: 'fastsam-x'; modelSource: ResourceSource }; /** * Union of all built-in instance segmentation model names. diff --git a/packages/react-native-executorch/src/utils/commonVision.ts b/packages/react-native-executorch/src/utils/commonVision.ts new file mode 100644 index 0000000000..7cd9b2a44b --- /dev/null +++ b/packages/react-native-executorch/src/utils/commonVision.ts @@ -0,0 +1,10 @@ +import { Bbox } from '../types/objectDetection'; + +/** + * Calculates the area of a bounding box. + * @param bbox - Bounding box to calculate area for. + * @returns Area of the bounding box. + */ +export function bboxArea(bbox: Bbox): number { + return Math.max(bbox.x2 - bbox.x1, 0) * Math.max(bbox.y2 - bbox.y1, 0); +} diff --git a/packages/react-native-executorch/src/utils/segmentAnythingPrompts.ts b/packages/react-native-executorch/src/utils/segmentAnythingPrompts.ts new file mode 100644 index 0000000000..db854705c7 --- /dev/null +++ b/packages/react-native-executorch/src/utils/segmentAnythingPrompts.ts @@ -0,0 +1,157 @@ +import { LabelEnum } from '../types/common'; +import { Bbox } from '../types/objectDetection'; +import { SegmentedInstance } from '../types/instanceSegmentation'; +import { bboxArea } from './commonVision'; + +/** + * Selects the best matching instance for a given point prompt. + * + * Finds all instances whose mask covers the point (x, y), then returns the one + * with the smallest bounding box area (ties broken by highest confidence). + * @param instances - Array of segmented instances returned by `forward()`. + * @param x - X coordinate in original image space. + * @param y - Y coordinate in original image space. + * @returns The best matching instance, or `null` if no mask covers the point. + */ +export function selectByPoint( + instances: SegmentedInstance[], + x: number, + y: number +): SegmentedInstance | null { + const px = Math.round(x); + const py = Math.round(y); + + const matches = instances.filter((inst) => { + const mx = px - Math.round(inst.bbox.x1); + const my = py - Math.round(inst.bbox.y1); + if (mx < 0 || my < 0 || mx >= inst.maskWidth || my >= inst.maskHeight) { + return false; + } + return inst.mask[my * inst.maskWidth + mx] === 1; + }); + + if (matches.length === 0) return null; + + return matches.reduce((best, inst) => { + const boxArea = bboxArea(inst.bbox); + const bestBoxArea = bboxArea(best.bbox); + if (boxArea !== bestBoxArea) return boxArea < bestBoxArea ? inst : best; + + return inst.score > best.score ? inst : best; + }); +} + +/** + * Selects the best matching instance for a given box prompt. + * + * Finds all instances that overlap with the prompt box, then returns the one + * with the highest IoU with that box (ties broken by highest confidence). + * @param instances - Array of segmented instances returned by `forward()`. + * @param box - The prompt bounding box in image coordinates. + * @returns The best matching instance, or `null` if no instance overlaps. + */ +export function selectByBox( + instances: SegmentedInstance[], + box: Bbox +): SegmentedInstance | null { + const { x1: px1, y1: py1, x2: px2, y2: py2 } = box; + const promptArea = bboxArea(box); + + type Match = { + iou: number; + score: number; + inst: SegmentedInstance; + }; + let best: Match | null = null; + + for (const inst of instances) { + const { x1, y1, x2, y2 } = inst.bbox; + const interX1 = Math.max(px1, x1); + const interY1 = Math.max(py1, y1); + const interX2 = Math.min(px2, x2); + const interY2 = Math.min(py2, y2); + const interArea = + Math.max(interX2 - interX1, 0) * Math.max(interY2 - interY1, 0); + if (interArea <= 0) continue; + + const detArea = bboxArea(inst.bbox); + const iou = interArea / (promptArea + detArea - interArea + 1e-7); + + if ( + best === null || + iou > best.iou || + (iou === best.iou && inst.score > best.score) + ) { + best = { iou, score: inst.score, inst }; + } + } + + return best?.inst ?? null; +} + +/** + * Selects the best matching instance(s) for a text prompt. + * + * Returns the instance(s) whose image embedding has the highest cosine similarity + * with the text embedding. The caller is responsible for producing the + * embeddings (e.g. with CLIP) and passing them in the same order as + * `instances`. + * @param instances - Array of segmented instances returned by `forward()`. + * @param instanceEmbeddings - Image embedding for each instance, in the same order as `instances`. + * @param textEmbedding - Embedding of the text prompt. + * @param topk - Number of top matches to return (defaults to 1). + * @returns The best matching instance (or null) if topk is 1, otherwise an array of the topk matching instances. + */ +export function selectByText( + instances: SegmentedInstance[], + instanceEmbeddings: Float32Array[], + textEmbedding: Float32Array, + topk?: 1 +): SegmentedInstance | null; +export function selectByText( + instances: SegmentedInstance[], + instanceEmbeddings: Float32Array[], + textEmbedding: Float32Array, + topk: number +): SegmentedInstance[]; +export function selectByText( + instances: SegmentedInstance[], + instanceEmbeddings: Float32Array[], + textEmbedding: Float32Array, + topk = 1 +): SegmentedInstance | null | SegmentedInstance[] { + if (instances.length === 0) return topk === 1 ? null : []; + if (instances.length !== instanceEmbeddings.length) { + throw new Error( + `selectByText: instances (${instances.length}) ` + + `and instanceEmbeddings (${instanceEmbeddings.length}) ` + + `must have the same length` + ); + } + + const scores = instanceEmbeddings.map((emb) => { + let dot = 0; + for (let j = 0; j < emb.length; j++) { + dot += emb[j]! * textEmbedding[j]!; + } + return dot; + }); + + if (topk === 1) { + let bestIdx = 0; + let bestScore = -Infinity; + for (let i = 0; i < scores.length; i++) { + if (scores[i]! > bestScore) { + bestScore = scores[i]!; + bestIdx = i; + } + } + return instances[bestIdx]!; + } + + return instances + .map((instance, index) => ({ instance, score: scores[index]! })) + .sort((a, b) => b.score - a.score) + .slice(0, topk) + .map((item) => item.instance); +}