From 252c089c60a32ab99c9145e55f81c936c29868e3 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 29 May 2025 12:34:37 +0400 Subject: [PATCH 1/2] add wan2.1 notebook --- .ci/spellcheck/.pyspelling.wordlist.txt | 2 + notebooks/wan2.1-text-to-video/README.md | 35 + .../wan2.1-text-to-video/gradio_helper.py | 59 ++ .../wan2.1-text-to-video/ov_wan_helper.py | 552 +++++++++++++ .../wan2.1-text-to-video.ipynb | 759 ++++++++++++++++++ 5 files changed, 1407 insertions(+) create mode 100644 notebooks/wan2.1-text-to-video/README.md create mode 100644 notebooks/wan2.1-text-to-video/gradio_helper.py create mode 100644 notebooks/wan2.1-text-to-video/ov_wan_helper.py create mode 100644 notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb diff --git a/.ci/spellcheck/.pyspelling.wordlist.txt b/.ci/spellcheck/.pyspelling.wordlist.txt index 319af808147..b113b0b66ae 100644 --- a/.ci/spellcheck/.pyspelling.wordlist.txt +++ b/.ci/spellcheck/.pyspelling.wordlist.txt @@ -84,6 +84,7 @@ blockwise BLACKBOX boolean CatVTON +CausVid CentOS centric CFG @@ -941,6 +942,7 @@ sparsity Sparisty sparsified sparsify +spatio spatiotemporal spectrogram spectrograms diff --git a/notebooks/wan2.1-text-to-video/README.md b/notebooks/wan2.1-text-to-video/README.md new file mode 100644 index 00000000000..4fac72338f3 --- /dev/null +++ b/notebooks/wan2.1-text-to-video/README.md @@ -0,0 +1,35 @@ +# Text to Video generation with Wan2.1 and OpenVINO + + Wan2.1 is a comprehensive and open suite of video foundation models that pushes the boundaries of video generation. + + Built upon the mainstream diffusion transformer paradigm, Wan 2.1 achieves significant advancements in generative capabilities through a series of innovations, including our novel spatio-temporal variational autoencoder (VAE), scalable pre-training strategies, large-scale data construction, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility. + + You can find more details about model in [model card](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) and [original repository](https://github.com/Wan-Video/Wan2.1) + + In this tutorial we consider how to convert, optimize and run Wan2.1 model using OpenVINO. + Additionally, for achieving inference speedup, we will apply [CausVid](https://causvid.github.io/) distillation approach using LoRA. + + ![](https://causvid.github.io/images/methods.jpg) + + Current video diffusion models achieve impressive generation quality but struggle in interactive applications due to bidirectional attention dependencies. The generation of a single frame requires the model to process the entire sequence, including the future. CausVid address this limitation by adapting a pretrained bidirectional diffusion transformer to an autoregressive transformer that generates frames on-the-fly. To further reduce latency, the authors extend distribution matching distillation (DMD) to videos, distilling 50-step diffusion model into a 4-step generator. + + The method distills a many-step, bidirectional video diffusion model into a 4-step, causal generator. The training process consists of two stages: + 1. Student Initialization: Initialization of the causal student by pretraining it on a small set of ODE solution pairs generated by the bidirectional teacher. This step helps stabilize the subsequent distillation training. + 2. Asymmetric Distillation: Using the bidirectional teacher model, we train the causal student generator through a distribution matching distillation loss. + +More details about CausVid can be found in the [paper](https://arxiv.org/abs/2412.07772), [original repository](https://github.com/tianweiy/CausVi) and [project page](https://causvid.github.io/) + + +## Notebook contents +This tutorial consists of the following steps: +- Prerequisites +- Convert and Optimize model +- Run inference pipeline +- Interactive inference + +## Installation instructions +This is a self-contained example that relies solely on its own code.
+We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. +For details, please refer to [Installation Guide](../../README.md). + + diff --git a/notebooks/wan2.1-text-to-video/gradio_helper.py b/notebooks/wan2.1-text-to-video/gradio_helper.py new file mode 100644 index 00000000000..595ec2c3733 --- /dev/null +++ b/notebooks/wan2.1-text-to-video/gradio_helper.py @@ -0,0 +1,59 @@ +import gradio as gr +import torch +from diffusers.utils import export_to_video +import numpy as np + +MAX_SEED = np.iinfo(np.int32).max + + +def make_demo(pipeline): + def generate_video(prompt, negative_prompt="", guidance_scale=1.0, seed=42, progress=gr.Progress(track_tqdm=True)): + output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=480, + width=832, + num_frames=20, + guidance_scale=guidance_scale, + num_inference_steps=4, + generator=torch.Generator().manual_seed(seed), + ).frames[0] + + video_path = "output.mp4" + export_to_video(output, video_path, fps=10) + return video_path + + iface = gr.Interface( + fn=generate_video, + inputs=[ + gr.Textbox(label="Prompt", placeholder="Enter your video prompt here"), + gr.Textbox(label="Negative Prompt", placeholder="Optional negative prompt", value=""), + gr.Slider( + label="Guidance scale", + minimum=0.0, + maximum=20.0, + step=0.1, + value=1.0, + ), + gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=42, + ), + ], + outputs=gr.Video(label="Generated Video"), + title="Wan2.1-T2V-1.3B OpenVINO Video Generator", + flagging_mode="never", + examples=[ + ["a penguin playfully dancing in the snow, Antarctica", "", 1.0, 42], + [ + "A cat walks on the grass, realistic", + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", + 2.5, + 678, + ], + ], + ) + return iface diff --git a/notebooks/wan2.1-text-to-video/ov_wan_helper.py b/notebooks/wan2.1-text-to-video/ov_wan_helper.py new file mode 100644 index 00000000000..59fe7e617dc --- /dev/null +++ b/notebooks/wan2.1-text-to-video/ov_wan_helper.py @@ -0,0 +1,552 @@ +from pathlib import Path +import gc + +import torch +from diffusers import AutoencoderKLWan, WanPipeline, DiffusionPipeline, UniPCMultistepScheduler +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer + +import nncf +import openvino as ov + +from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder +from openvino.frontend.pytorch.patch_model import __make_16bit_traceable +from dataclasses import dataclass +from typing import Optional, Union +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +import ftfy +import regex as re +import html +from huggingface_hub import hf_hub_download + + +def cleanup_torchscript_cache(): + """ + Helper for removing cached model representation + """ + torch._C._jit_clear_class_registry() + torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._clear_class_state() + + +TEXT_ENCODER_PATH = "text_encoder.xml" +VAE_DECODER_PATH = "vae_decoder.xml" +TRANSFORMER_PATH = "transformer.xml" + +LORA_REPO_ID = "Kijai/WanVideo_comfy" +lora_filE_name = "Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors" + + +def convert_pipeline(model_id, output_dir, apply_lora=True, compression_config=None): + output_dir = Path(output_dir) + + if all([(output_dir / model_path).exists() for model_path in [TEXT_ENCODER_PATH, VAE_DECODER_PATH, TRANSFORMER_PATH]]): + print(f"✅ {model_id} model already converted. You can find results in {output_dir}") + return + + print(f"⌛ {model_id} conversion started. Be patient, it may takes some time.") + print("⌛ Load Original model") + vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.float16) + if apply_lora: + causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=lora_filE_name) + pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") + pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) + pipe.fuse_lora() + + transformer = pipe.transformer + transformer.eval() + vae = pipe.vae + vae.eval() + text_encoder = pipe.text_encoder + text_encoder.eval() + tokenizer = pipe.tokenizer + scheduler = pipe.scheduler + tokenizer.save_pretrained(output_dir / "tokenizer") + scheduler.save_pretrained(output_dir / "scheduler") + del pipe + gc.collect() + + if not (output_dir / TRANSFORMER_PATH).exists(): + print("⌛ Convert Transformer model") + transformer_inputs = { + "hidden_states": torch.ones([1, 16, 5, 60, 104]), + "timestep": torch.tensor([1000.0]), + "encoder_hidden_states": torch.ones([1, 512, 4096]), + } + transformer.eval() + __make_16bit_traceable(transformer) + ts_decoder = TorchScriptPythonDecoder(transformer, example_input=transformer_inputs, trace_kwargs={"check_trace": False}) + with torch.no_grad(): + ov_model = ov.convert_model(ts_decoder, example_input=transformer_inputs) + if compression_config is not None: + ov_model = nncf.compress_weights(ov_model, **compression_config) + ov.save_model(ov_model, output_dir / TRANSFORMER_PATH) + del ov_model + cleanup_torchscript_cache() + print("✅ Transformer model successfully converted") + + del transformer + gc.collect() + if not (output_dir / TEXT_ENCODER_PATH).exists(): + print("⌛ Convert Text Encoder model") + __make_16bit_traceable(text_encoder) + with torch.no_grad(): + ov_model = ov.convert_model(text_encoder, example_input=torch.ones((1, 226), dtype=torch.long)) + if compression_config is not None: + ov_model = nncf.compress_weights(ov_model, **compression_config) + ov.save_model(ov_model, output_dir / TEXT_ENCODER_PATH) + del ov_model + cleanup_torchscript_cache() + print(f"✅ Text Encoder successfully converted") + del text_encoder + gc.collect() + + if not (output_dir / VAE_DECODER_PATH).exists(): + print("⌛ Convert VAE Decoder model") + vae.forward = vae.decode + __make_16bit_traceable(vae) + for up_block in vae.decoder.up_blocks: + if up_block.upsamplers is not None: + up_block.upsamplers[0].resample[0].mode = "nearest" + with torch.no_grad(): + ov_model = ov.convert_model(vae, example_input=(torch.ones((1, 16, 5, 60, 104)))) + if compression_config is not None: + ov_model = nncf.compress_weights(ov_model, **compression_config) + ov.save_model(ov_model, output_dir / VAE_DECODER_PATH) + cleanup_torchscript_cache() + print(f"✅ VAE Decoder successfully converted") + del vae + gc.collect() + print(f"✅ Model successfully converted and can be found in {output_dir}") + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for Wan pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +core = ov.Core() + + +class OVWanPipeline(DiffusionPipeline): + def __init__(self, model_dir, device_map="CPU", ov_config=None): + model_dir = Path(model_dir) + tokenizer = AutoTokenizer.from_pretrained(model_dir / "tokenizer") + scheduler = UniPCMultistepScheduler.from_pretrained(model_dir / "scheduler") + if isinstance(device_map, str): + device_map = {"transformer": device_map, "text_encoder": device_map, "vae": device_map} + transformer_model = core.read_model(model_dir / TRANSFORMER_PATH) + transformer = core.compile_model(transformer_model, device_map["transformer"], ov_config) + text_encoder_model = core.read_model(model_dir / TEXT_ENCODER_PATH) + text_encoder = core.compile_model(text_encoder_model, device_map["text_encoder"], ov_config) + vae = core.compile_model(model_dir / VAE_DECODER_PATH, device_map["vae"], ov_config) + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.z_dim = 16 + self.latents_mean = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + self.latents_std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.916] + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, list[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = torch.from_numpy(self.text_encoder(text_input_ids)[0]) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, list[str]], + negative_prompt: Optional[Union[str, list[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}.") + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.") + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and (not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=torch.device("cpu"), dtype=torch.float32) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, list[str]] = None, + negative_prompt: Union[str, list[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = None + self._current_timestep = None + self._interrupt = False + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = 16 + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents + timestep = t.expand(latents.shape[0]) + + noise_pred = torch.from_numpy(self.transformer([latent_model_input, timestep, prompt_embeds])[0]) + + if self.do_classifier_free_guidance: + noise_uncond = torch.from_numpy(self.transformer([latent_model_input, timestep, negative_prompt_embeds])[0]) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + video = torch.from_numpy(self.vae(latents)[0]) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb b/notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb new file mode 100644 index 00000000000..79adb1ab507 --- /dev/null +++ b/notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb @@ -0,0 +1,759 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Text to Video generation with Wan2.1 and OpenVINO\n", + "\n", + " Wan2.1 is a comprehensive and open suite of video foundation models that pushes the boundaries of video generation.\n", + "\n", + " Built upon the mainstream diffusion transformer paradigm, Wan 2.1 achieves significant advancements in generative capabilities through a series of innovations, including our novel spatio-temporal variational autoencoder (VAE), scalable pre-training strategies, large-scale data construction, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility.\n", + "\n", + " You can find more details about model in [model card](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) and [original repository](https://github.com/Wan-Video/Wan2.1)\n", + "\n", + " In this tutorial we consider how to convert, optimize and run Wan2.1 model using OpenVINO.\n", + " Additionally, for achieving inference speedup, we will apply [CausVid](https://causvid.github.io/) distillation approach using LoRA.\n", + "\n", + " ![](https://causvid.github.io/images/methods.jpg)\n", + "\n", + " Current video diffusion models achieve impressive generation quality but struggle in interactive applications due to bidirectional attention dependencies. The generation of a single frame requires the model to process the entire sequence, including the future. CausVid address this limitation by adapting a pretrained bidirectional diffusion transformer to an autoregressive transformer that generates frames on-the-fly. To further reduce latency, the authors extend distribution matching distillation (DMD) to videos, distilling 50-step diffusion model into a 4-step generator.\n", + "\n", + " The method distills a many-step, bidirectional video diffusion model into a 4-step, causal generator. The training process consists of two stages: \n", + " 1. Student Initialization: Initialization of the causal student by pretraining it on a small set of ODE solution pairs generated by the bidirectional teacher. This step helps stabilize the subsequent distillation training.\n", + " 2. Asymmetric Distillation: Using the bidirectional teacher model, we train the causal student generator through a distribution matching distillation loss. \n", + "\n", + "More details about CausVid can be found in the [paper](https://arxiv.org/abs/2412.07772), [original repository](https://github.com/tianweiy/CausVi) and [project page](https://causvid.github.io/)\n", + "#### Table of contents:\n", + "\n", + "- [Prerequisites](#Prerequisites)\n", + "- [Convert model to OpenVINO Intermediate Representation](#Convert-model-to-OpenVINO-Intermediate-Representation)\n", + " - [Compress model weights](#Compress-model-weights)\n", + "- [Prepare model inference pipeline](#Prepare-model-inference-pipeline)\n", + " - [Select inference device](#Select-inference-device)\n", + "- [Run OpenVINO Model Inference](#Run-OpenVINO-Model-Inference)\n", + "- [Interactive demo](#Interactive-demo)\n", + "\n", + "\n", + "### Installation Instructions\n", + "\n", + "This is a self-contained example that relies solely on its own code.\n", + "\n", + "We recommend running the notebook in a virtual environment. You only need a Jupyter server to start.\n", + "For details, please refer to [Installation Guide](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/README.md#-installation-guide).\n", + "\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -q \"torch>=2.1\" \"git+https://github.com/huggingface/diffusers.git\" \"transformers>=4.49.0\" \"accelerate\" \"safetensors\" \"sentencepiece\" \"peft>=0.7.0\" \"ftfy\" \"gradio>=4.19\" \"opencv-python\" --extra-index-url https://download.pytorch.org/whl/cpu\n", + "%pip install --pre -U \"openvino>=2025.1.0\" \"nncf>=2.16.0\" --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from pathlib import Path\n", + "\n", + "if not Path(\"ov_wan_helper.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/wan2.1-text-to-video/ov_wan_helper.py\")\n", + " open(\"ov_wan_helper.py\", \"w\").write(r.text)\n", + "\n", + "if not Path(\"gradio_helper.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/wan2.1-text-to-video/gradio_helper.py\")\n", + " open(\"gradio_helper.py\", \"w\").write(r.text)\n", + "\n", + "if not Path(\"notebook_utils.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py\")\n", + " open(\"notebook_utils.py\", \"w\").write(r.text)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert model to OpenVINO Intermediate Representation\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "\n", + "Wan2.1 is PyTorch model. OpenVINO supports PyTorch models via conversion to OpenVINO Intermediate Representation (IR). [OpenVINO model conversion API](https://docs.openvino.ai/2024/openvino-workflow/model-preparation.html#convert-a-model-with-python-convert-model) should be used for these purposes. `ov.convert_model` function accepts original PyTorch model instance and example input for tracing and returns `ov.Model` representing this model in OpenVINO framework. Converted model can be used for saving on disk using `ov.save_model` function or directly loading on device using `core.complie_model`.\n", + "\n", + "Model consist of 3 parts:\n", + "* **Text Encoder** to encode input multi-language text, incorporating cross-attention within each transformer block to embed the text into the model structure\n", + "* **Diffusion Transformer** for step by step denoising of generated video guided by text instructions.\n", + "* **VAE Decoder** to decode generated video represented in latent space.\n", + "\n", + "Model performs text-to-video generation task. For preserving original model flexibility, we will convert each part separately.\n", + "\n", + "The script `ov_wan_helper.py` contains helper function for model conversion, please check its content if you interested in conversion details.\n", + "\n", + "### Compress model weights\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "For reducing memory consumption, weights compression optimization can be applied using [NNCF](https://github.com/openvinotoolkit/nncf). \n", + "\n", + "
\n", + " Click here for more details about weight compression\n", + "Weight compression aims to reduce the memory footprint of a model. It can also lead to significant performance improvement for large memory-bound models, such as Large Language Models (LLMs). LLMs and other models, which require extensive memory to store the weights during inference, can benefit from weight compression in the following ways:\n", + "\n", + "* enabling the inference of exceptionally large models that cannot be accommodated in the memory of the device;\n", + "\n", + "* improving the inference performance of the models by reducing the latency of the memory access when computing the operations with weights, for example, Linear layers.\n", + "\n", + "[Neural Network Compression Framework (NNCF)](https://github.com/openvinotoolkit/nncf) provides 4-bit / 8-bit mixed weight quantization as a compression method primarily designed to optimize LLMs. The main difference between weights compression and full model quantization (post-training quantization) is that activations remain floating-point in the case of weights compression which leads to a better accuracy. Weight compression for LLMs provides a solid inference performance improvement which is on par with the performance of the full model quantization. In addition, weight compression is data-free and does not require a calibration dataset, making it easy to use.\n", + "\n", + "`nncf.compress_weights` function can be used for performing weights compression. The function accepts an OpenVINO model and other compression parameters. Compared to INT8 compression, INT4 compression improves performance even more, but introduces a minor drop in prediction quality.\n", + "\n", + "More details about weights compression, can be found in [OpenVINO documentation](https://docs.openvino.ai/2024/openvino-workflow/model-optimization-guide/weight-compression.html).\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c0fbf714670947e48155681658ab588f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dropdown(description='Model format:', index=2, options=('FP16', 'INT8', 'INT4'), value='INT4')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import ipywidgets as widgets\n", + "\n", + "# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry\n", + "from notebook_utils import collect_telemetry\n", + "\n", + "collect_telemetry(\"wan2.1-text-to-video.ipynb\")\n", + "\n", + "model_id = \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\"\n", + "model_base_dir = Path(model_id.split(\"/\")[-1])\n", + "\n", + "model_format = widgets.Dropdown(\n", + " options=[\"FP16\", \"INT8\", \"INT4\"],\n", + " value=\"INT4\",\n", + " description=\"Model format:\",\n", + ")\n", + "\n", + "model_format" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import nncf\n", + "\n", + "model_dir = model_base_dir / model_format.value\n", + "\n", + "if model_format.value == \"INT4\":\n", + " weights_compression_config = {\"mode\": nncf.CompressWeightsMode.INT4_ASYM, \"group_size\": 64, \"ratio\": 1.0}\n", + "elif model_format.value == \"INT8\":\n", + " weights_compression_config = {\"mode\": nncf.CompressWeightsMode.INT8_ASYM}\n", + "else:\n", + " weights_compression_config = None" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Multiple distributions found for package optimum. Picked distribution: optimum\n", + "The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n" + ] + } + ], + "source": [ + "from ov_wan_helper import convert_pipeline\n", + "\n", + "# Uncomment the line to see model conversion code\n", + "# ??convert_pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Wan-AI/Wan2.1-T2V-1.3B-Diffusers model already converted. You can find results in Wan2.1-T2V-1.3B-Diffusers/INT4\n" + ] + } + ], + "source": [ + "convert_pipeline(model_id, model_dir, compression_config=weights_compression_config)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare model inference pipeline\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "\n", + "`OVWanPipeline` defined in `ov_wan_helper.py` provides unified interface for running model inference. It accepts model directory and target device map for inference." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from ov_wan_helper import OVWanPipeline\n", + "\n", + "# Uncomment the line to see model inference code\n", + "# ??OVWanPipeline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Select inference device\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "You can specify inference device for each pipeline component or use the same device for all of them using widgets below." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4e3a666ff4004b4eb2fa75e2b362d921", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Dropdown(description='Transformer', index=1, options=('CPU', 'AUTO'), value='AUTO'), Dropdown(d…" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from notebook_utils import device_widget\n", + "\n", + "device_transformer = device_widget(exclude=[\"NPU\"], description=\"Transformer\")\n", + "device_text_encoder = device_widget(exclude=[\"NPU\"], description=\"Text Encoder\")\n", + "device_vae = device_widget(exclude=[\"NPU\"], description=\"VAE Decoder\")\n", + "\n", + "devices = widgets.VBox([device_transformer, device_text_encoder, device_vae])\n", + "\n", + "devices" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "device_map = {\"transformer\": device_transformer.value, \"text_encoder\": device_text_encoder.value, \"vae\": device_vae.value}\n", + "\n", + "ov_pipe = OVWanPipeline(model_dir, device_map)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run OpenVINO Model Inference\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers.utils import export_to_video\n", + "\n", + "prompt = \"A cat walks on the grass, realistic\"\n", + "negative_prompt = \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n", + "\n", + "output = ov_pipe(prompt=prompt, negative_prompt=negative_prompt, height=480, width=832, num_frames=20, guidance_scale=1.0, num_inference_steps=4).frames[0]\n", + "export_to_video(output, \"output.mp4\", fps=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Video\n", + "\n", + "display(Video(\"output.mp4\"))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interactive demo\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from gradio_helper import make_demo\n", + "\n", + "demo = make_demo(ov_pipe)\n", + "\n", + "try:\n", + " demo.launch(debug=True)\n", + "except Exception:\n", + " demo.launch(share=True, debug=True)\n", + "# if you are launching remotely, specify server_name and server_port\n", + "# demo.launch(server_name='your server name', server_port='server port in int')\n", + "# Read more in the docs: https://gradio.app/docs/" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "openvino_notebooks": { + "imageUrl": "https://media.githubusercontent.com/media/Lightricks/LTX-Video/refs/heads/main//docs/_static/ltx-video_example_00003.gif?raw=true", + "tags": { + "categories": [ + "Model Demos", + "AI Trends" + ], + "libraries": [], + "other": [], + "tasks": [ + "Text-to-Video" + ] + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "05592f84f5104aefb789af1ac56241cc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DescriptionStyleModel", + "state": { + "description_width": "" + } + }, + "0c470d92ab6d4f909c03fa31a0e7bbc7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DescriptionStyleModel", + "state": { + "description_width": "" + } + }, + "0f298906ffd145b0b024e5c5d4de40b6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "children": [ + "IPY_MODEL_5249107cf961483f8eb8e3c54109983c", + "IPY_MODEL_7827e06e83654ab9b12e0c3782dba700", + "IPY_MODEL_b2bc6fb669514af8a2bb12b0ec55c618" + ], + "layout": "IPY_MODEL_abc2e27b9de6460fbdf561ee3213e1c1" + } + }, + "1055cdeb567649a8a2746539a0cf7b87": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "25a664bf06ef4a7c9a0ab503ee497759": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "2efd8892df254338b4873d33037c3e24": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DropdownModel", + "state": { + "_options_labels": [ + "CPU", + "AUTO" + ], + "description": "Transformer", + "index": 1, + "layout": "IPY_MODEL_546abc5d11bf4bc0aeff2d4ae46c230a", + "style": "IPY_MODEL_05592f84f5104aefb789af1ac56241cc" + } + }, + "3decb2c6214e4232b8a236cb0115db22": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "children": [ + "IPY_MODEL_91e7bedb19cb4430836adaa0ea35143e", + "IPY_MODEL_8770dcd83a5046c4a47d8345a8201ed1", + "IPY_MODEL_90bde2e6e00646b296c8629611adc831" + ], + "layout": "IPY_MODEL_4b20a03efb6c4c83a41f5873f59393a0" + } + }, + "48cd0b498d9942fca34c1b6cecdb5596": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "4b20a03efb6c4c83a41f5873f59393a0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "4e3a666ff4004b4eb2fa75e2b362d921": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "VBoxModel", + "state": { + "children": [ + "IPY_MODEL_2efd8892df254338b4873d33037c3e24", + "IPY_MODEL_dddb537713304325bce235618653f4ed", + "IPY_MODEL_e6d005e13b5e4c928267665bfbd2120f" + ], + "layout": "IPY_MODEL_debe50f8e1b74013b4b0ab3db4e47052" + } + }, + "5249107cf961483f8eb8e3c54109983c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "layout": "IPY_MODEL_6e21a98be6e243d5baaa895a29397936", + "style": "IPY_MODEL_ded4fe557ee14efebae06eb737e08d9c", + "value": "100%" + } + }, + "54539ede5d2e4cbeb69b9930a1d1097a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "description_width": "" + } + }, + "546abc5d11bf4bc0aeff2d4ae46c230a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "5b9c545caad74c49a4d2760419b9564d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "637f184f36a84a138f51f30a6bae2aa7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DescriptionStyleModel", + "state": { + "description_width": "" + } + }, + "6e21a98be6e243d5baaa895a29397936": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "77beacf3065e4e5c8429bc6124ad9c93": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "7827e06e83654ab9b12e0c3782dba700": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "bar_style": "success", + "layout": "IPY_MODEL_e42d849d6cc4420dba0ded96ba3f07fe", + "max": 4, + "style": "IPY_MODEL_54539ede5d2e4cbeb69b9930a1d1097a", + "value": 4 + } + }, + "8770dcd83a5046c4a47d8345a8201ed1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "bar_style": "success", + "layout": "IPY_MODEL_cd8971c29586478098064f7b5238c3fb", + "max": 4, + "style": "IPY_MODEL_ab82dee3c8184723ad48f7c557ccc226", + "value": 4 + } + }, + "90bde2e6e00646b296c8629611adc831": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "layout": "IPY_MODEL_48cd0b498d9942fca34c1b6cecdb5596", + "style": "IPY_MODEL_d114193526614448836db914c1a96ce3", + "value": " 4/4 [05:18<00:00, 79.47s/steps]" + } + }, + "91e7bedb19cb4430836adaa0ea35143e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "layout": "IPY_MODEL_9668c3451df24eef8c3bd2b3deb17859", + "style": "IPY_MODEL_d6a81425fc474d9eb2a3d078e1dbd529", + "value": "100%" + } + }, + "9668c3451df24eef8c3bd2b3deb17859": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "9ff69292e7114d86ac7b828532542b8e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "ab82dee3c8184723ad48f7c557ccc226": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "description_width": "" + } + }, + "abc2e27b9de6460fbdf561ee3213e1c1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "b2bc6fb669514af8a2bb12b0ec55c618": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "layout": "IPY_MODEL_5b9c545caad74c49a4d2760419b9564d", + "style": "IPY_MODEL_77beacf3065e4e5c8429bc6124ad9c93", + "value": " 4/4 [02:35<00:00, 38.67s/it]" + } + }, + "c0fbf714670947e48155681658ab588f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DropdownModel", + "state": { + "_options_labels": [ + "FP16", + "INT8", + "INT4" + ], + "description": "Model format:", + "index": 2, + "layout": "IPY_MODEL_9ff69292e7114d86ac7b828532542b8e", + "style": "IPY_MODEL_637f184f36a84a138f51f30a6bae2aa7" + } + }, + "cd8971c29586478098064f7b5238c3fb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "d114193526614448836db914c1a96ce3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "d29d1a223c2d4fabbe1fac8167d0dc5f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DescriptionStyleModel", + "state": { + "description_width": "" + } + }, + "d6a81425fc474d9eb2a3d078e1dbd529": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "dddb537713304325bce235618653f4ed": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DropdownModel", + "state": { + "_options_labels": [ + "CPU", + "AUTO" + ], + "description": "Text Encoder", + "index": 1, + "layout": "IPY_MODEL_25a664bf06ef4a7c9a0ab503ee497759", + "style": "IPY_MODEL_d29d1a223c2d4fabbe1fac8167d0dc5f" + } + }, + "debe50f8e1b74013b4b0ab3db4e47052": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "ded4fe557ee14efebae06eb737e08d9c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "e42d849d6cc4420dba0ded96ba3f07fe": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "e6d005e13b5e4c928267665bfbd2120f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "DropdownModel", + "state": { + "_options_labels": [ + "CPU", + "AUTO" + ], + "description": "VAE Decoder", + "index": 1, + "layout": "IPY_MODEL_1055cdeb567649a8a2746539a0cf7b87", + "style": "IPY_MODEL_0c470d92ab6d4f909c03fa31a0e7bbc7" + } + } + }, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From ef4350b502fa8ffe3e8a46a8d9e8f1a5292c1340 Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 29 May 2025 13:58:47 +0400 Subject: [PATCH 2/2] skip tests --- .ci/ignore_treon_docker.txt | 3 ++- .ci/skipped_notebooks.yml | 6 ++++++ notebooks/wan2.1-text-to-video/README.md | 2 +- .../wan2.1-text-to-video/wan2.1-text-to-video.ipynb | 9 +++++++-- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.ci/ignore_treon_docker.txt b/.ci/ignore_treon_docker.txt index e6afdb84abf..6803c637d09 100644 --- a/.ci/ignore_treon_docker.txt +++ b/.ci/ignore_treon_docker.txt @@ -89,4 +89,5 @@ notebooks/minicpm-o-omnimodal-chatbot/minicpm-o-omnimodal-chatbot.ipynb notebooks/kokoro/kokoro.ipynb notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb notebooks/intern-video2-classiciation/intern-video2-classification.ipynb -notebooks/flex.2-image-generation/flex.2-image-generation.ipynb \ No newline at end of file +notebooks/flex.2-image-generation/flex.2-image-generation.ipynb +notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb \ No newline at end of file diff --git a/.ci/skipped_notebooks.yml b/.ci/skipped_notebooks.yml index d94b190ff11..92943bc0c9d 100644 --- a/.ci/skipped_notebooks.yml +++ b/.ci/skipped_notebooks.yml @@ -550,3 +550,9 @@ skips: - os: - macos-13 +- notebook: notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb + skips: + - os: + - macos-13 + - ubuntu-22.04 + - windows-2019 diff --git a/notebooks/wan2.1-text-to-video/README.md b/notebooks/wan2.1-text-to-video/README.md index 4fac72338f3..8547ca3bd0e 100644 --- a/notebooks/wan2.1-text-to-video/README.md +++ b/notebooks/wan2.1-text-to-video/README.md @@ -17,7 +17,7 @@ 1. Student Initialization: Initialization of the causal student by pretraining it on a small set of ODE solution pairs generated by the bidirectional teacher. This step helps stabilize the subsequent distillation training. 2. Asymmetric Distillation: Using the bidirectional teacher model, we train the causal student generator through a distribution matching distillation loss. -More details about CausVid can be found in the [paper](https://arxiv.org/abs/2412.07772), [original repository](https://github.com/tianweiy/CausVi) and [project page](https://causvid.github.io/) +More details about CausVid can be found in the [paper](https://arxiv.org/abs/2412.07772), [original repository](https://github.com/tianweiy/CausVid) and [project page](https://causvid.github.io/) ## Notebook contents diff --git a/notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb b/notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb index 79adb1ab507..8cae39c36d2 100644 --- a/notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb +++ b/notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb @@ -24,7 +24,7 @@ " 1. Student Initialization: Initialization of the causal student by pretraining it on a small set of ODE solution pairs generated by the bidirectional teacher. This step helps stabilize the subsequent distillation training.\n", " 2. Asymmetric Distillation: Using the bidirectional teacher model, we train the causal student generator through a distribution matching distillation loss. \n", "\n", - "More details about CausVid can be found in the [paper](https://arxiv.org/abs/2412.07772), [original repository](https://github.com/tianweiy/CausVi) and [project page](https://causvid.github.io/)\n", + "More details about CausVid can be found in the [paper](https://arxiv.org/abs/2412.07772), [original repository](https://github.com/tianweiy/CausVid) and [project page](https://causvid.github.io/)\n", "#### Table of contents:\n", "\n", "- [Prerequisites](#Prerequisites)\n", @@ -61,8 +61,13 @@ "metadata": {}, "outputs": [], "source": [ + "import platform\n", + "\n", "%pip install -q \"torch>=2.1\" \"git+https://github.com/huggingface/diffusers.git\" \"transformers>=4.49.0\" \"accelerate\" \"safetensors\" \"sentencepiece\" \"peft>=0.7.0\" \"ftfy\" \"gradio>=4.19\" \"opencv-python\" --extra-index-url https://download.pytorch.org/whl/cpu\n", - "%pip install --pre -U \"openvino>=2025.1.0\" \"nncf>=2.16.0\" --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly" + "%pip install --pre -U \"openvino>=2025.1.0\" \"nncf>=2.16.0\" --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly\n", + "\n", + "if platform.system() == \"Darwin\":\n", + " %pip install -q \"numpy<2.0.0\"" ] }, {