diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index df64429d8f..b5da229a69 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -342,7 +342,7 @@ def get_data_loader( loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, # shuffle is done in the dataset. + shuffle=True, # if there is no other sampler collate_fn=dataset.collate_fn, drop_last=False, # setting this False might cause issues in AMP training. sampler=sampler, diff --git a/TTS/utils/capacitron_optimizer.py b/TTS/utils/capacitron_optimizer.py index fac7d8a06d..7206ffd508 100644 --- a/TTS/utils/capacitron_optimizer.py +++ b/TTS/utils/capacitron_optimizer.py @@ -38,9 +38,9 @@ def step(self): self.param_groups = self.primary_optimizer.param_groups self.primary_optimizer.step() - def zero_grad(self): - self.primary_optimizer.zero_grad() - self.secondary_optimizer.zero_grad() + def zero_grad(self, set_to_none=False): + self.primary_optimizer.zero_grad(set_to_none) + self.secondary_optimizer.zero_grad(set_to_none) def load_state_dict(self, state_dict): self.primary_optimizer.load_state_dict(state_dict[0])