Skip to content

Commit

Permalink
allow for prompt audio to be passed in as prime_wave_path to AudioLM
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 1, 2023
1 parent 9fd9e45 commit 22951ab
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,10 @@ $ accelerate launch train.py
- [x] allow for specialized relative positional embeddings in fine transformer based on absolute matching positions of quantizers between coarse and fine
- [x] allow for grouped residual vq in soundstream (use `GroupedResidualVQ` from vector-quantize-pytorch lib), from <a href="https://arxiv.org/abs/2305.02765">hifi-codec</a>
- [x] add flash attention with <a href="https://arxiv.org/abs/2305.19466">NoPE</a>
- [x] accept prime wave in `AudioLM` as a path to an audio file, and auto resample for semantic vs acoustic

- [ ] design a hierarchical coarse and fine transformer
- [ ] investigate <a href="https://openreview.net/forum?id=H-VlwsYvVi">spec decoding</a>, first test in x-transformers, then port over if applicable
- [ ] accept prime wave in `AudioLM` as a path to an audio file, and auto resample for semantic vs acoustic

- [ ] redo the positional embeddings in the presence of groups in residual vq
- [ ] test with speech synthesis for starters
Expand Down
12 changes: 12 additions & 0 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

import torchaudio

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

Expand Down Expand Up @@ -1901,6 +1903,7 @@ def forward(
text_embeds: Optional[Tensor] = None,
prime_wave = None,
prime_wave_input_sample_hz = None,
prime_wave_path = None,
max_length = 2048,
return_coarse_generated_wave = False,
mask_out_generated_fine_tokens = False
Expand All @@ -1911,7 +1914,16 @@ def forward(
if exists(text):
text_embeds = self.semantic.embed_text(text)

assert not (exists(prime_wave) and exists(prime_wave_path)), 'prompt audio must be given as either `prime_wave: Tensor` or `prime_wave_path: str`'

if exists(prime_wave):
assert exists(prime_wave_input_sample_hz), 'the input sample frequency for the prompt audio must be given as `prime_wave_input_sample_hz: int`'
prime_wave = prime_wave.to(self.device)
elif exists(prime_wave_path):
prime_wave_path = Path(prime_wave_path)
assert exists(prime_wave_path), f'file does not exist at {str(prime_wave_path)}'

prime_wave, prime_wave_input_sample_hz = torchaudio.load(str(prime_wave_path))
prime_wave = prime_wave.to(self.device)

semantic_token_ids = self.semantic.generate(
Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.23'
__version__ = '1.2.24'

0 comments on commit 22951ab

Please sign in to comment.