Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); #9708

Merged
merged 101 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
6e616a9
first add a script for DC-AE;
lawrence-cj Oct 18, 2024
d2e187a
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 23, 2024
90e8939
DC-AE init
chenjy2003 Oct 23, 2024
825c975
replace triton with custom implementation
chenjy2003 Oct 23, 2024
3a44fa4
1. rename file and remove un-used codes;
lawrence-cj Oct 23, 2024
55b2615
no longer rely on omegaconf and dataclass
chenjy2003 Oct 25, 2024
6fb7fdb
merge
chenjy2003 Oct 25, 2024
c323e76
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 25, 2024
da7caa5
replace custom activation with diffuers activation
chenjy2003 Oct 25, 2024
fb6d92a
remove dc_ae attention in attention_processor.py
chenjy2003 Oct 25, 2024
5e63a1a
iinherit from ModelMixin
chenjy2003 Oct 25, 2024
72cce2b
inherit from ConfigMixin
chenjy2003 Oct 25, 2024
8f9b4e4
dc-ae reduce to one file
chenjy2003 Oct 31, 2024
b7f68f9
Merge remote-tracking branch 'upstream/main' into DC-AE
chenjy2003 Oct 31, 2024
6d96b95
Merge branch 'huggingface:main' into DC-AE
lawrence-cj Nov 4, 2024
3c3cc51
Merge remote-tracking branch 'refs/remotes/origin/main' into DC-AE
lawrence-cj Nov 6, 2024
1448681
update downsample and upsample
chenjy2003 Nov 9, 2024
bf40fe8
merge
chenjy2003 Nov 9, 2024
dd7718a
clean code
chenjy2003 Nov 9, 2024
19986a5
support DecoderOutput
chenjy2003 Nov 9, 2024
3481e23
Merge branch 'main' into DC-AE
lawrence-cj Nov 9, 2024
0e818df
Merge branch 'main' into DC-AE
lawrence-cj Nov 13, 2024
c6eb233
remove get_same_padding and val2tuple
chenjy2003 Nov 14, 2024
59de0a3
remove autocast and some assert
chenjy2003 Nov 14, 2024
ea604a4
update ResBlock
chenjy2003 Nov 14, 2024
80dce02
remove contents within super().__init__
chenjy2003 Nov 14, 2024
1752afd
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 16, 2024
883bcf4
remove opsequential
chenjy2003 Nov 16, 2024
25ae389
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 16, 2024
96e844b
update other blocks to support the removal of build_norm
chenjy2003 Nov 16, 2024
59b6e25
Merge branch 'main' into DC-AE
sayakpaul Nov 16, 2024
7ce9ff2
remove build encoder/decoder project in/out
chenjy2003 Nov 16, 2024
30d6308
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 16, 2024
cab56b1
remove inheritance of RMSNorm2d from LayerNorm
chenjy2003 Nov 16, 2024
b42bb54
remove reset_parameters for RMSNorm2d
chenjy2003 Nov 20, 2024
2e04a99
remove device and dtype in RMSNorm2d __init__
chenjy2003 Nov 20, 2024
b4f75f2
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
c82f828
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
22ea5fd
Update src/diffusers/models/autoencoders/dc_ae.py
lawrence-cj Nov 21, 2024
4f5cbb4
remove op_list & build_block
chenjy2003 Nov 26, 2024
2f6bbad
remove build_stage_main
chenjy2003 Nov 26, 2024
4495783
Merge branch 'main' into DC-AE
lawrence-cj Nov 26, 2024
4d3c026
change file name to autoencoder_dc
chenjy2003 Nov 28, 2024
e007057
Merge branch 'DC-AE' of github.com:lawrence-cj/diffusers into DC-AE
chenjy2003 Nov 28, 2024
d3d9c84
move LiteMLA to attention.py
chenjy2003 Nov 28, 2024
be9826c
align with other vae decode output;
lawrence-cj Nov 28, 2024
20da201
add DC-AE into init files;
lawrence-cj Nov 28, 2024
5ed50e9
update
a-r-r-o-w Nov 28, 2024
2d59056
make quality && make style;
lawrence-cj Nov 28, 2024
c1c02a2
quick push before dgx disappears again
a-r-r-o-w Nov 28, 2024
1f8a3b3
update
a-r-r-o-w Nov 28, 2024
7b9d7e5
make style
a-r-r-o-w Nov 28, 2024
bf6c211
update
a-r-r-o-w Nov 28, 2024
a2ec5f8
update
a-r-r-o-w Nov 28, 2024
f5876c5
fix
a-r-r-o-w Nov 28, 2024
44034a6
refactor
a-r-r-o-w Nov 29, 2024
6379241
refactor
a-r-r-o-w Nov 29, 2024
77571a8
refactor
a-r-r-o-w Nov 29, 2024
c4d0867
update
a-r-r-o-w Nov 30, 2024
0bdb7ef
possibly change to nn.Linear
a-r-r-o-w Nov 30, 2024
54e933b
refactor
a-r-r-o-w Nov 30, 2024
babc9f5
Merge branch 'main' into aryan-dcae
a-r-r-o-w Nov 30, 2024
3d5faaf
make fix-copies
a-r-r-o-w Nov 30, 2024
65edfa5
resolve conflicts & merge
a-r-r-o-w Dec 1, 2024
ca3ac4d
replace vae with ae
chenjy2003 Dec 3, 2024
9ef7b59
replace get_block_from_block_type to get_block
chenjy2003 Dec 3, 2024
074817c
replace downsample_block_type from Conv to conv for consistency
chenjy2003 Dec 3, 2024
64de66a
add scaling factors
chenjy2003 Dec 3, 2024
0bda5c5
incorporate changes for all checkpoints
a-r-r-o-w Dec 4, 2024
eb64d52
make style
a-r-r-o-w Dec 4, 2024
4a224ce
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 4, 2024
30c3238
move mla to attention processor file; split qkv conv to linears
a-r-r-o-w Dec 4, 2024
39a947c
refactor
a-r-r-o-w Dec 4, 2024
68f817a
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 4, 2024
da834d5
add tests
a-r-r-o-w Dec 4, 2024
632ad3b
Merge branch 'main' into DC-AE
lawrence-cj Dec 4, 2024
d6c748c
from original file loader
a-r-r-o-w Dec 4, 2024
46eb504
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 4, 2024
31f9fc6
add docs
a-r-r-o-w Dec 4, 2024
6f29e2a
add standard autoencoder methods
a-r-r-o-w Dec 4, 2024
b6e8fba
combine attention processor
yiyixuxu Dec 4, 2024
f862bae
fix tests
a-r-r-o-w Dec 5, 2024
f9fce24
update
a-r-r-o-w Dec 5, 2024
e594745
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 5, 2024
3c0b1ca
minor fix
chenjy2003 Dec 5, 2024
91057d4
minor fix
chenjy2003 Dec 5, 2024
67aa715
Merge branch 'main' into DC-AE
lawrence-cj Dec 5, 2024
eda66e1
minor fix & in/out shortcut rename
chenjy2003 Dec 5, 2024
e3d33e6
minor fix
chenjy2003 Dec 5, 2024
cc97502
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 5, 2024
2b370df
make style
a-r-r-o-w Dec 5, 2024
94355ab
fix paper link
chenjy2003 Dec 6, 2024
a191f07
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 6, 2024
116c049
update docs
a-r-r-o-w Dec 6, 2024
b6e0aba
update single file loading
a-r-r-o-w Dec 6, 2024
ec4e84f
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 6, 2024
dbae8f1
make style
a-r-r-o-w Dec 6, 2024
042c2a0
remove single file loading support; todo for DN6
a-r-r-o-w Dec 6, 2024
f2525b9
Apply suggestions from code review
a-r-r-o-w Dec 6, 2024
d3d224c
Merge branch 'main' into DC-AE
a-r-r-o-w Dec 6, 2024
6122b84
add abstract
a-r-r-o-w Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
from original file loader
  • Loading branch information
a-r-r-o-w committed Dec 4, 2024
commit d6c748c7e665cbf0f4392dedd2b4b32253b3d462
4 changes: 0 additions & 4 deletions scripts/convert_dcae_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
from diffusers import AutoencoderDC


def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)


def remap_qkv_(key: str, state_dict: Dict[str, Any]):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_autoencoder_dc_config_from_original,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
Expand Down Expand Up @@ -82,6 +84,10 @@
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderDC": {
"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers,
"config_mapping_fn": create_autoencoder_dc_config_from_original,
},
}


Expand Down Expand Up @@ -228,7 +234,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
f"`original_config` has been provided for {mapping_class_name} but no mapping function "
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
Expand Down
256 changes: 255 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conversion script for the Stable Diffusion checkpoints."""

"""
Conversion scripts for the various modeling checkpoints. These scripts convert original model implementations to
Diffusers adapted versions. This usually only involves renaming/remapping the state dict keys and changing some
modeling components partially (for example, splitting a single QKV linear to individual Q, K, V layers).
"""

import copy
import os
Expand Down Expand Up @@ -92,6 +97,7 @@
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"autoencoder_dc": "decoder.stages.0.op_list.0.main.conv.conv.weight",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to infer the model repo type using this key right? That still has to be added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, missed it. Adding now, but not sure how this worked before then 🤔

}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -2198,3 +2204,251 @@ def swap_scale_shift(weight):
)

return converted_state_dict


def create_autoencoder_dc_config_from_original(original_config, checkpoint, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for new single file models let's not rely on the original configs anymore. This was for legacy support for the SD1.5/XL models with yaml configs. It's better to infer the diffusers config from the checkpoint and use that for loading.

Copy link
Member

@a-r-r-o-w a-r-r-o-w Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a little difficult here, so please lmk if you have any suggestions on what to do.

Some DCAE checkpoints have the exact same structure and configuration, except for scaling_factor. For example, dc-ae-f128c512-in-1.0-diffusers and dc-ae-f128c512-mix-1.0-diffusers` only differ in their scaling factor.

I'm unsure how we would determine this just by the model structure. Do we rely on the user passing it as a config correctly, and document this info somewhere?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine since in the snippet in the docs, we're doing the same thing just with original_config instead of config right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated usage to config now and verified that it works. Thank you for the fixes and suggestions!

model_name = original_config.get("model_name", "dc-ae-f32c32-sana-1.0")
print("trying:", model_name)

if model_name in ["dc-ae-f32c32-sana-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
elif model_name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
}
if model_name == "dc-ae-f32c32-in-1.0":
config["scaling_factor"] = 0.3189
elif model_name == "dc-ae-f32c32-mix-1.0":
config["scaling_factor"] = 0.4552
elif model_name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
config = {
"latent_channels": 128,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
}
if model_name == "dc-ae-f64c128-in-1.0":
config["scaling_factor"] = 0.2889
elif model_name == "dc-ae-f64c128-mix-1.0":
config["scaling_factor"] = 0.4538
elif model_name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
config = {
"latent_channels": 512,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
}
if model_name == "dc-ae-f128c512-in-1.0":
config["scaling_factor"] = 0.4883
elif model_name == "dc-ae-f128c512-mix-1.0":
config["scaling_factor"] = 0.3620

config.update({"model_name": model_name})

return config


def convert_autoencoder_dc_checkpoint_to_diffusers(config, checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
model_name = config.pop("model_name")

def remap_qkv_(key: str, state_dict):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()

def remap_proj_conv_(key: str, state_dict):
parent_module, _, _ = key.rpartition(".proj.conv.weight")
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()

AE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
"encoder.stages": "encoder.down_blocks",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out",
"decoder.project_out.2.conv": "decoder.conv_out",
"decoder.stages": "decoder.up_blocks",
}

AE_F32C32_KEYS = {
"encoder.project_in.conv": "encoder.conv_in.conv",
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_F64C128_KEYS = {
"encoder.project_in.conv": "encoder.conv_in.conv",
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_F128C512_KEYS = {
"encoder.project_in.conv": "encoder.conv_in.conv",
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
"proj.conv.weight": remap_proj_conv_,
}

if "f32c32" in model_name and "sana" not in model_name:
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
elif "f64c128" in model_name:
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
elif "f128c512" in model_name:
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)

for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)

for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)

return converted_state_dict
3 changes: 2 additions & 1 deletion src/diffusers/models/autoencoders/autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ..activations import get_activation
from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin
Expand Down Expand Up @@ -394,7 +395,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class AutoencoderDC(ModelMixin, ConfigMixin):
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
[SANA](https://arxiv.org/abs/2410.10629).
Expand Down