From 4e33582bf7571e5e495f6b23ebb5d5d4335a9c3c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Aug 2024 08:27:11 +0000 Subject: [PATCH 1/9] fix mamba left padding --- .../falcon_mamba/modeling_falcon_mamba.py | 48 ++++++++++++++++--- .../models/mamba/modeling_mamba.py | 48 ++++++++++++++++--- .../test_modeling_falcon_mamba.py | 29 +++++++++-- tests/models/mamba/test_modeling_mamba.py | 9 ++-- 4 files changed, 115 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 4bcd0e9d467d12..9af66d87609cf0 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -155,6 +155,7 @@ def cuda_kernels_forward( hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -179,6 +180,9 @@ def cuda_kernels_forward( else: hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if cache_params is not None and cache_position[0] > 0: @@ -200,6 +204,9 @@ def cuda_kernels_forward( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -259,6 +266,7 @@ def slow_forward( input_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype @@ -266,6 +274,9 @@ def slow_forward( projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation if cache_params is not None: ssm_state = cache_params.ssm_states[self.layer_idx].clone() @@ -294,6 +305,9 @@ def slow_forward( ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -355,10 +369,11 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position) - return self.slow_forward(hidden_states, cache_params, cache_position) + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) # Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba @@ -396,13 +411,16 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) hidden_states = residual + hidden_states return hidden_states @@ -649,10 +667,15 @@ def forward( for mixer_block in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask ) else: - hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -721,6 +744,7 @@ def prepare_inputs_for_generation( use_cache=None, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): if use_cache: @@ -733,6 +757,10 @@ def prepare_inputs_for_generation( ) if cache_position[0] > 0: input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = attention_mask[:, -1].unsqueeze(-1) + else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation @@ -745,11 +773,17 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} + # In case cache is not used, manually add a new column in the attention mask + if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: + padd_length = input_ids.shape[-1] - attention_mask.shape[-1] + attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :padd_length])], dim=-1) + model_inputs.update( { "cache_params": cache_params, "use_cache": use_cache, "cache_position": cache_position, + "attention_mask": attention_mask, } ) return model_inputs @@ -760,11 +794,10 @@ def prepare_inputs_for_generation( output_type=FalconMambaCausalLMOutput, config_class=_CONFIG_FOR_DOC, ) - # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # Ignored copy + attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[MambaCache] = None, labels: Optional[torch.LongTensor] = None, @@ -790,6 +823,7 @@ def forward( return_dict=return_dict, use_cache=use_cache, cache_position=cache_position, + attention_mask=attention_mask, ) hidden_states = falcon_mamba_outputs[0] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 23ab3ab142d075..798b484979d70e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -136,6 +136,7 @@ def cuda_kernels_forward( hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -160,6 +161,9 @@ def cuda_kernels_forward( else: hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if cache_params is not None and cache_position[0] > 0: @@ -181,6 +185,9 @@ def cuda_kernels_forward( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -226,13 +233,16 @@ def cuda_kernels_forward( return contextualized_states # fmt: off - def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None): + def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 2. Convolution sequence transformation if cache_params is not None: ssm_state = cache_params.ssm_states[self.layer_idx].clone() @@ -261,6 +271,9 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) @@ -306,10 +319,11 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position) - return self.slow_forward(hidden_states, cache_params, cache_position) + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) class MambaRMSNorm(nn.Module): @@ -346,13 +360,16 @@ def forward( hidden_states, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, ): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) hidden_states = residual + hidden_states return hidden_states @@ -557,6 +574,7 @@ def set_input_embeddings(self, new_embeddings): def forward( self, input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, cache_params: Optional[MambaCache] = None, use_cache: Optional[bool] = None, @@ -605,10 +623,15 @@ def forward( for mixer_block in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask ) else: - hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -677,6 +700,7 @@ def prepare_inputs_for_generation( use_cache=None, cache_params: Optional[MambaCache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, **kwargs, ): if use_cache: @@ -689,6 +713,10 @@ def prepare_inputs_for_generation( ) if cache_position[0] > 0: input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = attention_mask[:, -1].unsqueeze(-1) + else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage # considering padding will be applied when input length is shorter, and truncation @@ -701,11 +729,17 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} + # In case cache is not used, manually add a new column in the attention mask + if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: + padd_length = input_ids.shape[-1] - attention_mask.shape[-1] + attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :padd_length])], dim=-1) + model_inputs.update( { "cache_params": cache_params, "use_cache": use_cache, "cache_position": cache_position, + "attention_mask": attention_mask, } ) return model_inputs @@ -719,6 +753,7 @@ def prepare_inputs_for_generation( def forward( self, input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[MambaCache] = None, labels: Optional[torch.LongTensor] = None, @@ -744,6 +779,7 @@ def forward( return_dict=return_dict, use_cache=use_cache, cache_position=cache_position, + attention_mask=attention_mask, ) hidden_states = mamba_outputs[0] diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 8e7c456e4a383b..6bccf6d645033e 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -101,6 +101,7 @@ def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = ids_tensor([self.batch_size, self.seq_length], 1) sequence_labels = None token_labels = None @@ -119,7 +120,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - None, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -153,6 +154,7 @@ def prepare_config_and_inputs_for_decoder(self): ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -161,6 +163,7 @@ def prepare_config_and_inputs_for_decoder(self): return ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -253,12 +256,12 @@ def prepare_config_and_inputs_for_common(self): ( config, input_ids, - _, + attention_mask, sequence_labels, token_labels, choice_labels, ) = self.prepare_config_and_inputs() - inputs_dict = {"input_ids": input_ids} + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} return config, inputs_dict @@ -491,3 +494,23 @@ def test_generation_torch_compile(self): self.tokenizer.batch_decode(out, skip_special_tokens=False)[0], "Hello today I am going to show you how to make a simple and easy to make paper plane.\nStep", ) + + def test_batched_generation(self): + model_id = "tiiuae/falcon-mamba-7b" + tok = AutoTokenizer.from_pretrained(model_id) + tok.pad_token_id = tok.eos_token_id + + texts = ["Hello today", "Hello my name is Younes and today"] + + EXPECTED_OUTPUT = [ + "Hello today I'm going to show you how to make a 3D model of a house.\n", + "Hello my name is Younes and today I will be talking about the topic of “The importance of the internet in our life”.\n", + ] + + inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.bfloat16) + + out = model.generate(**inputs, max_new_tokens=20) + out = tok.batch_decode(out, skip_special_tokens=True) + + self.assertListEqual(out, EXPECTED_OUTPUT) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index cd800da9765169..a27553247013ea 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -94,6 +94,7 @@ def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = ids_tensor([self.batch_size, self.seq_length], 1) sequence_labels = None token_labels = None @@ -112,7 +113,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - None, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -146,6 +147,7 @@ def prepare_config_and_inputs_for_decoder(self): ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -154,6 +156,7 @@ def prepare_config_and_inputs_for_decoder(self): return ( config, input_ids, + attention_mask, sequence_labels, token_labels, choice_labels, @@ -246,12 +249,12 @@ def prepare_config_and_inputs_for_common(self): ( config, input_ids, - _, + attention_mask, sequence_labels, token_labels, choice_labels, ) = self.prepare_config_and_inputs() - inputs_dict = {"input_ids": input_ids} + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} return config, inputs_dict From 39dab5a7cf7a450775f50d27db5740a723dacb45 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:38:00 +0400 Subject: [PATCH 2/9] Apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --- src/transformers/models/falcon_mamba/modeling_falcon_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 9af66d87609cf0..c4e7d0d23ecf6d 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -775,8 +775,8 @@ def prepare_inputs_for_generation( # In case cache is not used, manually add a new column in the attention mask if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: - padd_length = input_ids.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :padd_length])], dim=-1) + pad_length = input_ids.shape[-1] - attention_mask.shape[-1] + attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :pad_length])], dim=-1) model_inputs.update( { From 97e9dfd8618eb36768bffdb400986bca94825218 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Aug 2024 09:38:19 +0000 Subject: [PATCH 3/9] fix copies --- src/transformers/models/mamba/modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 798b484979d70e..4a0a63691ba5c5 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -731,8 +731,8 @@ def prepare_inputs_for_generation( # In case cache is not used, manually add a new column in the attention mask if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: - padd_length = input_ids.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :padd_length])], dim=-1) + pad_length = input_ids.shape[-1] - attention_mask.shape[-1] + attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :pad_length])], dim=-1) model_inputs.update( { From 8978e00919cec6f72c874a9762d25c292ea08d80 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Aug 2024 10:11:51 +0000 Subject: [PATCH 4/9] test with `inputs_embeds` --- .../models/falcon_mamba/test_modeling_falcon_mamba.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index 6bccf6d645033e..d75014f370d29f 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -514,3 +514,13 @@ def test_batched_generation(self): out = tok.batch_decode(out, skip_special_tokens=True) self.assertListEqual(out, EXPECTED_OUTPUT) + + # We test the same generations with inputs_embeds + with torch.no_grad(): + inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids")) + + inputs["inputs_embeds"] = inputs_embeds + out = model.generate(**inputs, max_new_tokens=20) + out = tok.batch_decode(out, skip_special_tokens=True) + + self.assertListEqual(out, EXPECTED_OUTPUT) From fcc05af7de77449e73f72ddc6651a06c4360230a Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 16 Aug 2024 12:12:37 +0400 Subject: [PATCH 5/9] Update src/transformers/models/falcon_mamba/modeling_falcon_mamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/falcon_mamba/modeling_falcon_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index c4e7d0d23ecf6d..71366e03d903ec 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -759,7 +759,7 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, -1].unsqueeze(-1) if attention_mask is not None: - attention_mask = attention_mask[:, -1].unsqueeze(-1) + attention_mask = None else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage From fe2725eaab86ba11b074ff88872620bd75791dad Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Aug 2024 08:13:02 +0000 Subject: [PATCH 6/9] copies --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 4a0a63691ba5c5..b43902c4c595b1 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -715,7 +715,7 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, -1].unsqueeze(-1) if attention_mask is not None: - attention_mask = attention_mask[:, -1].unsqueeze(-1) + attention_mask = None else: # we initialize the `cache_position` to full size of `conv_states` at prefill stage From 6ea01ee7ab2b96c0c975d0390096d5ecc1dba506 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Aug 2024 08:22:31 +0000 Subject: [PATCH 7/9] clairfy --- .../models/falcon_mamba/modeling_falcon_mamba.py | 6 +++--- src/transformers/models/mamba/modeling_mamba.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 71366e03d903ec..4d6c1a9c60cba7 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -773,10 +773,10 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} - # In case cache is not used, manually add a new column in the attention mask + # In case cache is not used, manually update the attention mask if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: - pad_length = input_ids.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :pad_length])], dim=-1) + past_length = input_ids.shape[-1] - attention_mask.shape[-1] + attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1) model_inputs.update( { diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index b43902c4c595b1..ce462d5dea1534 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -729,10 +729,10 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} - # In case cache is not used, manually add a new column in the attention mask + # In case cache is not used, manually update the attention mask if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: - pad_length = input_ids.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :pad_length])], dim=-1) + past_length = input_ids.shape[-1] - attention_mask.shape[-1] + attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1) model_inputs.update( { From 9a09c899c45a6e9493509d0c2057b671df6df0c2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Aug 2024 17:26:39 +0000 Subject: [PATCH 8/9] fix last comments --- .../models/falcon_mamba/modeling_falcon_mamba.py | 15 ++++++++------- src/transformers/models/mamba/modeling_mamba.py | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 0b4aa6124d26ef..2f0bcef54176d1 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -619,14 +619,13 @@ def set_input_embeddings(self, new_embeddings): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # Ignored arg inputs_embeds: Optional[torch.LongTensor] = None, cache_params: Optional[MambaCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + attention_mask: Optional[torch.LongTensor] = None, # Ignored arg ) -> Union[Tuple, FalconMambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -735,6 +734,13 @@ def _update_model_kwargs_for_generation( and model_kwargs["cache_position"] is not None ): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + return model_kwargs def prepare_inputs_for_generation( @@ -773,11 +779,6 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} - # In case cache is not used, manually update the attention mask - if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: - past_length = input_ids.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1) - model_inputs.update( { "cache_params": cache_params, diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index ce462d5dea1534..1d131bfa02132f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -574,14 +574,13 @@ def set_input_embeddings(self, new_embeddings): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, cache_params: Optional[MambaCache] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + attention_mask: Optional[torch.LongTensor] = None, # Ignored arg ) -> Union[Tuple, MambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -691,6 +690,12 @@ def _update_model_kwargs_for_generation( ): model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + return model_kwargs def prepare_inputs_for_generation( @@ -729,11 +734,6 @@ def prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids.contiguous()} - # In case cache is not used, manually update the attention mask - if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: - past_length = input_ids.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1) - model_inputs.update( { "cache_params": cache_params, From 6c16fc09a7f1e814c2ec5a79f4d249c1e37b2415 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Aug 2024 17:27:54 +0000 Subject: [PATCH 9/9] remove --- src/transformers/models/falcon_mamba/modeling_falcon_mamba.py | 2 +- src/transformers/models/mamba/modeling_mamba.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 2f0bcef54176d1..07374fe1dfd7b5 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -625,7 +625,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # Ignored arg + attention_mask: Optional[torch.LongTensor] = None, ) -> Union[Tuple, FalconMambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 1d131bfa02132f..14a3dea1d1ccf8 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -580,7 +580,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # Ignored arg + attention_mask: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states