Skip to content

Commit

Permalink
Factor out a run_episode method which returns per-episode logged data.
Browse files Browse the repository at this point in the history
A common use case is to call this method (instead of run()), so the caller has some flexibility to consume the logged data after every episode.

PiperOrigin-RevId: 326652451
Change-Id: Idad1687b736c1e8312f8f28b4fdb66ce290e1f75
  • Loading branch information
Acme Contributor authored and Copybara-Service committed Aug 14, 2020
1 parent 509ff85 commit 1b1d160
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 50 deletions.
94 changes: 52 additions & 42 deletions acme/environment_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,52 @@ def __init__(
self._counter = counter or counting.Counter()
self._logger = logger or loggers.make_default_logger(label)

def run_episode(self) -> loggers.LoggingData:
"""Run one episode.
Each episode is a loop which interacts first with the environment to get an
observation and then give that observation to the agent in order to retrieve
an action.
Returns:
An instance of `loggers.LoggingData`.
"""
# Reset any counts and start the environment.
start_time = time.time()
episode_steps = 0
episode_return = 0
timestep = self._environment.reset()

# Make the first observation.
self._actor.observe_first(timestep)

# Run an episode.
while not timestep.last():
# Generate an action from the agent's policy and step the environment.
action = self._actor.select_action(timestep.observation)
timestep = self._environment.step(action)

# Have the agent observe the timestep and let the actor update itself.
self._actor.observe(action, next_timestep=timestep)
self._actor.update()

# Book-keeping.
episode_steps += 1
episode_return += timestep.reward

# Record counts.
counts = self._counter.increment(episodes=1, steps=episode_steps)

# Collect the results and combine with counts.
steps_per_second = episode_steps / (time.time() - start_time)
result = {
'episode_length': episode_steps,
'episode_return': episode_return,
'steps_per_second': steps_per_second,
}
result.update(counts)
return result

def run(self,
num_episodes: Optional[int] = None,
num_steps: Optional[int] = None):
Expand All @@ -69,12 +115,10 @@ def run(self,
least `num_steps` steps (the last episode is always run until completion,
so the total number of steps may be slightly more than `num_steps`).
At least one of these two arguments has to be None.
Each episode is itself a loop which interacts first with the environment to
get an observation and then give that observation to the agent in order to
retrieve an action. Upon termination of an episode a new episode will be
started. If the number of episodes and the number of steps are not given
then this will interact with the environment infinitely.
If both num_episodes and num_steps are `None` (default), runs without limit.
Upon termination of an episode a new episode will be started. If the number
of episodes and the number of steps are not given then this will interact
with the environment infinitely.
Args:
num_episodes: number of episodes to run the loop for.
Expand All @@ -93,43 +137,9 @@ def should_terminate(episode_count: int, step_count: int) -> bool:

episode_count, step_count = 0, 0
while not should_terminate(episode_count, step_count):
# Reset any counts and start the environment.
start_time = time.time()
episode_steps = 0
episode_return = 0
timestep = self._environment.reset()

# Make the first observation.
self._actor.observe_first(timestep)

# Run an episode.
while not timestep.last():
# Generate an action from the agent's policy and step the environment.
action = self._actor.select_action(timestep.observation)
timestep = self._environment.step(action)

# Have the agent observe the timestep and let the actor update itself.
self._actor.observe(action, next_timestep=timestep)
self._actor.update()

# Book-keeping.
episode_steps += 1
episode_return += timestep.reward

# Record counts.
counts = self._counter.increment(episodes=1, steps=episode_steps)

# Collect the results and combine with counts.
steps_per_second = episode_steps / (time.time() - start_time)
result = {
'episode_length': episode_steps,
'episode_return': episode_return,
'steps_per_second': steps_per_second,
}
result.update(counts)
result = self.run_episode()
episode_count += 1
step_count += episode_steps

step_count += result['episode_length']
# Log the given results.
self._logger.write(result)

Expand Down
32 changes: 24 additions & 8 deletions acme/environment_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,34 @@
from acme import specs
from acme.testing import fakes

EPISODE_LENGTH = 10


class EnvironmentLoopTest(absltest.TestCase):

def test_environment_loop(self):
def setUp(self):
super().setUp()
# Create the actor/environment and stick them in a loop.
environment = fakes.DiscreteEnvironment(episode_length=10)
actor = fakes.Actor(specs.make_environment_spec(environment))
loop = environment_loop.EnvironmentLoop(environment, actor)

# Run the loop. There should be episode_length+1 update calls per episode.
loop.run(num_episodes=10)
self.assertEqual(actor.num_updates, 100)
environment = fakes.DiscreteEnvironment(episode_length=EPISODE_LENGTH)
self.actor = fakes.Actor(specs.make_environment_spec(environment))
self.loop = environment_loop.EnvironmentLoop(environment, self.actor)

def test_one_episode(self):
result = self.loop.run_episode()
self.assertDictContainsSubset({'episode_length': EPISODE_LENGTH}, result)
self.assertIn('episode_return', result)
self.assertIn('steps_per_second', result)

def test_run_episodes(self):
# Run the loop. There should be EPISODE_LENGTH update calls per episode.
self.loop.run(num_episodes=10)
self.assertEqual(self.actor.num_updates, 10 * EPISODE_LENGTH)

def test_run_steps(self):
# Run the loop. This will run 2 episodes so that total number of steps is
# at least 15.
self.loop.run(num_steps=EPISODE_LENGTH + 5)
self.assertEqual(self.actor.num_updates, 2 * EPISODE_LENGTH)


if __name__ == '__main__':
Expand Down

0 comments on commit 1b1d160

Please sign in to comment.