Skip to content

Commit

Permalink
only gradient sync on last gradient accum step for all 3 transformer …
Browse files Browse the repository at this point in the history
…training
  • Loading branch information
lucidrains committed Dec 11, 2023
1 parent 01f0008 commit b5ef1b6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
47 changes: 30 additions & 17 deletions audiolm_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def train_step(self):
wave, = next(self.dl_iter)
wave = wave.to(device)

with context():
with self.accelerator.autocast(), context():
loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)

self.accelerator.backward(loss / self.grad_accum_every)
Expand Down Expand Up @@ -610,7 +610,7 @@ def train_step(self):
wave, = next(self.dl_iter)
wave = wave.to(device)

with context():
with self.accelerator.autocast(), context():
discr_losses = self.soundstream(
wave,
apply_grad_penalty = apply_grad_penalty,
Expand Down Expand Up @@ -910,14 +910,18 @@ def train_step(self):

logs = {}

# update vae (generator)
# update transformer

for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

for _ in range(self.grad_accum_every):
data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

loss = self.train_wrapper(**data_kwargs, return_loss = True)
with self.accelerator.autocast(), context():
loss = self.train_wrapper(**data_kwargs, return_loss = True)

self.accelerator.backward(loss / self.grad_accum_every)
self.accelerator.backward(loss / self.grad_accum_every)

accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

Expand Down Expand Up @@ -1177,17 +1181,21 @@ def train_step(self):

logs = {}

# update vae (generator)
# update transformer

for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

for _ in range(self.grad_accum_every):
data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter)))

loss = self.train_wrapper(
**data_kwargs,
return_loss = True
)
with self.accelerator.autocast(), context():
loss = self.train_wrapper(
**data_kwargs,
return_loss = True
)

self.accelerator.backward(loss / self.grad_accum_every)
self.accelerator.backward(loss / self.grad_accum_every)

accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

Expand Down Expand Up @@ -1453,13 +1461,18 @@ def train_step(self):

logs = {}

# update vae (generator)
# update transformer

for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.train_wrapper) if not is_last else nullcontext

for _ in range(self.grad_accum_every):
data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
loss = self.train_wrapper(**data_kwargs, return_loss = True)

self.accelerator.backward(loss / self.grad_accum_every)
with self.accelerator.autocast(), context():
loss = self.train_wrapper(**data_kwargs, return_loss = True)

self.accelerator.backward(loss / self.grad_accum_every)

accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.9.0'
__version__ = '1.9.1'

0 comments on commit b5ef1b6

Please sign in to comment.