From 00da4b6c4f8db0aa6d639366aae98370f20b2250 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 27 Jan 2026 19:34:09 +0800 Subject: [PATCH 1/6] add video_vae and dit for ltx-2 --- diffsynth/configs/model_configs.py | 26 +- diffsynth/core/attention/attention.py | 2 +- diffsynth/models/ltx2_common.py | 253 +++ diffsynth/models/ltx2_dit.py | 1442 ++++++++++++ diffsynth/models/ltx2_video_vae.py | 1969 +++++++++++++++++ .../utils/state_dict_converters/ltx2_dit.py | 9 + .../state_dict_converters/ltx2_video_vae.py | 22 + diffsynth/utils/test/load_model.py | 22 + 8 files changed, 3743 insertions(+), 2 deletions(-) create mode 100644 diffsynth/models/ltx2_common.py create mode 100644 diffsynth/models/ltx2_dit.py create mode 100644 diffsynth/models/ltx2_video_vae.py create mode 100644 diffsynth/utils/state_dict_converters/ltx2_dit.py create mode 100644 diffsynth/utils/state_dict_converters/ltx2_video_vae.py create mode 100644 diffsynth/utils/test/load_model.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index c93f5e9a3..7809ee2ca 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -591,4 +591,28 @@ }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series +ltx2_series = [ + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_dit", + "model_class": "diffsynth.models.ltx2_dit.LTXModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_encoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_video_vae_decoder", + "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", + }, +] + +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series diff --git a/diffsynth/core/attention/attention.py b/diffsynth/core/attention/attention.py index 15b55a43a..630d37545 100644 --- a/diffsynth/core/attention/attention.py +++ b/diffsynth/core/attention/attention.py @@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern=" if k_pattern != required_in_pattern: k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) if v_pattern != required_in_pattern: - v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims) + v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims) return q, k, v diff --git a/diffsynth/models/ltx2_common.py b/diffsynth/models/ltx2_common.py new file mode 100644 index 000000000..fbcae4b9e --- /dev/null +++ b/diffsynth/models/ltx2_common.py @@ -0,0 +1,253 @@ +from dataclasses import dataclass +from typing import NamedTuple +import torch +from torch import nn + +class VideoPixelShape(NamedTuple): + """ + Shape of the tensor representing the video pixel array. Assumes BGR channel format. + """ + + batch: int + frames: int + height: int + width: int + fps: float + + +class SpatioTemporalScaleFactors(NamedTuple): + """ + Describes the spatiotemporal downscaling between decoded video space and + the corresponding VAE latent grid. + """ + + time: int + width: int + height: int + + @classmethod + def default(cls) -> "SpatioTemporalScaleFactors": + return cls(time=8, width=32, height=32) + + +VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default() + + +class VideoLatentShape(NamedTuple): + """ + Shape of the tensor representing video in VAE latent space. + The latent representation is a 5D tensor with dimensions ordered as + (batch, channels, frames, height, width). Spatial and temporal dimensions + are downscaled relative to pixel space according to the VAE's scale factors. + """ + + batch: int + channels: int + frames: int + height: int + width: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.height, self.width]) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "VideoLatentShape": + return VideoLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + height=shape[3], + width=shape[4], + ) + + def mask_shape(self) -> "VideoLatentShape": + return self._replace(channels=1) + + @staticmethod + def from_pixel_shape( + shape: VideoPixelShape, + latent_channels: int = 128, + scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS, + ) -> "VideoLatentShape": + frames = (shape.frames - 1) // scale_factors[0] + 1 + height = shape.height // scale_factors[1] + width = shape.width // scale_factors[2] + + return VideoLatentShape( + batch=shape.batch, + channels=latent_channels, + frames=frames, + height=height, + width=width, + ) + + def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape": + return self._replace( + channels=3, + frames=(self.frames - 1) * scale_factors.time + 1, + height=self.height * scale_factors.height, + width=self.width * scale_factors.width, + ) + + +class AudioLatentShape(NamedTuple): + """ + Shape of audio in VAE latent space: (batch, channels, frames, mel_bins). + mel_bins is the number of frequency bins from the mel-spectrogram encoding. + """ + + batch: int + channels: int + frames: int + mel_bins: int + + def to_torch_shape(self) -> torch.Size: + return torch.Size([self.batch, self.channels, self.frames, self.mel_bins]) + + def mask_shape(self) -> "AudioLatentShape": + return self._replace(channels=1, mel_bins=1) + + @staticmethod + def from_torch_shape(shape: torch.Size) -> "AudioLatentShape": + return AudioLatentShape( + batch=shape[0], + channels=shape[1], + frames=shape[2], + mel_bins=shape[3], + ) + + @staticmethod + def from_duration( + batch: int, + duration: float, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor) + + return AudioLatentShape( + batch=batch, + channels=channels, + frames=round(duration * latents_per_second), + mel_bins=mel_bins, + ) + + @staticmethod + def from_video_pixel_shape( + shape: VideoPixelShape, + channels: int = 8, + mel_bins: int = 16, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + ) -> "AudioLatentShape": + return AudioLatentShape.from_duration( + batch=shape.batch, + duration=float(shape.frames) / float(shape.fps), + channels=channels, + mel_bins=mel_bins, + sample_rate=sample_rate, + hop_length=hop_length, + audio_latent_downsample_factor=audio_latent_downsample_factor, + ) + + +@dataclass(frozen=True) +class LatentState: + """ + State of latents during the diffusion denoising process. + Attributes: + latent: The current noisy latent tensor being denoised. + denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising). + positions: Positional indices for each latent element, used for positional embeddings. + clean_latent: Initial state of the latent before denoising, may include conditioning latents. + """ + + latent: torch.Tensor + denoise_mask: torch.Tensor + positions: torch.Tensor + clean_latent: torch.Tensor + + def clone(self) -> "LatentState": + return LatentState( + latent=self.latent.clone(), + denoise_mask=self.denoise_mask.clone(), + positions=self.positions.clone(), + clean_latent=self.clean_latent.clone(), + ) + + +class PixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor: + """Root-mean-square (RMS) normalize `x` over its last dimension. + Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized + shape and forwards `weight` and `eps`. + """ + return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps) + + +@dataclass(frozen=True) +class Modality: + """ + Input data for a single modality (video or audio) in the transformer. + Bundles the latent tokens, timestep embeddings, positional information, + and text conditioning context for processing by the diffusion transformer. + """ + + latent: ( + torch.Tensor + ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension + timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps + positions: ( + torch.Tensor + ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens + context: torch.Tensor + enabled: bool = True + context_mask: torch.Tensor | None = None + + +def to_denoised( + sample: torch.Tensor, + velocity: torch.Tensor, + sigma: float | torch.Tensor, + calc_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Convert the sample and its denoising velocity to denoised sample. + Returns: + Denoised sample + """ + if isinstance(sigma, torch.Tensor): + sigma = sigma.to(calc_dtype) + return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py new file mode 100644 index 000000000..a4cfd5c34 --- /dev/null +++ b/diffsynth/models/ltx2_dit.py @@ -0,0 +1,1442 @@ +import math +import functools +from dataclasses import dataclass, replace +from enum import Enum +from typing import Optional, Tuple, Callable +import numpy as np +import torch +from torch._prims_common import DeviceLikeType +from einops import rearrange +from .ltx2_common import rms_norm, Modality +from ..core.attention.attention import attention_forward + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(torch.nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + ): + super().__init__() + + self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + + self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): + """ + For PixArt-Alpha. + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__( + self, + embedding_dim: int, + size_emb_dim: int, + ): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb + + +class PerturbationType(Enum): + """Types of attention perturbations for STG (Spatio-Temporal Guidance).""" + + SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" + SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" + SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" + SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" + + +@dataclass(frozen=True) +class Perturbation: + """A single perturbation specifying which attention type to skip and in which blocks.""" + + type: PerturbationType + blocks: list[int] | None # None means all blocks + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.type != perturbation_type: + return False + + if self.blocks is None: + return True + + return block in self.blocks + + +@dataclass(frozen=True) +class PerturbationConfig: + """Configuration holding a list of perturbations for a single sample.""" + + perturbations: list[Perturbation] | None + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.perturbations is None: + return False + + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty() -> "PerturbationConfig": + return PerturbationConfig([]) + + +@dataclass(frozen=True) +class BatchedPerturbationConfig: + """Perturbation configurations for a batch, with utilities for generating attention masks.""" + + perturbations: list[PerturbationConfig] + + def mask( + self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype + ) -> torch.Tensor: + mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) + for batch_idx, perturbation in enumerate(self.perturbations): + if perturbation.is_perturbed(perturbation_type, block): + mask[batch_idx] = 0 + + return mask + + def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: + mask = self.mask(perturbation_type, block, values.device, values.dtype) + return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) + + def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty(batch_size: int) -> "BatchedPerturbationConfig": + return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) + + +class AdaLayerNormSingle(torch.nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, + size_emb_dim=embedding_dim // 3, + ) + + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTXRopeType(Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + +def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, +) -> torch.Tensor: + if rope_type == LTXRopeType.INTERLEAVED: + return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) + elif rope_type == LTXRopeType.SPLIT: + return apply_split_rotary_emb(input_tensor, *freqs_cis) + else: + raise ValueError(f"Invalid rope type: {rope_type}") + + + +def apply_interleaved_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +def apply_split_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + needs_reshape = False + if input_tensor.ndim != 4 and cos_freqs.ndim == 4: + b, h, t, _ = cos_freqs.shape + input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + + output = split_input * cos_freqs.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + + first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) + + output = rearrange(output, "... d r -> ... (d r)") + if needs_reshape: + output = output.swapaxes(1, 2).reshape(b, t, -1) + + return output + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + np.log(start) / np.log(theta), + np.log(end) / np.log(theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_pytorch( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + + +def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), ( + f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" + ) + fractional_positions = torch.stack( + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + dim=-1, + ) + return fractional_positions + + +def generate_freqs( + indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool +) -> torch.Tensor: + if use_middle_indices_grid: + assert len(indices_grid.shape) == 4 + assert indices_grid.shape[-1] == 2 + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = indices.to(device=fractional_positions.device) + + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + return freqs + + +def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) + + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq + + +def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq + + +def precompute_freqs_cis( + indices_grid: torch.Tensor, + dim: int, + out_dtype: torch.dtype, + theta: float = 10000.0, + max_pos: list[int] | None = None, + use_middle_indices_grid: bool = False, + num_attention_heads: int = 32, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, +) -> tuple[torch.Tensor, torch.Tensor]: + if max_pos is None: + max_pos = [20, 2048, 2048] + + indices = freq_grid_generator(theta, indices_grid.shape[1], dim) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if rope_type == LTXRopeType.SPLIT: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + + +class Attention(torch.nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + norm_eps: float = 1e-6, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + ) -> None: + super().__init__() + self.rope_type = rope_type + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + + self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) + self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) + + self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + ) -> torch.Tensor: + q = self.to_q(x) + context = x if context is None else context + k = self.to_k(context) + v = self.to_v(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe, self.rope_type) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) + + # Reshape for attention_forward using unflatten + q = q.unflatten(-1, (self.heads, self.dim_head)) + k = k.unflatten(-1, (self.heads, self.dim_head)) + v = v.unflatten(-1, (self.heads, self.dim_head)) + + out = attention_forward( + q=q, + k=k, + v=v, + q_pattern="b s n d", + k_pattern="b s n d", + v_pattern="b s n d", + out_pattern="b s n d", + attn_mask=mask + ) + + # Reshape back to original format + out = out.flatten(2, 3) + return self.to_out(out) + + +class PixArtAlphaTextProjection(torch.nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = torch.nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = torch.nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass(frozen=True) +class TransformerArgs: + x: torch.Tensor + context: torch.Tensor + context_mask: torch.Tensor + timesteps: torch.Tensor + embedded_timestep: torch.Tensor + positional_embeddings: torch.Tensor + cross_positional_embeddings: torch.Tensor | None + cross_scale_shift_timestep: torch.Tensor | None + cross_gate_timestep: torch.Tensor | None + enabled: bool + + + +class TransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + use_middle_indices_grid: bool, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + ) -> None: + self.patchify_proj = patchify_proj + self.adaln = adaln + self.caption_projection = caption_projection + self.inner_dim = inner_dim + self.max_pos = max_pos + self.num_attention_heads = num_attention_heads + self.use_middle_indices_grid = use_middle_indices_grid + self.timestep_scale_multiplier = timestep_scale_multiplier + self.double_precision_rope = double_precision_rope + self.positional_embedding_theta = positional_embedding_theta + self.rope_type = rope_type + + def _prepare_timestep( + self, timestep: torch.Tensor, batch_size: int, hidden_dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare timestep embeddings.""" + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + return timestep, embedded_timestep + + def _prepare_context( + self, + context: torch.Tensor, + x: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Prepare context for transformer blocks.""" + batch_size = x.shape[0] + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: + """Prepare attention mask.""" + if attention_mask is None or torch.is_floating_point(attention_mask): + return attention_mask + + return (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + + def _prepare_positional_embeddings( + self, + positions: torch.Tensor, + inner_dim: int, + max_pos: list[int], + use_middle_indices_grid: bool, + num_attention_heads: int, + x_dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare positional embeddings.""" + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + pe = precompute_freqs_cis( + positions, + dim=inner_dim, + out_dtype=x_dtype, + theta=self.positional_embedding_theta, + max_pos=max_pos, + use_middle_indices_grid=use_middle_indices_grid, + num_attention_heads=num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + return pe + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + x = self.patchify_proj(modality.latent) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], modality.latent.dtype) + context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) + attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + return TransformerArgs( + x=x, + context=context, + context_mask=attention_mask, + timesteps=timestep, + embedded_timestep=embedded_timestep, + positional_embeddings=pe, + cross_positional_embeddings=None, + cross_scale_shift_timestep=None, + cross_gate_timestep=None, + enabled=modality.enabled, + ) + + +class MultiModalTransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + caption_projection: PixArtAlphaTextProjection, + cross_scale_shift_adaln: AdaLayerNormSingle, + cross_gate_adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + cross_pe_max_pos: int, + use_middle_indices_grid: bool, + audio_cross_attention_dim: int, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + av_ca_timestep_scale_multiplier: int, + ) -> None: + self.simple_preprocessor = TransformerArgsPreprocessor( + patchify_proj=patchify_proj, + adaln=adaln, + caption_projection=caption_projection, + inner_dim=inner_dim, + max_pos=max_pos, + num_attention_heads=num_attention_heads, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + double_precision_rope=double_precision_rope, + positional_embedding_theta=positional_embedding_theta, + rope_type=rope_type, + ) + self.cross_scale_shift_adaln = cross_scale_shift_adaln + self.cross_gate_adaln = cross_gate_adaln + self.cross_pe_max_pos = cross_pe_max_pos + self.audio_cross_attention_dim = audio_cross_attention_dim + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + def prepare( + self, + modality: Modality, + ) -> TransformerArgs: + transformer_args = self.simple_preprocessor.prepare(modality) + cross_pe = self.simple_preprocessor._prepare_positional_embeddings( + positions=modality.positions[:, 0:1, :], + inner_dim=self.audio_cross_attention_dim, + max_pos=[self.cross_pe_max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.simple_preprocessor.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + + cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( + timestep=modality.timesteps, + timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, + batch_size=transformer_args.x.shape[0], + hidden_dtype=modality.latent.dtype, + ) + + return replace( + transformer_args, + cross_positional_embeddings=cross_pe, + cross_scale_shift_timestep=cross_scale_shift_timestep, + cross_gate_timestep=cross_gate_timestep, + ) + + def _prepare_cross_attention_timestep( + self, + timestep: torch.Tensor, + timestep_scale_multiplier: int, + batch_size: int, + hidden_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare cross attention timestep embeddings.""" + timestep = timestep * timestep_scale_multiplier + + av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier + + scale_shift_timestep, _ = self.cross_scale_shift_adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) + gate_noise_timestep, _ = self.cross_gate_adaln( + timestep.flatten() * av_ca_factor, + hidden_dtype=hidden_dtype, + ) + gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) + + return scale_shift_timestep, gate_noise_timestep + + +@dataclass +class TransformerConfig: + dim: int + heads: int + d_head: int + context_dim: int + + +class BasicAVTransformerBlock(torch.nn.Module): + def __init__( + self, + idx: int, + video: TransformerConfig | None = None, + audio: TransformerConfig | None = None, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + norm_eps: float = 1e-6, + ): + super().__init__() + + self.idx = idx + if video is not None: + self.attn1 = Attention( + query_dim=video.dim, + heads=video.heads, + dim_head=video.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.attn2 = Attention( + query_dim=video.dim, + context_dim=video.context_dim, + heads=video.heads, + dim_head=video.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.ff = FeedForward(video.dim, dim_out=video.dim) + self.scale_shift_table = torch.nn.Parameter(torch.empty(6, video.dim)) + + if audio is not None: + self.audio_attn1 = Attention( + query_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.audio_attn2 = Attention( + query_dim=audio.dim, + context_dim=audio.context_dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(6, audio.dim)) + + if audio is not None and video is not None: + # Q: Video, K,V: Audio + self.audio_to_video_attn = Attention( + query_dim=video.dim, + context_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = Attention( + query_dim=audio.dim, + context_dim=video.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + ) + + self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) + self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) + + self.norm_eps = norm_eps + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, slice(None, None) + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None) + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( # noqa: PLR0915 + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig | None = None, + ) -> tuple[TransformerArgs | None, TransformerArgs | None]: + batch_size = video.x.shape[0] + if perturbations is None: + perturbations = BatchedPerturbationConfig.empty(batch_size) + + vx = video.x if video is not None else None + ax = audio.x if audio is not None else None + + run_vx = video is not None and video.enabled and vx.numel() > 0 + run_ax = audio is not None and audio.enabled and ax.numel() > 0 + + run_a2v = run_vx and (audio is not None and ax.numel() > 0) + run_v2a = run_ax and (video is not None and vx.numel() > 0) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) + ) + if not perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx): + norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa + v_mask = perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) + vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa * v_mask + + vx = vx + self.attn2(rms_norm(vx, eps=self.norm_eps), context=video.context, mask=video.context_mask) + + del vshift_msa, vscale_msa, vgate_msa + + if run_ax: + ashift_msa, ascale_msa, agate_msa = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) + ) + + if not perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx): + norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa + a_mask = perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) + ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa * a_mask + + ax = ax + self.audio_attn2(rms_norm(ax, eps=self.norm_eps), context=audio.context, mask=audio.context_mask) + + del ashift_msa, ascale_msa, agate_msa + + # Audio - Video cross attention. + if run_a2v or run_v2a: + vx_norm3 = rms_norm(vx, eps=self.norm_eps) + ax_norm3 = rms_norm(ax, eps=self.norm_eps) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) + vx = vx + ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=video.cross_positional_embeddings, + k_pe=audio.cross_positional_embeddings, + ) + * gate_out_a2v + * a2v_mask + ) + + if run_v2a: + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) + ax = ax + ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=audio.cross_positional_embeddings, + k_pe=video.cross_positional_embeddings, + ) + * gate_out_v2a + * v2a_mask + ) + + del gate_out_a2v, gate_out_v2a + del ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + ) + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None) + ) + vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp + vx = vx + self.ff(vx_scaled) * vgate_mlp + + del vshift_mlp, vscale_mlp, vgate_mlp + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None) + ) + ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp + ax = ax + self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp + + return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None + + +class GELUApprox(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = torch.nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") + + +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + project_in = GELUApprox(dim, inner_dim) + + self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class LTXModelType(Enum): + AudioVideo = "ltx av model" + VideoOnly = "ltx video only model" + AudioOnly = "ltx audio only model" + + def is_video_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) + + def is_audio_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) + + +class LTXModel(torch.nn.Module): + """ + LTX model transformer implementation. + This class implements the transformer blocks for the LTX model. + """ + + def __init__( # noqa: PLR0913 + self, + *, + model_type: LTXModelType = LTXModelType.AudioVideo, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + in_channels: int = 128, + out_channels: int = 128, + num_layers: int = 48, + cross_attention_dim: int = 4096, + norm_eps: float = 1e-06, + caption_channels: int = 3840, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = [20, 2048, 2048], + timestep_scale_multiplier: int = 1000, + use_middle_indices_grid: bool = True, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_in_channels: int = 128, + audio_out_channels: int = 128, + audio_cross_attention_dim: int = 2048, + audio_positional_embedding_max_pos: list[int] | None = [20], + av_ca_timestep_scale_multiplier: int = 1000, + rope_type: LTXRopeType = LTXRopeType.SPLIT, + double_precision_rope: bool = True, + ): + super().__init__() + self._enable_gradient_checkpointing = False + self.use_middle_indices_grid = use_middle_indices_grid + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.timestep_scale_multiplier = timestep_scale_multiplier + self.positional_embedding_theta = positional_embedding_theta + self.model_type = model_type + cross_pe_max_pos = None + if model_type.is_video_enabled(): + if positional_embedding_max_pos is None: + positional_embedding_max_pos = [20, 2048, 2048] + self.positional_embedding_max_pos = positional_embedding_max_pos + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self._init_video( + in_channels=in_channels, + out_channels=out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_audio_enabled(): + if audio_positional_embedding_max_pos is None: + audio_positional_embedding_max_pos = [20] + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim + self._init_audio( + in_channels=audio_in_channels, + out_channels=audio_out_channels, + caption_channels=caption_channels, + norm_eps=norm_eps, + ) + + if model_type.is_video_enabled() and model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + self.audio_cross_attention_dim = audio_cross_attention_dim + self._init_audio_video(num_scale_shift_values=4) + + self._init_preprocessors(cross_pe_max_pos) + # Initialize transformer blocks + self._init_transformer_blocks( + num_layers=num_layers, + attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, + cross_attention_dim=cross_attention_dim, + audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, + audio_cross_attention_dim=audio_cross_attention_dim, + norm_eps=norm_eps, + ) + + def _init_video( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize video-specific components.""" + # Video input components + self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) + + self.adaln_single = AdaLayerNormSingle(self.inner_dim) + + # Video caption projection + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.inner_dim, + ) + + # Video output components + self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) + self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) + self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) + + def _init_audio( + self, + in_channels: int, + out_channels: int, + caption_channels: int, + norm_eps: float, + ) -> None: + """Initialize audio-specific components.""" + + # Audio input components + self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) + + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, + hidden_size=self.audio_inner_dim, + ) + + # Audio output components + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) + self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) + self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) + + def _init_audio_video( + self, + num_scale_shift_values: int, + ) -> None: + """Initialize audio-video cross-attention components.""" + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=1, + ) + + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=1, + ) + + def _init_preprocessors( + self, + cross_pe_max_pos: int | None = None, + ) -> None: + """Initialize preprocessors for LTX.""" + + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + ) + elif self.model_type.is_video_enabled(): + self.video_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + caption_projection=self.caption_projection, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + elif self.model_type.is_audio_enabled(): + self.audio_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + caption_projection=self.audio_caption_projection, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + ) + + def _init_transformer_blocks( + self, + num_layers: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + norm_eps: float, + ) -> None: + """Initialize transformer blocks for LTX.""" + video_config = ( + TransformerConfig( + dim=self.inner_dim, + heads=self.num_attention_heads, + d_head=attention_head_dim, + context_dim=cross_attention_dim, + ) + if self.model_type.is_video_enabled() + else None + ) + audio_config = ( + TransformerConfig( + dim=self.audio_inner_dim, + heads=self.audio_num_attention_heads, + d_head=audio_attention_head_dim, + context_dim=audio_cross_attention_dim, + ) + if self.model_type.is_audio_enabled() + else None + ) + self.transformer_blocks = torch.nn.ModuleList( + [ + BasicAVTransformerBlock( + idx=idx, + video=video_config, + audio=audio_config, + rope_type=self.rope_type, + norm_eps=norm_eps, + ) + for idx in range(num_layers) + ] + ) + + def set_gradient_checkpointing(self, enable: bool) -> None: + """Enable or disable gradient checkpointing for transformer blocks. + Gradient checkpointing trades compute for memory by recomputing activations + during the backward pass instead of storing them. This can significantly + reduce memory usage at the cost of ~20-30% slower training. + Args: + enable: Whether to enable gradient checkpointing + """ + self._enable_gradient_checkpointing = enable + + def _process_transformer_blocks( + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[TransformerArgs, TransformerArgs]: + """Process transformer blocks for LTXAV.""" + + # Process transformer blocks + for block in self.transformer_blocks: + if self._enable_gradient_checkpointing and self.training: + # Use gradient checkpointing to save memory during training. + # With use_reentrant=False, we can pass dataclasses directly - + # PyTorch will track all tensor leaves in the computation graph. + video, audio = torch.utils.checkpoint.checkpoint( + block, + video, + audio, + perturbations, + use_reentrant=False, + ) + else: + video, audio = block( + video=video, + audio=audio, + perturbations=perturbations, + ) + + return video, audio + + def _process_output( + self, + scale_shift_table: torch.Tensor, + norm_out: torch.nn.LayerNorm, + proj_out: torch.nn.Linear, + x: torch.Tensor, + embedded_timestep: torch.Tensor, + ) -> torch.Tensor: + """Process output for LTXV.""" + # Apply scale-shift modulation + scale_shift_values = ( + scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + x = norm_out(x) + x = x * (1 + scale) + shift + x = proj_out(x) + return x + + def forward( + self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for LTX models. + Returns: + Processed output tensors + """ + if not self.model_type.is_video_enabled() and video is not None: + raise ValueError("Video is not enabled for this model") + if not self.model_type.is_audio_enabled() and audio is not None: + raise ValueError("Audio is not enabled for this model") + + video_args = self.video_args_preprocessor.prepare(video) if video is not None else None + audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None + # Process transformer blocks + video_out, audio_out = self._process_transformer_blocks( + video=video_args, + audio=audio_args, + perturbations=perturbations, + ) + + # Process output + vx = ( + self._process_output( + self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep + ) + if video_out is not None + else None + ) + ax = ( + self._process_output( + self.audio_scale_shift_table, + self.audio_norm_out, + self.audio_proj_out, + audio_out.x, + audio_out.embedded_timestep, + ) + if audio_out is not None + else None + ) + return vx, ax diff --git a/diffsynth/models/ltx2_video_vae.py b/diffsynth/models/ltx2_video_vae.py new file mode 100644 index 000000000..34db4fe76 --- /dev/null +++ b/diffsynth/models/ltx2_video_vae.py @@ -0,0 +1,1969 @@ +import itertools +import math +from dataclasses import replace, dataclass +from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from enum import Enum +from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape +from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings + + +class NormLayerType(Enum): + GROUP_NORM = "group_norm" + PIXEL_NORM = "pixel_norm" + + +class LogVarianceType(Enum): + PER_CHANNEL = "per_channel" + UNIFORM = "uniform" + CONSTANT = "constant" + NONE = "none" + + +class PaddingModeType(Enum): + ZEROS = "zeros" + REFLECT = "reflect" + REPLICATE = "replicate" + CIRCULAR = "circular" + + +class DualConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ) -> None: + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + )) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / torch.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / torch.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward( + self, + x: torch.Tensor, + use_conv3d: bool = False, + skip_time_conv: bool = False, + ) -> torch.Tensor: + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + b, _, _, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self) -> torch.Tensor: + return self.weight2 + + +class CausalConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode.value, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self) -> torch.Tensor: + return self.conv.weight + + +def make_conv_nd( # noqa: PLR0913 + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS, +) -> nn.Module: + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias: bool = True, +) -> nn.Module: + if dims == 2: + return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims in (3, (2, 1)): + return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks + and moves pixels from each block into separate channels (space-to-depth). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching). + For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw) + Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from + channels back into patch_size x patch_size blocks (depth-to-space). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion). + For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw) + Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict. + """ + + def __init__(self, latent_channels: int = 128): + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds", torch.empty(latent_channels)) + self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(latent_channels)) + self.register_buffer("channel", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view( + 1, -1, 1, 1, 1).to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view( + 1, -1, 1, 1, 1).to(x) + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = (make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels else nn.Identity()) + + # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout + # avoiding the need for dimension rearrangement used in standard nn.LayerNorm + self.norm3 = (nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True) + if in_channels != out_channels else nn.Identity()) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels)) + + def _feed_spatial_noise( + self, + hidden_states: torch.Tensor, + per_channel_scale: torch.Tensor, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + ada_values = self.scale_shift_table[None, ..., None, None, None].to( + device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: NormLayerType = NormLayerType.GROUP_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=in_channels * 4, + size_emb_dim=0) + + self.res_blocks = nn.ModuleList([ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) for _ in range(num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + timestep_embed = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, + causal=causal, + timestep=timestep_embed, + generator=generator, + ) + + return hidden_states + + +class SpaceToDepthDownsample(nn.Module): + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + stride: Tuple[int, int, int], + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + + def __init__( + self, + dims: int | Tuple[int, int], + in_channels: int, + stride: Tuple[int, int, int], + residual: bool = False, + out_channels_reduction_factor: int = 1, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +def compute_trapezoidal_mask_1d( + length: int, + ramp_left: int, + ramp_right: int, + left_starts_from_0: bool = False, +) -> torch.Tensor: + """ + Generate a 1D trapezoidal blending mask with linear ramps. + Args: + length: Output length of the mask. + ramp_left: Fade-in length on the left. + ramp_right: Fade-out length on the right. + left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. + Useful for temporal tiles where the first tile is causal. + Returns: + A 1D tensor of shape `(length,)` with values in [0, 1]. + """ + if length <= 0: + raise ValueError("Mask length must be positive.") + + ramp_left = max(0, min(ramp_left, length)) + ramp_right = max(0, min(ramp_right, length)) + + mask = torch.ones(length) + + if ramp_left > 0: + interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 + fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1] + if not left_starts_from_0: + fade_in = fade_in[1:] + mask[:ramp_left] *= fade_in + + if ramp_right > 0: + fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1] + mask[-ramp_right:] *= fade_out + + return mask.clamp_(0, 1) + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap. + Args: + tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32. + tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0. + """ + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" + ) + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap. + Args: + tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8. + tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles. + Must be divisible by 8. Defaults to 0. + """ + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + if self.tile_overlap_in_frames >= self.tile_size_in_frames: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" + ) + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap. + Attributes: + spatial_config: Configuration for splitting spatial dimensions into tiles. + temporal_config: Configuration for splitting temporal dimension into tiles. + """ + + spatial_config: SpatialTilingConfig | None = None + temporal_config: TemporalTilingConfig | None = None + + @classmethod + def default(cls) -> "TilingConfig": + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) + + +@dataclass(frozen=True) +class DimensionIntervals: + """Intervals which a single dimension of the latent space is split into. + Each interval is defined by its start, end, left ramp, and right ramp. + The start and end are the indices of the first and last element (exclusive) in the interval. + Ramps are regions of the interval where the value of the mask tensor is + interpolated between 0 and 1 for blending with neighboring intervals. + The left ramp and right ramp values are the lengths of the left and right ramps. + """ + + starts: List[int] + ends: List[int] + left_ramps: List[int] + right_ramps: List[int] + + +@dataclass(frozen=True) +class LatentIntervals: + """Intervals which the latent tensor of given shape is split into. + Each dimension of the latent space is split into intervals based on the length along said dimension. + """ + + original_shape: torch.Size + dimension_intervals: Tuple[DimensionIntervals, ...] + + +# Operation to split a single dimension of the tensor into intervals based on the length along the dimension. +SplitOperation = Callable[[int], DimensionIntervals] +# Operation to map the intervals in input dimension to slices and masks along a corresponding output dimension. +MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]] + + +def default_split_operation(length: int) -> DimensionIntervals: + return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0]) + + +DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation + + +def default_mapping_operation(_intervals: DimensionIntervals,) -> tuple[list[slice], list[torch.Tensor | None]]: + return [slice(0, None)], [None] + + +DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation + + +class Tile(NamedTuple): + """ + Represents a single tile. + Attributes: + in_coords: + Tuple of slices specifying where to cut the tile from the INPUT tensor. + out_coords: + Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor. + masks_1d: + Per-dimension masks in OUTPUT units. + These are used to create all-dimensional blending mask. + Methods: + blend_mask: + Create a single N-D mask from the per-dimension masks. + """ + + in_coords: Tuple[slice, ...] + out_coords: Tuple[slice, ...] + masks_1d: Tuple[Tuple[torch.Tensor, ...]] + + @property + def blend_mask(self) -> torch.Tensor: + num_dims = len(self.out_coords) + per_dimension_masks: List[torch.Tensor] = [] + + for dim_idx in range(num_dims): + mask_1d = self.masks_1d[dim_idx] + view_shape = [1] * num_dims + if mask_1d is None: + # Broadcast mask along this dimension (length 1). + one = torch.ones(1) + + view_shape[dim_idx] = 1 + per_dimension_masks.append(one.view(*view_shape)) + continue + + # Reshape (L,) -> (1, ..., L, ..., 1) so masks across dimensions broadcast-multiply. + view_shape[dim_idx] = mask_1d.shape[0] + per_dimension_masks.append(mask_1d.view(*view_shape)) + + # Multiply per-dimension masks to form the full N-D mask (separable blending window). + combined_mask = per_dimension_masks[0] + for mask in per_dimension_masks[1:]: + combined_mask = combined_mask * mask + + return combined_mask + + +def create_tiles_from_intervals_and_mappers( + intervals: LatentIntervals, + mappers: List[MappingOperation], +) -> List[Tile]: + full_dim_input_slices = [] + full_dim_output_slices = [] + full_dim_masks_1d = [] + for axis_index in range(len(intervals.original_shape)): + dimension_intervals = intervals.dimension_intervals[axis_index] + starts = dimension_intervals.starts + ends = dimension_intervals.ends + input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)] + output_slices, masks_1d = mappers[axis_index](dimension_intervals) + full_dim_input_slices.append(input_slices) + full_dim_output_slices.append(output_slices) + full_dim_masks_1d.append(masks_1d) + + tiles = [] + tile_in_coords = list(itertools.product(*full_dim_input_slices)) + tile_out_coords = list(itertools.product(*full_dim_output_slices)) + tile_mask_1ds = list(itertools.product(*full_dim_masks_1d)) + for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True): + tiles.append(Tile( + in_coords=in_coord, + out_coords=out_coord, + masks_1d=mask_1d, + )) + return tiles + + +def create_tiles( + latent_shape: torch.Size, + splitters: List[SplitOperation], + mappers: List[MappingOperation], +) -> List[Tile]: + if len(splitters) != len(latent_shape): + raise ValueError(f"Number of splitters must be equal to number of dimensions in latent shape, " + f"got {len(splitters)} and {len(latent_shape)}") + if len(mappers) != len(latent_shape): + raise ValueError(f"Number of mappers must be equal to number of dimensions in latent shape, " + f"got {len(mappers)} and {len(latent_shape)}") + intervals = [splitter(length) for splitter, length in zip(splitters, latent_shape, strict=True)] + latent_intervals = LatentIntervals(original_shape=latent_shape, dimension_intervals=tuple(intervals)) + return create_tiles_from_intervals_and_mappers(latent_intervals, mappers) + + +def _make_encoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + return block, out_channels + + +class LTX2VideoEncoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Encoder. Encodes video frames into a latent representation. + The encoder compresses the input video through a series of downsampling operations controlled by + patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W'). + Compression Behavior: + The total compression is determined by: + 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4) + 2. Sequential compression through encoder_blocks based on their stride patterns + Compression blocks apply 2x compression in specified dimensions: + - "compress_time" / "compress_time_res": temporal only + - "compress_space" / "compress_space_res": spatial only (H and W) + - "compress_all" / "compress_all_res": all dimensions (F, H, W) + - "res_x" / "res_x_y": no compression + Standard LTX Video configuration: + - patch_size=4 + - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res + - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32 + - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16) + - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...) + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels. For RGB images, this is 3. + out_channels: The number of output channels (latent channels). For latent channels, this is 128. + encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: The patch size for initial spatial compression. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 3, + out_channels: int = 128, + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, + encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + encoder_blocks = [['res_x', { + 'num_layers': 4 + }], ['compress_space_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 6 + }], ['compress_time_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 6 + }], ['compress_all_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 2 + }], ['compress_all_res', { + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 2 + }]] + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for normalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) + + in_channels = in_channels * patch_size**2 + feature_channels = out_channels + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in encoder_blocks: + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_encoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks.append(block) + + # out + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == LogVarianceType.PER_CHANNEL: + conv_out_channels *= 2 + elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + conv_out_channels += 1 + elif latent_log_var != LogVarianceType.NONE: + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=conv_out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r""" + Encode video frames into normalized latent representation. + Args: + sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...). + Returns: + Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32. + Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16). + """ + # Validate frame count + frames_count = sample.shape[2] + if ((frames_count - 1) % 8) != 0: + raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames " + "(e.g., 1, 9, 17, ...). Please check your input.") + + # Initial spatial compression: trade spatial resolution for channel depth + # This reduces H,W by patch_size and increases channels, making convolutions more efficient + # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4 + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + for down_block in self.down_blocks: + sample = down_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == LogVarianceType.UNIFORM: + # Uniform Variance: model outputs N means and 1 shared log-variance channel. + # We need to expand the single logvar to match the number of means channels + # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels). + # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129) + # Target shape: (B, 2*N, ...) where first N are means, last N are logvar + + if sample.shape[1] < 2: + raise ValueError(f"Invalid channel count for UNIFORM mode: expected at least 2 channels " + f"(N means + 1 logvar), got {sample.shape[1]}") + + # Extract means (first N channels) and logvar (last 1 channel) + means = sample[:, :-1, ...] # (B, N, ...) + logvar = sample[:, -1:, ...] # (B, 1, ...) + + # Repeat logvar N times to match means channels + # Use expand/repeat pattern that works for both 4D and 5D tensors + num_channels = means.shape[1] + repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2) + repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...) + + # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar] + sample = torch.cat([means, repeated_logvar], dim=1) + elif self.latent_log_var == LogVarianceType.CONSTANT: + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + # Split into means and logvar, then normalize means + means, _ = torch.chunk(sample, 2, dim=1) + return self.per_channel_statistics.normalize(means) + + +def _make_decoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + timestep_conditioning: bool, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_config["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels // block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 2, 2), + residual=block_config.get("residual", False), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + return block, out_channels + + +class LTX2VideoDecoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Decoder. Decodes latent representation into video frames. + The decoder upsamples latents through a series of upsampling operations (inverse of encoder). + Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration. + Upsampling blocks expand dimensions by 2x in specified dimensions: + - "compress_time": temporal only + - "compress_space": spatial only (H and W) + - "compress_all": all dimensions (F, H, W) + - "res_x" / "res_x_y" / "attn_res_x": no upsampling + Causal Mode: + causal=False (standard): Symmetric padding, allows future frame dependencies. + causal=True: Causal padding, each frame depends only on past/current frames. + First frame removed after temporal upsampling in both modes. Output shape unchanged. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes. + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels (latent channels). Default is 128. + out_channels: The number of output channels. For RGB images, this is 3. + decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion: + H -> Hx4, W -> Wx4. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding. + When True, uses causal padding (past/current frames only). + timestep_conditioning: Whether to condition the decoder on timestep for denoising. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 128, + out_channels: int = 3, + decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + causal: bool = False, + timestep_conditioning: bool = False, + decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, + ): + super().__init__() + + # Spatiotemporal downscaling between decoded video space and VAE latents. + # According to the LTXV paper, the standard configuration downsamples + # video inputs by a factor of 8 in the temporal dimension and 32 in + # each spatial dimension (height and width). This parameter determines how + # many video frames and pixels correspond to a single latent cell. + decoder_blocks = [['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }], ['compress_all', { + 'residual': True, + 'multiplier': 2 + }], ['res_x', { + 'num_layers': 5, + 'inject_noise': False + }]] + self.video_downscale_factors = SpatioTemporalScaleFactors( + time=8, + width=32, + height=32, + ) + + self.patch_size = patch_size + out_channels = out_channels * patch_size**2 + self.causal = causal + self.timestep_conditioning = timestep_conditioning + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) + + # Noise and timestep parameters for decoder conditioning + self.decode_noise_scale = 0.025 + self.decode_timestep = 0.05 + + # Compute initial feature_channels by going through blocks in reverse + # This determines the channel width at the start of the decoder + feature_channels = in_channels + for block_name, block_params in list(reversed(decoder_blocks)): + block_config = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + feature_channels = feature_channels * block_config.get("multiplier", 2) + if block_name == "compress_all": + feature_channels = feature_channels * block_config.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(decoder_blocks)): + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_decoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + timestep_conditioning=timestep_conditioning, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks.append(block) + + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=feature_channels * 2, + size_emb_dim=0) + self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels)) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r""" + Decode latent representation into video frames. + Args: + sample: Latent tensor (B, 128, F', H', W'). + timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None. + generator: Random generator for deterministic noise injection (if inject_noise=True in blocks). + Returns: + Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512). + Note: First frame is removed after temporal upsampling regardless of causal mode. + When causal=False, allows future frame dependencies in convolutions but maintains same output shape. + """ + batch_size = sample.shape[0] + + # Add noise if timestep conditioning is enabled + if self.timestep_conditioning: + noise = (torch.randn( + sample.size(), + generator=generator, + dtype=sample.dtype, + device=sample.device, + ) * self.decode_noise_scale) + + sample = noise + (1.0 - self.decode_noise_scale) * sample + + # Denormalize latents + sample = self.per_channel_statistics.un_normalize(sample) + + # Use default decode_timestep if timestep not provided + if timestep is None and self.timestep_conditioning: + timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype) + + sample = self.conv_in(sample, causal=self.causal) + + scaled_timestep = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample) + + for up_block in self.up_blocks: + if isinstance(up_block, UNetMidBlock3D): + block_kwargs = { + "causal": self.causal, + "timestep": scaled_timestep if self.timestep_conditioning else None, + "generator": generator, + } + sample = up_block(sample, **block_kwargs) + elif isinstance(up_block, ResnetBlock3D): + sample = up_block(sample, causal=self.causal, generator=generator) + else: + sample = up_block(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None].to( + device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + # Final spatial expansion: reverse the initial patchify from encoder + # Moves pixels from channels back to spatial dimensions + # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4 + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + def _prepare_tiles( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + ) -> List[Tile]: + splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape) + mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape) + if tiling_config is not None and tiling_config.spatial_config is not None: + cfg = tiling_config.spatial_config + long_side = max(latent.shape[3], latent.shape[4]) + + def enable_on_axis(axis_idx: int, factor: int) -> None: + size = cfg.tile_size_in_pixels // factor + overlap = cfg.tile_overlap_in_pixels // factor + axis_length = latent.shape[axis_idx] + lower_threshold = max(2, overlap + 1) + tile_size = max(lower_threshold, round(size * axis_length / long_side)) + splitters[axis_idx] = split_in_spatial(tile_size, overlap) + mappers[axis_idx] = to_mapping_operation(map_spatial_slice, factor) + + enable_on_axis(3, self.video_downscale_factors.height) + enable_on_axis(4, self.video_downscale_factors.width) + + if tiling_config is not None and tiling_config.temporal_config is not None: + cfg = tiling_config.temporal_config + tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time + overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time + splitters[2] = split_in_temporal(tile_size, overlap) + mappers[2] = to_mapping_operation(map_temporal_slice, self.video_downscale_factors.time) + + return create_tiles(latent.shape, splitters, mappers) + + def tiled_decode( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> Iterator[torch.Tensor]: + """ + Decode a latent tensor into video frames using tiled processing. + Splits the latent tensor into tiles, decodes each tile individually, + and yields video chunks as they become available. + Args: + latent: Input latent tensor (B, C, F', H', W'). + tiling_config: Tiling configuration for the latent tensor. + timestep: Optional timestep for decoder conditioning. + generator: Optional random generator for deterministic decoding. + Yields: + Video chunks (B, C, T, H, W) by temporal slices; + """ + + # Calculate full video shape from latent shape to get spatial dimensions + full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors) + tiles = self._prepare_tiles(latent, tiling_config) + + temporal_groups = self._group_tiles_by_temporal_slice(tiles) + + # State for temporal overlap handling + previous_chunk = None + previous_weights = None + previous_temporal_slice = None + + for temporal_group_tiles in temporal_groups: + curr_temporal_slice = temporal_group_tiles[0].out_coords[2] + + # Calculate the shape of the temporal buffer for this group of tiles. + # The temporal length depends on whether this is the first tile (starts at 0) or not. + # - First tile: (frames - 1) * scale + 1 + # - Subsequent tiles: frames * scale + # This logic is handled by TemporalAxisMapping and reflected in out_coords. + temporal_tile_buffer_shape = full_video_shape._replace(frames=curr_temporal_slice.stop - + curr_temporal_slice.start,) + + buffer = torch.zeros( + temporal_tile_buffer_shape.to_torch_shape(), + device=latent.device, + dtype=latent.dtype, + ) + + curr_weights = self._accumulate_temporal_group_into_buffer( + group_tiles=temporal_group_tiles, + buffer=buffer, + latent=latent, + timestep=timestep, + generator=generator, + ) + + # Blend with previous temporal chunk if it exists + if previous_chunk is not None: + # Check if current temporal slice overlaps with previous temporal slice + if previous_temporal_slice.stop > curr_temporal_slice.start: + overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start + temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None) + + # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer + # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add + # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the + # previous buffers, then later normalize by weights. + previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :] + previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[:, :, + slice(0, overlap_len), :, :] + + buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :] + curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[:, :, + temporal_overlap_slice, :, :] + + # Yield the non-overlapping part of the previous chunk + previous_weights = previous_weights.clamp(min=1e-8) + yield_len = curr_temporal_slice.start - previous_temporal_slice.start + yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :] + + # Update state for next iteration + previous_chunk = buffer + previous_weights = curr_weights + previous_temporal_slice = curr_temporal_slice + + # Yield any remaining chunk + if previous_chunk is not None: + previous_weights = previous_weights.clamp(min=1e-8) + yield previous_chunk / previous_weights + + def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]: + """Group tiles by their temporal output slice.""" + if not tiles: + return [] + + groups = [] + current_slice = tiles[0].out_coords[2] + current_group = [] + + for tile in tiles: + tile_slice = tile.out_coords[2] + if tile_slice == current_slice: + current_group.append(tile) + else: + groups.append(current_group) + current_slice = tile_slice + current_group = [tile] + + # Add the final group + if current_group: + groups.append(current_group) + + return groups + + def _accumulate_temporal_group_into_buffer( + self, + group_tiles: List[Tile], + buffer: torch.Tensor, + latent: torch.Tensor, + timestep: torch.Tensor | None, + generator: torch.Generator | None, + ) -> torch.Tensor: + """ + Decode and accumulate all tiles of a temporal group into a local buffer. + The buffer is local to the group and always starts at time 0; temporal coordinates + are rebased by subtracting temporal_slice.start. + """ + temporal_slice = group_tiles[0].out_coords[2] + + weights = torch.zeros_like(buffer) + + for tile in group_tiles: + decoded_tile = self.forward(latent[tile.in_coords], timestep, generator) + mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype) + temporal_offset = tile.out_coords[2].start - temporal_slice.start + # Use the tile's output coordinate length, not the decoded tile's length, + # as the decoder may produce a different number of frames than expected + expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start + decoded_temporal_len = decoded_tile.shape[2] + + # Ensure we don't exceed the buffer or decoded tile bounds + actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset) + + chunk_coords = ( + slice(None), # batch + slice(None), # channels + slice(temporal_offset, temporal_offset + actual_temporal_len), + tile.out_coords[3], # height + tile.out_coords[4], # width + ) + + # Slice decoded_tile and mask to match the actual length we're writing + decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :] + mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask + + buffer[chunk_coords] += decoded_slice * mask_slice + weights[chunk_coords] += mask_slice + + return weights + + +def decode_video( + latent: torch.Tensor, + video_decoder: LTX2VideoDecoder, + tiling_config: TilingConfig | None = None, + generator: torch.Generator | None = None, +) -> Iterator[torch.Tensor]: + """ + Decode a video latent tensor with the given decoder. + Args: + latent: Tensor [c, f, h, w] + video_decoder: Decoder module. + tiling_config: Optional tiling settings. + generator: Optional random generator for deterministic decoding. + Yields: + Decoded chunk [f, h, w, c], uint8 in [0, 255]. + """ + + def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: + frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8) + frames = rearrange(frames[0], "c f h w -> f h w c") + return frames + + if tiling_config is not None: + for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator): + yield convert_to_uint8(frames) + else: + decoded_video = video_decoder(latent, generator=generator) + yield convert_to_uint8(decoded_video) + + +def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: + """ + Get the number of video chunks for a given number of frames and tiling configuration. + Args: + num_frames: Number of frames in the video. + tiling_config: Tiling configuration. + Returns: + Number of video chunks. + """ + if not tiling_config or not tiling_config.temporal_config: + return 1 + cfg = tiling_config.temporal_config + frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames + return (num_frames - 1 + frame_stride - 1) // frame_stride + + +def split_in_spatial(size: int, overlap: int) -> SplitOperation: + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) + + return split + + +def split_in_temporal(size: int, overlap: int) -> SplitOperation: + non_causal_split = split_in_spatial(size, overlap) + + def split(dimension_size: int) -> DimensionIntervals: + if dimension_size <= size: + return DEFAULT_SPLIT_OPERATION(dimension_size) + intervals = non_causal_split(dimension_size) + starts = intervals.starts + starts[1:] = [s - 1 for s in starts[1:]] + left_ramps = intervals.left_ramps + left_ramps[1:] = [r + 1 for r in left_ramps[1:]] + return replace(intervals, starts=starts, left_ramps=left_ramps) + + return split + + +def to_mapping_operation( + map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor]], + scale: int, +) -> MappingOperation: + + def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]: + output_slices: list[slice] = [] + masks_1d: list[torch.Tensor | None] = [] + number_of_slices = len(intervals.starts) + for i in range(number_of_slices): + start = intervals.starts[i] + end = intervals.ends[i] + left_ramp = intervals.left_ramps[i] + right_ramp = intervals.right_ramps[i] + output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale) + output_slices.append(output_slice) + masks_1d.append(mask_1d) + return output_slices, masks_1d + + return map_op + + +def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp = 1 + (left_ramp - 1) * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True) + + +def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = end * scale + left_ramp = left_ramp * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False) diff --git a/diffsynth/utils/state_dict_converters/ltx2_dit.py b/diffsynth/utils/state_dict_converters/ltx2_dit.py new file mode 100644 index 000000000..baffb9a6d --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_dit.py @@ -0,0 +1,9 @@ +def LTXModelStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("model.diffusion_model."): + new_name = name.replace("model.diffusion_model.", "") + if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."): + continue + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_video_vae.py b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py new file mode 100644 index 000000000..132897de2 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_video_vae.py @@ -0,0 +1,22 @@ +def LTX2VideoEncoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.encoder."): + new_name = name.replace("vae.encoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2VideoDecoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vae.decoder."): + new_name = name.replace("vae.decoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("vae.per_channel_statistics."): + new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/test/load_model.py b/diffsynth/utils/test/load_model.py new file mode 100644 index 000000000..f21fa606b --- /dev/null +++ b/diffsynth/utils/test/load_model.py @@ -0,0 +1,22 @@ +import torch +from diffsynth.models.model_loader import ModelPool +from diffsynth.core.loader import ModelConfig + + +def test_model_loading(model_name, + model_config: ModelConfig, + vram_limit: float = None, + device="cpu", + torch_dtype=torch.bfloat16): + model_pool = ModelPool() + model_config.download_if_necessary() + vram_config = model_config.vram_config() + vram_config["computation_dtype"] = torch_dtype + vram_config["computation_device"] = device + model_pool.auto_load_model( + model_config.path, + vram_config=vram_config, + vram_limit=vram_limit, + clear_parameters=model_config.clear_parameters, + ) + return model_pool.fetch_model(model_name) From 8d303b47e924120a7ba41a97c79e46e46ed00836 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 28 Jan 2026 16:09:22 +0800 Subject: [PATCH 2/6] add audio_vae, audio_vocoder, text_encoder, connector and upsampler for ltx2 --- diffsynth/configs/model_configs.py | 42 +- diffsynth/models/ltx2_audio_vae.py | 1337 +++++++++++++++++ diffsynth/models/ltx2_common.py | 88 +- diffsynth/models/ltx2_text_encoder.py | 366 +++++ diffsynth/models/ltx2_upsampler.py | 313 ++++ .../state_dict_converters/ltx2_audio_vae.py | 32 + .../ltx2_text_encoder.py | 31 + diffsynth/utils/test/load_model.py | 22 - 8 files changed, 2207 insertions(+), 24 deletions(-) create mode 100644 diffsynth/models/ltx2_audio_vae.py create mode 100644 diffsynth/models/ltx2_text_encoder.py create mode 100644 diffsynth/models/ltx2_upsampler.py create mode 100644 diffsynth/utils/state_dict_converters/ltx2_audio_vae.py create mode 100644 diffsynth/utils/state_dict_converters/ltx2_text_encoder.py delete mode 100644 diffsynth/utils/test/load_model.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 7809ee2ca..f427fa6a3 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -613,6 +613,46 @@ "model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder", "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter", }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vae_decoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_audio_vocoder", + "model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter", + }, + # { # not used currently + # # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + # "model_hash": "aca7b0bbf8415e9c98360750268915fc", + # "model_name": "ltx2_audio_vae_encoder", + # "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder", + # "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter", + # }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "aca7b0bbf8415e9c98360750268915fc", + "model_name": "ltx2_text_encoder_post_modules", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter", + }, + { + # Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors") + "model_hash": "33917f31c4a79196171154cca39f165e", + "model_name": "ltx2_text_encoder", + "model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors") + "model_hash": "c79c458c6e99e0e14d47e676761732d2", + "model_name": "ltx2_latent_upsampler", + "model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler", + }, ] - MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series diff --git a/diffsynth/models/ltx2_audio_vae.py b/diffsynth/models/ltx2_audio_vae.py new file mode 100644 index 000000000..dd4694166 --- /dev/null +++ b/diffsynth/models/ltx2_audio_vae.py @@ -0,0 +1,1337 @@ +from typing import Set, Tuple, Optional, List +from enum import Enum +import math +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer + +class AudioPatchifier(Patchifier): + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + shift: int = 0, + ): + """ + Patchifier tailored for spectrogram/audio latents. + Args: + patch_size: Number of mel bins combined into a single patch. This + controls the resolution along the frequency axis. + sample_rate: Original waveform sampling rate. Used to map latent + indices back to seconds so downstream consumers can align audio + and video cues. + hop_length: Window hop length used for the spectrogram. Determines + how many real-time samples separate two consecutive latent frames. + audio_latent_downsample_factor: Ratio between spectrogram frames and + latent frames; compensates for additional downsampling inside the + VAE encoder. + is_causal: When True, timing is shifted to account for causal + receptive fields so timestamps do not peek into the future. + shift: Integer offset applied to the latent indices. Enables + constructing overlapping windows from the same latent sequence. + """ + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + self._patch_size = (1, patch_size, patch_size) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: AudioLatentShape) -> int: + return tgt_shape.frames + + def _get_audio_latent_time_in_sec( + self, + start_latent: int, + end_latent: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Converts latent indices into real-time seconds while honoring causal + offsets and the configured hop length. + Args: + start_latent: Inclusive start index inside the latent sequence. This + sets the first timestamp returned. + end_latent: Exclusive end index. Determines how many timestamps get + generated. + dtype: Floating-point dtype used for the returned tensor, allowing + callers to control precision. + device: Target device for the timestamp tensor. When omitted the + computation occurs on CPU to avoid surprising GPU allocations. + """ + if device is None: + device = torch.device("cpu") + + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + + if self.is_causal: + # Frame offset for causal alignment. + # The "+1" ensures the timestamp corresponds to the first sample that is fully available. + causal_offset = 1 + audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) + + return audio_mel_frame * self.hop_length / self.sample_rate + + def _compute_audio_timings( + self, + batch_size: int, + num_steps: int, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. + This helper method underpins `get_patch_grid_bounds` for the audio patchifier. + Args: + batch_size: Number of sequences to broadcast the timings over. + num_steps: Number of latent frames (time steps) to convert into timestamps. + device: Device on which the resulting tensor should reside. + """ + resolved_device = device + if resolved_device is None: + resolved_device = torch.device("cpu") + + start_timings = self._get_audio_latent_time_in_sec( + self.shift, + num_steps + self.shift, + torch.float32, + resolved_device, + ) + start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + end_timings = self._get_audio_latent_time_in_sec( + self.shift + 1, + num_steps + self.shift + 1, + torch.float32, + resolved_device, + ) + end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + return torch.stack([start_timings, end_timings], dim=-1) + + def patchify( + self, + audio_latents: torch.Tensor, + ) -> torch.Tensor: + """ + Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` + to derive timestamps for each latent frame based on the configured hop + length and downsampling. + Args: + audio_latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the + corresponding timing metadata when needed. + """ + audio_latents = einops.rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + return audio_latents + + def unpatchify( + self, + audio_latents: torch.Tensor, + output_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. + Use `get_patch_grid_bounds` to recompute the timestamps that describe each + frame's position in real time. + Args: + audio_latents: Latent tensor to unpatchify. + output_shape: Shape of the unpatched output tensor. + Returns: + Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing + metadata associated with the restored latents. + """ + # audio_latents shape: (batch, time, freq * channels) + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=output_shape.channels, + f=output_shape.mel_bins, + ) + + return audio_latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the temporal bounds `[inclusive start, exclusive end)` for every + patch emitted by `patchify`. For audio this corresponds to timestamps in + seconds aligned with the original spectrogram grid. + The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores the `[start, end)` timestamps per patch + Args: + output_shape: Audio grid specification describing the number of time steps. + device: Target device for the returned tensor. + """ + if not isinstance(output_shape, AudioLatentShape): + raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") + + return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) + + +class AttentionType(Enum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class AttnBlock(torch.nn.Module): + def __init__( + self, + in_channels: int, + norm_type: NormType = NormType.GROUP, + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn( + in_channels: int, + attn_type: AttentionType = AttentionType.VANILLA, + norm_type: NormType = NormType.GROUP, +) -> torch.nn.Module: + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return torch.nn.Identity() + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class CausalityAxis(Enum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +class CausalConv2d(torch.nn.Module): + """ + A causal 2D convolution. + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = torch.nn.modules.utils._pair(kernel_size) + dilation = torch.nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + padding: tuple[int, int, int, int] | None = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis | None = None, +) -> torch.nn.Module: + """ + Create a 2D convolution layer that can be either causal or non-causal. + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding="same", + ), + ] + ) + + self.convs2 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2, strict=True): + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv1(xt) + xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) + xt = conv2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)): + super(ResBlock2, self).__init__() + self.convs = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv in self.convs: + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv(xt) + x = xt + x + return x + + +class ResnetBlock(torch.nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = torch.nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class Downsample(torch.nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +def build_downsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, +) -> tuple[torch.nn.ModuleList, int]: + """Build the downsampling path with residual blocks, attention, and downsampling layers.""" + down_modules = torch.nn.ModuleList() + curr_res = resolution + in_ch_mult = (1, *tuple(ch_mult)) + block_in = ch + + for i_level in range(num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + down_modules.append(down) + + return down_modules, block_in + + +class Upsample(torch.nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +def build_upsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, + initial_block_channels: int, +) -> tuple[torch.nn.ModuleList, int]: + """Build the upsampling path with residual blocks, attention, and upsampling layers.""" + up_modules = torch.nn.ModuleList() + block_in = initial_block_channels + curr_res = resolution // (2 ** (num_resolutions - 1)) + + for level in reversed(range(num_resolutions)): + stage = torch.nn.Module() + stage.block = torch.nn.ModuleList() + stage.attn = torch.nn.ModuleList() + block_out = ch * ch_mult[level] + + for _ in range(num_res_blocks + 1): + stage.block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + if level != 0: + stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict. + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +def build_mid_block( + channels: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + add_attention: bool, +) -> torch.nn.Module: + """Build the middle block with two ResNet blocks and optional attention.""" + mid = torch.nn.Module() + mid.block_1 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity() + mid.block_2 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + return mid + + +def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor: + """Run features through the middle block.""" + features = mid.block_1(features, temb=None) + features = mid.attn_1(features) + return mid.block_2(features, temb=None) + + +class LTX2AudioEncoder(torch.nn.Module): + """ + Encoder that compresses audio spectrograms into latent representations. + The encoder uses a series of downsampling blocks with residual connections, + attention mechanisms, and configurable causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int = 128, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Set[int] = set(), + dropout: float = 0.0, + resamp_with_conv: bool = True, + in_channels: int = 2, + resolution: int = 256, + z_channels: int = 8, + double_z: bool = True, + attn_type: AttentionType = AttentionType.VANILLA, + mid_block_add_attention: bool = False, + norm_type: NormType = NormType.PIXEL, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + sample_rate: int = 16000, + mel_hop_length: int = 160, + n_fft: int = 1024, + is_causal: bool = True, + mel_bins: int = 64, + **_ignore_kwargs, + ) -> None: + """ + Initialize the Encoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + ch: Base number of feature channels used in the first convolution layer. + ch_mult: Multiplicative factors for the number of channels at each resolution level. + num_res_blocks: Number of residual blocks to use at each resolution level. + attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention. + resolution: Input spatial resolution of the spectrogram (height, width). + z_channels: Number of channels in the latent representation. + norm_type: Normalization layer type to use within the network (e.g., group, batch). + causality_axis: Axis along which convolutions should be causal (e.g., time axis). + sample_rate: Audio sample rate in Hz for the input signals. + mel_hop_length: Hop length used when computing the mel spectrogram. + n_fft: FFT size used to compute the spectrogram. + mel_bins: Number of mel-frequency bins in the input spectrogram. + in_channels: Number of channels in the input spectrogram tensor. + double_z: If True, predict both mean and log-variance (doubling latent channels). + is_causal: If True, use causal convolutions suitable for streaming setups. + dropout: Dropout probability used in residual and mid blocks. + attn_type: Type of attention mechanism to use in attention blocks. + resamp_with_conv: If True, perform resolution changes using strided convolutions. + mid_block_add_attention: If True, add an attention block in the mid-level of the encoder. + """ + super().__init__() + + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + self.causality_axis = causality_axis + self.attn_type = attn_type + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + self.non_linearity = torch.nn.SiLU() + + self.down, block_in = build_downsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + ) + + self.mid = build_mid_block( + channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + + self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: + """ + Encode audio spectrogram into latent representations. + Args: + spectrogram: Input spectrogram of shape (batch, channels, time, frequency) + Returns: + Encoded latent representation of shape (batch, channels, frames, mel_bins) + """ + h = self.conv_in(spectrogram) + h = self._run_downsampling_path(h) + h = run_mid_block(self.mid, h) + h = self._finalize_output(h) + + return self._normalize_latents(h) + + def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx in range(self.num_res_blocks): + h = stage.block[block_idx](h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != self.num_resolutions - 1: + h = stage.downsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + h = self.norm_out(h) + h = self.non_linearity(h) + return self.conv_out(h) + + def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor: + """ + Normalize encoder latents using per-channel statistics. + When the encoder is configured with ``double_z=True``, the final + convolution produces twice the number of latent channels, typically + interpreted as two concatenated tensors along the channel dimension + (e.g., mean and variance or other auxiliary parameters). + This method intentionally uses only the first half of the channels + (the "mean" component) as input to the patchifier and normalization + logic. The remaining channels are left unchanged by this method and + are expected to be consumed elsewhere in the VAE pipeline. + If ``double_z=False``, the encoder output already contains only the + mean latents and the chunking operation simply returns that tensor. + """ + means = torch.chunk(latent_output, 2, dim=1)[0] + latent_shape = AudioLatentShape( + batch=means.shape[0], + channels=means.shape[1], + frames=means.shape[2], + mel_bins=means.shape[3], + ) + latent_patched = self.patchifier.patchify(means) + latent_normalized = self.per_channel_statistics.normalize(latent_patched) + return self.patchifier.unpatchify(latent_normalized, latent_shape) + + +class LTX2AudioDecoder(torch.nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + The decoder mirrors the encoder structure with configurable channel multipliers, + attention resolutions, and causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int = 128, + out_ch: int = 2, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Set[int] = set(), + resolution: int=256, + z_channels: int=8, + norm_type: NormType = NormType.PIXEL, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + ) -> None: + """ + Initialize the Decoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions + - resolution, z_channels + - norm_type, causality_axis + """ + super().__init__() + + # Internal behavioural defaults that are not driven by the checkpoint. + resamp_with_conv = True + attn_type = AttentionType.VANILLA + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.out_ch = out_ch + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.z_channels = z_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + self.attn_type = attn_type + + base_block_channels = ch * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, z_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = torch.nn.SiLU() + self.mid = build_mid_block( + channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + self.up, final_block_channels = build_upsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + initial_block_channels=base_block_channels, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + """ + Decode latent features back to audio spectrograms. + Args: + sample: Encoded latent representation of shape (batch, channels, frames, mel_bins) + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + sample, target_shape = self._denormalize_latents(sample) + + h = self.conv_in(sample) + h = run_mid_block(self.mid, h) + h = self._run_upsampling_path(h) + h = self._finalize_output(h) + + return self._adjust_output_shape(h, target_shape) + + def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + if self.causality_axis != CausalityAxis.NONE: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + return sample, target_shape + + def _adjust_output_shape( + self, + decoded_output: torch.Tensor, + target_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Adjust output shape to match target dimensions for variable-length audio. + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: AudioLatentShape describing (batch, channels, time, mel bins) + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + h = block(h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != 0 and hasattr(stage, "upsample"): + h = stage.upsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = self.non_linearity(h) + h = self.conv_out(h) + return torch.tanh(h) if self.tanh_out else h + + +class LTX2Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from Mel spectrograms. + Args: + resblock_kernel_sizes: List of kernel sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`. + upsample_rates: List of upsampling rates. + This value is read from the checkpoint at `config.vocoder.upsample_rates`. + upsample_kernel_sizes: List of kernel sizes for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`. + resblock_dilation_sizes: List of dilation sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`. + upsample_initial_channel: Initial number of channels for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`. + stereo: Whether to use stereo output. + This value is read from the checkpoint at `config.vocoder.stereo`. + resblock: Type of residual block to use. + This value is read from the checkpoint at `config.vocoder.resblock`. + output_sample_rate: Waveform sample rate. + This value is read from the checkpoint at `config.vocoder.output_sample_rate`. + """ + + def __init__( + self, + resblock_kernel_sizes: List[int] | None = [3, 7, 11], + upsample_rates: List[int] | None = [6, 5, 2, 2, 2], + upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4], + resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel: int = 1024, + stereo: bool = True, + resblock: str = "1", + output_sample_rate: int = 24000, + ): + super().__init__() + + # Initialize default values if not provided. Note that mutable default values are not supported. + if resblock_kernel_sizes is None: + resblock_kernel_sizes = [3, 7, 11] + if upsample_rates is None: + upsample_rates = [6, 5, 2, 2, 2] + if upsample_kernel_sizes is None: + upsample_kernel_sizes = [16, 15, 8, 4, 4] + if resblock_dilation_sizes is None: + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + + self.output_sample_rate = output_sample_rate + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)): + self.ups.append( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i, _ in enumerate(self.ups): + ch = upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True): + self.resblocks.append(resblock_class(ch, kernel_size, dilations)) + + out_channels = 2 if stereo else 1 + final_channels = upsample_initial_channel // (2**self.num_upsamples) + self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3) + + self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the vocoder. + Args: + x: Input Mel spectrogram tensor. Can be either: + - 3D: (batch_size, time, mel_bins) for mono + - 4D: (batch_size, 2, time, mel_bins) for stereo + Returns: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time) + + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = einops.rearrange(x, "b s c t -> b (s c) t") + + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + start = i * self.num_kernels + end = start + self.num_kernels + + # Evaluate all resblocks with the same input tensor so they can run + # independently (and thus in parallel on accelerator hardware) before + # aggregating their outputs via mean. + block_outputs = torch.stack( + [self.resblocks[idx](x) for idx in range(start, end)], + dim=0, + ) + + x = block_outputs.mean(dim=0) + + x = self.conv_post(F.leaky_relu(x)) + return torch.tanh(x) + + +def decode_audio(latent: torch.Tensor, audio_decoder: "LTX2AudioDecoder", vocoder: "LTX2Vocoder") -> torch.Tensor: + """ + Decode an audio latent representation using the provided audio decoder and vocoder. + Args: + latent: Input audio latent tensor. + audio_decoder: Model to decode the latent to waveform features. + vocoder: Model to convert decoded features to audio waveform. + Returns: + Decoded audio as a float tensor. + """ + decoded_audio = audio_decoder(latent) + decoded_audio = vocoder(decoded_audio).squeeze(0).float() + return decoded_audio diff --git a/diffsynth/models/ltx2_common.py b/diffsynth/models/ltx2_common.py index fbcae4b9e..25e46bc6e 100644 --- a/diffsynth/models/ltx2_common.py +++ b/diffsynth/models/ltx2_common.py @@ -1,7 +1,9 @@ from dataclasses import dataclass -from typing import NamedTuple +from typing import NamedTuple, Protocol, Tuple import torch from torch import nn +from enum import Enum + class VideoPixelShape(NamedTuple): """ @@ -180,6 +182,13 @@ def clone(self) -> "LatentState": ) +class NormType(Enum): + """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm).""" + + GROUP = "group" + PIXEL = "pixel" + + class PixelNorm(nn.Module): """ Per-pixel (per-location) RMS normalization layer. @@ -209,6 +218,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x / rms +def build_normalization_layer( + in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP +) -> nn.Module: + """ + Create a normalization layer based on the normalization type. + Args: + in_channels: Number of input channels + num_groups: Number of groups for group normalization + normtype: Type of normalization: "group" or "pixel" + Returns: + A normalization layer + """ + if normtype == NormType.GROUP: + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == NormType.PIXEL: + return PixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") + + def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor: """Root-mean-square (RMS) normalize `x` over its last dimension. Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized @@ -251,3 +279,61 @@ def to_denoised( if isinstance(sigma, torch.Tensor): sigma = sigma.to(calc_dtype) return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) + + + +class Patchifier(Protocol): + """ + Protocol for patchifiers that convert latent tensors into patches and assemble them back. + """ + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + ... + """ + Convert latent tensors into flattened patch tokens. + Args: + latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. + """ + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: AudioLatentShape | VideoLatentShape, + ) -> torch.Tensor: + """ + Converts latent tensors between spatio-temporal formats and flattened sequence representations. + Args: + latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. + output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or + VideoLatentShape. + Returns: + Dense latent tensor restored from the flattened representation. + """ + + @property + def patch_size(self) -> Tuple[int, int, int]: + ... + """ + Returns the patch size as a tuple of (temporal, height, width) dimensions + """ + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: torch.device | None = None, + ) -> torch.Tensor: + ... + """ + Compute metadata describing where each latent patch resides within the + grid specified by `output_shape`. + Args: + output_shape: Target grid layout for the patches. + device: Target device for the returned tensor. + Returns: + Tensor containing patch coordinate metadata such as spatial or temporal intervals. + """ diff --git a/diffsynth/models/ltx2_text_encoder.py b/diffsynth/models/ltx2_text_encoder.py new file mode 100644 index 000000000..535d5e578 --- /dev/null +++ b/diffsynth/models/ltx2_text_encoder.py @@ -0,0 +1,366 @@ +import torch +from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer +from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention, + FeedForward) +from .ltx2_common import rms_norm + + +class LTX2TextEncoder(Gemma3ForConditionalGeneration): + def __init__(self): + config = Gemma3Config( + **{ + "architectures": ["Gemma3ForConditionalGeneration"], + "boi_token_index": 255999, + "dtype": "bfloat16", + "eoi_token_index": 256000, + "eos_token_id": [1, 106], + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "text_config": { + "_sliding_window_pattern": 6, + "attention_bias": False, + "attention_dropout": 0.0, + "attn_logit_softcapping": None, + "cache_implementation": "hybrid", + "dtype": "bfloat16", + "final_logit_softcapping": None, + "head_dim": 256, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 3840, + "initializer_range": 0.02, + "intermediate_size": 15360, + "layer_types": [ + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", + "sliding_attention", "sliding_attention", "sliding_attention", "full_attention" + ], + "max_position_embeddings": 131072, + "model_type": "gemma3_text", + "num_attention_heads": 16, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "query_pre_attn_scalar": 256, + "rms_norm_eps": 1e-06, + "rope_local_base_freq": 10000, + "rope_scaling": { + "factor": 8.0, + "rope_type": "linear" + }, + "rope_theta": 1000000, + "sliding_window": 1024, + "sliding_window_pattern": 6, + "use_bidirectional_attention": False, + "use_cache": True, + "vocab_size": 262208 + }, + "transformers_version": "4.57.3", + "vision_config": { + "attention_dropout": 0.0, + "dtype": "bfloat16", + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14, + "vision_use_head": False + } + }) + super().__init__(config) + + +class LTXVGemmaTokenizer: + """ + Tokenizer wrapper for Gemma models compatible with LTXV processes. + This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders, + ensuring correct settings and output formatting for downstream consumption. + """ + + def __init__(self, tokenizer_path: str, max_length: int = 1024): + """ + Initialize the tokenizer. + Args: + tokenizer_path (str): Path to the pretrained tokenizer files or model directory. + max_length (int, optional): Max sequence length for encoding. Defaults to 256. + """ + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, local_files_only=True, model_max_length=max_length + ) + # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much. + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.max_length = max_length + + def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]: + """ + Tokenize the given text and return token IDs and attention weights. + Args: + text (str): The input string to tokenize. + return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples. + If False (default), omits the indices. + Returns: + dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]: + A dictionary with a "gemma" key mapping to: + - a list of (token_id, attention_mask) tuples if return_word_ids is False; + - a list of (token_id, attention_mask, index) tuples if return_word_ids is True. + Example: + >>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8) + >>> tokenizer.tokenize_with_weights("hello world") + {'gemma': [(1234, 1), (5678, 1), (2, 0), ...]} + """ + text = text.strip() + encoded = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_tensors="pt", + ) + input_ids = encoded.input_ids + attention_mask = encoded.attention_mask + tuples = [ + (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True)) + ] + out = {"gemma": tuples} + + if not return_word_ids: + # Return only (token_id, attention_mask) pairs, omitting token position + out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()} + + return out + + +class GemmaFeaturesExtractorProjLinear(torch.nn.Module): + """ + Feature extractor module for Gemma models. + This module applies a single linear projection to the input tensor. + It expects a flattened feature tensor of shape (batch_size, 3840*49). + The linear layer maps this to a (batch_size, 3840) embedding. + Attributes: + aggregate_embed (torch.nn.Linear): Linear projection layer. + """ + + def __init__(self) -> None: + """ + Initialize the GemmaFeaturesExtractorProjLinear module. + The input dimension is expected to be 3840 * 49, and the output is 3840. + """ + super().__init__() + self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the feature extractor. + Args: + x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49). + Returns: + torch.Tensor: Output tensor of shape (batch_size, 3840). + """ + return self.aggregate_embed(x) + + +class _BasicTransformerBlock1D(torch.nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + ): + super().__init__() + + self.attn1 = Attention( + query_dim=dim, + heads=heads, + dim_head=dim_head, + rope_type=rope_type, + ) + + self.ff = FeedForward( + dim, + dim_out=dim, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(torch.nn.Module): + """ + Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or + other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can + substitute padded positions with learnable registers. The module is highly configurable for head size, number of + layers, and register usage. + Args: + attention_head_dim (int): Dimension of each attention head (default=128). + num_attention_heads (int): Number of attention heads (default=30). + num_layers (int): Number of transformer layers (default=2). + positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0). + positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]). + causal_temporal_positioning (bool): If True, uses causal attention (default=False). + num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables + register replacement. (default=128) + rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE). + double_precision_rope (bool): Use double precision rope calculation (default=False). + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + attention_head_dim: int = 128, + num_attention_heads: int = 30, + num_layers: int = 2, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = [4096], + causal_temporal_positioning: bool = False, + num_learnable_registers: int | None = 128, + rope_type: LTXRopeType = LTXRopeType.SPLIT, + double_precision_rope: bool = True, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = ( + positional_embedding_max_pos if positional_embedding_max_pos is not None else [1] + ) + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = torch.nn.ModuleList( + [ + _BasicTransformerBlock1D( + dim=self.inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = torch.nn.Parameter( + torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0 + ) + + def _replace_padded_with_learnable_registers( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[1] % self.num_learnable_registers == 0, ( + f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers " + f"{self.num_learnable_registers}." + ) + + num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers + learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1)) + attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int() + + non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :] + non_zero_nums = non_zero_hidden_states.shape[1] + pad_length = hidden_states.shape[1] - non_zero_nums + adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0) + flipped_mask = torch.flip(attention_mask_binary, dims=[1]) + hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers + + attention_mask = torch.full_like( + attention_mask, + 0.0, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + return hidden_states, attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of Embeddings1DConnector. + Args: + hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]). + attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states). + Returns: + tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask. + """ + if self.num_learnable_registers: + hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) + + indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device) + indices_grid = indices_grid[None, None, :] + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + freqs_cis = precompute_freqs_cis( + indices_grid=indices_grid, + dim=self.inner_dim, + out_dtype=hidden_states.dtype, + theta=self.positional_embedding_theta, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + + for block in self.transformer_1d_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis) + + hidden_states = rms_norm(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextEncoderPostModules(torch.nn.Module): + def __init__(self,): + super().__init__() + self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear() + self.embeddings_connector = Embeddings1DConnector() + self.audio_embeddings_connector = Embeddings1DConnector() \ No newline at end of file diff --git a/diffsynth/models/ltx2_upsampler.py b/diffsynth/models/ltx2_upsampler.py new file mode 100644 index 000000000..862ca14bc --- /dev/null +++ b/diffsynth/models/ltx2_upsampler.py @@ -0,0 +1,313 @@ +import math +from typing import Optional, Tuple +import torch +from einops import rearrange +import torch.nn.functional as F +from .ltx2_video_vae import LTX2VideoEncoder + +class PixelShuffleND(torch.nn.Module): + """ + N-dimensional pixel shuffle operation for upsampling tensors. + Args: + dims (int): Number of dimensions to apply pixel shuffle to. + - 1: Temporal (e.g., frames) + - 2: Spatial (e.g., height and width) + - 3: Spatiotemporal (e.g., depth, height, width) + upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension. + For dims=1, only the first value is used. + For dims=2, the first two values are used. + For dims=3, all three values are used. + The input tensor is rearranged so that the channel dimension is split into + smaller channels and upscaling factors, and the upscaling factors are moved + into the corresponding spatial/temporal dimensions. + Note: + This operation is equivalent to the patchifier operation in for the models. Consider + using this class instead. + """ + + def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + else: + raise ValueError(f"Unsupported dims: {self.dims}") + + +class ResBlock(torch.nn.Module): + """ + Residual block with two convolutional layers, group normalization, and SiLU activation. + Args: + channels (int): Number of input and output channels. + mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels` + if not specified. + dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3. + """ + + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + assert dims in (2, 3) + assert isinstance(stride, int) + assert stride >= 1 + assert kernel_size >= 3 + assert kernel_size % 2 == 1 + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + return self._apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self._apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor: + c = x2d.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + return x2d + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}") + return mapping[float(scale)] + + +class SpatialRationalResampler(torch.nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + Args: + mid_channels (`int`): Number of intermediate channels for the convolution layer + scale (`float`): Spatial scaling factor. Supported values are: + - 0.75: Downsample by 3/4 (reduce spatial size) + - 1.5: Upsample by 3/2 (increase spatial size) + - 2.0: Upsample by 2x (double spatial size) + - 4.0: Upsample by 4x (quadruple spatial size) + Any other value will raise a ValueError. + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class LTX2LatentUpsampler(torch.nn.Module): + """ + Model to upsample VAE latents spatially and/or temporally. + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + spatial_scale (`float`): Scale factor for spatial upsampling + rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling + """ + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + # remove the first frame after upsampling. + # This is done because the first frame encodes one pixel frame. + x = x[:, :, 1:, :, :] + elif isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + +def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor: + """ + Apply upsampling to the latent representation using the provided upsampler, + with normalization and un-normalization based on the video encoder's per-channel statistics. + Args: + latent: Input latent tensor of shape [B, C, F, H, W]. + video_encoder: VideoEncoder with per_channel_statistics for normalization. + upsampler: LTX2LatentUpsampler module to perform upsampling. + Returns: + torch.Tensor: Upsampled and re-normalized latent tensor. + """ + latent = video_encoder.per_channel_statistics.un_normalize(latent) + latent = upsampler(latent) + latent = video_encoder.per_channel_statistics.normalize(latent) + return latent diff --git a/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py new file mode 100644 index 000000000..c9bb66d25 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_audio_vae.py @@ -0,0 +1,32 @@ +def LTX2AudioEncoderStateDictConverter(state_dict): + # Not used + state_dict_ = {} + for name in state_dict: + if name.startswith("audio_vae.encoder."): + new_name = name.replace("audio_vae.encoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2AudioDecoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("audio_vae.decoder."): + new_name = name.replace("audio_vae.decoder.", "") + state_dict_[new_name] = state_dict[name] + elif name.startswith("audio_vae.per_channel_statistics."): + new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.") + state_dict_[new_name] = state_dict[name] + return state_dict_ + + +def LTX2VocoderStateDictConverter(state_dict): + state_dict_ = {} + for name in state_dict: + if name.startswith("vocoder."): + new_name = name.replace("vocoder.", "") + state_dict_[new_name] = state_dict[name] + return state_dict_ diff --git a/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py b/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py new file mode 100644 index 000000000..b7e528f7c --- /dev/null +++ b/diffsynth/utils/state_dict_converters/ltx2_text_encoder.py @@ -0,0 +1,31 @@ +def LTX2TextEncoderStateDictConverter(state_dict): + state_dict_ = {} + for key in state_dict: + if key.startswith("language_model.model."): + new_key = key.replace("language_model.model.", "model.language_model.") + elif key.startswith("vision_tower."): + new_key = key.replace("vision_tower.", "model.vision_tower.") + elif key.startswith("multi_modal_projector."): + new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.") + elif key.startswith("language_model.lm_head."): + new_key = key.replace("language_model.lm_head.", "lm_head.") + else: + continue + state_dict_[new_key] = state_dict[key] + state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight") + return state_dict_ + + +def LTX2TextEncoderPostModulesStateDictConverter(state_dict): + state_dict_ = {} + for key in state_dict: + if key.startswith("text_embedding_projection."): + new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.") + elif key.startswith("model.diffusion_model.video_embeddings_connector."): + new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.") + elif key.startswith("model.diffusion_model.audio_embeddings_connector."): + new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.") + else: + continue + state_dict_[new_key] = state_dict[key] + return state_dict_ diff --git a/diffsynth/utils/test/load_model.py b/diffsynth/utils/test/load_model.py deleted file mode 100644 index f21fa606b..000000000 --- a/diffsynth/utils/test/load_model.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -from diffsynth.models.model_loader import ModelPool -from diffsynth.core.loader import ModelConfig - - -def test_model_loading(model_name, - model_config: ModelConfig, - vram_limit: float = None, - device="cpu", - torch_dtype=torch.bfloat16): - model_pool = ModelPool() - model_config.download_if_necessary() - vram_config = model_config.vram_config() - vram_config["computation_dtype"] = torch_dtype - vram_config["computation_device"] = device - model_pool.auto_load_model( - model_config.path, - vram_config=vram_config, - vram_limit=vram_limit, - clear_parameters=model_config.clear_parameters, - ) - return model_pool.fetch_model(model_name) From b1a2782ad7575d70e447c70c59c8c612f6cce80e Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 29 Jan 2026 16:30:15 +0800 Subject: [PATCH 3/6] support ltx2 one-stage pipeline --- diffsynth/diffusion/flow_match.py | 28 +- diffsynth/models/ltx2_common.py | 32 ++ diffsynth/models/ltx2_dit.py | 10 +- diffsynth/models/ltx2_video_vae.py | 339 ++++++++++++- diffsynth/pipelines/ltx2_audio_video.py | 451 ++++++++++++++++++ diffsynth/utils/data/media_io.py | 106 ++++ .../model_inference/LTX-2-T2AV-OneStage.py | 46 ++ 7 files changed, 1005 insertions(+), 7 deletions(-) create mode 100644 diffsynth/pipelines/ltx2_audio_video.py create mode 100644 diffsynth/utils/data/media_io.py create mode 100644 examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 2d6b3676c..aed8a3527 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -4,13 +4,14 @@ class FlowMatchScheduler(): - def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"): + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2"] = "FLUX.1"): self.set_timesteps_fn = { "FLUX.1": FlowMatchScheduler.set_timesteps_flux, "Wan": FlowMatchScheduler.set_timesteps_wan, "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, "Z-Image": FlowMatchScheduler.set_timesteps_z_image, + "LTX-2": FlowMatchScheduler.set_timesteps_ltx2, }.get(template, FlowMatchScheduler.set_timesteps_flux) self.num_train_timesteps = 1000 @@ -121,7 +122,30 @@ def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift timestep_id = torch.argmin((timesteps - timestep).abs()) timesteps[timestep_id] = timestep return sigmas, timesteps - + + @staticmethod + def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, stretch=True, terminal=0.1): + dynamic_shift_len = dynamic_shift_len or 4096 + sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( + image_seq_len=dynamic_shift_len, + base_seq_len=1024, + max_seq_len=4096, + base_shift=0.95, + max_shift=2.05, + ) + num_train_timesteps = 1000 + sigma_min = 0.0 + sigma_max = 1.0 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1 - terminal) + sigmas = 1.0 - (one_minus_z / scale_factor) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + def set_training_weight(self): steps = 1000 x = self.timesteps diff --git a/diffsynth/models/ltx2_common.py b/diffsynth/models/ltx2_common.py index 25e46bc6e..a06ccd6b0 100644 --- a/diffsynth/models/ltx2_common.py +++ b/diffsynth/models/ltx2_common.py @@ -337,3 +337,35 @@ def get_patch_grid_bounds( Returns: Tensor containing patch coordinate metadata such as spatial or temporal intervals. """ + + +def get_pixel_coords( + latent_coords: torch.Tensor, + scale_factors: SpatioTemporalScaleFactors, + causal_fix: bool = False, +) -> torch.Tensor: + """ + Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling + each axis (frame/time, height, width) with the corresponding VAE downsampling factors. + Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. + Args: + latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. + scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied + per axis. + causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs + that treat frame zero differently still yield non-negative timestamps. + """ + # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout. + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width) + scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) + + # Apply per-axis scaling to convert latent bounds into pixel-space coordinates. + pixel_coords = latent_coords * scale_tensor + + if causal_fix: + # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`. + # Shift and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) + + return pixel_coords diff --git a/diffsynth/models/ltx2_dit.py b/diffsynth/models/ltx2_dit.py index a4cfd5c34..43a27f642 100644 --- a/diffsynth/models/ltx2_dit.py +++ b/diffsynth/models/ltx2_dit.py @@ -514,7 +514,7 @@ def forward( out_pattern="b s n d", attn_mask=mask ) - + # Reshape back to original format out = out.flatten(2, 3) return self.to_out(out) @@ -1398,7 +1398,7 @@ def _process_output( x = proj_out(x) return x - def forward( + def _forward( self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1440,3 +1440,9 @@ def forward( else None ) return vx, ax + + def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps): + video = Modality(video_latents, video_timesteps, video_positions, video_context) + audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) + vx, ax = self._forward(video=video, audio=audio, perturbations=None) + return vx, ax diff --git a/diffsynth/models/ltx2_video_vae.py b/diffsynth/models/ltx2_video_vae.py index 34db4fe76..0bd13314c 100644 --- a/diffsynth/models/ltx2_video_vae.py +++ b/diffsynth/models/ltx2_video_vae.py @@ -1,5 +1,6 @@ import itertools import math +import einops from dataclasses import replace, dataclass from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional import torch @@ -7,9 +8,138 @@ from torch import nn from torch.nn import functional as F from enum import Enum -from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape +from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings +VAE_SPATIAL_FACTOR = 32 +VAE_TEMPORAL_FACTOR = 8 + + +class VideoLatentPatchifier(Patchifier): + def __init__(self, patch_size: int): + # Patch sizes for video latents. + self._patch_size = ( + 1, # temporal dimension + patch_size, # height dimension + patch_size, # width dimension + ) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: VideoLatentShape) -> int: + return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + + return latents + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: VideoLatentShape, + ) -> torch.Tensor: + assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" + + patch_grid_frames = output_shape.frames // self._patch_size[0] + patch_grid_height = output_shape.height // self._patch_size[1] + patch_grid_width = output_shape.width // self._patch_size[2] + + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=patch_grid_frames, + h=patch_grid_height, + w=patch_grid_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + + return latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the per-dimension bounds [inclusive start, exclusive end) for every + patch produced by `patchify`. The bounds are expressed in the original + video grid coordinates: frame/time, height, and width. + The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: + - axis 1 (size 3) enumerates (frame/time, height, width) dimensions + - axis 3 (size 2) stores `[start, end)` indices within each dimension + Args: + output_shape: Video grid description containing frames, height, and width. + device: Device of the latent tensor. + """ + if not isinstance(output_shape, VideoLatentShape): + raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") + + frames = output_shape.frames + height = output_shape.height + width = output_shape.width + batch_size = output_shape.batch + + # Validate inputs to ensure positive dimensions + assert frames > 0, f"frames must be positive, got {frames}" + assert height > 0, f"height must be positive, got {height}" + assert width > 0, f"width must be positive, got {width}" + assert batch_size > 0, f"batch_size must be positive, got {batch_size}" + + # Generate grid coordinates for each dimension (frame, height, width) + # We use torch.arange to create the starting coordinates for each patch. + # indexing='ij' ensures the dimensions are in the order (frame, height, width). + grid_coords = torch.meshgrid( + torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), + torch.arange(start=0, end=height, step=self._patch_size[1], device=device), + torch.arange(start=0, end=width, step=self._patch_size[2], device=device), + indexing="ij", + ) + + # Stack the grid coordinates to create the start coordinates tensor. + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_starts = torch.stack(grid_coords, dim=0) + + # Create a tensor containing the size of a single patch: + # (frame_patch_size, height_patch_size, width_patch_size). + # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates. + patch_size_delta = torch.tensor( + self._patch_size, + device=patch_starts.device, + dtype=patch_starts.dtype, + ).view(3, 1, 1, 1) + + # Calculate end coordinates: start + patch_size + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_ends = patch_starts + patch_size_delta + + # Stack start and end coordinates together along the last dimension + # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end] + latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) + + # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence. + # Final Shape: (batch_size, 3, num_patches, 2) + latent_coords = einops.repeat( + latent_coords, + "c f h w bounds -> b c (f h w) bounds", + b=batch_size, + bounds=2, + ) + + return latent_coords + class NormLayerType(Enum): GROUP_NORM = "group_norm" @@ -1339,6 +1469,185 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: return self.per_channel_statistics.normalize(means) + def tiled_encode_video( + self, + video: torch.Tensor, + tile_size: int = 512, + tile_overlap: int = 128, + ) -> torch.Tensor: + """Encode video using spatial tiling for memory efficiency. + Splits the video into overlapping spatial tiles, encodes each tile separately, + and blends the results using linear feathering in the overlap regions. + Args: + video: Input tensor of shape [B, C, F, H, W] + tile_size: Tile size in pixels (must be divisible by 32) + tile_overlap: Overlap between tiles in pixels (must be divisible by 32) + Returns: + Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent] + """ + batch, _channels, frames, height, width = video.shape + device = video.device + dtype = video.dtype + + # Validate tile parameters + if tile_size % VAE_SPATIAL_FACTOR != 0: + raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}") + if tile_overlap % VAE_SPATIAL_FACTOR != 0: + raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}") + if tile_overlap >= tile_size: + raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})") + + # If video fits in a single tile, use regular encoding + if height <= tile_size and width <= tile_size: + return self.forward(video) + + # Calculate output dimensions + # VAE compresses: H -> H/32, W -> W/32, F -> 1 + (F-1)/8 + output_height = height // VAE_SPATIAL_FACTOR + output_width = width // VAE_SPATIAL_FACTOR + output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR + + # Latent channels (128 for LTX-2) + # Get from a small test encode or assume 128 + latent_channels = 128 + + # Initialize output and weight tensors + output = torch.zeros( + (batch, latent_channels, output_frames, output_height, output_width), + device=device, + dtype=dtype, + ) + weights = torch.zeros( + (batch, 1, output_frames, output_height, output_width), + device=device, + dtype=dtype, + ) + + # Calculate tile positions with overlap + # Step size is tile_size - tile_overlap + step_h = tile_size - tile_overlap + step_w = tile_size - tile_overlap + + h_positions = list(range(0, max(1, height - tile_overlap), step_h)) + w_positions = list(range(0, max(1, width - tile_overlap), step_w)) + + # Ensure last tile covers the edge + if h_positions[-1] + tile_size < height: + h_positions.append(height - tile_size) + if w_positions[-1] + tile_size < width: + w_positions.append(width - tile_size) + + # Remove duplicates and sort + h_positions = sorted(set(h_positions)) + w_positions = sorted(set(w_positions)) + + # Overlap in latent space + overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR + overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR + + # Process each tile + for h_pos in h_positions: + for w_pos in w_positions: + # Calculate tile boundaries in input space + h_start = max(0, h_pos) + w_start = max(0, w_pos) + h_end = min(h_start + tile_size, height) + w_end = min(w_start + tile_size, width) + + # Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR + tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR + tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR + + if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR: + continue + + # Adjust end positions + h_end = h_start + tile_h + w_end = w_start + tile_w + + # Extract tile + tile = video[:, :, :, h_start:h_end, w_start:w_end] + + # Encode tile + encoded_tile = self.forward(tile) + + # Get actual encoded dimensions + _, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape + + # Calculate output positions + out_h_start = h_start // VAE_SPATIAL_FACTOR + out_w_start = w_start // VAE_SPATIAL_FACTOR + out_h_end = min(out_h_start + tile_out_height, output_height) + out_w_end = min(out_w_start + tile_out_width, output_width) + + # Trim encoded tile if necessary + actual_tile_h = out_h_end - out_h_start + actual_tile_w = out_w_end - out_w_start + encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w] + + # Create blending mask with linear feathering at edges + mask = torch.ones( + (1, 1, tile_out_frames, actual_tile_h, actual_tile_w), + device=device, + dtype=dtype, + ) + + # Apply feathering at edges (linear blend in overlap regions) + # Left edge + if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h: + fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1) + + # Right edge (bottom in height dimension) + if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h: + fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1) + + # Top edge (left in width dimension) + if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w: + fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1) + + # Bottom edge (right in width dimension) + if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w: + fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] + mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1) + + # Accumulate weighted results + output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask + weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask + + # Normalize by weights (avoid division by zero) + output = output / (weights + 1e-8) + + return output + + def encode( + self, + video: torch.Tensor, + tiled=False, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + **kwargs, + ) -> torch.Tensor: + device = next(self.parameters()).device + vae_dtype = next(self.parameters()).dtype + if video.ndim == 4: + video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W] + video = video.to(device=device, dtype=vae_dtype) + # Choose encoding method based on tiling flag + if tiled: + latents = self.tiled_encode_video( + video=video, + tile_size=tile_size_in_pixels, + tile_overlap=tile_overlap_in_pixels, + ) + else: + # Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W'] + latents = self.forward(video) + return latents + + def _make_decoder_block( block_name: str, block_config: dict[str, Any], @@ -1850,6 +2159,30 @@ def _accumulate_temporal_group_into_buffer( return weights + def decode( + self, + latent: torch.Tensor, + tiled=False, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + tile_size_in_frames: Optional[int] = 128, + tile_overlap_in_frames: Optional[int] = 24, + ) -> torch.Tensor: + if tiled: + tiling_config = TilingConfig( + spatial_config=SpatialTilingConfig( + tile_size_in_pixels=tile_size_in_pixels, + tile_overlap_in_pixels=tile_overlap_in_pixels, + ), + temporal_config=TemporalTilingConfig( + tile_size_in_frames=tile_size_in_frames, + tile_overlap_in_frames=tile_overlap_in_frames, + ), + ) + tiles = self.tiled_decode(latent, tiling_config) + return torch.cat(list(tiles), dim=2) + else: + return self.forward(latent) def decode_video( latent: torch.Tensor, @@ -1875,10 +2208,10 @@ def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: if tiling_config is not None: for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator): - yield convert_to_uint8(frames) + return convert_to_uint8(frames) else: decoded_video = video_decoder(latent, generator=generator) - yield convert_to_uint8(decoded_video) + return convert_to_uint8(decoded_video) def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py new file mode 100644 index 000000000..10812881b --- /dev/null +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -0,0 +1,451 @@ +import torch, types +import numpy as np +from PIL import Image +from einops import repeat +from typing import Optional, Union +from einops import rearrange +import numpy as np +from PIL import Image +from tqdm import tqdm +from typing import Optional +from typing_extensions import Literal +from transformers import AutoImageProcessor, Gemma3Processor +import einops + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion import FlowMatchScheduler +from ..core import ModelConfig, gradient_checkpoint_forward +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer +from ..models.ltx2_dit import LTXModel +from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier +from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier +from ..models.ltx2_upsampler import LTX2LatentUpsampler +from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS + + +class LTX2AudioVideoPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): + super().__init__( + device=device, + torch_dtype=torch_dtype, + height_division_factor=32, + width_division_factor=32, + time_division_factor=8, + time_division_remainder=1, + ) + self.scheduler = FlowMatchScheduler("LTX-2") + self.text_encoder: LTX2TextEncoder = None + self.tokenizer: LTXVGemmaTokenizer = None + self.processor: Gemma3Processor = None + self.text_encoder_post_modules: LTX2TextEncoderPostModules = None + self.dit: LTXModel = None + self.video_vae_encoder: LTX2VideoEncoder = None + self.video_vae_decoder: LTX2VideoDecoder = None + self.audio_vae_encoder: LTX2AudioEncoder = None + self.audio_vae_decoder: LTX2AudioDecoder = None + self.audio_vocoder: LTX2Vocoder = None + self.upsampler: LTX2LatentUpsampler = None + + self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1) + self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1) + + self.in_iteration_models = ("dit",) + self.units = [ + LTX2AudioVideoUnit_PipelineChecker(), + LTX2AudioVideoUnit_ShapeChecker(), + LTX2AudioVideoUnit_PromptEmbedder(), + LTX2AudioVideoUnit_NoiseInitializer(), + LTX2AudioVideoUnit_InputVideoEmbedder(), + ] + self.post_units = [ + LTX2AudioVideoPostUnit_UnPatchifier(), + ] + self.model_fn = model_fn_ltx2 + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + vram_limit: float = None, + ): + # Initialize pipeline + pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder") + tokenizer_config.download_if_necessary() + pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path) + image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True) + pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer) + + pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules") + pipe.dit = model_pool.fetch_model("ltx2_dit") + pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder") + pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder") + pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder") + pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder") + pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") + # Optional + # pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: Optional[str] = "", + # Image-to-video + input_image: Optional[Image.Image] = None, + denoising_strength: float = 1.0, + # Randomness + seed: Optional[int] = None, + rand_device: Optional[str] = "cpu", + # Shape + height: Optional[int] = 512, + width: Optional[int] = 768, + num_frames=121, + # Classifier-free guidance + cfg_scale: Optional[float] = 3.0, + cfg_merge: Optional[bool] = False, + # Scheduler + num_inference_steps: Optional[int] = 40, + # VAE tiling + tiled: Optional[bool] = True, + tile_size_in_pixels: Optional[int] = 512, + tile_overlap_in_pixels: Optional[int] = 128, + tile_size_in_frames: Optional[int] = 128, + tile_overlap_in_frames: Optional[int] = 24, + # Two-Stage Pipeline + use_two_stage: Optional[bool] = True, + # progress_bar + progress_bar_cmd=tqdm, + ): + # Scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) + + # Inputs + inputs_posi = { + "prompt": prompt, + } + inputs_nega = { + "negative_prompt": negative_prompt, + } + inputs_shared = { + "input_image": input_image, + "seed": seed, "rand_device": rand_device, + "height": height, "width": width, "num_frames": num_frames, + "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, + "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, + "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, + "use_two_stage": True + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt")) + # inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt")) + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) + + # post-denoising, pre-decoding processing logic + for unit in self.post_units: + inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Decode + self.load_models_to_device(['video_vae_decoder']) + video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, + tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames) + video = self.vae_output_to_video(video) + self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder']) + decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"]) + decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float() + return video, decoded_audio + + + def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) + noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) + if cfg_scale != 1.0: + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) + if isinstance(noise_pred_posi, tuple): + noise_pred = tuple( + n_nega + cfg_scale * (n_posi - n_nega) + for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) + ) + else: + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + return noise_pred + + +class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + + def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + pass + return inputs_shared, inputs_posi, inputs_nega + + +class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit): + """ + # TODO: Adjust with two stage pipeline + For two-stage pipelines, the resolution must be divisible by 64. + For one-stage pipelines, the resolution must be divisible by 32. + """ + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames): + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames} + + +class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): + + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("video_context", "audio_context"), + onload_model_names=("text_encoder", "text_encoder_post_modules"), + ) + + def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return (attention_mask - 1).to(dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max + + def _run_connectors(self, pipe, encoded_input: torch.Tensor, + attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype) + + encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector( + encoded_input, + connector_attention_mask, + ) + + # restore the mask values to int64 + attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64) + attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded * attention_mask + + encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector( + encoded_input, connector_attention_mask) + + return encoded, encoded_for_audio, attention_mask.squeeze(-1) + + def _norm_and_concat_padded_batch( + self, + encoded_text: torch.Tensor, + sequence_lengths: torch.Tensor, + padding_side: str = "right", + ) -> torch.Tensor: + """Normalize and flatten multi-layer hidden states, respecting padding. + Performs per-batch, per-layer normalization using masked mean and range, + then concatenates across the layer dimension. + Args: + encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers]. + sequence_lengths: Number of valid (non-padded) tokens per batch item. + padding_side: Whether padding is on "left" or "right". + Returns: + Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers], + with padded positions zeroed out. + """ + b, t, d, l = encoded_text.shape # noqa: E741 + device = encoded_text.device + # Build mask: [B, T, 1, 1] + token_indices = torch.arange(t, device=device)[None, :] # [1, T] + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [B, T] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = t - sequence_lengths[:, None] # [B, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = rearrange(mask, "b t -> b t 1 1") + eps = 1e-6 + # Compute masked mean: [B, 1, 1, L] + masked = encoded_text.masked_fill(~mask, 0.0) + denom = (sequence_lengths * d).view(b, 1, 1, 1) + mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps) + # Compute masked min/max: [B, 1, 1, L] + x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + range_ = x_max - x_min + # Normalize only the valid tokens + normed = 8 * (encoded_text - mean) / (range_ + eps) + # concat to be [Batch, T, D * L] - this preserves the original structure + normed = normed.reshape(b, t, -1) # [B, T, D * L] + # Apply mask to preserve original padding (set padded positions to 0) + mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l) + normed = normed.masked_fill(~mask_flattened, 0.0) + + return normed + + def _run_feature_extractor(self, + pipe, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + padding_side: str = "right") -> torch.Tensor: + encoded_text_features = torch.stack(hidden_states, dim=-1) + encoded_text_features_dtype = encoded_text_features.dtype + sequence_lengths = attention_mask.sum(dim=-1) + normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features, + sequence_lengths, + padding_side=padding_side) + + return pipe.text_encoder_post_modules.feature_extractor_linear( + normed_concated_encoded_text_features.to(encoded_text_features_dtype)) + + def _preprocess_text(self, + pipe, + text: str, + padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Encode a given string into feature tensors suitable for downstream tasks. + Args: + text (str): Input string to encode. + Returns: + tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask. + """ + token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"] + input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device) + attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device) + outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + projected = self._run_feature_extractor(pipe, + hidden_states=outputs.hidden_states, + attention_mask=attention_mask, + padding_side=padding_side) + return projected, attention_mask + + def encode_prompt(self, pipe, text, padding_side="left"): + encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side) + video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask) + return video_encoding, audio_encoding, attention_mask + + def process(self, pipe: LTX2AudioVideoPipeline, prompt: str): + pipe.load_models_to_device(self.onload_model_names) + video_context, audio_context, _ = self.encode_prompt(pipe, prompt) + return {"video_context": video_context, "audio_context": audio_context} + + +class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "num_frames", "seed", "rand_device",), + output_params=("video_noise", "audio_noise",), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): + video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels) + video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) + video_noise = pipe.video_patchifier.patchify(video_noise) + + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() + video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate + video_positions = video_positions.to(pipe.torch_dtype) + + audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape) + audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) + audio_noise = pipe.audio_patchifier.patchify(audio_noise) + audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) + return { + "video_noise": video_noise, + "audio_noise": audio_noise, + "video_positions": video_positions, + "audio_positions": audio_positions, + "video_latent_shape": video_latent_shape, + "audio_latent_shape": audio_latent_shape + } + + +class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"), + output_params=("video_latents", "audio_latents"), + onload_model_names=("video_vae_encoder") + ) + + def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride): + if input_video is None: + return {"video_latents": video_noise, "audio_latents": audio_noise} + else: + # TODO: implement video-to-video + raise NotImplementedError("Video-to-video not implemented yet.") + + +class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit): + + def __init__(self): + super().__init__( + input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"), + output_params=("video_latents", "audio_latents"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents): + video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape) + audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape) + return {"video_latents": video_latents, "audio_latents": audio_latents} + + +def model_fn_ltx2( + dit: LTXModel, + video_latents=None, + video_context=None, + video_positions=None, + audio_latents=None, + audio_context=None, + audio_positions=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + #TODO: support gradient checkpointing + timestep = timestep.float() / 1000. + video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) + audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1) + vx, ax = dit( + video_latents=video_latents, + video_positions=video_positions, + video_context=video_context, + video_timesteps=video_timesteps, + audio_latents=audio_latents, + audio_positions=audio_positions, + audio_context=audio_context, + audio_timesteps=audio_timesteps, + ) + return vx, ax diff --git a/diffsynth/utils/data/media_io.py b/diffsynth/utils/data/media_io.py new file mode 100644 index 000000000..a95fdbc4e --- /dev/null +++ b/diffsynth/utils/data/media_io.py @@ -0,0 +1,106 @@ + +from fractions import Fraction +import torch +import av +from tqdm import tqdm + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + +def write_video_audio_ltx2( + video: list[Image.Image], + audio: torch.Tensor | None, + output_path: str, + fps: int = 24, + audio_sample_rate: int | None = 24000, +) -> None: + + width, height = video[0].size + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame in tqdm(video, total=len(video)): + frame = av.VideoFrame.from_image(frame) + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py new file mode 100644 index 000000000..afbdd4be7 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -0,0 +1,46 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +prompt = """ +INT. OVEN – DAY. Static camera from inside the oven, looking outward through the slightly fogged glass door. Warm golden light glows around freshly baked cookies. The baker’s face fills the frame, eyes wide with focus, his breath fogging the glass as he leans in. Subtle reflections move across the glass as steam rises. +Baker (whispering dramatically): “Today… I achieve perfection.” +He leans even closer, nose nearly touching the glass. +""" +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=False, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_onestage.mp4', + fps=24, + audio_sample_rate=24000, +) From 4f23caa55fed9445b75936ca63c58dd69194e673 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 30 Jan 2026 16:55:40 +0800 Subject: [PATCH 4/6] support ltx2 two stage pipeline & vram --- .../configs/vram_management_module_maps.py | 33 ++++ diffsynth/diffusion/flow_match.py | 41 +++-- diffsynth/models/ltx2_text_encoder.py | 2 +- diffsynth/pipelines/ltx2_audio_video.py | 166 +++++++++++------- diffsynth/utils/data/media_io.py | 1 + .../model_inference/LTX-2-T2AV-OneStage.py | 6 +- .../model_inference/LTX-2-T2AV-TwoStage.py | 63 +++++++ 7 files changed, 227 insertions(+), 85 deletions(-) create mode 100644 examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index a1813fb4b..908361e45 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -210,4 +210,37 @@ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", }, + "diffsynth.models.ltx2_dit.LTXModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2Vocoder": { + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index aed8a3527..3f1ee8c2e 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -124,25 +124,30 @@ def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift return sigmas, timesteps @staticmethod - def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, stretch=True, terminal=0.1): - dynamic_shift_len = dynamic_shift_len or 4096 - sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( - image_seq_len=dynamic_shift_len, - base_seq_len=1024, - max_seq_len=4096, - base_shift=0.95, - max_shift=2.05, - ) + def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None): num_train_timesteps = 1000 - sigma_min = 0.0 - sigma_max = 1.0 - sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength - sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] - sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1)) - # Shift terminal - one_minus_z = 1.0 - sigmas - scale_factor = one_minus_z[-1] / (1 - terminal) - sigmas = 1.0 - (one_minus_z / scale_factor) + if special_case == "stage2": + sigmas = torch.Tensor([0.909375, 0.725, 0.421875]) + elif special_case == "ditilled_stage1": + sigmas = torch.Tensor([0.95, 0.8, 0.5, 0.2]) + else: + dynamic_shift_len = dynamic_shift_len or 4096 + sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( + image_seq_len=dynamic_shift_len, + base_seq_len=1024, + max_seq_len=4096, + base_shift=0.95, + max_shift=2.05, + ) + sigma_min = 0.0 + sigma_max = 1.0 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1 - terminal) + sigmas = 1.0 - (one_minus_z / scale_factor) timesteps = sigmas * num_train_timesteps return sigmas, timesteps diff --git a/diffsynth/models/ltx2_text_encoder.py b/diffsynth/models/ltx2_text_encoder.py index 535d5e578..7fb94a7b0 100644 --- a/diffsynth/models/ltx2_text_encoder.py +++ b/diffsynth/models/ltx2_text_encoder.py @@ -363,4 +363,4 @@ def __init__(self,): super().__init__() self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear() self.embeddings_connector = Embeddings1DConnector() - self.audio_embeddings_connector = Embeddings1DConnector() \ No newline at end of file + self.audio_embeddings_connector = Embeddings1DConnector() diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 10812881b..c921eacf9 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -60,10 +60,8 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): LTX2AudioVideoUnit_NoiseInitializer(), LTX2AudioVideoUnit_InputVideoEmbedder(), ] - self.post_units = [ - LTX2AudioVideoPostUnit_UnPatchifier(), - ] self.model_fn = model_fn_ltx2 + # self.lora_loader = LTX2LoRALoader @staticmethod def from_pretrained( @@ -71,6 +69,7 @@ def from_pretrained( device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config: Optional[ModelConfig] = None, vram_limit: float = None, ): # Initialize pipeline @@ -90,14 +89,67 @@ def from_pretrained( pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder") pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder") pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder") - pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") - # Optional + + # Stage 2 + if stage2_lora_config is not None: + stage2_lora_config.download_if_necessary() + pipe.stage2_lora_path = stage2_lora_config.path + pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") + # Optional, currently not used # pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe + def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) + noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) + if cfg_scale != 1.0: + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) + if isinstance(noise_pred_posi, tuple): + noise_pred = tuple( + n_nega + cfg_scale * (n_posi - n_nega) + for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) + ) + else: + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + return noise_pred + + def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline=True, progress_bar_cmd=tqdm): + if use_two_stage_pipeline: + latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) + self.load_models_to_device(self.in_iteration_models + ('upsampler',)) + latent = self.upsampler(latent) + latent = self.video_vae_encoder.per_channel_statistics.normalize(latent) + latent = self.video_patchifier.patchify(latent) + self.scheduler.set_timesteps(special_case="stage2") + inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) + inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent + inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"]) + inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] + self.load_models_to_device(self.in_iteration_models) + self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, + noise_pred=noise_pred_video, **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, + noise_pred=noise_pred_audio, **inputs_shared) + inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) + inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) + return inputs_shared @torch.no_grad() def __call__( @@ -126,14 +178,15 @@ def __call__( tile_overlap_in_pixels: Optional[int] = 128, tile_size_in_frames: Optional[int] = 128, tile_overlap_in_frames: Optional[int] = 24, - # Two-Stage Pipeline - use_two_stage: Optional[bool] = True, + # Special Pipelines + use_two_stage_pipeline: Optional[bool] = False, + use_distilled_pipeline: Optional[bool] = False, # progress_bar progress_bar_cmd=tqdm, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) - + # self.load_lora(self.dit, self.stage2_lora_path) # Inputs inputs_posi = { "prompt": prompt, @@ -148,14 +201,12 @@ def __call__( "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, - "use_two_stage": True + "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) - # inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt")) - # inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt")) - # Denoise + # Denoise Stage 1 self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): @@ -164,12 +215,16 @@ def __call__( self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **models, timestep=timestep, progress_id=progress_id ) - inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared) - inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, + noise_pred=noise_pred_video, **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, + noise_pred=noise_pred_audio, **inputs_shared) + inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) + inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) + + # Denoise Stage 2 + inputs_shared = self.stage2_denoise(cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline, progress_bar_cmd) - # post-denoising, pre-decoding processing logic - for unit in self.post_units: - inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Decode self.load_models_to_device(['video_vae_decoder']) video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, @@ -181,39 +236,21 @@ def __call__( return video, decoded_audio - def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): - if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) - self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) - noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) - if cfg_scale != 1.0: - if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) - noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) - if isinstance(noise_pred_posi, tuple): - noise_pred = tuple( - n_nega + cfg_scale * (n_posi - n_nega) - for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) - ) - else: - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) - else: - noise_pred = noise_pred_posi - return noise_pred - - class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): def __init__(self): super().__init__(take_over=True) def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega): - pass + if inputs_shared.get("use_two_stage_pipeline", False): + if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None): + raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.") + if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None): + raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.") return inputs_shared, inputs_posi, inputs_nega class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit): """ - # TODO: Adjust with two stage pipeline For two-stage pipelines, the resolution must be divisible by 64. For one-stage pipelines, the resolution must be divisible by 32. """ @@ -223,8 +260,14 @@ def __init__(self): output_params=("height", "width", "num_frames"), ) - def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames): + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False): + if use_two_stage_pipeline: + self.width_division_factor = 64 + self.height_division_factor = 64 height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + if use_two_stage_pipeline: + self.width_division_factor = 32 + self.height_division_factor = 32 return {"height": height, "width": width, "num_frames": num_frames} @@ -327,10 +370,12 @@ def _run_feature_extractor(self, return pipe.text_encoder_post_modules.feature_extractor_linear( normed_concated_encoded_text_features.to(encoded_text_features_dtype)) - def _preprocess_text(self, - pipe, - text: str, - padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + def _preprocess_text( + self, + pipe, + text: str, + padding_side: str = "left", + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Encode a given string into feature tensors suitable for downstream tasks. Args: @@ -339,8 +384,8 @@ def _preprocess_text(self, tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask. """ token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"] - input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device) - attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device) + input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device) + attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device) outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) projected = self._run_feature_extractor(pipe, hidden_states=outputs.hidden_states, @@ -362,11 +407,11 @@ def process(self, pipe: LTX2AudioVideoPipeline, prompt: str): class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "seed", "rand_device",), + input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"), output_params=("video_noise", "audio_noise",), ) - def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): + def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels) video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) @@ -390,6 +435,15 @@ def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, "audio_latent_shape": audio_latent_shape } + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False): + if use_two_stage_pipeline: + stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate) + stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) + initial_dict = stage1_dict + initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()}) + return initial_dict + else: + return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): @@ -407,20 +461,6 @@ def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_ raise NotImplementedError("Video-to-video not implemented yet.") -class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit): - - def __init__(self): - super().__init__( - input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"), - output_params=("video_latents", "audio_latents"), - ) - - def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents): - video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape) - audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape) - return {"video_latents": video_latents, "audio_latents": audio_latents} - - def model_fn_ltx2( dit: LTXModel, video_latents=None, diff --git a/diffsynth/utils/data/media_io.py b/diffsynth/utils/data/media_io.py index a95fdbc4e..04501860c 100644 --- a/diffsynth/utils/data/media_io.py +++ b/diffsynth/utils/data/media_io.py @@ -3,6 +3,7 @@ import torch import av from tqdm import tqdm +from PIL import Image def _resample_audio( diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py index afbdd4be7..a96efbf22 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -31,16 +31,16 @@ video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, - seed=43, + seed=10, height=height, width=width, num_frames=num_frames, - tiled=False, + tiled=True, ) write_video_audio_ltx2( video=video, audio=audio, - output_path='ltx2_onestage.mp4', + output_path='ltx2_onestage_oven.mp4', fps=24, audio_sample_rate=24000, ) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py new file mode 100644 index 000000000..b966b0a80 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py @@ -0,0 +1,63 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) + +prompt = """ +INT. OVEN – DAY. Static camera from inside the oven, looking outward through the slightly fogged glass door. Warm golden light glows around freshly baked cookies. The baker’s face fills the frame, eyes wide with focus, his breath fogging the glass as he leans in. Subtle reflections move across the glass as steam rises. +Baker (whispering dramatically): “Today… I achieve perfection.” +He leans even closer, nose nearly touching the glass. +""" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, + num_inference_steps=40, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_oven.mp4', + fps=24, + audio_sample_rate=24000, +) From 9f07d65ebb9505b74bc81cb75d6bdb78b26812a0 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 30 Jan 2026 17:40:30 +0800 Subject: [PATCH 5/6] support ltx2 distilled pipeline --- diffsynth/diffusion/flow_match.py | 2 +- diffsynth/pipelines/ltx2_audio_video.py | 29 ++++++---- .../LTX-2-T2AV-DistilledPipeline.py | 57 +++++++++++++++++++ .../model_inference/LTX-2-T2AV-OneStage.py | 10 +--- .../model_inference/LTX-2-T2AV-TwoStage.py | 11 +--- 5 files changed, 82 insertions(+), 27 deletions(-) create mode 100644 examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 3f1ee8c2e..9b31466e0 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -129,7 +129,7 @@ def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_ if special_case == "stage2": sigmas = torch.Tensor([0.909375, 0.725, 0.421875]) elif special_case == "ditilled_stage1": - sigmas = torch.Tensor([0.95, 0.8, 0.5, 0.2]) + sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]) else: dynamic_shift_len = dynamic_shift_len or 4096 sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index c921eacf9..5973e9de2 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -61,7 +61,6 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): LTX2AudioVideoUnit_InputVideoEmbedder(), ] self.model_fn = model_fn_ltx2 - # self.lora_loader = LTX2LoRALoader @staticmethod def from_pretrained( @@ -89,12 +88,12 @@ def from_pretrained( pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder") pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder") pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder") + pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") # Stage 2 if stage2_lora_config is not None: stage2_lora_config.download_if_necessary() pipe.stage2_lora_path = stage2_lora_config.path - pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") # Optional, currently not used # pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") @@ -122,8 +121,8 @@ def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, i noise_pred = noise_pred_posi return noise_pred - def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline=True, progress_bar_cmd=tqdm): - if use_two_stage_pipeline: + def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm): + if inputs_shared["use_two_stage_pipeline"]: latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) self.load_models_to_device(self.in_iteration_models + ('upsampler',)) latent = self.upsampler(latent) @@ -135,12 +134,13 @@ def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"]) inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] self.load_models_to_device(self.in_iteration_models) - self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8) + if not inputs_shared["use_distilled_pipeline"]: + self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( - self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega, **models, timestep=timestep, progress_id=progress_id ) inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, @@ -185,7 +185,8 @@ def __call__( progress_bar_cmd=tqdm, ): # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, + special_case="ditilled_stage1" if use_distilled_pipeline else None) # self.load_lora(self.dit, self.stage2_lora_path) # Inputs inputs_posi = { @@ -223,7 +224,7 @@ def __call__( inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) # Denoise Stage 2 - inputs_shared = self.stage2_denoise(cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline, progress_bar_cmd) + inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd) # Decode self.load_models_to_device(['video_vae_decoder']) @@ -241,11 +242,17 @@ def __init__(self): super().__init__(take_over=True) def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("use_distilled_pipeline", False): + inputs_shared["use_two_stage_pipeline"] = True + inputs_shared["cfg_scale"] = 1.0 + print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.") if inputs_shared.get("use_two_stage_pipeline", False): - if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None): - raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.") + # distill pipeline also uses two-stage, but it does not needs lora + if not inputs_shared.get("use_distilled_pipeline", False): + if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None): + raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.") if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None): - raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.") + raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.") return inputs_shared, inputs_posi, inputs_nega diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py new file mode 100644 index 000000000..8ee36bfaf --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py @@ -0,0 +1,57 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) + +prompt = "A girl is speaking: “I enjoy working with Diffsynth-Studio, it's a great tool.”" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_distilled_pipeline=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_distilled.mp4', + fps=24, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py index a96efbf22..4331392b0 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -21,17 +21,13 @@ ], tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), ) -prompt = """ -INT. OVEN – DAY. Static camera from inside the oven, looking outward through the slightly fogged glass door. Warm golden light glows around freshly baked cookies. The baker’s face fills the frame, eyes wide with focus, his breath fogging the glass as he leans in. Subtle reflections move across the glass as steam rises. -Baker (whispering dramatically): “Today… I achieve perfection.” -He leans even closer, nose nearly touching the glass. -""" +prompt = "A girl is speaking: “I enjoy working with Diffsynth-Studio, it's a great tool.”" negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, - seed=10, + seed=43, height=height, width=width, num_frames=num_frames, @@ -40,7 +36,7 @@ write_video_audio_ltx2( video=video, audio=audio, - output_path='ltx2_onestage_oven.mp4', + output_path='ltx2_onestage.mp4', fps=24, audio_sample_rate=24000, ) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py index b966b0a80..73e68f026 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py @@ -24,11 +24,7 @@ stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), ) -prompt = """ -INT. OVEN – DAY. Static camera from inside the oven, looking outward through the slightly fogged glass door. Warm golden light glows around freshly baked cookies. The baker’s face fills the frame, eyes wide with focus, his breath fogging the glass as he leans in. Subtle reflections move across the glass as steam rises. -Baker (whispering dramatically): “Today… I achieve perfection.” -He leans even closer, nose nearly touching the glass. -""" +prompt = "A girl is speaking: “I enjoy working with Diffsynth-Studio, it's a great tool.”" negative_prompt = ( "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " @@ -46,18 +42,17 @@ video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, - seed=0, + seed=43, height=height, width=width, num_frames=num_frames, tiled=True, use_two_stage_pipeline=True, - num_inference_steps=40, ) write_video_audio_ltx2( video=video, audio=audio, - output_path='ltx2_twostage_oven.mp4', + output_path='ltx2_twostage.mp4', fps=24, audio_sample_rate=24000, ) From 1c8a0f8317a74917283bd42610c7f62450f802aa Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Sat, 31 Jan 2026 13:55:52 +0800 Subject: [PATCH 6/6] refactor patchify --- diffsynth/models/ltx2_audio_vae.py | 14 ++++++++++++++ diffsynth/models/ltx2_video_vae.py | 18 ++++++++++++++++++ diffsynth/pipelines/ltx2_audio_video.py | 20 +++++++++++--------- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/diffsynth/models/ltx2_audio_vae.py b/diffsynth/models/ltx2_audio_vae.py index dd4694166..708ded7ef 100644 --- a/diffsynth/models/ltx2_audio_vae.py +++ b/diffsynth/models/ltx2_audio_vae.py @@ -167,6 +167,20 @@ def unpatchify( return audio_latents + def unpatchify_audio( + self, + audio_latents: torch.Tensor, + channels: int, + mel_bins: int + ) -> torch.Tensor: + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=channels, + f=mel_bins, + ) + return audio_latents + def get_patch_grid_bounds( self, output_shape: AudioLatentShape | VideoLatentShape, diff --git a/diffsynth/models/ltx2_video_vae.py b/diffsynth/models/ltx2_video_vae.py index 0bd13314c..ebc9483c6 100644 --- a/diffsynth/models/ltx2_video_vae.py +++ b/diffsynth/models/ltx2_video_vae.py @@ -68,6 +68,24 @@ def unpatchify( return latents + def unpatchify_video( + self, + latents: torch.Tensor, + frames: int, + height: int, + width: int, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=frames, + h=height // self._patch_size[1], + w=width // self._patch_size[2], + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents + def get_patch_grid_bounds( self, output_shape: AudioLatentShape | VideoLatentShape, diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 5973e9de2..4a96e36f7 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -127,11 +127,9 @@ def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_c self.load_models_to_device(self.in_iteration_models + ('upsampler',)) latent = self.upsampler(latent) latent = self.video_vae_encoder.per_channel_statistics.normalize(latent) - latent = self.video_patchifier.patchify(latent) self.scheduler.set_timesteps(special_case="stage2") inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent - inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"]) inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] self.load_models_to_device(self.in_iteration_models) if not inputs_shared["use_distilled_pipeline"]: @@ -147,8 +145,6 @@ def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_c noise_pred=noise_pred_video, **inputs_shared) inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) - inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) - inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) return inputs_shared @torch.no_grad() @@ -187,7 +183,6 @@ def __call__( # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None) - # self.load_lora(self.dit, self.stage2_lora_path) # Inputs inputs_posi = { "prompt": prompt, @@ -203,6 +198,7 @@ def __call__( "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, + "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -220,8 +216,6 @@ def __call__( noise_pred=noise_pred_video, **inputs_shared) inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) - inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) - inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) # Denoise Stage 2 inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd) @@ -422,7 +416,6 @@ def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels) video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) - video_noise = pipe.video_patchifier.patchify(video_noise) latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device) video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() @@ -431,7 +424,6 @@ def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape) audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) - audio_noise = pipe.audio_patchifier.patchify(audio_noise) audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) return { "video_noise": video_noise, @@ -473,14 +465,21 @@ def model_fn_ltx2( video_latents=None, video_context=None, video_positions=None, + video_patchifier=None, audio_latents=None, audio_context=None, audio_positions=None, + audio_patchifier=None, timestep=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, ): + # patchify + b, c_v, f, h, w = video_latents.shape + _, c_a, _, mel_bins = audio_latents.shape + video_latents = video_patchifier.patchify(video_latents) + audio_latents = audio_patchifier.patchify(audio_latents) #TODO: support gradient checkpointing timestep = timestep.float() / 1000. video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) @@ -495,4 +494,7 @@ def model_fn_ltx2( audio_context=audio_context, audio_timesteps=audio_timesteps, ) + # unpatchify + vx = video_patchifier.unpatchify_video(vx, f, h, w) + ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) return vx, ax