diff --git a/acme/agents/tf/svg0_prior/README.md b/acme/agents/tf/svg0_prior/README.md new file mode 100644 index 0000000000..1464932285 --- /dev/null +++ b/acme/agents/tf/svg0_prior/README.md @@ -0,0 +1,26 @@ +# Stochastic Value Gradients (SVG) with Behavior Prior. + +This folder contains a version of the SVG-0 agent introduced in +([Heess et al., 2015]) that has been extended with an entropy bonus, RETRACE +([Munos et al., 2016]) for off-policy correction and code to learn behavior +priors ([Tirumala et al., 2019], [Galashov et al., 2019]). + +The base SVG-0 algorithm is similar to DPG and DDPG +([Silver et al., 2015], [Lillicrap et al., 2015]) but uses the +reparameterization trick to learn stochastic and not deterministic policies. In +addition, the RETRACE algorithm is used to learn value functions using multiple +timesteps of data with importance sampling for off policy correction. + +In addition an optional Behavior Prior can be learnt using this setup with an +information asymmetry that has shown to boost performance in some domains. +Example code to run with and without behavior priors on the DeepMind Control +Suite and Locomotion tasks are provided in the `examples` folder. + + +[Heess et al., 2015]: https://arxiv.org/abs/1510.09142 +[Munos et al., 2016]: https://arxiv.org/abs/1606.02647 +[Lillicrap et al., 2015]: https://arxiv.org/abs/1509.02971 +[Silver et al., 2014]: http://proceedings.mlr.press/v32/silver14 +[Tirumala et al., 2020]: https://arxiv.org/abs/2010.14274 +[Galashov et al., 2019]: https://arxiv.org/abs/1905.01240 + diff --git a/acme/agents/tf/svg0_prior/__init__.py b/acme/agents/tf/svg0_prior/__init__.py new file mode 100644 index 0000000000..b4218db22c --- /dev/null +++ b/acme/agents/tf/svg0_prior/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a SVG0 agent with prior.""" + +from acme.agents.tf.svg0_prior.agent import SVG0 +from acme.agents.tf.svg0_prior.agent_distributed import DistributedSVG0 +from acme.agents.tf.svg0_prior.learning import SVG0Learner +from acme.agents.tf.svg0_prior.networks import make_default_networks +from acme.agents.tf.svg0_prior.networks import make_network_with_prior diff --git a/acme/agents/tf/svg0_prior/acting.py b/acme/agents/tf/svg0_prior/acting.py new file mode 100644 index 0000000000..f044a14e4b --- /dev/null +++ b/acme/agents/tf/svg0_prior/acting.py @@ -0,0 +1,67 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG0 actor implementation.""" + +from typing import Optional + +from acme import adders +from acme import types + +from acme.agents.tf import actors +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + +import dm_env +import sonnet as snt + + +class SVG0Actor(actors.FeedForwardActor): + """An actor that also returns `log_prob`.""" + + def __init__( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + deterministic_policy: Optional[bool] = False, + ): + super().__init__(policy_network, adder, variable_client) + self._log_prob = None + self._deterministic_policy = deterministic_policy + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # Compute the policy, conditioned on the observation. + policy = self._policy_network(batched_observation) + if self._deterministic_policy: + action = policy.mean() + else: + action = policy.sample() + self._log_prob = policy.log_prob(action) + return tf2_utils.to_numpy_squeeze(action) + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + if not self._adder: + return + + extras = {'log_prob': self._log_prob} + extras = tf2_utils.to_numpy_squeeze(extras) + self._adder.add(action, next_timestep, extras) diff --git a/acme/agents/tf/svg0_prior/agent.py b/acme/agents/tf/svg0_prior/agent.py new file mode 100644 index 0000000000..99b2b92baf --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent.py @@ -0,0 +1,371 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG0 agent implementation.""" + +import copy +import dataclasses +from typing import Iterator, List, Optional, Tuple + +from acme import adders +from acme import core +from acme import datasets +from acme import specs +from acme.adders import reverb as reverb_adders +from acme.agents import agent +from acme.agents.tf.svg0_prior import acting +from acme.agents.tf.svg0_prior import learning +from acme.tf import utils +from acme.tf import variable_utils +from acme.utils import counting +from acme.utils import loggers +import reverb +import sonnet as snt +import tensorflow as tf + + +@dataclasses.dataclass +class SVG0Config: + """Configuration options for the agent.""" + + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_update_period: int = 100 + policy_optimizer: Optional[snt.Optimizer] = None + critic_optimizer: Optional[snt.Optimizer] = None + prior_optimizer: Optional[snt.Optimizer] = None + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + sequence_length: int = 10 + sigma: float = 0.3 + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + distillation_cost: Optional[float] = 1e-3 + entropy_regularizer_cost: Optional[float] = 1e-3 + + +@dataclasses.dataclass +class SVG0Networks: + """Structure containing the networks for SVG0.""" + + policy_network: snt.Module + critic_network: snt.Module + prior_network: Optional[snt.Module] + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + prior_network: Optional[snt.Module] = None + ): + # This method is implemented (rather than added by the dataclass decorator) + # in order to allow observation network to be passed as an arbitrary tensor + # transformation rather than as a snt Module. + # TODO(mwhoffman): use Protocol rather than Module/TensorTransformation. + self.policy_network = policy_network + self.critic_network = critic_network + self.prior_network = prior_network + + def init(self, environment_spec: specs.EnvironmentSpec): + """Initialize the networks given an environment spec.""" + # Get observation and action specs. + act_spec = environment_spec.actions + obs_spec = environment_spec.observations + + # Create variables for the policy and critic nets. + _ = utils.create_variables(self.policy_network, [obs_spec]) + _ = utils.create_variables(self.critic_network, [obs_spec, act_spec]) + if self.prior_network is not None: + _ = utils.create_variables(self.prior_network, [obs_spec]) + + def make_policy( + self, + ) -> snt.Module: + """Create a single network which evaluates the policy.""" + return self.policy_network + + def make_prior( + self, + ) -> snt.Module: + """Create a single network which evaluates the prior.""" + behavior_prior = self.prior_network + return behavior_prior + + +class SVG0Builder: + """Builder for SVG0 which constructs individual components of the agent.""" + + def __init__(self, config: SVG0Config): + self._config = config + + def make_replay_tables( + self, + environment_spec: specs.EnvironmentSpec, + sequence_length: int, + ) -> List[reverb.Table]: + """Create tables to insert data into.""" + if self._config.samples_per_insert is None: + # We will take a samples_per_insert ratio of None to mean that there is + # no limit, i.e. this only implies a min size limit. + limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + + else: + error_buffer = max(1, self._config.samples_per_insert) + limiter = reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer) + + extras_spec = { + 'log_prob': tf.ones( + shape=(), dtype=tf.float32) + } + replay_table = reverb.Table( + name=self._config.replay_table_name, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter, + signature=reverb_adders.SequenceAdder.signature( + environment_spec, + extras_spec=extras_spec, + sequence_length=sequence_length + 1)) + + return [replay_table] + + def make_dataset_iterator( + self, + reverb_client: reverb.Client, + ) -> Iterator[reverb.ReplaySample]: + """Create a dataset iterator to use for learning/updating the agent.""" + # The dataset provides an interface to sample from replay. + dataset = datasets.make_reverb_dataset( + table=self._config.replay_table_name, + server_address=reverb_client.server_address, + batch_size=self._config.batch_size, + prefetch_size=self._config.prefetch_size) + + # TODO(b/155086959): Fix type stubs and remove. + return iter(dataset) # pytype: disable=wrong-arg-types + + def make_adder( + self, + replay_client: reverb.Client, + ) -> adders.Adder: + """Create an adder which records data generated by the actor/environment.""" + return reverb_adders.SequenceAdder( + client=replay_client, + sequence_length=self._config.sequence_length+1, + priority_fns={self._config.replay_table_name: lambda x: 1.}, + period=self._config.sequence_length, + end_of_episode_behavior=reverb_adders.EndBehavior.CONTINUE, + ) + + def make_actor( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_source: Optional[core.VariableSource] = None, + deterministic_policy: Optional[bool] = False, + ): + """Create an actor instance.""" + if variable_source: + # Create the variable client responsible for keeping the actor up-to-date. + variable_client = variable_utils.VariableClient( + client=variable_source, + variables={'policy': policy_network.variables}, + update_period=1000, + ) + + # Make sure not to use a random policy after checkpoint restoration by + # assigning variables before running the environment loop. + variable_client.update_and_wait() + + else: + variable_client = None + + # Create the actor which defines how we take actions. + return acting.SVG0Actor( + policy_network=policy_network, + adder=adder, + variable_client=variable_client, + deterministic_policy=deterministic_policy + ) + + def make_learner( + self, + networks: Tuple[SVG0Networks, SVG0Networks], + dataset: Iterator[reverb.ReplaySample], + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = False, + ): + """Creates an instance of the learner.""" + online_networks, target_networks = networks + + # The learner updates the parameters (and initializes them). + return learning.SVG0Learner( + policy_network=online_networks.policy_network, + critic_network=online_networks.critic_network, + target_policy_network=target_networks.policy_network, + target_critic_network=target_networks.critic_network, + prior_network=online_networks.prior_network, + target_prior_network=target_networks.prior_network, + policy_optimizer=self._config.policy_optimizer, + critic_optimizer=self._config.critic_optimizer, + prior_optimizer=self._config.prior_optimizer, + distillation_cost=self._config.distillation_cost, + entropy_regularizer_cost=self._config.entropy_regularizer_cost, + discount=self._config.discount, + target_update_period=self._config.target_update_period, + dataset_iterator=dataset, + counter=counter, + logger=logger, + checkpoint=checkpoint, + ) + + +class SVG0(agent.Agent): + """SVG0 Agent with prior. + + This implements a single-process SVG0 agent. This is an actor-critic algorithm + that generates data via a behavior policy, inserts N-step transitions into + a replay buffer, and periodically updates the policy (and as a result the + behavior) by sampling uniformly from this buffer. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + policy_network: snt.Module, + critic_network: snt.Module, + discount: float = 0.99, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: float = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + policy_network: the online (optimized) policy. + critic_network: the online critic. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + prior_network: an optional `behavior prior` to regularize against. + policy_optimizer: optimizer for the policy network updates. + critic_optimizer: optimizer for the critic network updates. + prior_optimizer: optimizer for the prior network updates. + distillation_cost: a multiplier to be used when adding distillation + against the prior to the losses. + entropy_regularizer_cost: a multiplier used for per state sample based + entropy added to the actor loss. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert + that is made. + sequence_length: number of timesteps to store for each trajectory. + sigma: standard deviation of zero-mean, Gaussian exploration noise. + replay_table_name: string indicating what name to give the replay table. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + # Create the Builder object which will internally create agent components. + builder = SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + # Right now this modifies min_replay_size and samples_per_insert so that + # they are not controlled by a limiter and are instead handled by the + # Agent base class (the above TODO directly references this behavior). + SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + min_replay_size=1, # Let the Agent class handle this. + max_replay_size=max_replay_size, + samples_per_insert=None, # Let the Agent class handle this. + sequence_length=sequence_length, + sigma=sigma, + replay_table_name=replay_table_name, + )) + + # TODO(mwhoffman): pass the network dataclass in directly. + online_networks = SVG0Networks(policy_network=policy_network, + critic_network=critic_network, + prior_network=prior_network,) + + # Target networks are just a copy of the online networks. + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(environment_spec) + target_networks.init(environment_spec) + + # TODO(mwhoffman): either make this Dataclass or pass only one struct. + # The network struct passed to make_learner is just a tuple for the + # time-being (for backwards compatibility). + networks = (online_networks, target_networks) + + # Create the behavior policy. + policy_network = online_networks.make_policy() + + # Create the replay server and grab its address. + replay_tables = builder.make_replay_tables(environment_spec, + sequence_length) + replay_server = reverb.Server(replay_tables, port=None) + replay_client = reverb.Client(f'localhost:{replay_server.port}') + + # Create actor, dataset, and learner for generating, storing, and consuming + # data respectively. + adder = builder.make_adder(replay_client) + actor = builder.make_actor(policy_network, adder) + dataset = builder.make_dataset_iterator(replay_client) + learner = builder.make_learner(networks, dataset, counter, logger, + checkpoint) + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert) + + # Save the replay so we don't garbage collect it. + self._replay_server = replay_server diff --git a/acme/agents/tf/svg0_prior/agent_distributed.py b/acme/agents/tf/svg0_prior/agent_distributed.py new file mode 100644 index 0000000000..8ac546a7dc --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent_distributed.py @@ -0,0 +1,252 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the SVG0 agent class.""" + +import copy +from typing import Callable, Dict, Optional + +import acme +from acme import specs +from acme.agents.tf.svg0_prior import agent +from acme.tf import savers as tf2_savers +from acme.utils import counting +from acme.utils import loggers +from acme.utils import lp_utils +import dm_env +import launchpad as lp +import reverb +import sonnet as snt + + +class DistributedSVG0: + """Program definition for SVG0.""" + + def __init__( + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[specs.BoundedArray], Dict[str, snt.Module]], + num_actors: int = 1, + num_caches: int = 0, + environment_spec: Optional[specs.EnvironmentSpec] = None, + batch_size: int = 256, + prefetch_size: int = 4, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + samples_per_insert: Optional[float] = 32.0, + sequence_length: int = 10, + sigma: float = 0.3, + discount: float = 0.99, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + target_update_period: int = 100, + max_actor_steps: Optional[int] = None, + log_every: float = 10.0, + ): + + if not environment_spec: + environment_spec = specs.make_environment_spec(environment_factory(False)) + + # TODO(mwhoffman): Make network_factory directly return the struct. + # TODO(mwhoffman): Make the factory take the entire spec. + def wrapped_network_factory(action_spec): + networks_dict = network_factory(action_spec) + networks = agent.SVG0Networks( + policy_network=networks_dict.get('policy'), + critic_network=networks_dict.get('critic'), + prior_network=networks_dict.get('prior', None),) + return networks + + self._environment_factory = environment_factory + self._network_factory = wrapped_network_factory + self._environment_spec = environment_spec + self._sigma = sigma + self._num_actors = num_actors + self._num_caches = num_caches + self._max_actor_steps = max_actor_steps + self._log_every = log_every + self._sequence_length = sequence_length + + self._builder = agent.SVG0Builder( + # TODO(mwhoffman): pass the config dataclass in directly. + # TODO(mwhoffman): use the limiter rather than the workaround below. + agent.SVG0Config( + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_update_period=target_update_period, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + prior_optimizer=prior_optimizer, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + sequence_length=sequence_length, + sigma=sigma, + distillation_cost=distillation_cost, + entropy_regularizer_cost=entropy_regularizer_cost, + )) + + def replay(self): + """The replay storage.""" + return self._builder.make_replay_tables(self._environment_spec, + self._sequence_length) + + def counter(self): + return tf2_savers.CheckpointingRunner(counting.Counter(), + time_delta_minutes=1, + subdirectory='counter') + + def coordinator(self, counter: counting.Counter): + return lp_utils.StepsLimiter(counter, self._max_actor_steps) + + def learner( + self, + replay: reverb.Client, + counter: counting.Counter, + ): + """The Learning part of the agent.""" + + # Create the networks to optimize (online) and target networks. + online_networks = self._network_factory(self._environment_spec.actions) + target_networks = copy.deepcopy(online_networks) + + # Initialize the networks. + online_networks.init(self._environment_spec) + target_networks.init(self._environment_spec) + + dataset = self._builder.make_dataset_iterator(replay) + counter = counting.Counter(counter, 'learner') + logger = loggers.make_default_logger( + 'learner', time_delta=self._log_every, steps_key='learner_steps') + + return self._builder.make_learner( + networks=(online_networks, target_networks), + dataset=dataset, + counter=counter, + logger=logger, + ) + + def actor( + self, + replay: reverb.Client, + variable_source: acme.VariableSource, + counter: counting.Counter, + ) -> acme.EnvironmentLoop: + """The actor process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + ) + + # Create the environment. + environment = self._environment_factory(False) + + # Create logger and counter; actors will not spam bigtable. + counter = counting.Counter(counter, 'actor') + logger = loggers.make_default_logger( + 'actor', + save_data=False, + time_delta=self._log_every, + steps_key='actor_steps') + + # Create the loop to connect environment and agent. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def evaluator( + self, + variable_source: acme.VariableSource, + counter: counting.Counter, + logger: Optional[loggers.Logger] = None, + ): + """The evaluation process.""" + + # Create the behavior policy. + networks = self._network_factory(self._environment_spec.actions) + networks.init(self._environment_spec) + policy_network = networks.make_policy() + + # Create the agent. + actor = self._builder.make_actor( + policy_network=policy_network, + variable_source=variable_source, + deterministic_policy=True, + ) + + # Make the environment. + environment = self._environment_factory(True) + + # Create logger and counter. + counter = counting.Counter(counter, 'evaluator') + logger = logger or loggers.make_default_logger( + 'evaluator', + time_delta=self._log_every, + steps_key='evaluator_steps', + ) + + # Create the run loop and return it. + return acme.EnvironmentLoop(environment, actor, counter, logger) + + def build(self, name='svg0'): + """Build the distributed agent topology.""" + program = lp.Program(name=name) + + with program.group('replay'): + replay = program.add_node(lp.ReverbNode(self.replay)) + + with program.group('counter'): + counter = program.add_node(lp.CourierNode(self.counter)) + + if self._max_actor_steps: + with program.group('coordinator'): + _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + + with program.group('learner'): + learner = program.add_node(lp.CourierNode(self.learner, replay, counter)) + + with program.group('evaluator'): + program.add_node(lp.CourierNode(self.evaluator, learner, counter)) + + if not self._num_caches: + # Use our learner as a single variable source. + sources = [learner] + else: + with program.group('cacher'): + # Create a set of learner caches. + sources = [] + for _ in range(self._num_caches): + cacher = program.add_node( + lp.CacherNode( + learner, refresh_interval_ms=2000, stale_after_ms=4000)) + sources.append(cacher) + + with program.group('actor'): + # Add actors which pull round-robin from our variable sources. + for actor_id in range(self._num_actors): + source = sources[actor_id % len(sources)] + program.add_node(lp.CourierNode(self.actor, replay, source, counter)) + + return program diff --git a/acme/agents/tf/svg0_prior/agent_distributed_test.py b/acme/agents/tf/svg0_prior/agent_distributed_test.py new file mode 100644 index 0000000000..cdbc2599e4 --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent_distributed_test.py @@ -0,0 +1,95 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for the distributed agent.""" + +from typing import Sequence + +from absl.testing import absltest +import acme +from acme import specs +from acme.agents.tf import svg0_prior +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils +import launchpad as lp +import numpy as np +import sonnet as snt + + +def make_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), +): + """Simple networks for testing..""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class DistributedAgentTest(absltest.TestCase): + """Simple integration/smoke test for the distributed agent.""" + + def test_control_suite(self): + """Tests that the agent can run on the control suite without crashing.""" + + agent = svg0_prior.DistributedSVG0( + environment_factory=lambda x: fakes.ContinuousEnvironment(), + network_factory=make_networks, + num_actors=2, + batch_size=32, + min_replay_size=32, + max_replay_size=1000, + ) + program = agent.build() + + (learner_node,) = program.groups['learner'] + learner_node.disable_run() + + lp.launch(program, launch_type='test_mt') + + learner: acme.Learner = learner_node.create_handle().dereference() + + for _ in range(5): + learner.step() + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/agents/tf/svg0_prior/agent_test.py b/acme/agents/tf/svg0_prior/agent_test.py new file mode 100644 index 0000000000..55670f1768 --- /dev/null +++ b/acme/agents/tf/svg0_prior/agent_test.py @@ -0,0 +1,97 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the SVG agent.""" + +import sys +from typing import Dict, Sequence + +from absl.testing import absltest +import acme +from acme import specs +from acme import types +from acme.agents.tf import svg0_prior +from acme.testing import fakes +from acme.tf import networks +from acme.tf import utils as tf2_utils + +import numpy as np +import sonnet as snt + + +def make_networks( + action_spec: types.NestedSpec, + policy_layer_sizes: Sequence[int] = (10, 10), + critic_layer_sizes: Sequence[int] = (10, 10), +) -> Dict[str, snt.Module]: + """Creates networks used by the agent.""" + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer() + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + 'policy': policy_network, + 'critic': critic_network, + } + + +class SVG0Test(absltest.TestCase): + + def test_svg0(self): + # Create a fake environment to test with. + environment = fakes.ContinuousEnvironment(episode_length=10) + spec = specs.make_environment_spec(environment) + + # Create the networks. + agent_networks = make_networks(spec.actions) + + # Construct the agent. + agent = svg0_prior.SVG0( + environment_spec=spec, + policy_network=agent_networks['policy'], + critic_network=agent_networks['critic'], + batch_size=10, + samples_per_insert=2, + min_replay_size=10, + ) + + # Try running the environment loop. We have no assertions here because all + # we care about is that the agent runs without raising any errors. + loop = acme.EnvironmentLoop(environment, agent) + loop.run(num_episodes=2) + # Make sure Acme doesn't import excessive number of modules. + self.assertLess(len(sys.modules), 4500) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/agents/tf/svg0_prior/learning.py b/acme/agents/tf/svg0_prior/learning.py new file mode 100644 index 0000000000..2baa1feac9 --- /dev/null +++ b/acme/agents/tf/svg0_prior/learning.py @@ -0,0 +1,387 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG learner implementation.""" + +import time +from typing import Dict, Iterator, List, Optional + +import acme +from acme.agents.tf.svg0_prior import utils as svg0_utils +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +from trfl import continuous_retrace_ops + +_MIN_LOG_VAL = 1e-20 + + +class SVG0Learner(acme.Learner): + """SVG0 learner with optional prior. + + This is the learning component of an SVG0 agent. IE it takes a dataset as + input and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + policy_network: snt.Module, + critic_network: snt.Module, + target_policy_network: snt.Module, + target_critic_network: snt.Module, + discount: float, + target_update_period: int, + dataset_iterator: Iterator[reverb.ReplaySample], + prior_network: Optional[snt.Module] = None, + target_prior_network: Optional[snt.Module] = None, + policy_optimizer: Optional[snt.Optimizer] = None, + critic_optimizer: Optional[snt.Optimizer] = None, + prior_optimizer: Optional[snt.Optimizer] = None, + distillation_cost: Optional[float] = 1e-3, + entropy_regularizer_cost: Optional[float] = 1e-3, + num_action_samples: int = 10, + lambda_: float = 1.0, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None, + checkpoint: bool = True, + ): + """Initializes the learner. + + Args: + policy_network: the online (optimized) policy. + critic_network: the online critic. + target_policy_network: the target policy (which lags behind the online + policy). + target_critic_network: the target critic. + discount: discount to use for TD updates. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset_iterator: dataset to learn from, whether fixed or from a replay + buffer (see `acme.datasets.reverb.make_dataset` documentation). + prior_network: the online (optimized) prior. + target_prior_network: the target prior (which lags behind the online + prior). + policy_optimizer: the optimizer to be applied to the SVG-0 (policy) loss. + critic_optimizer: the optimizer to be applied to the distributional + Bellman loss. + prior_optimizer: the optimizer to be applied to the prior (distillation) + loss. + distillation_cost: a multiplier to be used when adding distillation + against the prior to the losses. + entropy_regularizer_cost: a multiplier used for per state sample based + entropy added to the actor loss. + num_action_samples: the number of action samples to use for estimating the + value function and sample based entropy. + lambda_: the `lambda` value to be used with retrace. + counter: counter object used to keep track of steps. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + # Store online and target networks. + self._policy_network = policy_network + self._critic_network = critic_network + self._target_policy_network = target_policy_network + self._target_critic_network = target_critic_network + + self._prior_network = prior_network + self._target_prior_network = target_prior_network + + self._lambda = lambda_ + self._num_action_samples = num_action_samples + self._distillation_cost = distillation_cost + self._entropy_regularizer_cost = entropy_regularizer_cost + + # General learner book-keeping and loggers. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger('learner') + + # Other learner parameters. + self._discount = discount + + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_update_period = target_update_period + + # Batch dataset and create iterator. + self._iterator = dataset_iterator + + # Create optimizers if they aren't given. + self._critic_optimizer = critic_optimizer or snt.optimizers.Adam(1e-4) + self._policy_optimizer = policy_optimizer or snt.optimizers.Adam(1e-4) + self._prior_optimizer = prior_optimizer or snt.optimizers.Adam(1e-4) + + # Expose the variables. + self._variables = { + 'critic': self._critic_network.variables, + 'policy': self._policy_network.variables, + } + if self._prior_network is not None: + self._variables['prior'] = self._prior_network.variables + + # Create a checkpointer and snapshotter objects. + self._checkpointer = None + self._snapshotter = None + + if checkpoint: + objects_to_save = { + 'counter': self._counter, + 'policy': self._policy_network, + 'critic': self._critic_network, + 'target_policy': self._target_policy_network, + 'target_critic': self._target_critic_network, + 'policy_optimizer': self._policy_optimizer, + 'critic_optimizer': self._critic_optimizer, + 'num_steps': self._num_steps, + } + if self._prior_network is not None: + objects_to_save['prior'] = self._prior_network + objects_to_save['target_prior'] = self._target_prior_network + objects_to_save['prior_optimizer'] = self._prior_optimizer + + self._checkpointer = tf2_savers.Checkpointer( + subdirectory='svg0_learner', + objects_to_save=objects_to_save) + objects_to_snapshot = { + 'policy': self._policy_network, + 'critic': self._critic_network, + } + if self._prior_network is not None: + objects_to_snapshot['prior'] = self._prior_network + + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save=objects_to_snapshot) + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + # Update target network + online_variables = [ + *self._critic_network.variables, + *self._policy_network.variables, + ] + if self._prior_network is not None: + online_variables += [*self._prior_network.variables] + online_variables = tuple(online_variables) + + target_variables = [ + *self._target_critic_network.variables, + *self._target_policy_network.variables, + ] + if self._prior_network is not None: + target_variables += [*self._target_prior_network.variables] + target_variables = tuple(target_variables) + + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`. + sample: reverb.ReplaySample = next(self._iterator) + data = tf2_utils.batch_to_sequence(sample.data) + observations, actions, rewards, discounts, extra = (data.observation, + data.action, + data.reward, + data.discount, + data.extras) + online_target_pi_q = svg0_utils.OnlineTargetPiQ( + online_pi=self._policy_network, + online_q=self._critic_network, + target_pi=self._target_policy_network, + target_q=self._target_critic_network, + num_samples=self._num_action_samples, + online_prior=self._prior_network, + target_prior=self._target_prior_network, + ) + with tf.GradientTape(persistent=True) as tape: + step_outputs = svg0_utils.static_rnn( + core=online_target_pi_q, + inputs=(observations, actions), + unroll_length=rewards.shape[0]) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + target_pi_samples = tf2_utils.batch_to_sequence( + step_outputs.target_samples) + # Tile observations to have shape [S, T+1, B,..]. + tiled_observations = tf2_utils.tile_nested(observations, + self._num_action_samples) + + # Finally compute target Q values on the new action samples. + # Shape: [S, T+1, B, 1] + target_q_target_pi_samples = snt.BatchApply(self._target_critic_network, + 3)(tiled_observations, + target_pi_samples) + # Compute the value estimate by averaging over the action dimension. + # Shape: [T+1, B, 1]. + target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples, axis=0) + + # Split the target V's into the target for learning + # `value_function_target` and the bootstrap value. Shape: [T, B]. + value_function_target = tf.squeeze(target_v_target_pi[:-1], axis=-1) + # Shape: [B]. + bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1) + + # When learning with a prior, add entropy terms to value targets. + if self._prior_network is not None: + value_function_target -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[:-1] + ) + bootstrap_value -= self._distillation_cost * tf.stop_gradient( + step_outputs.analytic_kl_to_target[-1]) + + # Get target log probs and behavior log probs from rollout. + # Shape: [T+1, B]. + target_log_probs_behavior_actions = ( + step_outputs.target_log_probs_behavior_actions) + behavior_log_probs = extra['log_prob'] + # Calculate importance weights. Shape: [T+1, B]. + rhos = tf.exp(target_log_probs_behavior_actions - behavior_log_probs) + + # Filter the importance weights to mask out episode restarts. Ignore the + # last action and consider the step type of the next step for masking. + # Shape: [T, B]. + episode_start_mask = tf2_utils.batch_to_sequence( + sample.data.start_of_episode)[1:] + + rhos = svg0_utils.mask_out_restarting(rhos[:-1], episode_start_mask) + + # rhos = rhos[:-1] + # Compute the log importance weights with a small value added for + # stability. + # Shape: [T, B] + log_rhos = tf.math.log(rhos + _MIN_LOG_VAL) + + # Retrieve the target and online Q values and throw away the last action. + # Shape: [T, B]. + target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1) + online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1) + + # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the + # number of action samples taken. + online_pi_samples = tf2_utils.batch_to_sequence( + step_outputs.online_samples) + target_q_online_pi_samples = snt.BatchApply(self._target_critic_network, + 3)(tiled_observations, + online_pi_samples) + expected_q = tf.reduce_mean( + tf.squeeze(target_q_online_pi_samples, -1), axis=0) + + # Flip online_log_probs to be of shape [S, T+1, B] and then compute + # entropy by averaging over num samples. Final shape: [T+1, B]. + online_log_probs = tf2_utils.batch_to_sequence( + step_outputs.online_log_probs) + sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0) + retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights( + log_rhos=log_rhos, + discounts=self._discount * discounts[:-1], + rewards=rewards[:-1], + q_values=target_q_values, + values=value_function_target, + bootstrap_value=bootstrap_value, + lambda_=self._lambda, + ) + + # Critic loss. Shape: [T, B]. + critic_loss = 0.5 * tf.math.squared_difference( + tf.stop_gradient(retrace_outputs.qs), online_q_values) + + # Policy loss- SVG0 with sample based entropy. Shape: [T, B] + policy_loss = -( + expected_q + self._entropy_regularizer_cost * sample_based_entropy) + policy_loss = policy_loss[:-1] + + if self._prior_network is not None: + # When training the prior, also add the per-timestep KL cost. + policy_loss += ( + self._distillation_cost * step_outputs.analytic_kl_to_target[:-1]) + + # Ensure episode restarts are masked out when computing the losses. + critic_loss = svg0_utils.mask_out_restarting(critic_loss, + episode_start_mask) + critic_loss = tf.reduce_mean(critic_loss) + + policy_loss = svg0_utils.mask_out_restarting(policy_loss, + episode_start_mask) + policy_loss = tf.reduce_mean(policy_loss) + + if self._prior_network is not None: + prior_loss = step_outputs.analytic_kl_divergence[:-1] + prior_loss = svg0_utils.mask_out_restarting(prior_loss, + episode_start_mask) + prior_loss = tf.reduce_mean(prior_loss) + + # Get trainable variables. + policy_variables = self._policy_network.trainable_variables + critic_variables = self._critic_network.trainable_variables + + # Compute gradients. + policy_gradients = tape.gradient(policy_loss, policy_variables) + critic_gradients = tape.gradient(critic_loss, critic_variables) + if self._prior_network is not None: + prior_variables = self._prior_network.trainable_variables + prior_gradients = tape.gradient(prior_loss, prior_variables) + + # Delete the tape manually because of the persistent=True flag. + del tape + + # Apply gradients. + self._policy_optimizer.apply(policy_gradients, policy_variables) + self._critic_optimizer.apply(critic_gradients, critic_variables) + losses = { + 'critic_loss': critic_loss, + 'policy_loss': policy_loss, + } + + if self._prior_network is not None: + self._prior_optimizer.apply(prior_gradients, prior_variables) + losses['prior_loss'] = prior_loss + + # Losses to track. + return losses + + def step(self): + # Run the learning step. + fetches = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + fetches.update(counts) + + # Checkpoint and attempt to write the logs. + if self._checkpointer is not None: + self._checkpointer.save() + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(fetches) + + def get_variables(self, names: List[str]) -> List[List[np.ndarray]]: + return [tf2_utils.to_numpy(self._variables[name]) for name in names] diff --git a/acme/agents/tf/svg0_prior/networks.py b/acme/agents/tf/svg0_prior/networks.py new file mode 100644 index 0000000000..f7911a804d --- /dev/null +++ b/acme/agents/tf/svg0_prior/networks.py @@ -0,0 +1,119 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for different experiment flavours.""" + +import functools +from typing import Mapping, Sequence, Optional + +from acme import specs +from acme import types +from acme.agents.tf.svg0_prior import utils as svg0_utils +from acme.tf import networks +from acme.tf import utils as tf2_utils + +import numpy as np +import sonnet as snt + + +def make_default_networks( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (256, 256, 256), + critic_layer_sizes: Sequence[int] = (512, 512, 256), +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + + policy_network = snt.Sequential([ + tf2_utils.batch_concat, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.3, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + + return { + "policy": policy_network, + "critic": critic_network, + } + + +def make_network_with_prior( + action_spec: specs.BoundedArray, + policy_layer_sizes: Sequence[int] = (200, 100), + critic_layer_sizes: Sequence[int] = (400, 300), + prior_layer_sizes: Sequence[int] = (200, 100), + policy_keys: Optional[Sequence[str]] = None, + prior_keys: Optional[Sequence[str]] = None, +) -> Mapping[str, types.TensorTransformation]: + """Creates networks used by the agent.""" + + # Get total number of action dimensions from action spec. + num_dimensions = np.prod(action_spec.shape, dtype=int) + flatten_concat_policy = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=policy_keys) + flatten_concat_prior = functools.partial( + svg0_utils.batch_concat_selection, concat_keys=prior_keys) + + policy_network = snt.Sequential([ + flatten_concat_policy, + networks.LayerNormMLP(policy_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + # The multiplexer concatenates the (maybe transformed) observations/actions. + multiplexer = networks.CriticMultiplexer( + observation_network=flatten_concat_policy, + action_network=networks.ClipToSpec(action_spec)) + critic_network = snt.Sequential([ + multiplexer, + networks.LayerNormMLP(critic_layer_sizes, activate_final=True), + networks.NearZeroInitializedLinear(1), + ]) + prior_network = snt.Sequential([ + flatten_concat_prior, + networks.LayerNormMLP(prior_layer_sizes, activate_final=True), + networks.MultivariateNormalDiagHead( + num_dimensions, + tanh_mean=True, + min_scale=0.1, + init_scale=0.7, + fixed_scale=False, + use_tfd_independent=False) + ]) + return { + "policy": policy_network, + "critic": critic_network, + "prior": prior_network, + } diff --git a/acme/agents/tf/svg0_prior/utils.py b/acme/agents/tf/svg0_prior/utils.py new file mode 100644 index 0000000000..8474fea6cc --- /dev/null +++ b/acme/agents/tf/svg0_prior/utils.py @@ -0,0 +1,157 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for SVG0 algorithm with priors.""" + +import collections +from typing import Tuple, Optional, Dict, Iterable + +from acme import types +from acme.tf import utils as tf2_utils + +import sonnet as snt +import tensorflow as tf +import tree + + +class OnlineTargetPiQ(snt.Module): + """Core to unroll online and target policies and Q functions at once. + + A core that runs online and target policies and Q functions. This can be more + efficient if the core needs to be unrolled across time and called many times. + """ + + def __init__(self, + online_pi: snt.Module, + online_q: snt.Module, + target_pi: snt.Module, + target_q: snt.Module, + num_samples: int, + online_prior: Optional[snt.Module] = None, + target_prior: Optional[snt.Module] = None, + name='OnlineTargetPiQ'): + super().__init__(name) + + self._online_pi = online_pi + self._target_pi = target_pi + self._online_q = online_q + self._target_q = target_q + self._online_prior = online_prior + self._target_prior = target_prior + + self._num_samples = num_samples + output_list = [ + 'online_samples', 'target_samples', 'target_log_probs_behavior_actions', + 'online_log_probs', 'online_q', 'target_q' + ] + if online_prior is not None: + output_list += ['analytic_kl_divergence', 'analytic_kl_to_target'] + self._output_tuple = collections.namedtuple( + 'OnlineTargetPiQ', output_list) + + def __call__(self, input_obs_and_action: Tuple[tf.Tensor, tf.Tensor]): + (obs, action) = input_obs_and_action + online_pi_dist = self._online_pi(obs) + target_pi_dist = self._target_pi(obs) + + online_samples = online_pi_dist.sample(self._num_samples) + target_samples = target_pi_dist.sample(self._num_samples) + target_log_probs_behavior_actions = target_pi_dist.log_prob(action) + + online_log_probs = online_pi_dist.log_prob(tf.stop_gradient(online_samples)) + + online_q_out = self._online_q(obs, action) + target_q_out = self._target_q(obs, action) + + output_list = [ + online_samples, target_samples, target_log_probs_behavior_actions, + online_log_probs, online_q_out, target_q_out + ] + + if self._online_prior is not None: + prior_dist = self._online_prior(obs) + target_prior_dist = self._target_prior(obs) + analytic_kl_divergence = online_pi_dist.kl_divergence(prior_dist) + analytic_kl_to_target = online_pi_dist.kl_divergence(target_prior_dist) + + output_list += [analytic_kl_divergence, analytic_kl_to_target] + output = self._output_tuple(*output_list) + return output + + +def static_rnn(core: snt.Module, inputs: types.NestedTensor, + unroll_length: int): + """Unroll core along inputs for unroll_length steps. + + Note: for time-major input tensors whose leading dimension is less than + unroll_length, `None` would be provided instead. + + Args: + core: an instance of snt.Module. + inputs: a `nest` of time-major input tensors. + unroll_length: number of time steps to unroll. + + Returns: + step_outputs: a `nest` of time-major stacked output tensors of length + `unroll_length`. + """ + step_outputs = [] + for time_dim in range(unroll_length): + inputs_t = tree.map_structure( + lambda t, i_=time_dim: t[i_] if i_ < t.shape[0] else None, inputs) + step_output = core(inputs_t) + step_outputs.append(step_output) + + step_outputs = _nest_stack(step_outputs) + return step_outputs + + +def mask_out_restarting(tensor: tf.Tensor, start_of_episode: tf.Tensor): + """Mask out `tensor` taken on the step that resets the environment. + + Args: + tensor: a time-major 2-D `Tensor` of shape [T, B]. + start_of_episode: a 2-D `Tensor` of shape [T, B] that contains the points + where the episode restarts. + + Returns: + tensor of shape [T, B] with elements are masked out according to step_types, + restarting weights of shape [T, B] + """ + tensor.get_shape().assert_has_rank(2) + start_of_episode.get_shape().assert_has_rank(2) + weights = tf.cast(~start_of_episode, dtype=tf.float32) + masked_tensor = tensor * weights + return masked_tensor + + +def batch_concat_selection(observation_dict: Dict[str, types.NestedTensor], + concat_keys: Optional[Iterable[str]] = None, + output_dtype=tf.float32) -> tf.Tensor: + """Concatenate a dict of observations into 2-D tensors.""" + concat_keys = concat_keys or sorted(observation_dict.keys()) + to_concat = [] + for obs in concat_keys: + if obs not in observation_dict: + raise KeyError( + 'Missing observation. Requested: {} (available: {})'.format( + obs, list(observation_dict.keys()))) + to_concat.append(tf.cast(observation_dict[obs], output_dtype)) + + return tf2_utils.batch_concat(to_concat) + + +def _nest_stack(list_of_nests, axis=0): + """Convert a list of nests to a nest of stacked lists.""" + return tree.map_structure(lambda *ts: tf.stack(ts, axis=axis), *list_of_nests) diff --git a/examples/control/lp_local_svg0.py b/examples/control/lp_local_svg0.py new file mode 100644 index 0000000000..2a03620bbf --- /dev/null +++ b/examples/control/lp_local_svg0.py @@ -0,0 +1,58 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running SVG0 on the control suite.""" + +from absl import app +from absl import flags +from acme.agents.tf import svg0_prior +import helpers + +from acme.utils import lp_utils + +import launchpad as lp + +FLAGS = flags.FLAGS +flags.DEFINE_string('domain', 'cartpole', 'Control suite domain name (str).') +flags.DEFINE_string('task', 'balance', 'Control suite task name (str).') + + +def main(_): + environment_factory = lp_utils.partial_kwargs( + helpers.make_environment, domain_name=FLAGS.domain, task_name=FLAGS.task) + + batch_size = 32 + sequence_length = 20 + gradient_steps_per_actor_step = 1.0 + samples_per_insert = ( + gradient_steps_per_actor_step * batch_size * sequence_length) + num_actors = 1 + + program = svg0_prior.DistributedSVG0( + environment_factory=environment_factory, + network_factory=lp_utils.partial_kwargs(svg0_prior.make_default_networks), + batch_size=batch_size, + sequence_length=sequence_length, + samples_per_insert=samples_per_insert, + entropy_regularizer_cost=1e-4, + max_replay_size=int(2e6), + target_update_period=250, + num_actors=num_actors).build() + + lp.launch(program, lp.LaunchType.LOCAL_MULTI_PROCESSING) + + +if __name__ == '__main__': + app.run(main)