Skip to content

Commit

Permalink
feat: add text conditioning plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Dec 30, 2022
1 parent 734e4f3 commit 0f54f1f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
27 changes: 27 additions & 0 deletions a_unet/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def __init__(self, model: str = "t5-base", max_length: int = 64):
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.transformer = T5EncoderModel.from_pretrained(model)
self.max_length = max_length
self.embedding_features = self.transformer.config.d_model

@torch.no_grad()
def forward(self, texts: Sequence[str]) -> Tensor:
Expand Down Expand Up @@ -505,3 +506,29 @@ def forward(
return Module([embedder, mlp, net], forward)

return Net


def TextConditioningPlugin(
net_t: Type[nn.Module], embedder: nn.Module = T5Embedder()
) -> Callable[..., nn.Module]:
"""Adds time conditioning (e.g. for diffusion)"""
msg = "TextConditioningPlugin embedder requires embedding_features attribute"
assert hasattr(embedder, "embedding_features"), msg
features: int = embedder.embedding_features # type: ignore

def Net(embedding_features: int = features, **kwargs) -> nn.Module:
msg = f"TextConditioningPlugin requires embedding_features={features} "
assert embedding_features == features, msg
net = net_t(embedding_features=embedding_features, **kwargs) # type: ignore

def forward(
x: Tensor, text: Sequence[str], embedding: Optional[Tensor] = None, **kwargs
):
text_embedding = embedder(text)
if exists(embedding):
text_embedding = torch.cat([text_embedding, embedding], dim=1)
return net(x, embedding=text_embedding, **kwargs)

return Module([embedder, net], forward)

return Net
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="a-unet",
packages=find_packages(exclude=[]),
version="0.0.8",
version="0.0.9",
license="MIT",
description="A-UNet",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 0f54f1f

Please sign in to comment.