From eb7073fc7bcbd5b2c26ae5e52ab37e7212a5a1d4 Mon Sep 17 00:00:00 2001 From: Acme Contributor Date: Wed, 6 Jan 2021 12:52:04 -0800 Subject: [PATCH] Add a JAX Behaviour Cloning learner PiperOrigin-RevId: 350410049 Change-Id: I635259e24399cc54051bf04f9331afacdea992a4 --- acme/agents/jax/bc/__init__.py | 18 +++ acme/agents/jax/bc/learning.py | 131 ++++++++++++++++++++ acme/agents/tf/bc/learning.py | 7 +- examples/bsuite/run_bc_jax.py | 210 +++++++++++++++++++++++++++++++++ 4 files changed, 362 insertions(+), 4 deletions(-) create mode 100644 acme/agents/jax/bc/__init__.py create mode 100644 acme/agents/jax/bc/learning.py create mode 100644 examples/bsuite/run_bc_jax.py diff --git a/acme/agents/jax/bc/__init__.py b/acme/agents/jax/bc/__init__.py new file mode 100644 index 0000000000..df8ebd3195 --- /dev/null +++ b/acme/agents/jax/bc/__init__.py @@ -0,0 +1,18 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of a behavior cloning (BC) agent.""" + +from acme.agents.jax.bc.learning import BCLearner diff --git a/acme/agents/jax/bc/learning.py b/acme/agents/jax/bc/learning.py new file mode 100644 index 0000000000..2b7356fa28 --- /dev/null +++ b/acme/agents/jax/bc/learning.py @@ -0,0 +1,131 @@ +# Lint as: python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BC Learner implementation.""" + +from typing import Dict, List, NamedTuple, Tuple, Callable + +import acme +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +from dm_env import specs +import haiku as hk +import jax +import jax.numpy as jnp +import optax +import reverb +import rlax +import tensorflow as tf + +LossFn = Callable[[jnp.DeviceArray, jnp.DeviceArray], jnp.DeviceArray] + + +def _sparse_categorical_cross_entropy( + labels: jnp.DeviceArray, logits: jnp.DeviceArray) -> jnp.DeviceArray: + """Converts labels to one-hot and computes categorical cross-entropy.""" + one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1]) + cce = jax.vmap(rlax.categorical_cross_entropy) + return cce(one_hot_labels, logits) + + +class TrainingState(NamedTuple): + """Holds the agent's training state.""" + params: hk.Params + opt_state: optax.OptState + steps: int + + +class BCLearner(acme.Learner, acme.Saveable): + """BC learner. + + This is the learning component of a BC agent. IE it takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + def __init__(self, + network: hk.Transformed, + obs_spec: specs.Array, + optimizer: optax.GradientTransformation, + rng: hk.PRNGSequence, + dataset: tf.data.Dataset, + loss_fn: LossFn = _sparse_categorical_cross_entropy, + counter: counting.Counter = None, + logger: loggers.Logger = None): + """Initializes the learner.""" + + def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray: + # Pull out the data needed for updates. + o_tm1, a_tm1, r_t, d_t, o_t = sample.data + del r_t, d_t, o_t + logits = network.apply(params, o_tm1) + return jnp.mean(loss_fn(a_tm1, logits)) + + def sgd_step( + state: TrainingState, sample: reverb.ReplaySample + ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]: + """Do a step of SGD.""" + grad_fn = jax.value_and_grad(loss) + loss_value, gradients = grad_fn(state.params, sample) + updates, new_opt_state = optimizer.update(gradients, state.opt_state) + new_params = optax.apply_updates(state.params, updates) + + steps = state.steps + 1 + + new_state = TrainingState( + params=new_params, opt_state=new_opt_state, steps=steps) + + # Compute the global norm of the gradients for logging. + global_gradient_norm = optax.global_norm(gradients) + fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm} + + return new_state, fetches + + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + # Get an iterator over the dataset. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + # TODO(b/155086959): Fix type stubs and remove. + + # Initialise parameters and optimiser state. + initial_params = network.init( + next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) + initial_opt_state = optimizer.init(initial_params) + + self._state = TrainingState( + params=initial_params, opt_state=initial_opt_state, steps=0) + + self._sgd_step = jax.jit(sgd_step) + + def step(self): + batch = next(self._iterator) + # Do a batch of SGD. + self._state, result = self._sgd_step(self._state, batch) + + # Update our counts and record it. + counts = self._counter.increment(steps=1) + result.update(counts) + + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[hk.Params]: + return [self._state.params] + + def save(self) -> TrainingState: + return self._state + + def restore(self, state: TrainingState): + self._state = state diff --git a/acme/agents/tf/bc/learning.py b/acme/agents/tf/bc/learning.py index 9bd6ee61cb..4701414b54 100644 --- a/acme/agents/tf/bc/learning.py +++ b/acme/agents/tf/bc/learning.py @@ -31,8 +31,7 @@ class BCLearner(acme.Learner, tf2_savers.TFSaveable): """BC learner. This is the learning component of a BC agent. IE it takes a dataset as input - and implements update functionality to learn from this dataset. Optionally - it takes a replay client as well to allow for updating of priorities. + and implements update functionality to learn from this dataset. """ def __init__(self, @@ -45,8 +44,8 @@ def __init__(self, """Initializes the learner. Args: - network: the online Q network (the one being optimized) - learning_rate: learning rate for the q-network update. + network: the BC network (the one being optimized) + learning_rate: learning rate for the cross-entropy update. dataset: dataset to learn from. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. diff --git a/examples/bsuite/run_bc_jax.py b/examples/bsuite/run_bc_jax.py new file mode 100644 index 0000000000..cc31aaed08 --- /dev/null +++ b/examples/bsuite/run_bc_jax.py @@ -0,0 +1,210 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example BC running on BSuite.""" + +import functools +import operator +from typing import Callable + +from absl import app +from absl import flags +import acme +from acme import specs +from acme import types +from acme.agents.jax import actors +from acme.agents.jax.bc import learning +from acme.agents.tf.dqfd import bsuite_demonstrations +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +from acme.wrappers import single_precision +import bsuite +import haiku as hk +import jax.numpy as jnp +import optax +import reverb +import rlax +import tensorflow as tf +import tensorflow_datasets as tfds +import tree + +# Bsuite flags +flags.DEFINE_string('bsuite_id', 'deep_sea/0', 'Bsuite id.') +flags.DEFINE_string('results_dir', '/tmp/bsuite', 'CSV results directory.') +flags.DEFINE_boolean('overwrite', False, 'Whether to overwrite csv results.') + +# Agent flags +flags.DEFINE_float('learning_rate', 2e-4, 'Learning rate.') +flags.DEFINE_integer('batch_size', 16, 'Batch size.') +flags.DEFINE_float('epsilon', 0., 'Epsilon for the epsilon greedy in the env.') +flags.DEFINE_integer('evaluate_every', 100, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer('seed', 42, 'Random seed for learner and evaluator.') + +FLAGS = flags.FLAGS + + +def make_policy_network( + action_spec: specs.DiscreteArray +) -> Callable[[jnp.DeviceArray], jnp.DeviceArray]: + """Make the policy network..""" + + def network(x: jnp.DeviceArray) -> jnp.DeviceArray: + mlp = hk.Sequential( + [hk.Flatten(), + hk.nets.MLP([64, 64, action_spec.num_values])]) + return mlp(x) + + return network + + +# TODO(b/152733199): Move this function to acme utils. +def _n_step_transition_from_episode(observations: types.NestedTensor, + actions: tf.Tensor, rewards: tf.Tensor, + discounts: tf.Tensor, n_step: int, + additional_discount: float): + """Produce Reverb-like N-step transition from a full episode. + + Observations, actions, rewards and discounts have the same length. This + function will ignore the first reward and discount and the last action. + + Args: + observations: [L, ...] Tensor. + actions: [L, ...] Tensor. + rewards: [L] Tensor. + discounts: [L] Tensor. + n_step: number of steps to squash into a single transition. + additional_discount: discount to use for TD updates. + + Returns: + (o_t, a_t, r_t, d_t, o_tp1) tuple. + """ + + max_index = tf.shape(rewards)[0] - 1 + first = tf.random.uniform( + shape=(), minval=0, maxval=max_index - 1, dtype=tf.int32) + last = tf.minimum(first + n_step, max_index) + + o_t = tree.map_structure(operator.itemgetter(first), observations) + a_t = tree.map_structure(operator.itemgetter(first), actions) + o_tp1 = tree.map_structure(operator.itemgetter(last), observations) + + # 0, 1, ..., n-1. + discount_range = tf.cast(tf.range(last - first), tf.float32) + # 1, g, ..., g^{n-1}. + additional_discounts = tf.pow(additional_discount, discount_range) + # 1, d_t, d_t * d_{t+1}, ..., d_t * ... * d_{t+n-2}. + discounts = tf.concat([[1.], tf.math.cumprod(discounts[first:last - 1])], 0) + # 1, g * d_t, ..., g^{n-1} * d_t * ... * d_{t+n-2}. + discounts *= additional_discounts + # r_t + g * d_t * r_{t+1} + ... + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1} + # We have to shift rewards by one so last=max_index corresponds to transitions + # that include the last reward. + r_t = tf.reduce_sum(rewards[first + 1:last + 1] * discounts) + + # g^{n-1} * d_{t} * ... * d_{t+n-1}. + d_t = discounts[-1] + + # Reverb requires every sample to be given a key and priority. + # In the supervised learning case for BC, neither of those will be used. + # We set the key to `0` and the priorities probabilities to `1`, but that + # should not matter much. + key = tf.constant(0, tf.uint64) + probability = tf.constant(1.0, tf.float64) + table_size = tf.constant(1, tf.int64) + priority = tf.constant(1.0, tf.float64) + info = reverb.SampleInfo( + key=key, + probability=probability, + table_size=table_size, + priority=priority, + ) + + return reverb.ReplaySample(info=info, data=(o_t, a_t, r_t, d_t, o_tp1)) + + +def main(_): + # Create an environment and grab the spec. + raw_environment = bsuite.load_and_record_to_csv( + bsuite_id=FLAGS.bsuite_id, + results_dir=FLAGS.results_dir, + overwrite=FLAGS.overwrite, + ) + environment = single_precision.SinglePrecisionWrapper(raw_environment) + environment_spec = specs.make_environment_spec(environment) + + # Build demonstration dataset. + if hasattr(raw_environment, 'raw_env'): + raw_environment = raw_environment.raw_env + + batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) + # Combine with demonstration dataset. + transition = functools.partial( + _n_step_transition_from_episode, n_step=1, additional_discount=1.) + + dataset = batch_dataset.map(transition) + + # Batch and prefetch. + dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + dataset = tfds.as_numpy(dataset) + + # Create the networks to optimize. + policy_network = make_policy_network(environment_spec.actions) + policy_network = hk.without_apply_rng(hk.transform(policy_network)) + + # If the agent is non-autoregressive use epsilon=0 which will be a greedy + # policy. + def evaluator_network(params: hk.Params, key: jnp.DeviceArray, + observation: jnp.DeviceArray) -> jnp.DeviceArray: + action_values = policy_network.apply(params, observation) + return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values) + + counter = counting.Counter() + learner_counter = counting.Counter(counter, prefix='learner') + + # The learner updates the parameters (and initializes them). + learner = learning.BCLearner( + network=policy_network, + optimizer=optax.adam(FLAGS.learning_rate), + obs_spec=environment.observation_spec(), + dataset=dataset, + counter=learner_counter, + rng=hk.PRNGSequence(FLAGS.seed)) + + # Create the actor which defines how we take actions. + variable_client = variable_utils.VariableClient(learner, '') + evaluator = actors.FeedForwardActor( + evaluator_network, + variable_client=variable_client, + rng=hk.PRNGSequence(FLAGS.seed)) + + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + counter=counter, + logger=loggers.TerminalLogger('evaluation', time_delta=1.)) + + # Run the environment loop. + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + learner_counter.increment(learner_steps=FLAGS.evaluate_every) + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main)