-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathnlp_finetuning_lightning_google.py
469 lines (359 loc) · 18.8 KB
/
nlp_finetuning_lightning_google.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
# Start : Import the packages
import pandas as pd
import os
import pathlib
import zipfile
import wget
import gdown
import torch
from torch import nn
from torch import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import random_split
from torch.utils.data import RandomSampler
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import AutoConfig
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
# End : Import the packages
# %%
# Remark: pl.LightningModule is derived from torch.nn.Module It has additional methods that are part of the
# lightning interface and that need to be defined by the user. Having these additional methods is very useful
# for several reasons:
# 1. Reasonable Expectations: Once you know the pytorch-lightning-system you more easily read other people's
# code, because it is always structured in the same way
# 2. Less Boilerplate: The additional class methods make pl.LightningModule more powerful than the nn.Module
# from plain pytorch. This means you have to write less of the repetitive boilerplate code
# 3. Perfact for the development lifecycle Pytorch Lightning makes it very easy to switch from cpu to gpu/tpu.
# Further it supplies method to quickly run your code on a fraction of the data, which is very useful in
# the development process, especially for debugging
class Model(pl.LightningModule):
def __init__(self, *args, **kwargs):
super().__init__()
self.save_hyperparameters()
# a very useful feature of pytorch lightning which leads to the named variables that are passed in
# being available as self.hparams.<variable_name> We use this when refering to eg
# self.hparams.learning_rate
# freeze
self._frozen = False
# eg https://github.com/stefan-it/turkish-bert/issues/5
config = AutoConfig.from_pretrained(self.hparams.pretrained,
num_labels=5, # 1 implies regression
output_attentions=False,
output_hidden_states=False)
print(config)
A = AutoModelForSequenceClassification
self.model = A.from_pretrained(self.hparams.pretrained, config=config)
print('Model Type', type(self.model))
# Possible choices for pretrained are:
# distilbert-base-uncased
# bert-base-uncased
# The BERT paper says: "[The] pre-trained BERT model can be fine-tuned with just one additional output
# layer to create state-of-the-art models for a wide range of tasks, such as question answering and
# language inference, without substantial task-specific architecture modifications."
#
# Huggingface/transformers provides access to such pretrained model versions, some of which have been
# published by various community members.
#
# BertForSequenceClassification is one of those pretrained models, which is loaded automatically by
# AutoModelForSequenceClassification because it corresponds to the pretrained weights of
# "bert-base-uncased".
#
# Huggingface says about BertForSequenceClassification: Bert Model transformer with a sequence
# classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE
# tasks."
# This part is easy we instantiate the pretrained model (checkpoint)
# But it's also incredibly important, e.g. by using "bert-base-uncased, we determine, that that model
# does not distinguish between lower and upper case. This might have a significant impact on model
# performance!!!
def forward(self, batch):
# there are some choices, as to how you can define the input to the forward function I prefer it this
# way, where the batch contains the input_ids, the input_put_mask and sometimes the labels (for
# training)
b_input_ids = batch[0]
b_input_mask = batch[1]
has_labels = len(batch) > 2
b_labels = batch[2] if has_labels else None
res = self.model(b_input_ids,
attention_mask=b_input_mask,
labels=b_labels)
# there are labels in the batch, this indicates: training for the BertForSequenceClassification model:
# it means that the model returns tuples, where the first element is the training loss and the second
# element is the logits
if has_labels:
loss, logits = res['loss'], res['logits']
# there are labels in the batch, this indicates: prediction for the BertForSequenceClassification
# model: it means that the model returns simply the logits
if not has_labels:
loss, logits = None, res['logits']
return loss, logits
def training_step(self, batch, batch_nb):
# the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
# class stipulates you to overwrite. This we do here, by virtue of this definition
loss, logits = self(
batch
) # self refers to the model, which in turn acceses the forward method
self.log('train_loss', loss)
# pytorch lightning allows you to use various logging facilities, eg tensorboard with tensorboard we
# can track and easily visualise the progress of training. In this case
return {'loss': loss}
# the training_step method expects either a dictionary or a the loss as a number
def validation_step(self, batch, batch_nb):
# the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
# class wants you to overwrite, in case you want to do validation. This we do here, by virtue of this
# definition.
loss, logits = self(batch)
# self refers to the model, which in turn accesses the forward method
# Apart from the validation loss, we also want to track validation accuracy to get an idea, what the
# model training has achieved "in real terms".
labels = batch[2]
predictions = torch.argmax(logits, dim=1)
accuracy = (labels == predictions).float().mean()
self.log('val_loss', loss)
self.log('accuracy', accuracy)
# the validation_step method expects a dictionary, which should at least contain the val_loss
return {'val_loss': loss, 'val_accuracy': accuracy}
def validation_epoch_end(self, validation_step_outputs):
# OPTIONAL The second parameter in the validation_epoch_end - we named it validation_step_outputs -
# contains the outputs of the validation_step, collected for all the batches over the entire epoch.
# We use it to track progress of the entire epoch, by calculating averages
avg_loss = torch.stack([x['val_loss']
for x in validation_step_outputs]).mean()
avg_accuracy = torch.stack(
[x['val_accuracy'] for x in validation_step_outputs]).mean()
tensorboard_logs = {'val_avg_loss': avg_loss, 'val_avg_accuracy': avg_accuracy}
return {
'val_loss': avg_loss,
'log': tensorboard_logs,
'progress_bar': {
'avg_loss': avg_loss,
'avg_accuracy': avg_accuracy
}
}
# The training_step method expects a dictionary, which should at least contain the val_loss. We also
# use it to include the log - with the tensorboard logs. Further we define some values that are
# displayed in the tqdm-based progress bar.
def configure_optimizers(self):
# The configure_optimizers is a (virtual) method, specified in the interface, that the
# pl.LightningModule class wants you to overwrite.
# In this case we define that some parameters are optimized in a different way than others. In
# particular we single out parameters that have 'bias', 'LayerNorm.weight' in their names. For those
# we do not use an optimization technique called weight decay.
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [{
'params': [
p for n, p in self.named_parameters()
if not any(nd in n for nd in no_decay)
],
'weight_decay':
0.01
}, {
'params': [
p for n, p in self.named_parameters()
if any(nd in n for nd in no_decay)
],
'weight_decay':
0.0
}]
optimizer = AdamW(optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
eps=1e-8
# args.adam_epsilon - default is 1e-8.
)
# We also use a scheduler that is supplied by transformers.
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
# Default value in run_glue.py
num_training_steps=self.hparams.num_training_steps)
return [optimizer], [scheduler]
def freeze(self) -> None:
# freeze all layers, except the final classifier layers
for name, param in self.model.named_parameters():
if 'classifier' not in name: # classifier layer
param.requires_grad = False
self._frozen = True
def unfreeze(self) -> None:
if self._frozen:
for name, param in self.model.named_parameters():
if 'classifier' not in name: # classifier layer
param.requires_grad = True
self._frozen = False
def on_epoch_start(self):
"""pytorch lightning hook"""
if self.current_epoch < self.hparams.nr_frozen_epochs:
self.freeze()
if self.current_epoch >= self.hparams.nr_frozen_epochs:
self.unfreeze()
class Data(pl.LightningDataModule):
# So here we finally arrive at the definition of our data class derived from pl.LightningDataModule.
#
# In earlier versions of pytorch lightning (prior to 0.9) the methods here were part of the model class
# derived from pl.LightningModule. For better flexibility and readability the Data and Model related parts
# were split out into two different classes:
#
# pl.LightningDataModule and pl.LightningModule
#
# with the Model related part remaining in pl.LightningModule
#
# This is explained in more detail in this video: https://www.youtube.com/watch?v=L---MBeSXFw
def __init__(self, *args, **kwargs):
super().__init__()
# self.save_hyperparameters()
if isinstance(args, tuple):
args = args[0]
self.hparams = args
# cf this open issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/3232
print('args:', args)
print('kwargs:', kwargs)
# print(f'self.hparams.pretrained:{self.hparams.pretrained}')
print('Loading BERT tokenizer')
print(f'PRETRAINED:{self.hparams.pretrained}')
A = AutoTokenizer
self.tokenizer = A.from_pretrained(self.hparams.pretrained)
print('Type tokenizer:', type(self.tokenizer))
# This part is easy we instantiate the tokenizer
# So this is easy, but it's also incredibly important, e.g. in this by using "bert-base-uncased", we
# determine, that before any text is analysed its all turned into lower case. This might have a
# significant impact on model performance!!!
#
# BertTokenizer is the tokenizer, which is loaded automatically by AutoTokenizer because it was used
# to train the model weights of "bert-base-uncased".
def prepare_data(self):
# Even if you have a complicated setup, where you train on a cluster of multiple GPUs, prepare_data is
# only run once on the cluster.
# Typically - as done here - prepare_data just performs the time-consuming step of downloading the
# data.
print('Setting up dataset')
prefix = 'https://drive.google.com/uc?id='
id_apps = "1S6qMioqPJjyBLpLVz4gmRTnJHnjitnuV"
id_reviews = "1zdmewp7ayS4js4VtrJEHzAheSW-5NBZv"
pathlib.Path('./data').mkdir(parents=True, exist_ok=True)
# Download the file (if we haven't already)
if not os.path.exists('./data/apps.csv'):
gdown.download(url=prefix + id_apps,
output='./data/apps.csv',
quiet=False)
# Download the file (if we haven't already)
if not os.path.exists('./data/reviews.csv'):
gdown.download(url=prefix + id_reviews,
output='./data/reviews.csv',
quiet=False)
def setup(self, stage=None):
# Even if you have a complicated setup, where you train on a cluster of multiple GPUs, setup is run
# once on every gpu of the cluster.
# typically - as done here - setup
# - reads the previously downloaded data
# - does some preprocessing such as tokenization
# - splits out the dataset into training and validation datasets
# Load the dataset into a pandas dataframe.
df = pd.read_csv("./data/reviews.csv", delimiter=',', header=0)
if self.hparams.frac < 1:
df = df.sample(frac=self.hparams.frac, random_state=0)
df['score'] -= 1
# Report the number of sentences.
print('Number of training sentences: {:,}\n'.format(df.shape[0]))
# Get the lists of sentences and their labels.
sentences = df.content.values
labels = df.score.values
t = self.tokenizer(
sentences.tolist(), # Sentence to encode.
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
max_length=128, # Pad & truncate all sentences.
padding='max_length',
truncation=True,
return_attention_mask=True, # Construct attn. masks.
return_tensors='pt' # Return pytorch tensors.
)
# Convert the lists into tensors.
input_ids = t['input_ids']
attention_mask = t['attention_mask']
labels = torch.tensor(labels) # .float() if regrssion
# Print sentence 0, now as a list of IDs. print('Example') print('Original: ', sentences[0])
# print('Token IDs', input_ids[0]) print('End: Example')
# Combine the training inputs into a TensorDataset.
dataset = TensorDataset(input_ids, attention_mask, labels)
# Create a 90-10 train-validation split.
# Calculate the number of samples to include in each set.
train_size = int(self.hparams.training_portion * len(dataset))
val_size = len(dataset) - train_size
print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))
self.train_dataset, self.val_dataset = random_split(
dataset, [train_size, val_size],
generator=torch.Generator().manual_seed(42))
def train_dataloader(self):
# as explained above, train_dataloader was previously part of the model class derived from
# pl.LightningModule train_dataloader needs to return the a Dataloader with the train_dataset
return DataLoader(
self.train_dataset, # The training samples.
sampler=RandomSampler(
self.train_dataset), # Select batches randomly
batch_size=self.hparams.batch_size # Trains with this batch size.
)
def val_dataloader(self):
# as explained above, train_dataloader was previously part of the model class derived from
# pl.LightningModule train_dataloader needs to return the a Dataloader with the val_dataset
return DataLoader(
self.val_dataset, # The training samples.
sampler=RandomSampler(self.val_dataset), # Select batches randomly
batch_size=self.hparams.batch_size, # Trains with this batch size.
shuffle=False)
# %%
if __name__ == "__main__":
# Two key aspects:
# - pytorch lightning can add arguments to the parser automatically
# - you can manually add your own specific arguments.
# - there is a little more code than seems necessary, because of a particular argument the scheduler
# needs. There is currently an open issue on this complication
# https://github.com/PyTorchLightning/pytorch-lightning/issues/1038
import argparse
from argparse import ArgumentParser
parser = ArgumentParser()
# We use the parts of very convenient Auto functions from huggingface. This way we can easily switch
# between models and tokenizers, just by giving a different name of the pretrained model.
#
# BertForSequenceClassification is one of those pretrained models, which is loaded automatically by
# AutoModelForSequenceClassification because it corresponds to the pretrained weights of
# "bert-base-uncased".
# Similarly BertTokenizer is one of those tokenizers, which is loaded automatically by AutoTokenizer
# because it is the necessary tokenizer for the pretrained weights of "bert-base-uncased".
parser.add_argument('--pretrained', type=str, default="distilbert-base-uncased")
parser.add_argument('--nr_frozen_epochs', type=int, default=5)
parser.add_argument('--training_portion', type=float, default=0.9)
parser.add_argument('--batch_size', type=float, default=32)
parser.add_argument('--learning_rate', type=float, default=2e-5)
parser.add_argument('--frac', type=float, default=1)
# parser = Model.add_model_specific_args(parser) parser = Data.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# TODO start: remove this later
# args.limit_train_batches = 10 # TODO remove this later
# args.limit_val_batches = 5 # TODO remove this later
# args.frac = 0.01 # TODO remove this later
# args.fast_dev_run = True # TODO remove this later
# args.max_epochs = 2 # TODO remove this later
logger = TensorBoardLogger(
save_dir=os.getcwd(),
version=1,
name='lightning_logs')
parser.logger = logger
# TODO end: remove this later
# start : get training steps
d = Data(args)
d.prepare_data()
d.setup()
args.num_training_steps = len(d.train_dataloader()) * args.max_epochs
# end : get training steps
dict_args = vars(args)
m = Model(**dict_args)
args.early_stop_callback = EarlyStopping('val_loss')
trainer = pl.Trainer.from_argparse_args(args)
# fit the data
trainer.fit(m, d)
# %%