Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MambaForSequenceClassification #31155

Open
wants to merge 66 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
76c191c
Added MambaForSequenceClassification to src/transformers
Adibvafa May 31, 2024
4a54410
Updated docs with MambaForSequenceClassification
Adibvafa May 31, 2024
242d10c
Added tests for MambaForSequenceClassification.
Adibvafa May 31, 2024
4d17612
Fixed style errors.
Adibvafa May 31, 2024
3a1160c
Fixed errors with dummy objects.
Adibvafa May 31, 2024
1ba4ed6
Fixed style issues.
Adibvafa May 31, 2024
cb7b9ed
Fixed style issues with ruff.
Adibvafa May 31, 2024
29c83f3
Fixed incorrect example in docstring.
Adibvafa May 31, 2024
d5b8b90
Removed cache_params from MambaSequenceClassifierOutput
Adibvafa Jun 5, 2024
18ef0b0
Merge remote-tracking branch 'upstream/main' into main
Adibvafa Jun 7, 2024
3b7f419
Fixed issues with the incompatibale test with MambaMixer initialization.
Adibvafa Jun 7, 2024
bfdee9c
Added MambaForSequenceClassification to src/transformers
Adibvafa May 31, 2024
6a37706
Updated docs with MambaForSequenceClassification
Adibvafa May 31, 2024
96ad016
Added tests for MambaForSequenceClassification.
Adibvafa May 31, 2024
898ce93
Fixed style errors.
Adibvafa May 31, 2024
b617614
Fixed errors with dummy objects.
Adibvafa May 31, 2024
eba99b0
Fixed style issues.
Adibvafa May 31, 2024
fe4badd
Fixed style issues with ruff.
Adibvafa May 31, 2024
d80efe1
Fixed incorrect example in docstring.
Adibvafa May 31, 2024
bdf4936
Removed cache_params from MambaSequenceClassifierOutput
Adibvafa Jun 5, 2024
734c35e
Fixed issues with the incompatibale test with MambaMixer initialization.
Adibvafa Jun 7, 2024
97fdcbd
Merge branch 'main' of github.com:Adibvafa/MambaForSequenceClassifica…
Adibvafa Jun 11, 2024
64e53d1
Merge remote-tracking branch 'upstream/main' into main
Adibvafa Jun 11, 2024
41ea4a1
Remove parameter cache.
Adibvafa Jun 11, 2024
25fb588
Add cache_params and use_cache for MambaForSequenceClassification.
Adibvafa Jun 11, 2024
6ea2703
Improve code style.
Adibvafa Jun 11, 2024
f3d5188
Merge branch 'main' of github.com:huggingface/transformers into main
Adibvafa Jun 17, 2024
f9a23fa
Fix merge conflicts with new main.
Adibvafa Jun 21, 2024
b3065e9
Improve modeling_mamba.py using PR reviews.
Adibvafa Jun 21, 2024
3fbb1eb
Add MambaForSequenceClassification to list of available Mamba models …
Adibvafa Jun 21, 2024
8b29642
Merge branch 'main' of github.com:huggingface/transformers into main
Adibvafa Jun 21, 2024
301252d
Merge branch 'huggingface:main' into main
Adibvafa Jun 25, 2024
4fa8b85
Merge branch 'huggingface:main' into main
Adibvafa Jun 26, 2024
eb62e3c
Merge branch 'huggingface:main' into main
Adibvafa Jun 27, 2024
da4d7bc
Merge branch 'huggingface:main' into main
Adibvafa Jun 29, 2024
71f25ba
Merge branch 'huggingface:main' into main
Adibvafa Jul 3, 2024
23795db
Merge branch 'huggingface:main' into main
Adibvafa Jul 4, 2024
7e88d92
Merge branch 'huggingface:main' into main
Adibvafa Jul 4, 2024
ab10a42
Merge branch 'huggingface:main' into main
Adibvafa Jul 6, 2024
eafabc7
Merge branch 'huggingface:main' into main
Adibvafa Jul 8, 2024
c98e867
Merge branch 'huggingface:main' into main
Adibvafa Jul 8, 2024
f550318
Update test_modeling_mamba.py
Adibvafa Jul 13, 2024
2ae7e03
Merge branch 'main' into main
Adibvafa Jul 13, 2024
87b554f
Merge branch 'huggingface:main' into main
Adibvafa Jul 15, 2024
da9dabe
Merge branch 'huggingface:main' into main
Adibvafa Jul 17, 2024
d55d328
Merge branch 'huggingface:main' into main
Adibvafa Jul 18, 2024
80f4167
Merge branch 'huggingface:main' into main
Adibvafa Jul 18, 2024
fcb0bda
Merge branch 'huggingface:main' into main
Adibvafa Jul 19, 2024
781d6ac
Merge branch 'huggingface:main' into main
Adibvafa Jul 20, 2024
70e23c4
Merge branch 'huggingface:main' into main
Adibvafa Jul 22, 2024
e7c4ceb
Change classification head to a linear layer.
Adibvafa Jul 27, 2024
c3e4318
Update mamba configuration by adding use_mambapy.
Adibvafa Jul 27, 2024
dec2856
Fix merge conflicts and remove classifier dropout
Adibvafa Jul 27, 2024
7772f51
Merge branch 'huggingface:main' into main
Adibvafa Jul 29, 2024
683ec48
Merge branch 'huggingface:main' into main
Adibvafa Jul 30, 2024
172a3f7
Merge branch 'huggingface:main' into main
Adibvafa Aug 1, 2024
3aded36
Merge branch 'huggingface:main' into main
Adibvafa Aug 2, 2024
cc8b300
Merge branch 'huggingface:main' into main
Adibvafa Aug 5, 2024
df0f337
Merge branch 'huggingface:main' into main
Adibvafa Aug 6, 2024
6343b46
Merge branch 'huggingface:main' into main
Adibvafa Aug 7, 2024
3e27597
Merge branch 'huggingface:main' into main
Adibvafa Aug 9, 2024
de9afdc
Merge branch 'huggingface:main' into main
Adibvafa Aug 14, 2024
1b35cdd
Add **kwargs to MambaForSequenceClassification.
Adibvafa Aug 14, 2024
c9223f6
Merge branch 'huggingface:main' into main
Adibvafa Aug 17, 2024
514f636
Merge branch 'huggingface:main' into main
Adibvafa Aug 18, 2024
259f925
Merge branch 'huggingface:main' into main
Adibvafa Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/mamba.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ trainer.train()

[[autodoc]] MambaForCausalLM
- forward

## MambaForSequenceClassification

[[autodoc]] MambaForSequenceClassification
- forward
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2552,6 +2552,7 @@
_import_structure["models.mamba"].extend(
[
"MambaForCausalLM",
"MambaForSequenceClassification",
"MambaModel",
"MambaPreTrainedModel",
]
Expand Down Expand Up @@ -7069,6 +7070,7 @@
)
from .models.mamba import (
MambaForCausalLM,
MambaForSequenceClassification,
MambaModel,
MambaPreTrainedModel,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@
("llama", "LlamaForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"),
("mamba", "MambaForSequenceClassification"),
("markuplm", "MarkupLMForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("mega", "MegaForSequenceClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mamba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
else:
_import_structure["modeling_mamba"] = [
"MambaForCausalLM",
"MambaForSequenceClassification",
"MambaModel",
"MambaPreTrainedModel",
]
Expand All @@ -49,6 +50,7 @@
else:
from .modeling_mamba import (
MambaForCausalLM,
MambaForSequenceClassification,
MambaModel,
MambaPreTrainedModel,
)
Expand Down
155 changes: 154 additions & 1 deletion src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import MambaCache
Expand All @@ -32,6 +32,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
from .configuration_mamba import MambaConfig
Expand Down Expand Up @@ -441,6 +442,31 @@ class MambaOutput(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class MambaSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
cache_params (`MambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
cache_params: Optional[MambaCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class MambaCausalLMOutput(ModelOutput):
"""
Expand Down Expand Up @@ -769,3 +795,130 @@ def forward(
cache_params=mamba_outputs.cache_params,
hidden_states=mamba_outputs.hidden_states,
)


@add_start_docstrings(
"""
Mamba Model backbone with a sequence classification/regression head on top
(a linear layer on top of the pooled output) e.g. for GLUE tasks.

[`MambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.

Since it does classification on the last token, it requires to know the position of the last token.
If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row.
If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
MAMBA_START_DOCSTRING,
)
class MambaForSequenceClassification(MambaPreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use copied from here? : # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mamba, LLAMA->MAMBA, self.transformer->self.model, transformer_outputs->model_outputs

Copy link
Contributor Author

@Adibvafa Adibvafa Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand your comment. The forward method of Mamba and Llama for sequence classification seem different. Could you please elaborate! 🤗

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.backbone = MambaModel(config)
self.classifier = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MambaSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MambaSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there is a
**kwargs,
missing in the forward function line 842-843

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I suppose it won't be necessary (as opposed to MambaForCausalLM) but having it is good. I will add it in a commit now.

) -> Union[MambaSequenceClassifierOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss.
Indices should be in `[0, ..., config.num_labels - 1]`.
If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

mamba_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)

last_hidden_states = mamba_outputs[0]

if input_ids is not None:
batch_size, _ = input_ids.shape[:2]
else:
batch_size, _ = inputs_embeds.shape[:2]

if self.config.pad_token_id is None and batch_size > 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")

if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(last_hidden_states.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_last_hidden_states = last_hidden_states[
torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths
]
pooled_logits = self.classifier(pooled_last_hidden_states)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype in [torch.long, torch.int]):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)

if not return_dict:
output = (pooled_logits,) + mamba_outputs[1:]
return ((loss,) + output) if loss is not None else output

return MambaSequenceClassifierOutput(
loss=loss,
logits=pooled_logits,
cache_params=mamba_outputs.cache_params,
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
hidden_states=mamba_outputs.hidden_states,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5528,6 +5528,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class MambaForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class MambaModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
29 changes: 26 additions & 3 deletions tests/models/mamba/test_modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from transformers import (
MambaForCausalLM,
MambaForSequenceClassification,
MambaModel,
)
from transformers.models.mamba.modeling_mamba import MambaCache
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
num_choices=4,
scope=None,
tie_word_embeddings=True,
classifier_dropout=0.1,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -86,6 +88,7 @@ def __init__(
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
self.tie_word_embeddings = tie_word_embeddings
self.classifier_dropout = classifier_dropout

def get_large_model_config(self):
return MambaConfig.from_pretrained("hf-internal-testing/mamba-2.8b")
Expand Down Expand Up @@ -135,6 +138,7 @@ def get_config(
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
tie_word_embeddings=self.tie_word_embeddings,
classifier_dropout=self.classifier_dropout,
)

def get_pipeline_config(self):
Expand Down Expand Up @@ -179,6 +183,14 @@ def create_and_check_causal_lm(self, config, input_ids, *args):
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_sequence_classification(self, config, input_ids, sequence_labels, *args):
config.num_labels = self.num_labels
model = MambaForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_state_equivalency(self, config, input_ids, *args):
model = MambaModel(config=config)
model.to(torch_device)
Expand Down Expand Up @@ -260,17 +272,24 @@ def prepare_config_and_inputs_for_common(self):
)
@require_torch
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
all_model_classes = (MambaModel, MambaForCausalLM, MambaForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else ()
has_attentions = False # Mamba does not support attentions
fx_compatible = False # FIXME let's try to support this @ArthurZucker
test_torchscript = False # FIXME let's try to support this @ArthurZucker
test_missing_keys = False
test_model_parallel = False
test_pruning = False
test_head_masking = False # Mamba does not have attention heads
test_model_parallel = False
test_mismatched_shapes = False # MambaMixer follows a different initialization
Copy link
Contributor

@vasqu vasqu Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit weird to me 🤔 Disabling the test_mismatched_shapes flag shouldn't be needed imo.

Could you add get_input_embeddings and set_input_embeddings methods for the ForSeqClassification class and see if it fixes those tests?

pipeline_model_mapping = (
{"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {}
{
"feature-extraction": MambaModel,
"text-generation": MambaForCausalLM,
"text-classification": MambaForSequenceClassification,
}
if is_torch_available()
else {}
)

def setUp(self):
Expand Down Expand Up @@ -337,6 +356,10 @@ def test_mamba_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm(*config_and_inputs)

def test_mamba_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_sequence_classification(*config_and_inputs)

def test_state_equivalency(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
Expand Down