Skip to content

Commit

Permalink
Unify RNA Tokenizer
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <this@zyc.ai>
  • Loading branch information
ZhiyuanChen committed Mar 26, 2024
1 parent bfc060c commit 08652c1
Show file tree
Hide file tree
Showing 15 changed files with 226 additions and 103 deletions.
11 changes: 9 additions & 2 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from . import models
from transformers import AutoTokenizer

__all__ = ["models"]
from . import models, tokenizers
from .models import RnaBertConfig
from .tokenizers import RnaTokenizer

AutoTokenizer.register(RnaBertConfig, RnaTokenizer)


__all__ = ["models", "tokenizers"]
4 changes: 2 additions & 2 deletions multimolecule/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .rnabert import RnaBertConfig, RnaBertModel, RnaBertTokenizer
from .rnabert import RnaBertConfig, RnaBertModel

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaBertTokenizer"]
__all__ = ["RnaBertConfig", "RnaBertModel"]
7 changes: 4 additions & 3 deletions multimolecule/models/rnabert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from transformers import AutoConfig, AutoModel, AutoTokenizer

from multimolecule.tokenizers.rna import RnaTokenizer

from .configuration_rnabert import RnaBertConfig
from .modeling_rnabert import RnaBertModel
from .tokenization_rnabert import RnaBertTokenizer

__all__ = ["RnaBertConfig", "RnaBertModel", "RnaBertTokenizer"]
__all__ = ["RnaBertConfig", "RnaBertModel"]

AutoConfig.register("rnabert", RnaBertConfig)
AutoModel.register(RnaBertConfig, RnaBertModel)
AutoTokenizer.register(RnaBertConfig, RnaBertTokenizer)
AutoTokenizer.register(RnaBertConfig, RnaTokenizer)
20 changes: 11 additions & 9 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
logger = logging.get_logger(__name__)


DEFAULT_VOCAB_LIST = ["<pad>", "<mask>", "A", "T", "G", "C"]


class RnaBertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RnaBertModel`]. It is used to instantiate a
Expand Down Expand Up @@ -64,34 +61,39 @@ class RnaBertConfig(PretrainedConfig):

def __init__(
self,
vocab_size=None,
mask_token_id=None,
pad_token_id=None,
vocab_size=25,
ss_vocab_size=8,
hidden_size=None,
multiple=None,
num_hidden_layers=6,
num_attention_heads=12,
intermediate_size=40,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=440,
initializer_range=0.02,
layer_norm_eps=1e-12,
vocab_list=None,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
super().__init__(pad_token_id=pad_token_id, **kwargs)

self.vocab_size = vocab_size
self.ss_vocab_size = ss_vocab_size
if hidden_size is None:
hidden_size = num_attention_heads * multiple if multiple is not None else 120
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.vocab_list = vocab_list if vocab_list is not None else DEFAULT_VOCAB_LIST
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
58 changes: 45 additions & 13 deletions multimolecule/models/rnabert/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,61 @@

import chanfig
import torch
from torch import nn

from multimolecule.models import RnaBertConfig, RnaBertModel
from multimolecule.models.rnabert.configuration_rnabert import DEFAULT_VOCAB_LIST

CONFIG = {
"architectures": ["RnaBertModel"],
"attention_probs_dropout_prob": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 120,
"initializer_range": 0.02,
"intermediate_size": 40,
"layer_norm_eps": 1e-12,
"mask_token_id": 1,
"max_position_embeddings": 440,
"model_type": "rnabert",
"num_attention_heads": 12,
"num_hidden_layers": 6,
"position_embedding_type": "absolute",
"ss_size": 8,
"torch_dtype": "float32",
"vocab_size": 25,
"ss_vocab_size": 8,
"type_vocab_size": 2,
"pad_token_id": 0,
}

original_vocab_list = ["<pad>", "<mask>", "A", "U", "G", "C"]
vocab_list = [
"<pad>",
"<cls>",
"<eos>",
"<unk>",
"<mask>",
"<null>",
"A",
"C",
"G",
"U",
"N",
"X",
"V",
"H",
"D",
"B",
"M",
"R",
"W",
"S",
"Y",
"K",
".",
"*",
"-",
]


def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None):
if output_path is None:
output_path = "rnabert"
config = RnaBertConfig.from_dict(chanfig.NestedDict(CONFIG))
config.vocab_list = DEFAULT_VOCAB_LIST
config.vocab_size = len(config.vocab_list)
ckpt = torch.load(checkpoint_path)
config = RnaBertConfig.from_dict(chanfig.FlatDict(CONFIG))
ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu"))
bert_state_dict = ckpt
state_dict = {}

Expand All @@ -48,8 +71,17 @@ def convert_checkpoint(checkpoint_path: str, output_path: Optional[str] = None):
key = key.replace("beta", "bias")
state_dict[key] = value

word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# nn.init.normal_(pos_embed.weight, std=0.02)
for original_token, new_token in zip(original_vocab_list, vocab_list):
original_index = original_vocab_list.index(original_token)
new_index = vocab_list.index(new_token)
word_embed.weight.data[new_index] = state_dict["embeddings.word_embeddings.weight"][original_index]
state_dict["embeddings.word_embeddings.weight"] = word_embed.weight.data

model.load_state_dict(state_dict)
model.save_pretrained(output_path)
model.save_pretrained(output_path, safe_serialization=True)
model.save_pretrained(output_path, safe_serialization=False)


if __name__ == "__main__":
Expand Down
12 changes: 5 additions & 7 deletions multimolecule/models/rnabert/modeling_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def _init_weights(self, module):


class RnaBertModel(RnaBertPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.embeddings = RnaBertEmbeddings(config)
Expand Down Expand Up @@ -329,9 +328,8 @@ class RnaBertLMHead(nn.Module):
def __init__(self, config):
super().__init__()

self.predictions = MaskedWordPredictions(config)
config.vocab_size = config.ss_size
self.predictions_ss = MaskedWordPredictions(config)
self.predictions = MaskedWordPredictions(config, config.vocab_size)
self.predictions_ss = MaskedWordPredictions(config, config.ss_vocab_size)

self.seq_relationship = nn.Linear(config.hidden_size, 2)

Expand All @@ -345,13 +343,13 @@ def forward(self, sequence_output, pooled_output):


class MaskedWordPredictions(nn.Module):
def __init__(self, config):
def __init__(self, config, vocab_size):
super().__init__()

self.transform = RnaBertPredictionHeadTransform(config)

self.decoder = nn.Linear(in_features=config.hidden_size, out_features=config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder = nn.Linear(in_features=config.hidden_size, out_features=vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(vocab_size))

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
Expand Down
16 changes: 0 additions & 16 deletions multimolecule/models/rnabert/special_tokens_map.json

This file was deleted.

27 changes: 0 additions & 27 deletions multimolecule/models/rnabert/tokenizer_config.json

This file was deleted.

6 changes: 0 additions & 6 deletions multimolecule/models/rnabert/vocab.txt

This file was deleted.

3 changes: 3 additions & 0 deletions multimolecule/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .rna import RnaTokenizer

__all__ = ["RnaTokenizer"]
3 changes: 3 additions & 0 deletions multimolecule/tokenizers/rna/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .tokenization_rna import RnaTokenizer

__all__ = ["RnaTokenizer"]
37 changes: 37 additions & 0 deletions multimolecule/tokenizers/rna/special_tokens_map.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"cls_token": {
"content": "<cls>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<eos>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"mask_token": {
"content": "<mask>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
Loading

0 comments on commit 08652c1

Please sign in to comment.