Skip to content

Commit

Permalink
[JAX] Fix incorrect type annotations.
Browse files Browse the repository at this point in the history
An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases.

PiperOrigin-RevId: 557814902
Change-Id: Id3975ca545f6051ecc9526134251a87ca59b483a
  • Loading branch information
hawkinsp authored and Copybara-Service committed Aug 17, 2023
1 parent 76af644 commit c7690d1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 24 deletions.
5 changes: 3 additions & 2 deletions acme/agents/jax/dqn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ class DQNConfig:
num_sgd_steps_per_step: int = 1


def logspace_epsilons(num_epsilons: int, epsilon: float = 0.017
) -> Sequence[float]:
def logspace_epsilons(
num_epsilons: int, epsilon: float = 0.017
) -> Union[Sequence[float], jnp.ndarray]:
"""`num_epsilons` of logspace-distributed values, with median `epsilon`."""
if num_epsilons <= 1:
return (epsilon,)
Expand Down
4 changes: 2 additions & 2 deletions acme/agents/jax/r2d2/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def loss(
params: networks_lib.Params,
target_params: networks_lib.Params,
key_grad: networks_lib.PRNGKey,
sample: reverb.ReplaySample
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
sample: reverb.ReplaySample,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Computes mean transformed N-step loss for a batch of sequences."""

# Get core state & warm it up on observations for a burn-in period.
Expand Down
28 changes: 14 additions & 14 deletions acme/jax/losses/mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,24 @@ class MPOParams(NamedTuple):

class MPOStats(NamedTuple):
"""NamedTuple to store loss statistics."""
dual_alpha_mean: float
dual_alpha_stddev: float
dual_temperature: float
dual_alpha_mean: Union[float, jnp.ndarray]
dual_alpha_stddev: Union[float, jnp.ndarray]
dual_temperature: Union[float, jnp.ndarray]

loss_policy: float
loss_alpha: float
loss_temperature: float
kl_q_rel: float
loss_policy: Union[float, jnp.ndarray]
loss_alpha: Union[float, jnp.ndarray]
loss_temperature: Union[float, jnp.ndarray]
kl_q_rel: Union[float, jnp.ndarray]

kl_mean_rel: float
kl_stddev_rel: float
kl_mean_rel: Union[float, jnp.ndarray]
kl_stddev_rel: Union[float, jnp.ndarray]

q_min: float
q_max: float
q_min: Union[float, jnp.ndarray]
q_max: Union[float, jnp.ndarray]

pi_stddev_min: float
pi_stddev_max: float
pi_stddev_cond: float
pi_stddev_min: Union[float, jnp.ndarray]
pi_stddev_max: Union[float, jnp.ndarray]
pi_stddev_cond: Union[float, jnp.ndarray]

penalty_kl_q_rel: Optional[float] = None

Expand Down
13 changes: 7 additions & 6 deletions acme/jax/networks/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
- X?: X is optional (e.g. optional batch/sequence dimension).
"""
from typing import Optional, Tuple, Sequence
from typing import Any, Optional, Sequence, Tuple

from acme.jax.networks import base
from acme.jax.networks import duelling
Expand Down Expand Up @@ -120,9 +120,9 @@ def __init__(self, num_actions: int):
self._head = policy_value.PolicyValueHead(num_actions)
self._num_actions = num_actions

def __call__(self, inputs: observation_action_reward.OAR,
state: hk.LSTMState) -> base.LSTMOutputs:

def __call__(
self, inputs: observation_action_reward.OAR, state: hk.LSTMState
) -> Any:
embeddings = self._embed(inputs) # [B?, D+A+1]
embeddings, new_state = self._core(embeddings, state)
logits, value = self._head(embeddings) # logits: [B?, A], value: [B?, 1]
Expand All @@ -133,8 +133,9 @@ def initial_state(self, batch_size: Optional[int],
**unused_kwargs) -> hk.LSTMState:
return self._core.initial_state(batch_size)

def unroll(self, inputs: observation_action_reward.OAR,
state: hk.LSTMState) -> base.LSTMOutputs:
def unroll(
self, inputs: observation_action_reward.OAR, state: hk.LSTMState
) -> Any:
"""Efficient unroll that applies embeddings, MLP, & convnet in one pass."""
embeddings = self._embed(inputs)
embeddings, new_states = hk.static_unroll(self._core, embeddings, state)
Expand Down

0 comments on commit c7690d1

Please sign in to comment.