-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init pixart alpha pipeline * fix: import * script * script * script * add: vae to the pipeline * add: vae_scale_factor * add: checkpoint_path * clean conversion script a bit. * size embeddings. * fix: size embedding * update scrip * support for interpolation of position embedding. * support for conditioning. * .. * .. * .. * final layer * final layer * align if encode_prompt * support for caption embedding * refactor * refactor * refactor * start cross attention * start cross attention * cross_attention_dim * cross * cross * support for resolution and aspect_ratio * support for caption projection * refactor patch embeddings * batch_size * up * commit * commit * commit. * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze. * squeeze. * fix final block./ * fix final block./ * fix final block./ * clean * fix: interpolation scale. * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging * debugging * debugging * debugging * debugging * debugging * debugging * make --checkpoint_path non-required. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * remove num_tokens * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * debug * debug * update conversion script. * update conversion script. * update conversion script. * debug * debug * debug * clean * debug * debug * debug * debug * debug * debug * debug * debug * deug * debug * debug * debug * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * clean * fix * fix * boom * boom * some changes * boom * save * up * remove i * fix more tests * DPMSolverMultistepScheduler * fix * offloading * fix conversion script * fix conversion script * remove print * remove support for negative prompt embeds. * typo. * remove extra kwargs * bring conversion script to where it was * fix * trying mu luck * trying my luck again * again * again * again * clean up * up * up * update example * support for 512 * remove spacing * finalize docs. * test debug * fix: assertion values. * debug * debug * debug * fix: repeat * remove prints. * Apply suggestions from code review * Apply suggestions from code review * Correct more * Apply suggestions from code review * Change all * Clean more * fix more * Fix more * Fix more * Correct more * address patrick's comments. * remove unneeded args * clean up pipeline. * sty;e * make the use of additional conditions better conditioned. * None better * dtype * height and width validation * add a note about size brackets. * fix * spit out slow test outputs. * fix? * fix optional test * fix more * remove unneeded comment * debug --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
- Loading branch information
1 parent
2b23ec8
commit d61889f
Showing
15 changed files
with
1,501 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
--> | ||
|
||
# PixArt | ||
|
||
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/header_collage.png) | ||
|
||
[PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis](https://huggingface.co/papers/2310.00426) is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li. | ||
|
||
The abstract from the paper is: | ||
|
||
*The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost, as shown in Figure 1 and 2. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PIXART-α's training speed markedly surpasses existing large-scale T2I models, e.g., PIXART-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PIXART-α excels in image quality, artistry, and semantic control. We hope PIXART-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.* | ||
|
||
You can find the original codebase at [PixArt-alpha/PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha) and all the available checkpoints at [PixArt-alpha](https://huggingface.co/PixArt-alpha). | ||
|
||
Some notes about this pipeline: | ||
|
||
* It uses a Transformer backbone (instead of a UNet) for denoising. As such it has a similar architecture as [DiT](./dit.md). | ||
* It was trained using text conditions computed from T5. This aspect makes the pipeline better at following complex text prompts with intricate details. | ||
* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-alpha/blob/08fbbd281ec96866109bdd2cdb75f2f58fb17610/diffusion/data/datasets/utils.py). | ||
* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as Stable Diffusion XL, Imagen, and DALL-E 2, while being more efficient than them. | ||
|
||
## PixArtAlphaPipeline | ||
|
||
[[autodoc]] PixArtAlphaPipeline | ||
- all | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import argparse | ||
import os | ||
|
||
import torch | ||
from transformers import T5EncoderModel, T5Tokenizer | ||
|
||
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel | ||
|
||
|
||
ckpt_id = "PixArt-alpha/PixArt-alpha" | ||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125 | ||
interpolation_scale = {512: 1, 1024: 2} | ||
|
||
|
||
def main(args): | ||
all_state_dict = torch.load(args.orig_ckpt_path) | ||
state_dict = all_state_dict.pop("state_dict") | ||
converted_state_dict = {} | ||
|
||
# Patch embeddings. | ||
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") | ||
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") | ||
|
||
# Caption projection. | ||
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding") | ||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") | ||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") | ||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") | ||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") | ||
|
||
# AdaLN-single LN | ||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( | ||
"t_embedder.mlp.0.weight" | ||
) | ||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") | ||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( | ||
"t_embedder.mlp.2.weight" | ||
) | ||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") | ||
|
||
if args.image_size == 1024: | ||
# Resolution. | ||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop( | ||
"csize_embedder.mlp.0.weight" | ||
) | ||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop( | ||
"csize_embedder.mlp.0.bias" | ||
) | ||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop( | ||
"csize_embedder.mlp.2.weight" | ||
) | ||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop( | ||
"csize_embedder.mlp.2.bias" | ||
) | ||
# Aspect ratio. | ||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop( | ||
"ar_embedder.mlp.0.weight" | ||
) | ||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop( | ||
"ar_embedder.mlp.0.bias" | ||
) | ||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop( | ||
"ar_embedder.mlp.2.weight" | ||
) | ||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop( | ||
"ar_embedder.mlp.2.bias" | ||
) | ||
# Shared norm. | ||
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") | ||
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") | ||
|
||
for depth in range(28): | ||
# Transformer blocks. | ||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( | ||
f"blocks.{depth}.scale_shift_table" | ||
) | ||
|
||
# Attention is all you need 🤘 | ||
|
||
# Self attention. | ||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) | ||
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0) | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias | ||
# Projection. | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( | ||
f"blocks.{depth}.attn.proj.weight" | ||
) | ||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( | ||
f"blocks.{depth}.attn.proj.bias" | ||
) | ||
|
||
# Feed-forward. | ||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop( | ||
f"blocks.{depth}.mlp.fc1.weight" | ||
) | ||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop( | ||
f"blocks.{depth}.mlp.fc1.bias" | ||
) | ||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop( | ||
f"blocks.{depth}.mlp.fc2.weight" | ||
) | ||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop( | ||
f"blocks.{depth}.mlp.fc2.bias" | ||
) | ||
|
||
# Cross-attention. | ||
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") | ||
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") | ||
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) | ||
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) | ||
|
||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q | ||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias | ||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k | ||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias | ||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v | ||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias | ||
|
||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( | ||
f"blocks.{depth}.cross_attn.proj.weight" | ||
) | ||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( | ||
f"blocks.{depth}.cross_attn.proj.bias" | ||
) | ||
|
||
# Final block. | ||
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") | ||
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") | ||
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") | ||
|
||
# DiT XL/2 | ||
transformer = Transformer2DModel( | ||
sample_size=args.image_size // 8, | ||
num_layers=28, | ||
attention_head_dim=72, | ||
in_channels=4, | ||
out_channels=8, | ||
patch_size=2, | ||
attention_bias=True, | ||
num_attention_heads=16, | ||
cross_attention_dim=1152, | ||
activation_fn="gelu-approximate", | ||
num_embeds_ada_norm=1000, | ||
norm_type="ada_norm_single", | ||
norm_elementwise_affine=False, | ||
norm_eps=1e-6, | ||
caption_channels=4096, | ||
) | ||
transformer.load_state_dict(converted_state_dict, strict=True) | ||
|
||
assert transformer.pos_embed.pos_embed is not None | ||
state_dict.pop("pos_embed") | ||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" | ||
|
||
num_model_params = sum(p.numel() for p in transformer.parameters()) | ||
print(f"Total number of transformer parameters: {num_model_params}") | ||
|
||
if args.only_transformer: | ||
transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) | ||
else: | ||
scheduler = DPMSolverMultistepScheduler() | ||
|
||
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="sd-vae-ft-ema") | ||
|
||
tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl") | ||
text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl") | ||
|
||
pipeline = PixArtAlphaPipeline( | ||
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler | ||
) | ||
|
||
pipeline.save_pretrained(args.dump_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." | ||
) | ||
parser.add_argument( | ||
"--image_size", | ||
default=1024, | ||
type=int, | ||
choices=[512, 1024], | ||
required=False, | ||
help="Image size of pretrained model, either 512 or 1024.", | ||
) | ||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") | ||
parser.add_argument("--only_transformer", default=True, type=bool, required=True) | ||
|
||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.