Conversation
67233e9 to
e281524
Compare
e281524 to
349d080
Compare
There was a problem hiding this comment.
Could we move the assets to a public GCS path or use an existing hf dataset link?
| f"{_frame_summary('mask', mask_video)}" | ||
| ) | ||
|
|
||
| animate_settings = _get_animate_inference_settings(config) |
There was a problem hiding this comment.
Could you also add lora support?
There was a problem hiding this comment.
I wonder if there is a need for a separate generate script? Can we add this to existing generate_wan.py file?
There was a problem hiding this comment.
Animate needs different lora for replacement, will add it in next pr.
For generate script, the inputs are different, we have ref image, video and replacement image, video... so I think it's better to have a different one.
|
Please resolve conflicts and enable support for diagnostics and profiling as in this PR |
… into sagarchapara/wananimate-pipeline # Conflicts: # .gitignore # src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py
| # RL pipelines may produce mp4 outputs | ||
| *.mp4 | ||
| !assets/wan_animate/**/*.mp4 | ||
| assets/wan_animate/ |
There was a problem hiding this comment.
let's remove the folder assests/wan_animate all together and remove this
| max_logging.log(f"Saved video to {video_path}") | ||
|
|
||
| if getattr(config, "enable_profiler", False): | ||
| if max_utils.profiler_enabled(config): |
There was a problem hiding this comment.
Does this support both ml-diagnostics and enable-profiler options?
| reference_image_path: "assets/wan_animate/src_ref.png" | ||
| pose_video_path: "assets/wan_animate/src_pose.mp4" | ||
| face_video_path: "assets/wan_animate/src_face.mp4" | ||
| reference_image_path: "" |
There was a problem hiding this comment.
Can we point to some default image path in huggingface?
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
The Pull Request introduces the Wan Animate pipeline, which includes the transformer model architecture, inference entry point, and necessary utilities. The implementation is comprehensive and follows the established patterns in the repository, including support for segment-based inference and parity with Diffusers.
🔍 General Feedback
- Performance Optimization: The current implementation of the transformer re-encodes the face video frames during every denoising step. Since the face video is static throughout the inference process, this encoding can be pre-computed once per segment to significantly reduce redundant computation and speed up generation.
- Compilation Efficiency: The generation script performs two full inference passes. For high-resolution video generation, this double work is expensive. Consider reducing the number of steps in the first (compile) pass.
- Robustness: Added checks for optional inputs in the transformer to prevent potential runtime errors when
face_pixel_valuesis not provided. - Code Quality: The reuse of the Wan attention operator and the integration with the existing configuration system is well-done. The use of
nnx.scanfor transformer blocks ensures memory efficiency during inference.
| num_inference_steps=config.num_inference_steps, | ||
| mode=mode, | ||
| ) | ||
|
|
There was a problem hiding this comment.
🟡 The script performs two full inference passes. While the first pass is intended for compilation timing, it could be optimized by using a smaller number of steps (e.g., 1 or 2) or a dummy input to reduce total execution time, especially for high-resolution generations where inference is expensive.
| # First pass (compile with minimal steps to save time) | |
| compile_config = deepcopy(config) | |
| compile_config.num_inference_steps = 1 | |
| _ = pipeline( | |
| image=image, | |
| pose_video=pose_video, | |
| face_video=face_video, | |
| background_video=background_video, | |
| mask_video=mask_video, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| segment_frame_length=animate_settings["segment_frame_length"], | |
| prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"], | |
| motion_encode_batch_size=animate_settings["motion_encode_batch_size"], | |
| guidance_scale=animate_settings["guidance_scale"], | |
| num_inference_steps=1, | |
| mode=mode, | |
| ) |
|
|
||
| query = self.to_q(hidden_states) | ||
| key = self.to_k(encoder_hidden_states) | ||
| value = self.to_v(encoder_hidden_states) |
There was a problem hiding this comment.
🟠 Performance: The face video and motion encoding are performed during every denoising step. Since these depend only on the static face video and are independent of the noisy latents and timestep, they should be pre-computed outside the denoising loop to significantly improve performance.
| value = self.to_v(encoder_hidden_states) | |
| # 4. Batched Face & Motion Encoding (Pre-compute this outside the denoising loop if possible) | |
| if face_pixel_values is not None: | |
| _, face_channels, num_face_frames, face_height, face_width = face_pixel_values.shape | |
| # Rearrange from (B, C, T, H, W) to (B*T, C, H, W) | |
| face_pixel_values_reshaped = jnp.transpose(face_pixel_values, (0, 2, 1, 3, 4)) | |
| face_pixel_values_reshaped = jnp.reshape(face_pixel_values_reshaped, (-1, face_channels, face_height, face_width)) | |
| total_face_frames = face_pixel_values_reshaped.shape[0] | |
| motion_encode_batch_size = motion_encode_batch_size or self.motion_encoder_batch_size | |
| # Pad sequence if it doesn't divide evenly by encode_bs | |
| pad_len = (motion_encode_batch_size - (total_face_frames % motion_encode_batch_size)) % motion_encode_batch_size | |
| if pad_len > 0: | |
| pad_tensor = jnp.zeros( | |
| (pad_len, face_channels, face_height, face_width), | |
| dtype=face_pixel_values_reshaped.dtype, | |
| ) | |
| face_pixel_values_reshaped = jnp.concatenate([face_pixel_values_reshaped, pad_tensor], axis=0) | |
| # Reshape into chunks for scan | |
| num_chunks = face_pixel_values_reshaped.shape[0] // motion_encode_batch_size | |
| face_chunks = jnp.reshape( | |
| face_pixel_values_reshaped, | |
| ( | |
| num_chunks, | |
| motion_encode_batch_size, | |
| face_channels, | |
| face_height, | |
| face_width, | |
| ), | |
| ) | |
| # Use jax.lax.scan to iterate over chunks to save memory | |
| def encode_chunk_fn(carry, chunk): | |
| encoded_chunk = self.motion_encoder(chunk) | |
| return carry, encoded_chunk | |
| _, motion_vec_chunks = jax.lax.scan(encode_chunk_fn, None, face_chunks) | |
| motion_vec = jnp.reshape(motion_vec_chunks, (-1, motion_vec_chunks.shape[-1])) | |
| # Remove padding if added | |
| if pad_len > 0: | |
| motion_vec = motion_vec[:-pad_len] | |
| motion_vec = jnp.reshape(motion_vec, (batch_size, num_face_frames, -1)) | |
| # Apply face encoder | |
| motion_vec = self.face_encoder(motion_vec) | |
| pad_face = jnp.zeros_like(motion_vec[:, :1]) | |
| motion_vec = jnp.concatenate([pad_face, motion_vec], axis=1) | |
| else: | |
| motion_vec = None |
Wan Animate Pipeline
This CL publishes add the Wan Animate pipepline.
Links
Performance
Configuration