Skip to content

Commit

Permalink
[Vae and AutoencoderKL] Final clean of LDM checkpoints (open-mmlab#137)
Browse files Browse the repository at this point in the history
* [Vae and AutoencoderKL clean]

* save intermediate finished work

* more progress

* more progress

* finish modeling code

* save intermediate

* finish

* Correct tests
  • Loading branch information
patrickvonplaten authored Jul 28, 2022
1 parent e05f03a commit 3100bc9
Show file tree
Hide file tree
Showing 6 changed files with 490 additions and 312 deletions.
181 changes: 153 additions & 28 deletions scripts/convert_ddpm_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline, VQModel, AutoencoderKL
import argparse
import json
import torch
Expand Down Expand Up @@ -64,7 +64,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s

target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)

num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3

old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
Expand All @@ -79,7 +79,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue

new_path = new_path.replace('down.', 'downsample_blocks.')
new_path = new_path.replace('down.', 'down_blocks.')
new_path = new_path.replace('up.', 'up_blocks.')

if additional_replacements is not None:
Expand Down Expand Up @@ -111,36 +111,36 @@ def convert_ddpm_checkpoint(checkpoint, config):
new_checkpoint['conv_out.weight'] = checkpoint['conv_out.weight']
new_checkpoint['conv_out.bias'] = checkpoint['conv_out.bias']

num_downsample_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
downsample_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_downsample_blocks)}
num_down_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'down' in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}

num_up_blocks = len({'.'.join(layer.split('.')[:2]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}

for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)

if any('downsample' in layer for layer in downsample_blocks[i]):
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.op.weight']
new_checkpoint[f'down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'down.{i}.downsample.op.bias']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']

if any('block' in layer for layer in downsample_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in downsample_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_blocks > 0:
for j in range(config['num_res_blocks']):
for j in range(config['layers_per_block']):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)

if any('attn' in layer for layer in downsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_attn > 0:
for j in range(config['num_res_blocks']):
for j in range(config['layers_per_block']):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)

Expand Down Expand Up @@ -176,7 +176,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_blocks > 0:
for j in range(config['num_res_blocks'] + 1):
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
Expand All @@ -186,7 +186,7 @@ def convert_ddpm_checkpoint(checkpoint, config):
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_attn > 0:
for j in range(config['num_res_blocks'] + 1):
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
Expand All @@ -195,6 +195,117 @@ def convert_ddpm_checkpoint(checkpoint, config):
return new_checkpoint


def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint = {}

new_checkpoint['encoder.conv_norm_out.weight'] = checkpoint['encoder.norm_out.weight']
new_checkpoint['encoder.conv_norm_out.bias'] = checkpoint['encoder.norm_out.bias']

new_checkpoint['encoder.conv_in.weight'] = checkpoint['encoder.conv_in.weight']
new_checkpoint['encoder.conv_in.bias'] = checkpoint['encoder.conv_in.bias']
new_checkpoint['encoder.conv_out.weight'] = checkpoint['encoder.conv_out.weight']
new_checkpoint['encoder.conv_out.bias'] = checkpoint['encoder.conv_out.bias']

new_checkpoint['decoder.conv_norm_out.weight'] = checkpoint['decoder.norm_out.weight']
new_checkpoint['decoder.conv_norm_out.bias'] = checkpoint['decoder.norm_out.bias']

new_checkpoint['decoder.conv_in.weight'] = checkpoint['decoder.conv_in.weight']
new_checkpoint['decoder.conv_in.bias'] = checkpoint['decoder.conv_in.bias']
new_checkpoint['decoder.conv_out.weight'] = checkpoint['decoder.conv_out.weight']
new_checkpoint['decoder.conv_out.bias'] = checkpoint['decoder.conv_out.bias']

num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'down' in layer})
down_blocks = {layer_id: [key for key in checkpoint if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}

num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in checkpoint if 'up' in layer})
up_blocks = {layer_id: [key for key in checkpoint if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}

for i in range(num_down_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)

if any('downsample' in layer for layer in down_blocks[i]):
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'encoder.down.{i}.downsample.conv.weight']
new_checkpoint[f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'] = checkpoint[f'encoder.down.{i}.downsample.conv.bias']

if any('block' in layer for layer in down_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in down_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_blocks > 0:
for j in range(config['layers_per_block']):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)

if any('attn' in layer for layer in down_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in down_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in down_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_attn > 0:
for j in range(config['layers_per_block']):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)

mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]

# Mid new 2
paths = renew_resnet_paths(mid_block_1_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_1', 'new': 'resnets.0'}
])

paths = renew_resnet_paths(mid_block_2_layers)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'block_2', 'new': 'resnets.1'}
])

paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[
{'old': 'mid.', 'new': 'mid_new_2.'}, {'old': 'attn_1', 'new': 'attentions.0'}
])

for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i

if any('upsample' in layer for layer in up_blocks[i]):
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.weight'] = checkpoint[f'decoder.up.{i}.upsample.conv.weight']
new_checkpoint[f'decoder.up_blocks.{block_id}.upsamplers.0.conv.bias'] = checkpoint[f'decoder.up.{i}.upsample.conv.bias']

if any('block' in layer for layer in up_blocks[i]):
num_blocks = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'block' in layer})
blocks = {layer_id: [key for key in up_blocks[i] if f'block.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_blocks > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])

if any('attn' in layer for layer in up_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 3).split('.')[:3]) for layer in up_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in up_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}

if num_attn > 0:
for j in range(config['layers_per_block'] + 1):
replace_indices = {'old': f'up_blocks.{i}', 'new': f'up_blocks.{block_id}'}
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])

new_checkpoint = {k.replace('mid_new_2', 'mid_block'): v for k, v in new_checkpoint.items()}
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
if "quantize.embedding.weight" in checkpoint:
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]

return new_checkpoint


if __name__ == "__main__":
parser = argparse.ArgumentParser()

Expand All @@ -220,15 +331,29 @@ def convert_ddpm_checkpoint(checkpoint, config):
with open(args.config_file) as f:
config = json.loads(f.read())

converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
# unet case
key_prefix_set = set(key.split(".")[0] for key in checkpoint.keys())
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
else:
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)

if "ddpm" in config:
del config["ddpm"]

model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)
if config["_class_name"] == "VQModel":
model = VQModel(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
elif config["_class_name"] == "AutoencoderKL":
model = AutoencoderKL(**config)
model.load_state_dict(converted_checkpoint)
model.save_pretrained(args.dump_path)
else:
model = UNet2DModel(**config)
model.load_state_dict(converted_checkpoint)

scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))

pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
10 changes: 7 additions & 3 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ def __init__(

self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)

self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None

self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
Expand Down Expand Up @@ -364,8 +367,9 @@ def set_weight(self, resnet):
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data

self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data
if self.time_emb_proj is not None:
self.time_emb_proj.weight.data = resnet.temb_proj.weight.data
self.time_emb_proj.bias.data = resnet.temb_proj.bias.data

self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
Expand Down
Loading

0 comments on commit 3100bc9

Please sign in to comment.