Skip to content

Commit

Permalink
Add save and load of checkpoints during training
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni committed Jan 10, 2020
1 parent fc824ff commit 096de30
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 19 deletions.
4 changes: 2 additions & 2 deletions farm/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
automatic_loading=True,
max_multiprocessing_chunksize=2000,
max_processes=128,
checkpointing=False,
caching=False,
):
"""
:param processor: A dataset specific Processor object which will turn input (file or dict) into a Pytorch Dataset.
Expand All @@ -67,7 +67,7 @@ def __init__(
self.max_multiprocessing_chunksize = max_multiprocessing_chunksize

loaded_from_cache = False
if checkpointing: # Check if DataSets are present in cache
if caching: # Check if DataSets are present in cache
checksum = self._get_checksum()
dataset_path = Path(f"cache/data_silo/{checksum}")

Expand Down
149 changes: 132 additions & 17 deletions farm/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import absolute_import, division, print_function

import logging
import sys
import torch
from tqdm import tqdm

from farm.utils import MLFlowLogger as MlLogger
from farm.utils import GracefulKiller
from farm.eval import Evaluator
from farm.data_handler.data_silo import DataSilo
from farm.visual.ascii.images import GROWING_TREE
Expand Down Expand Up @@ -130,6 +132,7 @@ class Trainer:

def __init__(
self,
model,
optimizer,
data_silo,
epochs,
Expand All @@ -143,6 +146,11 @@ def __init__(
grad_acc_steps=1,
local_rank=-1,
early_stopping=None,
save_on_sigkill=False,
checkpoint_every=None,
checkpointing_dir=None,
from_step=1,
from_epoch=1
):
"""
:param optimizer: An optimizer object that determines the learning strategy to be used during training
Expand Down Expand Up @@ -175,6 +183,7 @@ def __init__(
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
:type early_stopping: EarlyStopping
"""
self.model = model
self.data_silo = data_silo
self.epochs = int(epochs)
self.optimizer = optimizer
Expand All @@ -190,6 +199,20 @@ def __init__(
self.local_rank = local_rank
self.log_params()
self.early_stopping = early_stopping
self.save_on_sigkill = save_on_sigkill
if save_on_sigkill:
self.sigkill_handler = GracefulKiller()
else:
self.sigkill_handler = None
self.checkpointing_dir = checkpointing_dir
self.checkpoint_every = checkpoint_every
if self.checkpoint_every and not checkpointing_dir:
raise Exception("checkpoint_path needs to be supplied when using checkpoint_every.")
if save_on_sigkill and not checkpointing_dir:
raise Exception("checkpoint_path needs to be supplied when using save_on_sigkill.")

self.from_epoch = from_epoch
self.from_step = from_step

# evaluator on dev set
if evaluator_dev is None and self.data_silo.get_data_loader("dev"):
Expand All @@ -209,39 +232,129 @@ def __init__(
)
self.evaluator_test = evaluator_test

def train(self, model):
@classmethod
def create_or_load_from_checkpoint(cls, data_silo, checkpointing_dir, **kwargs):
if checkpointing_dir.exists():
dirs = [d for d in checkpointing_dir.iterdir() if d.is_dir()]
total_steps_at_checkpoints = []
for d in dirs:
epoch, step = [int(s) for s in str(d).split("_") if s.isdigit()]
total_steps_at_checkpoints.append((d, epoch * step))
total_steps_at_checkpoints.sort(key=lambda tup: tup[1], reverse=True)
latest_checkpoint_path = total_steps_at_checkpoints[0][0]

trainer = cls.load_from_checkpoint(checkpoint_path=latest_checkpoint_path, data_silo=data_silo)
logging.info(f"Resuming training from the latest train checkpoint at {latest_checkpoint_path} ...")
else:
logging.info(f"No train checkpoints found. Starting a new training ...")
trainer = Trainer(data_silo=data_silo, checkpointing_dir=checkpointing_dir, **kwargs)
return trainer

@classmethod
def load_from_checkpoint(cls, checkpoint_path, data_silo):
if checkpoint_path.exists():

trainer_checkpoint = torch.load(checkpoint_path / "trainer")
trainer_state_dict = trainer_checkpoint["trainer_state_dict"]

device = trainer_state_dict["device"]
model = AdaptiveModel.load(load_dir=checkpoint_path, device=device)

optimizer = trainer_checkpoint["optimizer"]
optimizer.log_learning_rate = False

trainer = Trainer(
data_silo=data_silo,
model=model,
**trainer_state_dict
)
return trainer

def save(self):
checkpoint_dir = self.checkpointing_dir / f"epoch_{self.from_epoch}_step_{self.from_step}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

trainer_state_dict = self.get_state_dict()
self.model.save(checkpoint_dir)
torch.save({
"trainer_state_dict": trainer_state_dict,
'model_state_dict': self.model.state_dict(),
"optimizer": self.optimizer,
}, checkpoint_dir / "trainer")

# TODO custom defined evaluators are not saved in the checkpoint.
logger.info(f"Saved a training checkpoint at {checkpoint_dir}")

def get_state_dict(self):
state_dict = {
"optimizer": self.optimizer,
"warmup_linear": self.warmup_linear,
"evaluate_every": self.evaluate_every,
"n_gpu": self.n_gpu,
"grad_acc_steps": self.grad_acc_steps,
"device": self.device,
"local_rank": self.local_rank,
"early_stopping": self.early_stopping,
"fp16": self.fp16,
"epochs": self.epochs,
"save_on_sigkill": self.save_on_sigkill,
"checkpointing_dir": self.checkpointing_dir,
"checkpoint_every": self.checkpoint_every,
"from_epoch": self.from_epoch,
"from_step": self.from_step,
}

return state_dict

def train(self):
""" Perform the training procedure. """

# connect the prediction heads with the right output from processor
model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)

# Check that the tokenizer fits the language model
model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer))
self.model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer))

logger.info(f"\n {GROWING_TREE}")
model.train()
self.model.train()
# multi GPU + distributed settings
if self.fp16:
model.half()
self.model.half()
if self.local_rank > -1:
model = WrappedDDP(model)
self.model = WrappedDDP(self.model)
elif self.n_gpu > 1:
model = WrappedDataParallel(model)
self.model = WrappedDataParallel(self.model)

do_stopping = False
evalnr = 0
loss = 0
for epoch in range(1, self.epochs + 1):
progress_bar = tqdm(self.data_loader_train)
for step, batch in enumerate(progress_bar):

resume_from_step = self.from_step

for epoch in range(self.from_epoch, self.epochs + 1):
progress_bar = tqdm(self.data_loader_train) # start at a random location
for step, batch in enumerate(progress_bar, start=1):
# when resuming training from a checkpoint, we want to fast forward to the step of the checkpoint
if resume_from_step and step < resume_from_step:
if resume_from_step == step:
resume_from_step = None
continue

if self.sigkill_handler and self.sigkill_handler.kill_now: # save the current state as a checkpoint
self.save()
sys.exit(0)

if step and step % self.checkpoint_every == 0: # save a checkpoint and continue training
self.save()

progress_bar.set_description(f"Train epoch {epoch}/{self.epochs} (Cur. train loss: {loss:.4f})")

# Move batch of samples to device
batch = {key: batch[key].to(self.device) for key in batch}

# Forward pass through model
logits = model.forward(**batch)
per_sample_loss = model.logits_to_loss(logits=logits, **batch)
logits = self.model.forward(**batch)
per_sample_loss = self.model.logits_to_loss(logits=logits, **batch)

loss = self.backward_propagate(per_sample_loss, step)

Expand All @@ -251,37 +364,39 @@ def train(self, model):
self.global_step % self.evaluate_every == 0
):
evalnr += 1
result = self.evaluator_dev.eval(model)
result = self.evaluator_dev.eval(self.model)
self.evaluator_dev.log_results(result, "Dev", self.global_step)
if self.early_stopping:
do_stopping, save_model, eval_value = self.early_stopping.check_stopping(result)
if save_model:
logger.info(
"Saving current best model to {}, eval={}".format(
self.early_stopping.save_dir, eval_value))
model.save(self.early_stopping.save_dir)
self.model.save(self.early_stopping.save_dir)
self.data_silo.processor.save(self.early_stopping.save_dir)
if do_stopping:
# log the stopping
logger.info("STOPPING EARLY AT EPOCH {}, STEP {}, EVALUATION {}".format(epoch, step, evalnr))
if do_stopping:
break
self.global_step += 1
self.from_step = step
self.from_epoch = epoch
if do_stopping:
break

# With early stopping we want to restore the best model
if self.early_stopping and self.early_stopping.save_dir:
logger.info("Restoring best model so far from {}".format(self.early_stopping.save_dir))
lm_name = model.language_model.name
lm_name = self.model.language_model.name
model = AdaptiveModel.load(self.early_stopping.save_dir, self.device, lm_name=lm_name)
model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)

# Eval on test set
if self.evaluator_test:
result = self.evaluator_test.eval(model)
result = self.evaluator_test.eval(self.model)
self.evaluator_test.log_results(result, "Test", self.global_step)
return model
return self.model

def backward_propagate(self, loss, step):
loss = self.adjust_loss(loss)
Expand Down
11 changes: 11 additions & 0 deletions farm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import random
import signal

import numpy as np
import torch
Expand Down Expand Up @@ -252,3 +253,13 @@ def get_dict_checksum(payload_dict):
"""
checksum = hashlib.md5(json.dumps(payload_dict, sort_keys=True).encode("utf-8")).hexdigest()
return checksum


class GracefulKiller:
kill_now = False

def __init__(self):
signal.signal(signal.SIGTERM, self.exit_gracefully)

def exit_gracefully(self, signum, frame):
self.kill_now = True

0 comments on commit 096de30

Please sign in to comment.