diff --git a/acme/environment_loop.py b/acme/environment_loop.py index a65e28bccd..3b1426efac 100644 --- a/acme/environment_loop.py +++ b/acme/environment_loop.py @@ -60,6 +60,52 @@ def __init__( self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger(label) + def run_episode(self) -> loggers.LoggingData: + """Run one episode. + + Each episode is a loop which interacts first with the environment to get an + observation and then give that observation to the agent in order to retrieve + an action. + + Returns: + An instance of `loggers.LoggingData`. + """ + # Reset any counts and start the environment. + start_time = time.time() + episode_steps = 0 + episode_return = 0 + timestep = self._environment.reset() + + # Make the first observation. + self._actor.observe_first(timestep) + + # Run an episode. + while not timestep.last(): + # Generate an action from the agent's policy and step the environment. + action = self._actor.select_action(timestep.observation) + timestep = self._environment.step(action) + + # Have the agent observe the timestep and let the actor update itself. + self._actor.observe(action, next_timestep=timestep) + self._actor.update() + + # Book-keeping. + episode_steps += 1 + episode_return += timestep.reward + + # Record counts. + counts = self._counter.increment(episodes=1, steps=episode_steps) + + # Collect the results and combine with counts. + steps_per_second = episode_steps / (time.time() - start_time) + result = { + 'episode_length': episode_steps, + 'episode_return': episode_return, + 'steps_per_second': steps_per_second, + } + result.update(counts) + return result + def run(self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None): @@ -69,12 +115,10 @@ def run(self, least `num_steps` steps (the last episode is always run until completion, so the total number of steps may be slightly more than `num_steps`). At least one of these two arguments has to be None. - Each episode is itself a loop which interacts first with the environment to - get an observation and then give that observation to the agent in order to - retrieve an action. Upon termination of an episode a new episode will be - started. If the number of episodes and the number of steps are not given - then this will interact with the environment infinitely. - If both num_episodes and num_steps are `None` (default), runs without limit. + + Upon termination of an episode a new episode will be started. If the number + of episodes and the number of steps are not given then this will interact + with the environment infinitely. Args: num_episodes: number of episodes to run the loop for. @@ -93,43 +137,9 @@ def should_terminate(episode_count: int, step_count: int) -> bool: episode_count, step_count = 0, 0 while not should_terminate(episode_count, step_count): - # Reset any counts and start the environment. - start_time = time.time() - episode_steps = 0 - episode_return = 0 - timestep = self._environment.reset() - - # Make the first observation. - self._actor.observe_first(timestep) - - # Run an episode. - while not timestep.last(): - # Generate an action from the agent's policy and step the environment. - action = self._actor.select_action(timestep.observation) - timestep = self._environment.step(action) - - # Have the agent observe the timestep and let the actor update itself. - self._actor.observe(action, next_timestep=timestep) - self._actor.update() - - # Book-keeping. - episode_steps += 1 - episode_return += timestep.reward - - # Record counts. - counts = self._counter.increment(episodes=1, steps=episode_steps) - - # Collect the results and combine with counts. - steps_per_second = episode_steps / (time.time() - start_time) - result = { - 'episode_length': episode_steps, - 'episode_return': episode_return, - 'steps_per_second': steps_per_second, - } - result.update(counts) + result = self.run_episode() episode_count += 1 - step_count += episode_steps - + step_count += result['episode_length'] # Log the given results. self._logger.write(result) diff --git a/acme/environment_loop_test.py b/acme/environment_loop_test.py index ff3e709fa5..bd94d549a1 100644 --- a/acme/environment_loop_test.py +++ b/acme/environment_loop_test.py @@ -21,18 +21,34 @@ from acme import specs from acme.testing import fakes +EPISODE_LENGTH = 10 + class EnvironmentLoopTest(absltest.TestCase): - def test_environment_loop(self): + def setUp(self): + super().setUp() # Create the actor/environment and stick them in a loop. - environment = fakes.DiscreteEnvironment(episode_length=10) - actor = fakes.Actor(specs.make_environment_spec(environment)) - loop = environment_loop.EnvironmentLoop(environment, actor) - - # Run the loop. There should be episode_length+1 update calls per episode. - loop.run(num_episodes=10) - self.assertEqual(actor.num_updates, 100) + environment = fakes.DiscreteEnvironment(episode_length=EPISODE_LENGTH) + self.actor = fakes.Actor(specs.make_environment_spec(environment)) + self.loop = environment_loop.EnvironmentLoop(environment, self.actor) + + def test_one_episode(self): + result = self.loop.run_episode() + self.assertDictContainsSubset({'episode_length': EPISODE_LENGTH}, result) + self.assertIn('episode_return', result) + self.assertIn('steps_per_second', result) + + def test_run_episodes(self): + # Run the loop. There should be EPISODE_LENGTH update calls per episode. + self.loop.run(num_episodes=10) + self.assertEqual(self.actor.num_updates, 10 * EPISODE_LENGTH) + + def test_run_steps(self): + # Run the loop. This will run 2 episodes so that total number of steps is + # at least 15. + self.loop.run(num_steps=EPISODE_LENGTH + 5) + self.assertEqual(self.actor.num_updates, 2 * EPISODE_LENGTH) if __name__ == '__main__':