From 717ebce0d24a344b47a63bdb35b62f8de3e0e06d Mon Sep 17 00:00:00 2001 From: iliaschair Date: Sat, 16 Nov 2024 16:14:00 +0100 Subject: [PATCH] add ema to BaseTrainer init --- src/fairchem/core/trainers/base_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index acd18d6e1..90cdce0e5 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -101,6 +101,7 @@ def __init__( self.cpu = cpu self.epoch = 0 self.step = 0 + self.ema = None if torch.cuda.is_available() and not self.cpu: logging.info(f"local rank base: {local_rank}") @@ -617,7 +618,7 @@ def load_checkpoint( "Loading checkpoint in inference-only mode, not loading keys associated with trainer state!" ) - if "ema" in checkpoint and checkpoint["ema"] is not None: + if "ema" in checkpoint and checkpoint["ema"] is not None and self.ema: self.ema.load_state_dict(checkpoint["ema"]) else: self.ema = None