Skip to content

Commit

Permalink
[Feat] PixArt-Alpha (huggingface#5642)
Browse files Browse the repository at this point in the history
* init pixart alpha pipeline

* fix: import

* script

* script

* script

* add: vae to the pipeline

* add: vae_scale_factor

* add: checkpoint_path

* clean conversion script a bit.

* size embeddings.

* fix: size embedding

* update scrip

* support for interpolation of position embedding.

* support for conditioning.

* ..

* ..

* ..

* final layer

* final layer

* align if encode_prompt

* support for caption embedding

* refactor

* refactor

* refactor

* start cross attention

* start cross attention

* cross_attention_dim

* cross

* cross

* support for resolution and aspect_ratio

* support for caption projection

* refactor patch embeddings

* batch_size

* up

* commit

* commit

* commit.

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze

* squeeze.

* squeeze.

* fix final block./

* fix final block./

* fix final block./

* clean

* fix: interpolation scale.

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging'

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* make --checkpoint_path non-required.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* remove num_tokens

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* timesteps -> timestep

* debug

* debug

* update conversion script.

* update conversion script.

* update conversion script.

* debug

* debug

* debug

* clean

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* deug

* debug

* debug

* debug

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* clean

* fix

* fix

* boom

* boom

* some changes

* boom

* save

* up

* remove i

* fix more tests

* DPMSolverMultistepScheduler

* fix

* offloading

* fix conversion script

* fix conversion script

* remove print

* remove support for negative prompt embeds.

* typo.

* remove extra kwargs

* bring conversion script to where it was

* fix

* trying mu luck

* trying my luck again

* again

* again

* again

* clean up

* up

* up

* update example

* support for 512

* remove spacing

* finalize docs.

* test debug

* fix: assertion values.

* debug

* debug

* debug

* fix: repeat

* remove prints.

* Apply suggestions from code review

* Apply suggestions from code review

* Correct more

* Apply suggestions from code review

* Change all

* Clean more

* fix more

* Fix more

* Fix more

* Correct more

* address patrick's comments.

* remove unneeded args

* clean up pipeline.

* sty;e

* make the use of additional conditions better conditioned.

* None better

* dtype

* height and width validation

* add a note about size brackets.

* fix

* spit out slow test outputs.

* fix?

* fix optional test

* fix more

* remove unneeded comment

* debug

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
sayakpaul and patrickvonplaten authored Nov 6, 2023
1 parent f7578d9 commit 801ef05
Show file tree
Hide file tree
Showing 10 changed files with 1,003 additions and 30 deletions.
2 changes: 2 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
"LDMTextToImagePipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
"PixArtAlphaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
Expand Down Expand Up @@ -579,6 +580,7 @@
LDMTextToImagePipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
PixArtAlphaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
Expand Down
64 changes: 53 additions & 11 deletions models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
Expand All @@ -128,6 +129,8 @@ def __init__(

self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"

if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
Expand All @@ -152,7 +155,8 @@ def __init__(
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
Expand All @@ -171,7 +175,7 @@ def __init__(
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
self.attn2 = Attention(
query_dim=dim,
Expand All @@ -187,13 +191,19 @@ def __init__(
self.attn2 = None

# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)

# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)

# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)

# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
Expand All @@ -215,14 +225,25 @@ def forward(
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]

if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")

if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
Expand All @@ -242,19 +263,31 @@ def forward(
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output

hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)

# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 2.5 ends

# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
if self.pos_embed is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")

if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
norm_hidden_states = self.pos_embed(norm_hidden_states)

attn_output = self.attn2(
Expand All @@ -266,11 +299,16 @@ def forward(
hidden_states = attn_output + hidden_states

# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)

if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp

if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
Expand All @@ -291,8 +329,12 @@ def forward(

if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output

hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)

return hidden_states

Expand Down
119 changes: 113 additions & 6 deletions models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,22 @@ def get_timestep_embedding(
return emb


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)

grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, grid_size, grid_size])
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
Expand Down Expand Up @@ -129,6 +134,7 @@ def __init__(
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()

Expand All @@ -144,16 +150,41 @@ def __init__(
else:
self.norm = None

pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)

def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size

latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent + self.pos_embed

# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed

return (latent + pos_embed).to(latent.dtype)


class TimestepEmbedding(nn.Module):
Expand Down Expand Up @@ -683,3 +714,79 @@ def forward(
objs = torch.cat([objs_text, objs_image], dim=1)

return objs


class CombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""

def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()

self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)

self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.use_additional_conditions = True
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)

def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
if size.ndim == 1:
size = size[:, None]

if size.shape[0] != batch_size:
size = size.repeat(batch_size // size.shape[0], 1)
if size.shape[0] != batch_size:
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")

current_batch_size, dims = size.shape[0], size.shape[1]
size = size.reshape(-1)
size_freq = self.additional_condition_proj(size).to(size.dtype)

size_emb = embedder(size_freq)
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
return size_emb

def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)

if self.use_additional_conditions:
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
aspect_ratio = self.apply_condition(
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
)
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
else:
conditioning = timesteps_emb

return conditioning


class CaptionProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""

def __init__(self, in_features, hidden_size, num_tokens=120):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))

def forward(self, caption, force_drop_ids=None):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
37 changes: 35 additions & 2 deletions models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings


class AdaLayerNorm(nn.Module):
Expand Down Expand Up @@ -77,6 +77,39 @@ def forward(
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""

def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()

self.emb = CombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)

self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)

def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
batch_size: int = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep


class AdaGroupNorm(nn.Module):
r"""
GroupNorm layer modified to incorporate timestep embeddings.
Expand Down
Loading

0 comments on commit 801ef05

Please sign in to comment.