diff --git a/ml-agents/mlagents/trainers/model_saver/model_saver.py b/ml-agents/mlagents/trainers/model_saver/model_saver.py index c88a358847..c1594ff08e 100644 --- a/ml-agents/mlagents/trainers/model_saver/model_saver.py +++ b/ml-agents/mlagents/trainers/model_saver/model_saver.py @@ -1,6 +1,6 @@ # # Unity ML-Agents Toolkit import abc -from typing import Any +from typing import Any, Tuple, List class BaseModelSaver(abc.ABC): @@ -34,11 +34,14 @@ def _register_optimizer(self, optimizer): pass @abc.abstractmethod - def save_checkpoint(self, behavior_name: str, step: int) -> str: + def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]: """ Checkpoints the policy on disk. :param checkpoint_path: filepath to write the checkpoint :param behavior_name: Behavior name of bevavior to be trained + :return: A Tuple of the path to the exported file, as well as a List of any + auxillary files that were returned. For instance, an exported file would be Model.onnx, + and the auxillary files would be [Model.pt] for PyTorch """ pass diff --git a/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py b/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py index eecf65b677..d9e1459eab 100644 --- a/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py +++ b/ml-agents/mlagents/trainers/model_saver/torch_model_saver.py @@ -1,7 +1,7 @@ import os import shutil from mlagents.torch_utils import torch -from typing import Dict, Union, Optional, cast +from typing import Dict, Union, Optional, cast, Tuple, List from mlagents_envs.exception import UnityPolicyException from mlagents_envs.logging_util import get_logger from mlagents.trainers.model_saver.model_saver import BaseModelSaver @@ -45,17 +45,19 @@ def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None: self.policy = module self.exporter = ModelSerializer(self.policy) - def save_checkpoint(self, behavior_name: str, step: int) -> str: + def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]: if not os.path.exists(self.model_path): os.makedirs(self.model_path) checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}") state_dict = { name: module.state_dict() for name, module in self.modules.items() } + pytorch_ckpt_path = f"{checkpoint_path}.pt" + export_ckpt_path = f"{checkpoint_path}.onnx" torch.save(state_dict, f"{checkpoint_path}.pt") torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt")) self.export(checkpoint_path, behavior_name) - return checkpoint_path + return export_ckpt_path, [pytorch_ckpt_path] def export(self, output_filepath: str, behavior_name: str) -> None: if self.exporter is not None: diff --git a/ml-agents/mlagents/trainers/policy/checkpoint_manager.py b/ml-agents/mlagents/trainers/policy/checkpoint_manager.py index d8203a4eb4..5f3e2762b3 100644 --- a/ml-agents/mlagents/trainers/policy/checkpoint_manager.py +++ b/ml-agents/mlagents/trainers/policy/checkpoint_manager.py @@ -14,6 +14,7 @@ class ModelCheckpoint: file_path: str reward: Optional[float] creation_time: float + auxillary_file_paths: List[str] = attr.ib(factory=list) class ModelCheckpointManager: @@ -37,12 +38,14 @@ def remove_checkpoint(checkpoint: Dict[str, Any]) -> None: :param checkpoint: A checkpoint stored in checkpoint_list """ - file_path: str = checkpoint["file_path"] - if os.path.exists(file_path): - os.remove(file_path) - logger.debug(f"Removed checkpoint model {file_path}.") - else: - logger.debug(f"Checkpoint at {file_path} could not be found.") + file_paths: List[str] = [checkpoint["file_path"]] + file_paths.extend(checkpoint["auxillary_file_paths"]) + for file_path in file_paths: + if os.path.exists(file_path): + os.remove(file_path) + logger.debug(f"Removed checkpoint model {file_path}.") + else: + logger.debug(f"Checkpoint at {file_path} could not be found.") return @classmethod diff --git a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py index 838ac51cec..fb346245f1 100644 --- a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py +++ b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py @@ -29,7 +29,13 @@ def _update_policy(self): def add_policy(self, mock_behavior_id, mock_policy): def checkpoint_path(brain_name, step): - return os.path.join(self.model_saver.model_path, f"{brain_name}-{step}") + onnx_file_path = os.path.join( + self.model_saver.model_path, f"{brain_name}-{step}.onnx" + ) + other_file_paths = [ + os.path.join(self.model_saver.model_path, f"{brain_name}-{step}.pt") + ] + return onnx_file_path, other_file_paths self.policies[mock_behavior_id] = mock_policy mock_model_saver = mock.Mock() @@ -171,6 +177,9 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary): f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.{export_ext}", None, mock.ANY, + [ + f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.pt" + ], ), trainer.trainer_settings.keep_checkpoints, ) diff --git a/ml-agents/mlagents/trainers/tests/test_training_status.py b/ml-agents/mlagents/trainers/tests/test_training_status.py index 32028befdb..88dde1c5cd 100644 --- a/ml-agents/mlagents/trainers/tests/test_training_status.py +++ b/ml-agents/mlagents/trainers/tests/test_training_status.py @@ -60,18 +60,21 @@ def test_model_management(tmpdir): "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"), "reward": 1.312, "creation_time": time.time(), + "auxillary_file_paths": [], }, { "steps": 2, "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"), "reward": 1.912, "creation_time": time.time(), + "auxillary_file_paths": [], }, { "steps": 3, "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"), "reward": 2.312, "creation_time": time.time(), + "auxillary_file_paths": [], }, ] GlobalTrainingStatus.set_parameter_state( diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py index 51bb78be94..78a9d7b130 100644 --- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py +++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py @@ -154,13 +154,15 @@ def _checkpoint(self) -> ModelCheckpoint: logger.warning( "Trainer has multiple policies, but default behavior only saves the first." ) - checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self._step) - export_ext = "onnx" + export_path, auxillary_paths = self.model_saver.save_checkpoint( + self.brain_name, self._step + ) new_checkpoint = ModelCheckpoint( int(self._step), - f"{checkpoint_path}.{export_ext}", + export_path, self._policy_mean_reward(), time.time(), + auxillary_file_paths=auxillary_paths, ) ModelCheckpointManager.add_checkpoint( self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints diff --git a/ml-agents/mlagents/trainers/training_status.py b/ml-agents/mlagents/trainers/training_status.py index 6d69093411..06bd73cd23 100644 --- a/ml-agents/mlagents/trainers/training_status.py +++ b/ml-agents/mlagents/trainers/training_status.py @@ -12,7 +12,7 @@ logger = get_logger(__name__) -STATUS_FORMAT_VERSION = "0.2.0" +STATUS_FORMAT_VERSION = "0.3.0" class StatusType(Enum):