Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for premade checkpoints #24

Open
rockerBOO opened this issue Jan 2, 2023 · 7 comments
Open

Support for premade checkpoints #24

rockerBOO opened this issue Jan 2, 2023 · 7 comments

Comments

@rockerBOO
Copy link

rockerBOO commented Jan 2, 2023

Would it be possible to use models based on the CompVis style used by stabilityai and supported in HF diffusers? My personal goals are:

  • Support for 1.5 and 2.1 ckpt models (converted over to .ot)
  • Support for third party merges and training checkpoints (based on 1.5 or 2.1so probably would work the same). Analog Diffusion 1.0

I tried the following to convert the file over, and got the names of the tensors using the tensor tools. Maybe these can be extracted and compiled back together?

import numpy as np
import torch

model = torch.load("./data/analog-diffusion-1.0.ckpt")
x = {k: v.numpy() for k, v in model["state_dict"].items()}
np.savez("./data/analog-diffusion-1.0.npz", **x)
cargo run --release --example tensor-tools cp ./data/analog-diffusion-1.0.npz ./data/analog-diffusion-1.0.ot
cargo run --release --example tensor-tools ls ./data/analog-diffusion-1.0.ot
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.0.0.weight Tensor[[320, 4, 3, 3], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.0.0.bias Tensor[[320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.0.weight Tensor[[1280, 320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.0.bias Tensor[[1280], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.2.weight Tensor[[1280, 1280], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.2.bias Tensor[[1280], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.norm.weight Tensor[[320], Half]
...
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight Tensor[[768, 768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias Tensor[[768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight Tensor[[768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias Tensor[[768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight Tensor[[3072, 768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias Tensor[[3072], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight Tensor[[768, 3072], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias Tensor[[768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight Tensor[[768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias Tensor[[768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.final_layer_norm.weight Tensor[[768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.final_layer_norm.bias Tensor[[768], Half]

Full list analog-diffusion-1.0.ot.log

Thanks!

@rockerBOO
Copy link
Author

rockerBOO commented Jan 3, 2023

Ok did some further investigation.

For these they are all based on the default inference configs set inside

  • CompVis/stable-diffusion for 1.x and
  • Stability-AI/stablediffusion for 2.x

The inference config files:

These lines read the config in and load/call the modules in the config.

Classes inside CompVis/stable-diffusion


Inside this repo we have src/pipelines/stable_diffusion.rs which has loaders the VAE, UNet, Autoencoder, Scheduler, and CLIP (CLIP for 1.x and OpenCLIP for 2.x). Is this something where the LatentDiffusion model needs to be made to support this?

@rockerBOO
Copy link
Author

rockerBOO commented Jan 8, 2023

Ok, so it seems these checkpoint files have all the components inside, and are named in a way to rebuild from the python modules.

I took some time reading through the diffusers library some more (wasn't using this before, just the CompVis forks mostly). HF diffusers has a from_pretrained which allows you to load a model from the hugging face website or a local version. Putting them on HF has you make a model_index.json to help define the mappings to the components in HF diffusers and transformers.

To put these together, they have Diffusion pipelines. This allows various diffusion pipelines to be made, including the Stable Diffusion pipeline. The Stable Diffusion pipeline works to bring together these parts. This library has a stable_diffusion pipeline.

So, to get premade checkpoints that include all these various components in the pipeline, there needs to be a translation process that converts all the named parameters in the checkpoint to the right components (autoencoder, unet). I started to try and piece together the names to possible internal representation, but I do not have a clear enough picture to process it.

For example

model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight Tensor[[1280, 1280], Half]
model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias Tensor[[1280], Half]

Might be inside this data structure:

UNetMidBlock2DCrossAttn {
	attn_resnets: vec![(
		SpatialTransformer {
			transformer_blocks: vec![BasticTransformerBlock {
				attn1
				ff
				attn2: CrossAttention {
					to_q
					to_k
					to_v
					to_out: nn::Linear {
						ws: Tensor
						bs: Option<Tensor>
					},
				},
				norm1
				norm2
				norm3
			}, ...]
		}, _
	)]
}

But some / a few of the conversions are less clear.

Well, a little closer in my head, but let me know if any opinions about this thought process. Seems doable, but the translating between libraries and python/rust is taking me some time to understand.

@rockerBOO
Copy link
Author

Read a helpful post about Latent Diffusion that documented over the CompVis version, which highlighted what certain things are in the checkpoint files. Now it seems to all match up. Only thing that's a little off is there are no resnet inside the CompVis version but its there inside the local unet.

model.diffusion_model = unet
first_stage_model = autoencoder
cond_stage_model.transformer = CLIPTextEmbedder/ClipTextTransformer

cond_stage_model.transformer

./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.embeddings.position_ids Tensor[[1, 77], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.embeddings.token_embedding.weight Tensor[[49408, 768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.embeddings.position_embedding.weight Tensor[[77, 768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight Tensor[[768, 768], Half]
./data/analog-diffusion-1.0.ot: cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias Tensor[[768], Half]

matches up with clip model pytorch_model.ot

➜ cargo run --release --example tensor-tools ls ./data/pytorch_model.ot
    Finished release [optimized] target(s) in 0.07s
     Running `target/release/examples/tensor-tools ls ./data/pytorch_model.ot`
./data/pytorch_model.ot: text_model.embeddings.position_ids Tensor[[1, 77], Int64]
./data/pytorch_model.ot: text_model.embeddings.token_embedding.weight Tensor[[49408, 768], Float]
./data/pytorch_model.ot: text_model.embeddings.position_embedding.weight Tensor[[77, 768], Float]
./data/pytorch_model.ot: text_model.encoder.layers.0.self_attn.k_proj.bias Tensor[[768], Float]

first_stage_model

./data/analog-diffusion-1.0.ot: first_stage_model.encoder.conv_in.weight Tensor[[128, 3, 3, 3], Half]
./data/analog-diffusion-1.0.ot: first_stage_model.encoder.conv_in.bias Tensor[[128], Half]
./data/analog-diffusion-1.0.ot: first_stage_model.encoder.down.0.block.0.norm1.weight Tensor[[128], Half]
./data/analog-diffusion-1.0.ot: first_stage_model.encoder.down.0.block.0.norm1.bias Tensor[[128], Half]

Matches the vae.ot:

./data/vae.ot: encoder.conv_in.weight Tensor[[128, 3, 3, 3], Float]
./data/vae.ot: encoder.conv_in.bias Tensor[[128], Float]
./data/vae.ot: encoder.down_blocks.0.resnets.0.norm1.weight Tensor[[128], Float]
./data/vae.ot: encoder.down_blocks.0.resnets.0.norm1.bias Tensor[[128], Float]

model.diffusion_model

./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.0.0.weight Tensor[[320, 4, 3, 3], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.0.0.bias Tensor[[320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.0.weight Tensor[[1280, 320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.0.bias Tensor[[1280], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.2.weight Tensor[[1280, 1280], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.time_embed.2.bias Tensor[[1280], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.norm.weight Tensor[[320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.norm.bias Tensor[[320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.proj_in.weight Tensor[[320, 320, 1, 1], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.proj_in.bias Tensor[[320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight Tensor[[320, 320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight Tensor[[320, 320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight Tensor[[320, 320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight Tensor[[320, 320], Half]
./data/analog-diffusion-1.0.ot: model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias Tensor[[320], Half]

Matches unet.ot

./data/unet.ot: conv_in.weight Tensor[[320, 4, 3, 3], Float]
./data/unet.ot: conv_in.bias Tensor[[320], Float]
./data/unet.ot: time_embedding.linear_1.weight Tensor[[1280, 320], Float]
./data/unet.ot: time_embedding.linear_1.bias Tensor[[1280], Float]
./data/unet.ot: time_embedding.linear_2.weight Tensor[[1280, 1280], Float]
./data/unet.ot: time_embedding.linear_2.bias Tensor[[1280], Float]
./data/unet.ot: down_blocks.0.attentions.0.norm.weight Tensor[[320], Float]
./data/unet.ot: down_blocks.0.attentions.0.norm.bias Tensor[[320], Float]
./data/unet.ot: down_blocks.0.attentions.0.proj_in.weight Tensor[[320, 320, 1, 1], Float]
./data/unet.ot: down_blocks.0.attentions.0.proj_in.bias Tensor[[320], Float]
./data/unet.ot: down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight Tensor[[320, 320], Float]
./data/unet.ot: down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight Tensor[[320, 320], Float]
./data/unet.ot: down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight Tensor[[320, 320], Float]
./data/unet.ot: down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight Tensor[[320, 320], Float]
./data/unet.ot: down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias Tensor[[320], Float]

Now, looking at that, how would we actually load this in? We are using load which has a file for each of these, but how would we load all of them at the same time and convert all the parameters over? They use nn:VarStore but not exactly sure how it would work to load these parameters (load or load_from_stream, or load_partial?).

@LaurentMazare
Copy link
Owner

Not sure to understand the full context here but my guess is that the easiest would be on the python side to perform the renaming and write separate .npz files for each bit and then use tensor-tools to convert each of these to a .ot file.
There is some preliminary support for loading checkpoints saved from python at the tch level but sadly it's not fully working yet, see LaurentMazare/tch-rs#595 for more details.

@rockerBOO
Copy link
Author

That sounds good! I will write a python script to convert it over to the various .npz files and report back. Subscribed to that issue for any developments on that front.

@rockerBOO
Copy link
Author

rockerBOO commented Jan 12, 2023

I was working through making a conversion script when I realized that someone might have already made a script to convert. So went to google and found that diffusers has conversion scripts for the various formats. The scripts here should be good to convert over any of the other versions into a diffusers version.

Since we want individual files, I will be trying to get the script to run in 1 step instead of 2 (instead of making a .ckpt and then extracting out the parts. )

@rockerBOO
Copy link
Author

rockerBOO commented Jan 12, 2023

The conversion scripts take your checkpoint and extracts it into a hugging face diffusers format with many folders, configuration json files.

feature_extractor  model_index.json  safety_checker  scheduler  text_encoder  tokenizer  unet  vae

Which then has the extracted unet (unet/diffusion_pytorch_model.bin), vae (vae/diffusion_pytorch_model.bin) that can then be converted over to the .npz and then into the .ot.

Downside to this approach

Creates files we don't need in many cases like the text_encoder (for general Stable Diffusion uses CLIP or OpenCLIP), safety_checker (which detects NSFW, and other potentially harmful content) which can produce another couple of gigs of files.

It creates a 2-step process to make the hugging face version and then convert the .npz over and then convert that over to .ot.

For me, I need the fp16 unet to fit onto my GPU so this would require another step of converting a fp32 over, which I haven't gotten working yet.

Upside

Many models seem to use this conversion to put them up on hugging face, so we can just download these files from their extraction and use the unet and vae parts. Then a matter of converting to .npz and .ot there.


So for me the biggest issue is converting it from fp32 to fp16 at the moment and then the rest seems doable. After I get this working, I will write a little walkthrough for other users to utilize.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants