diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index a939a06cbd46..c6be52d455b8 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -12,6 +12,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import ( + AutoencoderKLLTX2Video, AutoencoderKLWan, DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, @@ -24,7 +25,10 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext -ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"] +ckpt_ids = [ + "Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth", + "Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth", +] # https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py @@ -92,12 +96,22 @@ def main(args): if args.video_size == 480: sample_size = 30 # Wan-VAE: 8xp2 downsample factor patch_size = (1, 2, 2) + in_channels = 16 + out_channels = 16 elif args.video_size == 720: - sample_size = 22 # Wan-VAE: 32xp1 downsample factor + sample_size = 22 # DC-AE-V: 32xp1 downsample factor patch_size = (1, 1, 1) + in_channels = 32 + out_channels = 32 else: raise ValueError(f"Video size {args.video_size} is not supported.") + if args.vae_type == "ltx2": + sample_size = 22 + patch_size = (1, 1, 1) + in_channels = 128 + out_channels = 128 + for depth in range(layer_num): # Transformer blocks. converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( @@ -182,8 +196,8 @@ def main(args): # Transformer with CTX(): transformer_kwargs = { - "in_channels": 16, - "out_channels": 16, + "in_channels": in_channels, + "out_channels": out_channels, "num_attention_heads": 20, "attention_head_dim": 112, "num_layers": 20, @@ -235,9 +249,12 @@ def main(args): else: print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - vae = AutoencoderKLWan.from_pretrained( - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 - ) + if args.vae_type == "ltx2": + vae_path = args.vae_path or "Lightricks/LTX-2" + vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) + else: + vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) # Text Encoder text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" @@ -314,7 +331,23 @@ def main(args): choices=["flow-dpm_solver", "flow-euler", "uni-pc"], help="Scheduler type to use.", ) - parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.") + parser.add_argument( + "--vae_type", + default="wan", + type=str, + choices=["wan", "ltx2"], + help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).", + ) + parser.add_argument( + "--vae_path", + default=None, + type=str, + required=False, + help="Optional VAE path or repo id. If not set, a default is used per VAE type.", + ) + parser.add_argument( + "--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v." + ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.") parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py index 8b44dfc1143c..7ae85639e358 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py @@ -24,7 +24,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import SanaLoraLoaderMixin -from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import ( BACKENDS_MAPPING, @@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): The tokenizer used to tokenize the prompt. text_encoder ([`Gemma2PreTrainedModel`]): Text encoder model to encode the input prompts. - vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer ([`SanaVideoTransformer3DModel`]): Conditional Transformer to denoise the input latents. @@ -213,7 +213,7 @@ def __init__( self, tokenizer: GemmaTokenizer | GemmaTokenizerFast, text_encoder: Gemma2PreTrainedModel, - vae: AutoencoderDC | AutoencoderKLWan, + vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan, transformer: SanaVideoTransformer3DModel, scheduler: DPMSolverMultistepScheduler, ): @@ -223,8 +223,19 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + if getattr(self, "vae", None): + if isinstance(self.vae, AutoencoderKLLTX2Video): + self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio + elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)): + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 self.vae_scale_factor = self.vae_scale_factor_spatial @@ -985,14 +996,21 @@ def __call__( if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError ) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + if isinstance(self.vae, AutoencoderKLLTX2Video): + latents_mean = self.vae.latents_mean + latents_std = self.vae.latents_std + z_dim = self.vae.config.latent_channels + elif isinstance(self.vae, AutoencoderKLWan): + latents_mean = torch.tensor(self.vae.config.latents_mean) + latents_std = torch.tensor(self.vae.config.latents_std) + z_dim = self.vae.config.z_dim + else: + latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype) + latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype) + z_dim = latents.shape[1] + + latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean try: video = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py index b90d7c6f5a60..81df1d0759da 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py @@ -26,7 +26,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import SanaLoraLoaderMixin -from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, @@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): The tokenizer used to tokenize the prompt. text_encoder ([`Gemma2PreTrainedModel`]): Text encoder model to encode the input prompts. - vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer ([`SanaVideoTransformer3DModel`]): Conditional Transformer to denoise the input latents. @@ -203,7 +203,7 @@ def __init__( self, tokenizer: GemmaTokenizer | GemmaTokenizerFast, text_encoder: Gemma2PreTrainedModel, - vae: AutoencoderDC | AutoencoderKLWan, + vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan, transformer: SanaVideoTransformer3DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): @@ -213,8 +213,19 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + if getattr(self, "vae", None): + if isinstance(self.vae, AutoencoderKLLTX2Video): + self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio + elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)): + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 self.vae_scale_factor = self.vae_scale_factor_spatial @@ -687,14 +698,18 @@ def prepare_latents( image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, -1, 1, 1, 1) - .to(image_latents.device, image_latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( - image_latents.device, image_latents.dtype - ) + if isinstance(self.vae, AutoencoderKLLTX2Video): + _latents_mean = self.vae.latents_mean + _latents_std = self.vae.latents_std + elif isinstance(self.vae, AutoencoderKLWan): + _latents_mean = torch.tensor(self.vae.config.latents_mean) + _latents_std = torch.tensor(self.vae.config.latents_std) + else: + _latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype) + _latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype) + + latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype) image_latents = (image_latents - latents_mean) * latents_std latents[:, :, 0:1] = image_latents.to(dtype) @@ -1034,14 +1049,21 @@ def __call__( if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError ) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + if isinstance(self.vae, AutoencoderKLLTX2Video): + latents_mean = self.vae.latents_mean + latents_std = self.vae.latents_std + z_dim = self.vae.config.latent_channels + elif isinstance(self.vae, AutoencoderKLWan): + latents_mean = torch.tensor(self.vae.config.latents_mean) + latents_std = torch.tensor(self.vae.config.latents_std) + z_dim = self.vae.config.z_dim + else: + latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype) + latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype) + z_dim = latents.shape[1] + + latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean try: video = self.vae.decode(latents, return_dict=False)[0]