Skip to content

Commit

Permalink
[Test] Fix W2V-Conformer integration test (huggingface#17303)
Browse files Browse the repository at this point in the history
* [Test] Fix W2V-Conformer integration test

* correct w2v2

* up
  • Loading branch information
patrickvonplaten authored May 17, 2022
1 parent 28a0811 commit 10704e1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
1 change: 0 additions & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,6 @@ def forward(
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
>>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1442,7 +1442,7 @@ def compute_contrastive_logits(

@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2-base->wav2vec2-conformer-rel-pos-large,wav2vec2->wav2vec2_conformer
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
def forward(
self,
input_values: Optional[torch.Tensor],
Expand Down Expand Up @@ -1470,14 +1470,9 @@ def forward(
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
... )
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained(
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
... )
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,10 @@ def _mock_init_weights(self, module):
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
module.pos_bias_u.data.fill_(3)
if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
module.pos_bias_v.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
Expand Down

0 comments on commit 10704e1

Please sign in to comment.