Skip to content

Commit

Permalink
[Core] Refactor activation and normalization layers (huggingface#5493)
Browse files Browse the repository at this point in the history
* move out the activations.

* move normalization layers.

* add doc.

* add doc.

* fix: paths

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
  • Loading branch information
2 people authored and kashif committed Nov 11, 2023
1 parent a603cd7 commit cb612f4
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 172 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@
title: Conceptual Guides
- sections:
- sections:
- local: api/activations
title: Custom activation functions
- local: api/normalization
title: Custom normalization layers
- local: api/attnprocessor
title: Attention Processor
- local: api/diffusion_pipeline
Expand Down
15 changes: 15 additions & 0 deletions docs/source/en/api/activations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Activation functions

Customized activation functions for supporting various models in 🤗 Diffusers.

## GELU

[[autodoc]] models.activations.GELU

## GEGLU

[[autodoc]] models.activations.GEGLU

## ApproximateGELU

[[autodoc]] models.activations.ApproximateGELU
15 changes: 15 additions & 0 deletions docs/source/en/api/normalization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Normalization layers

Customized normalization layers for supporting various models in 🤗 Diffusers.

## AdaLayerNorm

[[autodoc]] models.normalization.AdaLayerNorm

## AdaLayerNormZero

[[autodoc]] models.normalization.AdaLayerNormZero

## AdaGroupNorm

[[autodoc]] models.normalization.AdaGroupNorm
93 changes: 93 additions & 0 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F
from torch import nn

from ..utils import USE_PEFT_BACKEND
from .lora import LoRACompatibleLinear


def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Expand All @@ -20,3 +40,76 @@ def get_activation(act_fn: str) -> nn.Module:
return nn.ReLU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")


class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
"""

def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)

def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states


class GEGLU(nn.Module):
r"""
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

self.proj = linear_cls(dim_in, dim_out * 2)

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states, scale: float = 1.0):
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)


class ApproximateGELU(nn.Module):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
[paper](https://arxiv.org/abs/1606.08415).
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
172 changes: 3 additions & 169 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import nn

from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import get_activation
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero


@maybe_allow_in_graph
Expand Down Expand Up @@ -331,168 +330,3 @@ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tens
else:
hidden_states = module(hidden_states)
return hidden_states


class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
"""

def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)

def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states


class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear

self.proj = linear_cls(dim_in, dim_out * 2)

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states, scale: float = 1.0):
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)


class ApproximateGELU(nn.Module):
r"""
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
https://arxiv.org/abs/1606.08415.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)


class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
"""

def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)

def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x


class AdaLayerNormZero(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
"""

def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__()

self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)

self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
class_labels: torch.LongTensor,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaGroupNorm(nn.Module):
r"""
GroupNorm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
num_groups (`int`): The number of groups to separate the channels into.
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
"""

def __init__(
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
):
super().__init__()
self.num_groups = num_groups
self.eps = eps

if act_fn is None:
self.act = None
else:
self.act = get_activation(act_fn)

self.linear = nn.Linear(embedding_dim, out_dim * 2)

def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
if self.act:
emb = self.act(emb)
emb = self.linear(emb)
emb = emb[:, :, None, None]
scale, shift = emb.chunk(2, dim=1)

x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x
Loading

0 comments on commit cb612f4

Please sign in to comment.