diff --git a/whisperx/decoding.py b/whisperx/decoding.py index ca608ca4..9d18cb6d 100644 --- a/whisperx/decoding.py +++ b/whisperx/decoding.py @@ -428,6 +428,13 @@ def apply(self, logits: Tensor, tokens: Tensor): if timestamps.numel() > 0: # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf + + # to force that timestamps are strictly increasing + if last_was_timestamp and not penultimate_was_timestamp: + timestamp_last = timestamps[-1] + else: + timestamp_last = timestamps[-1] + 1 + logits[k, self.tokenizer.timestamp_begin: timestamp_last] = -np.inf if tokens.shape[1] == self.sample_begin: # suppress generating non-timestamp tokens at the beginning