Skip to content

Commit

Permalink
Revert "Mamba / FalconMamba: Fix mamba left padding (#32677)"
Browse files Browse the repository at this point in the history
This reverts commit 91b799b.
  • Loading branch information
ArthurZucker committed Aug 20, 2024
1 parent 91b799b commit c0e2699
Show file tree
Hide file tree
Showing 22 changed files with 258 additions and 121 deletions.
29 changes: 26 additions & 3 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,10 +619,33 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,17 +1132,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,17 +1403,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
33 changes: 30 additions & 3 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,10 +1123,37 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,17 +1143,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
19 changes: 4 additions & 15 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

return causal_mask


Expand Down Expand Up @@ -301,7 +300,6 @@ def forward(
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
Expand Down Expand Up @@ -502,11 +500,9 @@ def forward(

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
Expand All @@ -517,7 +513,6 @@ def forward(
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand Down Expand Up @@ -583,7 +578,6 @@ def forward(
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -996,7 +990,6 @@ def forward(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand Down Expand Up @@ -1063,7 +1056,6 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
Expand All @@ -1079,10 +1071,10 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)}
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if (
isinstance(past_key_values, HybridCache)
Expand All @@ -1093,12 +1085,10 @@ def prepare_inputs_for_generation(
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device

batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
Expand All @@ -1109,7 +1099,6 @@ def prepare_inputs_for_generation(
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
"position_ids": position_ids,
Expand Down
29 changes: 26 additions & 3 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,10 +815,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
Expand Down
30 changes: 27 additions & 3 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,10 +1054,34 @@ def prepare_inputs_for_generation(
attention_mask = input_ids.new_ones(input_shape)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.embed_out.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
"attention_mask": attention_mask,
Expand Down
29 changes: 26 additions & 3 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,10 +964,33 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,17 +1265,18 @@ def prepare_inputs_for_generation(

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
Expand Down
Loading

0 comments on commit c0e2699

Please sign in to comment.