Skip to content

Commit

Permalink
feat: set default scale to SkipCat, fix default embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jan 5, 2023
1 parent ce86eb8 commit 50737ba
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
9 changes: 6 additions & 3 deletions a_unet/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def forward(x: Tensor, channels: Sequence[Optional[Tensor]]) -> Tensor:
assert depth < len(channels), f"Required {msg_}"
context = channels[depth]
shape = torch.Size([x.shape[0], context_channels, *x.shape[2:]])
msg = f"Required {msg_} to be tensor of shape {list(shape)}"
msg = f"Required {msg_} to be tensor of shape {shape}, found {context.shape}"
assert torch.is_tensor(context) and context.shape == shape, msg
return conv(torch.cat([x, context], dim=1)) + x

Expand Down Expand Up @@ -270,11 +270,14 @@ def SkipAdd(**kwargs) -> nn.Module:


def SkipCat(
dim: Optional[int] = None, out_channels: Optional[int] = None, **kwargs
dim: Optional[int] = None,
out_channels: Optional[int] = None,
skip_scale: float = 2**-0.5,
**kwargs,
) -> nn.Module:
msg = "SkipCat requires dim, out_channels"
assert exists(dim) and exists(out_channels), msg
return MergeCat(dim=dim, channels=out_channels)
return MergeCat(dim=dim, channels=out_channels, scale=skip_scale)


def SkipModulate(
Expand Down
13 changes: 7 additions & 6 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ def MergeAdd():
return Module([], lambda x, y, *_: x + y)


def MergeCat(dim: int, channels: int) -> nn.Module:
def MergeCat(dim: int, channels: int, scale: float = 2**-0.5) -> nn.Module:
conv = Conv(dim=dim, in_channels=channels * 2, out_channels=channels, kernel_size=1)
return Module([conv], lambda x, y, *_: conv(torch.cat([x, y], dim=1)))
return Module([conv], lambda x, y, *_: conv(torch.cat([x * scale, y], dim=1)))


def MergeModulate(dim: int, channels: int, modulation_features: int):
Expand Down Expand Up @@ -509,9 +509,10 @@ def forward(


def TextConditioningPlugin(
net_t: Type[nn.Module], embedder: nn.Module = T5Embedder()
net_t: Type[nn.Module], embedder: Optional[nn.Module] = None
) -> Callable[..., nn.Module]:
"""Adds time conditioning (e.g. for diffusion)"""
"""Adds text conditioning"""
embedder = embedder if exists(embedder) else T5Embedder()
msg = "TextConditioningPlugin embedder requires embedding_features attribute"
assert hasattr(embedder, "embedding_features"), msg
features: int = embedder.embedding_features # type: ignore
Expand All @@ -524,11 +525,11 @@ def Net(embedding_features: int = features, **kwargs) -> nn.Module:
def forward(
x: Tensor, text: Sequence[str], embedding: Optional[Tensor] = None, **kwargs
):
text_embedding = embedder(text)
text_embedding = embedder(text) # type: ignore
if exists(embedding):
text_embedding = torch.cat([text_embedding, embedding], dim=1)
return net(x, embedding=text_embedding, **kwargs)

return Module([embedder, net], forward)
return Module([embedder, net], forward) # type: ignore

return Net
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.10",
version="0.0.11",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 50737ba

Please sign in to comment.