-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Add Support for LTX-2.3 Models #13217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 35 commits
6c7e720
e90b90a
f768f8d
cde6748
236eb8d
1e89cb3
5a44adb
4ff3168
835bed6
19004ef
4dfcfeb
13292dd
0528fde
c5e1fcc
50da4df
4206280
e719d32
fbb50d9
de3f869
5056aa8
f875031
652d363
d018534
c0bb2ef
ab0e5b5
f78c3da
63b3c9f
6188af2
89f8cc4
f1a812a
145e8e4
8a58073
93247a0
17b53f0
6ee66c9
c016ce5
2feb460
2740409
5d8b634
b0723de
67a9ce3
4cbedd7
8a9a148
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -237,7 +237,7 @@ def forward( | |
|
|
||
|
|
||
| # Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d | ||
| class LTXVideoDownsampler3d(nn.Module): | ||
| class LTX2VideoDownsampler3d(nn.Module): | ||
| def __init__( | ||
| self, | ||
| in_channels: int, | ||
|
|
@@ -285,10 +285,11 @@ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Ten | |
|
|
||
|
|
||
| # Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d | ||
| class LTXVideoUpsampler3d(nn.Module): | ||
| class LTX2VideoUpsampler3d(nn.Module): | ||
| def __init__( | ||
| self, | ||
| in_channels: int, | ||
| out_channels: int | None = None, | ||
| stride: int | tuple[int, int, int] = 1, | ||
| residual: bool = False, | ||
| upscale_factor: int = 1, | ||
|
|
@@ -300,7 +301,8 @@ def __init__( | |
| self.residual = residual | ||
| self.upscale_factor = upscale_factor | ||
|
|
||
| out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor | ||
| out_channels = out_channels or in_channels | ||
| out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor | ||
|
|
||
| self.conv = LTX2VideoCausalConv3d( | ||
| in_channels=in_channels, | ||
|
|
@@ -408,7 +410,7 @@ def __init__( | |
| ) | ||
| elif downsample_type == "spatial": | ||
| self.downsamplers.append( | ||
| LTXVideoDownsampler3d( | ||
| LTX2VideoDownsampler3d( | ||
| in_channels=in_channels, | ||
| out_channels=out_channels, | ||
| stride=(1, 2, 2), | ||
|
|
@@ -417,7 +419,7 @@ def __init__( | |
| ) | ||
| elif downsample_type == "temporal": | ||
| self.downsamplers.append( | ||
| LTXVideoDownsampler3d( | ||
| LTX2VideoDownsampler3d( | ||
| in_channels=in_channels, | ||
| out_channels=out_channels, | ||
| stride=(2, 1, 1), | ||
|
|
@@ -426,7 +428,7 @@ def __init__( | |
| ) | ||
| elif downsample_type == "spatiotemporal": | ||
| self.downsamplers.append( | ||
| LTXVideoDownsampler3d( | ||
| LTX2VideoDownsampler3d( | ||
| in_channels=in_channels, | ||
| out_channels=out_channels, | ||
| stride=(2, 2, 2), | ||
|
|
@@ -580,6 +582,7 @@ def __init__( | |
| resnet_eps: float = 1e-6, | ||
| resnet_act_fn: str = "swish", | ||
| spatio_temporal_scale: bool = True, | ||
| upsample_type: str = "spatiotemporal", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this go at the last of init params to prevent backwards breaking in case someone is using positional arguments?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I put |
||
| inject_noise: bool = False, | ||
| timestep_conditioning: bool = False, | ||
| upsample_residual: bool = False, | ||
|
|
@@ -609,17 +612,38 @@ def __init__( | |
|
|
||
| self.upsamplers = None | ||
| if spatio_temporal_scale: | ||
| self.upsamplers = nn.ModuleList( | ||
| [ | ||
| LTXVideoUpsampler3d( | ||
| out_channels * upscale_factor, | ||
| self.upsamplers = nn.ModuleList() | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like if upsample_type == "spatial":
stride = (1, 2, 2)
elif upsample_type == "temporal":
stride = (2, 1, 1)
elif upsample_type == "spatio_temporal":
stride = (2, 2, 2)
self.upsamplers.append(..., strides=strides)WDYT? |
||
| if upsample_type == "spatial": | ||
| self.upsamplers.append( | ||
| LTX2VideoUpsampler3d( | ||
| in_channels=out_channels * upscale_factor, | ||
| stride=(1, 2, 2), | ||
| residual=upsample_residual, | ||
| upscale_factor=upscale_factor, | ||
| spatial_padding_mode=spatial_padding_mode, | ||
| ) | ||
| ) | ||
| elif upsample_type == "temporal": | ||
| self.upsamplers.append( | ||
| LTX2VideoUpsampler3d( | ||
| in_channels=out_channels * upscale_factor, | ||
| stride=(2, 1, 1), | ||
| residual=upsample_residual, | ||
| upscale_factor=upscale_factor, | ||
| spatial_padding_mode=spatial_padding_mode, | ||
| ) | ||
| ) | ||
| elif upsample_type == "spatiotemporal": | ||
| self.upsamplers.append( | ||
| LTX2VideoUpsampler3d( | ||
| in_channels=out_channels * upscale_factor, | ||
| stride=(2, 2, 2), | ||
| residual=upsample_residual, | ||
| upscale_factor=upscale_factor, | ||
| spatial_padding_mode=spatial_padding_mode, | ||
| ) | ||
| ] | ||
| ) | ||
| ) | ||
|
|
||
| resnets = [] | ||
| for _ in range(num_layers): | ||
|
|
@@ -716,7 +740,7 @@ def __init__( | |
| "LTX2VideoDownBlock3D", | ||
| "LTX2VideoDownBlock3D", | ||
| ), | ||
| spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), | ||
| spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), | ||
| layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), | ||
| downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), | ||
| patch_size: int = 4, | ||
|
|
@@ -726,6 +750,9 @@ def __init__( | |
| spatial_padding_mode: str = "zeros", | ||
| ): | ||
| super().__init__() | ||
| num_encoder_blocks = len(layers_per_block) | ||
| if isinstance(spatio_temporal_scaling, bool): | ||
| spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) | ||
|
|
||
| self.patch_size = patch_size | ||
| self.patch_size_t = patch_size_t | ||
|
|
@@ -860,19 +887,27 @@ def __init__( | |
| in_channels: int = 128, | ||
| out_channels: int = 3, | ||
| block_out_channels: tuple[int, ...] = (256, 512, 1024), | ||
| spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), | ||
| spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), | ||
| layers_per_block: tuple[int, ...] = (5, 5, 5, 5), | ||
| upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), | ||
| patch_size: int = 4, | ||
| patch_size_t: int = 1, | ||
| resnet_norm_eps: float = 1e-6, | ||
| is_causal: bool = False, | ||
| inject_noise: tuple[bool, ...] = (False, False, False), | ||
| inject_noise: bool | tuple[bool, ...] = (False, False, False), | ||
| timestep_conditioning: bool = False, | ||
| upsample_residual: tuple[bool, ...] = (True, True, True), | ||
| upsample_residual: bool | tuple[bool, ...] = (True, True, True), | ||
| upsample_factor: tuple[bool, ...] = (2, 2, 2), | ||
| spatial_padding_mode: str = "reflect", | ||
| ) -> None: | ||
| super().__init__() | ||
| num_decoder_blocks = len(layers_per_block) | ||
| if isinstance(spatio_temporal_scaling, bool): | ||
| spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1) | ||
| if isinstance(inject_noise, bool): | ||
| inject_noise = (inject_noise,) * num_decoder_blocks | ||
| if isinstance(upsample_residual, bool): | ||
| upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) | ||
|
|
||
| self.patch_size = patch_size | ||
| self.patch_size_t = patch_size_t | ||
|
|
@@ -917,6 +952,7 @@ def __init__( | |
| num_layers=layers_per_block[i + 1], | ||
| resnet_eps=resnet_norm_eps, | ||
| spatio_temporal_scale=spatio_temporal_scaling[i], | ||
| upsample_type=upsample_type[i], | ||
| inject_noise=inject_noise[i + 1], | ||
| timestep_conditioning=timestep_conditioning, | ||
| upsample_residual=upsample_residual[i], | ||
|
|
@@ -1058,11 +1094,12 @@ def __init__( | |
| decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024), | ||
| layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), | ||
| decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5), | ||
| spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), | ||
| decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), | ||
| decoder_inject_noise: tuple[bool, ...] = (False, False, False, False), | ||
| spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True), | ||
| decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True), | ||
| decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False), | ||
| downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), | ||
| upsample_residual: tuple[bool, ...] = (True, True, True), | ||
| upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"), | ||
| upsample_residual: bool | tuple[bool, ...] = (True, True, True), | ||
| upsample_factor: tuple[int, ...] = (2, 2, 2), | ||
| timestep_conditioning: bool = False, | ||
| patch_size: int = 4, | ||
|
|
@@ -1077,6 +1114,16 @@ def __init__( | |
| temporal_compression_ratio: int = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| num_encoder_blocks = len(layers_per_block) | ||
| num_decoder_blocks = len(decoder_layers_per_block) | ||
| if isinstance(spatio_temporal_scaling, bool): | ||
| spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1) | ||
| if isinstance(decoder_spatio_temporal_scaling, bool): | ||
| decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1) | ||
| if isinstance(decoder_inject_noise, bool): | ||
| decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks | ||
| if isinstance(upsample_residual, bool): | ||
| upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1) | ||
|
|
||
| self.encoder = LTX2VideoEncoder3d( | ||
| in_channels=in_channels, | ||
|
|
@@ -1098,6 +1145,7 @@ def __init__( | |
| block_out_channels=decoder_block_out_channels, | ||
| spatio_temporal_scaling=decoder_spatio_temporal_scaling, | ||
| layers_per_block=decoder_layers_per_block, | ||
| upsample_type=upsample_type, | ||
| patch_size=patch_size, | ||
| patch_size_t=patch_size_t, | ||
| resnet_norm_eps=resnet_norm_eps, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did this pop up? Distillation checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
prompt_adalnandaudio_prompt_adalnmodules are used by both the full model and distilled model to calculate scale/shift modulation parameters for the textencoder_hidden_statesfor the video and audio modalities respectively. (I believe this is in place of thecaption_projections, which were removed in LTX-2.3.)