diff --git a/farm/train.py b/farm/train.py index f810facbc..3f8a5e28b 100644 --- a/farm/train.py +++ b/farm/train.py @@ -120,8 +120,8 @@ def __init__( checkpoint_on_sigkill=False, checkpoint_every=None, checkpoint_root_dir=None, - from_epoch=1, - from_step=1, + from_epoch=0, + from_step=0, ): """ :param optimizer: An optimizer object that determines the learning strategy to be used during training @@ -270,6 +270,9 @@ def _load_checkpoint(cls, path, data_silo): trainer_checkpoint = torch.load(path / "trainer") trainer_state_dict = trainer_checkpoint["trainer_state_dict"] + # Just setting seeds is not sufficient to have deterministic results when resuming + # training from a checkpoint. Additionally, the previous states of Random Number + # Generators also need to be restored from the saved checkpoint. numpy_rng_state = trainer_checkpoint["numpy_rng_state"] numpy.random.set_state(numpy_rng_state) rng_state = trainer_checkpoint["rng_state"]