Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] PixArt-Alpha #5642

Merged
merged 258 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
258 commits
Select commit Hold shift + click to select a range
5866115
init pixart alpha pipeline
sayakpaul Oct 30, 2023
1e347fb
fix: import
sayakpaul Oct 30, 2023
38dfa67
script
sayakpaul Oct 30, 2023
62a2938
script
sayakpaul Oct 30, 2023
9be7087
script
sayakpaul Oct 30, 2023
0d8bcfc
add: vae to the pipeline
sayakpaul Oct 30, 2023
f86c90d
add: vae_scale_factor
sayakpaul Oct 30, 2023
2aef4b5
add: checkpoint_path
sayakpaul Oct 30, 2023
a7eeb76
clean conversion script a bit.
sayakpaul Oct 30, 2023
bc5ada3
size embeddings.
sayakpaul Oct 30, 2023
fb769b8
fix: size embedding
sayakpaul Oct 30, 2023
2f84eea
update scrip
sayakpaul Oct 30, 2023
90f5ace
support for interpolation of position embedding.
sayakpaul Oct 30, 2023
2c73290
support for conditioning.
sayakpaul Oct 30, 2023
a710107
..
sayakpaul Oct 30, 2023
8bfbb0a
..
sayakpaul Oct 30, 2023
9d92a10
..
sayakpaul Oct 30, 2023
5075573
final layer
sayakpaul Oct 30, 2023
0a32102
final layer
sayakpaul Oct 30, 2023
fea8df7
align if encode_prompt
sayakpaul Oct 30, 2023
c8d5bfa
support for caption embedding
sayakpaul Oct 30, 2023
ac54774
refactor
sayakpaul Oct 30, 2023
f304557
refactor
sayakpaul Oct 30, 2023
04e5342
refactor
sayakpaul Oct 30, 2023
cdec38b
start cross attention
sayakpaul Oct 30, 2023
ddaf8ce
start cross attention
sayakpaul Oct 30, 2023
afc43c7
cross_attention_dim
sayakpaul Oct 30, 2023
44bdcc1
cross
sayakpaul Oct 30, 2023
300911a
cross
sayakpaul Oct 30, 2023
f9f893c
support for resolution and aspect_ratio
sayakpaul Oct 31, 2023
546f4a0
support for caption projection
sayakpaul Oct 31, 2023
7a3ff2c
refactor patch embeddings
sayakpaul Oct 31, 2023
bf8ef00
batch_size
sayakpaul Oct 31, 2023
cca8355
up
sayakpaul Oct 31, 2023
1c3dd76
commit
sayakpaul Oct 31, 2023
2b821cd
commit
sayakpaul Oct 31, 2023
25930bb
commit.
sayakpaul Oct 31, 2023
944d44f
squeeze
sayakpaul Oct 31, 2023
2861c88
squeeze
sayakpaul Oct 31, 2023
9170d9a
squeeze
sayakpaul Oct 31, 2023
b4f3510
squeeze
sayakpaul Oct 31, 2023
b555a77
squeeze
sayakpaul Oct 31, 2023
79ef225
squeeze
sayakpaul Oct 31, 2023
14cf017
squeeze
sayakpaul Oct 31, 2023
f4e6bb6
squeeze
sayakpaul Oct 31, 2023
f3a3186
squeeze
sayakpaul Oct 31, 2023
3d113d7
squeeze
sayakpaul Oct 31, 2023
0c9b661
squeeze
sayakpaul Oct 31, 2023
fdd156a
squeeze.
sayakpaul Oct 31, 2023
5c66e16
squeeze.
sayakpaul Oct 31, 2023
11cdb4d
fix final block./
sayakpaul Oct 31, 2023
03cb83d
fix final block./
sayakpaul Oct 31, 2023
a7f9687
fix final block./
sayakpaul Oct 31, 2023
696aa61
clean
sayakpaul Oct 31, 2023
b698975
fix: interpolation scale.
sayakpaul Oct 31, 2023
5af0fb7
debugging'
sayakpaul Oct 31, 2023
18f7105
debugging'
sayakpaul Oct 31, 2023
791efb3
debugging'
sayakpaul Oct 31, 2023
facd99a
debugging'
sayakpaul Oct 31, 2023
fb1204e
debugging'
sayakpaul Oct 31, 2023
75f73a8
debugging'
sayakpaul Oct 31, 2023
ec635e8
debugging'
sayakpaul Oct 31, 2023
148d322
debugging'
sayakpaul Oct 31, 2023
fa9344d
debugging'
sayakpaul Oct 31, 2023
163167a
debugging'
sayakpaul Oct 31, 2023
030b69e
debugging'
sayakpaul Oct 31, 2023
8072f1a
debugging'
sayakpaul Oct 31, 2023
8930d44
debugging'
sayakpaul Oct 31, 2023
2c1deeb
debugging'
sayakpaul Oct 31, 2023
a9e179b
debugging'
sayakpaul Oct 31, 2023
84beae1
debugging'
sayakpaul Oct 31, 2023
9daa187
debugging'
sayakpaul Oct 31, 2023
f50b378
debugging'
sayakpaul Oct 31, 2023
67c731e
debugging'
sayakpaul Oct 31, 2023
822d522
debugging'
sayakpaul Oct 31, 2023
a21cb9d
debugging'
sayakpaul Oct 31, 2023
476e56f
debugging'
sayakpaul Oct 31, 2023
ada2141
debugging'
sayakpaul Oct 31, 2023
d9c7c28
debugging'
sayakpaul Oct 31, 2023
f3eacac
debugging'
sayakpaul Oct 31, 2023
5f01bd2
debugging'
sayakpaul Oct 31, 2023
1b2486f
debugging'
sayakpaul Oct 31, 2023
3b6fcd5
debugging'
sayakpaul Oct 31, 2023
3715aef
debugging'
sayakpaul Oct 31, 2023
f42672d
debugging'
sayakpaul Oct 31, 2023
51ddf4f
debugging'
sayakpaul Oct 31, 2023
2948123
debugging'
sayakpaul Oct 31, 2023
6d01157
debugging'
sayakpaul Oct 31, 2023
a257ca4
debugging'
sayakpaul Oct 31, 2023
4a5868f
debugging'
sayakpaul Oct 31, 2023
3661307
debugging'
sayakpaul Oct 31, 2023
7f0e425
debugging'
sayakpaul Oct 31, 2023
0bff9f6
debugging'
sayakpaul Oct 31, 2023
3d7d1a5
debugging'
sayakpaul Oct 31, 2023
ae4cf95
debugging'
sayakpaul Oct 31, 2023
6672755
debugging'
sayakpaul Oct 31, 2023
90e9b2f
debugging'
sayakpaul Oct 31, 2023
85eaa31
debugging
sayakpaul Oct 31, 2023
eca0c66
debugging
sayakpaul Oct 31, 2023
8f3fbfc
debugging
sayakpaul Oct 31, 2023
0033278
debugging
sayakpaul Oct 31, 2023
56b8770
debugging
sayakpaul Oct 31, 2023
57eca10
debugging
sayakpaul Oct 31, 2023
8a11fba
debugging
sayakpaul Nov 1, 2023
ca352f6
make --checkpoint_path non-required.
sayakpaul Nov 1, 2023
d426281
debugging
sayakpaul Nov 1, 2023
64b9539
debugging
sayakpaul Nov 1, 2023
9e791c1
debugging
sayakpaul Nov 1, 2023
dc8094a
debugging
sayakpaul Nov 1, 2023
8afdc91
debugging
sayakpaul Nov 1, 2023
2f35fc0
debugging
sayakpaul Nov 1, 2023
7290c9c
debugging
sayakpaul Nov 1, 2023
0a9425e
debugging
sayakpaul Nov 1, 2023
5312c0b
debugging
sayakpaul Nov 1, 2023
6cd0815
debugging
sayakpaul Nov 1, 2023
3f9653a
debugging
sayakpaul Nov 1, 2023
b9b6b40
debugging
sayakpaul Nov 1, 2023
9bd2255
debugging
sayakpaul Nov 1, 2023
1841963
debugging
sayakpaul Nov 1, 2023
196741a
debugging
sayakpaul Nov 1, 2023
675c19c
debugging
sayakpaul Nov 1, 2023
59d3d52
debugging
sayakpaul Nov 1, 2023
eff5e35
debugging
sayakpaul Nov 1, 2023
e7682d0
debugging
sayakpaul Nov 1, 2023
b0c7a7b
debugging
sayakpaul Nov 1, 2023
52f7da3
debugging
sayakpaul Nov 1, 2023
3d36d71
debugging
sayakpaul Nov 1, 2023
4a13640
debugging
sayakpaul Nov 1, 2023
c6055b3
debugging
sayakpaul Nov 1, 2023
9f39a53
debugging
sayakpaul Nov 1, 2023
0278462
debugging
sayakpaul Nov 1, 2023
9a2ac25
debugging
sayakpaul Nov 1, 2023
326d9a1
debugging
sayakpaul Nov 1, 2023
1602931
debugging
sayakpaul Nov 1, 2023
995fccf
debugging
sayakpaul Nov 1, 2023
7078d20
debugging
sayakpaul Nov 1, 2023
bdf2125
remove num_tokens
sayakpaul Nov 3, 2023
dbabf75
timesteps -> timestep
sayakpaul Nov 3, 2023
e860663
timesteps -> timestep
sayakpaul Nov 3, 2023
5bcbce8
timesteps -> timestep
sayakpaul Nov 3, 2023
e62cc85
timesteps -> timestep
sayakpaul Nov 3, 2023
c424643
timesteps -> timestep
sayakpaul Nov 3, 2023
10ee86c
timesteps -> timestep
sayakpaul Nov 3, 2023
2d25bd1
debug
sayakpaul Nov 3, 2023
051bb40
debug
sayakpaul Nov 3, 2023
bb10ad6
update conversion script.
sayakpaul Nov 3, 2023
58c8bf6
update conversion script.
sayakpaul Nov 3, 2023
1b1353b
update conversion script.
sayakpaul Nov 3, 2023
5c0d38c
debug
sayakpaul Nov 3, 2023
fe66654
debug
sayakpaul Nov 3, 2023
565ef63
debug
sayakpaul Nov 3, 2023
ebe7b10
clean
sayakpaul Nov 3, 2023
b42deb3
debug
sayakpaul Nov 3, 2023
fe35226
debug
sayakpaul Nov 3, 2023
ebea989
debug
sayakpaul Nov 3, 2023
3b767f6
debug
sayakpaul Nov 3, 2023
b44be3b
debug
sayakpaul Nov 3, 2023
424f435
debug
sayakpaul Nov 3, 2023
e116191
debug
sayakpaul Nov 3, 2023
0bc55c0
debug
sayakpaul Nov 3, 2023
0fbbf7e
deug
sayakpaul Nov 3, 2023
c281bf2
debug
sayakpaul Nov 3, 2023
5e42815
debug
sayakpaul Nov 3, 2023
d46c6e5
debug
sayakpaul Nov 3, 2023
0bccadd
fix
sayakpaul Nov 3, 2023
5a0aa54
fix
sayakpaul Nov 3, 2023
932ca92
fix
sayakpaul Nov 3, 2023
4550c42
fix
sayakpaul Nov 3, 2023
9305f4e
fix
sayakpaul Nov 3, 2023
2c8588b
fix
sayakpaul Nov 3, 2023
e879929
fix
sayakpaul Nov 3, 2023
ebbe35d
fix
sayakpaul Nov 3, 2023
7b32589
fix
sayakpaul Nov 3, 2023
84e7920
fix
sayakpaul Nov 3, 2023
b4ecd5f
fix
sayakpaul Nov 3, 2023
7439673
fix
sayakpaul Nov 3, 2023
f2b682c
fix
sayakpaul Nov 3, 2023
ad2825e
clean
sayakpaul Nov 3, 2023
c2ec596
fix
sayakpaul Nov 3, 2023
fc98faa
fix
sayakpaul Nov 3, 2023
b848369
boom
sayakpaul Nov 3, 2023
5cde4d2
boom
sayakpaul Nov 3, 2023
64b3a9b
some changes
patrickvonplaten Nov 3, 2023
047f239
Merge branch 'feat/pixart-alpha' of https://github.com/huggingface/di…
patrickvonplaten Nov 3, 2023
5a421e1
boom
sayakpaul Nov 3, 2023
c42585f
Merge branch 'feat/pixart-alpha' of https://github.com/huggingface/di…
patrickvonplaten Nov 3, 2023
02dae17
save
sayakpaul Nov 3, 2023
b09c7a3
Merge branch 'feat/pixart-alpha' of https://github.com/huggingface/di…
patrickvonplaten Nov 3, 2023
9152319
up
patrickvonplaten Nov 3, 2023
35483d2
remove i
patrickvonplaten Nov 3, 2023
00f2aad
fix more tests
patrickvonplaten Nov 3, 2023
c8a8171
DPMSolverMultistepScheduler
sayakpaul Nov 4, 2023
f8bcb26
fix
sayakpaul Nov 4, 2023
4944b90
offloading
sayakpaul Nov 4, 2023
746d503
fix conversion script
sayakpaul Nov 4, 2023
4790c68
fix conversion script
sayakpaul Nov 4, 2023
4f15269
remove print
sayakpaul Nov 4, 2023
7dfed94
resolve conflicts and nothing else
sayakpaul Nov 4, 2023
afc8931
remove support for negative prompt embeds.
sayakpaul Nov 4, 2023
8085203
typo.
sayakpaul Nov 4, 2023
9b80d46
remove extra kwargs
sayakpaul Nov 4, 2023
1aa7456
bring conversion script to where it was
sayakpaul Nov 4, 2023
66a6829
fix
sayakpaul Nov 4, 2023
193b43e
trying mu luck
sayakpaul Nov 4, 2023
6733afa
trying my luck again
sayakpaul Nov 4, 2023
e0bfbf8
again
sayakpaul Nov 4, 2023
4b4df35
again
sayakpaul Nov 4, 2023
976cc40
again
sayakpaul Nov 4, 2023
2b2a7a9
clean up
sayakpaul Nov 4, 2023
ae13a1b
up
sayakpaul Nov 4, 2023
2662f46
up
sayakpaul Nov 4, 2023
c8d46fc
update example
sayakpaul Nov 5, 2023
2d020c3
support for 512
sayakpaul Nov 5, 2023
4c0e3a2
Merge branch 'main' into feat/pixart-alpha
sayakpaul Nov 5, 2023
5333c41
remove spacing
sayakpaul Nov 5, 2023
3a41ace
finalize docs.
sayakpaul Nov 5, 2023
4ff114e
test debug
sayakpaul Nov 5, 2023
091b6fa
fix: assertion values.
sayakpaul Nov 5, 2023
a46423e
debug
sayakpaul Nov 5, 2023
cb86d5d
debug
sayakpaul Nov 5, 2023
add79b7
debug
sayakpaul Nov 5, 2023
693b8de
fix: repeat
sayakpaul Nov 5, 2023
cc3cdf8
remove prints.
sayakpaul Nov 5, 2023
067caee
Apply suggestions from code review
patrickvonplaten Nov 5, 2023
74c2d89
Apply suggestions from code review
patrickvonplaten Nov 5, 2023
662fef1
Correct more
patrickvonplaten Nov 5, 2023
6d70777
Apply suggestions from code review
patrickvonplaten Nov 5, 2023
23fca6e
Change all
patrickvonplaten Nov 5, 2023
5ce3467
Clean more
patrickvonplaten Nov 6, 2023
decfa3d
fix more
patrickvonplaten Nov 6, 2023
50f4d5d
Fix more
patrickvonplaten Nov 6, 2023
38d3b8f
Fix more
patrickvonplaten Nov 6, 2023
201cb4b
Correct more
patrickvonplaten Nov 6, 2023
b059505
address patrick's comments.
sayakpaul Nov 6, 2023
892a323
remove unneeded args
sayakpaul Nov 6, 2023
00403c4
clean up pipeline.
sayakpaul Nov 6, 2023
89df5e4
sty;e
sayakpaul Nov 6, 2023
6916263
make the use of additional conditions better conditioned.
sayakpaul Nov 6, 2023
a6a7b7d
None better
sayakpaul Nov 6, 2023
65f9a0e
dtype
sayakpaul Nov 6, 2023
053895a
height and width validation
sayakpaul Nov 6, 2023
9c326d6
add a note about size brackets.
sayakpaul Nov 6, 2023
5dbed8b
fix
sayakpaul Nov 6, 2023
38a1e47
spit out slow test outputs.
sayakpaul Nov 6, 2023
c23dd15
fix?
sayakpaul Nov 6, 2023
e56aa69
fix optional test
sayakpaul Nov 6, 2023
d05df42
fix more
sayakpaul Nov 6, 2023
4485253
Merge branch 'main' into feat/pixart-alpha
sayakpaul Nov 6, 2023
4c7cc1b
remove unneeded comment
sayakpaul Nov 6, 2023
40ae864
debug
sayakpaul Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@
title: Parallel Sampling of Diffusion Models
- local: api/pipelines/pix2pix_zero
title: Pix2Pix Zero
- local: api/pipelines/pixart
title: PixArt
- local: api/pipelines/pndm
title: PNDM
- local: api/pipelines/repaint
Expand Down
36 changes: 36 additions & 0 deletions docs/source/en/api/pipelines/pixart.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# PixArt

![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/header_collage.png)

[PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis](https://huggingface.co/papers/2310.00426) is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li.

The abstract from the paper is:

*The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost, as shown in Figure 1 and 2. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PIXART-α's training speed markedly surpasses existing large-scale T2I models, e.g., PIXART-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PIXART-α excels in image quality, artistry, and semantic control. We hope PIXART-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.*

You can find the original codebase at [PixArt-alpha/PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha) and all the available checkpoints at [PixArt-alpha](https://huggingface.co/PixArt-alpha).

Some notes about this pipeline:

* It uses a Transformer backbone (instead of a UNet) for denoising. As such it has a similar architecture as [DiT](./dit.md).
* It was trained using text conditions computed from T5. This aspect makes the pipeline better at following complex text prompts with intricate details.
* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-alpha/blob/08fbbd281ec96866109bdd2cdb75f2f58fb17610/diffusion/data/datasets/utils.py).
* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as Stable Diffusion XL, Imagen, and DALL-E 2, while being more efficient than them.

## PixArtAlphaPipeline

[[autodoc]] PixArtAlphaPipeline
- all
- __call__
198 changes: 198 additions & 0 deletions scripts/convert_pixart_alpha_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import argparse
import os

import torch
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, Transformer2DModel


ckpt_id = "PixArt-alpha/PixArt-alpha"
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
interpolation_scale = {512: 1, 1024: 2}


def main(args):
all_state_dict = torch.load(args.orig_ckpt_path)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}

# Patch embeddings.
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")

# Caption projection.
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")

# AdaLN-single LN
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")

if args.image_size == 1024:
# Resolution.
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
"csize_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
"csize_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
"csize_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
"csize_embedder.mlp.2.bias"
)
# Aspect ratio.
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
"ar_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
"ar_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
"ar_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
"ar_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")

for depth in range(28):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)

# Attention is all you need 🤘

# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)

# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.bias"
)

# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)

converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias

converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)

# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")

# DiT XL/2
transformer = Transformer2DModel(
sample_size=args.image_size // 8,
num_layers=28,
attention_head_dim=72,
in_channels=4,
out_channels=8,
patch_size=2,
attention_bias=True,
num_attention_heads=16,
cross_attention_dim=1152,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_single",
norm_elementwise_affine=False,
norm_eps=1e-6,
caption_channels=4096,
)
transformer.load_state_dict(converted_state_dict, strict=True)

assert transformer.pos_embed.pos_embed is not None
state_dict.pop("pos_embed")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"

num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")

if args.only_transformer:
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
else:
scheduler = DPMSolverMultistepScheduler()

vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="sd-vae-ft-ema")

tokenizer = T5Tokenizer.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")
text_encoder = T5EncoderModel.from_pretrained(ckpt_id, subfolder="t5-v1_1-xxl")

pipeline = PixArtAlphaPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)

pipeline.save_pretrained(args.dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[512, 1024],
required=False,
help="Image size of pretrained model, either 512 or 1024.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--only_transformer", default=True, type=bool, required=True)

args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
"LDMTextToImagePipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
"PixArtAlphaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
Expand Down Expand Up @@ -579,6 +580,7 @@
LDMTextToImagePipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
PixArtAlphaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
Expand Down
Loading
Loading