From 6b1b413a39cefe1556190600f2a7430cb3b60b26 Mon Sep 17 00:00:00 2001 From: Acme Contributor Date: Mon, 24 Aug 2020 08:51:51 -0700 Subject: [PATCH] Wrapper that implements action repeats. PiperOrigin-RevId: 328147500 Change-Id: If62d7fd7a5e1255a478ec1aa30a246000e55100c --- acme/wrappers/__init__.py | 1 + acme/wrappers/action_repeat.py | 48 ++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 acme/wrappers/action_repeat.py diff --git a/acme/wrappers/__init__.py b/acme/wrappers/__init__.py index c89802a6db..d113528227 100644 --- a/acme/wrappers/__init__.py +++ b/acme/wrappers/__init__.py @@ -14,6 +14,7 @@ """Common environment wrapper classes.""" +from acme.wrappers.action_repeat import ActionRepeatWrapper from acme.wrappers.atari_wrapper import AtariWrapper from acme.wrappers.base import EnvironmentWrapper from acme.wrappers.base import wrap_all diff --git a/acme/wrappers/action_repeat.py b/acme/wrappers/action_repeat.py new file mode 100644 index 0000000000..dc376c8ac0 --- /dev/null +++ b/acme/wrappers/action_repeat.py @@ -0,0 +1,48 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper that implements action repeats.""" + +from acme import types +from acme.wrappers import base +import dm_env + + +class ActionRepeatWrapper(base.EnvironmentWrapper): + """Action repeat wrapper.""" + + def __init__(self, environment: dm_env.Environment, num_repeats: int = 1): + super().__init__(environment) + self._num_repeats = num_repeats + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + # Initialize accumulated reward and discount. + reward = 0. + discount = 1. + + # Step the environment by repeating action. + for _ in range(self._num_repeats): + timestep = self._environment.step(action) + + # Accumulate reward and discount. + reward += timestep.reward * discount + discount *= timestep.discount + + # Don't go over episode boundaries. + if timestep.last(): + break + + # Replace the final timestep's reward and discount. + return timestep._replace(reward=reward, discount=discount)