Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ModernBERT to Transformers #35158

Merged
merged 91 commits into from
Dec 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
6b5a823
initial cut of modernbert for transformers
warner-benjamin Dec 9, 2024
dafb203
small bug fixes
warner-benjamin Dec 10, 2024
df13def
fixes
warner-benjamin Dec 11, 2024
d09eabf
Update import
tomaarsen Dec 11, 2024
8c3afea
Use compiled mlp->mlp_norm to match research implementation
tomaarsen Dec 11, 2024
a40aaa9
Propagate changes in modular to modeling
tomaarsen Dec 11, 2024
9f0b8ca
Replace duplicate attn_out_dropout in favor of attention_dropout
tomaarsen Dec 11, 2024
900d8ec
Update BOS to CLS and EOS to SEP
tomaarsen Dec 11, 2024
caf8901
Set default classifier bias to False, matching research repo
tomaarsen Dec 11, 2024
8276602
Update tie_word_embeddings description
tomaarsen Dec 11, 2024
79e4bbb
Fix _init_weights for ForMaskedLM
tomaarsen Dec 11, 2024
b59bad9
Match base_model_prefix
tomaarsen Dec 11, 2024
e7bef53
Add compiled_head to match research repo outputs
tomaarsen Dec 11, 2024
120578b
Fix imports for ModernBertForMaskedLM
tomaarsen Dec 11, 2024
142ff11
Just use "gelu" default outright for classifier
tomaarsen Dec 11, 2024
b44abdc
Fix config name typo: initalizer -> initializer
tomaarsen Dec 11, 2024
3de8ebf
Remove some unused parameters in docstring. Still lots to edit there!
tomaarsen Dec 11, 2024
7a05b3f
Compile the embeddings forward
tomaarsen Dec 12, 2024
88b0ecf
Add drafts for ForSequenceClassification/ForTokenClassification
tomaarsen Dec 12, 2024
5e3d61d
Add initial SDPA support (not exactly equivalent to FA2 yet!)
tomaarsen Dec 12, 2024
2a3d378
Only use attention dropout if training
tomaarsen Dec 12, 2024
a2051d6
Add initial eager attention support (also not equivalent to FA2 yet!)
tomaarsen Dec 12, 2024
124f1fd
Add initial tests, output_attentions, output_hidden_states, prune_heads
tomaarsen Dec 13, 2024
38f959b
Remove kwargs from ModernBertForMaskedLM
tomaarsen Dec 13, 2024
f716943
Remove/adjust/skip improper tests; warn if padding but no attn mask
tomaarsen Dec 13, 2024
f41adaa
Run formatting etc.
tomaarsen Dec 13, 2024
d06654a
Run python utils/custom_init_isort.py
tomaarsen Dec 14, 2024
f9301f4
FlexAttention with unpadded sequences(matches FA2 within bf16 numerics)
staghado Dec 15, 2024
a356708
Reformat init_weights based on review
tomaarsen Dec 16, 2024
f83fdc0
self -> module in attention forwards
tomaarsen Dec 16, 2024
b444c15
Remove if config.tie_word_embeddings
tomaarsen Dec 16, 2024
5aaf273
Reformat output projection on a different line
tomaarsen Dec 16, 2024
0a8d044
Remove pruning
tomaarsen Dec 16, 2024
382e481
Remove assert
tomaarsen Dec 16, 2024
5d05e8e
Call contiguous() to simplify paths
tomaarsen Dec 16, 2024
98508c7
Remove prune_qkv_linear_layer
tomaarsen Dec 16, 2024
2c076c8
Format code
tomaarsen Dec 16, 2024
986c6fe
Keep as kwargs, only use if needed
tomaarsen Dec 16, 2024
5cd39ad
Remove unused codepaths & related config options
tomaarsen Dec 16, 2024
2d606b9
Remove 3d attn_mask test; fix token classification tuple output
tomaarsen Dec 16, 2024
8eb87e8
Reorder: attention_mask above position_ids, fixes gradient checkpointing
tomaarsen Dec 16, 2024
5d83c56
Merge branch 'main' into pr-35158
tomaarsen Dec 16, 2024
3a24af4
Fix usage if no FA2 or torch v2.5+
tomaarsen Dec 16, 2024
37a6030
Make torch.compile/triton optional
tomaarsen Dec 17, 2024
b3b4028
Separate pooling options into separate functions (cls, mean) - cls as…
tomaarsen Dec 17, 2024
b241a7e
Simplify _pad_modernbert_output, remove unused labels path
tomaarsen Dec 17, 2024
66f4603
Update tied weights to remove decoder.weight, simplify decoder loading
tomaarsen Dec 17, 2024
3eb786b
Adaptively set config.compile based on hf_device_map/device/resize, etc.
tomaarsen Dec 17, 2024
093b601
Merge branch 'main' of https://github.com/huggingface/transformers in…
tomaarsen Dec 17, 2024
28fc79e
Update ModernBertConfig docstring
tomaarsen Dec 17, 2024
612befa
Satisfy some consistency checks, add unfinished docs
tomaarsen Dec 17, 2024
ae32e8b
Merge branch 'main' of https://github.com/huggingface/transformers in…
tomaarsen Dec 17, 2024
f4e280a
Only set compile to False if there's more than 1 device
tomaarsen Dec 17, 2024
bc14967
Add docstrings for public ModernBert classes
tomaarsen Dec 17, 2024
0f17fb9
Dont replace docstring returns - ends up being duplicate
tomaarsen Dec 17, 2024
25b12b4
Fix mistake in toctree
tomaarsen Dec 17, 2024
f312eef
Reformat toctree
tomaarsen Dec 17, 2024
1e367df
Patched FlexAttention, SDPA, Eager with Local Attention
tomaarsen Dec 17, 2024
fb748ce
Implement FA2 -> SDPA -> Eager attn_impl defaulting, crucial
tomaarsen Dec 17, 2024
051233f
Patch test edge case with Idefics3 not working with 'attn_implementat…
tomaarsen Dec 17, 2024
6c01711
Repad all_hidden_states as well
tomaarsen Dec 17, 2024
5f7c566
rename config.compile to reference_compile
warner-benjamin Dec 18, 2024
c8a80e7
disable flex_attention since it crashes
warner-benjamin Dec 18, 2024
8962f05
Update modernbert.md
bclavie Dec 18, 2024
7e89f4d
Using dtype min to mask in eager
NohTow Dec 18, 2024
0742a1d
Fully remove flex attention for now
tomaarsen Dec 18, 2024
6c6cddb
Call contiguous to allow for .view()
tomaarsen Dec 18, 2024
e37e4ec
Copyright 2020 -> 2024
tomaarsen Dec 18, 2024
9afc480
Update/simplify __init__ structure
tomaarsen Dec 18, 2024
aa1bdb4
Remove "... if dropout_prob > 0 else identity"
tomaarsen Dec 18, 2024
659807f
re-use existing pad/unpad functions instead of creating new ones
staghado Dec 18, 2024
7955e39
remove flexattention method
staghado Dec 18, 2024
4145119
Compute attention_mask and local_attention_mask once in modeling
tomaarsen Dec 18, 2024
0e572d5
Simplify sequence classification prediction heads, only CLS now
tomaarsen Dec 18, 2024
e5dca63
Simplify module.training in eager attn
tomaarsen Dec 18, 2024
bf11173
Also export ModernBertPreTrainedModel
tomaarsen Dec 18, 2024
54ed5db
Update the documentation with links to finetuning scripts
tomaarsen Dec 18, 2024
a1bfae8
Explain local_attention_mask parameter in docstring
tomaarsen Dec 18, 2024
df7658a
Simplify _autoset_attn_implementation, rely on super()
tomaarsen Dec 18, 2024
b3404ed
Keep "in" to initialize Prediction head
tomaarsen Dec 18, 2024
e057bc2
add back mean pooling
warner-benjamin Dec 18, 2024
99c38ba
Use the pooling head in TokenClassification
warner-benjamin Dec 18, 2024
5114ed7
update copyright
warner-benjamin Dec 18, 2024
175fb95
Reset config._attn_implementation_internal on failure
tomaarsen Dec 18, 2024
8cedfc5
Allow optional attention_mask in ForMaskedLM head
warner-benjamin Dec 18, 2024
2380729
fix failing run_slow tests
warner-benjamin Dec 18, 2024
7686134
Add links to the paper
tomaarsen Dec 19, 2024
44275fd
Remove unpad_no_grad, always pad/unpad without gradients
tomaarsen Dec 19, 2024
d799d65
local_attention_mask -> sliding_window_mask
tomaarsen Dec 19, 2024
ed77867
Revert "Use the pooling head in TokenClassification"
tomaarsen Dec 19, 2024
92e17c6
Simplify pooling, 2 options via if-else
tomaarsen Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Separate pooling options into separate functions (cls, mean) - cls as…
… default
  • Loading branch information
tomaarsen committed Dec 17, 2024
commit b3b4028e826d14b623bc6c35f3541ab51a67b234
Original file line number Diff line number Diff line change
@@ -20,6 +20,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

from ...configuration_utils import PretrainedConfig
from ...utils.import_utils import is_triton_available

@@ -128,7 +130,7 @@ def __init__(
unpad_no_grad=True,
decoder_bias=True,
classifier_dropout=0.0,
classifier_pooling="mean",
classifier_pooling: Literal["cls", "mean"] = "cls",
classifier_bias=False,
classifier_activation="gelu",
deterministic_flash_attn=False,
@@ -178,6 +180,11 @@ def __init__(
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.compile = compile

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)

if self.compile is None:
self.compile = is_triton_available()

44 changes: 21 additions & 23 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@
# limitations under the License.

import math
from enum import Enum
from typing import Optional, Tuple, Union

import torch
@@ -50,12 +49,6 @@
logger = logging.get_logger(__name__)


class ModernBertPoolingType(str, Enum):
cls = "cls"
mean = "mean"
max = "max"


class ApplyRotaryEmbUnpad(torch.autograd.Function):
@staticmethod
def forward(
@@ -624,6 +617,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))


def cls_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
return hidden_states[:, 0]


def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
return (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)


MODERNBERT_POOLING_FUNCTION = {
"cls": cls_pooling,
"mean": mean_pooling,
}


class ModernBertPoolingHead(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
@@ -632,24 +639,15 @@ def __init__(self, config: ModernBertConfig):
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity()
self.pooling_type = ModernBertPoolingType(config.classifier_pooling)
self.pooling = MODERNBERT_POOLING_FUNCTION[config.classifier_pooling]

def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool: Optional[bool] = True
) -> torch.Tensor:
if pool:
if self.pooling_type == ModernBertPoolingType.cls:
output = hidden_states[:, 0]
elif self.pooling_type == ModernBertPoolingType.mean:
output = hidden_states.mean(dim=1)
elif self.pooling_type == ModernBertPoolingType.max:
output = hidden_states.max(dim=1)[0]
else:
output = hidden_states

return self.drop(self.norm(self.act(self.dense(output))))

hidden_states = self.pooling(hidden_states, attention_mask)

# Copyright 2023 OLMo Authors
# License: Apache-2.0
return self.drop(self.norm(self.act(self.dense(hidden_states))))


def _unpad_modernbert_input(
@@ -1173,7 +1171,7 @@ def forward(
)
last_hidden_state = outputs[0]

pooled_output = self.head(last_hidden_state)
pooled_output = self.head(last_hidden_state, attention_mask)
logits = self.classifier(pooled_output)

loss = None
53 changes: 28 additions & 25 deletions src/transformers/models/modernbert/modular_modernbert.py
Original file line number Diff line number Diff line change
@@ -15,8 +15,7 @@
# limitations under the License.

import math
from enum import Enum
from typing import Optional, Tuple, Union
from typing import Literal, Optional, Tuple, Union

import torch
import torch.nn.functional as F
@@ -162,7 +161,7 @@ def __init__(
unpad_no_grad=True,
decoder_bias=True,
classifier_dropout=0.0,
classifier_pooling="mean",
classifier_pooling: Literal["cls", "mean"] = "cls",
classifier_bias=False,
classifier_activation="gelu",
deterministic_flash_attn=False,
@@ -212,23 +211,18 @@ def __init__(
self.sparse_pred_ignore_index = sparse_pred_ignore_index
self.compile = compile

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)

if self.compile is None:
self.compile = is_triton_available()

if unpad_inputs is None:
self.unpad_inputs = self._attn_implementation in {"flash_attention_2", "flex_attention"}


class ModernBertPoolingType(str, Enum):
cls = "cls"
mean = "mean"
max = "max"


# Copyright 2023 OLMo Authors
# License: Apache-2.0


def _unpad_modernbert_input(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct me if I am wrong, the inputs are the same as any other LLMs no? In that case if you want to unpad you should be using

def _upad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
):
available in the utils!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid not quite. We're unpadding much earlier and repadding much later, so that e.g. even the MaskedLM can take advantage of it. As a result, _upad_input (and _flash_attention_forward) aren't viable here.

inputs: torch.Tensor,
attention_mask: torch.Tensor,
@@ -829,6 +823,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))


def cls_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
return hidden_states[:, 0]


def mean_pooling(hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
return (hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)


MODERNBERT_POOLING_FUNCTION = {
"cls": cls_pooling,
"mean": mean_pooling,
}


class ModernBertPoolingHead(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
@@ -837,20 +845,15 @@ def __init__(self, config: ModernBertConfig):
self.act = ACT2FN[config.classifier_activation]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.drop = torch.nn.Dropout(config.classifier_dropout) if config.classifier_dropout > 0 else nn.Identity()
self.pooling_type = ModernBertPoolingType(config.classifier_pooling)
self.pooling = MODERNBERT_POOLING_FUNCTION[config.classifier_pooling]

def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool: Optional[bool] = True
) -> torch.Tensor:
if pool:
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
if self.pooling_type == ModernBertPoolingType.cls:
output = hidden_states[:, 0]
elif self.pooling_type == ModernBertPoolingType.mean:
output = hidden_states.mean(dim=1)
elif self.pooling_type == ModernBertPoolingType.max:
output = hidden_states.max(dim=1)[0]
else:
output = hidden_states
hidden_states = self.pooling(hidden_states, attention_mask)

return self.drop(self.norm(self.act(self.dense(output))))
return self.drop(self.norm(self.act(self.dense(hidden_states))))


class ModernBertPreTrainedModel(PreTrainedModel):
@@ -1284,7 +1287,7 @@ def forward(
)
last_hidden_state = outputs[0]

pooled_output = self.head(last_hidden_state)
pooled_output = self.head(last_hidden_state, attention_mask)
logits = self.classifier(pooled_output)

loss = None