Skip to content

Commit

Permalink
[bug-fix] Fix issue with initialize not resetting step count (#3962)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ervin T authored May 13, 2020
1 parent cd27c30 commit dd8d170
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ and this project adheres to
- Unity Player logs are now written out to the results directory. (#3877)
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
### Bug Fixes
- An issue was fixed where using `--initialize-from` would resume from the past step count. (#3962)
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)

Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None
)
)
if reset_global_steps:
self._set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
Expand Down
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_nn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def test_load_save(dummy_config, tmp_path):
trainer_params["output_path"] = path1
policy = create_policy_mock(trainer_params)
policy.initialize_or_load()
policy._set_step(2000)
policy.save_model(2000)

assert len(os.listdir(tmp_path)) > 0
Expand All @@ -93,6 +94,7 @@ def test_load_save(dummy_config, tmp_path):
policy2 = create_policy_mock(trainer_params, load=True, seed=1)
policy2.initialize_or_load()
_compare_two_policies(policy, policy2)
assert policy2.get_current_step() == 2000

# Try initialize from path 1
trainer_params["model_path"] = path2
Expand All @@ -101,6 +103,8 @@ def test_load_save(dummy_config, tmp_path):
policy3.initialize_or_load()

_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0


def _compare_two_policies(policy1: NNPolicy, policy2: NNPolicy) -> None:
Expand Down

0 comments on commit dd8d170

Please sign in to comment.