diff --git a/slm_lab/agent/memory/onpolicy.py b/slm_lab/agent/memory/onpolicy.py index 6e77e0991..7f8b228ca 100644 --- a/slm_lab/agent/memory/onpolicy.py +++ b/slm_lab/agent/memory/onpolicy.py @@ -144,22 +144,3 @@ def sample(self): 'dones' : dones} ''' return super().sample() - - -class OnPolicyAtariReplay(OnPolicyReplay): - ''' - Preprocesses an state to be the concatenation of the last four states, after converting the 210 x 160 x 3 image to 84 x 84 x 1 grayscale image, and clips all rewards to [-10, 10] as per "Playing Atari with Deep Reinforcement Learning", Mnih et al, 2013 - Note: Playing Atari with Deep RL clips the rewards to + / - 1 - Otherwise the same as OnPolicyReplay memory - ''' - - def add_experience(self, state, action, reward, next_state, done): - # clip reward, done here to minimize change to only training data data - super().add_experience(state, action, np.sign(reward), next_state, done) - - -class OnPolicyAtariBatchReplay(OnPolicyBatchReplay, OnPolicyAtariReplay): - ''' - OnPolicyBatchReplay with Atari concat - ''' - pass diff --git a/slm_lab/agent/memory/prioritized.py b/slm_lab/agent/memory/prioritized.py index 5533b25ba..695218054 100644 --- a/slm_lab/agent/memory/prioritized.py +++ b/slm_lab/agent/memory/prioritized.py @@ -1,4 +1,4 @@ -from slm_lab.agent.memory.replay import Replay, AtariReplay +from slm_lab.agent.memory.replay import Replay from slm_lab.lib import util from slm_lab.lib.decorator import lab_api import numpy as np @@ -175,8 +175,3 @@ def update_priorities(self, errors): self.priorities[idx] = p for p, i in zip(priorities, self.tree_idxs): self.tree.update(i, p) - - -class AtariPrioritizedReplay(PrioritizedReplay, AtariReplay): - '''Make a Atari PrioritizedReplay via nice multi-inheritance (python magic)''' - pass diff --git a/slm_lab/agent/memory/replay.py b/slm_lab/agent/memory/replay.py index b67d9c26b..39c7f0322 100644 --- a/slm_lab/agent/memory/replay.py +++ b/slm_lab/agent/memory/replay.py @@ -151,14 +151,3 @@ def sample_idxs(self, batch_size): if self.use_cer: # add the latest sample batch_idxs[-1] = self.head return batch_idxs - - -class AtariReplay(Replay): - ''' - Preprocesses an state to be the concatenation of the last four states, after converting the 210 x 160 x 3 image to 84 x 84 x 1 grayscale image, and clips all rewards to [-10, 10] as per "Playing Atari with Deep Reinforcement Learning", Mnih et al, 2013 - Note: Playing Atari with Deep RL clips the rewards to + / - 1 - ''' - - def add_experience(self, state, action, reward, next_state, done): - # clip reward, done here to minimize change to only training data data - super().add_experience(state, action, np.sign(reward), next_state, done) diff --git a/slm_lab/env/openai.py b/slm_lab/env/openai.py index 9060b7266..9968681c9 100644 --- a/slm_lab/env/openai.py +++ b/slm_lab/env/openai.py @@ -33,9 +33,9 @@ def __init__(self, spec, e=None, env_space=None): try_register_env(spec) # register if it's a custom gym env seed = ps.get(spec, 'meta.random_seed') if self.is_venv: # make vector environment - self.u_env = make_gym_venv(self.name, seed, self.frame_op, self.frame_op_len, self.num_envs) + self.u_env = make_gym_venv(self.name, seed, self.frame_op, self.frame_op_len, self.reward_scale, self.num_envs) else: - self.u_env = make_gym_env(self.name, seed, self.frame_op, self.frame_op_len) + self.u_env = make_gym_env(self.name, seed, self.frame_op, self.frame_op_len, self.reward_scale) self._set_attr_from_u_env(self.u_env) self.max_t = self.max_t or self.u_env.spec.max_episode_steps assert self.max_t is not None @@ -58,8 +58,6 @@ def step(self, action): if not self.is_discrete and self.action_dim == 1: # guard for continuous with action_dim 1, make array action = np.expand_dims(action, axis=-1) state, reward, done, info = self.u_env.step(action) - if self.reward_scale is not None: - reward *= self.reward_scale if self.to_render: self.u_env.render() if not self.is_venv and self.clock.t > self.max_t: @@ -100,8 +98,6 @@ def space_step(self, action_e): state, reward, done, info = self.u_env.step(action) if done: state = self.u_env.reset() - if self.reward_scale is not None: - reward *= self.reward_scale if self.to_render: self.u_env.render() if not self.is_venv and self.clock.t > self.max_t: diff --git a/slm_lab/env/unity.py b/slm_lab/env/unity.py index 6e1252da1..69a331868 100644 --- a/slm_lab/env/unity.py +++ b/slm_lab/env/unity.py @@ -1,6 +1,7 @@ from gym import spaces from slm_lab.env.base import BaseEnv, ENV_DATA_NAMES, set_gym_space_attr from slm_lab.env.registration import get_env_path +from slm_lab.env.wrapper import try_scale_reward from slm_lab.lib import logger, util from slm_lab.lib.decorator import lab_api from unityagents import brain, UnityEnvironment @@ -141,8 +142,7 @@ def step(self, action): env_info_a = self._get_env_info(env_info_dict, a) state = env_info_a.states[b] reward = env_info_a.rewards[b] - if self.reward_scale is not None: - reward *= self.reward_scale + reward = try_scale_reward(self, reward) done = env_info_a.local_done[b] if not self.is_venv and self.clock.t > self.max_t: done = True @@ -187,10 +187,9 @@ def space_step(self, action_e): for (a, b), body in util.ndenumerate_nonan(self.body_e): env_info_a = self._get_env_info(env_info_dict, a) state_e[(a, b)] = env_info_a.states[b] - reward = env_info_a.rewards[b] - if self.reward_scale is not None: - reward *= self.reward_scale - reward_e[(a, b)] = reward + rewards = env_info_a.rewards[b] + rewards = try_scale_reward(self, rewards) + reward_e[(a, b)] = rewards done_e[(a, b)] = env_info_a.local_done[b] info_e = env_info_dict self.done = (util.nonan_all(done_e) or self.clock.t > self.max_t) diff --git a/slm_lab/env/vec_env.py b/slm_lab/env/vec_env.py index b4cfaa3f6..9b10e84a2 100644 --- a/slm_lab/env/vec_env.py +++ b/slm_lab/env/vec_env.py @@ -4,7 +4,7 @@ from collections import OrderedDict from functools import partial from gym import spaces -from slm_lab.env.wrapper import make_gym_env +from slm_lab.env.wrapper import make_gym_env, try_scale_reward from slm_lab.lib import logger import contextlib import ctypes @@ -450,11 +450,13 @@ def _decode_obses(self, obs): class VecFrameStack(VecEnvWrapper): '''Frame stack wrapper for vector environment''' - def __init__(self, venv, frame_op, frame_op_len): + def __init__(self, venv, frame_op, frame_op_len, reward_scale=None): self.venv = venv assert frame_op == 'concat', 'VecFrameStack only supports concat frame_op for now' self.frame_op = frame_op self.frame_op_len = frame_op_len + self.reward_scale = reward_scale + self.sign_reward = self.reward_scale == 'sign' self.spec = venv.spec wos = venv.observation_space # wrapped ob space self.shape_dim0 = wos.shape[0] @@ -471,6 +473,7 @@ def step_wait(self): if new: self.stackedobs[i] = 0 self.stackedobs[:, -self.shape_dim0:] = obs + rews = try_scale_reward(self, rews) return self.stackedobs.copy(), rews, news, infos def reset(self): @@ -480,11 +483,11 @@ def reset(self): return self.stackedobs.copy() -def make_gym_venv(name, seed=0, frame_op=None, frame_op_len=None, num_envs=4): +def make_gym_venv(name, seed=0, frame_op=None, frame_op_len=None, reward_scale=None, num_envs=4): '''General method to create any parallel vectorized Gym env; auto wraps Atari''' venv = [ - # don't stack on individual env, but stack as vector - partial(make_gym_env, name, seed + i, frame_op=None, frame_op_len=None) + # don't concat frame or clip reward on individual env; do that at vector level + partial(make_gym_env, name, seed + i, frame_op=None, frame_op_len=None, reward_scale=None) for i in range(num_envs) ] if len(venv) > 1: @@ -492,5 +495,5 @@ def make_gym_venv(name, seed=0, frame_op=None, frame_op_len=None, num_envs=4): else: venv = DummyVecEnv(venv) if frame_op is not None: - venv = VecFrameStack(venv, frame_op, frame_op_len) + venv = VecFrameStack(venv, frame_op, frame_op_len, reward_scale) return venv diff --git a/slm_lab/env/wrapper.py b/slm_lab/env/wrapper.py index b06b12461..7a60f3a35 100644 --- a/slm_lab/env/wrapper.py +++ b/slm_lab/env/wrapper.py @@ -8,6 +8,17 @@ import numpy as np +def try_scale_reward(cls, reward): + '''Env class to scale reward and set raw_reward''' + if cls.reward_scale is not None: + cls.raw_reward = reward + if cls.sign_reward: + reward = np.sign(reward) + else: + reward *= cls.reward_scale + return reward + + class NoopResetEnv(gym.Wrapper): def __init__(self, env, noop_max=30): ''' @@ -130,10 +141,19 @@ def reset(self, **kwargs): return self.env.reset(**kwargs) -class ClipRewardEnv(gym.RewardWrapper): +class ScaleRewardEnv(gym.RewardWrapper): + def __init__(self, env, reward_scale): + ''' + Rescale reward + @param (str,float):reward_scale If 'sign', use np.sign, else multiply with the specified float scale + ''' + gym.Wrapper.__init__(self, env) + self.reward_scale = reward_scale + self.sign_reward = self.reward_scale == 'sign' + def reward(self, reward): - '''Atari reward, to -1, 0 or +1. Not usually used as SLM Lab memory class does the clipping''' - return np.sign(reward) + '''Set self.raw_reward for retrieving the original reward''' + return try_scale_reward(self, reward) class PreprocessImage(gym.ObservationWrapper): @@ -241,14 +261,12 @@ def wrap_atari(env): return env -def wrap_deepmind(env, episode_life=True, clip_rewards=True, stack_len=None): +def wrap_deepmind(env, episode_life=True, stack_len=None): '''Wrap Atari environment DeepMind-style''' if episode_life: env = EpisodicLifeEnv(env) if 'FIRE' in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) - if clip_rewards: - env = ClipRewardEnv(env) env = PreprocessImage(env) if stack_len is not None: # use concat for image (1, 84, 84) env = FrameStack(env, 'concat', stack_len) @@ -263,7 +281,7 @@ def wrap_image_env(env, stack_len=None): return env -def make_gym_env(name, seed=None, frame_op=None, frame_op_len=None): +def make_gym_env(name, seed=None, frame_op=None, frame_op_len=None, reward_scale=None): '''General method to create any Gym env; auto wraps Atari''' env = gym.make(name) if seed is not None: @@ -271,12 +289,13 @@ def make_gym_env(name, seed=None, frame_op=None, frame_op_len=None): if 'NoFrameskip' in env.spec.id: # Atari env = wrap_atari(env) # no reward clipping to allow monitoring; Atari memory clips it - clip_rewards = False episode_life = util.get_lab_mode() != 'eval' - env = wrap_deepmind(env, clip_rewards, episode_life, frame_op_len) + env = wrap_deepmind(env, episode_life, frame_op_len) elif len(env.observation_space.shape) == 3: # image-state env env = wrap_image_env(env, frame_op_len) else: # vector-state env if frame_op is not None: env = FrameStack(env, frame_op, frame_op_len) + if reward_scale is not None: + env = ScaleRewardEnv(env, reward_scale) return env diff --git a/slm_lab/experiment/monitor.py b/slm_lab/experiment/monitor.py index 8c32ef501..045c4a54d 100644 --- a/slm_lab/experiment/monitor.py +++ b/slm_lab/experiment/monitor.py @@ -140,6 +140,8 @@ def __init__(self, env, agent_spec, aeb=(0, 0, 0), aeb_space=None): def update(self, state, action, reward, next_state, done): '''Interface update method for body at agent.update()''' + if self.env.reward_scale is not None: + reward = self.env.u_env.raw_reward if self.ckpt_total_reward is np.nan: # init self.ckpt_total_reward = reward else: # reset on epi_start, else keep adding. generalized for vec env diff --git a/slm_lab/spec/experimental/a2c.json b/slm_lab/spec/experimental/a2c.json index 03a61ebf4..3f703f20b 100644 --- a/slm_lab/spec/experimental/a2c.json +++ b/slm_lab/spec/experimental/a2c.json @@ -798,7 +798,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay", + "name": "OnPolicyReplay", }, "net": { "type": "ConvNet", @@ -833,6 +833,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000, }], diff --git a/slm_lab/spec/experimental/a2c_pong.json b/slm_lab/spec/experimental/a2c_pong.json index ed218c10e..deb49e8cc 100644 --- a/slm_lab/spec/experimental/a2c_pong.json +++ b/slm_lab/spec/experimental/a2c_pong.json @@ -23,7 +23,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariBatchReplay", + "name": "OnPolicyBatchReplay", }, "net": { "type": "ConvNet", @@ -63,6 +63,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "num_envs": 16, "max_t": null, "max_tick": 1e7 diff --git a/slm_lab/spec/experimental/ddqn.json b/slm_lab/spec/experimental/ddqn.json index cd390fa26..b82a95f9a 100644 --- a/slm_lab/spec/experimental/ddqn.json +++ b/slm_lab/spec/experimental/ddqn.json @@ -379,7 +379,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 250000, "use_cer": true @@ -421,6 +421,7 @@ "name": "BreakoutDeterministic-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 50000, }], @@ -459,7 +460,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 250000, "use_cer": true @@ -501,6 +502,7 @@ "name": "BreakoutDeterministic-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 50000, }], diff --git a/slm_lab/spec/experimental/ddqn_beamrider.json b/slm_lab/spec/experimental/ddqn_beamrider.json index 473348244..6732dab22 100644 --- a/slm_lab/spec/experimental/ddqn_beamrider.json +++ b/slm_lab/spec/experimental/ddqn_beamrider.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false, @@ -55,6 +55,7 @@ "name": "BeamRiderNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_breakout.json b/slm_lab/spec/experimental/ddqn_breakout.json index b8e58b173..3bfc8cba6 100644 --- a/slm_lab/spec/experimental/ddqn_breakout.json +++ b/slm_lab/spec/experimental/ddqn_breakout.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "BreakoutNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_enduro.json b/slm_lab/spec/experimental/ddqn_enduro.json index e0306dc60..fd798b817 100644 --- a/slm_lab/spec/experimental/ddqn_enduro.json +++ b/slm_lab/spec/experimental/ddqn_enduro.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "EnduroNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_mspacman.json b/slm_lab/spec/experimental/ddqn_mspacman.json index 7d32d001b..18228bed0 100644 --- a/slm_lab/spec/experimental/ddqn_mspacman.json +++ b/slm_lab/spec/experimental/ddqn_mspacman.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "MsPacmanNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_beamrider.json b/slm_lab/spec/experimental/ddqn_per_beamrider.json index 73c623d9a..bd58b8c46 100644 --- a/slm_lab/spec/experimental/ddqn_per_beamrider.json +++ b/slm_lab/spec/experimental/ddqn_per_beamrider.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "BeamRiderNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_breakout.json b/slm_lab/spec/experimental/ddqn_per_breakout.json index 7d3296e37..3b76dfebd 100644 --- a/slm_lab/spec/experimental/ddqn_per_breakout.json +++ b/slm_lab/spec/experimental/ddqn_per_breakout.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "BreakoutNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_enduro.json b/slm_lab/spec/experimental/ddqn_per_enduro.json index ffe4d57bf..5b36b1ab2 100644 --- a/slm_lab/spec/experimental/ddqn_per_enduro.json +++ b/slm_lab/spec/experimental/ddqn_per_enduro.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "EnduroNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_mspacman.json b/slm_lab/spec/experimental/ddqn_per_mspacman.json index 5c85243d1..7ab49765b 100644 --- a/slm_lab/spec/experimental/ddqn_per_mspacman.json +++ b/slm_lab/spec/experimental/ddqn_per_mspacman.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "MsPacmanNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_pong.json b/slm_lab/spec/experimental/ddqn_per_pong.json index 487c5ebdd..d6b382247 100644 --- a/slm_lab/spec/experimental/ddqn_per_pong.json +++ b/slm_lab/spec/experimental/ddqn_per_pong.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_qbert.json b/slm_lab/spec/experimental/ddqn_per_qbert.json index d4cf8c3db..bb123b10f 100644 --- a/slm_lab/spec/experimental/ddqn_per_qbert.json +++ b/slm_lab/spec/experimental/ddqn_per_qbert.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "QbertNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_seaquest.json b/slm_lab/spec/experimental/ddqn_per_seaquest.json index 5d7aea017..df391f684 100644 --- a/slm_lab/spec/experimental/ddqn_per_seaquest.json +++ b/slm_lab/spec/experimental/ddqn_per_seaquest.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "SeaquestNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_per_spaceinvaders.json b/slm_lab/spec/experimental/ddqn_per_spaceinvaders.json index 965c8306b..9a2f4fca4 100644 --- a/slm_lab/spec/experimental/ddqn_per_spaceinvaders.json +++ b/slm_lab/spec/experimental/ddqn_per_spaceinvaders.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "SpaceInvadersNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_pong.json b/slm_lab/spec/experimental/ddqn_pong.json index a29af6a68..a9029ba5d 100644 --- a/slm_lab/spec/experimental/ddqn_pong.json +++ b/slm_lab/spec/experimental/ddqn_pong.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_qbert.json b/slm_lab/spec/experimental/ddqn_qbert.json index 8571fac4e..a4962a35d 100644 --- a/slm_lab/spec/experimental/ddqn_qbert.json +++ b/slm_lab/spec/experimental/ddqn_qbert.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "QbertNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_seaquest.json b/slm_lab/spec/experimental/ddqn_seaquest.json index f4add14a4..e1906f1ea 100644 --- a/slm_lab/spec/experimental/ddqn_seaquest.json +++ b/slm_lab/spec/experimental/ddqn_seaquest.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "SeaquestNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ddqn_spaceinvaders.json b/slm_lab/spec/experimental/ddqn_spaceinvaders.json index 17818e49a..514dac716 100644 --- a/slm_lab/spec/experimental/ddqn_spaceinvaders.json +++ b/slm_lab/spec/experimental/ddqn_spaceinvaders.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "SpaceInvadersNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn.json b/slm_lab/spec/experimental/dqn.json index 21e4941ec..7620d7528 100644 --- a/slm_lab/spec/experimental/dqn.json +++ b/slm_lab/spec/experimental/dqn.json @@ -534,7 +534,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 100000, "use_cer": false @@ -568,6 +568,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000, }], diff --git a/slm_lab/spec/experimental/dqn_beamrider.json b/slm_lab/spec/experimental/dqn_beamrider.json index 37fd83cac..457493348 100644 --- a/slm_lab/spec/experimental/dqn_beamrider.json +++ b/slm_lab/spec/experimental/dqn_beamrider.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "BeamRiderNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_breakout.json b/slm_lab/spec/experimental/dqn_breakout.json index a2a372589..41f3ea3b1 100644 --- a/slm_lab/spec/experimental/dqn_breakout.json +++ b/slm_lab/spec/experimental/dqn_breakout.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "BreakoutNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_enduro.json b/slm_lab/spec/experimental/dqn_enduro.json index 8d2234147..fabc14e3f 100644 --- a/slm_lab/spec/experimental/dqn_enduro.json +++ b/slm_lab/spec/experimental/dqn_enduro.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "EnduroNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_mspacman.json b/slm_lab/spec/experimental/dqn_mspacman.json index ad6aa9a14..a5005543f 100644 --- a/slm_lab/spec/experimental/dqn_mspacman.json +++ b/slm_lab/spec/experimental/dqn_mspacman.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "MsPacmanNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_per_beamrider.json b/slm_lab/spec/experimental/dqn_per_beamrider.json index 3e95c097e..a10c5e6b1 100644 --- a/slm_lab/spec/experimental/dqn_per_beamrider.json +++ b/slm_lab/spec/experimental/dqn_per_beamrider.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "BeamRiderNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_per_breakout.json b/slm_lab/spec/experimental/dqn_per_breakout.json index 3ff03f37b..787c18e3b 100644 --- a/slm_lab/spec/experimental/dqn_per_breakout.json +++ b/slm_lab/spec/experimental/dqn_per_breakout.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "BreakoutNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_per_enduro.json b/slm_lab/spec/experimental/dqn_per_enduro.json index 371ae900d..eaf9f6f83 100644 --- a/slm_lab/spec/experimental/dqn_per_enduro.json +++ b/slm_lab/spec/experimental/dqn_per_enduro.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "EnduroNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_per_mspacman.json b/slm_lab/spec/experimental/dqn_per_mspacman.json index 558483eb0..6c12073f2 100644 --- a/slm_lab/spec/experimental/dqn_per_mspacman.json +++ b/slm_lab/spec/experimental/dqn_per_mspacman.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "MsPacmanNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_per_pong.json b/slm_lab/spec/experimental/dqn_per_pong.json index 11a163b54..e37bbacea 100644 --- a/slm_lab/spec/experimental/dqn_per_pong.json +++ b/slm_lab/spec/experimental/dqn_per_pong.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "num_envs": null, "max_t": null, "max_tick": 1e7 diff --git a/slm_lab/spec/experimental/dqn_per_qbert.json b/slm_lab/spec/experimental/dqn_per_qbert.json index fbb50c646..dc8825c0c 100644 --- a/slm_lab/spec/experimental/dqn_per_qbert.json +++ b/slm_lab/spec/experimental/dqn_per_qbert.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "QbertNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_per_seaquest.json b/slm_lab/spec/experimental/dqn_per_seaquest.json index 252c27301..724a6e59e 100644 --- a/slm_lab/spec/experimental/dqn_per_seaquest.json +++ b/slm_lab/spec/experimental/dqn_per_seaquest.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "SeaquestNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_per_spaceinvaders.json b/slm_lab/spec/experimental/dqn_per_spaceinvaders.json index 29d541cd5..510a472c2 100644 --- a/slm_lab/spec/experimental/dqn_per_spaceinvaders.json +++ b/slm_lab/spec/experimental/dqn_per_spaceinvaders.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariPrioritizedReplay", + "name": "PrioritizedReplay", "alpha": 0.6, "epsilon": 0.0001, "batch_size": 32, @@ -57,6 +57,7 @@ "name": "SpaceInvadersNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_pong.json b/slm_lab/spec/experimental/dqn_pong.json index 52841e527..322d8dfac 100644 --- a/slm_lab/spec/experimental/dqn_pong.json +++ b/slm_lab/spec/experimental/dqn_pong.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "num_envs": null, "max_t": null, "max_tick": 1e7 diff --git a/slm_lab/spec/experimental/dqn_qbert.json b/slm_lab/spec/experimental/dqn_qbert.json index 9f41f5574..a6e622721 100644 --- a/slm_lab/spec/experimental/dqn_qbert.json +++ b/slm_lab/spec/experimental/dqn_qbert.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "QbertNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_seaquest.json b/slm_lab/spec/experimental/dqn_seaquest.json index 51b3879a9..31c7c4101 100644 --- a/slm_lab/spec/experimental/dqn_seaquest.json +++ b/slm_lab/spec/experimental/dqn_seaquest.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "SeaquestNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dqn_spaceinvaders.json b/slm_lab/spec/experimental/dqn_spaceinvaders.json index 2c5a2c330..41f37e0c6 100644 --- a/slm_lab/spec/experimental/dqn_spaceinvaders.json +++ b/slm_lab/spec/experimental/dqn_spaceinvaders.json @@ -21,7 +21,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 200000, "use_cer": false @@ -55,6 +55,7 @@ "name": "SpaceInvadersNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/dueling_dqn.json b/slm_lab/spec/experimental/dueling_dqn.json index 80bb34fdd..0bc0aff95 100644 --- a/slm_lab/spec/experimental/dueling_dqn.json +++ b/slm_lab/spec/experimental/dueling_dqn.json @@ -269,7 +269,7 @@ "normalize_state": false }, "memory": { - "name": "AtariReplay", + "name": "Replay", "batch_size": 32, "max_size": 250000, "use_cer": true @@ -311,6 +311,7 @@ "name": "BreakoutDeterministic-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 50000, }], diff --git a/slm_lab/spec/experimental/ppo_beamrider.json b/slm_lab/spec/experimental/ppo_beamrider.json index f070c49ca..af814dc48 100644 --- a/slm_lab/spec/experimental/ppo_beamrider.json +++ b/slm_lab/spec/experimental/ppo_beamrider.json @@ -29,7 +29,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay" + "name": "OnPolicyReplay" }, "net": { "type": "ConvNet", @@ -60,6 +60,7 @@ "name": "BeamRiderNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ppo_breakout.json b/slm_lab/spec/experimental/ppo_breakout.json index 4c0a54877..46385f447 100644 --- a/slm_lab/spec/experimental/ppo_breakout.json +++ b/slm_lab/spec/experimental/ppo_breakout.json @@ -29,7 +29,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay" + "name": "OnPolicyReplay" }, "net": { "type": "ConvNet", @@ -60,6 +60,7 @@ "name": "BreakoutNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ppo_enduro.json b/slm_lab/spec/experimental/ppo_enduro.json index 9b52c14f5..0b3f108bd 100644 --- a/slm_lab/spec/experimental/ppo_enduro.json +++ b/slm_lab/spec/experimental/ppo_enduro.json @@ -29,7 +29,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay" + "name": "OnPolicyReplay" }, "net": { "type": "ConvNet", @@ -60,6 +60,7 @@ "name": "EnduroNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ppo_mspacman.json b/slm_lab/spec/experimental/ppo_mspacman.json index 5ef13a781..651105230 100644 --- a/slm_lab/spec/experimental/ppo_mspacman.json +++ b/slm_lab/spec/experimental/ppo_mspacman.json @@ -29,7 +29,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay" + "name": "OnPolicyReplay" }, "net": { "type": "ConvNet", @@ -60,6 +60,7 @@ "name": "MsPacmanNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ppo_pong.json b/slm_lab/spec/experimental/ppo_pong.json index 399d60101..365caa75a 100644 --- a/slm_lab/spec/experimental/ppo_pong.json +++ b/slm_lab/spec/experimental/ppo_pong.json @@ -30,7 +30,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariBatchReplay", + "name": "OnPolicyBatchReplay", }, "net": { "type": "ConvNet", @@ -70,6 +70,7 @@ "name": "PongNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "num_envs": 8, "max_t": null, "max_tick": 1e7 diff --git a/slm_lab/spec/experimental/ppo_qbert.json b/slm_lab/spec/experimental/ppo_qbert.json index 0eedb8ffa..71ade7da5 100644 --- a/slm_lab/spec/experimental/ppo_qbert.json +++ b/slm_lab/spec/experimental/ppo_qbert.json @@ -29,7 +29,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay" + "name": "OnPolicyReplay" }, "net": { "type": "ConvNet", @@ -60,6 +60,7 @@ "name": "QbertNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ppo_seaquest.json b/slm_lab/spec/experimental/ppo_seaquest.json index e4b7e092b..709ead8cc 100644 --- a/slm_lab/spec/experimental/ppo_seaquest.json +++ b/slm_lab/spec/experimental/ppo_seaquest.json @@ -29,7 +29,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay" + "name": "OnPolicyReplay" }, "net": { "type": "ConvNet", @@ -60,6 +60,7 @@ "name": "SeaquestNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/ppo_spaceinvaders.json b/slm_lab/spec/experimental/ppo_spaceinvaders.json index dfc3744b2..cfdb5ccde 100644 --- a/slm_lab/spec/experimental/ppo_spaceinvaders.json +++ b/slm_lab/spec/experimental/ppo_spaceinvaders.json @@ -29,7 +29,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay" + "name": "OnPolicyReplay" }, "net": { "type": "ConvNet", @@ -60,6 +60,7 @@ "name": "SpaceInvadersNoFrameskip-v4", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "max_t": null, "max_tick": 10000000 }], diff --git a/slm_lab/spec/experimental/reinforce.json b/slm_lab/spec/experimental/reinforce.json index 7e1f9d6ea..76b2bc94a 100644 --- a/slm_lab/spec/experimental/reinforce.json +++ b/slm_lab/spec/experimental/reinforce.json @@ -385,7 +385,7 @@ "normalize_state": false }, "memory": { - "name": "OnPolicyAtariReplay", + "name": "OnPolicyReplay", }, "net": { "type": "ConvNet", @@ -418,6 +418,7 @@ "name": "vizdoom-v0", "frame_op": "concat", "frame_op_len": 4, + "reward_scale": "sign", "cfg_name": "basic", "max_t": 400000, "max_tick": 100 diff --git a/test/env/test_vec_env.py b/test/env/test_vec_env.py index 586d5b11a..c70a1aafb 100644 --- a/test/env/test_vec_env.py +++ b/test/env/test_vec_env.py @@ -3,17 +3,17 @@ import pytest -@pytest.mark.parametrize('name,state_shape', [ - ('PongNoFrameskip-v4', (1, 84, 84)), - ('LunarLander-v2', (8,)), - ('CartPole-v0', (4,)), +@pytest.mark.parametrize('name,state_shape,reward_scale', [ + ('PongNoFrameskip-v4', (1, 84, 84), 'sign'), + ('LunarLander-v2', (8,), None), + ('CartPole-v0', (4,), None), ]) @pytest.mark.parametrize('num_envs', (1, 4)) -def test_make_gym_venv_nostack(name, state_shape, num_envs): +def test_make_gym_venv_nostack(name, state_shape, reward_scale, num_envs): seed = 0 frame_op = None frame_op_len = None - venv = make_gym_venv(name, seed, frame_op, frame_op_len, num_envs) + venv = make_gym_venv(name, seed, frame_op, frame_op_len, reward_scale, num_envs) venv.reset() for i in range(5): state, reward, done, info = venv.step([venv.action_space.sample()] * num_envs) @@ -28,17 +28,17 @@ def test_make_gym_venv_nostack(name, state_shape, num_envs): venv.close() -@pytest.mark.parametrize('name,state_shape', [ - ('PongNoFrameskip-v4', (1, 84, 84)), - ('LunarLander-v2', (8,)), - ('CartPole-v0', (4,)), +@pytest.mark.parametrize('name,state_shape, reward_scale', [ + ('PongNoFrameskip-v4', (1, 84, 84), 'sign'), + ('LunarLander-v2', (8,), None), + ('CartPole-v0', (4,), None), ]) @pytest.mark.parametrize('num_envs', (1, 4)) -def test_make_gym_concat(name, state_shape, num_envs): +def test_make_gym_concat(name, state_shape, reward_scale, num_envs): seed = 0 frame_op = 'concat' # used for image, or for concat vector frame_op_len = 4 - venv = make_gym_venv(name, seed, frame_op, frame_op_len, num_envs) + venv = make_gym_venv(name, seed, frame_op, frame_op_len, reward_scale, num_envs) venv.reset() for i in range(5): state, reward, done, info = venv.step([venv.action_space.sample()] * num_envs) @@ -55,16 +55,16 @@ def test_make_gym_concat(name, state_shape, num_envs): @pytest.mark.skip(reason='Not implemented yet') -@pytest.mark.parametrize('name,state_shape', [ - ('LunarLander-v2', (8,)), - ('CartPole-v0', (4,)), +@pytest.mark.parametrize('name,state_shape,reward_scale', [ + ('LunarLander-v2', (8,), None), + ('CartPole-v0', (4,), None), ]) @pytest.mark.parametrize('num_envs', (1, 4)) -def test_make_gym_stack(name, state_shape, num_envs): +def test_make_gym_stack(name, state_shape, reward_scale, num_envs): seed = 0 frame_op = 'stack' # used for rnn frame_op_len = 4 - venv = make_gym_venv(name, seed, frame_op, frame_op_len, num_envs) + venv = make_gym_venv(name, seed, frame_op, frame_op_len, reward_scale, num_envs) venv.reset() for i in range(5): state, reward, done, info = venv.step([venv.action_space.sample()] * num_envs) diff --git a/test/env/test_wrapper.py b/test/env/test_wrapper.py index eb69c5e4f..6b237efef 100644 --- a/test/env/test_wrapper.py +++ b/test/env/test_wrapper.py @@ -3,16 +3,16 @@ import pytest -@pytest.mark.parametrize('name,state_shape', [ - ('PongNoFrameskip-v4', (1, 84, 84)), - ('LunarLander-v2', (8,)), - ('CartPole-v0', (4,)), +@pytest.mark.parametrize('name,state_shape,reward_scale', [ + ('PongNoFrameskip-v4', (1, 84, 84), 'sign'), + ('LunarLander-v2', (8,), None), + ('CartPole-v0', (4,), None), ]) -def test_make_gym_env_nostack(name, state_shape): +def test_make_gym_env_nostack(name, state_shape, reward_scale): seed = 0 frame_op = None frame_op_len = None - env = make_gym_env(name, seed, frame_op, frame_op_len) + env = make_gym_env(name, seed, frame_op, frame_op_len, reward_scale) env.reset() for i in range(5): state, reward, done, info = env.step(env.action_space.sample()) @@ -26,16 +26,16 @@ def test_make_gym_env_nostack(name, state_shape): env.close() -@pytest.mark.parametrize('name,state_shape', [ - ('PongNoFrameskip-v4', (1, 84, 84)), - ('LunarLander-v2', (8,)), - ('CartPole-v0', (4,)), +@pytest.mark.parametrize('name,state_shape,reward_scale', [ + ('PongNoFrameskip-v4', (1, 84, 84), 'sign'), + ('LunarLander-v2', (8,), None), + ('CartPole-v0', (4,), None), ]) -def test_make_gym_env_concat(name, state_shape): +def test_make_gym_env_concat(name, state_shape, reward_scale): seed = 0 frame_op = 'concat' # used for image, or for concat vector frame_op_len = 4 - env = make_gym_env(name, seed, frame_op, frame_op_len) + env = make_gym_env(name, seed, frame_op, frame_op_len, reward_scale) env.reset() for i in range(5): state, reward, done, info = env.step(env.action_space.sample()) @@ -53,15 +53,15 @@ def test_make_gym_env_concat(name, state_shape): env.close() -@pytest.mark.parametrize('name,state_shape', [ - ('LunarLander-v2', (8,)), - ('CartPole-v0', (4,)), +@pytest.mark.parametrize('name,state_shape, reward_scale', [ + ('LunarLander-v2', (8,), None), + ('CartPole-v0', (4,), None), ]) -def test_make_gym_env_stack(name, state_shape): +def test_make_gym_env_stack(name, state_shape, reward_scale): seed = 0 frame_op = 'stack' # used for rnn frame_op_len = 4 - env = make_gym_env(name, seed, frame_op, frame_op_len) + env = make_gym_env(name, seed, frame_op, frame_op_len, reward_scale) env.reset() for i in range(5): state, reward, done, info = env.step(env.action_space.sample())