Skip to content

Commit

Permalink
feat: add inject channels item
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jan 1, 2023
1 parent 0f54f1f commit ce86eb8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
41 changes: 39 additions & 2 deletions a_unet/apex.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, List, Optional, Sequence
from typing import Callable, List, Optional, Sequence, no_type_check

import torch
from torch import Tensor, nn

from .blocks import (
Expand All @@ -14,6 +15,7 @@
MergeCat,
MergeModulate,
Modulation,
Module,
Packed,
ResnetBlock,
Select,
Expand All @@ -30,8 +32,9 @@

# Selections for item forward parameters
SelectX = Select(lambda x, *_: (x,))
SelectXE = Select(lambda x, f, e, *_: (x, e))
SelectXF = Select(lambda x, f, *_: (x, f))
SelectXE = Select(lambda x, f, e, *_: (x, e))
SelectXC = Select(lambda x, f, e, c, *_: (x, c))


""" Downsample / Upsample """
Expand Down Expand Up @@ -201,6 +204,40 @@ def FeedForwardItem(
)


def InjectChannelsItem(
dim: Optional[int] = None,
channels: Optional[int] = None,
depth: Optional[int] = None,
context_channels: Optional[int] = None,
**kwargs,
) -> nn.Module:
msg = "InjectChannelsItem requires dim, depth, channels, context_channels"
assert (
exists(dim) and exists(depth) and exists(channels) and exists(context_channels)
), msg
msg = "InjectChannelsItem requires context_channels > 0"
assert context_channels > 0, msg

conv = Conv(
dim=dim,
in_channels=channels + context_channels,
out_channels=channels,
kernel_size=1,
)

@no_type_check
def forward(x: Tensor, channels: Sequence[Optional[Tensor]]) -> Tensor:
msg_ = f"context `channels` at depth {depth} in forward"
assert depth < len(channels), f"Required {msg_}"
context = channels[depth]
shape = torch.Size([x.shape[0], context_channels, *x.shape[2:]])
msg = f"Required {msg_} to be tensor of shape {list(shape)}"
assert torch.is_tensor(context) and context.shape == shape, msg
return conv(torch.cat([x, context], dim=1)) + x

return SelectXC(Module)([conv], forward) # type: ignore


""" Skip Adapters """


Expand Down
2 changes: 1 addition & 1 deletion a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def TextConditioningPlugin(
features: int = embedder.embedding_features # type: ignore

def Net(embedding_features: int = features, **kwargs) -> nn.Module:
msg = f"TextConditioningPlugin requires embedding_features={features} "
msg = f"TextConditioningPlugin requires embedding_features={features}"
assert embedding_features == features, msg
net = net_t(embedding_features=embedding_features, **kwargs) # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.9",
version="0.0.10",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit ce86eb8

Please sign in to comment.