Skip to content

Commit

Permalink
Don't save model twice, copy instead (#4302)
Browse files Browse the repository at this point in the history
* Don't save model twice, copy instead

* narrower exception
  • Loading branch information
Chris Elion authored Aug 5, 2020
1 parent 9d92e8a commit b68beb4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
18 changes: 18 additions & 0 deletions ml-agents/mlagents/model_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from distutils.util import strtobool
import os
import shutil
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion

Expand Down Expand Up @@ -227,3 +228,20 @@ def _enforce_onnx_conversion() -> bool:
return strtobool(val)
except Exception:
return False


def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
shutil.copyfile(source_nn_path, destination_nn_path)
logger.info(f"Copied {source_nn_path} to {destination_nn_path}.")
# Copy the onnx file if it exists
source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx"
destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx"
try:
shutil.copyfile(source_onnx_path, destination_onnx_path)
logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.")
except OSError:
pass
8 changes: 5 additions & 3 deletions ml-agents/mlagents/trainers/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import abc
import time
import attr
from mlagents.model_serialization import SerializationSettings
from mlagents.model_serialization import SerializationSettings, copy_model_files
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,
Expand Down Expand Up @@ -131,12 +131,14 @@ def save_model(self) -> None:
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
settings = SerializationSettings(policy.model_path, self.brain_name)
model_checkpoint = self._checkpoint()

# Copy the checkpointed model files to the final output location
copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")

final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{policy.model_path}.nn"
)
policy.save(policy.model_path, settings)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)

@abc.abstractmethod
Expand Down

0 comments on commit b68beb4

Please sign in to comment.