Skip to content

Commit

Permalink
[Wav2Vec2FeatureExtractor] Fix extractor.pad() dtype backwards comp…
Browse files Browse the repository at this point in the history
…atibility (huggingface#13693)

* Force dtype, add tests

* Local torch imports

* Remove unused logic (always ndarray)
  • Loading branch information
anton-l authored and Narsil committed Sep 25, 2021
1 parent bd78130 commit e88712c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
19 changes: 2 additions & 17 deletions src/transformers/feature_extraction_sequence_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,23 +187,6 @@ def pad(
padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)

required_input = processed_features[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], np.ndarray):
# truncation
processed_features = self._truncate(
processed_features,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
truncation=truncation,
)
# padding
processed_features = self._pad(
processed_features,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchFeature(processed_features, tensor_type=return_tensors)

batch_size = len(required_input)
if not all(len(v) == batch_size for v in processed_features.values()):
Expand Down Expand Up @@ -240,6 +223,8 @@ def pad(
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
if value.dtype is np.dtype(np.float64):
value = value.astype(np.float32)
batch_outputs[key].append(value)

return BatchFeature(batch_outputs, tensor_type=return_tensors)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_feature_extraction_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,16 @@ def test_cepstral_mean_and_variance_normalization_trunc_longest(self):

# make sure that if max_length < longest -> then pad to max_length
self.assertEqual(input_features.shape, (3, 6, 24))

def test_double_precision_pad(self):
import torch

feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
py_speech_inputs = np_speech_inputs.tolist()

for inputs in [py_speech_inputs, np_speech_inputs]:
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
14 changes: 14 additions & 0 deletions tests/test_feature_extraction_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
# make sure that if max_length > longest -> then pad to longest
self.assertTrue(input_values.shape == (3, 1200))

@require_torch
def test_double_precision_pad(self):
import torch

feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_speech_inputs = np.random.rand(100).astype(np.float64)
py_speech_inputs = np_speech_inputs.tolist()

for inputs in [py_speech_inputs, np_speech_inputs]:
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_values.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_values.dtype == torch.float32)

@slow
@require_torch
def test_pretrained_checkpoints_are_set_correctly(self):
Expand Down

0 comments on commit e88712c

Please sign in to comment.