From b68beb47e22e3feeec76452295b7199c6cf9ded9 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 5 Aug 2020 11:34:31 -0700 Subject: [PATCH] Don't save model twice, copy instead (#4302) * Don't save model twice, copy instead * narrower exception --- ml-agents/mlagents/model_serialization.py | 18 ++++++++++++++++++ .../mlagents/trainers/trainer/rl_trainer.py | 8 +++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/ml-agents/mlagents/model_serialization.py b/ml-agents/mlagents/model_serialization.py index edc7a5f6ee..11714c3ec2 100644 --- a/ml-agents/mlagents/model_serialization.py +++ b/ml-agents/mlagents/model_serialization.py @@ -1,5 +1,6 @@ from distutils.util import strtobool import os +import shutil from typing import Any, List, Set, NamedTuple from distutils.version import LooseVersion @@ -227,3 +228,20 @@ def _enforce_onnx_conversion() -> bool: return strtobool(val) except Exception: return False + + +def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None: + """ + Copy the .nn file at the given source to the destination. + Also copies the corresponding .onnx file if it exists. + """ + shutil.copyfile(source_nn_path, destination_nn_path) + logger.info(f"Copied {source_nn_path} to {destination_nn_path}.") + # Copy the onnx file if it exists + source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx" + destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx" + try: + shutil.copyfile(source_onnx_path, destination_onnx_path) + logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.") + except OSError: + pass diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py index 2ab44bcf1f..9765ab17be 100644 --- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py +++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py @@ -5,7 +5,7 @@ import abc import time import attr -from mlagents.model_serialization import SerializationSettings +from mlagents.model_serialization import SerializationSettings, copy_model_files from mlagents.trainers.policy.checkpoint_manager import ( NNCheckpoint, NNCheckpointManager, @@ -131,12 +131,14 @@ def save_model(self) -> None: "Trainer has multiple policies, but default behavior only saves the first." ) policy = list(self.policies.values())[0] - settings = SerializationSettings(policy.model_path, self.brain_name) model_checkpoint = self._checkpoint() + + # Copy the checkpointed model files to the final output location + copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn") + final_checkpoint = attr.evolve( model_checkpoint, file_path=f"{policy.model_path}.nn" ) - policy.save(policy.model_path, settings) NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint) @abc.abstractmethod