Implementation of AudioLM, a Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
It also extends the work for conditioning with classifier free guidance with T5. This allows for one to do text-to-audio or TTS, not offered in the paper. Yes, this means VALL-E can be trained from this repository. It is essentially the same.
Please join if you are interested in replicating this work in the open
This repository now also contains a MIT licensed version of SoundStream. It is also compatible with EnCodec, which is also MIT-licensed at the time of writing.
Update: AudioLM was essentially used to 'solve' music generation in the new MusicLM
In the future, this movie clip would no longer make any sense. You would just prompt an AI instead.
-
Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
-
🤗 Huggingface for their amazing accelerate and transformers libraries
-
@eonglints and Joseph for offering their professional advice and expertise as well as pull requests!
-
@djqualia, @yigityu, @inspirit, and @BlackFox1197 for helping with the debugging of soundstream
-
Allen and LWprogramming for reviewing the code and submitting bug fixes!
-
Ilya for finding an issue with multi-scale discriminator downsampling and for soundstream trainer improvements
-
Andrey for identifying a missing loss in soundstream and guiding me through the proper mel spectrogram hyperparameters
-
Alejandro and Ilya for sharing their results with training soundstream, and for working through a few issues with the local attention positional embeddings
-
LWprogramming for adding Encodec compatibility!
-
LWprogramming for finding an issue with handling of the EOS token when sampling from the
FineTransformer
! -
@YoungloLee for identifying a big bug in the 1d causal convolution for soundstream related to padding not accounting for strides!
-
Hayden for pointing out some discrepancies in the multi-scale discriminator for Soundstream
$ pip install audiolm-pytorch
There are two options for the neural codec. If you want to use the pretrained 24kHz Encodec, just create an Encodec object as follows:
from audiolm_pytorch import EncodecWrapper
encodec = EncodecWrapper()
# Now you can use the encodec variable in the same way you'd use the soundstream variables below.
Otherwise, to stay more true to the original paper, you can use SoundStream
. First, SoundStream
needs to be trained on a large corpus of audio data
from audiolm_pytorch import SoundStream, SoundStreamTrainer
soundstream = SoundStream(
codebook_size = 4096,
rq_num_quantizers = 8,
rq_groups = 2, # this paper proposes using multi-headed residual vector quantization - https://arxiv.org/abs/2305.02765
use_lookup_free_quantizer = True, # whether to use residual lookup free quantization - there are now reports of successful usage of this unpublished technique
use_finite_scalar_quantizer = False, # whether to use residual finite scalar quantization
attn_window_size = 128, # local attention receptive field at bottleneck
attn_depth = 2 # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
)
trainer = SoundStreamTrainer(
soundstream,
folder = '/path/to/audio/files',
batch_size = 4,
grad_accum_every = 8, # effective batch size of 32
data_max_length_seconds = 2, # train on 2 second audio
num_train_steps = 1_000_000
).cuda()
trainer.train()
# after a lot of training, you can test the autoencoding as so
soundstream.eval() # your soundstream must be in eval mode, to avoid having the residual dropout of the residual VQ necessary for training
audio = torch.randn(10080).cuda()
recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel
Your trained SoundStream
can then be used as a generic tokenizer for audio
audio = torch.randn(1, 512 * 320)
codes = soundstream.tokenize(audio)
# you can now train anything with the codebook ids
recon_audio_from_codes = soundstream.decode_from_codebook_indices(codes)
# sanity check
assert torch.allclose(
recon_audio_from_codes,
soundstream(audio, return_recons_only = True)
)
You can also use soundstreams that are specific to AudioLM
and MusicLM
by importing AudioLMSoundStream
and MusicLMSoundStream
respectively
from audiolm_pytorch import AudioLMSoundStream, MusicLMSoundStream
soundstream = AudioLMSoundStream(...) # say you want the hyperparameters as in Audio LM paper
# rest is the same as above
As of version 0.17.0
, you can now invoke the class method on SoundStream
to load from checkpoint files, without having to remember your configurations.
from audiolm_pytorch import SoundStream
soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt')
To use Weights & Biases tracking, first set use_wandb_tracking = True
on the SoundStreamTrainer
, then do the following
trainer = SoundStreamTrainer(
soundstream,
...,
use_wandb_tracking = True
)
# wrap .train() with contextmanager, specifying project and run name
with trainer.wandb_tracker(project = 'soundstream', run = 'baseline'):
trainer.train()
Then three separate transformers (SemanticTransformer
, CoarseTransformer
, FineTransformer
) need to be trained
ex. SemanticTransformer
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer
# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
wav2vec = HubertWithKmeans(
checkpoint_path = './hubert/hubert_base_ls960.pt',
kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)
semantic_transformer = SemanticTransformer(
num_semantic_tokens = wav2vec.codebook_size,
dim = 1024,
depth = 6,
flash_attn = True
).cuda()
trainer = SemanticTransformerTrainer(
transformer = semantic_transformer,
wav2vec = wav2vec,
folder ='/path/to/audio/files',
batch_size = 1,
data_max_length = 320 * 32,
num_train_steps = 1
)
trainer.train()
ex. CoarseTransformer
import torch
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerTrainer
wav2vec = HubertWithKmeans(
checkpoint_path = './hubert/hubert_base_ls960.pt',
kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)
soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')
coarse_transformer = CoarseTransformer(
num_semantic_tokens = wav2vec.codebook_size,
codebook_size = 1024,
num_coarse_quantizers = 3,
dim = 512,
depth = 6,
flash_attn = True
)
trainer = CoarseTransformerTrainer(
transformer = coarse_transformer,
codec = soundstream,
wav2vec = wav2vec,
folder = '/path/to/audio/files',
batch_size = 1,
data_max_length = 320 * 32,
num_train_steps = 1_000_000
)
trainer.train()
ex. FineTransformer
import torch
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer
soundstream = SoundStream.init_and_load_from('/path/to/trained/soundstream.pt')
fine_transformer = FineTransformer(
num_coarse_quantizers = 3,
num_fine_quantizers = 5,
codebook_size = 1024,
dim = 512,
depth = 6,
flash_attn = True
)
trainer = FineTransformerTrainer(
transformer = fine_transformer,
codec = soundstream,
folder = '/path/to/audio/files',
batch_size = 1,
data_max_length = 320 * 32,
num_train_steps = 1_000_000
)
trainer.train()
All together now
from audiolm_pytorch import AudioLM
audiolm = AudioLM(
wav2vec = wav2vec,
codec = soundstream,
semantic_transformer = semantic_transformer,
coarse_transformer = coarse_transformer,
fine_transformer = fine_transformer
)
generated_wav = audiolm(batch_size = 1)
# or with priming
generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))
# or with text condition, if given
generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])
Update: Looks like this will work, given 'VALL-E'
ex. Semantic Transformer
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer
wav2vec = HubertWithKmeans(
checkpoint_path = './hubert/hubert_base_ls960.pt',
kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)
semantic_transformer = SemanticTransformer(
num_semantic_tokens = 500,
dim = 1024,
depth = 6,
has_condition = True, # this will have to be set to True
cond_as_self_attn_prefix = True # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper
).cuda()
# mock text audio dataset (as an example)
# you will have to extend your own from `Dataset`, and return an audio tensor as well as a string (the audio description) in any order (the framework will autodetect and route it into the transformer)
from torch.utils.data import Dataset
class MockTextAudioDataset(Dataset):
def __init__(self, length = 100, audio_length = 320 * 32):
super().__init__()
self.audio_length = audio_length
self.len = length
def __len__(self):
return self.len
def __getitem__(self, idx):
mock_audio = torch.randn(self.audio_length)
mock_caption = 'audio caption'
return mock_caption, mock_audio
dataset = MockTextAudioDataset()
# instantiate semantic transformer trainer and train
trainer = SemanticTransformerTrainer(
transformer = semantic_transformer,
wav2vec = wav2vec,
dataset = dataset,
batch_size = 4,
grad_accum_every = 8,
data_max_length = 320 * 32,
num_train_steps = 1_000_000
)
trainer.train()
# after much training above
sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_size = 1, max_length = 2) # (1, < 128) - may terminate early if it detects [eos]
Because all the trainer classes uses 🤗 Accelerator, you can easily do multi gpu training by using the accelerate
command as so
At the project root
$ accelerate config
Then, in the same directory
$ accelerate launch train.py
-
complete CoarseTransformer
-
use fairseq vq-wav2vec for embeddings
-
add conditioning
-
add classifier free guidance
-
add unique consecutive for
-
incorporate ability to use hubert intermediate features as semantic tokens, recommended by eonglints
-
accommodate variable lengthed audio, bring in eos token
-
make sure unique consecutive works with coarse transformer
-
pretty printing all discriminator losses to log
-
handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing
-
complete sampling code for both Coarse and Fine Transformers, which will be tricky
-
make sure full inference with or without prompting works on the
AudioLM
class -
complete full training code for soundstream, taking care of discriminator training
-
add efficient gradient penalty for discriminators for soundstream
-
wire up sample hz from sound dataset -> transformers, and have proper resampling within during training - think about whether to allow for dataset to have sound files of varying or enforce same sample hz
-
full transformer training code for all three transformers
-
refactor so semantic transformer has a wrapper to that handles unique consecutives as well as wav to hubert or vq-wav2vec
-
simply not self attend to eos token on the prompting side (semantic for coarse transformer, coarse for fine transformer)
-
add structured dropout from forgetful causal masking, far better than traditional dropouts
-
figure out how to suppress logging in fairseq
-
assert that all three transformers passed into audiolm is compatible
-
allow for specialized relative positional embeddings in fine transformer based on absolute matching positions of quantizers between coarse and fine
-
allow for grouped residual vq in soundstream (use
GroupedResidualVQ
from vector-quantize-pytorch lib), from hifi-codec -
add flash attention with NoPE
-
accept prime wave in
AudioLM
as a path to an audio file, and auto resample for semantic vs acoustic -
add key / value caching to all transformers, speeding up inference
-
design a hierarchical coarse and fine transformer
-
investigate spec decoding, first test in x-transformers, then port over if applicable
-
redo the positional embeddings in the presence of groups in residual vq
-
test with speech synthesis for starters
-
cli tool, something like
audiolm generate <wav.file | text>
and save generated wav file to local directory -
return a list of waves in the case of variable lengthed audio
-
just take care of the edge case in coarse transformer text conditioned training, where the raw wave is resampled at different frequencies. autodetermine how to route based on length
@inproceedings{Borsos2022AudioLMAL,
title = {AudioLM: a Language Modeling Approach to Audio Generation},
author = {Zal{\'a}n Borsos and Rapha{\"e}l Marinier and Damien Vincent and Eugene Kharitonov and Olivier Pietquin and Matthew Sharifi and Olivier Teboul and David Grangier and Marco Tagliasacchi and Neil Zeghidour},
year = {2022}
}
@misc{https://doi.org/10.48550/arxiv.2107.03312,
title = {SoundStream: An End-to-End Neural Audio Codec},
author = {Zeghidour, Neil and Luebs, Alejandro and Omran, Ahmed and Skoglund, Jan and Tagliasacchi, Marco},
publisher = {arXiv},
url = {https://arxiv.org/abs/2107.03312},
year = {2021}
}
@misc{shazeer2020glu,
title = {GLU Variants Improve Transformer},
author = {Noam Shazeer},
year = {2020},
url = {https://arxiv.org/abs/2002.05202}
}
@article{Shazeer2019FastTD,
title = {Fast Transformer Decoding: One Write-Head is All You Need},
author = {Noam M. Shazeer},
journal = {ArXiv},
year = {2019},
volume = {abs/1911.02150}
}
@article{Ho2022ClassifierFreeDG,
title = {Classifier-Free Diffusion Guidance},
author = {Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2207.12598}
}
@misc{crowson2022,
author = {Katherine Crowson},
url = {https://twitter.com/rivershavewings}
}
@misc{ding2021cogview,
title = {CogView: Mastering Text-to-Image Generation via Transformers},
author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
year = {2021},
eprint = {2105.13290},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@article{Liu2022FCMFC,
title = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
author = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.13432}
}
@inproceedings{anonymous2022normformer,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=GMYWzWztDx5},
note = {under review}
}
@misc{liu2021swin,
title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@article{Li2021LocalViTBL,
title = {LocalViT: Bringing Locality to Vision Transformers},
author = {Yawei Li and K. Zhang and Jie Cao and Radu Timofte and Luc Van Gool},
journal = {ArXiv},
year = {2021},
volume = {abs/2104.05707}
}
@article{Defossez2022HighFN,
title = {High Fidelity Neural Audio Compression},
author = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.13438}
}
@article{Hu2017SqueezeandExcitationN,
title = {Squeeze-and-Excitation Networks},
author = {Jie Hu and Li Shen and Gang Sun},
journal = {2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year = {2017},
pages = {7132-7141}
}
@inproceedings{Yang2023HiFiCodecGV,
title = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
author = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
year = {2023}
}
@article{Kazemnejad2023TheIO,
title = {The Impact of Positional Encoding on Length Generalization in Transformers},
author = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.19466}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{Katsch2023GateLoopFD,
title = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
author = {Tobias Katsch},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:265018962}
}
@article{Fifty2024Restructuring,
title = {Restructuring Vector Quantization with the Rotation Trick},
author = {Christopher Fifty, Ronald G. Junkins, Dennis Duan, Aniketh Iyengar, Jerry W. Liu, Ehsan Amid, Sebastian Thrun, Christopher Ré},
journal = {ArXiv},
year = {2024},
volume = {abs/2410.06424},
url = {https://api.semanticscholar.org/CorpusID:273229218}
}
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}