Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

omnistaging #3370

Merged
merged 1 commit into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,28 @@ jobs:
- python-version: 3.6
os: ubuntu-latest
enable-x64: 0
enable-omnistaging: 0
package-overrides: "none"
num_generated_cases: 25
- python-version: 3.7
os: ubuntu-latest
enable-x64: 1
enable-omnistaging: 0
package-overrides: "none"
num_generated_cases: 25
- python-version: 3.6
os: ubuntu-latest
enable-x64: 1
enable-omnistaging: 0
# Test with numpy version that matches Google-internal version
package-overrides: "numpy==1.16.4"
num_generated_cases: 10
- python-version: 3.7
os: ubuntu-latest
enable-x64: 0
enable-omnistaging: 1
package-overrides: "none"
num_generated_cases: 8
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -73,11 +82,13 @@ jobs:
env:
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
JAX_OMNISTAGING: ${{ matrix.enable-omnistaging }}
run: |
pip install -e .
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
if [ $JAX_ENABLE_X64 = 0 ]; then
echo "JAX_OMNISTAGING=$JAX_OMNISTAGING"
if [ $JAX_ENABLE_X64 = 0 -a $JAX_OMNISTAGING = 0 ]; then
pytest -n auto jax/experimental/jax2tf/tests
fi
pytest -n auto tests examples
Expand Down
6 changes: 5 additions & 1 deletion jax/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


from jax import core
from .core import (lattice_join, Primitive, Unit, unit, AbstractUnit,
valid_jaxtype, raise_to_shaped, get_aval)
from .tree_util import register_pytree_node
Expand All @@ -27,7 +28,10 @@
jaxval_adders[Unit] = lambda _, __: unit

def add_jaxvals(x, y):
return add_jaxvals_p.bind(x, y)
if core.get_aval(x) is core.abstract_unit is core.get_aval(y):
return core.unit
else:
return add_jaxvals_p.bind(x, y)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added logic to skip the bind here just because it made jaxprs cleaner to look at in some cases.


add_jaxvals_p = Primitive('add_any')

Expand Down
157 changes: 52 additions & 105 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
tree_transpose, tree_leaves, tree_multimap,
treedef_is_leaf, Partial)
from .util import (unzip2, curry, partial, safe_map, safe_zip, prod,
split_list, extend_name_stack, wrap_name)
split_list, extend_name_stack, wrap_name, cache)
from .lib import xla_bridge as xb
from .lib import xla_client as xc
# Unused imports to be exported
Expand Down Expand Up @@ -104,12 +104,13 @@ def jit(fun: Callable, static_argnums: Union[int, Iterable[int]] = (),
why hash and equality operators must be defined.
static_argnums: An int or collection of ints specifying which positional
arguments to treat as static (compile-time constant). Operations that only
depend on static arguments will be constant-folded. Calling the jitted
function with different values for these constants will trigger
recompilation. If the jitted function is called with fewer positional
arguments than indicated by ``static_argnums`` then an error is raised.
Arguments that are not arrays or containers thereof must be marked as
static. Defaults to ().
depend on static arguments will be constant-folded in Python (during
tracing), and so the corrersponding argument values can be any Python
object. Calling the jitted function with different values for these
constants will trigger recompilation. If the jitted function is called
with fewer positional arguments than indicated by ``static_argnums`` then
an error is raised. Arguments that are not arrays or containers thereof
must be marked as static. Defaults to ().
device: This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited from
Expand Down Expand Up @@ -228,7 +229,7 @@ def xla_computation(fun: Callable,
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
backend: Optional[str] = None,
tuple_args: bool = False,
instantiate_const_outputs: bool = True,
instantiate_const_outputs: Optional[bool] = None,
return_shape: bool = False) -> Callable:
"""Creates a function that produces its XLA computation given example args.

Expand All @@ -247,20 +248,23 @@ def xla_computation(fun: Callable,
tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting
XLA computation will have a single tuple argument that is unpacked into
the specified function arguments.
instantiate_const_outputs: Optional bool, defaults to ``True``. If
``False``, then :py:func:`xla_computation` does not instantiate
constant-valued outputs in the XLA computation, and so the result is
closer to the computation that :py:func:`jax.jit` produces and may be more
useful for studying :py:func:`jit` behavior. If ``True``, then
constant-valued outputs are instantiated in the XLA computation, which may
be more useful for staging computations out of JAX entirely.
instantiate_const_outputs: Deprecated argument, does nothing.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.

Returns:
A wrapped version of ``fun`` that when applied to example arguments returns a
built XLA Computation (see xla_client.py), from which representations of the
unoptimized XLA HLO computation can be extracted using methods like
A wrapped version of ``fun`` that when applied to example arguments returns
a built XLA Computation (see xla_client.py), from which representations of
the unoptimized XLA HLO computation can be extracted using methods like
``as_hlo_text``, ``as_serialized_hlo_module_proto``, and
``as_hlo_dot_graph``.
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
wrapped function eturns a pair where the first element is the XLA
Computation and the second element is a pytree representing the structure,
shapes, and dtypes of the output of ``fun``.

For example:

Expand Down Expand Up @@ -326,18 +330,20 @@ def xla_computation(fun: Callable,
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}
"""
del instantiate_const_outputs # Unused

_check_callable(fun)
if isinstance(static_argnums, int):
static_argnums = (static_argnums,)
fun_name = getattr(fun, '__name__', 'unknown')

def make_axis_env(nreps):
if axis_env is None:
return xla.AxisEnv(nreps)
return xla.AxisEnv(nreps, (), (), None)
else:
nreps = nreps * prod(size for name, size in axis_env)
names, sizes = zip(*axis_env)
return xla.AxisEnv(nreps, names, sizes)
return xla.AxisEnv(nreps, names, sizes, None)

def abstractify(x):
return ShapedArray(np.shape(x), dtypes.result_type(x))
Expand All @@ -351,9 +357,13 @@ def computation_maker(*args, **kwargs):
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
avals = map(abstractify, jax_args)
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, pvals, instantiate=instantiate_const_outputs, stage_out=True)
if config.omnistaging_enabled:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
else:
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, pvals, instantiate=True, stage_out=True)
out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals]
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
Expand All @@ -364,7 +374,6 @@ def computation_maker(*args, **kwargs):
extend_name_stack(wrap_name(fun_name, 'xla_computation')), *xla_args)
built = c.build(xc.ops.Tuple(c, outs))
if return_shape:
out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
return built, out_shape
Expand Down Expand Up @@ -1190,8 +1199,10 @@ def __eq__(self, other):
return type(other) is _TempAxisName and self.obj == other.obj


def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, *,
in_axes=0, backend: Optional[str] = None) -> Callable:
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
) -> Callable:
if not config.omnistaging_enabled:
raise NotImplementedError("soft_pmap requires omnistaging.")
warn("soft_pmap is an experimental feature and probably has bugs!")
_check_callable(fun)
axis_name = _TempAxisName(fun) if axis_name is None else axis_name
Expand All @@ -1208,45 +1219,11 @@ def f_pmapped(*args, **kwargs):
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)

chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count(backend))
if chunk_size == 0 and leftover:
return pmap(fun, axis_name, backend=backend)(*args) # can map directly onto hardware
elif leftover:
msg = ("soft_pmap mapped axis size must be divisible by the number of "
"XLA devices (or be less than or equal to that number), but got "
"an axis size of {} with {} devices.")
raise ValueError(msg.format(axis_size, pxla.unmapped_device_count()))
num_chunks = axis_size // chunk_size

reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
soft_mapped_fun = pxla.split_axis(flat_fun, axis_name, chunk_size)
# TODO(tomhennigan): soft_pmap should support buffer donation.
donated_invars = (False,) * len(reshaped_args)
reshaped_outs = pxla.xla_pmap(soft_mapped_fun, *reshaped_args, backend=backend,
axis_name=axis_name, axis_size=num_chunks,
global_axis_size=None, devices=None,
name=soft_mapped_fun.__name__,
mapped_invars=mapped_invars,
donated_invars=donated_invars)
outs = [_reshape_merge(out) for out in reshaped_outs]
outs = pxla.soft_pmap(flat_fun, *args_flat, axis_name=axis_name,
axis_size=axis_size, mapped_invars=mapped_invars)
return tree_unflatten(out_tree(), outs)
return f_pmapped

def _reshape_split(num_chunks, x):
aval = core.get_aval(x)
if aval is core.abstract_unit:
return x
else:
return x.reshape((num_chunks, x.shape[0] // num_chunks) + x.shape[1:])

def _reshape_merge(x):
aval = core.get_aval(x)
if aval is core.abstract_unit:
return x
else:
return x.reshape((-1,) + x.shape[2:])


def _papply(fun):
# This function is for testing purposes.
Expand All @@ -1264,37 +1241,6 @@ def papply_fun(*args, **kwargs):
return papply_fun, axis_name


def _parallelize(fun):
axis_name = _TempAxisName(fun)

def pfun(*args):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten(args)
f, out_tree = flatten_fun_nokwargs(f, in_tree)
axis_size = _mapped_axis_size(
in_tree, args_flat, (0,) * len(args_flat), "parallelize")

chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
if chunk_size == 0 and leftover:
return pmap(fun, axis_name)(*args) # can map directly onto hardware
elif leftover:
raise ValueError
num_chunks = axis_size // chunk_size

reshaped_args = [_reshape_split(num_chunks, x) for x in args_flat]
f, out_axes = parallel.papply_transform(f, axis_name, axis_size)
f = pxla.split_axis(f, axis_name, chunk_size)
outs = pxla.xla_pmap(f, *reshaped_args, backend=None, axis_name=axis_name,
axis_size=num_chunks, global_axis_size=None,
devices=None, name=f.__name__)
outs = map(_reshape_merge, outs)
outs = [batching.matchaxis(axis_size, 0, dst, x)
for dst, x in zip(out_axes(), outs)]
return tree_unflatten(out_tree(), outs)

return pfun


def mask(fun: Callable, in_shapes, out_shape) -> Callable:
_check_callable(fun)
unique_ids = masking.UniqueIds()
Expand Down Expand Up @@ -1635,10 +1581,6 @@ def make_jaxpr(fun: Callable,
if isinstance(static_argnums, int):
static_argnums = (static_argnums,)

def pv_like(x):
aval = xla.abstractify(x)
return pe.PartialVal.unknown(aval)

@wraps(fun)
def jaxpr_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun)
Expand All @@ -1647,11 +1589,14 @@ def jaxpr_maker(*args, **kwargs):
wrapped, _ = argnums_partial(wrapped, dyn_argnums, args)
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
in_pvals = map(pv_like, jax_args)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
in_avals = tuple(raise_to_shaped(in_aval) for in_aval, _ in in_pvals)
in_avals = map(xla.abstractify, jax_args)
if config.omnistaging_enabled:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
else:
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr

Expand Down Expand Up @@ -1915,13 +1860,15 @@ def __repr__(self):
return '<jax.custom_transforms function {fun}>'.format(fun=self.__name__)

def __call__(self, *args):
# TODO(mattjj): instead of tracing to a jaxpr, use process_call
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_pvals = [pe.PartialVal.unknown(raise_to_shaped(core.get_aval(x)))
for x in args_flat]
with core.initial_style_staging():
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
else:
with core.initial_style_staging():
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
in_tree=in_tree, out_tree=out_tree(),
num_consts=len(consts))
Expand Down
24 changes: 20 additions & 4 deletions jax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ def bool_env(varname: str, default: bool) -> bool:
raise ValueError("invalid truth value %r for environment %r" % (val, varname))


class Config(object):
class Config:
def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self.omnistaging_enabled = False

self.omnistaging_enabled = False
self.omnistaging_enablers = []

def update(self, name, val):
if self.use_absl:
setattr(self.absl_flags.FLAGS, name, val)
Expand Down Expand Up @@ -113,8 +116,16 @@ def parse_flags_with_absl(self):
self.complete_absl_config(absl.flags)
already_configured_with_absl = True

if FLAGS.jax_omnistaging:
self.enable_omnistaging()

# TODO(mattjj): remove this when omnistaging fully lands
def enable_omnistaging(self):
pass # placeholder
if not self.omnistaging_enabled:
for enabler in self.omnistaging_enablers:
enabler()
self.omnistaging_enabled = True


class NameSpace(object):
def __init__(self, getter):
Expand All @@ -133,6 +144,11 @@ def __getattr__(self, name):
flags.DEFINE_bool(
'jax_enable_checks',
bool_env('JAX_ENABLE_CHECKS', False),
help=
'Turn on invariant checking (core.skip_checks = False)'
help='Turn on invariant checking (core.skip_checks = False)'
)

flags.DEFINE_bool(
'jax_omnistaging',
bool_env('JAX_OMNISTAGING', False),
help='Enable staging based on dynamic context rather than data dependence.'
)
Loading