-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); #9708
Changes from 1 commit
6e616a9
d2e187a
90e8939
825c975
3a44fa4
55b2615
6fb7fdb
c323e76
da7caa5
fb6d92a
5e63a1a
72cce2b
8f9b4e4
b7f68f9
6d96b95
3c3cc51
1448681
bf40fe8
dd7718a
19986a5
3481e23
0e818df
c6eb233
59de0a3
ea604a4
80dce02
1752afd
883bcf4
25ae389
96e844b
59b6e25
7ce9ff2
30d6308
cab56b1
b42bb54
2e04a99
b4f75f2
c82f828
22ea5fd
4f5cbb4
2f6bbad
4495783
4d3c026
e007057
d3d9c84
be9826c
20da201
5ed50e9
2d59056
c1c02a2
1f8a3b3
7b9d7e5
bf6c211
a2ec5f8
f5876c5
44034a6
6379241
77571a8
c4d0867
0bdb7ef
54e933b
babc9f5
3d5faaf
65edfa5
ca3ac4d
9ef7b59
074817c
64de66a
0bda5c5
eb64d52
4a224ce
30c3238
39a947c
68f817a
da834d5
632ad3b
d6c748c
46eb504
31f9fc6
6f29e2a
b6e8fba
f862bae
f9fce24
e594745
3c0b1ca
91057d4
67aa715
eda66e1
e3d33e6
cc97502
2b370df
94355ab
a191f07
116c049
b6e0aba
ec4e84f
dbae8f1
042c2a0
f2525b9
d3d224c
6122b84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,12 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Conversion script for the Stable Diffusion checkpoints.""" | ||
|
||
""" | ||
Conversion scripts for the various modeling checkpoints. These scripts convert original model implementations to | ||
Diffusers adapted versions. This usually only involves renaming/remapping the state dict keys and changing some | ||
modeling components partially (for example, splitting a single QKV linear to individual Q, K, V layers). | ||
""" | ||
|
||
import copy | ||
import os | ||
|
@@ -92,6 +97,7 @@ | |
"double_blocks.0.img_attn.norm.key_norm.scale", | ||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", | ||
], | ||
"autoencoder_dc": "decoder.stages.0.op_list.0.main.conv.conv.weight", | ||
} | ||
|
||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = { | ||
|
@@ -2198,3 +2204,251 @@ def swap_scale_shift(weight): | |
) | ||
|
||
return converted_state_dict | ||
|
||
|
||
def create_autoencoder_dc_config_from_original(original_config, checkpoint, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for new single file models let's not rely on the original configs anymore. This was for legacy support for the SD1.5/XL models with yaml configs. It's better to infer the diffusers config from the checkpoint and use that for loading. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be a little difficult here, so please lmk if you have any suggestions on what to do. Some DCAE checkpoints have the exact same structure and configuration, except for I'm unsure how we would determine this just by the model structure. Do we rely on the user passing it as a config correctly, and document this info somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's fine since in the snippet in the docs, we're doing the same thing just with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated usage to |
||
model_name = original_config.get("model_name", "dc-ae-f32c32-sana-1.0") | ||
print("trying:", model_name) | ||
|
||
if model_name in ["dc-ae-f32c32-sana-1.0"]: | ||
config = { | ||
"latent_channels": 32, | ||
"encoder_block_types": ( | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
), | ||
"decoder_block_types": ( | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
), | ||
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | ||
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), | ||
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | ||
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), | ||
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3), | ||
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3], | ||
"downsample_block_type": "conv", | ||
"upsample_block_type": "interpolate", | ||
"decoder_norm_types": "rms_norm", | ||
"decoder_act_fns": "silu", | ||
"scaling_factor": 0.41407, | ||
} | ||
elif model_name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: | ||
config = { | ||
"latent_channels": 32, | ||
"encoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"decoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | ||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], | ||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2], | ||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2], | ||
"encoder_qkv_multiscales": ((), (), (), (), (), ()), | ||
"decoder_qkv_multiscales": ((), (), (), (), (), ()), | ||
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], | ||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], | ||
} | ||
if model_name == "dc-ae-f32c32-in-1.0": | ||
config["scaling_factor"] = 0.3189 | ||
elif model_name == "dc-ae-f32c32-mix-1.0": | ||
config["scaling_factor"] = 0.4552 | ||
elif model_name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: | ||
config = { | ||
"latent_channels": 128, | ||
"encoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"decoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | ||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], | ||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], | ||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], | ||
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()), | ||
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()), | ||
"decoder_norm_types": [ | ||
"batch_norm", | ||
"batch_norm", | ||
"batch_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
], | ||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], | ||
} | ||
if model_name == "dc-ae-f64c128-in-1.0": | ||
config["scaling_factor"] = 0.2889 | ||
elif model_name == "dc-ae-f64c128-mix-1.0": | ||
config["scaling_factor"] = 0.4538 | ||
elif model_name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: | ||
config = { | ||
"latent_channels": 512, | ||
"encoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"decoder_block_types": [ | ||
"ResBlock", | ||
"ResBlock", | ||
"ResBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
"EfficientViTBlock", | ||
], | ||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | ||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], | ||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], | ||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], | ||
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | ||
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), | ||
"decoder_norm_types": [ | ||
"batch_norm", | ||
"batch_norm", | ||
"batch_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
"rms_norm", | ||
], | ||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], | ||
} | ||
if model_name == "dc-ae-f128c512-in-1.0": | ||
config["scaling_factor"] = 0.4883 | ||
elif model_name == "dc-ae-f128c512-mix-1.0": | ||
config["scaling_factor"] = 0.3620 | ||
|
||
config.update({"model_name": model_name}) | ||
|
||
return config | ||
|
||
|
||
def convert_autoencoder_dc_checkpoint_to_diffusers(config, checkpoint, **kwargs): | ||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} | ||
model_name = config.pop("model_name") | ||
|
||
def remap_qkv_(key: str, state_dict): | ||
qkv = state_dict.pop(key) | ||
q, k, v = torch.chunk(qkv, 3, dim=0) | ||
parent_module, _, _ = key.rpartition(".qkv.conv.weight") | ||
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() | ||
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() | ||
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() | ||
|
||
def remap_proj_conv_(key: str, state_dict): | ||
parent_module, _, _ = key.rpartition(".proj.conv.weight") | ||
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() | ||
|
||
AE_KEYS_RENAME_DICT = { | ||
# common | ||
"main.": "", | ||
"op_list.": "", | ||
"context_module": "attn", | ||
"local_module": "conv_out", | ||
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 | ||
# If there were more scales, there would be more layers, so a loop would be better to handle this | ||
"aggreg.0.0": "to_qkv_multiscale.0.proj_in", | ||
"aggreg.0.1": "to_qkv_multiscale.0.proj_out", | ||
"depth_conv.conv": "conv_depth", | ||
"inverted_conv.conv": "conv_inverted", | ||
"point_conv.conv": "conv_point", | ||
"point_conv.norm": "norm", | ||
"conv.conv.": "conv.", | ||
"conv1.conv": "conv1", | ||
"conv2.conv": "conv2", | ||
"conv2.norm": "norm", | ||
"proj.norm": "norm_out", | ||
# encoder | ||
"encoder.project_in.conv": "encoder.conv_in", | ||
"encoder.project_out.0.conv": "encoder.conv_out", | ||
"encoder.stages": "encoder.down_blocks", | ||
# decoder | ||
"decoder.project_in.conv": "decoder.conv_in", | ||
"decoder.project_out.0": "decoder.norm_out", | ||
"decoder.project_out.2.conv": "decoder.conv_out", | ||
"decoder.stages": "decoder.up_blocks", | ||
} | ||
|
||
AE_F32C32_KEYS = { | ||
"encoder.project_in.conv": "encoder.conv_in.conv", | ||
"decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
} | ||
|
||
AE_F64C128_KEYS = { | ||
"encoder.project_in.conv": "encoder.conv_in.conv", | ||
"decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
} | ||
|
||
AE_F128C512_KEYS = { | ||
"encoder.project_in.conv": "encoder.conv_in.conv", | ||
"decoder.project_out.2.conv": "decoder.conv_out.conv", | ||
} | ||
|
||
AE_SPECIAL_KEYS_REMAP = { | ||
"qkv.conv.weight": remap_qkv_, | ||
"proj.conv.weight": remap_proj_conv_, | ||
} | ||
|
||
if "f32c32" in model_name and "sana" not in model_name: | ||
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) | ||
elif "f64c128" in model_name: | ||
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) | ||
elif "f128c512" in model_name: | ||
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) | ||
|
||
for key in list(converted_state_dict.keys()): | ||
new_key = key[:] | ||
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): | ||
new_key = new_key.replace(replace_key, rename_key) | ||
converted_state_dict[new_key] = converted_state_dict.pop(key) | ||
|
||
for key in list(converted_state_dict.keys()): | ||
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): | ||
if special_key not in key: | ||
continue | ||
handler_fn_inplace(key, converted_state_dict) | ||
|
||
return converted_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would need to infer the model repo type using this key right? That still has to be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, missed it. Adding now, but not sure how this worked before then 🤔