Skip to content

Commit

Permalink
Remove ConcreteArray from JAX. It's easy to do trace-time concretizat…
Browse files Browse the repository at this point in the history
…ion without it.

PiperOrigin-RevId: 691929385
  • Loading branch information
dougalm authored and Google-ML-Automation committed Oct 31, 2024
1 parent 8536eca commit 48f24b6
Show file tree
Hide file tree
Showing 22 changed files with 83 additions and 168 deletions.
9 changes: 5 additions & 4 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)

UnshapedArray = core.UnshapedArray
ShapedArray = core.ShapedArray
ConcreteArray = core.ConcreteArray
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
Expand All @@ -47,8 +45,11 @@
array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic

def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
weak_type=weak_type)
weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type
dtype = dtypes.canonicalize_dtype(np.result_type(val))
dtypes.check_valid_dtype(dtype)
sharding = core._get_abstract_sharding(val)
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)

def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from jax._src import traceback_util
from jax._src import pjit
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray
from jax._src.core import eval_jaxpr, ShapedArray
from jax._src.api_util import (
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, donation_vector,
Expand Down Expand Up @@ -2188,9 +2188,9 @@ def _infer_src_sharding(src, x) -> Sharding | None:
if isinstance(x, array.ArrayImpl):
return x.sharding
elif isinstance(x, core.Tracer):
aval = core.get_aval(x)
if isinstance(aval, ConcreteArray) and isinstance(aval.val, array.ArrayImpl):
return aval.val.sharding
val = x.to_concrete_value()
if val is not None and isinstance(val, array.ArrayImpl):
return val.sharding
return None


Expand Down
2 changes: 0 additions & 2 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
global_aval, out_sharding, committed=committed, _skip_checks=True
)
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler

# Only used for Arrays that come out of pmap.
def _array_local_result_handler(aval, sharding, indices):
Expand All @@ -1197,7 +1196,6 @@ def _array_local_result_handler(aval, sharding, indices):
aval, sharding, committed=True, _skip_checks=True
)
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler


# Token handlers
Expand Down
108 changes: 26 additions & 82 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
import functools
from functools import partial, partialmethod, total_ordering
from functools import partial, total_ordering
import gc
import inspect
import itertools as it
Expand Down Expand Up @@ -696,6 +696,10 @@ def __reversed__(self):
def __len__(self):
return self.aval._len(self)

def to_concrete_value(self):
# Should return the concrete value if there is one, or else None.
return None

@property
def sharding(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
Expand Down Expand Up @@ -739,10 +743,12 @@ def get_referent(self) -> Any:
return self # Override for object equivalence checking

def __bool__(self):
if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_bool_conversion(self)
return self.aval._bool(self)

def __int__(self):
if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_scalar_conversion(self)
return self.aval._int(self)

Expand All @@ -755,14 +761,17 @@ def __complex__(self):
return self.aval._complex(self)

def __hex__(self):
if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._hex(self)

def __oct__(self):
if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._oct(self)

def __index__(self):
if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types
check_integer_conversion(self)
return self.aval._index(self)

Expand Down Expand Up @@ -1393,12 +1402,16 @@ def get_aval(x):
else:
return concrete_aval(x)

def get_type(x):
aval = get_aval(x)
if isinstance(aval, ConcreteArray):
return raise_to_shaped(aval)
get_type = get_aval

def is_concrete(x):
return to_concrete_value(x) is not None

def to_concrete_value(x):
if isinstance(x, Tracer):
return x.to_concrete_value()
else:
return aval
return x

def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
Expand All @@ -1423,10 +1436,11 @@ def concrete_or_error(force: Any, val: Any, context=""):
if force is None:
force = lambda x: x
if isinstance(val, Tracer):
if isinstance(val.aval, ConcreteArray):
return force(val.aval.val)
else:
maybe_concrete = val.to_concrete_value()
if maybe_concrete is None:
raise ConcretizationTypeError(val, context)
else:
return force(maybe_concrete)
else:
return force(val)

Expand Down Expand Up @@ -1578,7 +1592,7 @@ def _invalid_shape_error(shape: Shape, context: str=""):
msg += f" {context}."
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
and not is_concrete(x) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
Expand Down Expand Up @@ -1677,10 +1691,6 @@ def _get_shape_sharding_str(shape, spec):
else:
yield f"{s1}@{s2}"


def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)

def _get_abstract_sharding(val):
from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error

Expand All @@ -1690,59 +1700,6 @@ def _get_abstract_sharding(val):
val.sharding._normalized_spec(val.ndim))
return None

class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0

def __init__(self, dtype, val, weak_type=None):
super().__init__(
np.shape(val), dtype,
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type,
sharding=_get_abstract_sharding(val))
dtypes.check_valid_dtype(self.dtype)
# Note: canonicalized self.dtype doesn't necessarily match self.val
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
self.val = val

def update(self, dtype=None, val=None, weak_type=None):
dtype = self.dtype if dtype is None else dtype
val = self.val if val is None else val
weak_type = self.weak_type if weak_type is None else weak_type
return ConcreteArray(dtype, val, weak_type)

def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type):
with eval_context(): # in case self.val is an Array
return (self.val == other.val).all()
else:
return False

def __hash__(self):
return id(self.val)

def join(self, other) -> AbstractValue:
if self == other:
return self
elif self.shape == other.shape and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
else:
raise TypeError(self, other)

def str_short(self, short_dtypes=False) -> str:
dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return f'{self.val}, dtype={dt_str}'

_bool = partialmethod(_forward_to_value, bool)
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_index = partialmethod(_forward_to_value, operator.index)

_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)

def primal_dtype_to_tangent_dtype(primal_dtype):
if isinstance(primal_dtype, dtypes.ExtendedDType):
return primal_dtype._rules.tangent_dtype(primal_dtype)
Expand Down Expand Up @@ -1817,14 +1774,6 @@ def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)

class DConcreteArray(DShapedArray):
__slots__ = ['val']
array_abstraction_level = 1
def __init__(self, shape, dtype, weak_type, val):
super().__init__(shape, dtype, weak_type)
self.val = val


pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}


Expand Down Expand Up @@ -1881,8 +1830,7 @@ def data(self):


pytype_aval_mappings[DArray] = \
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
x._data)
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)

@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
Expand Down Expand Up @@ -1984,10 +1932,7 @@ def _shaped_array_mapping(aval, weak_type):
AbstractToken: lambda aval, _: aval,
Bot: lambda aval, _: aval,
ShapedArray: _shaped_array_mapping,
DShapedArray: lambda aval, _: aval,
DConcreteArray: lambda aval, weak_type: DShapedArray(
aval.shape, aval.dtype, weak_type
),
DShapedArray: lambda aval, _: aval
}

### Operations on shapes and dimension sizes.
Expand Down Expand Up @@ -2323,7 +2268,6 @@ def _unmap_dshaped_array(
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
ConcreteArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ___, a: a)
}

Expand Down
3 changes: 3 additions & 0 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ def full_lower(self):
else:
return self

def to_concrete_value(self):
return core.to_concrete_value(self.primal)

def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
Expand Down
1 change: 0 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes:
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err

ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.ConcreteArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get()
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types

Expand Down
3 changes: 1 addition & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
fun_sourceinfo)
from jax._src.core import (Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval,
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, Var, DropVar, raise_to_shaped, Atom,
Var, DropVar, raise_to_shaped, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent, JaxprEqnContext)
Expand Down Expand Up @@ -299,7 +299,6 @@ def process_call(self, primitive, f, tracers, params):
# With dynamic shapes, we may need to substitute Tracers into avals.
out_tracers = []
for aval, _ in out_type:
assert not isinstance(aval, ConcreteArray)
if type(aval) is DShapedArray:
shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val]
if type(d) is InDBIdx else d for d in aval.shape]
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.core import ShapedArray
from jax._src.util import safe_zip, safe_map

from jax._src.typing import Shape
Expand Down Expand Up @@ -101,7 +101,6 @@ def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]:
_xla_shape_handlers: dict[type[core.AbstractValue],
Callable[[Any], Sequence[xc.Shape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)

Expand Down
7 changes: 3 additions & 4 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from jax._src import util
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.core import raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -130,8 +130,7 @@ def switch(index, branches, *operands):
hi = np.array(len(branches) - 1, np.int32)
index = lax.clamp(lo, index, hi)

if (config.disable_jit.value and
isinstance(core.get_aval(index), ConcreteArray)):
if (config.disable_jit.value and core.is_concrete(index)):
return branches[int(index)](*operands)

ops, ops_tree = tree_flatten(operands)
Expand Down Expand Up @@ -220,7 +219,7 @@ def cond(pred, true_fun, false_fun, *operands):
msg = ("Pred type must be either boolean or number, got {}.")
raise TypeError(msg.format(pred_dtype))

if config.disable_jit.value and isinstance(core.get_aval(pred), ConcreteArray):
if config.disable_jit.value and core.is_concrete(pred):
if pred:
return true_fun(*operands)
else:
Expand Down
7 changes: 3 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from jax._src import state
from jax._src import util
from jax._src.api_util import shaped_abstractify
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -2015,12 +2015,11 @@ def fori_loop(lower, upper, body_fun, init_val):

# If we can specialize on the trip count, call scan instead of a while_loop
# to enable efficient reverse-mode differentiation.
if (isinstance(core.get_aval(lower), ConcreteArray) and
isinstance(core.get_aval(upper), ConcreteArray)):
if core.is_concrete(lower) and core.is_concrete(upper):
try:
lower_ = int(lower)
upper_ = int(upper)
except TypeError:
except (TypeError, core.InconclusiveDimensionOperation):
use_scan = False
else:
use_scan = True
Expand Down
Loading

0 comments on commit 48f24b6

Please sign in to comment.