-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_rr_inter.py
128 lines (111 loc) · 4.79 KB
/
run_rr_inter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
from RNABERT.rnabert import BertModel
from utils import get_config
from torch.optim import AdamW
from RNAMSM.model import MSATransformer
from metrics import RRInterMetrics
from trainers import RRInterTrainer
from utils import str2bool, str2list
from losses import RRInterLoss
from datasets import GenerateRRInterTrainTest
from rr_inter import RNABertForRRInter, RNAFmForRRInter, RNAMsmForRRInter
from tokenizer import RNATokenizer
from collators import RRDataCollator
import RNAFM.fm as fm
# ========== Define constants
MODELS = ["RNABERT", "RNAMSM", "RNAFM"]
# ========== Configuration
parser = argparse.ArgumentParser(
'Implementation of RNA-RNA Interaction prediction.')
# model args
parser.add_argument('--model_name', type=str, default="RNABERT", choices=MODELS)
parser.add_argument('--vocab_path', type=str, default="./vocabs/")
parser.add_argument('--pretrained_model', type=str,
default="./checkpoints/")
parser.add_argument('--config_path', type=str,
default="./configs/")
parser.add_argument('--dataset', type=str, default="MirTarRAW",)
parser.add_argument('--dataset_dir', type=str, default="./data/rr")
parser.add_argument('--replace_T', type=bool, default=True)
parser.add_argument('--replace_U', type=bool, default=False)
parser.add_argument('--dataloader_num_workers', type=int, default=0)
# training args
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--max_seq_lens', type=list, default=[26, 40])
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--train', type=str2bool, default=True)
parser.add_argument('--disable_tqdm', type=str2bool,
default=False, help='Disable tqdm display if true.')
parser.add_argument('--batch_size', type=int, default=50,
help='The number of samples used per step & per device.')
parser.add_argument('--num_train_epochs', type=int, default=50,
help='The number of epoch for training.')
parser.add_argument('--metrics', type=str2list,
default="F1s,Precision,Recall,Accuracy,AUC",)
# logging args
parser.add_argument('--logging_steps', type=int, default=1000,
help='Update visualdl logs every logging_steps.')
args = parser.parse_args()
if __name__ == "__main__":
# ========== post process
# ========== args check
assert args.replace_T ^ args.replace_U, "Only replace T or U."
# ========== Build tokenizer, model, criterion
tokenizer = RNATokenizer(args.vocab_path + "{}.txt".format(args.model_name))
if args.model_name == "RNABERT":
model_config = get_config(
args.config_path + "{}.json".format(args.model_name))
model = BertModel(model_config)
model = RNABertForRRInter(model)
model._load_pretrained_bert(
args.pretrained_model+"{}.pth".format(args.model_name))
elif args.model_name == "RNAMSM":
model_config = get_config(
args.config_path + "{}.json".format(args.model_name))
model = MSATransformer(**model_config)
model = RNAMsmForRRInter(model)
model._load_pretrained_bert(
args.pretrained_model+"{}.pth".format(args.model_name))
elif args.model_name == "RNAFM":
model, alphabet = fm.pretrained.rna_fm_t12()
model = RNAFmForRRInter(model)
else:
raise ValueError("Unknown model name: {}".format(args.model_name))
model.to(args.device)
_loss_fn = RRInterLoss()
# ========== Prepare data
# load datasets
# train & test datasets
datasets_generator = GenerateRRInterTrainTest(rr_dir=args.dataset_dir,
dataset=args.dataset,
split=0.8,)
dataset_train, dataset_eval = datasets_generator.get()
# ========== Create the data collator
_collate_fn = RRDataCollator(
max_seq_lens=args.max_seq_lens,
tokenizer=tokenizer,
replace_T=args.replace_T,
replace_U=args.replace_U)
# ========== Create the learning_rate scheduler (if need) and optimizer
# optimizer
optimizer = AdamW(params=model.parameters(), lr=args.learning_rate)
# ========== Create the metrics
_metric = RRInterMetrics(metrics=args.metrics)
# ========== Training
# train model
rr_inter_trainer = RRInterTrainer(
args=args,
tokenizer=tokenizer,
model=model,
train_dataset=dataset_train,
eval_dataset=dataset_eval,
data_collator=_collate_fn,
loss_fn=_loss_fn,
optimizer=optimizer,
compute_metrics=_metric,
)
if args.train:
for i_epoch in range(args.num_train_epochs):
print("Epoch: {}".format(i_epoch))
rr_inter_trainer.eval(i_epoch)
rr_inter_trainer.train(i_epoch)