-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfine_tune_clintox.py
121 lines (93 loc) · 3.56 KB
/
fine_tune_clintox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import pandas as pd
import torch
from transformers import RobertaForMaskedLM
from torch.utils.data import Dataset, DataLoader
import pickle
from tqdm import tqdm
from Tokenizer.MFBERT_Tokenizer import MFBERTTokenizer
import numpy as np
from sklearn.model_selection import train_test_split
assert torch.cuda.device_count() == 1
MAX_LEN = 514
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 1e-05
TOKENIZER_DIR = 'Tokenizer/'
class ClintoxDataset(Dataset):
def __init__(self, training=True):
examples = []
with open('Datasets/data_splits/ClinTox/train.pkl', 'rb') as f:
traindata = pickle.load(f)
for k,v in traindata.items():
examples.append((k,v))
self.data = examples
self.tokenizer = MFBERTTokenizer.from_pretrained(TOKENIZER_DIR+'Model/',
dict_file = TOKENIZER_DIR+'Model/dict.txt')
self.max_len = 514
def __getitem__(self, idx):
example = self.data[idx]
smiles = example[0]
target = example[1]
inputs = self.tokenizer.encode_plus(
smiles,
None,
add_special_tokens=True,
max_length=self.max_len,
pad_to_max_length=True,
return_token_type_ids=True,
truncation=True
)
ids = inputs['input_ids']
mask = inputs['attention_mask']
return {'input_ids':torch.tensor(ids, dtype=torch.long),
'attention_mask':torch.tensor(mask, dtype=torch.long),
'label':torch.tensor(target, dtype=torch.long)}
def __len__(self):
return len(self.data)
class MFBERTForClintox(torch.nn.Module):
def __init__(self):
super(MFBERTForClintox, self).__init__()
self.l1 = list(RobertaForMaskedLM.from_pretrained('Model/pre-trained').children())[0]
self.l2 = torch.nn.Dropout(0.2)
self.l3 = torch.nn.Linear(768, 1)
def forward(self, ids, mask):
output_1 = self.l1(ids, mask)
output_2 = self.l2(torch.mean(output_1[0], dim=1))
output = self.l3(output_2)
return output
trainds = ClintoxDataset()
model = MFBERTForClintox().cuda()
train_params = {'batch_size': TRAIN_BATCH_SIZE,
'shuffle': True,
'num_workers': 0
}
training_loader = DataLoader(trainds, **train_params)
# Creating the loss function and optimizer
loss_function = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
curminloss = 1
def train(epoch):
model.train()
for _ , data in tqdm(enumerate(training_loader, 0), desc='ITERATION', total=len(training_loader)):
ids = data['input_ids'].cuda()
mask = data['attention_mask'].cuda()
targets = data['label'].float().cuda()
global curminloss
outputs = model(ids, mask).squeeze()
optimizer.zero_grad()
loss = loss_function(outputs, targets)
if _%100==0:
print(f'Epoch: {epoch}, Loss: {loss.item()}')
# save best model
if loss.item()<curminloss:
torch.save(model, f'fine-tuned/Clintox_model_best_{loss.item()}.bin')
curminloss = loss.item()
print('saving best...')
loss.backward()
optimizer.step()
for epoch in tqdm(range(EPOCHS), desc='EPOCHS'):
train(epoch)
torch.save(model, 'fine-tuned/Clintox_model_last.bin')