Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any

import torch
import torch.nn.functional as F
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
PipelineState,
Expand Down Expand Up @@ -708,6 +709,34 @@ def _encode_with_conditioning(self, components, block_state, current_start):

batch_size, channels, num_frames, height, width = input_frames_data.shape

# Guard against spatial underflow: the WAN VAE encoder contains a 3×3 spatial
# convolution kernel. If either spatial dimension is < 3 the forward pass
# will raise:
# RuntimeError: Calculated padded input size per channel: (2 x 513).
# Kernel size: (3 x 3). Kernel size can't be greater than actual input size
# Observed in prod (2026-03-15) on krea-realtime-video with job
# 5193400c-da0f-4eef-8bdd-dd0fdd26c1db: 2 372 errors over 11 minutes.
# This is the spatial analogue of the 3×1×1 temporal kernel guard (issue #673).
# Pad to the minimum safe size rather than hard-crashing.
# Related: #557 (same block — spatial kernel underflow)
_MIN_SPATIAL = 3 # WAN VAE first-layer 3×3 conv
if height < _MIN_SPATIAL or width < _MIN_SPATIAL:
new_h = max(height, _MIN_SPATIAL)
new_w = max(width, _MIN_SPATIAL)
logger.warning(
f"VaceEncodingBlock._encode_with_conditioning: vace_input_frames spatial "
f"dimensions ({height}×{width}) are below the 3×3 spatial convolution "
f"kernel minimum. Padding input from ({height}×{width}) to ({new_h}×{new_w})."
)
# F.pad takes dims in reverse order: (W_left, W_right, H_top, H_bottom, ...)
input_frames_data = F.pad(
input_frames_data, (0, new_w - width, 0, new_h - height)
)
# Also patch block_state so the downstream resolution check passes
block_state.height = new_h
block_state.width = new_w
height, width = new_h, new_w

# Validate resolution
if height != block_state.height or width != block_state.width:
raise ValueError(
Expand Down Expand Up @@ -767,6 +796,13 @@ def _encode_with_conditioning(self, components, block_state, current_start):
raise ValueError(
f"VaceEncodingBlock._encode_with_conditioning: vace_input_masks must have 1 channel, got {mask_channels}"
)
# Spatially pad mask to match frames if we padded the frames above
if mask_height < height or mask_width < width:
input_masks_data = F.pad(
input_masks_data,
(0, width - mask_width, 0, height - mask_height),
)
mask_height, mask_width = height, width
if (
mask_frames != num_frames
or mask_height != height
Expand Down
Loading