[MAX] Add WanTokenizer and WanContext for Wan architecture#22
[MAX] Add WanTokenizer and WanContext for Wan architecture#22jglee-sqbits wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Wan-specific architecture components, including WanContext and WanTokenizer, to support video generation and MoE features. Feedback highlights several areas for improvement: WanContext contains redundant fields already defined in its base class, and WanTokenizer includes duplicated logic for flow shift and latent calculations that should be consolidated. Additionally, the temporary patching of _class_name during tokenizer initialization is flagged as potentially unnecessary complexity.
| num_frames: int | None = field(default=None) | ||
| """Number of frames for video generation.""" | ||
|
|
||
| guidance_scale_2: float | None = field(default=None) | ||
| """Secondary guidance scale for low-noise expert (MoE models).""" | ||
|
|
||
| step_coefficients: npt.NDArray[np.float32] | None = field(default=None) | ||
| """Pre-computed scheduler step coefficients.""" | ||
|
|
||
| boundary_timestep: float | None = field(default=None) | ||
| """Timestep threshold for switching between high/low noise experts.""" |
There was a problem hiding this comment.
The fields num_frames, guidance_scale_2, step_coefficients, and boundary_timestep are already defined in the base class PixelContext. Redefining them here is redundant and can lead to maintenance issues. If the intention is to move these fields out of the shared PixelContext, they should be removed from the base class in this PR. Otherwise, they should be removed from WanContext to avoid shadowing.
| diffusers_config = pipeline_config.model.diffusers_config | ||
| original_class_name = diffusers_config.get("_class_name") | ||
| diffusers_config["_class_name"] = "FluxPipeline" | ||
| try: | ||
| super().__init__( | ||
| model_path, pipeline_config, subfolder, **kwargs | ||
| ) | ||
| finally: | ||
| diffusers_config["_class_name"] = original_class_name |
There was a problem hiding this comment.
This temporary patching of _class_name to "FluxPipeline" appears unnecessary as WanPipeline and WanImageToVideoPipeline are still present in the PipelineClassName enum in pixel_tokenizer.py. If the goal is to remove them from the enum, that change should be included in this PR. Otherwise, this logic should be removed to avoid unnecessary complexity.
| base = await super().new_context(request, input_image=input_image) | ||
|
|
||
| video_options = request.body.provider_options.video | ||
| image_options = request.body.provider_options.image | ||
|
|
||
| num_frames: int | None = ( | ||
| video_options.num_frames if video_options else None | ||
| ) | ||
| guidance_scale_2: float | None = ( | ||
| video_options.guidance_scale_2 if video_options else None | ||
| ) | ||
|
|
||
| height = base.height | ||
| width = base.width | ||
| timesteps: npt.NDArray[np.float32] = base.timesteps | ||
| sigmas: npt.NDArray[np.float32] = base.sigmas | ||
|
|
||
| if getattr(self._scheduler, "use_flow_sigmas", False): | ||
| self._scheduler.flow_shift = self._select_wan_flow_shift( | ||
| height, width | ||
| ) | ||
| latent_height = 2 * (int(height) // (self._vae_scale_factor * 2)) | ||
| latent_width = 2 * (int(width) // (self._vae_scale_factor * 2)) | ||
| image_seq_len = (latent_height // 2) * (latent_width // 2) | ||
| timesteps, sigmas = self._scheduler.retrieve_timesteps_and_sigmas( | ||
| image_seq_len, base.num_inference_steps | ||
| ) | ||
|
|
||
| boundary_timestep: float | None = None | ||
| boundary_ratio = self.diffusers_config.get("boundary_ratio") | ||
| if boundary_ratio is not None: | ||
| boundary_timestep = float(boundary_ratio) * float( | ||
| getattr(self._scheduler, "num_train_timesteps", 1000) | ||
| ) | ||
|
|
||
| step_coefficients: npt.NDArray[np.float32] | None = None | ||
| if hasattr(self._scheduler, "build_step_coefficients"): | ||
| step_coefficients = self._scheduler.build_step_coefficients() | ||
|
|
||
| latents = base.latents | ||
| if num_frames is not None: | ||
| vae_scale_factor_temporal = 4 | ||
| latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 | ||
| latent_height = 2 * (int(height) // (self._vae_scale_factor * 2)) | ||
| latent_width = 2 * (int(width) // (self._vae_scale_factor * 2)) | ||
| shape_5d = ( | ||
| image_options.num_images, | ||
| self._num_channels_latents, | ||
| latent_frames, | ||
| latent_height, | ||
| latent_width, | ||
| ) | ||
| latents = self._randn_tensor(shape_5d, request.body.seed) |
There was a problem hiding this comment.
The logic in new_context for calculating flow shift, sigmas, and reshaping latents is currently duplicated from the base class PixelGenerationTokenizer.new_context. Since super().new_context() is called at the beginning, it already populates base with these values. To properly extract the logic as intended, the Wan-specific branches should be removed from the base class. Additionally, latent_height and latent_width are calculated twice within this method (lines 113-114 and 135-136); they should be calculated once and reused.
af48815 to
9f5b0c3
Compare
a44a04b to
496bb7d
Compare
Signed-off-by: jglee-sqbits <jingu.lee@squeezebits.com>
496bb7d to
d48a369
Compare
Summary
WanTokenizerandWanContextfor Wan video generation architectureChanges
New files:
wan/context.py—WanContext(PixelContext)withnum_frames,guidance_scale_2,step_coefficients,boundary_timestepwan/tokenizer.py—WanTokenizer(PixelGenerationTokenizer)overridingnew_context()for flow shift, MoE boundary, 5D latent reshapeNotes
arch.pyand shared code cleanup will be done in [MAX] Add Wan T2V diffusion pipeline with MoE support modular/modular#6302 update.Test plan
simple_offline_video_generation.py— 480p 17-frame video generated successfully on H200