Skip to content

Commit

Permalink
feat: add pretrained/frozen T5 embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 13, 2022
1 parent d16dfa7 commit 315ab44
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 4 deletions.
25 changes: 22 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ from audio_diffusion_pytorch import AudioDiffusionConditional

model = AudioDiffusionConditional(
in_channels=1,
embedding_max_length=512,
embedding_max_length=64,
embedding_features=768,
embedding_mask_proba=0.1 # Conditional dropout of batch elements
)

# Train on pairs of audio and embedding data (e.g. from a transformer output)
x = torch.randn(2, 1, 2 ** 18)
embedding = torch.randn(2, 512, 768)
embedding = torch.randn(2, 64, 768)
loss = model(x, embedding=embedding)
loss.backward()

# Given start embedding and noise sample new source
embedding = torch.randn(1, 512, 768)
embedding = torch.randn(1, 64, 768)
noise = torch.randn(1, 1, 2 ** 18)
sampled = model.sample(
noise,
Expand All @@ -115,6 +115,25 @@ sampled = model.sample(
) # [1, 1, 2 ** 18]
```

#### Text Conditional Generation
You can generate embeddings from text by using a pretrained frozen T5 transformer with `T5Embedder`, as follows (note that this requires `pip install transformers`):

```py
from audio_diffusion_pytorch import T5Embedder

embedder = T5Embedder(model='t5-base', max_length=64)
embedding = embedder(["First batch item text...", "Second batch item text..."])

loss = model(x, embedding=embedding)
# ...
sampled = model.sample(
noise,
embedding=embedding,
embedding_scale=5.0, # Classifier-free guidance scale
num_steps=5
)
```

## Usage with Components

### UNet1d
Expand Down
1 change: 1 addition & 0 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .modules import (
AutoEncoder1d,
MultiEncoder1d,
T5Embedder,
UNet1d,
UNetConditional1d,
Variational,
Expand Down
32 changes: 32 additions & 0 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,9 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)


""" Conditioning """


class UNetConditional1d(UNet1d):
"""
UNet1d with classifier-free guidance on the token embeddings
Expand Down Expand Up @@ -1130,6 +1133,35 @@ def forward( # type: ignore
return out


class T5Embedder(nn.Module):
def __init__(self, model: str = "t5-base", max_length: int = 64):
super().__init__()
from transformers import T5EncoderModel, T5Tokenizer

self.tokenizer = T5Tokenizer.from_pretrained(model)
self.transformer = T5EncoderModel.from_pretrained(model)
self.max_length = max_length

@torch.no_grad()
def forward(self, texts: List[str]) -> Tensor:

encoded = self.tokenizer(
texts,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)

self.transformer.eval()

embedding = self.transformer(
input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"]
)["last_hidden_state"]

return embedding


"""
Encoders / Decoders
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.61",
version="0.0.62",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 315ab44

Please sign in to comment.