diff --git a/audiolm_pytorch/trainer.py b/audiolm_pytorch/trainer.py index 148ca60..be22f18 100644 --- a/audiolm_pytorch/trainer.py +++ b/audiolm_pytorch/trainer.py @@ -572,7 +572,7 @@ def train_step(self): wave, = next(self.dl_iter) wave = wave.to(device) - with context(): + with self.accelerator.autocast(), context(): loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True) self.accelerator.backward(loss / self.grad_accum_every) @@ -610,7 +610,7 @@ def train_step(self): wave, = next(self.dl_iter) wave = wave.to(device) - with context(): + with self.accelerator.autocast(), context(): discr_losses = self.soundstream( wave, apply_grad_penalty = apply_grad_penalty, @@ -910,14 +910,18 @@ def train_step(self): logs = {} - # update vae (generator) + # update transformer + + for i in range(self.grad_accum_every): + is_last = i == (self.grad_accum_every - 1) + context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext - for _ in range(self.grad_accum_every): data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter)) - loss = self.train_wrapper(**data_kwargs, return_loss = True) + with self.accelerator.autocast(), context(): + loss = self.train_wrapper(**data_kwargs, return_loss = True) - self.accelerator.backward(loss / self.grad_accum_every) + self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) @@ -1177,17 +1181,21 @@ def train_step(self): logs = {} - # update vae (generator) + # update transformer + + for i in range(self.grad_accum_every): + is_last = i == (self.grad_accum_every - 1) + context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext - for _ in range(self.grad_accum_every): data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter))) - loss = self.train_wrapper( - **data_kwargs, - return_loss = True - ) + with self.accelerator.autocast(), context(): + loss = self.train_wrapper( + **data_kwargs, + return_loss = True + ) - self.accelerator.backward(loss / self.grad_accum_every) + self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) @@ -1453,13 +1461,18 @@ def train_step(self): logs = {} - # update vae (generator) + # update transformer + + for i in range(self.grad_accum_every): + is_last = i == (self.grad_accum_every - 1) + context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext - for _ in range(self.grad_accum_every): data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter)) - loss = self.train_wrapper(**data_kwargs, return_loss = True) - self.accelerator.backward(loss / self.grad_accum_every) + with self.accelerator.autocast(), context(): + loss = self.train_wrapper(**data_kwargs, return_loss = True) + + self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index e5102d3..35424e8 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.9.0' +__version__ = '1.9.1'