Skip to content

Commit

Permalink
reorganise rnabert
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuanChen committed Mar 27, 2024
1 parent 7adf925 commit 9468b54
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 155 deletions.
143 changes: 143 additions & 0 deletions multimolecule/models/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput


class MaskedLMHead(nn.Module):
"""Head for masked language modeling."""

def __init__(self, config):
super().__init__()
self.projection = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.activation = nn.GELU()
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.projection(features)
x = self.activation(x)
x = self.layer_norm(x)
x = self.decoder(x)
return x


class SequenceClassificationHead(nn.Module):
"""Head for sequence-level classification tasks."""

num_labels: int

def __init__(self, config):
super().__init__()
self.num_labels = config.num_labels
self.projection = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.GELU()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.num_labels)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
sequence_output = outputs.last_hidden_state if return_dict else outputs[0]
x = self.dropout(sequence_output)
x = self.projection(x)
x = self.activation(x)
x = self.dropout(x)
logits = self.classifier(x)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class TokenClassificationHead(nn.Module):
"""Head for token-level classification tasks."""

num_labels: int

def __init__(self, config):
super().__init__()
self.num_labels = config.num_labels
self.projection = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.GELU()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.num_labels)

def forward(
self, outputs, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None
) -> Union[Tuple, TokenClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
token_output = outputs.pooled_output if return_dict else outputs[1]
x = self.dropout(token_output)
x = self.dropout(x)
x = self.projection(x)
x = self.activation(x)
x = self.dropout(x)
logits = self.classifier(x)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
9 changes: 6 additions & 3 deletions multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import RnaBertModel
from .modeling_rnabert import RnaBertModel, RnaBertForMaskedLM, RnaBertForSequenceClassification, RnaBertForTokenClassification

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer"]
__all__ = ["RnaBertConfig", "RnaBertModel", "RnaTokenizer", "RnaBertForMaskedLM", "RnaBertForSequenceClassification", "RnaBertForTokenClassification"]

AutoConfig.register("rnabert", RnaBertConfig)
AutoModel.register(RnaBertConfig, RnaBertModel)
AutoModelForMaskedLM.register(RnaBertConfig, RnaBertForMaskedLM)
AutoModelForSequenceClassification.register(RnaBertConfig, RnaBertForSequenceClassification)
AutoModelForTokenClassification.register(RnaBertConfig, RnaBertForTokenClassification)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
5 changes: 3 additions & 2 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class RnaBertConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a
RnaBert model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the RnaBert
[mana438/RNABERT](https://github.com/mana438/RNABERT/blob/master/RNA_bert_config.json) architecture.
[mana438/RNABERT](https://github.com/mana438/RNABERT) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down Expand Up @@ -55,7 +55,8 @@ class RnaBertConfig(PretrainedConfig):
>>> # Initializing a model from the configuration >>> model = RnaBertModel(configuration)
>>> # Accessing the model configuration >>> configuration = model.config
```"""
```
"""

model_type = "rnabert"

Expand Down
Loading

0 comments on commit 9468b54

Please sign in to comment.