Skip to content

Commit

Permalink
Add multi-head critic network.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 327418649
Change-Id: I3d46f5340926cd7c9d37fbfd65786ebb8e9c4159
  • Loading branch information
Acme Contributor authored and copybara-github committed Aug 19, 2020
1 parent 7eeeb24 commit 61bad92
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
1 change: 1 addition & 0 deletions acme/tf/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from acme.tf.networks.distributional import UnivariateGaussianMixture
from acme.tf.networks.distributions import DiscreteValuedDistribution
from acme.tf.networks.duelling import DuellingMLP
from acme.tf.networks.multihead import Multihead
from acme.tf.networks.multiplexers import CriticMultiplexer
from acme.tf.networks.noise import ClippedGaussian
from acme.tf.networks.policy_value import PolicyValueHead
Expand Down
53 changes: 53 additions & 0 deletions acme/tf/networks/multihead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# python3
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multihead networks apply separate networks to the input."""

from typing import Callable, Union, Sequence

from acme import types

import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor],
tf.Tensor]]


class Multihead(snt.Module):
"""Multi-head network module.
This takes as input a list of N `network_heads`, and returns another network
whose output is the stacked outputs of each of these network heads separately
applied to the module input. The dimension of the output is [..., N].
"""

def __init__(self,
network_heads: Sequence[TensorTransformation]):
if not network_heads:
raise ValueError('Must specify non-empty, non-None critic_network_heads.')
self._network_heads = network_heads
super().__init__(name='multihead')

def __call__(self,
inputs: tf.Tensor) -> Union[tf.Tensor, Sequence[tf.Tensor]]:
outputs = [network_head(inputs) for network_head in self._network_heads]
if isinstance(outputs[0], tfd.Distribution):
# Cannot stack distributions
return outputs
outputs = tf.stack(outputs, axis=-1)
return outputs
7 changes: 7 additions & 0 deletions acme/tf/networks/multiplexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

import sonnet as snt
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor],
tf.Tensor]]

Expand Down Expand Up @@ -63,6 +65,11 @@ def __call__(self,
if self._action_network:
action = self._action_network(action)

if hasattr(observation, 'dtype') and hasattr(action, 'dtype'):
if observation.dtype != action.dtype:
# Observation and action must be the same type for concat to work
action = tf.cast(action, observation.dtype)

# Concat observations and actions, with one batch dimension.
outputs = tf2_utils.batch_concat([observation, action])

Expand Down

0 comments on commit 61bad92

Please sign in to comment.