diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 94eb2e2c2d528d..ddc873284c931b 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -102,3 +102,8 @@ trainer.train() [[autodoc]] MambaForCausalLM - forward + +## MambaForSequenceClassification + +[[autodoc]] MambaForSequenceClassification + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ecf7031086f4fa..2636448ef8cb93 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2573,6 +2573,7 @@ _import_structure["models.mamba"].extend( [ "MambaForCausalLM", + "MambaForSequenceClassification", "MambaModel", "MambaPreTrainedModel", ] @@ -7124,6 +7125,7 @@ ) from .models.mamba import ( MambaForCausalLM, + MambaForSequenceClassification, MambaModel, MambaPreTrainedModel, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index bb47857bd0c690..b21949e2f27df7 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -906,6 +906,7 @@ ("llama", "LlamaForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), + ("mamba", "MambaForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), ("mbart", "MBartForSequenceClassification"), ("mega", "MegaForSequenceClassification"), diff --git a/src/transformers/models/mamba/__init__.py b/src/transformers/models/mamba/__init__.py index 80cb8e1c68a21d..aeb330bbbe4864 100644 --- a/src/transformers/models/mamba/__init__.py +++ b/src/transformers/models/mamba/__init__.py @@ -33,6 +33,7 @@ else: _import_structure["modeling_mamba"] = [ "MambaForCausalLM", + "MambaForSequenceClassification", "MambaModel", "MambaPreTrainedModel", ] @@ -49,6 +50,7 @@ else: from .modeling_mamba import ( MambaForCausalLM, + MambaForSequenceClassification, MambaModel, MambaPreTrainedModel, ) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 14a3dea1d1ccf8..d11e8786e649e0 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -21,7 +21,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import MambaCache @@ -32,6 +32,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, logging, + replace_return_docstrings, ) from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available from .configuration_mamba import MambaConfig @@ -459,6 +460,31 @@ class MambaOutput(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class MambaSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + @dataclass class MambaCausalLMOutput(ModelOutput): """ @@ -806,3 +832,131 @@ def forward( cache_params=mamba_outputs.cache_params, hidden_states=mamba_outputs.hidden_states, ) + + +@add_start_docstrings( + """ + Mamba Model backbone with a sequence classification/regression head on top + (a linear layer on top of the pooled output) e.g. for GLUE tasks. + + [`MambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. + If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. + If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MAMBA_START_DOCSTRING, +) +class MambaForSequenceClassification(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.backbone = MambaModel(config) + self.classifier = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=MambaSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MambaSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + **kwargs, + ) -> Union[MambaSequenceClassifierOutput, Tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., config.num_labels - 1]`. + If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + + last_hidden_states = mamba_outputs[0] + + if input_ids is not None: + batch_size, _ = input_ids.shape[:2] + else: + batch_size, _ = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None and batch_size > 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(last_hidden_states.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_last_hidden_states = last_hidden_states[ + torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths + ] + pooled_logits = self.classifier(pooled_last_hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype in [torch.long, torch.int]): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + + if not return_dict: + output = (pooled_logits,) + mamba_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MambaSequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + cache_params=mamba_outputs.cache_params, + hidden_states=mamba_outputs.hidden_states, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ce3f0045e3dfe1..ab1facba8ea2a2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5556,6 +5556,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MambaForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MambaModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 54d35917556f6d..4e44c559350782 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -35,6 +35,7 @@ from transformers import ( MambaForCausalLM, + MambaForSequenceClassification, MambaModel, ) from transformers.models.mamba.modeling_mamba import MambaCache @@ -64,6 +65,7 @@ def __init__( num_choices=4, scope=None, tie_word_embeddings=True, + classifier_dropout=0.1, ): self.parent = parent self.batch_size = batch_size @@ -86,6 +88,7 @@ def __init__( self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 self.tie_word_embeddings = tie_word_embeddings + self.classifier_dropout = classifier_dropout def get_large_model_config(self): return MambaConfig.from_pretrained("hf-internal-testing/mamba-2.8b") @@ -136,6 +139,7 @@ def get_config( pad_token_id=self.pad_token_id, gradient_checkpointing=gradient_checkpointing, tie_word_embeddings=self.tie_word_embeddings, + classifier_dropout=self.classifier_dropout, ) def get_pipeline_config(self): @@ -182,6 +186,14 @@ def create_and_check_causal_lm(self, config, input_ids, *args): self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_sequence_classification(self, config, input_ids, sequence_labels, *args): + config.num_labels = self.num_labels + model = MambaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def create_and_check_state_equivalency(self, config, input_ids, *args): model = MambaModel(config=config) model.to(torch_device) @@ -263,17 +275,24 @@ def prepare_config_and_inputs_for_common(self): ) @require_torch class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () + all_model_classes = (MambaModel, MambaForCausalLM, MambaForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else () has_attentions = False # Mamba does not support attentions fx_compatible = False # FIXME let's try to support this @ArthurZucker test_torchscript = False # FIXME let's try to support this @ArthurZucker test_missing_keys = False - test_model_parallel = False test_pruning = False test_head_masking = False # Mamba does not have attention heads + test_model_parallel = False + test_mismatched_shapes = False # MambaMixer follows a different initialization pipeline_model_mapping = ( - {"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {} + { + "feature-extraction": MambaModel, + "text-generation": MambaForCausalLM, + "text-classification": MambaForSequenceClassification, + } + if is_torch_available() + else {} ) def setUp(self): @@ -340,6 +359,10 @@ def test_mamba_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_causal_lm(*config_and_inputs) + def test_mamba_sequence_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_sequence_classification(*config_and_inputs) + def test_state_equivalency(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_state_equivalency(*config_and_inputs)