Skip to content

Commit

Permalink
Fix Capacitron training (#2086)
Browse files Browse the repository at this point in the history
  • Loading branch information
victor-shepardson authored Nov 1, 2022
1 parent 5ccef6e commit 5307a22
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def get_data_loader(
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False, # shuffle is done in the dataset.
shuffle=True, # if there is no other sampler
collate_fn=dataset.collate_fn,
drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler,
Expand Down
6 changes: 3 additions & 3 deletions TTS/utils/capacitron_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def step(self):
self.param_groups = self.primary_optimizer.param_groups
self.primary_optimizer.step()

def zero_grad(self):
self.primary_optimizer.zero_grad()
self.secondary_optimizer.zero_grad()
def zero_grad(self, set_to_none=False):
self.primary_optimizer.zero_grad(set_to_none)
self.secondary_optimizer.zero_grad(set_to_none)

def load_state_dict(self, state_dict):
self.primary_optimizer.load_state_dict(state_dict[0])
Expand Down

0 comments on commit 5307a22

Please sign in to comment.