Skip to content

[MAX] Add WanTokenizer and WanContext for Wan architecture#22

Open
jglee-sqbits wants to merge 1 commit into
mainfrom
add/wan/tokenizer
Open

[MAX] Add WanTokenizer and WanContext for Wan architecture#22
jglee-sqbits wants to merge 1 commit into
mainfrom
add/wan/tokenizer

Conversation

@jglee-sqbits
Copy link
Copy Markdown
Collaborator

@jglee-sqbits jglee-sqbits commented Apr 9, 2026

Summary

  • Add WanTokenizer and WanContext for Wan video generation architecture
  • Follow the architecture-specific tokenizer pattern (like Qwen) instead of adding model-specific branches to shared code

Changes

New files:

  • wan/context.pyWanContext(PixelContext) with num_frames, guidance_scale_2, step_coefficients, boundary_timestep
  • wan/tokenizer.pyWanTokenizer(PixelGenerationTokenizer) overriding new_context() for flow shift, MoE boundary, 5D latent reshape

Notes

Test plan

  • Verified with simple_offline_video_generation.py — 480p 17-frame video generated successfully on H200

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Wan-specific architecture components, including WanContext and WanTokenizer, to support video generation and MoE features. Feedback highlights several areas for improvement: WanContext contains redundant fields already defined in its base class, and WanTokenizer includes duplicated logic for flow shift and latent calculations that should be consolidated. Additionally, the temporary patching of _class_name during tokenizer initialization is flagged as potentially unnecessary complexity.

Comment on lines +29 to +39
num_frames: int | None = field(default=None)
"""Number of frames for video generation."""

guidance_scale_2: float | None = field(default=None)
"""Secondary guidance scale for low-noise expert (MoE models)."""

step_coefficients: npt.NDArray[np.float32] | None = field(default=None)
"""Pre-computed scheduler step coefficients."""

boundary_timestep: float | None = field(default=None)
"""Timestep threshold for switching between high/low noise experts."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fields num_frames, guidance_scale_2, step_coefficients, and boundary_timestep are already defined in the base class PixelContext. Redefining them here is redundant and can lead to maintenance issues. If the intention is to move these fields out of the shared PixelContext, they should be removed from the base class in this PR. Otherwise, they should be removed from WanContext to avoid shadowing.

Comment on lines +48 to +56
diffusers_config = pipeline_config.model.diffusers_config
original_class_name = diffusers_config.get("_class_name")
diffusers_config["_class_name"] = "FluxPipeline"
try:
super().__init__(
model_path, pipeline_config, subfolder, **kwargs
)
finally:
diffusers_config["_class_name"] = original_class_name
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This temporary patching of _class_name to "FluxPipeline" appears unnecessary as WanPipeline and WanImageToVideoPipeline are still present in the PipelineClassName enum in pixel_tokenizer.py. If the goal is to remove them from the enum, that change should be included in this PR. Otherwise, this logic should be removed to avoid unnecessary complexity.

Comment on lines +92 to +144
base = await super().new_context(request, input_image=input_image)

video_options = request.body.provider_options.video
image_options = request.body.provider_options.image

num_frames: int | None = (
video_options.num_frames if video_options else None
)
guidance_scale_2: float | None = (
video_options.guidance_scale_2 if video_options else None
)

height = base.height
width = base.width
timesteps: npt.NDArray[np.float32] = base.timesteps
sigmas: npt.NDArray[np.float32] = base.sigmas

if getattr(self._scheduler, "use_flow_sigmas", False):
self._scheduler.flow_shift = self._select_wan_flow_shift(
height, width
)
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
timesteps, sigmas = self._scheduler.retrieve_timesteps_and_sigmas(
image_seq_len, base.num_inference_steps
)

boundary_timestep: float | None = None
boundary_ratio = self.diffusers_config.get("boundary_ratio")
if boundary_ratio is not None:
boundary_timestep = float(boundary_ratio) * float(
getattr(self._scheduler, "num_train_timesteps", 1000)
)

step_coefficients: npt.NDArray[np.float32] | None = None
if hasattr(self._scheduler, "build_step_coefficients"):
step_coefficients = self._scheduler.build_step_coefficients()

latents = base.latents
if num_frames is not None:
vae_scale_factor_temporal = 4
latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
shape_5d = (
image_options.num_images,
self._num_channels_latents,
latent_frames,
latent_height,
latent_width,
)
latents = self._randn_tensor(shape_5d, request.body.seed)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in new_context for calculating flow shift, sigmas, and reshaping latents is currently duplicated from the base class PixelGenerationTokenizer.new_context. Since super().new_context() is called at the beginning, it already populates base with these values. To properly extract the logic as intended, the Wan-specific branches should be removed from the base class. Additionally, latent_height and latent_width are calculated twice within this method (lines 113-114 and 135-136); they should be calculated once and reused.

@jglee-sqbits jglee-sqbits changed the base branch from add/wan/pipeline-t2v to main April 9, 2026 06:01
@jglee-sqbits jglee-sqbits force-pushed the add/wan/tokenizer branch 4 times, most recently from a44a04b to 496bb7d Compare April 9, 2026 09:09
Signed-off-by: jglee-sqbits <jingu.lee@squeezebits.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant