From 329aae50695ca0290d6c151b7bbd4cefb1ccf83d Mon Sep 17 00:00:00 2001 From: "Matthew W. Hoffman" Date: Thu, 11 Feb 2021 08:04:51 -0800 Subject: [PATCH] Add a Builder class which encapsulates a full agent. This class allows for the consituent components to be broken apart so that it can be used both for distributed and non-distributed variants. For the time-being this is only incorporated into the TF D4PG agent to allow for minimal disruption and experimentation, but should be rolled out for all agents soon. PiperOrigin-RevId: 356975846 Change-Id: I00ead33da40f4f98052ae3beb218c23788ada206 --- acme/agents/builders.py | 103 +++++++++ acme/agents/tf/d4pg/agent.py | 382 +++++++++++++++++++++++--------- acme/agents/tf/d4pg/learning.py | 12 +- acme/core.py | 17 -- acme/jax/variable_utils.py | 5 +- acme/tf/variable_utils.py | 2 +- 6 files changed, 395 insertions(+), 126 deletions(-) create mode 100644 acme/agents/builders.py diff --git a/acme/agents/builders.py b/acme/agents/builders.py new file mode 100644 index 0000000000..c99760783a --- /dev/null +++ b/acme/agents/builders.py @@ -0,0 +1,103 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RL agent Builder interface.""" + +import abc +from typing import Iterator, List, Optional + +from acme import adders +from acme import core +from acme import specs +from acme.utils import counting +from acme.utils import loggers +import reverb + + +class ActorLearnerBuilder(abc.ABC): + """Defines an interface for defining the components of an RL agent. + + Implementations of this interface contain a complete specification of a + concrete RL agent. An instance of this class can be used to build an + RL agent which interacts with the environment either locally or in a + distributed setup. + """ + + @abc.abstractmethod + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + + @abc.abstractmethod + def make_dataset_iterator( + self, + replay_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + + @abc.abstractmethod + def make_adder( + self, + replay_client: reverb.Client, + ) -> Optional[adders.Adder]: + """Create an adder which records data generated by the actor/environment. + + Args: + replay_client: Reverb Client which points to the replay server. + """ + + @abc.abstractmethod + def make_actor( + self, + policy_network, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + """Create an actor instance. + + Args: + policy_network: Instance of a policy network; this should be a callable + which takes as input observations and returns actions. + adder: How data is recorded (e.g. added to replay). + variable_source: A source providing the necessary actor parameters. + """ + + @abc.abstractmethod + def make_learner( + self, + networks, + dataset: Iterator[reverb.ReplaySample], + replay_client: Optional[reverb.Client] = None, + counter: Optional[counting.Counter] = None, + # TODO(mwhoffman): consider eliminating logger and log return values. + # TODO(mwhoffman): eliminate checkpoint and move it outside. + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ) -> core.Learner: + """Creates an instance of the learner. + + Args: + networks: struct describing the networks needed by the learner; this can + be specific to the learner in question. + dataset: iterator over samples from replay. + replay_client: client which allows communication with replay, e.g. in + order to update priorities. + counter: a Counter which allows for recording of counts (learner steps, + actor steps, etc.) distributed throughout the agent. + logger: Logger object for logging metadata. + checkpoint: bool controlling whether the learner checkpoints itself. + """ diff --git a/acme/agents/tf/d4pg/agent.py b/acme/agents/tf/d4pg/agent.py index 8c98a17ef2..473a028e41 100644 --- a/acme/agents/tf/d4pg/agent.py +++ b/acme/agents/tf/d4pg/agent.py @@ -16,24 +16,227 @@ """D4PG agent implementation.""" import copy +from typing import Iterator, List, Optional, Tuple +from acme import adders +from acme import core from acme import datasets from acme import specs from acme import types -from acme.adders import reverb as adders +from acme.adders import reverb as reverb_adders from acme.agents import agent +from acme.agents import builders from acme.agents.tf import actors from acme.agents.tf.d4pg import learning -from acme.tf import networks -from acme.tf import utils as tf2_utils +from acme.tf import networks as network_utils +from acme.tf import utils +from acme.tf import variable_utils from acme.utils import counting from acme.utils import loggers + +import dataclasses import reverb import sonnet as snt import tensorflow as tf -# TODO(b/145531941): make the naming of this agent consistent. +@dataclasses.dataclass +class D4PGConfig: + """Configuration options for the D4PG agent.""" + + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_update_period: int = 100 + policy_optimizer: Optional[snt.Optimizer] = None + critic_optimizer: Optional[snt.Optimizer] = None + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + n_step: int = 5 + sigma: float = 0.3 + clipping: bool = True + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + + +@dataclasses.dataclass +class D4PGNetworks: + """Structure containing the networks for D4PG.""" + + policy_network: snt.Module + critic_network: snt.Module + observation_network: snt.Module + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation, + ): + # This method is implemented (rather than added by the dataclass decorator) + # in order to allow observation network to be passed as an arbitrary tensor + # transformation rather than as a snt Module. + # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. + self.policy_network = policy_network + self.critic_network = critic_network + self.observation_network = utils.to_sonnet_module(observation_network) + + def init(self, environment_spec: specs.EnvironmentSpec): + """Initialize the networks given an environment spec.""" + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + + # Create variables for the observation net and, as a side-effect, get a + # spec describing the embedding space. + emb_spec = utils.create_variables(self.observation_network, [obs_spec]) + + # Create variables for the policy and critic nets. + _ = utils.create_variables(self.policy_network, [emb_spec]) + _ = utils.create_variables(self.critic_network, [emb_spec, act_spec]) + + def make_policy( + self, + environment_spec: specs.EnvironmentSpec, + sigma: float = 0.0, + ) -> snt.Module: + """Create a single network which evaluates the policy.""" + # Stack the observation and policy networks. + stack = [ + self.observation_network, + self.policy_network, + ] + + # If a stochastic/non-greedy policy is requested, add Gaussian noise on + # top to enable a simple form of exploration. + # TODO(mwhoffman): Refactor this to remove it from the class. + if sigma > 0.0: + stack += [ + network_utils.ClippedGaussian(sigma), + network_utils.ClipToSpec(environment_spec.actions), + ] + + # Return a network which sequentially evaluates everything in the stack. + return snt.Sequential(stack) + + +class D4PGBuilder(builders.ActorLearnerBuilder): + """Builder for D4PG which constructs individual components of the agent.""" + + def __init__(self, config: D4PGConfig): + self._config = config + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + ) -> List[reverb.Table]: + if self._config.samples_per_insert is None: + # We will take a samples_per_insert ratio of None to mean that there is + # no limit, i.e. this only implies a min size limit. + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + + else: + # Create enough of an error buffer to give a 10% tolerance in rate. + samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert + error_buffer = self._config.min_replay_size * samples_per_insert_tolerance + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + + replay_table = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=reverb_adders.NStepTransitionAdder.signature( + environment_spec)) + + return [replay_table] + + def make_dataset_iterator( + self, + reverb_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=reverb_client.server_address, + batch_size=self._config.batch_size, + prefetch_size=self._config.prefetch_size) + + # TODO(b/155086959): Fix type stubs and remove. + return iter(dataset) # pytype: disable=wrong-arg-types + + def make_adder( + self, + replay_client: reverb.Client, + ) -> adders.Adder: + return reverb_adders.NStepTransitionAdder( + priority_fns={self._config.replay_table_name: lambda x: 1.}, + client=replay_client, + n_step=self._config.n_step, + discount=self._config.discount) + + def make_actor( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + ): + if variable_source: + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.variables}, + update_period=1000, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + else: + variable_client = None + + # Create the actor which defines how we take actions. + return actors.FeedForwardActor( + policy_network=policy_network, + adder=adder, + variable_client=variable_client, + ) + + def make_learner( + self, + networks: Tuple[D4PGNetworks, D4PGNetworks], + dataset: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + online_networks, target_networks = networks + + # The learner updates the parameters (and initializes them). + return learning.D4PGLearner( + policy_network=online_networks.policy_network, + critic_network=online_networks.critic_network, + observation_network=online_networks.observation_network, + target_policy_network=target_networks.policy_network, + target_critic_network=target_networks.critic_network, + target_observation_network=target_networks.observation_network, + policy_optimizer=self._config.policy_optimizer, + critic_optimizer=self._config.critic_optimizer, + clipping=self._config.clipping, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + dataset_iterator=dataset, + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) + + class D4PG(agent.Agent): """D4PG Agent. @@ -43,27 +246,29 @@ class D4PG(agent.Agent): behavior) by sampling uniformly from this buffer. """ - def __init__(self, - environment_spec: specs.EnvironmentSpec, - policy_network: snt.Module, - critic_network: snt.Module, - observation_network: types.TensorTransformation = tf.identity, - discount: float = 0.99, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - policy_optimizer: snt.Optimizer = None, - critic_optimizer: snt.Optimizer = None, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: float = 32.0, - n_step: int = 5, - sigma: float = 0.3, - clipping: bool = True, - logger: loggers.Logger = None, - counter: counting.Counter = None, - checkpoint: bool = True, - replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + observation_network: types.TensorTransformation = tf.identity, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + policy_optimizer: snt.Optimizer = None, + critic_optimizer: snt.Optimizer = None, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + n_step: int = 5, + sigma: float = 0.3, + clipping: bool = True, + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, + counter: counting.Counter = None, + logger: loggers.Logger = None, + checkpoint: bool = True, + ): """Initialize the agent. Args: @@ -86,95 +291,72 @@ def __init__(self, n_step: number of steps to squash into a single transition. sigma: standard deviation of zero-mean, Gaussian exploration noise. clipping: whether to clip gradients by global norm. - logger: logger object to be used by learner. + replay_table_name: string indicating what name to give the replay table. counter: counter object used to keep track of steps. + logger: logger object to be used by learner. checkpoint: boolean indicating whether to checkpoint the learner. - replay_table_name: string indicating what name to give the replay table. """ - # Create a replay server to add data to. This uses no limiter behavior in - # order to allow the Agent interface to handle it. - replay_table = reverb.Table( - name=replay_table_name, - sampler=reverb.selectors.Uniform(), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(1), - signature=adders.NStepTransitionAdder.signature(environment_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - priority_fns={replay_table_name: lambda x: 1.}, - client=reverb.Client(address), - n_step=n_step, - discount=discount) + # Create the Builder object which will internally create agent components. + builder = D4PGBuilder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + # Right now this modifies min_replay_size and samples_per_insert so that + # they are not controlled by a limiter and are instead handled by the + # Agent base class (the above TODO directly references this behavior). + D4PGConfig( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + min_replay_size=1, # Let the Agent class handle this. + max_replay_size=max_replay_size, + samples_per_insert=None, # Let the Agent class handle this. + n_step=n_step, + sigma=sigma, + clipping=clipping, + replay_table_name=replay_table_name, + )) - # The dataset provides an interface to sample from replay. - dataset = datasets.make_reverb_dataset( - table=replay_table_name, - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) + # TODO(mwhoffman): pass the network dataclass in directly. + online_networks = D4PGNetworks(policy_network=policy_network, + critic_network=critic_network, + observation_network=observation_network) - # Make sure observation network is a Sonnet Module. - observation_network = tf2_utils.to_sonnet_module(observation_network) + # Target networks are just a copy of the online networks. + target_networks = copy.deepcopy(online_networks) - # Create target networks. - target_policy_network = copy.deepcopy(policy_network) - target_critic_network = copy.deepcopy(critic_network) - target_observation_network = copy.deepcopy(observation_network) + # Initialize the networks. + online_networks.init(environment_spec) + target_networks.init(environment_spec) - # Get observation and action specs. - act_spec = environment_spec.actions - obs_spec = environment_spec.observations - emb_spec = tf2_utils.create_variables(observation_network, [obs_spec]) + # TODO(mwhoffman): either make this Dataclass or pass only one struct. + # The network struct passed to make_learner is just a tuple for the + # time-being (for backwards compatibility). + networks = (online_networks, target_networks) # Create the behavior policy. - behavior_network = snt.Sequential([ - observation_network, - policy_network, - networks.ClippedGaussian(sigma), - networks.ClipToSpec(act_spec), - ]) - - # Create variables. - tf2_utils.create_variables(policy_network, [emb_spec]) - tf2_utils.create_variables(critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_policy_network, [emb_spec]) - tf2_utils.create_variables(target_critic_network, [emb_spec, act_spec]) - tf2_utils.create_variables(target_observation_network, [obs_spec]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor(behavior_network, adder=adder) + policy_network = online_networks.make_policy(environment_spec, sigma) - # Create optimizers. - policy_optimizer = policy_optimizer or snt.optimizers.Adam( - learning_rate=1e-4) - critic_optimizer = critic_optimizer or snt.optimizers.Adam( - learning_rate=1e-4) + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') - # The learner updates the parameters (and initializes them). - learner = learning.D4PGLearner( - policy_network=policy_network, - critic_network=critic_network, - observation_network=observation_network, - target_policy_network=target_policy_network, - target_critic_network=target_critic_network, - target_observation_network=target_observation_network, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - clipping=clipping, - discount=discount, - target_update_period=target_update_period, - dataset=dataset, - counter=counter, - logger=logger, - checkpoint=checkpoint, - ) + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client) + actor = builder.make_actor(policy_network, adder) + dataset = builder.make_dataset_iterator(replay_client) + learner = builder.make_learner(networks, dataset, counter, logger, + checkpoint) super().__init__( actor=actor, learner=learner, min_observations=max(batch_size, min_replay_size), observations_per_step=float(batch_size) / samples_per_insert) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server diff --git a/acme/agents/tf/d4pg/learning.py b/acme/agents/tf/d4pg/learning.py index 2366dccdf5..ba817ad2d9 100644 --- a/acme/agents/tf/d4pg/learning.py +++ b/acme/agents/tf/d4pg/learning.py @@ -16,7 +16,7 @@ """D4PG learner implementation.""" import time -from typing import Dict, List +from typing import Dict, Iterator, List import acme from acme import types @@ -27,6 +27,7 @@ from acme.utils import counting from acme.utils import loggers import numpy as np +import reverb import sonnet as snt import tensorflow as tf import tree @@ -47,7 +48,7 @@ def __init__( target_critic_network: snt.Module, discount: float, target_update_period: int, - dataset: tf.data.Dataset, + dataset_iterator: Iterator[reverb.ReplaySample], observation_network: types.TensorTransformation = lambda x: x, target_observation_network: types.TensorTransformation = lambda x: x, policy_optimizer: snt.Optimizer = None, @@ -68,8 +69,8 @@ def __init__( discount: discount to use for TD updates. target_update_period: number of learner steps to perform before updating the target networks. - dataset: dataset to learn from, whether fixed or from a replay buffer - (see `acme.datasets.reverb.make_dataset` documentation). + dataset_iterator: dataset to learn from, whether fixed or from a replay + buffer (see `acme.datasets.reverb.make_dataset` documentation). observation_network: an optional online network to process observations before the policy and the critic. target_observation_network: the target observation network. @@ -106,8 +107,7 @@ def __init__( self._target_update_period = target_update_period # Batch dataset and create iterator. - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._iterator = dataset_iterator # Create optimizers if they aren't given. self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) diff --git a/acme/core.py b/acme/core.py index 8f9e27554a..637a7fb469 100644 --- a/acme/core.py +++ b/acme/core.py @@ -113,23 +113,6 @@ def get_variables(self, names: Sequence[str]) -> List[types.NestedArray]: """ -class VariableClient(abc.ABC): - """A variable client for updating variables from a remote source.""" - - @abc.abstractmethod - def update(self, wait: bool = False): - """Periodically updates the variables with the latest copy from the source. - - Args: - wait: if True, executes blocking update. - """ - - # TODO(b/178587027): Eliminate this method. - @abc.abstractmethod - def update_and_wait(self): - """Immediately update and block until we get the result.""" - - @metrics.record_class_usage class Worker(abc.ABC): """An interface for (potentially) distributed workers.""" diff --git a/acme/jax/variable_utils.py b/acme/jax/variable_utils.py index 2e7fafc66c..375e1405f9 100644 --- a/acme/jax/variable_utils.py +++ b/acme/jax/variable_utils.py @@ -24,7 +24,7 @@ import jax -class VariableClient(core.VariableClient): +class VariableClient: """A variable client for updating variables from a remote source.""" def __init__(self, @@ -77,7 +77,8 @@ def update(self, wait: bool = False) -> None: if wait: if self._future is not None: - if self._future.running(): self._future.cancel() + if self._future.running(): + self._future.cancel() self._future = None self._call_counter = 0 self.update_and_wait() diff --git a/acme/tf/variable_utils.py b/acme/tf/variable_utils.py index 53e4dd20f6..cdd9bf0e0e 100644 --- a/acme/tf/variable_utils.py +++ b/acme/tf/variable_utils.py @@ -24,7 +24,7 @@ import tree -class VariableClient(core.VariableClient): +class VariableClient: """A variable client for updating variables from a remote source.""" def __init__(self,