From cb612f40979da1d5fa8461d26cc817ad6c42559c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 24 Oct 2023 08:49:43 +0530 Subject: [PATCH] [Core] Refactor activation and normalization layers (#5493) * move out the activations. * move normalization layers. * add doc. * add doc. * fix: paths * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 4 + docs/source/en/api/activations.md | 15 ++ docs/source/en/api/normalization.md | 15 ++ src/diffusers/models/activations.py | 93 ++++++++++ src/diffusers/models/attention.py | 172 +----------------- src/diffusers/models/normalization.py | 115 ++++++++++++ src/diffusers/models/resnet.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 2 +- .../pipelines/unidiffuser/modeling_uvit.py | 3 +- 9 files changed, 249 insertions(+), 172 deletions(-) create mode 100644 docs/source/en/api/activations.md create mode 100644 docs/source/en/api/normalization.md create mode 100644 src/diffusers/models/normalization.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 88da548bd597a..718feeaa11716 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -162,6 +162,10 @@ title: Conceptual Guides - sections: - sections: + - local: api/activations + title: Custom activation functions + - local: api/normalization + title: Custom normalization layers - local: api/attnprocessor title: Attention Processor - local: api/diffusion_pipeline diff --git a/docs/source/en/api/activations.md b/docs/source/en/api/activations.md new file mode 100644 index 0000000000000..684238420ce1a --- /dev/null +++ b/docs/source/en/api/activations.md @@ -0,0 +1,15 @@ +# Activation functions + +Customized activation functions for supporting various models in 🤗 Diffusers. + +## GELU + +[[autodoc]] models.activations.GELU + +## GEGLU + +[[autodoc]] models.activations.GEGLU + +## ApproximateGELU + +[[autodoc]] models.activations.ApproximateGELU \ No newline at end of file diff --git a/docs/source/en/api/normalization.md b/docs/source/en/api/normalization.md new file mode 100644 index 0000000000000..7e09976b15657 --- /dev/null +++ b/docs/source/en/api/normalization.md @@ -0,0 +1,15 @@ +# Normalization layers + +Customized normalization layers for supporting various models in 🤗 Diffusers. + +## AdaLayerNorm + +[[autodoc]] models.normalization.AdaLayerNorm + +## AdaLayerNormZero + +[[autodoc]] models.normalization.AdaLayerNormZero + +## AdaGroupNorm + +[[autodoc]] models.normalization.AdaGroupNorm \ No newline at end of file diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 46da899096c2d..e66d90040fd2a 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -1,5 +1,25 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F from torch import nn +from ..utils import USE_PEFT_BACKEND +from .lora import LoRACompatibleLinear + def get_activation(act_fn: str) -> nn.Module: """Helper function to get activation function from string. @@ -20,3 +40,76 @@ def get_activation(act_fn: str) -> nn.Module: return nn.ReLU() else: raise ValueError(f"Unsupported activation function: {act_fn}") + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + self.proj = linear_cls(dim_in, dim_out * 2) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, scale: float = 1.0): + args = () if USE_PEFT_BACKEND else (scale,) + hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 47608005d3745..80e2afa94a876 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,18 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import torch -import torch.nn.functional as F from torch import nn from ..utils import USE_PEFT_BACKEND from ..utils.torch_utils import maybe_allow_in_graph -from .activations import get_activation +from .activations import GEGLU, GELU, ApproximateGELU from .attention_processor import Attention -from .embeddings import CombinedTimestepLabelEmbeddings from .lora import LoRACompatibleLinear +from .normalization import AdaLayerNorm, AdaLayerNormZero @maybe_allow_in_graph @@ -331,168 +330,3 @@ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tens else: hidden_states = module(hidden_states) return hidden_states - - -class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. - """ - - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - self.approximate = approximate - - def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states - - -class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - - self.proj = linear_cls(dim_in, dim_out * 2) - - def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - - def forward(self, hidden_states, scale: float = 1.0): - args = () if USE_PEFT_BACKEND else (scale,) - hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class ApproximateGELU(nn.Module): - r""" - The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: - https://arxiv.org/abs/1606.08415. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.proj(x) - return x * torch.sigmoid(1.702 * x) - - -class AdaLayerNorm(nn.Module): - r""" - Norm layer modified to incorporate timestep embeddings. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the dictionary of embeddings. - """ - - def __init__(self, embedding_dim: int, num_embeddings: int): - super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) - - def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = torch.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift - return x - - -class AdaLayerNormZero(nn.Module): - r""" - Norm layer adaptive layer norm zero (adaLN-Zero). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the dictionary of embeddings. - """ - - def __init__(self, embedding_dim: int, num_embeddings: int): - super().__init__() - - self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) - - def forward( - self, - x: torch.Tensor, - timestep: torch.Tensor, - class_labels: torch.LongTensor, - hidden_dtype: Optional[torch.dtype] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp - - -class AdaGroupNorm(nn.Module): - r""" - GroupNorm layer modified to incorporate timestep embeddings. - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the dictionary of embeddings. - num_groups (`int`): The number of groups to separate the channels into. - act_fn (`str`, *optional*, defaults to `None`): The activation function to use. - eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. - """ - - def __init__( - self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 - ): - super().__init__() - self.num_groups = num_groups - self.eps = eps - - if act_fn is None: - self.act = None - else: - self.act = get_activation(act_fn) - - self.linear = nn.Linear(embedding_dim, out_dim * 2) - - def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: - if self.act: - emb = self.act(emb) - emb = self.linear(emb) - emb = emb[:, :, None, None] - scale, shift = emb.chunk(2, dim=1) - - x = F.group_norm(x, self.num_groups, eps=self.eps) - x = x * (1 + scale) + shift - return x diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py new file mode 100644 index 0000000000000..dd451b5f3bfc2 --- /dev/null +++ b/src/diffusers/models/normalization.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .activations import get_activation +from .embeddings import CombinedTimestepLabelEmbeddings + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: int): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class AdaLayerNormZero(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: int): + super().__init__() + + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + class_labels: torch.LongTensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaGroupNorm(nn.Module): + r""" + GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. + """ + + def __init__( + self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 80bf269fc4e3c..8fe66aacf5db5 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -22,9 +22,9 @@ from ..utils import USE_PEFT_BACKEND from .activations import get_activation -from .attention import AdaGroupNorm from .attention_processor import SpatialNorm from .lora import LoRACompatibleConv, LoRACompatibleLinear +from .normalization import AdaGroupNorm class Upsample1D(nn.Module): diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index ebd45c09ae334..cfaedd717bef2 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -21,9 +21,9 @@ from ..utils import is_torch_version, logging from ..utils.torch_utils import apply_freeu from .activations import get_activation -from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel +from .normalization import AdaGroupNorm from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index b7829f76ec12f..6e97e0279350f 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -6,9 +6,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin -from ...models.attention import AdaLayerNorm, FeedForward +from ...models.attention import FeedForward from ...models.attention_processor import Attention from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed +from ...models.normalization import AdaLayerNorm from ...models.transformer_2d import Transformer2DModelOutput from ...utils import logging