From dd8d1704047e3d8de43667c8cbd099c059f81e89 Mon Sep 17 00:00:00 2001 From: Ervin T Date: Wed, 13 May 2020 15:17:48 -0700 Subject: [PATCH] [bug-fix] Fix issue with initialize not resetting step count (#3962) --- com.unity.ml-agents/CHANGELOG.md | 1 + ml-agents/mlagents/trainers/policy/tf_policy.py | 1 + ml-agents/mlagents/trainers/tests/test_nn_policy.py | 4 ++++ 3 files changed, 6 insertions(+) diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index cfc2544051..06a35685f4 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to - Unity Player logs are now written out to the results directory. (#3877) - Run configuration YAML files are written out to the results directory at the end of the run. (#3815) ### Bug Fixes +- An issue was fixed where using `--initialize-from` would resume from the past step count. (#3962) #### com.unity.ml-agents (C#) #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/ml-agents/mlagents/trainers/policy/tf_policy.py b/ml-agents/mlagents/trainers/policy/tf_policy.py index f24f7acb25..bc725b9e8d 100644 --- a/ml-agents/mlagents/trainers/policy/tf_policy.py +++ b/ml-agents/mlagents/trainers/policy/tf_policy.py @@ -138,6 +138,7 @@ def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None ) ) if reset_global_steps: + self._set_step(0) logger.info( "Starting training from step 0 and saving to {}.".format( self.model_path diff --git a/ml-agents/mlagents/trainers/tests/test_nn_policy.py b/ml-agents/mlagents/trainers/tests/test_nn_policy.py index 17b15231bd..6cfb0d274a 100644 --- a/ml-agents/mlagents/trainers/tests/test_nn_policy.py +++ b/ml-agents/mlagents/trainers/tests/test_nn_policy.py @@ -85,6 +85,7 @@ def test_load_save(dummy_config, tmp_path): trainer_params["output_path"] = path1 policy = create_policy_mock(trainer_params) policy.initialize_or_load() + policy._set_step(2000) policy.save_model(2000) assert len(os.listdir(tmp_path)) > 0 @@ -93,6 +94,7 @@ def test_load_save(dummy_config, tmp_path): policy2 = create_policy_mock(trainer_params, load=True, seed=1) policy2.initialize_or_load() _compare_two_policies(policy, policy2) + assert policy2.get_current_step() == 2000 # Try initialize from path 1 trainer_params["model_path"] = path2 @@ -101,6 +103,8 @@ def test_load_save(dummy_config, tmp_path): policy3.initialize_or_load() _compare_two_policies(policy2, policy3) + # Assert that the steps are 0. + assert policy3.get_current_step() == 0 def _compare_two_policies(policy1: NNPolicy, policy2: NNPolicy) -> None: