Skip to content

Commit

Permalink
Fix ZeroDiscountonLifeLoss wrapper and add test covereage.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 318249207
Change-Id: I01c2f0b76d1b15fc42bf83538fbb1bc8c0757776
  • Loading branch information
aslanides authored and Copybara-Service committed Jun 25, 2020
1 parent bd963d9 commit 262c3cf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion acme/wrappers/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def __init__(self, environment: dm_env.Environment):
self._last_num_lives = None

def reset(self) -> dm_env.TimeStep:
timestep = self._env.reset()
timestep = self._environment.reset()
self._reset_next_step = False
self._last_num_lives = timestep.observation[LIVES_INDEX]
return timestep
Expand Down
9 changes: 6 additions & 3 deletions acme/wrappers/atari_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import unittest
from absl.testing import absltest
from absl.testing import parameterized
from acme.wrappers import atari_wrapper
from dm_env import specs
import numpy as np
Expand All @@ -33,12 +34,14 @@


@unittest.skipIf(SKIP_GYM_TESTS, SKIP_GYM_MESSAGE)
class AtariWrapperTest(absltest.TestCase):
class AtariWrapperTest(parameterized.TestCase):

def test_pong(self):
@parameterized.parameters(True, False)
def test_pong(self, zero_discount_on_life_loss: bool):
env = gym.make('PongNoFrameskip-v4', full_action_space=True)
env = gym_wrapper.GymAtariAdapter(env)
env = atari_wrapper.AtariWrapper(env)
env = atari_wrapper.AtariWrapper(
env, zero_discount_on_life_loss=zero_discount_on_life_loss)

# Test converted observation spec.
observation_spec = env.observation_spec()
Expand Down

0 comments on commit 262c3cf

Please sign in to comment.