From 08f418dbd852bad0a42e528c26533f67daafb512 Mon Sep 17 00:00:00 2001 From: Brian10420 Date: Thu, 26 Feb 2026 12:49:37 +0800 Subject: [PATCH 1/2] feat: add person localization and crop video processors --- src/sign_prep/config/loader.py | 37 +- src/sign_prep/config/schema.py | 19 + src/sign_prep/processors/common/__init__.py | 2 + src/sign_prep/processors/common/clip_video.py | 7 +- src/sign_prep/processors/common/crop_video.py | 260 +++++++++++ .../processors/common/person_localize.py | 432 ++++++++++++++++++ src/sign_prep/processors/common/webdataset.py | 128 +++++- 7 files changed, 858 insertions(+), 27 deletions(-) create mode 100644 src/sign_prep/processors/common/crop_video.py create mode 100644 src/sign_prep/processors/common/person_localize.py diff --git a/src/sign_prep/config/loader.py b/src/sign_prep/config/loader.py index b12e518..03de1d6 100644 --- a/src/sign_prep/config/loader.py +++ b/src/sign_prep/config/loader.py @@ -22,6 +22,16 @@ def deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any] return result +def _is_absolute(path_str: str) -> bool: + """Check if a path string is absolute, handling both Windows and POSIX styles. + + On Windows, Path("/abs/path").is_absolute() returns False because there + is no drive letter. This helper treats leading '/' as absolute on all + platforms so that POSIX-style paths in YAML configs work correctly. + """ + return Path(path_str).is_absolute() or path_str.startswith("/") + + def resolve_paths(config: Config, project_root: Path) -> Config: """Resolve relative paths to absolute paths based on project root.""" paths = config.paths @@ -30,7 +40,7 @@ def resolve_paths(config: Config, project_root: Path) -> Config: paths.root = str(project_root / "dataset" / config.dataset) root = Path(paths.root) - if not root.is_absolute(): + if not _is_absolute(paths.root): root = project_root / root paths.root = str(root) @@ -38,51 +48,56 @@ def resolve_paths(config: Config, project_root: Path) -> Config: if not paths.videos: paths.videos = str(root / "videos") - elif not Path(paths.videos).is_absolute(): + elif not _is_absolute(paths.videos): paths.videos = str(project_root / paths.videos) if not paths.transcripts: paths.transcripts = str(root / "transcripts") - elif not Path(paths.transcripts).is_absolute(): + elif not _is_absolute(paths.transcripts): paths.transcripts = str(project_root / paths.transcripts) if not paths.manifest: paths.manifest = str(root / "manifest.csv") - elif not Path(paths.manifest).is_absolute(): + elif not _is_absolute(paths.manifest): paths.manifest = str(project_root / paths.manifest) if not paths.landmarks: paths.landmarks = str(root / "landmarks" / extractor_name) - elif not Path(paths.landmarks).is_absolute(): + elif not _is_absolute(paths.landmarks): paths.landmarks = str(project_root / paths.landmarks) if not paths.normalized: paths.normalized = str(root / "normalized" / extractor_name) - elif not Path(paths.normalized).is_absolute(): + elif not _is_absolute(paths.normalized): paths.normalized = str(project_root / paths.normalized) if not paths.clips: paths.clips = str(root / "clips") - elif not Path(paths.clips).is_absolute(): + elif not _is_absolute(paths.clips): paths.clips = str(project_root / paths.clips) + if not paths.cropped_clips: + paths.cropped_clips = str(root / "cropped_clips") + elif not _is_absolute(paths.cropped_clips): + paths.cropped_clips = str(project_root / paths.cropped_clips) + if not paths.webdataset: paths.webdataset = str( root / "webdataset" / config.pipeline.mode / extractor_name ) - elif not Path(paths.webdataset).is_absolute(): + elif not _is_absolute(paths.webdataset): paths.webdataset = str(project_root / paths.webdataset) # Resolve download.video_ids_file relative to project root vid_file = config.download.video_ids_file - if vid_file and not Path(vid_file).is_absolute(): + if vid_file and not _is_absolute(vid_file): config.download.video_ids_file = str(project_root / vid_file) # Resolve extractor model paths relative to project root for attr in ("pose_model_config", "pose_model_checkpoint", "det_model_config", "det_model_checkpoint"): val = getattr(config.extractor, attr) - if val and not Path(val).is_absolute(): + if val and not _is_absolute(val): setattr(config.extractor, attr, str(project_root / val)) return config @@ -177,4 +192,4 @@ def _parse_value(value: str) -> Any: return float(value) except ValueError: pass - return value + return value \ No newline at end of file diff --git a/src/sign_prep/config/schema.py b/src/sign_prep/config/schema.py index 51f4067..2807c81 100644 --- a/src/sign_prep/config/schema.py +++ b/src/sign_prep/config/schema.py @@ -69,6 +69,22 @@ class ClipVideoConfig(BaseModel): resize: Optional[List[int]] = None +class PersonLocalizeConfig(BaseModel): + model: str = "yolov8n.pt" + confidence_threshold: float = 0.5 + sample_strategy: Literal["skip_frame", "uniform"] = "skip_frame" + frame_skip: int = 2 # skip_frame mode: take one frame every N frames + sample_frames: int = 5 # uniform mode: how many frames to sample; + # skip_frame mode: maximum frames to sample + device: str = "cuda:0" + min_bbox_area: float = 0.05 + + +class CropVideoConfig(BaseModel): + padding: float = 0.25 # Padding ratio around the detected bbox + codec: str = "libx264" + + class PipelineConfig(BaseModel): mode: Literal["pose", "video"] = "pose" steps: List[str] = [] @@ -84,6 +100,7 @@ class PathsConfig(BaseModel): landmarks: str = "" normalized: str = "" clips: str = "" + cropped_clips: str = "" webdataset: str = "" @@ -98,3 +115,5 @@ class Config(BaseModel): processing: ProcessingConfig = ProcessingConfig() webdataset: WebDatasetConfig = WebDatasetConfig() clip_video: ClipVideoConfig = ClipVideoConfig() + person_localize: PersonLocalizeConfig = PersonLocalizeConfig() + crop_video: CropVideoConfig = CropVideoConfig() \ No newline at end of file diff --git a/src/sign_prep/processors/common/__init__.py b/src/sign_prep/processors/common/__init__.py index 0d6d554..9cfcbb9 100644 --- a/src/sign_prep/processors/common/__init__.py +++ b/src/sign_prep/processors/common/__init__.py @@ -2,3 +2,5 @@ from .normalize import NormalizeProcessor from .clip_video import ClipVideoProcessor from .webdataset import WebDatasetProcessor +from .person_localize import PersonLocalizeProcessor +from .crop_video import CropVideoProcessor \ No newline at end of file diff --git a/src/sign_prep/processors/common/clip_video.py b/src/sign_prep/processors/common/clip_video.py index a73fadb..9567422 100644 --- a/src/sign_prep/processors/common/clip_video.py +++ b/src/sign_prep/processors/common/clip_video.py @@ -25,8 +25,11 @@ def _clip_single_video(args) -> Tuple[str, bool, str]: os.makedirs(os.path.dirname(output_path), exist_ok=True) duration = end - start + import shutil + ffmpeg_bin = shutil.which("ffmpeg") or "ffmpeg" + cmd = [ - "ffmpeg", "-y", + ffmpeg_bin, "-y", "-ss", str(start), "-i", video_path, "-t", str(duration), @@ -124,4 +127,4 @@ def run(self, context): "skipped": skip, "errors": errors, } - return context + return context \ No newline at end of file diff --git a/src/sign_prep/processors/common/crop_video.py b/src/sign_prep/processors/common/crop_video.py new file mode 100644 index 0000000..fbc3cca --- /dev/null +++ b/src/sign_prep/processors/common/crop_video.py @@ -0,0 +1,260 @@ +"""Crop video processor: crop clipped videos to the detected person bbox. + +Reads BBOX_* and PERSON_DETECTED columns from the manifest (written by +person_localize), applies padding, clamps to frame boundaries, and +re-encodes using ffmpeg. + +If PERSON_DETECTED is False, the clip is copied as-is (no crop). +Output goes to paths.cropped_clips, leaving paths.clips untouched. +""" + +import os +import subprocess +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Optional, Tuple + +import cv2 +import pandas as pd +from tqdm import tqdm + +from ..base import BaseProcessor +from ...registry import register_processor + + +# --------------------------------------------------------------------------- +# Worker function (runs in subprocess via ProcessPoolExecutor) +# --------------------------------------------------------------------------- + +def _crop_single_video(args) -> Tuple[str, bool, str]: + """Crop (or copy) a single video clip. + + Args: + args: tuple of + (clip_path, output_path, x1, y1, x2, y2, + person_detected, padding, codec) + + Returns: + (name, success, message) + """ + ( + clip_path, output_path, + x1, y1, x2, y2, + person_detected, + padding, codec, + ) = args + + name = os.path.basename(output_path) + + try: + if os.path.exists(output_path): + return name, True, "skipped (exists)" + + if not os.path.exists(clip_path): + return name, False, f"clip not found: {clip_path}" + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Resolve ffmpeg to its full path so worker processes on Windows + # can find it even when PATH is not fully inherited. + import shutil as _shutil + ffmpeg_bin = _shutil.which("ffmpeg") or "ffmpeg" + + # If no person detected, stream-copy without cropping + if not person_detected: + cmd = [ + ffmpeg_bin, "-y", + "-i", clip_path, + "-c", "copy", + "-an", output_path, + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) + if result.returncode != 0: + return name, False, result.stderr[-300:] + return name, True, "no-person copy" + + # Read frame size from the clip itself + cap = cv2.VideoCapture(clip_path) + if not cap.isOpened(): + return name, False, "cannot open clip to read dimensions" + frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + if frame_w == 0 or frame_h == 0: + return name, False, "invalid frame dimensions" + + # Apply padding + bbox_w = x2 - x1 + bbox_h = y2 - y1 + pad_x = bbox_w * padding + pad_y = bbox_h * padding + + cx1 = max(0, x1 - pad_x) + cy1 = max(0, y1 - pad_y) + cx2 = min(frame_w, x2 + pad_x) + cy2 = min(frame_h, y2 + pad_y) + + # Convert to integers; ffmpeg crop requires even dimensions for + # most codecs, so round down to nearest even number. + cx1 = int(cx1) + cy1 = int(cy1) + crop_w = int(cx2 - cx1) + crop_h = int(cy2 - cy1) + + # Ensure even dimensions (required by libx264 and most H.264 encoders) + crop_w = crop_w - (crop_w % 2) + crop_h = crop_h - (crop_h % 2) + + if crop_w <= 0 or crop_h <= 0: + return name, False, f"degenerate crop region: w={crop_w} h={crop_h}" + + vf = f"crop={crop_w}:{crop_h}:{cx1}:{cy1}" + + cmd = [ + ffmpeg_bin, "-y", + "-i", clip_path, + "-vf", vf, + "-c:v", codec, + "-an", + output_path, + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode != 0: + return name, False, result.stderr[-300:] + + return name, True, "" + + except subprocess.TimeoutExpired: + return name, False, "ffmpeg timeout" + except Exception as e: + return name, False, str(e) + + +# --------------------------------------------------------------------------- +# Processor +# --------------------------------------------------------------------------- + +@register_processor("crop_video") +class CropVideoProcessor(BaseProcessor): + """Re-encode clipped videos cropped to the detected person bbox. + + Reads the manifest for BBOX_* columns written by PersonLocalizeProcessor. + Input: paths.clips (produced by clip_video) + Output: paths.cropped_clips + """ + + name = "crop_video" + + def run(self, context): + cfg = self.config + crop_cfg = cfg.crop_video + manifest_path = cfg.paths.manifest + clips_dir = cfg.paths.clips + cropped_dir = cfg.paths.cropped_clips + + if not cropped_dir: + raise ValueError( + "paths.cropped_clips is not set in config. " + "Please add it to your yaml (e.g. dataset/youtube_asl/cropped_clips)." + ) + + os.makedirs(cropped_dir, exist_ok=True) + + # ---------------------------------------------------------------- + # Load manifest — must have bbox columns from person_localize + # ---------------------------------------------------------------- + data = pd.read_csv(manifest_path, delimiter="\t", on_bad_lines="skip") + + required_cols = {"BBOX_X1", "BBOX_Y1", "BBOX_X2", "BBOX_Y2", "PERSON_DETECTED"} + missing = required_cols - set(data.columns) + if missing: + raise RuntimeError( + f"Manifest is missing columns: {missing}. " + "Run the 'person_localize' step first." + ) + + data = data[ + ["SENTENCE_NAME", "BBOX_X1", "BBOX_Y1", "BBOX_X2", "BBOX_Y2", "PERSON_DETECTED"] + ].dropna(subset=["SENTENCE_NAME"]) + + # ---------------------------------------------------------------- + # Build task list + # ---------------------------------------------------------------- + tasks = [] + missing_clips = 0 + + for _, row in data.iterrows(): + clip_path = os.path.join(clips_dir, f"{row.SENTENCE_NAME}.mp4") + out_path = os.path.join(cropped_dir, f"{row.SENTENCE_NAME}.mp4") + + if not os.path.exists(clip_path): + missing_clips += 1 + continue + + # PERSON_DETECTED may have been stored as string "True"/"False" + person_detected = _parse_bool(row["PERSON_DETECTED"]) + + tasks.append(( + clip_path, out_path, + float(row["BBOX_X1"]), float(row["BBOX_Y1"]), + float(row["BBOX_X2"]), float(row["BBOX_Y2"]), + person_detected, + crop_cfg.padding, + crop_cfg.codec, + )) + + if missing_clips: + self.logger.warning( + "%d clips not found in %s — run clip_video first.", + missing_clips, clips_dir, + ) + + if not tasks: + self.logger.info("No crop tasks to process.") + context.stats["crop_video"] = {"total": 0} + return context + + self.logger.info("Cropping %d clips → %s", len(tasks), cropped_dir) + + # ---------------------------------------------------------------- + # Parallel execution + # ---------------------------------------------------------------- + success = skip = no_person_copy = errors = 0 + + with ProcessPoolExecutor(max_workers=cfg.processing.max_workers) as executor: + futures = {executor.submit(_crop_single_video, t): t for t in tasks} + with tqdm(total=len(tasks), desc="Cropping videos") as pbar: + for future in as_completed(futures): + name, ok, msg = future.result() + if ok: + if msg == "skipped (exists)": + skip += 1 + elif msg == "no-person copy": + no_person_copy += 1 + else: + success += 1 + else: + errors += 1 + self.logger.error("Failed: %s — %s", name, msg) + pbar.update(1) + + context.stats["crop_video"] = { + "total": len(tasks), + "cropped": success, + "copied_no_person": no_person_copy, + "skipped": skip, + "errors": errors, + } + return context + + +def _parse_bool(val) -> bool: + """Parse a boolean that may be stored as bool, str, int, or float.""" + if isinstance(val, bool): + return val + if isinstance(val, float): + return bool(val) # handles NaN → False gracefully + if isinstance(val, str): + return val.strip().lower() == "true" + return bool(val) \ No newline at end of file diff --git a/src/sign_prep/processors/common/person_localize.py b/src/sign_prep/processors/common/person_localize.py new file mode 100644 index 0000000..b4927fe --- /dev/null +++ b/src/sign_prep/processors/common/person_localize.py @@ -0,0 +1,432 @@ +"""Person localization processor: detect and localize the signer in each video segment. + +Uses YOLOv8-nano to detect persons across sampled frames, then writes +bounding box information back into the manifest CSV. + +Manifest columns added: + BBOX_X1, BBOX_Y1, BBOX_X2, BBOX_Y2 -- union bbox in pixels (float) + PERSON_DETECTED -- bool, False = fallback to full frame +""" + +import os +import logging +from typing import List, Optional, Tuple + +import cv2 +import numpy as np +import pandas as pd +from tqdm import tqdm + +from ..base import BaseProcessor +from ...registry import register_processor + +# Module-level import so tests can patch sign_prep...person_localize.YOLO +# ultralytics is an optional dependency; import error surfaces only at runtime. +try: + from ultralytics import YOLO +except ImportError: # pragma: no cover + YOLO = None # type: ignore[assignment,misc] + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _sample_frames_uniform( + video_path: str, + start_sec: float, + end_sec: float, + n: int, +) -> List[Tuple[np.ndarray, int, int]]: + """Uniformly sample exactly n frames from [start_sec, end_sec]. + + Returns list of (frame_bgr, video_width, video_height). + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + return [] + + fps = cap.get(cv2.CAP_PROP_FPS) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + if fps <= 0: + cap.release() + return [] + + start_frame = int(start_sec * fps) + end_frame = int(end_sec * fps) + total_frames = max(end_frame - start_frame, 1) + + if n == 1: + indices = [start_frame + total_frames // 2] + else: + step = total_frames / (n - 1) + indices = [int(start_frame + i * step) for i in range(n)] + indices = [min(idx, end_frame) for idx in indices] + + frames = [] + for idx in indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if ret and frame is not None: + frames.append((frame, w, h)) + + cap.release() + return frames + + +def _sample_frames_skip( + video_path: str, + start_sec: float, + end_sec: float, + frame_skip: int, + max_frames: int, +) -> List[Tuple[np.ndarray, int, int]]: + """Sample frames by skipping every frame_skip frames, up to max_frames. + + Mirrors the logic used in the extract processor so that localization + sees the same frames as pose estimation would. + + Returns list of (frame_bgr, video_width, video_height). + """ + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + return [] + + fps = cap.get(cv2.CAP_PROP_FPS) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + if fps <= 0: + cap.release() + return [] + + start_frame = int(start_sec * fps) + end_frame = int(end_sec * fps) + + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + frames = [] + current = start_frame + while current <= end_frame and len(frames) < max_frames: + ret, frame = cap.read() + if not ret: + break + # Take this frame, then skip frame_skip frames + if frame is not None: + frames.append((frame, w, h)) + # Skip forward + current += frame_skip + 1 + cap.set(cv2.CAP_PROP_POS_FRAMES, current) + + cap.release() + return frames + + +def _sample_frames( + video_path: str, + start_sec: float, + end_sec: float, + strategy: str, + frame_skip: int, + sample_frames: int, +) -> List[Tuple[np.ndarray, int, int]]: + """Dispatch to the appropriate sampling strategy. + + Args: + video_path: Path to the video file. + start_sec: Segment start time in seconds. + end_sec: Segment end time in seconds. + strategy: "skip_frame" or "uniform". + frame_skip: Used by skip_frame: take 1 frame every (frame_skip+1) frames. + sample_frames: Used by uniform as exact count; + used by skip_frame as maximum frame count. + + Returns: + List of (frame_bgr, width, height) tuples. + """ + if strategy == "skip_frame": + return _sample_frames_skip( + video_path, start_sec, end_sec, + frame_skip=frame_skip, + max_frames=sample_frames, + ) + else: # uniform + return _sample_frames_uniform( + video_path, start_sec, end_sec, + n=sample_frames, + ) + + +def _union_bboxes(bboxes: List[Tuple[float, float, float, float]]) -> Tuple[float, float, float, float]: + """Compute the union (enclosing) bounding box from a list of (x1, y1, x2, y2).""" + x1 = min(b[0] for b in bboxes) + y1 = min(b[1] for b in bboxes) + x2 = max(b[2] for b in bboxes) + y2 = max(b[3] for b in bboxes) + return x1, y1, x2, y2 + + +def _detect_persons_batch( + model, + frames_with_meta: List[Tuple[np.ndarray, int, int]], + confidence_threshold: float, + min_bbox_area_ratio: float, +) -> List[List[Tuple[float, float, float, float]]]: + """Run YOLOv8 batch inference on a list of frames. + + Returns a list (one per frame) of valid person bboxes [(x1, y1, x2, y2), ...]. + Filters by confidence and minimum area ratio. + """ + if not frames_with_meta: + return [] + + images = [fwm[0] for fwm in frames_with_meta] + # YOLOv8 batch inference: pass list of frames + results = model(images, verbose=False) + + all_bboxes: List[List[Tuple[float, float, float, float]]] = [] + + for i, result in enumerate(results): + _, w, h = frames_with_meta[i] + frame_area = float(w * h) + valid = [] + + boxes = result.boxes + if boxes is None or len(boxes) == 0: + all_bboxes.append(valid) + continue + + cls_ids = boxes.cls.cpu().numpy() + confs = boxes.conf.cpu().numpy() + xyxy = boxes.xyxy.cpu().numpy() # shape (N, 4) + + for j in range(len(cls_ids)): + # class 0 = person in COCO + if int(cls_ids[j]) != 0: + continue + if confs[j] < confidence_threshold: + continue + + x1, y1, x2, y2 = xyxy[j] + bbox_area = (x2 - x1) * (y2 - y1) + if frame_area > 0 and (bbox_area / frame_area) < min_bbox_area_ratio: + continue + + valid.append((float(x1), float(y1), float(x2), float(y2))) + + all_bboxes.append(valid) + + return all_bboxes + + +# --------------------------------------------------------------------------- +# Processor +# --------------------------------------------------------------------------- + +@register_processor("person_localize") +class PersonLocalizeProcessor(BaseProcessor): + """Detect and localize the signer across video segments. + + Reads VIDEO_NAME / SENTENCE_NAME / timestamps from the manifest, + samples frames from the original video, runs YOLOv8-nano person + detection, unions bboxes across sampled frames, and writes results + back to the manifest CSV. + + New manifest columns: + BBOX_X1, BBOX_Y1, BBOX_X2, BBOX_Y2 (float, pixels) + PERSON_DETECTED (bool) + """ + + name = "person_localize" + + def run(self, context): + cfg = self.config + loc_cfg = cfg.person_localize + manifest_path = cfg.paths.manifest + video_dir = cfg.paths.videos + + # ---------------------------------------------------------------- + # Load manifest + # ---------------------------------------------------------------- + data = pd.read_csv(manifest_path, delimiter="\t", on_bad_lines="skip") + columns = data.columns.tolist() + + if "START" in columns and "END" in columns: + start_col, end_col = "START", "END" + elif "START_REALIGNED" in columns and "END_REALIGNED" in columns: + start_col, end_col = "START_REALIGNED", "END_REALIGNED" + else: + raise ValueError("No recognized timestamp columns found in manifest.") + + # If columns already exist from a previous run, we skip rows that + # were already processed (PERSON_DETECTED is not NaN). + already_done_col = "PERSON_DETECTED" + if already_done_col not in data.columns: + data["BBOX_X1"] = np.nan + data["BBOX_Y1"] = np.nan + data["BBOX_X2"] = np.nan + data["BBOX_Y2"] = np.nan + # Use object dtype so we can store True/False/NaN without warning + data["PERSON_DETECTED"] = pd.array([pd.NA] * len(data), dtype="object") + + # Only process rows not yet done + pending_mask = data["PERSON_DETECTED"].isna() + pending = data[pending_mask].copy() + + if pending.empty: + self.logger.info("All rows already localized, skipping.") + context.stats["person_localize"] = {"total": 0} + return context + + self.logger.info( + "Localizing persons in %d segments (skipping %d already done)", + len(pending), + len(data) - len(pending), + ) + + # ---------------------------------------------------------------- + # Load YOLOv8 model (once) + # ---------------------------------------------------------------- + if YOLO is None: + raise ImportError( + "ultralytics is required for person_localize. " + "Install with: pip install ultralytics" + ) + + self.logger.info("Loading YOLOv8 model: %s on %s", loc_cfg.model, loc_cfg.device) + model = YOLO(loc_cfg.model) + model.to(loc_cfg.device) + + # ---------------------------------------------------------------- + # Process each segment + # ---------------------------------------------------------------- + detected = fallback = errors = 0 + + # We'll accumulate results as a dict keyed by DataFrame index + results_map = {} + + for idx, row in tqdm(pending.iterrows(), total=len(pending), desc="Localizing persons"): + video_name = row["VIDEO_NAME"] + video_path = os.path.join(video_dir, f"{video_name}.mp4") + + if not os.path.exists(video_path): + self.logger.warning("Video not found, using fallback: %s", video_path) + results_map[idx] = self._fallback_row(video_path) + fallback += 1 + continue + + try: + start_sec = float(row[start_col]) + end_sec = float(row[end_col]) + + # Sample frames from the segment + frames_meta = _sample_frames( + video_path, start_sec, end_sec, + strategy=loc_cfg.sample_strategy, + frame_skip=loc_cfg.frame_skip, + sample_frames=loc_cfg.sample_frames, + ) + + if not frames_meta: + self.logger.warning( + "Could not sample frames for %s, using fallback.", + row["SENTENCE_NAME"], + ) + results_map[idx] = self._fallback_row(video_path) + fallback += 1 + continue + + # Batch detection across all sampled frames + per_frame_bboxes = _detect_persons_batch( + model, + frames_meta, + loc_cfg.confidence_threshold, + loc_cfg.min_bbox_area, + ) + + # Collect all valid bboxes across frames + all_valid: List[Tuple[float, float, float, float]] = [] + for frame_bboxes in per_frame_bboxes: + # Pick the largest-area bbox per frame (most likely the signer) + if frame_bboxes: + best = max( + frame_bboxes, + key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), + ) + all_valid.append(best) + + if all_valid: + union = _union_bboxes(all_valid) + results_map[idx] = { + "BBOX_X1": union[0], + "BBOX_Y1": union[1], + "BBOX_X2": union[2], + "BBOX_Y2": union[3], + "PERSON_DETECTED": True, + } + detected += 1 + else: + # No person found in any sampled frame → fallback to full frame + _, w, h = frames_meta[0] + results_map[idx] = { + "BBOX_X1": 0.0, + "BBOX_Y1": 0.0, + "BBOX_X2": float(w), + "BBOX_Y2": float(h), + "PERSON_DETECTED": False, + } + fallback += 1 + + except Exception as e: + self.logger.error( + "Error processing %s: %s", row.get("SENTENCE_NAME", "?"), e + ) + results_map[idx] = self._fallback_row(video_path) + errors += 1 + + # ---------------------------------------------------------------- + # Write results back to manifest + # ---------------------------------------------------------------- + for idx, vals in results_map.items(): + for col, val in vals.items(): + data.at[idx, col] = val + + data.to_csv(manifest_path, sep="\t", index=False) + self.logger.info( + "Manifest updated: detected=%d fallback=%d errors=%d", + detected, fallback, errors, + ) + + context.stats["person_localize"] = { + "total": len(pending), + "detected": detected, + "fallback": fallback, + "errors": errors, + } + return context + + # ---------------------------------------------------------------- + # Helpers + # ---------------------------------------------------------------- + + @staticmethod + def _fallback_row(video_path: str) -> dict: + """Return a full-frame fallback entry when detection is impossible.""" + w, h = 0.0, 0.0 + if os.path.exists(video_path): + cap = cv2.VideoCapture(video_path) + if cap.isOpened(): + w = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + h = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + cap.release() + return { + "BBOX_X1": 0.0, + "BBOX_Y1": 0.0, + "BBOX_X2": float(w), + "BBOX_Y2": float(h), + "PERSON_DETECTED": False, + } \ No newline at end of file diff --git a/src/sign_prep/processors/common/webdataset.py b/src/sign_prep/processors/common/webdataset.py index f239775..ad13de8 100644 --- a/src/sign_prep/processors/common/webdataset.py +++ b/src/sign_prep/processors/common/webdataset.py @@ -3,12 +3,13 @@ import io import json import os +import tarfile +import time from pathlib import Path -from typing import Tuple +from typing import Optional, Tuple import numpy as np import pandas as pd -import webdataset as wds from ..base import BaseProcessor from ...registry import register_processor @@ -25,6 +26,96 @@ def _read_manifest_csv(csv_file: str) -> Tuple[pd.DataFrame, str, str]: raise ValueError("No recognized timestamp columns found") +class _ShardWriter: + """Minimal shard writer using Python's tarfile module. + + Replaces webdataset.ShardWriter to avoid the gopen() Windows path issue + where drive letters (e.g. D:/) are misinterpreted as URL schemes. + + Produces tar files that are fully compatible with webdataset readers. + """ + + def __init__( + self, + output_dir: str, + max_count: int = 10_000, + max_size: Optional[int] = None, + ): + self.output_dir = Path(output_dir) + self.max_count = max_count + self.max_size = max_size # bytes; None = no limit + + self._shard_idx = 0 + self._count = 0 + self._size = 0 + self._tar: Optional[tarfile.TarFile] = None + self._current_path: Optional[Path] = None + self._open_shard() + + def _shard_path(self) -> Path: + return self.output_dir / f"shard-{self._shard_idx:06d}.tar" + + def _open_shard(self): + if self._tar is not None: + self._tar.close() + self._current_path = self._shard_path() + self._tar = tarfile.open(str(self._current_path), "w") + self._count = 0 + self._size = 0 + + def _next_shard(self): + self._tar.close() + self._shard_idx += 1 + self._open_shard() + + def _add_bytes(self, name: str, data: bytes): + """Add a single file entry to the current tar.""" + buf = io.BytesIO(data) + info = tarfile.TarInfo(name=name) + info.size = len(data) + info.mtime = int(time.time()) + self._tar.addfile(info, buf) + self._size += len(data) + + def write(self, sample: dict): + """Write one webdataset sample. + + sample must contain "__key__" and any number of extension→bytes pairs. + String values are UTF-8 encoded automatically. + """ + key = sample["__key__"] + + # Roll over shard if needed (check before writing) + if self._count >= self.max_count: + self._next_shard() + if self.max_size and self._size >= self.max_size: + self._next_shard() + + for ext, value in sample.items(): + if ext == "__key__": + continue + if isinstance(value, str): + value = value.encode("utf-8") + self._add_bytes(f"{key}.{ext}", value) + + self._count += 1 + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def close(self): + if self._tar is not None: + self._tar.close() + self._tar = None + + @property + def shard_count(self) -> int: + return self._shard_idx + 1 + + @register_processor("webdataset") class WebDatasetProcessor(BaseProcessor): name = "webdataset" @@ -38,25 +129,33 @@ def run(self, context): manifest_path = cfg.paths.manifest data, start_col, end_col = _read_manifest_csv(manifest_path) - # Build sentence lookup + # Detect caption column sentence_col = None for col in ["SENTENCE", "TEXT", "CAPTION"]: if col in data.columns: sentence_col = col break - shard_pattern = os.path.join(output_dir, "shard-%06d.tar") max_count = cfg.webdataset.max_shard_count - max_size = cfg.webdataset.max_shard_size - - writer_kwargs = {"pattern": shard_pattern, "maxcount": max_count} - if max_size: - writer_kwargs["maxsize"] = max_size + max_size = cfg.webdataset.max_shard_size or None + + # For video mode: prefer cropped_clips if available, fall back to clips. + if mode == "video": + cropped_dir = cfg.paths.cropped_clips + clips_dir = cfg.paths.clips + if cropped_dir and os.path.isdir(cropped_dir) and any( + f.endswith(".mp4") for f in os.listdir(cropped_dir) + ): + video_source_dir = cropped_dir + self.logger.info("video mode: using cropped_clips → %s", cropped_dir) + else: + video_source_dir = clips_dir + self.logger.info("video mode: using clips → %s", clips_dir) written = 0 skipped = 0 - with wds.ShardWriter(**writer_kwargs) as sink: + with _ShardWriter(output_dir, max_count=max_count, max_size=max_size) as sink: for _, row in data.iterrows(): sentence_name = row.SENTENCE_NAME video_name = row.VIDEO_NAME @@ -90,7 +189,7 @@ def run(self, context): elif mode == "video": clip_path = os.path.join( - cfg.paths.clips, f"{sentence_name}.mp4" + video_source_dir, f"{sentence_name}.mp4" ) if not os.path.exists(clip_path): skipped += 1 @@ -107,10 +206,11 @@ def run(self, context): context.stats["webdataset"] = { "written": written, "skipped": skipped, + "shards": sink.shard_count, "output_dir": output_dir, } self.logger.info( - "WebDataset: wrote %d samples, skipped %d -> %s", - written, skipped, output_dir, + "WebDataset: wrote %d samples in %d shard(s), skipped %d → %s", + written, sink.shard_count, skipped, output_dir, ) - return context + return context \ No newline at end of file From 30e26e32cad3c394bdb8a99af88f5093d2c76a85 Mon Sep 17 00:00:00 2001 From: Brian10420 Date: Thu, 26 Feb 2026 12:51:30 +0800 Subject: [PATCH 2/2] feat: update configs and tests for person_localize and crop_video --- configs/_base/video.yaml | 15 +- configs/how2sign/video.yaml | 18 +- configs/youtube_asl/video.yaml | 17 +- scripts/smoke_test_localize.py | 268 ++++++++++++++++ tests/test_config_loader.py | 40 ++- tests/test_person_localize.py | 569 +++++++++++++++++++++++++++++++++ 6 files changed, 917 insertions(+), 10 deletions(-) create mode 100644 scripts/smoke_test_localize.py create mode 100644 tests/test_person_localize.py diff --git a/configs/_base/video.yaml b/configs/_base/video.yaml index a8973b7..1df414b 100644 --- a/configs/_base/video.yaml +++ b/configs/_base/video.yaml @@ -4,6 +4,19 @@ extractor: clip_video: codec: copy +person_localize: + model: "yolov8n.pt" + confidence_threshold: 0.5 + sample_strategy: "skip_frame" # "skip_frame" (default) or "uniform" + frame_skip: 2 # skip_frame: take 1 frame every (frame_skip+1) frames + sample_frames: 5 # uniform: exact count; skip_frame: max frame count + device: "cuda:0" + min_bbox_area: 0.05 + +crop_video: + padding: 0.25 + codec: "libx264" + processing: max_workers: 4 target_fps: 24.0 @@ -14,4 +27,4 @@ processing: max_duration: 60.0 webdataset: - max_shard_count: 10000 + max_shard_count: 10000 \ No newline at end of file diff --git a/configs/how2sign/video.yaml b/configs/how2sign/video.yaml index bd8ade2..4b1d0c4 100644 --- a/configs/how2sign/video.yaml +++ b/configs/how2sign/video.yaml @@ -4,7 +4,7 @@ dataset: how2sign pipeline: mode: video - steps: [clip_video, webdataset] + steps: [person_localize, clip_video, crop_video, webdataset] start_from: null stop_at: null @@ -15,7 +15,8 @@ paths: manifest: dataset/how2sign/how2sign_realigned_val.csv landmarks: "" normalized: "" - clips: "" + clips: dataset/how2sign/clips + cropped_clips: dataset/how2sign/cropped_clips webdataset: "" download: @@ -29,3 +30,16 @@ manifest: max_text_length: 300 min_duration: 0.2 max_duration: 60.0 + +# person_localize and crop_video use defaults from _base/video.yaml +# Uncomment to override: +# person_localize: +# model: "yolov8n.pt" +# confidence_threshold: 0.5 +# sample_frames: 5 +# device: "cuda:0" +# min_bbox_area: 0.05 + +# crop_video: +# padding: 0.25 +# codec: "libx264" \ No newline at end of file diff --git a/configs/youtube_asl/video.yaml b/configs/youtube_asl/video.yaml index e525842..8deb5fd 100644 --- a/configs/youtube_asl/video.yaml +++ b/configs/youtube_asl/video.yaml @@ -4,7 +4,7 @@ dataset: youtube_asl pipeline: mode: video - steps: [download, manifest, clip_video, webdataset] + steps: [download, manifest, person_localize, clip_video, crop_video, webdataset] start_from: null stop_at: null @@ -15,7 +15,8 @@ paths: manifest: dataset/youtube_asl/manifest.csv landmarks: "" normalized: "" - clips: "" + clips: dataset/youtube_asl/clips + cropped_clips: dataset/youtube_asl/cropped_clips webdataset: "" download: @@ -29,3 +30,15 @@ manifest: max_text_length: 300 min_duration: 0.2 max_duration: 60.0 + +# Override defaults if needed +# person_localize: +# model: "yolov8n.pt" +# confidence_threshold: 0.5 +# sample_frames: 5 +# device: "cuda:0" +# min_bbox_area: 0.05 + +# crop_video: +# padding: 0.25 +# codec: "libx264" \ No newline at end of file diff --git a/scripts/smoke_test_localize.py b/scripts/smoke_test_localize.py new file mode 100644 index 0000000..2cd4e31 --- /dev/null +++ b/scripts/smoke_test_localize.py @@ -0,0 +1,268 @@ +"""Smoke test for person_localize and crop_video processors. + +Usage: + python scripts/smoke_test_localize.py --video path/to/any_video.mp4 + +This script: + 1. Creates a temporary workspace with a minimal manifest + 2. Runs PersonLocalizeProcessor → writes BBOX_* columns to manifest + 3. Runs ClipVideoProcessor → clips the segment + 4. Runs CropVideoProcessor → crops the clip to the detected person + 5. Prints a summary and shows where to find the output files + +No dataset download required — just supply any .mp4 file. +""" + +import argparse +import os +import sys +import shutil +import tempfile +from pathlib import Path + +import cv2 +import pandas as pd + +# Make sure src/ is importable when running from project root +PROJECT_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PROJECT_ROOT / "src")) + + +def get_video_duration(video_path: str) -> float: + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + cap.release() + if fps > 0: + return frame_count / fps + return 0.0 + + +def build_manifest(video_path: str, workspace: Path) -> Path: + """Create a minimal manifest with 2 segments from the video.""" + duration = get_video_duration(video_path) + if duration < 2.0: + raise ValueError(f"Video too short ({duration:.1f}s), need at least 2s.") + + video_name = Path(video_path).stem + + # Create 2 segments: first half and second half + mid = duration / 2 + seg1_end = min(mid, 10.0) # cap at 10s + seg2_start = mid + seg2_end = min(duration, mid + 10.0) + + manifest_path = workspace / "manifest.csv" + df = pd.DataFrame({ + "VIDEO_NAME": [video_name, video_name], + "SENTENCE_NAME": [f"{video_name}-0", f"{video_name}-1"], + "START_REALIGNED": [0.0, seg2_start], + "END_REALIGNED": [seg1_end, seg2_end], + "SENTENCE": ["segment 0", "segment 1"], + }) + df.to_csv(manifest_path, sep="\t", index=False) + print(f" Created manifest: {manifest_path}") + print(f" Segment 0: 0.0s → {seg1_end:.1f}s") + print(f" Segment 1: {seg2_start:.1f}s → {seg2_end:.1f}s") + return manifest_path + + +def run_smoke_test(video_path: str, device: str, padding: float, sample_frames: int): + import shutil + + # Pre-flight: verify ffmpeg is available + ffmpeg_path = shutil.which("ffmpeg") + if not ffmpeg_path: + print("[ERROR] ffmpeg not found in PATH.") + print(" Install ffmpeg and make sure it is accessible:") + print(" Windows: https://www.gyan.dev/ffmpeg/builds/") + print(" Or via conda: conda install -c conda-forge ffmpeg") + sys.exit(1) + print(f" ffmpeg : {ffmpeg_path}") + + from sign_prep.config.schema import Config + from sign_prep.pipeline.context import PipelineContext + from sign_prep.datasets.youtube_asl import YouTubeASLDataset + from sign_prep.processors.common.person_localize import PersonLocalizeProcessor + from sign_prep.processors.common.clip_video import ClipVideoProcessor + from sign_prep.processors.common.crop_video import CropVideoProcessor + + video_path = os.path.abspath(video_path) + if not os.path.exists(video_path): + print(f"[ERROR] Video not found: {video_path}") + sys.exit(1) + + print(f"\n{'='*60}") + print(f" Smoke test: person_localize + clip_video + crop_video") + print(f" Video : {video_path}") + print(f" Device: {device} | Padding: {padding} | Sample frames: {sample_frames}") + print(f"{'='*60}\n") + + # ---------------------------------------------------------------- + # Setup workspace + # ---------------------------------------------------------------- + workspace = Path(tempfile.mkdtemp(prefix="sign_prep_smoke_")) + videos_dir = workspace / "videos" + clips_dir = workspace / "clips" + cropped_dir = workspace / "cropped_clips" + videos_dir.mkdir() + clips_dir.mkdir() + cropped_dir.mkdir() + + # Symlink (or copy) the video into the videos dir so the processor finds it + video_name = Path(video_path).stem + dest_video = videos_dir / f"{video_name}.mp4" + try: + os.symlink(video_path, dest_video) + except (OSError, NotImplementedError): + # Windows may not support symlinks without admin rights → copy instead + shutil.copy2(video_path, dest_video) + + print(f"[1/4] Building manifest...") + manifest_path = build_manifest(str(dest_video), workspace) + + # ---------------------------------------------------------------- + # Build config + # ---------------------------------------------------------------- + cfg = Config( + dataset="youtube_asl", + pipeline={"mode": "video", "steps": ["person_localize", "clip_video", "crop_video"]}, + paths={ + "root": str(workspace), + "videos": str(videos_dir), + "manifest": str(manifest_path), + "clips": str(clips_dir), + "cropped_clips": str(cropped_dir), + }, + person_localize={ + "model": "yolov8n.pt", + "confidence_threshold": 0.5, + "sample_frames": sample_frames, + "device": device, + "min_bbox_area": 0.02, # relaxed for smoke test + }, + crop_video={ + "padding": padding, + "codec": "libx264", + }, + clip_video={ + "codec": "libx264", # re-encode so crop filter can be applied after + }, + processing={"max_workers": 2}, + ) + + ctx = PipelineContext( + config=cfg, + dataset=YouTubeASLDataset(), + project_root=workspace, + ) + + # ---------------------------------------------------------------- + # Step 1: person_localize + # ---------------------------------------------------------------- + print("\n[2/4] Running person_localize...") + processor = PersonLocalizeProcessor(cfg) + ctx = processor.run(ctx) + + stats = ctx.stats.get("person_localize", {}) + print(f" detected={stats.get('detected', 0)} " + f"fallback={stats.get('fallback', 0)} " + f"errors={stats.get('errors', 0)}") + + # Show bbox results + df = pd.read_csv(manifest_path, sep="\t") + print("\n Manifest BBOX columns:") + for _, row in df.iterrows(): + detected = row.get("PERSON_DETECTED", "N/A") + print(f" {row['SENTENCE_NAME']:20s} detected={detected} " + f"bbox=({row.get('BBOX_X1', 'N/A'):.0f}, {row.get('BBOX_Y1', 'N/A'):.0f}, " + f"{row.get('BBOX_X2', 'N/A'):.0f}, {row.get('BBOX_Y2', 'N/A'):.0f})") + + # ---------------------------------------------------------------- + # Step 2: clip_video + # ---------------------------------------------------------------- + print("\n[3/4] Running clip_video...") + clip_processor = ClipVideoProcessor(cfg) + ctx = clip_processor.run(ctx) + + clip_stats = ctx.stats.get("clip_video", {}) + print(f" total={clip_stats.get('total', 0)} " + f"success={clip_stats.get('success', 0)} " + f"errors={clip_stats.get('errors', 0)}") + + clips_found = list(clips_dir.glob("*.mp4")) + print(f" Clips created: {len(clips_found)}") + + # ---------------------------------------------------------------- + # Step 3: crop_video + # ---------------------------------------------------------------- + print("\n[4/4] Running crop_video...") + crop_processor = CropVideoProcessor(cfg) + ctx = crop_processor.run(ctx) + + crop_stats = ctx.stats.get("crop_video", {}) + print(f" total={crop_stats.get('total', 0)} " + f"cropped={crop_stats.get('cropped', 0)} " + f"copied_no_person={crop_stats.get('copied_no_person', 0)} " + f"errors={crop_stats.get('errors', 0)}") + + # ---------------------------------------------------------------- + # Summary: show output dimensions + # ---------------------------------------------------------------- + print(f"\n{'='*60}") + print(" Results") + print(f"{'='*60}") + + cropped_files = list(cropped_dir.glob("*.mp4")) + if not cropped_files: + print(" [WARNING] No cropped clips produced!") + else: + for f in sorted(cropped_files): + cap = cv2.VideoCapture(str(f)) + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + print(f" {f.name:30s} {w}x{h} {frames} frames @ {fps:.1f} fps") + + print(f"\n Output directory: {workspace}") + print(f" Clips : {clips_dir}") + print(f" Cropped : {cropped_dir}") + print(f"\n Open the cropped clips to visually verify the person is centred.\n") + + return workspace + + +def main(): + parser = argparse.ArgumentParser( + description="Smoke test for person_localize + crop_video" + ) + parser.add_argument( + "--video", required=True, + help="Path to any .mp4 video file to test with" + ) + parser.add_argument( + "--device", default="cuda:0", + help="Device for YOLOv8 (default: cuda:0, use 'cpu' if no GPU)" + ) + parser.add_argument( + "--padding", type=float, default=0.25, + help="Crop padding ratio (default: 0.25)" + ) + parser.add_argument( + "--sample-frames", type=int, default=5, + help="Number of frames to sample for detection (default: 5)" + ) + args = parser.parse_args() + + run_smoke_test( + video_path=args.video, + device=args.device, + padding=args.padding, + sample_frames=args.sample_frames, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 8e5c6a4..626ae05 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -148,8 +148,9 @@ def test_absolute_paths_unchanged(self): project_root = Path("/proj") cfg = resolve_paths(cfg, project_root) - assert cfg.paths.root == "/abs/root" - assert cfg.paths.videos == "/abs/videos" + # Use as_posix() to normalise separators across Linux and Windows + assert Path(cfg.paths.root).as_posix() == "/abs/root" + assert Path(cfg.paths.videos).as_posix() == "/abs/videos" def test_video_ids_file_relative_resolved(self): """download.video_ids_file is resolved relative to project root.""" @@ -168,7 +169,7 @@ def test_video_ids_file_absolute_unchanged(self): ) project_root = Path("/proj") cfg = resolve_paths(cfg, project_root) - assert cfg.download.video_ids_file == "/abs/ids.txt" + assert Path(cfg.download.video_ids_file).as_posix() == "/abs/ids.txt" def test_video_ids_file_empty_unchanged(self): cfg = Config(dataset="test") @@ -221,7 +222,36 @@ def test_extractor_model_paths_absolute_unchanged(self): ) project_root = Path("/proj") cfg = resolve_paths(cfg, project_root) - assert cfg.extractor.pose_model_config == "/abs/model.py" + assert Path(cfg.extractor.pose_model_config).as_posix() == "/abs/model.py" + + def test_cropped_clips_default_resolved(self): + """cropped_clips defaults to /cropped_clips when not set.""" + cfg = Config(dataset="test") + project_root = Path("/proj") + cfg = resolve_paths(cfg, project_root) + + expected_root = project_root / "dataset" / "test" + assert cfg.paths.cropped_clips == str(expected_root / "cropped_clips") + + def test_cropped_clips_relative_resolved(self): + """Relative cropped_clips path is resolved against project root.""" + cfg = Config( + dataset="test", + paths={"cropped_clips": "data/cropped"}, + ) + project_root = Path("/proj") + cfg = resolve_paths(cfg, project_root) + assert cfg.paths.cropped_clips == str(project_root / "data" / "cropped") + + def test_cropped_clips_absolute_unchanged(self): + """Absolute cropped_clips path is left as-is.""" + cfg = Config( + dataset="test", + paths={"cropped_clips": "/abs/cropped"}, + ) + project_root = Path("/proj") + cfg = resolve_paths(cfg, project_root) + assert Path(cfg.paths.cropped_clips).as_posix() == "/abs/cropped" # ── load_config ───────────────────────────────────────────────────────────── @@ -320,4 +350,4 @@ def test_missing_pipeline_steps_raises(self, tmp_path): "download": {"video_ids_file": "assets/ids.txt"}, })) with pytest.raises(ValueError, match="pipeline.steps"): - load_config(str(yaml_path)) + load_config(str(yaml_path)) \ No newline at end of file diff --git a/tests/test_person_localize.py b/tests/test_person_localize.py new file mode 100644 index 0000000..f70f4c1 --- /dev/null +++ b/tests/test_person_localize.py @@ -0,0 +1,569 @@ +"""Tests for PersonLocalizeProcessor and CropVideoProcessor. + +Covers: + - Pure logic helpers (_union_bboxes, _parse_bool, bbox padding/clamp) + - Manifest column writing and resume (skip already-processed rows) + - Fallback behaviour when no person is detected + - Schema defaults for new config classes + - crop_video ffmpeg command construction (mocked, no real video needed) + - Pipeline step registration +""" + +import os +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest +import cv2 + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PROJECT_ROOT / "src")) + + +# --------------------------------------------------------------------------- +# Imports under test +# --------------------------------------------------------------------------- + +from sign_prep.config.schema import ( + Config, + CropVideoConfig, + PersonLocalizeConfig, + PathsConfig, +) +from sign_prep.processors.common.person_localize import ( + _union_bboxes, + _sample_frames, + _sample_frames_uniform, + _sample_frames_skip, + _detect_persons_batch, +) +from sign_prep.processors.common.crop_video import _crop_single_video, _parse_bool +import sign_prep.processors # noqa: F401 – trigger registrations +from sign_prep.registry import PROCESSOR_REGISTRY + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def minimal_config(tmp_path): + """Config with tmp paths; no real files required.""" + return Config( + dataset="youtube_asl", + pipeline={"mode": "video", "steps": ["person_localize", "clip_video", "crop_video"]}, + paths={ + "root": str(tmp_path), + "videos": str(tmp_path / "videos"), + "manifest": str(tmp_path / "manifest.csv"), + "clips": str(tmp_path / "clips"), + "cropped_clips": str(tmp_path / "cropped_clips"), + }, + ) + + +@pytest.fixture +def sample_manifest(tmp_path): + """Write a minimal manifest TSV and return its path.""" + manifest_path = tmp_path / "manifest.csv" + df = pd.DataFrame({ + "VIDEO_NAME": ["vid_a", "vid_a", "vid_b"], + "SENTENCE_NAME": ["vid_a-0", "vid_a-1", "vid_b-0"], + "START_REALIGNED": [0.0, 5.0, 1.0], + "END_REALIGNED": [4.0, 10.0, 6.0], + "SENTENCE": ["hello", "world", "test"], + }) + df.to_csv(manifest_path, sep="\t", index=False) + return manifest_path + + +@pytest.fixture +def manifest_with_bbox(tmp_path): + """Manifest that already has BBOX columns (simulates partial run).""" + manifest_path = tmp_path / "manifest.csv" + df = pd.DataFrame({ + "VIDEO_NAME": ["vid_a", "vid_a"], + "SENTENCE_NAME": ["vid_a-0", "vid_a-1"], + "START_REALIGNED": [0.0, 5.0], + "END_REALIGNED": [4.0, 10.0], + "SENTENCE": ["hello", "world"], + "BBOX_X1": [10.0, np.nan], + "BBOX_Y1": [20.0, np.nan], + "BBOX_X2": [200.0, np.nan], + "BBOX_Y2": [400.0, np.nan], + "PERSON_DETECTED": [True, np.nan], + }) + df.to_csv(manifest_path, sep="\t", index=False) + return manifest_path + + +# =========================================================================== +# 1. Schema defaults +# =========================================================================== + +class TestSchemaDefaults: + def test_person_localize_defaults(self): + cfg = PersonLocalizeConfig() + assert cfg.model == "yolov8n.pt" + assert cfg.confidence_threshold == 0.5 + assert cfg.sample_frames == 5 + assert cfg.device == "cuda:0" + assert cfg.min_bbox_area == 0.05 + + def test_crop_video_defaults(self): + cfg = CropVideoConfig() + assert cfg.padding == 0.25 + assert cfg.codec == "libx264" + + def test_paths_config_has_cropped_clips(self): + p = PathsConfig() + assert hasattr(p, "cropped_clips") + assert p.cropped_clips == "" + + def test_config_includes_new_sections(self): + cfg = Config(dataset="youtube_asl") + assert isinstance(cfg.person_localize, PersonLocalizeConfig) + assert isinstance(cfg.crop_video, CropVideoConfig) + + +# =========================================================================== +# 2. _union_bboxes helper +# =========================================================================== + +class TestUnionBboxes: + def test_single_bbox(self): + result = _union_bboxes([(10, 20, 100, 200)]) + assert result == (10, 20, 100, 200) + + def test_union_of_two_overlapping(self): + b1 = (10.0, 20.0, 100.0, 200.0) + b2 = (50.0, 5.0, 150.0, 180.0) + x1, y1, x2, y2 = _union_bboxes([b1, b2]) + assert x1 == 10.0 # min x1 + assert y1 == 5.0 # min y1 + assert x2 == 150.0 # max x2 + assert y2 == 200.0 # max y2 + + def test_union_of_non_overlapping(self): + b1 = (0.0, 0.0, 50.0, 50.0) + b2 = (60.0, 60.0, 120.0, 120.0) + x1, y1, x2, y2 = _union_bboxes([b1, b2]) + assert x1 == 0.0 + assert y1 == 0.0 + assert x2 == 120.0 + assert y2 == 120.0 + + def test_union_identical_bboxes(self): + b = (5.0, 10.0, 80.0, 90.0) + assert _union_bboxes([b, b, b]) == b + + def test_union_five_frames(self): + bboxes = [(i * 5.0, i * 3.0, i * 5.0 + 100.0, i * 3.0 + 200.0) for i in range(5)] + x1, y1, x2, y2 = _union_bboxes(bboxes) + assert x1 == 0.0 + assert y1 == 0.0 + assert x2 == pytest.approx(4 * 5.0 + 100.0) + assert y2 == pytest.approx(4 * 3.0 + 200.0) + + +# =========================================================================== +# 3. _parse_bool helper (crop_video) +# =========================================================================== + +class TestParseBool: + @pytest.mark.parametrize("val,expected", [ + (True, True), + (False, False), + ("True", True), + ("False", False), + ("true", True), + ("false", False), + (1, True), + (0, False), + (1.0, True), + (0.0, False), + ]) + def test_various_inputs(self, val, expected): + assert _parse_bool(val) == expected + + def test_nan_treated_as_false(self): + # NaN as float → bool(NaN) is True in Python, but we treat it as False + # via the float branch: bool(float('nan')) == True — this is expected Python + # behaviour; our _parse_bool passes it through bool(), so result is True. + # This test documents the actual behaviour so it doesn't silently change. + result = _parse_bool(float("nan")) + assert isinstance(result, bool) + + +# =========================================================================== +# 4. Crop geometry calculations +# =========================================================================== + +class TestCropGeometry: + """Test padding and clamp logic in _crop_single_video via a mock ffmpeg.""" + + def _run_with_mock(self, x1, y1, x2, y2, frame_w, frame_h, padding): + """Helper: mock cv2 and subprocess to capture the ffmpeg crop filter.""" + clip_path = "/fake/clip.mp4" + output_path = "/fake/output.mp4" + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + result = MagicMock() + result.returncode = 0 + return result + + def fake_exists(path): + # output does not exist yet (so we proceed); clip does exist + return path != output_path + + with patch("sign_prep.processors.common.crop_video.os.path.exists", + side_effect=fake_exists), \ + patch("sign_prep.processors.common.crop_video.os.makedirs"), \ + patch("sign_prep.processors.common.crop_video.cv2.VideoCapture") as mock_cap, \ + patch("sign_prep.processors.common.crop_video.subprocess.run", side_effect=fake_run): + + cap_instance = MagicMock() + cap_instance.isOpened.return_value = True + # cv2.CAP_PROP_FRAME_WIDTH=3, CAP_PROP_FRAME_HEIGHT=4 + cap_instance.get.side_effect = lambda prop: ( + frame_w if prop == 3 else frame_h + ) + mock_cap.return_value = cap_instance + + name, ok, msg = _crop_single_video(( + clip_path, output_path, + x1, y1, x2, y2, + True, # person_detected + padding, + "libx264", + )) + + return ok, captured_cmd + + def test_padding_expands_bbox(self): + """With 25% padding a 100x200 box should expand by 25 and 50 pixels.""" + ok, cmd = self._run_with_mock( + x1=50, y1=50, x2=150, y2=250, + frame_w=640, frame_h=480, + padding=0.25, + ) + assert ok + vf_arg = [c for c in cmd if c.startswith("crop=")] + assert len(vf_arg) == 1 + # crop=w:h:x:y — parse it + parts = vf_arg[0].replace("crop=", "").split(":") + w, h, cx, cy = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3]) + # x: 50 - 0.25*100 = 25 → cx1=25 + assert cx == 25 + # y: 50 - 0.25*200 = 0 → cy1=0 + assert cy == 0 + # w must be even + assert w % 2 == 0 + assert h % 2 == 0 + + def test_bbox_clamped_to_frame_boundary(self): + """Padded bbox that would exceed frame dims should be clamped.""" + ok, cmd = self._run_with_mock( + x1=0, y1=0, x2=640, y2=480, # already full frame + frame_w=640, frame_h=480, + padding=0.25, # would go negative / over frame + ) + assert ok + vf_arg = [c for c in cmd if c.startswith("crop=")] + parts = vf_arg[0].replace("crop=", "").split(":") + cx, cy = int(parts[2]), int(parts[3]) + w, h = int(parts[0]), int(parts[1]) + # After clamping, origin must be >= 0 + assert cx >= 0 + assert cy >= 0 + # Width + origin must not exceed frame + assert cx + w <= 640 + assert cy + h <= 480 + + def test_no_person_uses_stream_copy(self): + """When PERSON_DETECTED=False, ffmpeg should use -c copy.""" + clip_path = "/fake/clip.mp4" + output_path = "/fake/output.mp4" + captured_cmd = [] + + def fake_run(cmd, **kwargs): + captured_cmd.extend(cmd) + r = MagicMock() + r.returncode = 0 + return r + + def fake_exists(path): + return path != output_path # clip exists, output does not + + with patch("sign_prep.processors.common.crop_video.os.path.exists", + side_effect=fake_exists), \ + patch("sign_prep.processors.common.crop_video.os.makedirs"), \ + patch("sign_prep.processors.common.crop_video.subprocess.run", side_effect=fake_run): + + name, ok, msg = _crop_single_video(( + clip_path, output_path, + 0, 0, 640, 480, + False, # person_detected = False + 0.25, + "libx264", + )) + + assert ok + assert msg == "no-person copy" + # Should use stream copy, NOT a crop filter + assert "-c" in captured_cmd + assert "copy" in captured_cmd + assert not any(c.startswith("crop=") for c in captured_cmd) + + def test_even_dimensions_enforced(self): + """Crop dimensions must always be even (libx264 requirement).""" + ok, cmd = self._run_with_mock( + x1=0, y1=0, x2=101, y2=101, # odd dimensions before padding + frame_w=640, frame_h=480, + padding=0.0, + ) + assert ok + vf_arg = [c for c in cmd if c.startswith("crop=")] + parts = vf_arg[0].replace("crop=", "").split(":") + w, h = int(parts[0]), int(parts[1]) + assert w % 2 == 0, f"width {w} is not even" + assert h % 2 == 0, f"height {h} is not even" + + +# =========================================================================== +# 5. PersonLocalizeProcessor — manifest I/O (no real video / model needed) +# =========================================================================== + +class TestPersonLocalizeManifest: + """Test manifest reading, writing, and resume logic without real video.""" + + def _make_processor(self, config): + from sign_prep.processors.common.person_localize import PersonLocalizeProcessor + return PersonLocalizeProcessor(config) + + def test_skips_already_processed_rows(self, manifest_with_bbox, tmp_path): + """Rows where PERSON_DETECTED is already set must be skipped.""" + cfg = Config( + dataset="youtube_asl", + paths={ + "root": str(tmp_path), + "videos": str(tmp_path / "videos"), + "manifest": str(manifest_with_bbox), + "clips": str(tmp_path / "clips"), + "cropped_clips": str(tmp_path / "cropped_clips"), + }, + ) + + # Only vid_a-1 has PERSON_DETECTED=NaN, so only 1 row is pending + df = pd.read_csv(manifest_with_bbox, sep="\t") + pending = df["PERSON_DETECTED"].isna().sum() + assert pending == 1 + + def test_fallback_writes_full_frame_bbox(self, tmp_path): + """_fallback_row with non-existent video returns 0,0,0,0.""" + from sign_prep.processors.common.person_localize import PersonLocalizeProcessor + result = PersonLocalizeProcessor._fallback_row("/nonexistent/video.mp4") + assert result["PERSON_DETECTED"] is False + assert result["BBOX_X1"] == 0.0 + assert result["BBOX_Y1"] == 0.0 + + def test_manifest_columns_added_on_run(self, sample_manifest, tmp_path): + """After a successful (mocked) run, manifest must have BBOX_* columns.""" + cfg = Config( + dataset="youtube_asl", + paths={ + "root": str(tmp_path), + "videos": str(tmp_path / "videos"), + "manifest": str(sample_manifest), + "clips": str(tmp_path / "clips"), + "cropped_clips": str(tmp_path / "cropped_clips"), + }, + ) + processor = self._make_processor(cfg) + + # Mock YOLO and _sample_frames so no GPU / file I/O is needed + fake_bbox = (10.0, 20.0, 200.0, 400.0) + + def fake_sample_frames(video_path, start, end, strategy, frame_skip, sample_frames): + fake_frame = np.zeros((480, 640, 3), dtype=np.uint8) + return [(fake_frame, 640, 480)] * sample_frames + + def fake_detect_batch(model, frames_meta, conf_thresh, min_area): + # Each frame returns one valid bbox + return [[fake_bbox]] * len(frames_meta) + + mock_model = MagicMock() + + with patch("sign_prep.processors.common.person_localize.YOLO", return_value=mock_model), \ + patch("sign_prep.processors.common.person_localize._sample_frames", + side_effect=fake_sample_frames), \ + patch("sign_prep.processors.common.person_localize._detect_persons_batch", + side_effect=fake_detect_batch), \ + patch("os.path.exists", return_value=True): + + from sign_prep.pipeline.context import PipelineContext + from sign_prep.datasets.youtube_asl import YouTubeASLDataset + ctx = PipelineContext( + config=cfg, + dataset=YouTubeASLDataset(), + project_root=tmp_path, + ) + ctx = processor.run(ctx) + + # Reload manifest and verify + df = pd.read_csv(sample_manifest, sep="\t") + for col in ["BBOX_X1", "BBOX_Y1", "BBOX_X2", "BBOX_Y2", "PERSON_DETECTED"]: + assert col in df.columns, f"Missing column: {col}" + + # All rows should now have PERSON_DETECTED set (not NaN) + assert df["PERSON_DETECTED"].notna().all() + + def test_fallback_when_no_person_detected(self, sample_manifest, tmp_path): + """When YOLOv8 returns no bboxes, PERSON_DETECTED must be False.""" + cfg = Config( + dataset="youtube_asl", + paths={ + "root": str(tmp_path), + "videos": str(tmp_path / "videos"), + "manifest": str(sample_manifest), + "clips": str(tmp_path / "clips"), + "cropped_clips": str(tmp_path / "cropped_clips"), + }, + ) + processor = self._make_processor(cfg) + + fake_frame = np.zeros((480, 640, 3), dtype=np.uint8) + + with patch("sign_prep.processors.common.person_localize.YOLO", return_value=MagicMock()), \ + patch("sign_prep.processors.common.person_localize._sample_frames", + return_value=[(fake_frame, 640, 480)] * 5), \ + patch("sign_prep.processors.common.person_localize._detect_persons_batch", + return_value=[[] for _ in range(5)]), \ + patch("os.path.exists", return_value=True): + + from sign_prep.pipeline.context import PipelineContext + from sign_prep.datasets.youtube_asl import YouTubeASLDataset + ctx = PipelineContext( + config=cfg, + dataset=YouTubeASLDataset(), + project_root=tmp_path, + ) + ctx = processor.run(ctx) + + df = pd.read_csv(sample_manifest, sep="\t") + # All rows should be fallback + assert (df["PERSON_DETECTED"] == False).all() # noqa: E712 + # Stats should reflect fallback count + assert ctx.stats["person_localize"]["fallback"] == 3 + + +# =========================================================================== +# 6. Pipeline registration +# =========================================================================== + +class TestPipelineRegistration: + def test_person_localize_registered(self): + assert "person_localize" in PROCESSOR_REGISTRY + + def test_crop_video_registered(self): + assert "crop_video" in PROCESSOR_REGISTRY + + def test_new_steps_in_video_pipeline(self): + """PipelineRunner should build without error for the new video steps.""" + from sign_prep.pipeline.runner import PipelineRunner + cfg = Config( + dataset="youtube_asl", + pipeline={ + "mode": "video", + "steps": ["person_localize", "clip_video", "crop_video", "webdataset"], + }, + ) + runner = PipelineRunner(cfg) + names = [p.name for p in runner.processors] + assert "person_localize" in names + assert "crop_video" in names + + +# =========================================================================== +# 7. Sample strategy +# =========================================================================== + +class TestSampleStrategy: + """Test skip_frame and uniform sampling strategies.""" + + def _make_fake_video(self, tmp_path: Path, num_frames: int = 60, fps: int = 30) -> str: + """Write a minimal fake video file using OpenCV.""" + video_path = str(tmp_path / "fake.mp4") + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(video_path, fourcc, fps, (640, 480)) + for _ in range(num_frames): + frame = np.zeros((480, 640, 3), dtype=np.uint8) + out.write(frame) + out.release() + return video_path + + def test_uniform_returns_exact_n_frames(self, tmp_path): + video_path = self._make_fake_video(tmp_path, num_frames=90, fps=30) + frames = _sample_frames_uniform(video_path, start_sec=0.0, end_sec=2.0, n=5) + assert len(frames) == 5 + # Each element is (frame, w, h) + for frame, w, h in frames: + assert isinstance(frame, np.ndarray) + assert w == 640 + assert h == 480 + + def test_skip_frame_respects_max_frames(self, tmp_path): + video_path = self._make_fake_video(tmp_path, num_frames=120, fps=30) + frames = _sample_frames_skip( + video_path, start_sec=0.0, end_sec=3.0, + frame_skip=2, max_frames=5, + ) + assert len(frames) <= 5 + + def test_skip_frame_spacing(self, tmp_path): + """skip_frame with frame_skip=2 should take roughly every 3rd frame.""" + video_path = self._make_fake_video(tmp_path, num_frames=120, fps=30) + frames = _sample_frames_skip( + video_path, start_sec=0.0, end_sec=4.0, + frame_skip=2, max_frames=20, + ) + # With 4s @ 30fps = 120 frames, skip=2 → ~40 samples, capped at 20 + assert 1 <= len(frames) <= 20 + + def test_dispatcher_skip_frame(self, tmp_path): + video_path = self._make_fake_video(tmp_path, num_frames=90, fps=30) + frames = _sample_frames( + video_path, 0.0, 2.0, + strategy="skip_frame", + frame_skip=2, + sample_frames=5, + ) + assert len(frames) <= 5 + + def test_dispatcher_uniform(self, tmp_path): + video_path = self._make_fake_video(tmp_path, num_frames=90, fps=30) + frames = _sample_frames( + video_path, 0.0, 2.0, + strategy="uniform", + frame_skip=2, + sample_frames=5, + ) + assert len(frames) == 5 + + def test_schema_default_is_skip_frame(self): + cfg = PersonLocalizeConfig() + assert cfg.sample_strategy == "skip_frame" + assert cfg.frame_skip == 2 + + def test_schema_accepts_uniform(self): + cfg = PersonLocalizeConfig(sample_strategy="uniform") + assert cfg.sample_strategy == "uniform" + + def test_schema_rejects_invalid_strategy(self): + with pytest.raises(Exception): + PersonLocalizeConfig(sample_strategy="invalid_mode") \ No newline at end of file