From 1ecabdbf14a50e04660733d9376b5679ce12fe08 Mon Sep 17 00:00:00 2001 From: Dhruva Tirumala Date: Wed, 11 Aug 2021 10:06:40 -0700 Subject: [PATCH] SVG-0 agent with entropy and behavior prior. The stochastic value gradients (SVG)-0 agent bears similarity to DPG but uses the reparameterization trick to learn stochastic policies. This version uses a continuous version of RETRACE to learn the value function across multiple steps and an added entropy bonus for the actor loss. Additionally the code supports learning behavior priors with an example on running with the locomotion go to target task. PiperOrigin-RevId: 390159206 Change-Id: Ide9d8fd3dca9d6513ac571becd748a701e3c0173 --- acme/agents/tf/svg0_prior/README.md | 26 ++ acme/agents/tf/svg0_prior/__init__.py | 21 + acme/agents/tf/svg0_prior/acting.py | 67 +++ acme/agents/tf/svg0_prior/agent.py | 371 +++++++++++++++++ .../agents/tf/svg0_prior/agent_distributed.py | 252 ++++++++++++ .../tf/svg0_prior/agent_distributed_test.py | 95 +++++ acme/agents/tf/svg0_prior/agent_test.py | 97 +++++ acme/agents/tf/svg0_prior/learning.py | 387 ++++++++++++++++++ acme/agents/tf/svg0_prior/networks.py | 119 ++++++ acme/agents/tf/svg0_prior/utils.py | 157 +++++++ examples/control/lp_local_svg0.py | 58 +++ 11 files changed, 1650 insertions(+) create mode 100644 acme/agents/tf/svg0_prior/README.md create mode 100644 acme/agents/tf/svg0_prior/__init__.py create mode 100644 acme/agents/tf/svg0_prior/acting.py create mode 100644 acme/agents/tf/svg0_prior/agent.py create mode 100644 acme/agents/tf/svg0_prior/agent_distributed.py create mode 100644 acme/agents/tf/svg0_prior/agent_distributed_test.py create mode 100644 acme/agents/tf/svg0_prior/agent_test.py create mode 100644 acme/agents/tf/svg0_prior/learning.py create mode 100644 acme/agents/tf/svg0_prior/networks.py create mode 100644 acme/agents/tf/svg0_prior/utils.py create mode 100644 examples/control/lp_local_svg0.py diff --git a/acme/agents/tf/svg0_prior/README.md b/acme/agents/tf/svg0_prior/README.md new file mode 100644 index 0000000000..1464932285 --- /dev/null +++ b/acme/agents/tf/svg0_prior/README.md @@ -0,0 +1,26 @@ +# Stochastic Value Gradients (SVG) with Behavior Prior. + +This folder contains a version of the SVG-0 agent introduced in +([Heess et al., 2015]) that has been extended with an entropy bonus, RETRACE +([Munos et al., 2016]) for off-policy correction and code to learn behavior +priors ([Tirumala et al., 2019], [Galashov et al., 2019]). + +The base SVG-0 algorithm is similar to DPG and DDPG +([Silver et al., 2015], [Lillicrap et al., 2015]) but uses the +reparameterization trick to learn stochastic and not deterministic policies. In +addition, the RETRACE algorithm is used to learn value functions using multiple +timesteps of data with importance sampling for off policy correction. + +In addition an optional Behavior Prior can be learnt using this setup with an +information asymmetry that has shown to boost performance in some domains. +Example code to run with and without behavior priors on the DeepMind Control +Suite and Locomotion tasks are provided in the `examples` folder. + + +[Heess et al., 2015]: https://arxiv.org/abs/1510.09142 +[Munos et al., 2016]: https://arxiv.org/abs/1606.02647 +[Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 +[Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 +[Tirumala et al., 2020]: https://arxiv.org/abs/2010.14274 +[Galashov et al., 2019]: https://arxiv.org/abs/1905.01240 + diff --git a/acme/agents/tf/svg0_prior/__init__.py b/acme/agents/tf/svg0_prior/__init__.py new file mode 100644 index 0000000000..b4218db22c --- /dev/null +++ b/acme/agents/tf/svg0_prior/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a SVG0 agent with prior.""" + +from acme.agents.tf.svg0_prior.agent import SVG0 +from acme.agents.tf.svg0_prior.agent_distributed import DistributedSVG0 +from acme.agents.tf.svg0_prior.learning import SVG0Learner +from acme.agents.tf.svg0_prior.networks import make_default_networks +from acme.agents.tf.svg0_prior.networks import make_network_with_prior diff --git a/acme/agents/tf/svg0_prior/acting.py b/acme/agents/tf/svg0_prior/acting.py new file mode 100644 index 0000000000..f044a14e4b --- /dev/null +++ b/acme/agents/tf/svg0_prior/acting.py @@ -0,0 +1,67 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG0 actor implementation.""" + +from typing import Optional + +from acme import adders +from acme import types + +from acme.agents.tf import actors +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + +import dm_env +import sonnet as snt + + +class SVG0Actor(actors.FeedForwardActor): + """An actor that also returns `log_prob`.""" + + def __init__( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + deterministic_policy: Optional[bool] = False, + ): + super().__init__(policy_network, adder, variable_client) + self._log_prob = None + self._deterministic_policy = deterministic_policy + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # Compute the policy, conditioned on the observation. + policy = self._policy_network(batched_observation) + if self._deterministic_policy: + action = policy.mean() + else: + action = policy.sample() + self._log_prob = policy.log_prob(action) + return tf2_utils.to_numpy_squeeze(action) + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + if not self._adder: + return + + extras = {'log_prob': self._log_prob} + extras = tf2_utils.to_numpy_squeeze(extras) + self._adder.add(action, next_timestep, extras) diff --git a/acme/agents/tf/svg0_prior/agent.py b/acme/agents/tf/svg0_prior/agent.py new file mode 100644 index 0000000000..99b2b92baf --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent.py @@ -0,0 +1,371 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG0 agent implementation.""" + +import copy +import dataclasses +from typing import Iterator, List, Optional, Tuple + +from acme import adders +from acme import core +from acme import datasets +from acme import specs +from acme.adders import reverb as reverb_adders +from acme.agents import agent +from acme.agents.tf.svg0_prior import acting +from acme.agents.tf.svg0_prior import learning +from acme.tf import utils +from acme.tf import variable_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + + +@dataclasses.dataclass +class SVG0Config: + """Configuration options for the agent.""" + + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_update_period: int = 100 + policy_optimizer: Optional[snt.Optimizer] = None + critic_optimizer: Optional[snt.Optimizer] = None + prior_optimizer: Optional[snt.Optimizer] = None + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + sequence_length: int = 10 + sigma: float = 0.3 + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + distillation_cost: Optional[float] = 1e-3 + entropy_regularizer_cost: Optional[float] = 1e-3 + + +@dataclasses.dataclass +class SVG0Networks: + """Structure containing the networks for SVG0.""" + + policy_network: snt.Module + critic_network: snt.Module + prior_network: Optional[snt.Module] + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + prior_network: Optional[snt.Module] = None + ): + # This method is implemented (rather than added by the dataclass decorator) + # in order to allow observation network to be passed as an arbitrary tensor + # transformation rather than as a snt Module. + # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. + self.policy_network = policy_network + self.critic_network = critic_network + self.prior_network = prior_network + + def init(self, environment_spec: specs.EnvironmentSpec): + """Initialize the networks given an environment spec.""" + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + + # Create variables for the policy and critic nets. + _ = utils.create_variables(self.policy_network, [obs_spec]) + _ = utils.create_variables(self.critic_network, [obs_spec, act_spec]) + if self.prior_network is not None: + _ = utils.create_variables(self.prior_network, [obs_spec]) + + def make_policy( + self, + ) -> snt.Module: + """Create a single network which evaluates the policy.""" + return self.policy_network + + def make_prior( + self, + ) -> snt.Module: + """Create a single network which evaluates the prior.""" + behavior_prior = self.prior_network + return behavior_prior + + +class SVG0Builder: + """Builder for SVG0 which constructs individual components of the agent.""" + + def __init__(self, config: SVG0Config): + self._config = config + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + sequence_length: int, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + if self._config.samples_per_insert is None: + # We will take a samples_per_insert ratio of None to mean that there is + # no limit, i.e. this only implies a min size limit. + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + + else: + error_buffer = max(1, self._config.samples_per_insert) + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + + extras_spec = { + 'log_prob': tf.ones( + shape=(), dtype=tf.float32) + } + replay_table = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=reverb_adders.SequenceAdder.signature( + environment_spec, + extras_spec=extras_spec, + sequence_length=sequence_length + 1)) + + return [replay_table] + + def make_dataset_iterator( + self, + reverb_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=reverb_client.server_address, + batch_size=self._config.batch_size, + prefetch_size=self._config.prefetch_size) + + # TODO(b/155086959): Fix type stubs and remove. + return iter(dataset) # pytype: disable=wrong-arg-types + + def make_adder( + self, + replay_client: reverb.Client, + ) -> adders.Adder: + """Create an adder which records data generated by the actor/environment.""" + return reverb_adders.SequenceAdder( + client=replay_client, + sequence_length=self._config.sequence_length+1, + priority_fns={self._config.replay_table_name: lambda x: 1.}, + period=self._config.sequence_length, + end_of_episode_behavior=reverb_adders.EndBehavior.CONTINUE, + ) + + def make_actor( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + deterministic_policy: Optional[bool] = False, + ): + """Create an actor instance.""" + if variable_source: + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.variables}, + update_period=1000, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + else: + variable_client = None + + # Create the actor which defines how we take actions. + return acting.SVG0Actor( + policy_network=policy_network, + adder=adder, + variable_client=variable_client, + deterministic_policy=deterministic_policy + ) + + def make_learner( + self, + networks: Tuple[SVG0Networks, SVG0Networks], + dataset: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Creates an instance of the learner.""" + online_networks, target_networks = networks + + # The learner updates the parameters (and initializes them). + return learning.SVG0Learner( + policy_network=online_networks.policy_network, + critic_network=online_networks.critic_network, + target_policy_network=target_networks.policy_network, + target_critic_network=target_networks.critic_network, + prior_network=online_networks.prior_network, + target_prior_network=target_networks.prior_network, + policy_optimizer=self._config.policy_optimizer, + critic_optimizer=self._config.critic_optimizer, + prior_optimizer=self._config.prior_optimizer, + distillation_cost=self._config.distillation_cost, + entropy_regularizer_cost=self._config.entropy_regularizer_cost, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + dataset_iterator=dataset, + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) + + +class SVG0(agent.Agent): + """SVG0 Agent with prior. + + This implements a single-process SVG0 agent. This is an actor-critic algorithm + that generates data via a behavior policy, inserts N-step transitions into + a replay buffer, and periodically updates the policy (and as a result the + behavior) by sampling uniformly from this buffer. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + prior_network: an optional `behavior prior` to regularize against. + policy_optimizer: optimizer for the policy network updates. + critic_optimizer: optimizer for the critic network updates. + prior_optimizer: optimizer for the prior network updates. + distillation_cost: a multiplier to be used when adding distillation + against the prior to the losses. + entropy_regularizer_cost: a multiplier used for per state sample based + entropy added to the actor loss. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + sequence_length: number of timesteps to store for each trajectory. + sigma: standard deviation of zero-mean, Gaussian exploration noise. + replay_table_name: string indicating what name to give the replay table. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + # Create the Builder object which will internally create agent components. + builder = SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + # Right now this modifies min_replay_size and samples_per_insert so that + # they are not controlled by a limiter and are instead handled by the + # Agent base class (the above TODO directly references this behavior). + SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + min_replay_size=1, # Let the Agent class handle this. + max_replay_size=max_replay_size, + samples_per_insert=None, # Let the Agent class handle this. + sequence_length=sequence_length, + sigma=sigma, + replay_table_name=replay_table_name, + )) + + # TODO(mwhoffman): pass the network dataclass in directly. + online_networks = SVG0Networks(policy_network=policy_network, + critic_network=critic_network, + prior_network=prior_network,) + + # Target networks are just a copy of the online networks. + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(environment_spec) + target_networks.init(environment_spec) + + # TODO(mwhoffman): either make this Dataclass or pass only one struct. + # The network struct passed to make_learner is just a tuple for the + # time-being (for backwards compatibility). + networks = (online_networks, target_networks) + + # Create the behavior policy. + policy_network = online_networks.make_policy() + + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec, + sequence_length) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') + + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client) + actor = builder.make_actor(policy_network, adder) + dataset = builder.make_dataset_iterator(replay_client) + learner = builder.make_learner(networks, dataset, counter, logger, + checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server diff --git a/acme/agents/tf/svg0_prior/agent_distributed.py b/acme/agents/tf/svg0_prior/agent_distributed.py new file mode 100644 index 0000000000..8ac546a7dc --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent_distributed.py @@ -0,0 +1,252 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the SVG0 agent class.""" + +import copy +from typing import Callable, Dict, Optional + +import acme +from acme import specs +from acme.agents.tf.svg0_prior import agent +from acme.tf import savers as tf2_savers +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt + + +class DistributedSVG0: + """Program definition for SVG0.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + discount: float = 0.99, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + target_update_period: int = 100, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + # TODO(mwhoffman): Make network_factory directly return the struct. + # TODO(mwhoffman): Make the factory take the entire spec. + def wrapped_network_factory(action_spec): + networks_dict = network_factory(action_spec) + networks = agent.SVG0Networks( + policy_network=networks_dict.get('policy'), + critic_network=networks_dict.get('critic'), + prior_network=networks_dict.get('prior', None),) + return networks + + self._environment_factory = environment_factory + self._network_factory = wrapped_network_factory + self._environment_spec = environment_spec + self._sigma = sigma + self._num_actors = num_actors + self._num_caches = num_caches + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._sequence_length = sequence_length + + self._builder = agent.SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + agent.SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + sequence_length=sequence_length, + sigma=sigma, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + )) + + def replay(self): + """The replay storage.""" + return self._builder.make_replay_tables(self._environment_spec, + self._sequence_length) + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter): + return lp_utils.StepsLimiter(counter, self._max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(self._environment_spec.actions) + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(self._environment_spec) + target_networks.init(self._environment_spec) + + dataset = self._builder.make_dataset_iterator(replay) + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + return self._builder.make_learner( + networks=(online_networks, target_networks), + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + ) + + # Create the environment. + environment = self._environment_factory(False) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + logger: Optional[loggers.Logger] = None, + ): + """The evaluation process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + variable_source=variable_source, + deterministic_policy=True, + ) + + # Make the environment. + environment = self._environment_factory(True) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = logger or loggers.make_default_logger( + 'evaluator', + time_delta=self._log_every, + steps_key='evaluator_steps', + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name='svg0'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + with program.group('coordinator'): + _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/svg0_prior/agent_distributed_test.py b/acme/agents/tf/svg0_prior/agent_distributed_test.py new file mode 100644 index 0000000000..cdbc2599e4 --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent_distributed_test.py @@ -0,0 +1,95 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +from typing import Sequence + +from absl.testing import absltest +import acme +from acme import specs +from acme.agents.tf import svg0_prior +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + + +def make_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), +): + """Simple networks for testing..""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_control_suite(self): + """Tests that the agent can run on the control suite without crashing.""" + + agent = svg0_prior.DistributedSVG0( + environment_factory=lambda x: fakes.ContinuousEnvironment(), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/agents/tf/svg0_prior/agent_test.py b/acme/agents/tf/svg0_prior/agent_test.py new file mode 100644 index 0000000000..55670f1768 --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent_test.py @@ -0,0 +1,97 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the SVG agent.""" + +import sys +from typing import Dict, Sequence + +from absl.testing import absltest +import acme +from acme import specs +from acme import types +from acme.agents.tf import svg0_prior +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils + +import numpy as np +import sonnet as snt + + +def make_networks( + action_spec: types.NestedSpec, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), +) -> Dict[str, snt.Module]: + """Creates networks used by the agent.""" + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class SVG0Test(absltest.TestCase): + + def test_svg0(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create the networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = svg0_prior.SVG0( + environment_spec=spec, + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + # Make sure Acme doesn't import excessive number of modules. + self.assertLess(len(sys.modules), 4500) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/agents/tf/svg0_prior/learning.py b/acme/agents/tf/svg0_prior/learning.py new file mode 100644 index 0000000000..2baa1feac9 --- /dev/null +++ b/acme/agents/tf/svg0_prior/learning.py @@ -0,0 +1,387 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG learner implementation.""" + +import time +from typing import Dict, Iterator, List, Optional + +import acme +from acme.agents.tf.svg0_prior import utils as svg0_utils +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +from trfl import continuous_retrace_ops + +_MIN_LOG_VAL = 1e-20 + + +class SVG0Learner(acme.Learner): + """SVG0 learner with optional prior. + + This is the learning component of an SVG0 agent. IE it takes a dataset as + input and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset_iterator: Iterator[reverb.ReplaySample], + prior_network: Optional[snt.Module] = None, + target_prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + num_action_samples: int = 10, + lambda_: float = 1.0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. + + Args: + policy_network: the online (optimized) policy. + critic_network: the online critic. + target_policy_network: the target policy (which lags behind the online + policy). + target_critic_network: the target critic. + discount: discount to use for TD updates. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset_iterator: dataset to learn from, whether fixed or from a replay + buffer (see `acme.datasets.reverb.make_dataset` documentation). + prior_network: the online (optimized) prior. + target_prior_network: the target prior (which lags behind the online + prior). + policy_optimizer: the optimizer to be applied to the SVG-0 (policy) loss. + critic_optimizer: the optimizer to be applied to the distributional + Bellman loss. + prior_optimizer: the optimizer to be applied to the prior (distillation) + loss. + distillation_cost: a multiplier to be used when adding distillation + against the prior to the losses. + entropy_regularizer_cost: a multiplier used for per state sample based + entropy added to the actor loss. + num_action_samples: the number of action samples to use for estimating the + value function and sample based entropy. + lambda_: the `lambda` value to be used with retrace. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + self._prior_network = prior_network + self._target_prior_network = target_prior_network + + self._lambda = lambda_ + self._num_action_samples = num_action_samples + self._distillation_cost = distillation_cost + self._entropy_regularizer_cost = entropy_regularizer_cost + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Batch dataset and create iterator. + self._iterator = dataset_iterator + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._prior_optimizer = prior_optimizer or snt.optimizers.Adam(1e-4) + + # Expose the variables. + self._variables = { + 'critic': self._critic_network.variables, + 'policy': self._policy_network.variables, + } + if self._prior_network is not None: + self._variables['prior'] = self._prior_network.variables + + # Create a checkpointer and snapshotter objects. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + objects_to_save = { + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'num_steps': self._num_steps, + } + if self._prior_network is not None: + objects_to_save['prior'] = self._prior_network + objects_to_save['target_prior'] = self._target_prior_network + objects_to_save['prior_optimizer'] = self._prior_optimizer + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='svg0_learner', + objects_to_save=objects_to_save) + objects_to_snapshot = { + 'policy': self._policy_network, + 'critic': self._critic_network, + } + if self._prior_network is not None: + objects_to_snapshot['prior'] = self._prior_network + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save=objects_to_snapshot) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + # Update target network + online_variables = [ + *self._critic_network.variables, + *self._policy_network.variables, + ] + if self._prior_network is not None: + online_variables += [*self._prior_network.variables] + online_variables = tuple(online_variables) + + target_variables = [ + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ] + if self._prior_network is not None: + target_variables += [*self._target_prior_network.variables] + target_variables = tuple(target_variables) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`. + sample: reverb.ReplaySample = next(self._iterator) + data = tf2_utils.batch_to_sequence(sample.data) + observations, actions, rewards, discounts, extra = (data.observation, + data.action, + data.reward, + data.discount, + data.extras) + online_target_pi_q = svg0_utils.OnlineTargetPiQ( + online_pi=self._policy_network, + online_q=self._critic_network, + target_pi=self._target_policy_network, + target_q=self._target_critic_network, + num_samples=self._num_action_samples, + online_prior=self._prior_network, + target_prior=self._target_prior_network, + ) + with tf.GradientTape(persistent=True) as tape: + step_outputs = svg0_utils.static_rnn( + core=online_target_pi_q, + inputs=(observations, actions), + unroll_length=rewards.shape[0]) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + target_pi_samples = tf2_utils.batch_to_sequence( + step_outputs.target_samples) + # Tile observations to have shape [S, T+1, B,..]. + tiled_observations = tf2_utils.tile_nested(observations, + self._num_action_samples) + + # Finally compute target Q values on the new action samples. + # Shape: [S, T+1, B, 1] + target_q_target_pi_samples = snt.BatchApply(self._target_critic_network, + 3)(tiled_observations, + target_pi_samples) + # Compute the value estimate by averaging over the action dimension. + # Shape: [T+1, B, 1]. + target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples, axis=0) + + # Split the target V's into the target for learning + # `value_function_target` and the bootstrap value. Shape: [T, B]. + value_function_target = tf.squeeze(target_v_target_pi[:-1], axis=-1) + # Shape: [B]. + bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1) + + # When learning with a prior, add entropy terms to value targets. + if self._prior_network is not None: + value_function_target -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[:-1] + ) + bootstrap_value -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[-1]) + + # Get target log probs and behavior log probs from rollout. + # Shape: [T+1, B]. + target_log_probs_behavior_actions = ( + step_outputs.target_log_probs_behavior_actions) + behavior_log_probs = extra['log_prob'] + # Calculate importance weights. Shape: [T+1, B]. + rhos = tf.exp(target_log_probs_behavior_actions - behavior_log_probs) + + # Filter the importance weights to mask out episode restarts. Ignore the + # last action and consider the step type of the next step for masking. + # Shape: [T, B]. + episode_start_mask = tf2_utils.batch_to_sequence( + sample.data.start_of_episode)[1:] + + rhos = svg0_utils.mask_out_restarting(rhos[:-1], episode_start_mask) + + # rhos = rhos[:-1] + # Compute the log importance weights with a small value added for + # stability. + # Shape: [T, B] + log_rhos = tf.math.log(rhos + _MIN_LOG_VAL) + + # Retrieve the target and online Q values and throw away the last action. + # Shape: [T, B]. + target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1) + online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + online_pi_samples = tf2_utils.batch_to_sequence( + step_outputs.online_samples) + target_q_online_pi_samples = snt.BatchApply(self._target_critic_network, + 3)(tiled_observations, + online_pi_samples) + expected_q = tf.reduce_mean( + tf.squeeze(target_q_online_pi_samples, -1), axis=0) + + # Flip online_log_probs to be of shape [S, T+1, B] and then compute + # entropy by averaging over num samples. Final shape: [T+1, B]. + online_log_probs = tf2_utils.batch_to_sequence( + step_outputs.online_log_probs) + sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0) + retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights( + log_rhos=log_rhos, + discounts=self._discount * discounts[:-1], + rewards=rewards[:-1], + q_values=target_q_values, + values=value_function_target, + bootstrap_value=bootstrap_value, + lambda_=self._lambda, + ) + + # Critic loss. Shape: [T, B]. + critic_loss = 0.5 * tf.math.squared_difference( + tf.stop_gradient(retrace_outputs.qs), online_q_values) + + # Policy loss- SVG0 with sample based entropy. Shape: [T, B] + policy_loss = -( + expected_q + self._entropy_regularizer_cost * sample_based_entropy) + policy_loss = policy_loss[:-1] + + if self._prior_network is not None: + # When training the prior, also add the per-timestep KL cost. + policy_loss += ( + self._distillation_cost * step_outputs.analytic_kl_to_target[:-1]) + + # Ensure episode restarts are masked out when computing the losses. + critic_loss = svg0_utils.mask_out_restarting(critic_loss, + episode_start_mask) + critic_loss = tf.reduce_mean(critic_loss) + + policy_loss = svg0_utils.mask_out_restarting(policy_loss, + episode_start_mask) + policy_loss = tf.reduce_mean(policy_loss) + + if self._prior_network is not None: + prior_loss = step_outputs.analytic_kl_divergence[:-1] + prior_loss = svg0_utils.mask_out_restarting(prior_loss, + episode_start_mask) + prior_loss = tf.reduce_mean(prior_loss) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = self._critic_network.trainable_variables + + # Compute gradients. + policy_gradients = tape.gradient(policy_loss, policy_variables) + critic_gradients = tape.gradient(critic_loss, critic_variables) + if self._prior_network is not None: + prior_variables = self._prior_network.trainable_variables + prior_gradients = tape.gradient(prior_loss, prior_variables) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + losses = { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + + if self._prior_network is not None: + self._prior_optimizer.apply(prior_gradients, prior_variables) + losses['prior_loss'] = prior_loss + + # Losses to track. + return losses + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/agents/tf/svg0_prior/networks.py b/acme/agents/tf/svg0_prior/networks.py new file mode 100644 index 0000000000..f7911a804d --- /dev/null +++ b/acme/agents/tf/svg0_prior/networks.py @@ -0,0 +1,119 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for different experiment flavours.""" + +import functools +from typing import Mapping, Sequence, Optional + +from acme import specs +from acme import types +from acme.agents.tf.svg0_prior import utils as svg0_utils +from acme.tf import networks +from acme.tf import utils as tf2_utils + +import numpy as np +import sonnet as snt + + +def make_default_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (256, 256, 256), + critic_layer_sizes: Sequence[int] = (512, 512, 256), +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + "policy": policy_network, + "critic": critic_network, + } + + +def make_network_with_prior( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (200, 100), + critic_layer_sizes: Sequence[int] = (400, 300), + prior_layer_sizes: Sequence[int] = (200, 100), + policy_keys: Optional[Sequence[str]] = None, + prior_keys: Optional[Sequence[str]] = None, +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + flatten_concat_policy = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=policy_keys) + flatten_concat_prior = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=prior_keys) + + policy_network = snt.Sequential([ + flatten_concat_policy, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + observation_network=flatten_concat_policy, + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + prior_network = snt.Sequential([ + flatten_concat_prior, + networks.LayerNormMLP(prior_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + return { + "policy": policy_network, + "critic": critic_network, + "prior": prior_network, + } diff --git a/acme/agents/tf/svg0_prior/utils.py b/acme/agents/tf/svg0_prior/utils.py new file mode 100644 index 0000000000..8474fea6cc --- /dev/null +++ b/acme/agents/tf/svg0_prior/utils.py @@ -0,0 +1,157 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for SVG0 algorithm with priors.""" + +import collections +from typing import Tuple, Optional, Dict, Iterable + +from acme import types +from acme.tf import utils as tf2_utils + +import sonnet as snt +import tensorflow as tf +import tree + + +class OnlineTargetPiQ(snt.Module): + """Core to unroll online and target policies and Q functions at once. + + A core that runs online and target policies and Q functions. This can be more + efficient if the core needs to be unrolled across time and called many times. + """ + + def __init__(self, + online_pi: snt.Module, + online_q: snt.Module, + target_pi: snt.Module, + target_q: snt.Module, + num_samples: int, + online_prior: Optional[snt.Module] = None, + target_prior: Optional[snt.Module] = None, + name='OnlineTargetPiQ'): + super().__init__(name) + + self._online_pi = online_pi + self._target_pi = target_pi + self._online_q = online_q + self._target_q = target_q + self._online_prior = online_prior + self._target_prior = target_prior + + self._num_samples = num_samples + output_list = [ + 'online_samples', 'target_samples', 'target_log_probs_behavior_actions', + 'online_log_probs', 'online_q', 'target_q' + ] + if online_prior is not None: + output_list += ['analytic_kl_divergence', 'analytic_kl_to_target'] + self._output_tuple = collections.namedtuple( + 'OnlineTargetPiQ', output_list) + + def __call__(self, input_obs_and_action: Tuple[tf.Tensor, tf.Tensor]): + (obs, action) = input_obs_and_action + online_pi_dist = self._online_pi(obs) + target_pi_dist = self._target_pi(obs) + + online_samples = online_pi_dist.sample(self._num_samples) + target_samples = target_pi_dist.sample(self._num_samples) + target_log_probs_behavior_actions = target_pi_dist.log_prob(action) + + online_log_probs = online_pi_dist.log_prob(tf.stop_gradient(online_samples)) + + online_q_out = self._online_q(obs, action) + target_q_out = self._target_q(obs, action) + + output_list = [ + online_samples, target_samples, target_log_probs_behavior_actions, + online_log_probs, online_q_out, target_q_out + ] + + if self._online_prior is not None: + prior_dist = self._online_prior(obs) + target_prior_dist = self._target_prior(obs) + analytic_kl_divergence = online_pi_dist.kl_divergence(prior_dist) + analytic_kl_to_target = online_pi_dist.kl_divergence(target_prior_dist) + + output_list += [analytic_kl_divergence, analytic_kl_to_target] + output = self._output_tuple(*output_list) + return output + + +def static_rnn(core: snt.Module, inputs: types.NestedTensor, + unroll_length: int): + """Unroll core along inputs for unroll_length steps. + + Note: for time-major input tensors whose leading dimension is less than + unroll_length, `None` would be provided instead. + + Args: + core: an instance of snt.Module. + inputs: a `nest` of time-major input tensors. + unroll_length: number of time steps to unroll. + + Returns: + step_outputs: a `nest` of time-major stacked output tensors of length + `unroll_length`. + """ + step_outputs = [] + for time_dim in range(unroll_length): + inputs_t = tree.map_structure( + lambda t, i_=time_dim: t[i_] if i_ < t.shape[0] else None, inputs) + step_output = core(inputs_t) + step_outputs.append(step_output) + + step_outputs = _nest_stack(step_outputs) + return step_outputs + + +def mask_out_restarting(tensor: tf.Tensor, start_of_episode: tf.Tensor): + """Mask out `tensor` taken on the step that resets the environment. + + Args: + tensor: a time-major 2-D `Tensor` of shape [T, B]. + start_of_episode: a 2-D `Tensor` of shape [T, B] that contains the points + where the episode restarts. + + Returns: + tensor of shape [T, B] with elements are masked out according to step_types, + restarting weights of shape [T, B] + """ + tensor.get_shape().assert_has_rank(2) + start_of_episode.get_shape().assert_has_rank(2) + weights = tf.cast(~start_of_episode, dtype=tf.float32) + masked_tensor = tensor * weights + return masked_tensor + + +def batch_concat_selection(observation_dict: Dict[str, types.NestedTensor], + concat_keys: Optional[Iterable[str]] = None, + output_dtype=tf.float32) -> tf.Tensor: + """Concatenate a dict of observations into 2-D tensors.""" + concat_keys = concat_keys or sorted(observation_dict.keys()) + to_concat = [] + for obs in concat_keys: + if obs not in observation_dict: + raise KeyError( + 'Missing observation. Requested: {} (available: {})'.format( + obs, list(observation_dict.keys()))) + to_concat.append(tf.cast(observation_dict[obs], output_dtype)) + + return tf2_utils.batch_concat(to_concat) + + +def _nest_stack(list_of_nests, axis=0): + """Convert a list of nests to a nest of stacked lists.""" + return tree.map_structure(lambda *ts: tf.stack(ts, axis=axis), *list_of_nests) diff --git a/examples/control/lp_local_svg0.py b/examples/control/lp_local_svg0.py new file mode 100644 index 0000000000..2a03620bbf --- /dev/null +++ b/examples/control/lp_local_svg0.py @@ -0,0 +1,58 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running SVG0 on the control suite.""" + +from absl import app +from absl import flags +from acme.agents.tf import svg0_prior +import helpers + +from acme.utils import lp_utils + +import launchpad as lp + +FLAGS = flags.FLAGS +flags.DEFINE_string('domain', 'cartpole', 'Control suite domain name (str).') +flags.DEFINE_string('task', 'balance', 'Control suite task name (str).') + + +def main(_): + environment_factory = lp_utils.partial_kwargs( + helpers.make_environment, domain_name=FLAGS.domain, task_name=FLAGS.task) + + batch_size = 32 + sequence_length = 20 + gradient_steps_per_actor_step = 1.0 + samples_per_insert = ( + gradient_steps_per_actor_step * batch_size * sequence_length) + num_actors = 1 + + program = svg0_prior.DistributedSVG0( + environment_factory=environment_factory, + network_factory=lp_utils.partial_kwargs(svg0_prior.make_default_networks), + batch_size=batch_size, + sequence_length=sequence_length, + samples_per_insert=samples_per_insert, + entropy_regularizer_cost=1e-4, + max_replay_size=int(2e6), + target_update_period=250, + num_actors=num_actors).build() + + lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING) + + +if __name__ == '__main__': + app.run(main)