From 4e358e5ae0249d17472701516e458b292655381f Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Wed, 1 Mar 2023 01:00:03 -0800 Subject: [PATCH] Update pytypes. PiperOrigin-RevId: 513161806 --- chex/_src/pytypes.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/chex/_src/pytypes.py b/chex/_src/pytypes.py index c1d642df..53f4ac46 100644 --- a/chex/_src/pytypes.py +++ b/chex/_src/pytypes.py @@ -14,37 +14,31 @@ # ============================================================================== """Type definitions to use for type annotations.""" -from typing import Any, Iterable, Mapping, Union +from typing import Any, TypeAlias, Union import jax import jax.numpy as jnp import numpy as np # Special types of arrays. -ArrayBatched = jax.interpreters.batching.BatchTracer -ArrayNumpy = np.ndarray -ArraySharded = jax.interpreters.pxla.ShardedDeviceArray +ArrayBatched: TypeAlias = jax.interpreters.batching.BatchTracer +ArrayNumpy: TypeAlias = np.ndarray +ArraySharded: TypeAlias = jax.interpreters.pxla.ShardedDeviceArray # For instance checking, use `isinstance(x, jax.Array)`. -if hasattr(jax, 'Array'): - ArrayDevice = jax.Array # jax >= 0.3.20 -elif hasattr(jax.interpreters.xla, '_DeviceArray'): # 0.2.5 < jax < 0.3.20 - ArrayDevice = jax.interpreters.xla._DeviceArray # pylint:disable=protected-access -else: # jax <= 0.2.5 - ArrayDevice = jax.interpreters.xla.DeviceArray +ArrayDevice: TypeAlias = jax.Array # jax >= 0.3.20 # Generic array type. -Array = Union[ArrayDevice, ArrayNumpy, ArrayBatched, ArraySharded] +Array = Union[jax.Array, np.ndarray] +ArrayLike: TypeAlias = jax.typing.ArrayLike # A tree of generic arrays. -ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] +ArrayTree = Any +# Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] # Other types. Scalar = Union[float, int] Numeric = Union[Array, Scalar] -Shape = jax.core.Shape -PRNGKey = jax.random.KeyArray -PyTreeDef = type(jax.tree_util.tree_structure(None)) -if hasattr(jax, 'Device'): - Device = jax.Device # jax >= 0.4.3 -else: - Device = jax.lib.xla_extension.Device +Shape: TypeAlias = jax.core.Shape +PRNGKey: TypeAlias = jax.random.KeyArray +PyTreeDef: TypeAlias = jax.tree_util.PyTreeDef +Device: TypeAlias = jax.Device # jax >= 0.4.3 ArrayDType = type(jnp.float32)