diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index d66a5923868dc5..6bf52d3e1b1bc9 100644 --- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -23,6 +23,7 @@ import datasets import numpy as np from datasets import load_dataset +from packaging import version from transformers import AutoProcessor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor @@ -435,21 +436,19 @@ def test_word_time_stamp_integration(self): self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text) # output times - start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")] - end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")] + start_times = torch.tensor(self.get_from_offsets(word_time_stamps, "start_time")) + end_times = torch.tensor(self.get_from_offsets(word_time_stamps, "end_time")) # fmt: off - self.assertListEqual( - start_times, - [ - 1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66, - ], - ) - - self.assertListEqual( - end_times, - [ - 1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94, - ], - ) + expected_start_tensor = torch.tensor([1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66]) + + # TODO(Patrick): This if-else version statement should be removed once + # https://github.com/huggingface/datasets/issues/4889 is resolved + if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12.0"): + expected_end_tensor = torch.tensor([1.54, 1.88, 2.14, 2.46, 2.9, 3.16, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94]) + else: + expected_end_tensor = torch.tensor([1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94]) # fmt: on + + self.assertTrue(torch.allclose(start_times, expected_start_tensor, atol=0.01)) + self.assertTrue(torch.allclose(end_times, expected_end_tensor, atol=0.01))