From 096de3010a46c8556607f92e1370320acca2f503 Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Fri, 10 Jan 2020 11:45:53 +0100 Subject: [PATCH] Add save and load of checkpoints during training --- farm/data_handler/data_silo.py | 4 +- farm/train.py | 149 +++++++++++++++++++++++++++++---- farm/utils.py | 11 +++ 3 files changed, 145 insertions(+), 19 deletions(-) diff --git a/farm/data_handler/data_silo.py b/farm/data_handler/data_silo.py index 5910430a9..47aa304d4 100644 --- a/farm/data_handler/data_silo.py +++ b/farm/data_handler/data_silo.py @@ -41,7 +41,7 @@ def __init__( automatic_loading=True, max_multiprocessing_chunksize=2000, max_processes=128, - checkpointing=False, + caching=False, ): """ :param processor: A dataset specific Processor object which will turn input (file or dict) into a Pytorch Dataset. @@ -67,7 +67,7 @@ def __init__( self.max_multiprocessing_chunksize = max_multiprocessing_chunksize loaded_from_cache = False - if checkpointing: # Check if DataSets are present in cache + if caching: # Check if DataSets are present in cache checksum = self._get_checksum() dataset_path = Path(f"cache/data_silo/{checksum}") diff --git a/farm/train.py b/farm/train.py index 6dae1af6a..d11776120 100644 --- a/farm/train.py +++ b/farm/train.py @@ -1,10 +1,12 @@ from __future__ import absolute_import, division, print_function import logging +import sys import torch from tqdm import tqdm from farm.utils import MLFlowLogger as MlLogger +from farm.utils import GracefulKiller from farm.eval import Evaluator from farm.data_handler.data_silo import DataSilo from farm.visual.ascii.images import GROWING_TREE @@ -130,6 +132,7 @@ class Trainer: def __init__( self, + model, optimizer, data_silo, epochs, @@ -143,6 +146,11 @@ def __init__( grad_acc_steps=1, local_rank=-1, early_stopping=None, + save_on_sigkill=False, + checkpoint_every=None, + checkpointing_dir=None, + from_step=1, + from_epoch=1 ): """ :param optimizer: An optimizer object that determines the learning strategy to be used during training @@ -175,6 +183,7 @@ def __init__( :param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models. :type early_stopping: EarlyStopping """ + self.model = model self.data_silo = data_silo self.epochs = int(epochs) self.optimizer = optimizer @@ -190,6 +199,20 @@ def __init__( self.local_rank = local_rank self.log_params() self.early_stopping = early_stopping + self.save_on_sigkill = save_on_sigkill + if save_on_sigkill: + self.sigkill_handler = GracefulKiller() + else: + self.sigkill_handler = None + self.checkpointing_dir = checkpointing_dir + self.checkpoint_every = checkpoint_every + if self.checkpoint_every and not checkpointing_dir: + raise Exception("checkpoint_path needs to be supplied when using checkpoint_every.") + if save_on_sigkill and not checkpointing_dir: + raise Exception("checkpoint_path needs to be supplied when using save_on_sigkill.") + + self.from_epoch = from_epoch + self.from_step = from_step # evaluator on dev set if evaluator_dev is None and self.data_silo.get_data_loader("dev"): @@ -209,39 +232,129 @@ def __init__( ) self.evaluator_test = evaluator_test - def train(self, model): + @classmethod + def create_or_load_from_checkpoint(cls, data_silo, checkpointing_dir, **kwargs): + if checkpointing_dir.exists(): + dirs = [d for d in checkpointing_dir.iterdir() if d.is_dir()] + total_steps_at_checkpoints = [] + for d in dirs: + epoch, step = [int(s) for s in str(d).split("_") if s.isdigit()] + total_steps_at_checkpoints.append((d, epoch * step)) + total_steps_at_checkpoints.sort(key=lambda tup: tup[1], reverse=True) + latest_checkpoint_path = total_steps_at_checkpoints[0][0] + + trainer = cls.load_from_checkpoint(checkpoint_path=latest_checkpoint_path, data_silo=data_silo) + logging.info(f"Resuming training from the latest train checkpoint at {latest_checkpoint_path} ...") + else: + logging.info(f"No train checkpoints found. Starting a new training ...") + trainer = Trainer(data_silo=data_silo, checkpointing_dir=checkpointing_dir, **kwargs) + return trainer + + @classmethod + def load_from_checkpoint(cls, checkpoint_path, data_silo): + if checkpoint_path.exists(): + + trainer_checkpoint = torch.load(checkpoint_path / "trainer") + trainer_state_dict = trainer_checkpoint["trainer_state_dict"] + + device = trainer_state_dict["device"] + model = AdaptiveModel.load(load_dir=checkpoint_path, device=device) + + optimizer = trainer_checkpoint["optimizer"] + optimizer.log_learning_rate = False + + trainer = Trainer( + data_silo=data_silo, + model=model, + **trainer_state_dict + ) + return trainer + + def save(self): + checkpoint_dir = self.checkpointing_dir / f"epoch_{self.from_epoch}_step_{self.from_step}" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + trainer_state_dict = self.get_state_dict() + self.model.save(checkpoint_dir) + torch.save({ + "trainer_state_dict": trainer_state_dict, + 'model_state_dict': self.model.state_dict(), + "optimizer": self.optimizer, + }, checkpoint_dir / "trainer") + + # TODO custom defined evaluators are not saved in the checkpoint. + logger.info(f"Saved a training checkpoint at {checkpoint_dir}") + + def get_state_dict(self): + state_dict = { + "optimizer": self.optimizer, + "warmup_linear": self.warmup_linear, + "evaluate_every": self.evaluate_every, + "n_gpu": self.n_gpu, + "grad_acc_steps": self.grad_acc_steps, + "device": self.device, + "local_rank": self.local_rank, + "early_stopping": self.early_stopping, + "fp16": self.fp16, + "epochs": self.epochs, + "save_on_sigkill": self.save_on_sigkill, + "checkpointing_dir": self.checkpointing_dir, + "checkpoint_every": self.checkpoint_every, + "from_epoch": self.from_epoch, + "from_step": self.from_step, + } + + return state_dict + + def train(self): """ Perform the training procedure. """ # connect the prediction heads with the right output from processor - model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True) + self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True) # Check that the tokenizer fits the language model - model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer)) + self.model.verify_vocab_size(vocab_size=len(self.data_silo.processor.tokenizer)) logger.info(f"\n {GROWING_TREE}") - model.train() + self.model.train() # multi GPU + distributed settings if self.fp16: - model.half() + self.model.half() if self.local_rank > -1: - model = WrappedDDP(model) + self.model = WrappedDDP(self.model) elif self.n_gpu > 1: - model = WrappedDataParallel(model) + self.model = WrappedDataParallel(self.model) do_stopping = False evalnr = 0 loss = 0 - for epoch in range(1, self.epochs + 1): - progress_bar = tqdm(self.data_loader_train) - for step, batch in enumerate(progress_bar): + + resume_from_step = self.from_step + + for epoch in range(self.from_epoch, self.epochs + 1): + progress_bar = tqdm(self.data_loader_train) # start at a random location + for step, batch in enumerate(progress_bar, start=1): + # when resuming training from a checkpoint, we want to fast forward to the step of the checkpoint + if resume_from_step and step < resume_from_step: + if resume_from_step == step: + resume_from_step = None + continue + + if self.sigkill_handler and self.sigkill_handler.kill_now: # save the current state as a checkpoint + self.save() + sys.exit(0) + + if step and step % self.checkpoint_every == 0: # save a checkpoint and continue training + self.save() + progress_bar.set_description(f"Train epoch {epoch}/{self.epochs} (Cur. train loss: {loss:.4f})") # Move batch of samples to device batch = {key: batch[key].to(self.device) for key in batch} # Forward pass through model - logits = model.forward(**batch) - per_sample_loss = model.logits_to_loss(logits=logits, **batch) + logits = self.model.forward(**batch) + per_sample_loss = self.model.logits_to_loss(logits=logits, **batch) loss = self.backward_propagate(per_sample_loss, step) @@ -251,7 +364,7 @@ def train(self, model): self.global_step % self.evaluate_every == 0 ): evalnr += 1 - result = self.evaluator_dev.eval(model) + result = self.evaluator_dev.eval(self.model) self.evaluator_dev.log_results(result, "Dev", self.global_step) if self.early_stopping: do_stopping, save_model, eval_value = self.early_stopping.check_stopping(result) @@ -259,7 +372,7 @@ def train(self, model): logger.info( "Saving current best model to {}, eval={}".format( self.early_stopping.save_dir, eval_value)) - model.save(self.early_stopping.save_dir) + self.model.save(self.early_stopping.save_dir) self.data_silo.processor.save(self.early_stopping.save_dir) if do_stopping: # log the stopping @@ -267,21 +380,23 @@ def train(self, model): if do_stopping: break self.global_step += 1 + self.from_step = step + self.from_epoch = epoch if do_stopping: break # With early stopping we want to restore the best model if self.early_stopping and self.early_stopping.save_dir: logger.info("Restoring best model so far from {}".format(self.early_stopping.save_dir)) - lm_name = model.language_model.name + lm_name = self.model.language_model.name model = AdaptiveModel.load(self.early_stopping.save_dir, self.device, lm_name=lm_name) model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True) # Eval on test set if self.evaluator_test: - result = self.evaluator_test.eval(model) + result = self.evaluator_test.eval(self.model) self.evaluator_test.log_results(result, "Test", self.global_step) - return model + return self.model def backward_propagate(self, loss, step): loss = self.adjust_loss(loss) diff --git a/farm/utils.py b/farm/utils.py index 2a007a5b7..fe166532d 100644 --- a/farm/utils.py +++ b/farm/utils.py @@ -2,6 +2,7 @@ import json import logging import random +import signal import numpy as np import torch @@ -252,3 +253,13 @@ def get_dict_checksum(payload_dict): """ checksum = hashlib.md5(json.dumps(payload_dict, sort_keys=True).encode("utf-8")).hexdigest() return checksum + + +class GracefulKiller: + kill_now = False + + def __init__(self): + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + self.kill_now = True