-
Notifications
You must be signed in to change notification settings - Fork 10
/
finetune.py
76 lines (60 loc) · 2.05 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import fire
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from data_utils.constructed_discriminative_dataset import \
ConstructedDiscriminativeDataset
from data_utils.data_utils import get_dataloaders
from models.discriminative_aligner import DiscriminativeAligner
BATCH_SIZE = 1
ACCUMULATE_GRAD_BATCHES = 16
NUM_WORKERS = 6
WARMUP_PROPORTION = 0.1
ADAM_EPSILON = 1e-8
WEIGHT_DECAY = 0.01
LR = 1e-5
VAL_CHECK_INTERVAL = 1. / 4
def main(dataset_name, n_epochs=1, dialog_context=None):
dataset = {split: ConstructedDiscriminativeDataset(
dataset_name=dataset_name, split=split,
dialog_context=dialog_context)
for split in ['train', 'dev']}
dataloader = get_dataloaders(
dataset=dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
shuffle=True,
collate_fn='raw')
model = DiscriminativeAligner(aggr_type=None)
train_steps = n_epochs * (
len(dataloader['train']) // ACCUMULATE_GRAD_BATCHES + 1)
warmup_steps = int(train_steps * WARMUP_PROPORTION)
model.set_hparams(
batch_size=BATCH_SIZE,
accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
lr=LR,
train_steps=train_steps,
warmup_steps=warmup_steps,
weight_decay=WEIGHT_DECAY,
adam_epsilon=ADAM_EPSILON)
ckpt_filename = f'disc'
if dialog_context is not None:
ckpt_filename = ckpt_filename + f'_{dialog_context}'
checkpoint_callback = ModelCheckpoint(
dirpath=f'ckpts/{dataset_name}/',
filename=ckpt_filename,
monitor='val_f1',
mode='max',
save_top_k=1,
verbose=True)
trainer = pl.Trainer(
max_epochs=n_epochs,
checkpoint_callback=checkpoint_callback,
accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES,
val_check_interval=VAL_CHECK_INTERVAL,
gpus=1)
trainer.fit(
model=model,
train_dataloader=dataloader['train'],
val_dataloaders=dataloader['dev'])
if __name__ == '__main__':
fire.Fire(main)