Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug-fix] Delete .pt checkpoints past keep-checkpoints #5271

Merged
merged 4 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ml-agents/mlagents/trainers/model_saver/model_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# # Unity ML-Agents Toolkit
import abc
from typing import Any
from typing import Any, Tuple, List


class BaseModelSaver(abc.ABC):
Expand Down Expand Up @@ -34,11 +34,14 @@ def _register_optimizer(self, optimizer):
pass

@abc.abstractmethod
def save_checkpoint(self, behavior_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param behavior_name: Behavior name of bevavior to be trained
:return: A Tuple of the path to the exported file, as well as a List of any
auxillary files that were returned. For instance, an exported file would be Model.onnx,
and the auxillary files would be [Model.pt] for PyTorch
"""
pass

Expand Down
8 changes: 5 additions & 3 deletions ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import shutil
from mlagents.torch_utils import torch
from typing import Dict, Union, Optional, cast
from typing import Dict, Union, Optional, cast, Tuple, List
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.model_saver.model_saver import BaseModelSaver
Expand Down Expand Up @@ -45,17 +45,19 @@ def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None:
self.policy = module
self.exporter = ModelSerializer(self.policy)

def save_checkpoint(self, behavior_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]:
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
state_dict = {
name: module.state_dict() for name, module in self.modules.items()
}
pytorch_ckpt_path = f"{checkpoint_path}.pt"
export_ckpt_path = f"{checkpoint_path}.onnx"
torch.save(state_dict, f"{checkpoint_path}.pt")
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
self.export(checkpoint_path, behavior_name)
return checkpoint_path
return export_ckpt_path, [pytorch_ckpt_path]

def export(self, output_filepath: str, behavior_name: str) -> None:
if self.exporter is not None:
Expand Down
15 changes: 9 additions & 6 deletions ml-agents/mlagents/trainers/policy/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ModelCheckpoint:
file_path: str
reward: Optional[float]
creation_time: float
auxillary_file_paths: List[str] = attr.ib(factory=list)


class ModelCheckpointManager:
Expand All @@ -37,12 +38,14 @@ def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:

:param checkpoint: A checkpoint stored in checkpoint_list
"""
file_path: str = checkpoint["file_path"]
if os.path.exists(file_path):
os.remove(file_path)
logger.debug(f"Removed checkpoint model {file_path}.")
else:
logger.debug(f"Checkpoint at {file_path} could not be found.")
file_paths: List[str] = [checkpoint["file_path"]]
file_paths.extend(checkpoint["auxillary_file_paths"])
for file_path in file_paths:
if os.path.exists(file_path):
os.remove(file_path)
logger.debug(f"Removed checkpoint model {file_path}.")
else:
logger.debug(f"Checkpoint at {file_path} could not be found.")
return

@classmethod
Expand Down
11 changes: 10 additions & 1 deletion ml-agents/mlagents/trainers/tests/test_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ def _update_policy(self):

def add_policy(self, mock_behavior_id, mock_policy):
def checkpoint_path(brain_name, step):
return os.path.join(self.model_saver.model_path, f"{brain_name}-{step}")
onnx_file_path = os.path.join(
self.model_saver.model_path, f"{brain_name}-{step}.onnx"
)
other_file_paths = [
os.path.join(self.model_saver.model_path, f"{brain_name}-{step}.pt")
]
return onnx_file_path, other_file_paths

self.policies[mock_behavior_id] = mock_policy
mock_model_saver = mock.Mock()
Expand Down Expand Up @@ -171,6 +177,9 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.{export_ext}",
None,
mock.ANY,
[
f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.pt"
],
),
trainer.trainer_settings.keep_checkpoints,
)
Expand Down
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_training_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,21 @@ def test_model_management(tmpdir):
"file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
"reward": 1.312,
"creation_time": time.time(),
"auxillary_file_paths": [],
},
{
"steps": 2,
"file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"),
"reward": 1.912,
"creation_time": time.time(),
"auxillary_file_paths": [],
},
{
"steps": 3,
"file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"),
"reward": 2.312,
"creation_time": time.time(),
"auxillary_file_paths": [],
},
]
GlobalTrainingStatus.set_parameter_state(
Expand Down
8 changes: 5 additions & 3 deletions ml-agents/mlagents/trainers/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,15 @@ def _checkpoint(self) -> ModelCheckpoint:
logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self._step)
export_ext = "onnx"
export_path, auxillary_paths = self.model_saver.save_checkpoint(
self.brain_name, self._step
)
new_checkpoint = ModelCheckpoint(
int(self._step),
f"{checkpoint_path}.{export_ext}",
export_path,
self._policy_mean_reward(),
time.time(),
auxillary_file_paths=auxillary_paths,
)
ModelCheckpointManager.add_checkpoint(
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/training_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logger = get_logger(__name__)

STATUS_FORMAT_VERSION = "0.2.0"
STATUS_FORMAT_VERSION = "0.3.0"


class StatusType(Enum):
Expand Down