Skip to content

Commit

Permalink
feat: add embedders, cfg and time conditioning plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Dec 28, 2022
1 parent 97deaa5 commit a6de0f8
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# A-UNet

A toolbox that provides hackable building blocks for (1D/2D/3D) UNets, in PyTorch.
A toolbox that provides hackable building blocks for generic 1D/2D/3D UNets, in PyTorch.

## Install
```bash
Expand Down
121 changes: 116 additions & 5 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from math import pi
from typing import Any, Callable, Optional, Sequence, Type, TypeVar, Union

import torch
from einops import pack, rearrange, repeat, unpack
import torch.nn.functional as F
from einops import pack, rearrange, reduce, repeat, unpack
from torch import Tensor, einsum, nn
from typing_extensions import TypeGuard

Expand Down Expand Up @@ -351,13 +353,79 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)


def CFG(
"""
Embedders
"""


class NumberEmbedder(nn.Module):
def __init__(self, features: int, dim: int = 256):
super().__init__()
assert dim % 2 == 0, f"dim must be divisible by 2, found {dim}"
self.features = features
self.weights = nn.Parameter(torch.randn(dim // 2))
self.to_out = nn.Linear(in_features=dim + 1, out_features=features)

def to_embedding(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return self.to_out(fouriered)

def forward(self, x: Union[Sequence[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
x = torch.tensor(x, device=self.weights.device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
return self.to_embedding(x).view(*shape, self.features) # type: ignore


class T5Embedder(nn.Module):
def __init__(self, model: str = "t5-base", max_length: int = 64):
super().__init__()
from transformers import AutoTokenizer, T5EncoderModel

self.tokenizer = AutoTokenizer.from_pretrained(model)
self.transformer = T5EncoderModel.from_pretrained(model)
self.max_length = max_length

@torch.no_grad()
def forward(self, texts: Sequence[str]) -> Tensor:
encoded = self.tokenizer(
texts,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)

device = next(self.transformer.parameters()).device
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)

self.transformer.eval()

embedding = self.transformer(
input_ids=input_ids, attention_mask=attention_mask
)["last_hidden_state"]

return embedding


"""
Plugins
"""


def ClassifierFreeGuidancePlugin(
net_t: Type[nn.Module],
embedding_max_length: int,
) -> Callable[..., nn.Module]:
"""Classifier-Free Guidance -> CFG(UNet, embedding_max_length=512)(...)"""

def CFGNet(embedding_features: int, **kwargs) -> nn.Module:
def Net(embedding_features: int, **kwargs) -> nn.Module:
fixed_embedding = FixedEmbedding(
max_length=embedding_max_length,
features=embedding_features,
Expand All @@ -371,7 +439,8 @@ def forward(
embedding_mask_proba: float = 0.0,
**kwargs,
):
assert exists(embedding), "embedding required when using CFG"
msg = "ClassiferFreeGuidancePlugin requires embedding"
assert exists(embedding), msg
b, device = embedding.shape[0], embedding.device
embedding_mask = fixed_embedding(embedding)

Expand All @@ -393,4 +462,46 @@ def forward(

return Module([fixed_embedding, net], forward)

return CFGNet
return Net


def TimeConditioningPlugin(
net_t: Type[nn.Module],
num_layers: int = 2,
) -> Callable[..., nn.Module]:
"""Adds time conditioning (e.g. for diffusion)"""

def Net(modulation_features: Optional[int] = None, **kwargs) -> nn.Module:
msg = "TimeConditioningPlugin requires modulation_features"
assert exists(modulation_features), msg

embedder = NumberEmbedder(features=modulation_features)
mlp = Repeat(
nn.Sequential(
nn.Linear(modulation_features, modulation_features), nn.GELU()
),
times=num_layers,
)
net = net_t(modulation_features=modulation_features, **kwargs) # type: ignore

def forward(
x: Tensor,
time: Optional[Tensor] = None,
features: Optional[Tensor] = None,
**kwargs,
):
msg = "TimeConditioningPlugin requires time in forward"
assert exists(time), msg
# Process time to time_features
time_features = F.gelu(embedder(time))
time_features = mlp(time_features)
# Overlap features if more than one per batch
if time_features.ndim == 3:
time_features = reduce(time_features, "b n d -> b d", "sum")
# Merge time features with features if provided
features = features + time_features if exists(features) else time_features
return net(x, features=features, **kwargs)

return Module([embedder, mlp, net], forward)

return Net
29 changes: 17 additions & 12 deletions a_unet/unet/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def DownsampleItem(
factor: Optional[int] = None,
in_channels: Optional[int] = None,
channels: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "DownsampleItem requires dim, factor, in_channels, channels"
assert (
Expand All @@ -59,7 +59,7 @@ def UpsampleItem(
factor: Optional[int] = None,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "UpsampleItem requires dim, factor, channels, out_channels"
assert (
Expand All @@ -78,7 +78,7 @@ def ResnetItem(
dim: Optional[int] = None,
channels: Optional[int] = None,
resnet_groups: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "ResnetItem requires dim, channels, and resnet_groups"
assert exists(dim) and exists(channels) and exists(resnet_groups), msg
Expand All @@ -93,7 +93,7 @@ def AttentionItem(
channels: Optional[int] = None,
attention_features: Optional[int] = None,
attention_heads: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "AttentionItem requires channels, attention_features, attention_heads"
assert (
Expand All @@ -114,7 +114,7 @@ def CrossAttentionItem(
attention_features: Optional[int] = None,
attention_heads: Optional[int] = None,
embedding_features: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "CrossAttentionItem requires channels, embedding_features, attention_*"
assert (
Expand Down Expand Up @@ -149,7 +149,7 @@ def LinearAttentionItem(
channels: Optional[int] = None,
attention_features: Optional[int] = None,
attention_heads: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "LinearAttentionItem requires attention_features and attention_heads"
assert (
Expand All @@ -170,7 +170,7 @@ def LinearCrossAttentionItem(
attention_features: Optional[int] = None,
attention_heads: Optional[int] = None,
embedding_features: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "LinearCrossAttentionItem requires channels, embedding_features, attention_*"
assert (
Expand Down Expand Up @@ -208,7 +208,7 @@ def SkipAdapterItem(
dim: Optional[int] = None,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
**kwargs
**kwargs,
):
msg = "SkipAdapterItem requires dim, in_channels, out_channels"
assert exists(dim) and exists(in_channels) and exists(out_channels), msg
Expand Down Expand Up @@ -244,7 +244,7 @@ def SkipModulateItem(
dim: Optional[int] = None,
out_channels: Optional[int] = None,
modulation_features: Optional[int] = None,
**kwargs
**kwargs,
) -> nn.Module:
msg = "SkipModulateItem requires dim, out_channels, modulation_features"
assert exists(dim) and exists(out_channels) and exists(modulation_features), msg
Expand All @@ -268,7 +268,7 @@ def __init__(
items_up: Optional[Sequence[Callable]] = None,
out_channels: Optional[int] = None,
inner_block: Optional[nn.Module] = None,
**kwargs
**kwargs,
):
super().__init__()
out_channels = default(out_channels, in_channels)
Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(
in_channels: int,
blocks: Sequence,
out_channels: Optional[int] = None,
**kwargs
**kwargs,
):
super().__init__()
num_layers = len(blocks)
Expand All @@ -330,14 +330,19 @@ def Net(i: int) -> Optional[nn.Module]:
out_ch = out_channels if i == 0 else in_ch

return block_t(
in_channels=in_ch, out_channels=out_ch, inner_block=Net(i + 1), **kwargs
in_channels=in_ch,
out_channels=out_ch,
depth=i,
inner_block=Net(i + 1),
**kwargs,
)

self.net = Net(0)

def forward(
self,
x: Tensor,
*,
features: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
channels: Optional[Sequence[Tensor]] = None,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.6",
version="0.0.7",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit a6de0f8

Please sign in to comment.