-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement MambaForSequenceClassification #31155
base: main
Are you sure you want to change the base?
Changes from 61 commits
76c191c
4a54410
242d10c
4d17612
3a1160c
1ba4ed6
cb7b9ed
29c83f3
d5b8b90
18ef0b0
3b7f419
bfdee9c
6a37706
96ad016
898ce93
b617614
eba99b0
fe4badd
d80efe1
bdf4936
734c35e
97fdcbd
64e53d1
41ea4a1
25fb588
6ea2703
f3d5188
f9a23fa
b3065e9
3fbb1eb
8b29642
301252d
4fa8b85
eb62e3c
da4d7bc
71f25ba
23795db
7e88d92
ab10a42
eafabc7
c98e867
f550318
2ae7e03
87b554f
da9dabe
d55d328
80f4167
fcb0bda
781d6ac
70e23c4
e7c4ceb
c3e4318
dec2856
7772f51
683ec48
172a3f7
3aded36
cc8b300
df0f337
6343b46
3e27597
de9afdc
1b35cdd
c9223f6
514f636
259f925
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
import torch | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
from torch.nn import CrossEntropyLoss | ||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
|
||
from ...activations import ACT2FN | ||
from ...cache_utils import MambaCache | ||
|
@@ -32,6 +32,7 @@ | |
add_start_docstrings, | ||
add_start_docstrings_to_model_forward, | ||
logging, | ||
replace_return_docstrings, | ||
) | ||
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available | ||
from .configuration_mamba import MambaConfig | ||
|
@@ -441,6 +442,31 @@ class MambaOutput(ModelOutput): | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | ||
|
||
|
||
@dataclass | ||
class MambaSequenceClassifierOutput(ModelOutput): | ||
""" | ||
Base class for outputs of sentence classification models. | ||
|
||
Args: | ||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | ||
Classification (or regression if config.num_labels==1) loss. | ||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): | ||
Classification (or regression if config.num_labels==1) scores (before SoftMax). | ||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | ||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | ||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | ||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | ||
cache_params (`MambaCache`): | ||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to | ||
avoid providing the old `input_ids`. | ||
""" | ||
|
||
loss: Optional[torch.FloatTensor] = None | ||
logits: torch.FloatTensor = None | ||
cache_params: Optional[MambaCache] = None | ||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None | ||
|
||
|
||
@dataclass | ||
class MambaCausalLMOutput(ModelOutput): | ||
""" | ||
|
@@ -769,3 +795,130 @@ def forward( | |
cache_params=mamba_outputs.cache_params, | ||
hidden_states=mamba_outputs.hidden_states, | ||
) | ||
|
||
|
||
@add_start_docstrings( | ||
""" | ||
Mamba Model backbone with a sequence classification/regression head on top | ||
(a linear layer on top of the pooled output) e.g. for GLUE tasks. | ||
|
||
[`MambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models | ||
(e.g. GPT-2) do. | ||
|
||
Since it does classification on the last token, it requires to know the position of the last token. | ||
If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. | ||
If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the | ||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in | ||
each row of the batch). | ||
""", | ||
MAMBA_START_DOCSTRING, | ||
) | ||
class MambaForSequenceClassification(MambaPreTrainedModel): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.num_labels = config.num_labels | ||
self.config = config | ||
self.backbone = MambaModel(config) | ||
self.classifier = nn.Linear(config.hidden_size, self.num_labels, bias=False) | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) | ||
@replace_return_docstrings(output_type=MambaSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) | ||
@add_code_sample_docstrings( | ||
checkpoint=_CHECKPOINT_FOR_DOC, | ||
output_type=MambaSequenceClassifierOutput, | ||
config_class=_CONFIG_FOR_DOC, | ||
) | ||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
cache_params: Optional[MambaCache] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
use_cache: Optional[bool] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! I suppose it won't be necessary (as opposed to |
||
) -> Union[MambaSequenceClassifierOutput, Tuple[torch.FloatTensor]]: | ||
r""" | ||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||
Labels for computing the sequence classification/regression loss. | ||
Indices should be in `[0, ..., config.num_labels - 1]`. | ||
If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), | ||
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||
""" | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
mamba_outputs = self.backbone( | ||
input_ids, | ||
cache_params=cache_params, | ||
inputs_embeds=inputs_embeds, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
use_cache=use_cache, | ||
) | ||
|
||
last_hidden_states = mamba_outputs[0] | ||
|
||
if input_ids is not None: | ||
batch_size, _ = input_ids.shape[:2] | ||
else: | ||
batch_size, _ = inputs_embeds.shape[:2] | ||
|
||
if self.config.pad_token_id is None and batch_size > 1: | ||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | ||
|
||
if self.config.pad_token_id is None: | ||
sequence_lengths = -1 | ||
else: | ||
if input_ids is not None: | ||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | ||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | ||
sequence_lengths = sequence_lengths % input_ids.shape[-1] | ||
sequence_lengths = sequence_lengths.to(last_hidden_states.device) | ||
else: | ||
sequence_lengths = -1 | ||
logger.warning( | ||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " | ||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | ||
) | ||
|
||
pooled_last_hidden_states = last_hidden_states[ | ||
torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths | ||
] | ||
pooled_logits = self.classifier(pooled_last_hidden_states) | ||
|
||
loss = None | ||
if labels is not None: | ||
if self.config.problem_type is None: | ||
if self.num_labels == 1: | ||
self.config.problem_type = "regression" | ||
elif self.num_labels > 1 and (labels.dtype in [torch.long, torch.int]): | ||
self.config.problem_type = "single_label_classification" | ||
else: | ||
self.config.problem_type = "multi_label_classification" | ||
|
||
if self.config.problem_type == "regression": | ||
loss_fct = MSELoss() | ||
if self.num_labels == 1: | ||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||
else: | ||
loss = loss_fct(pooled_logits, labels) | ||
elif self.config.problem_type == "single_label_classification": | ||
loss_fct = CrossEntropyLoss() | ||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) | ||
elif self.config.problem_type == "multi_label_classification": | ||
loss_fct = BCEWithLogitsLoss() | ||
loss = loss_fct(pooled_logits, labels) | ||
|
||
if not return_dict: | ||
output = (pooled_logits,) + mamba_outputs[1:] | ||
return ((loss,) + output) if loss is not None else output | ||
|
||
return MambaSequenceClassifierOutput( | ||
loss=loss, | ||
logits=pooled_logits, | ||
cache_params=mamba_outputs.cache_params, | ||
ArthurZucker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hidden_states=mamba_outputs.hidden_states, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
|
||
from transformers import ( | ||
MambaForCausalLM, | ||
MambaForSequenceClassification, | ||
MambaModel, | ||
) | ||
from transformers.models.mamba.modeling_mamba import MambaCache | ||
|
@@ -64,6 +65,7 @@ def __init__( | |
num_choices=4, | ||
scope=None, | ||
tie_word_embeddings=True, | ||
classifier_dropout=0.1, | ||
): | ||
self.parent = parent | ||
self.batch_size = batch_size | ||
|
@@ -86,6 +88,7 @@ def __init__( | |
self.eos_token_id = vocab_size - 1 | ||
self.pad_token_id = vocab_size - 1 | ||
self.tie_word_embeddings = tie_word_embeddings | ||
self.classifier_dropout = classifier_dropout | ||
|
||
def get_large_model_config(self): | ||
return MambaConfig.from_pretrained("hf-internal-testing/mamba-2.8b") | ||
|
@@ -135,6 +138,7 @@ def get_config( | |
pad_token_id=self.pad_token_id, | ||
gradient_checkpointing=gradient_checkpointing, | ||
tie_word_embeddings=self.tie_word_embeddings, | ||
classifier_dropout=self.classifier_dropout, | ||
) | ||
|
||
def get_pipeline_config(self): | ||
|
@@ -179,6 +183,14 @@ def create_and_check_causal_lm(self, config, input_ids, *args): | |
self.parent.assertEqual(result.loss.shape, ()) | ||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) | ||
|
||
def create_and_check_sequence_classification(self, config, input_ids, sequence_labels, *args): | ||
config.num_labels = self.num_labels | ||
model = MambaForSequenceClassification(config) | ||
model.to(torch_device) | ||
model.eval() | ||
result = model(input_ids, labels=sequence_labels) | ||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) | ||
|
||
def create_and_check_state_equivalency(self, config, input_ids, *args): | ||
model = MambaModel(config=config) | ||
model.to(torch_device) | ||
|
@@ -260,17 +272,24 @@ def prepare_config_and_inputs_for_common(self): | |
) | ||
@require_torch | ||
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): | ||
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () | ||
all_model_classes = (MambaModel, MambaForCausalLM, MambaForSequenceClassification) if is_torch_available() else () | ||
all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else () | ||
has_attentions = False # Mamba does not support attentions | ||
fx_compatible = False # FIXME let's try to support this @ArthurZucker | ||
test_torchscript = False # FIXME let's try to support this @ArthurZucker | ||
test_missing_keys = False | ||
test_model_parallel = False | ||
test_pruning = False | ||
test_head_masking = False # Mamba does not have attention heads | ||
test_model_parallel = False | ||
test_mismatched_shapes = False # MambaMixer follows a different initialization | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems a bit weird to me 🤔 Disabling the Could you add |
||
pipeline_model_mapping = ( | ||
{"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {} | ||
{ | ||
"feature-extraction": MambaModel, | ||
"text-generation": MambaForCausalLM, | ||
"text-classification": MambaForSequenceClassification, | ||
} | ||
if is_torch_available() | ||
else {} | ||
) | ||
|
||
def setUp(self): | ||
|
@@ -337,6 +356,10 @@ def test_mamba_lm_head_model(self): | |
config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||
self.model_tester.create_and_check_causal_lm(*config_and_inputs) | ||
|
||
def test_mamba_sequence_classification_model(self): | ||
config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||
self.model_tester.create_and_check_sequence_classification(*config_and_inputs) | ||
|
||
def test_state_equivalency(self): | ||
config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||
self.model_tester.create_and_check_state_equivalency(*config_and_inputs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use copied from here? :
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mamba, LLAMA->MAMBA, self.transformer->self.model, transformer_outputs->model_outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if I understand your comment. The forward method of Mamba and Llama for sequence classification seem different. Could you please elaborate! 🤗