Skip to content

Commit

Permalink
Acme: Make D4PG use the tested n-step transition adder.
Browse files Browse the repository at this point in the history
Fixes Issue 292.

PiperOrigin-RevId: 571906732
Change-Id: I7ee0c4952fab2f3eec353e787caeffab799617a9
  • Loading branch information
bshahr authored and Copybara-Service committed Oct 9, 2023
1 parent ac668d5 commit d92e23b
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 180 deletions.
1 change: 1 addition & 0 deletions acme/adders/reverb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
from acme.adders.reverb.sequence import SequenceAdder
from acme.adders.reverb.structured import create_n_step_transition_config
from acme.adders.reverb.structured import create_step_spec
from acme.adders.reverb.structured import n_step_from_trajectory
from acme.adders.reverb.structured import StructuredAdder
from acme.adders.reverb.transition import NStepTransitionAdder
183 changes: 148 additions & 35 deletions acme/adders/reverb/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import itertools
import time

from typing import Callable, List, Optional, Sequence, Sized

from absl import logging
Expand All @@ -25,6 +24,7 @@
from acme.adders import base as adders_base
from acme.adders.reverb import base as reverb_base
from acme.adders.reverb import sequence as sequence_adder
from acme.utils import tree_utils
import dm_env
import numpy as np
import reverb
Expand Down Expand Up @@ -63,8 +63,13 @@ class StructuredAdder(adders_base.Adder):
expected to perform preprocessing in the dataset pipeline on the learner.
"""

def __init__(self, client: reverb.Client, max_in_flight_items: int,
configs: Sequence[sw.Config], step_spec: Step):
def __init__(
self,
client: reverb.Client,
max_in_flight_items: int,
configs: Sequence[sw.Config],
step_spec: Step,
):
"""Initialize a StructuredAdder instance.
Args:
Expand All @@ -86,7 +91,8 @@ def __init__(self, client: reverb.Client, max_in_flight_items: int,
sw.infer_signature(list(table_configs), step_spec)
except ValueError as e:
raise ValueError(
f'Received invalid configs for table {table}: {str(e)}') from e
f'Received invalid configs for table {table}: {str(e)}'
) from e

self._client = client
self._configs = tuple(configs)
Expand All @@ -106,7 +112,9 @@ def __del__(self):
except reverb.DeadlineExceededError as e:
logging.error(
'Timeout (10 s) exceeded when flushing the writer before '
'deleting it. Caught Reverb exception: %s', str(e))
'deleting it. Caught Reverb exception: %s',
str(e),
)

def _make_step(self, **kwargs) -> Step:
"""Complete the step with None in the missing positions."""
Expand All @@ -132,7 +140,8 @@ def add_first(self, timestep: dm_env.TimeStep):
if not timestep.first():
raise ValueError(
'adder.add_first called with a timestep that was not the first of its'
'episode (i.e. one for which timestep.first() is not True)')
'episode (i.e. one for which timestep.first() is not True)'
)

if self._writer is None:
self._writer = self._client.structured_writer(self._configs)
Expand All @@ -142,46 +151,55 @@ def add_first(self, timestep: dm_env.TimeStep):
# passing `partial_step=True`.
self._writer.append(
data=self._make_step(
observation=timestep.observation,
start_of_episode=timestep.first()),
partial_step=True)
observation=timestep.observation, start_of_episode=timestep.first()
),
partial_step=True,
)
self._writer.flush(self._max_in_flight_items)

def add(self,
action: types.NestedArray,
next_timestep: dm_env.TimeStep,
extras: types.NestedArray = ()):
def add(
self,
action: types.NestedArray,
next_timestep: dm_env.TimeStep,
extras: types.NestedArray = (),
):
"""Record an action and the following timestep."""

if self._writer is None or not self._writer.step_is_open:
raise ValueError('adder.add_first must be called before adder.add.')

# Add the timestep to the buffer.
has_extras = (
len(extras) > 0 if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test
else extras is not None)
len(extras) > 0 # pylint: disable=g-explicit-length-test
if isinstance(extras, Sized)
else extras is not None
)

current_step = self._make_step(
action=action,
reward=next_timestep.reward,
discount=next_timestep.discount,
extras=extras if has_extras else self._none_step.extras)
extras=extras if has_extras else self._none_step.extras,
)
self._writer.append(current_step)

# Record the next observation and write.
self._writer.append(
data=self._make_step(
observation=next_timestep.observation,
start_of_episode=next_timestep.first()),
partial_step=True)
start_of_episode=next_timestep.first(),
),
partial_step=True,
)
self._writer.flush(self._max_in_flight_items)

if next_timestep.last():
# Complete the row by appending zeros to remaining open fields.
# TODO(b/183945808): remove this when fields are no longer expected to be
# of equal length on the learner side.
dummy_step = tree.map_structure(
lambda x: None if x is None else np.zeros_like(x), current_step)
lambda x: None if x is None else np.zeros_like(x), current_step
)
self._writer.append(dummy_step)
self.reset()

Expand All @@ -192,7 +210,8 @@ def create_step_spec(
return Step(
*environment_spec,
start_of_episode=tf.TensorSpec([], tf.bool, 'start_of_episode'),
extras=extras_spec)
extras=extras_spec,
)


def _last_n(n: int, step_spec: Step) -> Trajectory:
Expand Down Expand Up @@ -227,8 +246,8 @@ def create_sequence_config(
end_of_episode_behavior: Determines how sequences at the end of the episode
are handled (default `EndOfEpisodeBehavior.TRUNCATE`). See the docstring
of `EndOfEpisodeBehavior` for more information.
sequence_pattern: Transformation to obtain a sequence given the length
and the shape of the step.
sequence_pattern: Transformation to obtain a sequence given the length and
the shape of the step.
Returns:
A list of configs for `StructuredAdder` to produce the described behaviour.
Expand All @@ -242,14 +261,16 @@ def create_sequence_config(

if end_of_episode_behavior == EndBehavior.ZERO_PAD:
raise NotImplementedError(
'Zero-padding is not supported. Please use TRUNCATE instead.')
'Zero-padding is not supported. Please use TRUNCATE instead.'
)

if end_of_episode_behavior == EndBehavior.CONTINUE:
raise NotImplementedError('Merging episodes is not supported.')

def _sequence_pattern(n: int) -> sw.Pattern:
return sw.pattern_from_transform(step_spec,
lambda step: sequence_pattern(n, step))
return sw.pattern_from_transform(
step_spec, lambda step: sequence_pattern(n, step)
)

# The base config is considered for all but the last step in the episode. No
# trajectories are created for the first `sequence_step-1` steps and then a
Expand All @@ -260,7 +281,8 @@ def _sequence_pattern(n: int) -> sw.Pattern:
conditions=[
sw.Condition.step_index() >= sequence_length - 1,
sw.Condition.step_index() % period == (sequence_length - 1) % period,
])
],
)

end_of_episode_configs = []
if end_of_episode_behavior == EndBehavior.WRITE:
Expand All @@ -275,7 +297,8 @@ def _sequence_pattern(n: int) -> sw.Pattern:
conditions=[
sw.Condition.is_end_episode(),
sw.Condition.step_index() >= sequence_length - 1,
])
],
)
end_of_episode_configs.append(config)
elif end_of_episode_behavior == EndBehavior.TRUNCATE:
# The first trajectory is written at step index `sequence_length - 1` and
Expand Down Expand Up @@ -315,7 +338,8 @@ def _sequence_pattern(n: int) -> sw.Pattern:
sw.Condition.is_end_episode(),
sw.Condition.step_index() % period == x,
sw.Condition.step_index() >= sequence_length,
])
],
)
end_of_episode_configs.append(config)

# The above configs will capture the "remainder" of any episode that is at
Expand All @@ -330,19 +354,22 @@ def _sequence_pattern(n: int) -> sw.Pattern:
conditions=[
sw.Condition.is_end_episode(),
sw.Condition.step_index() == x - 1,
])
],
)
end_of_episode_configs.append(config)
else:
raise ValueError(
f'Unexpected `end_of_episod_behavior`: {end_of_episode_behavior}')
f'Unexpected `end_of_episod_behavior`: {end_of_episode_behavior}'
)

return [base_config] + end_of_episode_configs


def create_n_step_transition_config(
step_spec: Step,
n_step: int,
table: str = reverb_base.DEFAULT_PRIORITY_TABLE) -> List[sw.Config]:
table: str = reverb_base.DEFAULT_PRIORITY_TABLE,
) -> List[sw.Config]:
"""Generates configs that replicates the behaviour of NStepTransitionAdder.
Please see the docstring of NStepTransitionAdder for more details.
Expand Down Expand Up @@ -370,9 +397,9 @@ def create_n_step_transition_config(
def _make_pattern(n: int):
ref_step = sw.create_reference_step(step_spec)

get_first = lambda x: x[-(n + 1):-n]
get_all = lambda x: x[-(n + 1):-1]
get_first_and_last = lambda x: x[-(n + 1)::n]
get_first = lambda x: x[-(n + 1) : -n]
get_all = lambda x: x[-(n + 1) : -1]
get_first_and_last = lambda x: x[-(n + 1) :: n]

tmap = tree.map_structure

Expand All @@ -388,7 +415,8 @@ def _make_pattern(n: int):
reward=tmap(get_all, ref_step.reward),
discount=tmap(get_all, ref_step.discount),
start_of_episode=tmap(get_first, ref_step.start_of_episode),
extras=tmap(get_first, ref_step.extras))
extras=tmap(get_first, ref_step.extras),
)

# At the start of the episodes we'll add shorter transitions.
start_of_episode_configs = []
Expand Down Expand Up @@ -422,3 +450,88 @@ def _make_pattern(n: int):
end_of_episode_configs.append(config)

return start_of_episode_configs + [base_config] + end_of_episode_configs


def n_step_from_trajectory(
trajectory: reverb_base.Trajectory,
agent_discount: float,
) -> types.Transition:
"""Converts an (n+1)-step trajectory into an n-step transition."""

rewards, discount = _compute_cumulative_quantities(
rewards=trajectory.reward,
discounts=trajectory.discount,
additional_discount=agent_discount,
)

tmap = tree.map_structure
return types.Transition(
observation=tmap(lambda x: x[0], trajectory.observation),
action=tmap(lambda x: x[0], trajectory.action),
reward=rewards,
discount=discount,
next_observation=tmap(lambda x: x[-1], trajectory.observation),
extras=tmap(lambda x: x[0], trajectory.extras),
)


def _compute_cumulative_quantities(
rewards: types.NestedArray,
discounts: types.NestedArray,
additional_discount: float,
):
"""Stolen from TransitionAdder."""

# Give the same tree structure to the n-step return accumulator,
# n-step discount accumulator, and self.discount, so that they can be
# iterated in parallel using tree.map_structure.
rewards, discounts = tree_utils.broadcast_structures(rewards, discounts)
flat_rewards = tree.flatten(rewards)
flat_discounts = tree.flatten(discounts)
n_step = tf.shape(flat_rewards[0])[0]
# Initialize flat output containers.
flat_total_discounts = []
flat_n_step_returns = []

def scan_body(
state: types.NestedTensor,
discount_and_reward: types.NestedTensor,
) -> types.NestedTensor:
compound_discount, discounted_return = state
discount, reward = discount_and_reward
return (
additional_discount * discount * compound_discount,
discounted_return + additional_discount * compound_discount * reward,
)

for reward, discount in zip(flat_rewards, flat_discounts):
shape = tf.broadcast_static_shape(
tf.TensorShape(reward[0].shape),
tf.TensorShape(discount[0].shape),
)
total_discount = discount[0]
n_step_return = tf.broadcast_to(reward[0], shape)

if n_step > 1:
# NOTE: total_discount will have one less additional_discount applied to
# it (compared to flat_discount). This is so that when the learner/update
# uses an additional discount we don't apply it twice. Inside the
# following loop we will apply this right before summing up the
# n_step_return.
total_discount, n_step_return = tf.scan(
scan_body,
(discount[1:], reward[1:]),
(total_discount, n_step_return),
)

# Add the last return and discount of the scan, which correspond to the
# n-step return and environment discount.
n_step_return = n_step_return[-1]
total_discount = total_discount[-1]

flat_n_step_returns.append(n_step_return)
flat_total_discounts.append(total_discount)

n_step_return = tree.unflatten_as(rewards, flat_n_step_returns)
total_discount = tree.unflatten_as(rewards, flat_total_discounts)
return n_step_return, total_discount
Loading

0 comments on commit d92e23b

Please sign in to comment.