diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 4ab0027072..a8a574c03c 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -603,10 +603,21 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): if wav_gen_prev is not None: wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] if wav_overlap is not None: - crossfade_wav = wav_chunk[:overlap_len] - crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) - wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) - wav_chunk[:overlap_len] += crossfade_wav + # cross fade the overlap section + if overlap_len > len(wav_chunk): + # wav_chunk is smaller than overlap_len, pass on last wav_gen + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):] + else: + # not expecting will hit here as problem happens on last chunk + wav_chunk = wav_gen[-overlap_len:] + return wav_chunk, wav_gen, None + else: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + wav_overlap = wav_gen[-overlap_len:] wav_gen_prev = wav_gen return wav_chunk, wav_gen_prev, wav_overlap