Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into t5x_to_pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
bastings authored Dec 22, 2022
2 parents f04ea71 + 4d10ffd commit 0f279bb
Showing 1 changed file with 73 additions and 12 deletions.
85 changes: 73 additions & 12 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
input (see `past_key_values`). This is useful if you want more control over how to convert
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -470,6 +482,7 @@ def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: torch.Tensor = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
Expand All @@ -480,6 +493,8 @@ def forward(
input_ids (`torch.LongTensor`): tokens in the source language of shape
*(batch, src_len)*
attention_mask (`torch.LongTensor`): indicating which indices are padding tokens
inputs_embeds (`torch.FloatTensor`):
embedding vectors of shape *(batch, src_len, embed_dim)*
head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
Expand All @@ -499,8 +514,24 @@ def forward(
if attention_mask is not None:
attention_mask = invert_mask(attention_mask)

inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids)
elif inputs_embeds is not None:
inputs_embeds = inputs_embeds * self.embed_scale

# We assume zeros hidden states correspond to padding tokens
# and create `position_ids` where inputs_embeds[:, :, 0] == 0
position_ids = inputs_embeds[:, :, 0].masked_fill(
inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
)

embed_pos = self.embed_positions(position_ids)
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

x = inputs_embeds + embed_pos
x = nn.functional.dropout(x, p=self.dropout, training=self.training)

Expand Down Expand Up @@ -675,6 +706,7 @@ def forward(
decoder_padding_mask: torch.Tensor,
decoder_causal_mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
Expand Down Expand Up @@ -717,15 +749,26 @@ def forward(
if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask)

# embed positions
positions = self.embed_positions(input_ids) # , use_cache=use_cache)

if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
# assert input_ids.ne(self.padding_idx).any()
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
# embed positions
positions = self.embed_positions(input_ids)
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
x = self.embed_tokens(input_ids) * self.embed_scale
elif inputs_embeds is not None:
# We assume zeros hidden states correspond to padding tokens
# and create `position_ids` where inputs_embeds[:, :, 0] == 0
position_ids = inputs_embeds[:, :, 0].masked_fill(
inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
)
positions = self.embed_positions(position_ids)
x = inputs_embeds * self.embed_scale
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

x = self.embed_tokens(input_ids) * self.embed_scale
x += positions
x = nn.functional.dropout(x, p=self.dropout, training=self.training)

Expand Down Expand Up @@ -1007,6 +1050,12 @@ def __init__(self, config: FSMTConfig):
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.encoder

def get_decoder(self):
return self.decoder

@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
Expand All @@ -1028,6 +1077,8 @@ def forward(
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
if decoder_input_ids is None:
Expand All @@ -1041,7 +1092,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# make masks if user doesn't supply
if not use_cache:
if not use_cache and input_ids is not None:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs(
self.config,
input_ids,
Expand All @@ -1052,12 +1103,14 @@ def forward(
else:
decoder_padding_mask, causal_mask = None, None

assert decoder_input_ids is not None
if decoder_input_ids is None and decoder_inputs_embeds is None:
raise ValueError("Make sure that `decoder_input_ids` or `decoder_inputs_embeds` are passed.")

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand All @@ -1078,6 +1131,7 @@ def forward(
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
Expand Down Expand Up @@ -1148,6 +1202,8 @@ def forward(
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand All @@ -1170,8 +1226,10 @@ def forward(

outputs = self.model(
input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_inputs_embeds=decoder_inputs_embeds,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
Expand Down Expand Up @@ -1248,6 +1306,9 @@ def _reorder_cache(past, beam_idx):
def get_encoder(self):
return self.model.encoder

def get_decoder(self):
return self.model.decoder

def get_output_embeddings(self):
return self.model.decoder.embed_tokens

Expand Down

0 comments on commit 0f279bb

Please sign in to comment.