-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune.py
executable file
·125 lines (114 loc) · 4.4 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- encoding: utf-8 -*-
'''
@Time : 2022/06/12 14:49:28
@Author : Chu Xiaokai
@Contact : xiaokaichu@gmail.com
'''
import numpy as np
import warnings
import sys
from metrics import *
from Transformer4Ranking.model import *
from paddle.io import DataLoader
import paddle.distributed as dist
from dataloader import *
from args import config
from datetime import datetime
random.seed(config.seed+1)
random.seed(config.seed)
np.random.seed(config.seed)
paddle.set_device(f"gpu:{config.gpu_device}")
dist.init_parallel_env()
paddle.seed(config.seed)
print(config)
exp_settings = config.exp_settings
dt_string = datetime.today().strftime('%Y-%m-%d')
model = TransformerModel(
ntoken=config.ntokens,
hidden=config.emb_dim,
nhead=config.nhead,
nlayers=config.nlayers,
dropout=config.dropout,
mode='finetune'
)
# load pretrained model
if config.init_parameters != "":
print('load warm up model ', config.init_parameters)
ptm = paddle.load(config.init_parameters)
for k, v in model.state_dict().items():
if not k in ptm:
pass
print("warning: not loading " + k)
else:
print("loading " + k)
v.set_value(ptm[k])
# 优化器设置
model = paddle.DataParallel(model)
scheduler = get_linear_schedule_with_warmup(config.lr, config.warmup_steps,
config.max_steps)
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=scheduler,
parameters=model.parameters(),
weight_decay=config.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params,
grad_clip=nn.ClipGradByNorm(clip_norm=0.5)
)
criterion = nn.BCEWithLogitsLoss()
vaild_annotate_dataset = TestDataset(config.valid_annotate_path, max_seq_len=config.max_seq_len, data_type='finetune')
vaild_annotate_loader = DataLoader(vaild_annotate_dataset, batch_size=config.eval_batch_size)
test_annotate_dataset = TestDataset(config.test_annotate_path, max_seq_len=config.max_seq_len, data_type='annotate')
test_annotate_loader = DataLoader(test_annotate_dataset, batch_size=config.eval_batch_size)
idx = 0
for i in range(config.finetune_epoch):
print(f'Start epoch {i}')
for valid_data_batch in vaild_annotate_loader:
model.train()
optimizer.clear_grad()
src_input, src_segment, src_padding_mask, label = valid_data_batch
score = model(
src=src_input,
src_segment=src_segment,
src_padding_mask=src_padding_mask,
)
ctr_loss = criterion(score, paddle.to_tensor(label, dtype=paddle.float32))
ctr_loss.backward()
optimizer.step()
scheduler.step()
if idx % config.log_interval == 0:
print(f'{idx:5d}th step | loss {ctr_loss.item():5.6f}')
if idx % config.eval_step == 0:
model.eval()
# ------------ evaluate on annotated data -------------- #
total_scores = []
for test_data_batch in test_annotate_loader:
src_input, src_segment, src_padding_mask, label = test_data_batch
score = model(
src=src_input,
src_segment=src_segment,
src_padding_mask=src_padding_mask,
)
score = score.cpu().detach().numpy().tolist()
total_scores += score
result_dict_ann = evaluate_all_metric(
qid_list=test_annotate_dataset.total_qids,
label_list=test_annotate_dataset.total_labels,
score_list=total_scores,
freq_list=test_annotate_dataset.total_freqs
)
print(
f'{idx}th step valid annotate | '
f'dcg@10: all {result_dict_ann["all_dcg@10"]:.6f} | '
f'high {result_dict_ann["high_dcg@10"]:.6f} | '
f'mid {result_dict_ann["mid_dcg@10"]:.6f} | '
f'low {result_dict_ann["low_dcg@10"]:.6f} | '
f'pnr {result_dict_ann["pnr"]:.6f}'
)
if idx % config.save_step == 0 and idx > 0 and paddle.distributed.get_rank() == 0:
paddle.save(model.state_dict(),
'model/{}-{}/save_steps{}_{:.5f}.model'.format('ckpt' if config.nlayers == 12 else '24L', dt_string, idx, result_dict_ann['all_dcg@10'])
)
idx += 1