-
Notifications
You must be signed in to change notification settings - Fork 7k
Expand file tree
/
Copy pathautoencoder_kl_ltx2.py
More file actions
1568 lines (1344 loc) · 65.2 KB
/
autoencoder_kl_ltx2.py
File metadata and controls
1568 lines (1344 loc) · 65.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class PerChannelRMSNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values
across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.channel_dim = channel_dim
self.eps = eps
def forward(self, x: torch.Tensor, channel_dim: int | None = None) -> torch.Tensor:
"""
Apply RMS normalization along the configured dimension.
"""
channel_dim = channel_dim or self.channel_dim
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True)
# Normalize by the root-mean-square (RMS).
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime
class LTX2VideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int] = 3,
stride: int | tuple[int, int, int] = 1,
dilation: int | tuple[int, int, int] = 1,
groups: int = 1,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
height_pad = self.kernel_size[1] // 2
width_pad = self.kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d(
in_channels,
out_channels,
self.kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
padding=padding,
padding_mode=spatial_padding_mode,
)
def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
time_kernel_size = self.kernel_size[0]
if causal:
pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1))
hidden_states = torch.concatenate([pad_left, hidden_states], dim=2)
else:
pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2)
hidden_states = self.conv(hidden_states)
return hidden_states
# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding
# mode is configurable
class LTX2VideoResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX 2.0 audiovisual model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
dropout (`float`, defaults to `0.0`):
Dropout rate.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
elementwise_affine (`bool`, defaults to `False`):
Whether to enable elementwise affinity in the normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
"""
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
dropout: float = 0.0,
eps: float = 1e-6,
elementwise_affine: bool = False,
non_linearity: str = "swish",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = PerChannelRMSNorm()
self.conv1 = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
spatial_padding_mode=spatial_padding_mode,
)
self.norm2 = PerChannelRMSNorm()
self.dropout = nn.Dropout(dropout)
self.conv2 = LTX2VideoCausalConv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
spatial_padding_mode=spatial_padding_mode,
)
self.norm3 = None
self.conv_shortcut = None
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
# LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d
self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1)
self.per_channel_scale1 = None
self.per_channel_scale2 = None
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.scale_shift_table = None
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
def forward(
self,
inputs: torch.Tensor,
temb: torch.Tensor | None = None,
generator: torch.Generator | None = None,
causal: bool = True,
) -> torch.Tensor:
hidden_states = inputs
hidden_states = self.norm1(hidden_states)
if self.scale_shift_table is not None:
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
hidden_states = hidden_states * (1 + scale_1) + shift_1
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.per_channel_scale1 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
hidden_states = self.norm2(hidden_states)
if self.scale_shift_table is not None:
hidden_states = hidden_states * (1 + scale_2) + shift_2
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.per_channel_scale2 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
if self.norm3 is not None:
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
if self.conv_shortcut is not None:
inputs = self.conv_shortcut(inputs)
hidden_states = hidden_states + inputs
return hidden_states
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
class LTX2VideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int | tuple[int, int, int] = 1,
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states, causal=causal)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
hidden_states = hidden_states + residual
return hidden_states
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
class LTX2VideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
stride: int | tuple[int, int, int] = 1,
residual: bool = False,
upscale_factor: int = 1,
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.residual = residual
self.upscale_factor = upscale_factor
out_channels = out_channels or in_channels
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.residual:
residual = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
)
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
residual = residual.repeat(1, repeats, 1, 1, 1)
residual = residual[:, :, self.stride[0] - 1 :]
hidden_states = self.conv(hidden_states, causal=causal)
hidden_states = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
)
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
if self.residual:
hidden_states = hidden_states + residual
return hidden_states
# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d
class LTX2VideoDownBlock3D(nn.Module):
r"""
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
downsample_type: str = "conv",
spatial_padding_mode: str = "zeros",
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTX2VideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
spatial_padding_mode=spatial_padding_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList()
if downsample_type == "conv":
self.downsamplers.append(
LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor | None = None,
generator: torch.Generator | None = None,
causal: bool = True,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
else:
hidden_states = resnet(hidden_states, temb, generator, causal=causal)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, causal=causal)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d
class LTX2VideoMidBlock3d(nn.Module):
r"""
A middle block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
resnets = []
for _ in range(num_layers):
resnets.append(
LTX2VideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor | None = None,
generator: torch.Generator | None = None,
causal: bool = True,
) -> torch.Tensor:
r"""Forward method of the `LTXMidBlock3D` class."""
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
else:
hidden_states = resnet(hidden_states, temb, generator, causal=causal)
return hidden_states
# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d
class LTX2VideoUpBlock3d(nn.Module):
r"""
Up block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
upsample_type: str = "spatiotemporal",
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
upscale_factor: int = 1,
spatial_padding_mode: str = "zeros",
):
super().__init__()
out_channels = out_channels or in_channels
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTX2VideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList()
if upsample_type == "spatial":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(1, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "temporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 1, 1),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "spatiotemporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
resnets = []
for _ in range(num_layers):
resnets.append(
LTX2VideoResnetBlock3d(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor | None = None,
generator: torch.Generator | None = None,
causal: bool = True,
) -> torch.Tensor:
if self.conv_in is not None:
hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal)
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, causal=causal)
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
else:
hidden_states = resnet(hidden_states, temb, generator, causal=causal)
return hidden_states
# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is
# different, as is the layers_per_block (the 2.0 VAE is bigger)
class LTX2VideoEncoder3d(nn.Module):
r"""
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
representation.
Args:
in_channels (`int`, defaults to 3):
Number of input channels.
out_channels (`int`, defaults to 128):
Number of latent channels.
block_out_channels (`tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`):
The number of output channels for each block.
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, True)`:
Whether a block should contain spatio-temporal downscaling layers or not.
layers_per_block (`tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`):
The number of layers per block.
downsample_type (`tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`):
The spatiotemporal downsampling pattern per block. Per-layer values can be
- `"spatial"` (downsample spatial dims by 2x)
- `"temporal"` (downsample temporal dim by 2x)
- `"spatiotemporal"` (downsample both spatial and temporal dims by 2x)
patch_size (`int`, defaults to `4`):
The size of spatial patches.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches.
resnet_norm_eps (`float`, defaults to `1e-6`):
Epsilon value for ResNet normalization layers.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 128,
block_out_channels: tuple[int, ...] = (256, 512, 1024, 2048),
down_block_types: tuple[str, ...] = (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = True,
spatial_padding_mode: str = "zeros",
):
super().__init__()
num_encoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.in_channels = in_channels * patch_size**2
self.is_causal = is_causal
output_channel = out_channels
self.conv_in = LTX2VideoCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
# down blocks
num_block_out_channels = len(block_out_channels)
self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i]
if down_block_types[i] == "LTX2VideoDownBlock3D":
down_block = LTX2VideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
downsample_type=downsample_type[i],
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
self.down_blocks.append(down_block)
# mid block
self.mid_block = LTX2VideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
spatial_padding_mode=spatial_padding_mode,
)
# out
self.norm_out = PerChannelRMSNorm()
self.conv_act = nn.SiLU()
self.conv_out = LTX2VideoCausalConv3d(
in_channels=output_channel,
out_channels=out_channels + 1,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor, causal: bool | None = None) -> torch.Tensor:
r"""The forward method of the `LTXVideoEncoder3d` class."""
p = self.patch_size
p_t = self.patch_size_t
batch_size, num_channels, num_frames, height, width = hidden_states.shape
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
causal = causal or self.is_causal
hidden_states = hidden_states.reshape(
batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
)
# Thanks for driving me insane with the weird patching order :(
hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
hidden_states = self.conv_in(hidden_states, causal=causal)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states, causal=causal)
hidden_states = self.mid_block(hidden_states, causal=causal)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states, causal=causal)
last_channel = hidden_states[:, -1:]
last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
hidden_states = torch.cat([hidden_states, last_channel], dim=1)
return hidden_states
# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2
class LTX2VideoDecoder3d(nn.Module):
r"""
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, defaults to 128):
Number of latent channels.
out_channels (`int`, defaults to 3):
Number of output channels.
block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
The number of output channels for each block.
spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`:
Whether a block should contain spatio-temporal upscaling layers or not.
layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
The number of layers per block.
patch_size (`int`, defaults to `4`):
The size of spatial patches.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches.
resnet_norm_eps (`float`, defaults to `1e-6`):
Epsilon value for ResNet normalization layers.
is_causal (`bool`, defaults to `False`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
timestep_conditioning (`bool`, defaults to `False`):
Whether to condition the model on timesteps.
"""
def __init__(
self,
in_channels: int = 128,
out_channels: int = 3,
block_out_channels: tuple[int, ...] = (256, 512, 1024),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: bool | tuple[bool, ...] = (False, False, False),
timestep_conditioning: bool = False,
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[bool, ...] = (2, 2, 2),
spatial_padding_mode: str = "reflect",
) -> None:
super().__init__()
num_decoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
if isinstance(inject_noise, bool):
inject_noise = (inject_noise,) * num_decoder_blocks
if isinstance(upsample_residual, bool):
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.out_channels = out_channels * patch_size**2
self.is_causal = is_causal
block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
layers_per_block = tuple(reversed(layers_per_block))
inject_noise = tuple(reversed(inject_noise))
upsample_residual = tuple(reversed(upsample_residual))
upsample_factor = tuple(reversed(upsample_factor))
output_channel = block_out_channels[0]
self.conv_in = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
self.mid_block = LTX2VideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[0],
resnet_eps=resnet_norm_eps,
inject_noise=inject_noise[0],
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
# up blocks
num_block_out_channels = len(block_out_channels)
self.up_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel // upsample_factor[i]
output_channel = block_out_channels[i] // upsample_factor[i]
up_block = LTX2VideoUpBlock3d(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
upsample_type=upsample_type[i],
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
upscale_factor=upsample_factor[i],
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks.append(up_block)
# out
self.norm_out = PerChannelRMSNorm()
self.conv_act = nn.SiLU()
self.conv_out = LTX2VideoCausalConv3d(
in_channels=output_channel,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor | None = None,
causal: bool | None = None,
) -> torch.Tensor:
causal = causal or self.is_causal
hidden_states = self.conv_in(hidden_states, causal=causal)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing: