diff --git a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py index 71472227a249..46ea6ab3947a 100644 --- a/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py +++ b/python/ray/train/examples/pytorch/tune_cifar_torch_pbt_example.py @@ -83,7 +83,7 @@ def train_func(config): checkpoint_dict = session.get_checkpoint().to_dict() # Load in model - model_state = checkpoint_dict["model_state_dict"] + model_state = checkpoint_dict["model"] model.load_state_dict(model_state) # Load in optimizer @@ -146,7 +146,7 @@ def train_func(config): checkpoint = Checkpoint.from_dict( { "epoch": epoch, - "model_state_dict": model.state_dict(), + "model": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), } )