diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 154b9f9f15..515d33c9e2 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -6,7 +6,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). - ## [2.0.0-exp.1] - 2021-04-22 ### Major Changes #### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) @@ -77,6 +76,9 @@ or actuators on your system. (#5194) #### ml-agents / ml-agents-envs / gym-unity (Python) - Fixed a bug where --results-dir has no effect. (#5269) - Fixed a bug where old `.pt` checkpoints were not deleted during training. (#5271) +- The `UnityToGymWrapper` initializer now accepts an optional `action_space_seed` seed. If this is specified, it will +be used to set the random seed on the resulting action space. (#5303) + ## [1.9.1-preview] - 2021-04-13 ### Major Changes diff --git a/gym-unity/gym_unity/envs/__init__.py b/gym-unity/gym_unity/envs/__init__.py index ab23ee2c75..0f70712947 100644 --- a/gym-unity/gym_unity/envs/__init__.py +++ b/gym-unity/gym_unity/envs/__init__.py @@ -1,6 +1,6 @@ import itertools import numpy as np -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import gym from gym import error, spaces @@ -35,6 +35,7 @@ def __init__( uint8_visual: bool = False, flatten_branched: bool = False, allow_multiple_obs: bool = False, + action_space_seed: Optional[int] = None, ): """ Environment initialization @@ -46,6 +47,7 @@ def __init__( containing the visual observations and the last element containing the array of vector observations. If False, returns a single np.ndarray containing either only a single visual observation or the array of vector observations. + :param action_space_seed: If non-None, will be used to set the random seed on created gym.Space instances. """ self._env = unity_env @@ -130,6 +132,9 @@ def __init__( "and continuous actions." ) + if action_space_seed is not None: + self._action_space.seed(action_space_seed) + # Set observations space list_spaces: List[gym.Space] = [] shapes = self._get_vis_obs_shape() @@ -305,7 +310,7 @@ def reward_range(self) -> Tuple[float, float]: return -float("inf"), float("inf") @property - def action_space(self): + def action_space(self) -> gym.Space: return self._action_space @property diff --git a/gym-unity/gym_unity/tests/test_gym.py b/gym-unity/gym_unity/tests/test_gym.py index c86ce3dee7..6928faa0fd 100644 --- a/gym-unity/gym_unity/tests/test_gym.py +++ b/gym-unity/gym_unity/tests/test_gym.py @@ -22,7 +22,6 @@ def test_gym_wrapper(): mock_env, mock_spec, mock_decision_step, mock_terminal_step ) env = UnityToGymWrapper(mock_env) - assert isinstance(env, UnityToGymWrapper) assert isinstance(env.reset(), np.ndarray) actions = env.action_space.sample() assert actions.shape[0] == 2 @@ -78,6 +77,21 @@ def test_action_space(): assert env.action_space.n == 5 +def test_action_space_seed(): + mock_env = mock.MagicMock() + mock_spec = create_mock_group_spec() + mock_decision_step, mock_terminal_step = create_mock_vector_steps(mock_spec) + setup_mock_unityenvironment( + mock_env, mock_spec, mock_decision_step, mock_terminal_step + ) + actions = [] + for _ in range(0, 2): + env = UnityToGymWrapper(mock_env, action_space_seed=1337) + env.reset() + actions.append(env.action_space.sample()) + assert (actions[0] == actions[1]).all() + + @pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"]) def test_gym_wrapper_visual(use_uint8): mock_env = mock.MagicMock() @@ -93,7 +107,6 @@ def test_gym_wrapper_visual(use_uint8): env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8) assert isinstance(env.observation_space, spaces.Box) - assert isinstance(env, UnityToGymWrapper) assert isinstance(env.reset(), np.ndarray) actions = env.action_space.sample() assert actions.shape[0] == 2 @@ -121,7 +134,6 @@ def test_gym_wrapper_single_visual_and_vector(use_uint8): ) env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True) - assert isinstance(env, UnityToGymWrapper) assert isinstance(env.observation_space, spaces.Tuple) assert len(env.observation_space) == 2 reset_obs = env.reset() @@ -143,7 +155,6 @@ def test_gym_wrapper_single_visual_and_vector(use_uint8): # check behavior for allow_multiple_obs = False env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False) - assert isinstance(env, UnityToGymWrapper) assert isinstance(env.observation_space, spaces.Box) reset_obs = env.reset() assert isinstance(reset_obs, np.ndarray) @@ -170,7 +181,6 @@ def test_gym_wrapper_multi_visual_and_vector(use_uint8): ) env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True) - assert isinstance(env, UnityToGymWrapper) assert isinstance(env.observation_space, spaces.Tuple) assert len(env.observation_space) == 3 reset_obs = env.reset() @@ -188,7 +198,6 @@ def test_gym_wrapper_multi_visual_and_vector(use_uint8): # check behavior for allow_multiple_obs = False env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False) - assert isinstance(env, UnityToGymWrapper) assert isinstance(env.observation_space, spaces.Box) reset_obs = env.reset() assert isinstance(reset_obs, np.ndarray)