From da79d6debe7050b31e63e245094b46b5aca8769c Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sun, 18 Dec 2022 02:47:30 +0000 Subject: [PATCH 01/15] pjit to sharding spec --- alpa/shard_parallel/compile_executable.py | 54 +++++++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/alpa/shard_parallel/compile_executable.py b/alpa/shard_parallel/compile_executable.py index cb2963b6e..4b5e44adf 100644 --- a/alpa/shard_parallel/compile_executable.py +++ b/alpa/shard_parallel/compile_executable.py @@ -1,16 +1,20 @@ """Compile executables for shard parallelism.""" import hashlib import inspect -from typing import Callable, Sequence, Optional, Union +from typing import Callable, Sequence, Optional, OrderedDict, Union import numpy as np -from jax import linear_util as lu +from jax import linear_util as lu, pxla from jax._src import traceback_util -from jax._src.lib import xla_extension as xe +from jax._src.lib import xla_client as xc, xla_extension as xe +from jax._src.tree_util import _replace_nones +from jax._src.util import safe_zip from jax.core import (Jaxpr, ClosedJaxpr, Literal, gensym, get_aval, raise_to_shaped, AbstractValue) +from jax.experimental import pjit +from jax.interpreters import mlir from jax.lax import add_p, div_p -from jax.tree_util import PyTreeDef +from jax.tree_util import PyTreeDef, tree_flatten, tree_map, tree_unflatten from alpa.device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh from alpa.global_env import global_config @@ -393,3 +397,45 @@ def add_gradient_accumulation(raw_jaxpr, num_micro_batches): ] return (combined_jaxpr, accumulate_grad_invar_indices, apply_grad_invar_indices, num_grads) + + +def pjit_to_sharding_spec(mesh_axis_names, logical_mesh, in_axis_resources, + in_tree, avals): + + def _parsed_pspec_to_hlo_sharding( + mesh_shape, + mesh_axis_names, + _parsed_pspec, + num_dimensions: int + ) -> xc.HloSharding: + + array_mapping = pjit.get_array_mapping(_parsed_pspec) + sharding_spec = pxla.new_mesh_sharding_specs( + mesh_shape, mesh_axis_names)(num_dimensions, array_mapping) + # Used in `with_sharding_constraint`. + special_axes = {} + op_sharding = sharding_spec.sharding_proto(special_axes=special_axes) + return xc.HloSharding.from_proto(op_sharding) + + + def flatten_axes(treedef, axis_tree): + proxy = object() + dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) + axes = [] + add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) + tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) + axes = [None if a is proxy else a for a in axes] + assert len(axes) == treedef.num_leaves + return axes + mesh_shape = OrderedDict( + (name, size) + for name, size in safe_zip(mesh_axis_names, logical_mesh.shape)) + + in_axis_resources, _, _, _ = pjit._prepare_axis_resources( + in_axis_resources, "in_axis_resources") + in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) + canonicalized_shardings = tuple( + _parsed_pspec_to_hlo_sharding(mesh_shape, mesh_axis_names, axis, + len(aval.shape)) + for axis, aval in safe_zip(in_axis_flat, avals)) + return canonicalized_shardings From 25750e583a2a3ed7813b6e805a5cc1030d6e4fa7 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Tue, 20 Dec 2022 19:29:06 +0000 Subject: [PATCH 02/15] move to manual sharding file --- alpa/__init__.py | 1 + alpa/shard_parallel/compile_executable.py | 54 +----------- alpa/shard_parallel/manual_sharding.py | 102 ++++++++++++++++++++++ 3 files changed, 107 insertions(+), 50 deletions(-) create mode 100644 alpa/shard_parallel/manual_sharding.py diff --git a/alpa/__init__.py b/alpa/__init__.py index a51282e24..4aa4325ea 100644 --- a/alpa/__init__.py +++ b/alpa/__init__.py @@ -45,6 +45,7 @@ AutoStageOption, UniformStageOption) from alpa.shard_parallel.auto_sharding import AutoShardingOption +from alpa.shard_parallel.manual_sharding import ManualShardingOption from alpa.serialization import save_checkpoint, restore_checkpoint from alpa.timer import timers from alpa.version import __version__ diff --git a/alpa/shard_parallel/compile_executable.py b/alpa/shard_parallel/compile_executable.py index 4b5e44adf..cb2963b6e 100644 --- a/alpa/shard_parallel/compile_executable.py +++ b/alpa/shard_parallel/compile_executable.py @@ -1,20 +1,16 @@ """Compile executables for shard parallelism.""" import hashlib import inspect -from typing import Callable, Sequence, Optional, OrderedDict, Union +from typing import Callable, Sequence, Optional, Union import numpy as np -from jax import linear_util as lu, pxla +from jax import linear_util as lu from jax._src import traceback_util -from jax._src.lib import xla_client as xc, xla_extension as xe -from jax._src.tree_util import _replace_nones -from jax._src.util import safe_zip +from jax._src.lib import xla_extension as xe from jax.core import (Jaxpr, ClosedJaxpr, Literal, gensym, get_aval, raise_to_shaped, AbstractValue) -from jax.experimental import pjit -from jax.interpreters import mlir from jax.lax import add_p, div_p -from jax.tree_util import PyTreeDef, tree_flatten, tree_map, tree_unflatten +from jax.tree_util import PyTreeDef from alpa.device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh from alpa.global_env import global_config @@ -397,45 +393,3 @@ def add_gradient_accumulation(raw_jaxpr, num_micro_batches): ] return (combined_jaxpr, accumulate_grad_invar_indices, apply_grad_invar_indices, num_grads) - - -def pjit_to_sharding_spec(mesh_axis_names, logical_mesh, in_axis_resources, - in_tree, avals): - - def _parsed_pspec_to_hlo_sharding( - mesh_shape, - mesh_axis_names, - _parsed_pspec, - num_dimensions: int - ) -> xc.HloSharding: - - array_mapping = pjit.get_array_mapping(_parsed_pspec) - sharding_spec = pxla.new_mesh_sharding_specs( - mesh_shape, mesh_axis_names)(num_dimensions, array_mapping) - # Used in `with_sharding_constraint`. - special_axes = {} - op_sharding = sharding_spec.sharding_proto(special_axes=special_axes) - return xc.HloSharding.from_proto(op_sharding) - - - def flatten_axes(treedef, axis_tree): - proxy = object() - dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) - axes = [] - add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) - tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) - axes = [None if a is proxy else a for a in axes] - assert len(axes) == treedef.num_leaves - return axes - mesh_shape = OrderedDict( - (name, size) - for name, size in safe_zip(mesh_axis_names, logical_mesh.shape)) - - in_axis_resources, _, _, _ = pjit._prepare_axis_resources( - in_axis_resources, "in_axis_resources") - in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) - canonicalized_shardings = tuple( - _parsed_pspec_to_hlo_sharding(mesh_shape, mesh_axis_names, axis, - len(aval.shape)) - for axis, aval in safe_zip(in_axis_flat, avals)) - return canonicalized_shardings diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py new file mode 100644 index 000000000..3c079eca6 --- /dev/null +++ b/alpa/shard_parallel/manual_sharding.py @@ -0,0 +1,102 @@ +import dataclasses +from typing import Any, Optional, OrderedDict, Tuple, Union + +from jax._src.lib import xla_client as xc +from jax._src.tree_util import _replace_nones +from jax._src.util import safe_zip +from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources +from jax.interpreters import mlir, xla +from jax.tree_util import tree_unflatten, tree_flatten, tree_map + +from jax import pxla + + +@dataclasses.dataclass +class ManualShardingOption: + """Options to manually set shardings in pjit convention.""" + use_manual_sharding: bool = False + mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None + in_axis_resources: Any = None + out_axis_resources: Any = None + + +def _parsed_pspec_to_hlo_sharding( + mesh_shape, + mesh_axis_names, + _parsed_pspec, + num_dimensions: int, + axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None +) -> xc.OpSharding: + """ + TODO(yonghao): support unspecified and auto + + This function inlines _create_mesh_pspec_sharding_from_parsed_pspec and + _process_in_axis_resources. It skips some checks there including + _is_unspecified_or_from_gda_or_auto, pjit_check_aval_sharding. It also skips + the local-global translation because we always assume alpa handles jaxprs at + the driver side. + """ + + array_mapping = get_array_mapping(_parsed_pspec) + sharding_spec = pxla.new_mesh_sharding_specs(mesh_shape, mesh_axis_names)( + num_dimensions, array_mapping) + # Used in `with_sharding_constraint`. + special_axes = {} + # Manual axes is only used with xmap. + if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext): + axis_names = mesh_axis_names + # Ignore type because mypy doesn't recognize the `hasattr` check above. + for manual_axis in axis_ctx.manual_axes: # type: ignore + special_axes[axis_names.index( + manual_axis)] = xc.OpSharding.Type.MANUAL + op_sharding = sharding_spec.sharding_proto(special_axes=special_axes) + return op_sharding + + +def flatten_axes(treedef, axis_tree): + """Flatten the axis tree and consider None as an effective value.""" + proxy = object() + dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) + axes = [] + add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) + tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) + axes = [None if a is proxy else a for a in axes] + assert len(axes) == treedef.num_leaves + return axes + + +def get_manual_sharding_spec( + sharding_option: ManualShardingOption, mesh_shape, in_tree, out_tree, + in_avals, out_avals) -> Tuple[Tuple[xc.HloSharding], xc.HloSharding]: + """Create input and output sharding spec from user's in_axis_resources.""" + named_mesh_shape = OrderedDict( + (name, size) + for name, size in safe_zip(sharding_option.mesh_axis_names, mesh_shape)) + + in_axis_resources, _, _, any_auto = _prepare_axis_resources( + sharding_option.in_axis_resources, "in_axis_resources") + out_axis_resources, _, _, _ = _prepare_axis_resources( + sharding_option.out_axis_resources, "out_axis_resources") + if any_auto: + raise NotImplementedError( + "auto mode in manual partition is unsupported.") + + in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) + in_op_shardings = tuple( + _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. + mesh_axis_names, axis, len(aval.shape)) + for axis, aval in safe_zip(in_axis_flat, in_avals)) + out_axis_flat = tuple(flatten_axes(out_tree, out_axis_resources)) + out_op_shardings = tuple( + _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. + mesh_axis_names, axis, len(aval.shape)) + for axis, aval in safe_zip(out_axis_flat, out_avals)) + # Tuple[OpSharding] -> OpSharding w/ type=TUPLE + tuple_output_sharding = xla.tuple_sharding_proto(out_op_shardings) + # OpSharding->HloSharding + in_hlo_shardings = tuple([ + xc.HloSharding.from_proto(op_sharding) + for op_sharding in in_op_shardings + ]) + out_hlo_sharding = xc.HloSharding.from_proto(tuple_output_sharding) + return in_hlo_shardings, out_hlo_sharding From 22a8cc682e432d814addb33764010ceb6ac64ba2 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Tue, 20 Dec 2022 21:55:46 +0000 Subject: [PATCH 03/15] add manual sharding in shard parallel --- alpa/parallel_method.py | 7 ++- alpa/shard_parallel/auto_sharding.py | 2 +- alpa/shard_parallel/compile_executable.py | 25 ++++++++-- alpa/shard_parallel/manual_sharding.py | 57 ++++++++++++----------- 4 files changed, 56 insertions(+), 35 deletions(-) diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index c63ac9742..3c5cf9549 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -38,6 +38,7 @@ UniformStageOption) from alpa.shard_parallel.auto_sharding import AutoShardingOption, LogicalDeviceMesh from alpa.shard_parallel.compile_executable import compile_shard_executable +from alpa.shard_parallel.manual_sharding import ManualShardingOption traceback_util.register_exclusion(__file__) @@ -75,10 +76,12 @@ def __init__(self, devices: Optional[Union[LogicalDeviceMesh, PhysicalDeviceMesh]] = None, num_micro_batches: Optional[int] = None, - auto_sharding_option: Optional[AutoShardingOption] = None): + auto_sharding_option: Optional[AutoShardingOption] = None, + manual_sharding_option: Optional[ManualShardingOption] = None): self.devices = devices self.num_micro_batches = num_micro_batches self.as_option = auto_sharding_option or AutoShardingOption() + self.ms_option = manual_sharding_option def compile_executable( self, @@ -106,7 +109,7 @@ def compile_executable( static_argnums, donated_invars, batch_invars, mesh, self.num_micro_batches, self.as_option, - *avals) + self.ms_option, *avals) class DataParallel(ShardParallel): diff --git a/alpa/shard_parallel/auto_sharding.py b/alpa/shard_parallel/auto_sharding.py index cf140c147..dce55d18f 100644 --- a/alpa/shard_parallel/auto_sharding.py +++ b/alpa/shard_parallel/auto_sharding.py @@ -381,7 +381,7 @@ def run_spmd_partitioner_pass( rewrite_grad_acc_indices: The indices of tensors in output that are gradients. """ - assert hlo.is_sharding_annotated(), f"{hlo.status}" + assert hlo.is_sharding_annotated(), hlo.status compile_options = get_compile_options( num_replicas=1, num_partitions=num_devices, diff --git a/alpa/shard_parallel/compile_executable.py b/alpa/shard_parallel/compile_executable.py index cb2963b6e..1da416f0d 100644 --- a/alpa/shard_parallel/compile_executable.py +++ b/alpa/shard_parallel/compile_executable.py @@ -20,6 +20,8 @@ from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass, run_spmd_partitioner_pass, AutoShardingOption) +from alpa.shard_parallel.manual_sharding import (ManualShardingOption, + get_manual_sharding_spec) from alpa.util import (jaxpr_to_hlo, trace_jaxpr_with_micro_batch, setup_computation_alias, OrderedSet, new_jaxpr_eqn) @@ -58,6 +60,7 @@ def compile_shard_executable( device_mesh: Union[PhysicalDeviceMesh, LogicalDeviceMesh], num_micro_batches: Optional[int], as_option: AutoShardingOption, + ms_option: ManualShardingOption, *avals: Sequence[AbstractValue], ): """Compile an executable with auto-sharding pass.""" @@ -74,7 +77,7 @@ def compile_shard_executable( return shard_parallel_internal(fun, in_tree, out_tree_thunk, static_argnums, donated_invars, physical_mesh, logical_mesh_choices, - as_option, *avals) + as_option, ms_option, *avals) else: if global_config.backend == "tpu": raise NotImplementedError( @@ -82,7 +85,7 @@ def compile_shard_executable( return shard_parallel_internal_gradient_accumulation( fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, physical_mesh, logical_mesh_choices, - num_micro_batches, as_option, *avals) + num_micro_batches, as_option, ms_option, *avals) def shard_parallel_internal( @@ -90,7 +93,8 @@ def shard_parallel_internal( static_argnums: Sequence[int], donated_invars: Sequence[bool], physical_mesh: PhysicalDeviceMesh, logical_mesh_choices: Sequence[LogicalDeviceMesh], - as_option: AutoShardingOption, *avals: Sequence[AbstractValue]): + as_option: AutoShardingOption, ms_option: ManualShardingOption, + *avals: Sequence[AbstractValue]): """ Compile an executable with auto-sharding pass. @@ -115,6 +119,17 @@ def shard_parallel_internal( # Convert jaxpr to XLA HLO name = f"{fun.__name__}_shard_parallel" hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars) + # Set user specified sharding specs. + if ms_option: + if as_option.enable_auto_sharding: + raise NotImplementedError("hybrid auto sharding is unsupported") + in_sharding_proto, out_sharding_proto = get_manual_sharding_spec( + ms_option, logical_mesh_choices[0].shape, in_tree, out_tree_thunk(), + avals, out_avals) + if in_sharding_proto is not None: + hlo.set_input_shardings(in_sharding_proto) + if out_sharding_proto is not None: + hlo.set_output_shardings(out_sharding_proto) flop_count = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module()) # Compile a XLA executable @@ -144,10 +159,12 @@ def shard_parallel_internal_gradient_accumulation( batch_invars: Sequence[bool], physical_mesh: PhysicalDeviceMesh, logical_mesh_choices: Sequence[LogicalDeviceMesh], num_micro_batches: int, as_option: AutoShardingOption, - *raw_avals: Sequence[AbstractValue]): + ms_option: ManualShardingOption, *raw_avals: Sequence[AbstractValue]): """Compile a gradient accumulation executable with auto-sharding pass.""" # pylint: disable=unused-argument # Split the batch dimension + if ms_option is not None: + raise NotImplementedError("Unsupported yet.") closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, batch_invars, num_micro_batches, raw_avals) diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index 3c079eca6..d02ffb61d 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -14,7 +14,6 @@ @dataclasses.dataclass class ManualShardingOption: """Options to manually set shardings in pjit convention.""" - use_manual_sharding: bool = False mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None in_axis_resources: Any = None out_axis_resources: Any = None @@ -67,36 +66,38 @@ def flatten_axes(treedef, axis_tree): def get_manual_sharding_spec( sharding_option: ManualShardingOption, mesh_shape, in_tree, out_tree, - in_avals, out_avals) -> Tuple[Tuple[xc.HloSharding], xc.HloSharding]: + in_avals, out_avals) -> Tuple[Tuple[xc.OpSharding], xc.OpSharding]: """Create input and output sharding spec from user's in_axis_resources.""" named_mesh_shape = OrderedDict( (name, size) for name, size in safe_zip(sharding_option.mesh_axis_names, mesh_shape)) - in_axis_resources, _, _, any_auto = _prepare_axis_resources( - sharding_option.in_axis_resources, "in_axis_resources") - out_axis_resources, _, _, _ = _prepare_axis_resources( - sharding_option.out_axis_resources, "out_axis_resources") - if any_auto: - raise NotImplementedError( - "auto mode in manual partition is unsupported.") + # process input + if sharding_option.in_axis_resources is not None: + in_axis_resources, _, _, any_auto = _prepare_axis_resources( + sharding_option.in_axis_resources, "in_axis_resources") + if any_auto: + raise NotImplementedError( + "auto mode in manual partition is unsupported.") + in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) + in_op_shardings = tuple( + _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. + mesh_axis_names, axis, len(aval.shape)) + for axis, aval in safe_zip(in_axis_flat, in_avals)) + else: + in_op_shardings = None - in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) - in_op_shardings = tuple( - _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. - mesh_axis_names, axis, len(aval.shape)) - for axis, aval in safe_zip(in_axis_flat, in_avals)) - out_axis_flat = tuple(flatten_axes(out_tree, out_axis_resources)) - out_op_shardings = tuple( - _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. - mesh_axis_names, axis, len(aval.shape)) - for axis, aval in safe_zip(out_axis_flat, out_avals)) - # Tuple[OpSharding] -> OpSharding w/ type=TUPLE - tuple_output_sharding = xla.tuple_sharding_proto(out_op_shardings) - # OpSharding->HloSharding - in_hlo_shardings = tuple([ - xc.HloSharding.from_proto(op_sharding) - for op_sharding in in_op_shardings - ]) - out_hlo_sharding = xc.HloSharding.from_proto(tuple_output_sharding) - return in_hlo_shardings, out_hlo_sharding + # process output + if sharding_option.out_axis_resources is not None: + out_axis_resources, _, _, _ = _prepare_axis_resources( + sharding_option.out_axis_resources, "out_axis_resources") + out_axis_flat = tuple(flatten_axes(out_tree, out_axis_resources)) + out_op_shardings = tuple( + _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. + mesh_axis_names, axis, len(aval.shape)) + for axis, aval in safe_zip(out_axis_flat, out_avals)) + # Tuple[OpSharding] -> OpSharding w/ type=TUPLE + tuple_output_sharding = xla.tuple_sharding_proto(out_op_shardings) + else: + tuple_output_sharding = None + return in_op_shardings, tuple_output_sharding From c0a5fc22e6e5a33ed097f21baf3471d4bea81e32 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Thu, 22 Dec 2022 19:18:54 +0000 Subject: [PATCH 04/15] add format --- alpa/shard_parallel/manual_sharding.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index d02ffb61d..d7fdabffd 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -1,3 +1,4 @@ +"""User specified manual sharding strategy following pjit's api.""" import dataclasses from typing import Any, Optional, OrderedDict, Tuple, Union @@ -22,7 +23,7 @@ class ManualShardingOption: def _parsed_pspec_to_hlo_sharding( mesh_shape, mesh_axis_names, - _parsed_pspec, + parsed_pspec, num_dimensions: int, axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None ) -> xc.OpSharding: @@ -36,7 +37,7 @@ def _parsed_pspec_to_hlo_sharding( the driver side. """ - array_mapping = get_array_mapping(_parsed_pspec) + array_mapping = get_array_mapping(parsed_pspec) sharding_spec = pxla.new_mesh_sharding_specs(mesh_shape, mesh_axis_names)( num_dimensions, array_mapping) # Used in `with_sharding_constraint`. @@ -56,8 +57,12 @@ def flatten_axes(treedef, axis_tree): """Flatten the axis tree and consider None as an effective value.""" proxy = object() dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) + axes = [] - add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0])) + + def add_leaves(i, x): + axes.extend([i] * len(tree_flatten(x))[0]) + tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) axes = [None if a is proxy else a for a in axes] assert len(axes) == treedef.num_leaves @@ -81,8 +86,9 @@ def get_manual_sharding_spec( "auto mode in manual partition is unsupported.") in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) in_op_shardings = tuple( - _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. - mesh_axis_names, axis, len(aval.shape)) + _parsed_pspec_to_hlo_sharding(named_mesh_shape, + sharding_option.mesh_axis_names, axis, + len(aval.shape)) for axis, aval in safe_zip(in_axis_flat, in_avals)) else: in_op_shardings = None @@ -93,8 +99,9 @@ def get_manual_sharding_spec( sharding_option.out_axis_resources, "out_axis_resources") out_axis_flat = tuple(flatten_axes(out_tree, out_axis_resources)) out_op_shardings = tuple( - _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. - mesh_axis_names, axis, len(aval.shape)) + _parsed_pspec_to_hlo_sharding(named_mesh_shape, + sharding_option.mesh_axis_names, axis, + len(aval.shape)) for axis, aval in safe_zip(out_axis_flat, out_avals)) # Tuple[OpSharding] -> OpSharding w/ type=TUPLE tuple_output_sharding = xla.tuple_sharding_proto(out_op_shardings) From 2ddede50f3e2cf34af5702a00dae860b676caa2a Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Thu, 22 Dec 2022 21:30:48 +0000 Subject: [PATCH 05/15] handle unspecified and grad acc in shard parallel --- alpa/shard_parallel/compile_executable.py | 19 ++++++++-- alpa/shard_parallel/manual_sharding.py | 43 +++++++++++++++-------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/alpa/shard_parallel/compile_executable.py b/alpa/shard_parallel/compile_executable.py index 1da416f0d..027dd5f14 100644 --- a/alpa/shard_parallel/compile_executable.py +++ b/alpa/shard_parallel/compile_executable.py @@ -22,8 +22,9 @@ AutoShardingOption) from alpa.shard_parallel.manual_sharding import (ManualShardingOption, get_manual_sharding_spec) -from alpa.util import (jaxpr_to_hlo, trace_jaxpr_with_micro_batch, - setup_computation_alias, OrderedSet, new_jaxpr_eqn) +from alpa.util import (jaxpr_to_hlo, new_jaxpr_eqn, setup_computation_alias, + trace_jaxpr_with_micro_batch, + undefined_sharding_spec_proto, OrderedSet) traceback_util.register_exclusion(__file__) @@ -182,6 +183,20 @@ def shard_parallel_internal_gradient_accumulation( flop_count = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module()) flop_count *= num_micro_batches + # Set user specified sharding specs. + if ms_option: + if as_option.enable_auto_sharding: + raise NotImplementedError("hybrid auto sharding is unsupported") + in_sharding_proto, out_sharding_proto = get_manual_sharding_spec( + ms_option, logical_mesh_choices[0].shape, in_tree, out_tree_thunk(), + in_avals, out_avals) + grad_sharding_proto = [undefined_sharding_spec_proto()] * num_grads + if in_sharding_proto is not None: + in_sharding_proto += tuple(grad_sharding_proto) + hlo.set_input_shardings(in_sharding_proto) + if out_sharding_proto is not None: + hlo.set_output_shardings(out_sharding_proto) + # pylint: disable=unbalanced-tuple-unpacking hlo_stage_names, hlo_stages, stage_plan = run_auto_sharding_pass( hlo, logical_mesh_choices[0], "stages", num_micro_batches, as_option) diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index d7fdabffd..0d694e86f 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -5,19 +5,22 @@ from jax._src.lib import xla_client as xc from jax._src.tree_util import _replace_nones from jax._src.util import safe_zip -from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources -from jax.interpreters import mlir, xla +from jax.experimental.pjit import (_is_unspecified, _is_auto, _is_from_gda, + _prepare_axis_resources, get_array_mapping, + _UNSPECIFIED) +from jax.interpreters import mlir, pxla, xla from jax.tree_util import tree_unflatten, tree_flatten, tree_map -from jax import pxla +from alpa.util import undefined_sharding_spec_proto @dataclasses.dataclass class ManualShardingOption: """Options to manually set shardings in pjit convention.""" mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None - in_axis_resources: Any = None - out_axis_resources: Any = None + # According to pjit, None means replicated. + in_axis_resources: Any = _UNSPECIFIED + out_axis_resources: Any = _UNSPECIFIED def _parsed_pspec_to_hlo_sharding( @@ -28,7 +31,7 @@ def _parsed_pspec_to_hlo_sharding( axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None ) -> xc.OpSharding: """ - TODO(yonghao): support unspecified and auto + TODO(yonghao): support auto(see how pxla.py lowers it) This function inlines _create_mesh_pspec_sharding_from_parsed_pspec and _process_in_axis_resources. It skips some checks there including @@ -36,6 +39,12 @@ def _parsed_pspec_to_hlo_sharding( the local-global translation because we always assume alpa handles jaxprs at the driver side. """ + if _is_unspecified(parsed_pspec): + return undefined_sharding_spec_proto() + if _is_from_gda(parsed_pspec): + raise NotImplementedError("alpa does not support global device array.") + if _is_auto(parsed_pspec): + raise NotImplementedError("") array_mapping = get_array_mapping(parsed_pspec) sharding_spec = pxla.new_mesh_sharding_specs(mesh_shape, mesh_axis_names)( @@ -43,10 +52,11 @@ def _parsed_pspec_to_hlo_sharding( # Used in `with_sharding_constraint`. special_axes = {} # Manual axes is only used with xmap. + # TODO: check whether this manual is conflict with what we use for the + # unspecified type(pjit uses REPLICATED as unspecified) if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext): axis_names = mesh_axis_names - # Ignore type because mypy doesn't recognize the `hasattr` check above. - for manual_axis in axis_ctx.manual_axes: # type: ignore + for manual_axis in axis_ctx.manual_axes: special_axes[axis_names.index( manual_axis)] = xc.OpSharding.Type.MANUAL op_sharding = sharding_spec.sharding_proto(special_axes=special_axes) @@ -78,26 +88,33 @@ def get_manual_sharding_spec( for name, size in safe_zip(sharding_option.mesh_axis_names, mesh_shape)) # process input - if sharding_option.in_axis_resources is not None: + if _is_unspecified(sharding_option.in_axis_resources): + in_op_shardings = None + else: + in_op_shardings = None in_axis_resources, _, _, any_auto = _prepare_axis_resources( sharding_option.in_axis_resources, "in_axis_resources") if any_auto: raise NotImplementedError( "auto mode in manual partition is unsupported.") in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) + if any(_is_unspecified(in_axis) for in_axis in in_axis_flat): + assert all(_is_unspecified(in_axis) for in_axis in in_axis_flat) in_op_shardings = tuple( _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option.mesh_axis_names, axis, len(aval.shape)) for axis, aval in safe_zip(in_axis_flat, in_avals)) - else: - in_op_shardings = None # process output - if sharding_option.out_axis_resources is not None: + if _is_unspecified(sharding_option.out_axis_resources): + tuple_output_sharding = None + else: out_axis_resources, _, _, _ = _prepare_axis_resources( sharding_option.out_axis_resources, "out_axis_resources") out_axis_flat = tuple(flatten_axes(out_tree, out_axis_resources)) + if any(_is_unspecified(out_axis) for out_axis in out_axis_flat): + assert all(_is_unspecified(out_axis) for out_axis in out_axis_flat) out_op_shardings = tuple( _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option.mesh_axis_names, axis, @@ -105,6 +122,4 @@ def get_manual_sharding_spec( for axis, aval in safe_zip(out_axis_flat, out_avals)) # Tuple[OpSharding] -> OpSharding w/ type=TUPLE tuple_output_sharding = xla.tuple_sharding_proto(out_op_shardings) - else: - tuple_output_sharding = None return in_op_shardings, tuple_output_sharding From 74d733d5b4016c13ccf90c0c7ea42fb9637d05b6 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Fri, 23 Dec 2022 07:13:20 +0000 Subject: [PATCH 06/15] add testcases --- alpa/shard_parallel/auto_sharding.py | 3 +- alpa/shard_parallel/compile_executable.py | 6 +- alpa/shard_parallel/manual_sharding.py | 9 +- alpa/util.py | 11 +- alpa/wrapped_hlo.py | 1 + tests/shard_parallel/test_manual.py | 127 ++++++++++++++++++++++ 6 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 tests/shard_parallel/test_manual.py diff --git a/alpa/shard_parallel/auto_sharding.py b/alpa/shard_parallel/auto_sharding.py index dce55d18f..b6d2ca9f6 100644 --- a/alpa/shard_parallel/auto_sharding.py +++ b/alpa/shard_parallel/auto_sharding.py @@ -218,7 +218,8 @@ def run_auto_sharding_pass( device_assignment=np.arange(num_devices).reshape((1, -1)), use_spmd_partitioning=True, parameter_is_tupled_arguments=False, - build_random_seed=build_random_seed) + build_random_seed=build_random_seed, + spmd_propagation_to_outputs=hlo.is_manually_annotated) # Set configs for force_zero_stage_3 if as_option.force_zero_stage_3: diff --git a/alpa/shard_parallel/compile_executable.py b/alpa/shard_parallel/compile_executable.py index 027dd5f14..99d713363 100644 --- a/alpa/shard_parallel/compile_executable.py +++ b/alpa/shard_parallel/compile_executable.py @@ -129,8 +129,10 @@ def shard_parallel_internal( avals, out_avals) if in_sharding_proto is not None: hlo.set_input_shardings(in_sharding_proto) + hlo.is_manually_annotated = True if out_sharding_proto is not None: hlo.set_output_shardings(out_sharding_proto) + hlo.is_manually_annotated = True flop_count = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module()) # Compile a XLA executable @@ -164,8 +166,6 @@ def shard_parallel_internal_gradient_accumulation( """Compile a gradient accumulation executable with auto-sharding pass.""" # pylint: disable=unused-argument # Split the batch dimension - if ms_option is not None: - raise NotImplementedError("Unsupported yet.") closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, batch_invars, num_micro_batches, raw_avals) @@ -194,8 +194,10 @@ def shard_parallel_internal_gradient_accumulation( if in_sharding_proto is not None: in_sharding_proto += tuple(grad_sharding_proto) hlo.set_input_shardings(in_sharding_proto) + hlo.is_manually_annotated = True if out_sharding_proto is not None: hlo.set_output_shardings(out_sharding_proto) + hlo.is_manually_annotated = True # pylint: disable=unbalanced-tuple-unpacking hlo_stage_names, hlo_stages, stage_plan = run_auto_sharding_pass( diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index 0d694e86f..696990d29 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -71,7 +71,7 @@ def flatten_axes(treedef, axis_tree): axes = [] def add_leaves(i, x): - axes.extend([i] * len(tree_flatten(x))[0]) + axes.extend([i] * len(tree_flatten(x)[0])) tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy) axes = [None if a is proxy else a for a in axes] @@ -91,7 +91,6 @@ def get_manual_sharding_spec( if _is_unspecified(sharding_option.in_axis_resources): in_op_shardings = None else: - in_op_shardings = None in_axis_resources, _, _, any_auto = _prepare_axis_resources( sharding_option.in_axis_resources, "in_axis_resources") if any_auto: @@ -108,7 +107,7 @@ def get_manual_sharding_spec( # process output if _is_unspecified(sharding_option.out_axis_resources): - tuple_output_sharding = None + out_op_shardings = None else: out_axis_resources, _, _, _ = _prepare_axis_resources( sharding_option.out_axis_resources, "out_axis_resources") @@ -120,6 +119,4 @@ def get_manual_sharding_spec( sharding_option.mesh_axis_names, axis, len(aval.shape)) for axis, aval in safe_zip(out_axis_flat, out_avals)) - # Tuple[OpSharding] -> OpSharding w/ type=TUPLE - tuple_output_sharding = xla.tuple_sharding_proto(out_op_shardings) - return in_op_shardings, tuple_output_sharding + return in_op_shardings, out_op_shardings diff --git a/alpa/util.py b/alpa/util.py index 4e21b8a6b..81eb9a193 100644 --- a/alpa/util.py +++ b/alpa/util.py @@ -308,11 +308,13 @@ def cached_property(fn, *args, **kwargs): ######################################## -def get_compile_options(num_replicas: int, num_partitions: int, +def get_compile_options(num_replicas: int, + num_partitions: int, device_assignment: np.ndarray, use_spmd_partitioning: bool, parameter_is_tupled_arguments: int, - build_random_seed: int): + build_random_seed: int, + spmd_propagation_to_outputs: bool = False): """Return CompileOptions for XLA compilation.""" compile_options = xb.get_compile_options( num_replicas=num_replicas, @@ -322,7 +324,10 @@ def get_compile_options(num_replicas: int, num_partitions: int, ) compile_options.parameter_is_tupled_arguments = ( parameter_is_tupled_arguments) - compile_options.executable_build_options.seed = build_random_seed + build_options = compile_options.executable_build_options + build_options.seed = build_random_seed + build_options.allow_spmd_sharding_propagation_to_output =\ + spmd_propagation_to_outputs return compile_options diff --git a/alpa/wrapped_hlo.py b/alpa/wrapped_hlo.py index 3190c188e..0a45c53b1 100644 --- a/alpa/wrapped_hlo.py +++ b/alpa/wrapped_hlo.py @@ -33,6 +33,7 @@ def __init__(self, self.module = xe.XlaComputation(module).get_hlo_module() self.name = self.module.name self.status = status + self.is_manually_annotated = False def get_computation(self) -> xe.XlaComputation: return xe.XlaComputation(self.module.as_serialized_hlo_module_proto()) diff --git a/tests/shard_parallel/test_manual.py b/tests/shard_parallel/test_manual.py new file mode 100644 index 000000000..67bc88f4e --- /dev/null +++ b/tests/shard_parallel/test_manual.py @@ -0,0 +1,127 @@ +""" +Test the manual sharding spec. +""" +import unittest + +import jax +from jax.experimental.pjit import PartitionSpec +from jax.tree_util import tree_flatten, tree_map +import jax.numpy as jnp + +import alpa +from alpa import (AutoShardingOption, LocalPhysicalDeviceMesh, + ManualShardingOption, ShardParallel, parallelize) + +class ManualShardingTest(unittest.TestCase): + + def setUp(self): + self.as_option = AutoShardingOption(enable_auto_sharding=False) + self.devices = LocalPhysicalDeviceMesh(jax.local_devices()[:4]) + self.devices = self.devices.get_logical_mesh((2, 2), (1, 1), (1, 1)) + self.mesh_axis_names = ("data", "model") + + def _get_fn_manual_sharding_with(self, + fn, + ms_option, + *args, + num_microbatches=None, + batch_argnums=(1,)): + method = ShardParallel( + devices=self.devices, + num_micro_batches=num_microbatches, + auto_sharding_option=self.as_option, + manual_sharding_option=ms_option, + ) + parallelized = parallelize(fn, method=method, batch_argnums=batch_argnums) + return parallelized.get_executable(*args).get_hlo_text() + + @staticmethod + def _get_param_line(text: str): + text = text[text.find("ENTRY"):] + text = text[:text.find("\n")] + return text + + @staticmethod + def _get_root_line(text:str): + text = text[text.find("ENTRY"):] + text = text[text.find("ROOT"):] + text = text[:text.find("\n")] + return text + + def test_set_input(self): + def fn(a, b): + return a + b + a = jnp.ones((6, 4)) + b = jnp.ones((6, 4)) + in_axis_resources = (PartitionSpec(None, "model"), + PartitionSpec(None, "model")) + ms_option = ManualShardingOption(self.mesh_axis_names, + in_axis_resources) + text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) + text = self._get_param_line(text) + assert "param: f32[6,2]" in text and "param.1: f32[6,2]" in text + in_axis_resources = (PartitionSpec("data", None), + PartitionSpec("data", "model")) + ms_option = ManualShardingOption(self.mesh_axis_names, + in_axis_resources) + text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) + text = self._get_param_line(text) + assert "param: f32[3,4]" in text and "param.1: f32[3,2]" in text + in_axis_resources = (None, PartitionSpec("data", None)) + ms_option = ManualShardingOption(self.mesh_axis_names, + in_axis_resources) + text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) + text = self._get_param_line(text) + assert "param: f32[6,4]" in text and "param.1: f32[3,4]" in text + + def test_set_output(self): + def fn(a): + return a**2, a + 1, a * 2, a / 2 + a = jnp.ones((6, 4)) + out_axis_resources = (PartitionSpec("data", None), None, + PartitionSpec(None, "model"), + PartitionSpec("data", "model")) + ms_option = ManualShardingOption(self.mesh_axis_names, + out_axis_resources=out_axis_resources) + text = self._get_fn_manual_sharding_with(fn, ms_option, a) + text = self._get_root_line(text) + assert ("(f32[3,4]{1,0}, f32[6,4]{1,0}, f32[6,2]{1,0}, f32[3,2]{1,0}" + in text) + + def test_grad_acc(self): + def fn(params, x): + def loss_fn(params): + w1, b1, w2, b2 = params + y = jax.nn.relu(x @ w1 + b1) + z = jax.nn.softmax(y @ w2 + b2) + return jnp.mean(z) + + grads = alpa.grad(loss_fn)(params) + new_params = tree_map(lambda p, g: p - g, params, grads) + return new_params + + + x = jnp.ones((2, 6)) + params = (jnp.ones((6, 8)), jnp.ones((8,)), jnp.ones( + (8, 10)), jnp.ones((10,))) + in_axis_resources = None + ms_option = ManualShardingOption(self.mesh_axis_names, + in_axis_resources) + text = self._get_fn_manual_sharding_with(fn, + ms_option, + params, + x, + num_microbatches=2) + # TODO(yonghao): check something here + +def suite(): + suite = unittest.TestSuite() + suite.addTest(ManualShardingTest("test_set_input")) + suite.addTest(ManualShardingTest("test_set_output")) + suite.addTest(ManualShardingTest("test_grad_acc")) + return suite + + +if __name__ == "__main__": + runner = unittest.TextTestRunner() + runner.run(suite()) From 33cb0b05c7f2ae03734fdab5f12d8c3cd30a2a00 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Fri, 23 Dec 2022 17:15:36 +0000 Subject: [PATCH 07/15] add test for accumulate grad --- alpa/shard_parallel/manual_sharding.py | 2 +- tests/shard_parallel/test_manual.py | 71 ++++++++++++++++++++++---- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index 696990d29..f4399ca2b 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -8,7 +8,7 @@ from jax.experimental.pjit import (_is_unspecified, _is_auto, _is_from_gda, _prepare_axis_resources, get_array_mapping, _UNSPECIFIED) -from jax.interpreters import mlir, pxla, xla +from jax.interpreters import mlir, pxla from jax.tree_util import tree_unflatten, tree_flatten, tree_map from alpa.util import undefined_sharding_spec_proto diff --git a/tests/shard_parallel/test_manual.py b/tests/shard_parallel/test_manual.py index 67bc88f4e..a2db45b0b 100644 --- a/tests/shard_parallel/test_manual.py +++ b/tests/shard_parallel/test_manual.py @@ -12,6 +12,7 @@ from alpa import (AutoShardingOption, LocalPhysicalDeviceMesh, ManualShardingOption, ShardParallel, parallelize) + class ManualShardingTest(unittest.TestCase): def setUp(self): @@ -32,7 +33,9 @@ def _get_fn_manual_sharding_with(self, auto_sharding_option=self.as_option, manual_sharding_option=ms_option, ) - parallelized = parallelize(fn, method=method, batch_argnums=batch_argnums) + parallelized = parallelize(fn, + method=method, + batch_argnums=batch_argnums) return parallelized.get_executable(*args).get_hlo_text() @staticmethod @@ -42,15 +45,32 @@ def _get_param_line(text: str): return text @staticmethod - def _get_root_line(text:str): + def _get_root_line(text: str): text = text[text.find("ENTRY"):] text = text[text.find("ROOT"):] text = text[:text.find("\n")] return text + @staticmethod + def _parse_param_shapes(text: str): + # the first one is "ENTRY %xxx (" + params = text.split("param")[1:] + shapes = tuple(map(lambda x: x[x.find("f32"):x.find("]") + 1], params)) + return shapes + + @staticmethod + def _parse_root_shapes(text: str): + tuple_shape = text[text.find("=") + 2:text.find("tuple(")] + # the last one is ')' + shapes = tuple_shape.split("0}")[:-1] + shapes = tuple(map(lambda x: x[x.find("f32"):x.find("{")], shapes)) + return shapes + def test_set_input(self): + def fn(a, b): return a + b + a = jnp.ones((6, 4)) b = jnp.ones((6, 4)) in_axis_resources = (PartitionSpec(None, "model"), @@ -75,8 +95,10 @@ def fn(a, b): assert "param: f32[6,4]" in text and "param.1: f32[3,4]" in text def test_set_output(self): + def fn(a): return a**2, a + 1, a * 2, a / 2 + a = jnp.ones((6, 4)) out_axis_resources = (PartitionSpec("data", None), None, PartitionSpec(None, "model"), @@ -89,30 +111,61 @@ def fn(a): in text) def test_grad_acc(self): - def fn(params, x): + + def fn(params, batch): + x, tgt = batch + def loss_fn(params): w1, b1, w2, b2 = params y = jax.nn.relu(x @ w1 + b1) z = jax.nn.softmax(y @ w2 + b2) - return jnp.mean(z) + return jnp.mean((z - tgt)**2) grads = alpa.grad(loss_fn)(params) new_params = tree_map(lambda p, g: p - g, params, grads) return new_params - - x = jnp.ones((2, 6)) + batch_size = 64 + x = jnp.ones((batch_size, 6)) + tgt = jnp.ones((batch_size, 10)) params = (jnp.ones((6, 8)), jnp.ones((8,)), jnp.ones( (8, 10)), jnp.ones((10,))) - in_axis_resources = None + batch = (x, tgt) + in_axis_resources = ((PartitionSpec(None, + "model"), PartitionSpec("model"), + PartitionSpec("model", + None), PartitionSpec(None)), + (PartitionSpec("data", + None), PartitionSpec("data", None))) ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, params, - x, + batch, num_microbatches=2) - # TODO(yonghao): check something here + apply_grad_start = text.find("HloModule", 1) + acc_grad_text = text[:apply_grad_start] + apply_grad_text = text[apply_grad_start:] + # 1. Accumulate grad: + acc_grad_params = self._get_param_line(acc_grad_text) + acc_grad_param_shapes = self._parse_param_shapes(acc_grad_params) + acc_grad_root = self._get_root_line(acc_grad_text) + acc_grad_root_shapes = self._parse_root_shapes(acc_grad_root) + + param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") + # batch_size / num_microbatches / data_parallel + batch_shape = ("f32[16,6]", "f32[16,10]") + assert acc_grad_param_shapes == param_shape + batch_shape + param_shape + assert acc_grad_root_shapes == param_shape + # 2. Apply grad: + apply_grad_params = self._get_param_line(apply_grad_text) + apply_grad_param_shapes = self._parse_param_shapes(apply_grad_params) + apply_grad_root = self._get_root_line(apply_grad_text) + apply_grad_root_shapes = self._parse_root_shapes(apply_grad_root) + assert apply_grad_param_shapes == param_shape + param_shape + assert apply_grad_root_shapes == param_shape + def suite(): suite = unittest.TestSuite() From 452b93cb15c588cddf77f890a8e2fc512cc1e6aa Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sat, 24 Dec 2022 02:08:24 +0000 Subject: [PATCH 08/15] manual sharding option for pipeshard parallel --- alpa/create_state_parallel.py | 2 +- alpa/follow_parallel.py | 2 +- alpa/parallel_method.py | 8 +- alpa/pipeline_parallel/compile_executable.py | 89 ++++++++++++++-- alpa/pipeline_parallel/computation.py | 2 - alpa/shard_parallel/manual_sharding.py | 101 ++++++++++++------- tests/shard_parallel/test_manual.py | 2 +- 7 files changed, 159 insertions(+), 47 deletions(-) diff --git a/alpa/create_state_parallel.py b/alpa/create_state_parallel.py index 63a737b52..8c3ffd0d5 100644 --- a/alpa/create_state_parallel.py +++ b/alpa/create_state_parallel.py @@ -138,7 +138,7 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk, new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals), executable.mesh_group.parent, 1, "inference", AutoShardingOption(enable_auto_sharding=False), - UniformStageOption(), name, None, output_shardings, None) + UniformStageOption(), name, None, output_shardings, None, None) return CreateStateExecutable(mesh_group=executable.mesh_group, pipeshard_config=pipeshard_config, diff --git a/alpa/follow_parallel.py b/alpa/follow_parallel.py index d75e8dd22..ae94bfb57 100644 --- a/alpa/follow_parallel.py +++ b/alpa/follow_parallel.py @@ -88,4 +88,4 @@ def is_leave(x): fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, mesh, num_micro_batches, pipeline_schedule, AutoShardingOption(enable_auto_sharding=False), layer_option, - UniformStageOption(), input_shardings, None, *avals) + UniformStageOption(), input_shardings, None, None, *avals) diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index 3c5cf9549..dd257567c 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -190,7 +190,8 @@ def __init__( layer_option: Optional[Union[LayerOption, str]] = None, stage_option: Optional[Union[StageOption, str]] = None, stage_input_shardings: Optional[Sequence[Sequence[ - pxla.ShardingSpec]]] = None): + pxla.ShardingSpec]]] = None, + manual_sharding_option: ManualShardingOption = None): self.devices = devices self.num_micro_batches = num_micro_batches self.as_option = (default_auto_sharding_option or @@ -212,6 +213,9 @@ def __init__( stage_option = UniformStageOption() self.stage_option = stage_option or UniformStageOption() self.stage_input_shardings = stage_input_shardings + assert not (stage_input_shardings is not None and + manual_sharding_option is not None) + self.manual_sharding_option = manual_sharding_option def compile_executable( self, @@ -237,7 +241,7 @@ def compile_executable( fun, in_tree, out_tree_thunk, static_argnums, donated_invars, batch_invars, mesh, self.num_micro_batches, self.pipeline_schedule, self.as_option, self.layer_option, self.stage_option, None, - self.stage_input_shardings, *avals) + self.stage_input_shardings, self.manual_sharding_option, *avals) def get_3d_parallel_method(num_micro_batches: int, diff --git a/alpa/pipeline_parallel/compile_executable.py b/alpa/pipeline_parallel/compile_executable.py index b68b2cc47..27f9bbdd9 100644 --- a/alpa/pipeline_parallel/compile_executable.py +++ b/alpa/pipeline_parallel/compile_executable.py @@ -5,6 +5,7 @@ from typing import Callable, Sequence, Optional from jax import linear_util as lu +from jax._src.lib import xla_client as xc from jax.core import gensym, AbstractValue, ClosedJaxpr from jax.interpreters import pxla from jax.tree_util import PyTreeDef @@ -32,7 +33,12 @@ from alpa.pipeline_parallel.stage_construction import ( cluster_layers_and_slice_mesh, StageOption) from alpa.pipeline_parallel.stage_profiling import CompileWorkerPool -from alpa.shard_parallel.auto_sharding import AutoShardingOption +from alpa.shard_parallel.auto_sharding import (AutoShardingOption, + hlo_sharding_to_sharding_spec) +from alpa.shard_parallel.manual_sharding import (ManualShardingOption, + ParsedManualShardingOption, + get_flatten_axis_resources, + parsed_spec_to_opsharding) from alpa.util import (get_var_mapping, trace_jaxpr_with_micro_batch, OrderedSet, GradFuncTransformContext) @@ -49,6 +55,7 @@ def compile_pipeshard_executable( layer_option: LayerOption, stage_option: StageOption, global_input_shardings: Optional[Sequence[pxla.ShardingSpec]], stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]], + manual_shard_options: Optional[ManualShardingOption], *avals: Sequence[AbstractValue]): """ Compile a callable for pipeshard parallel which combines @@ -60,6 +67,8 @@ def compile_pipeshard_executable( input vars. stage_input_shardings: Forcibly set sharding specs of input vars of each stage. + manual_sharding_options: pjit style sharding constraints of global input + vars. """ if global_config.backend == "tpu": raise NotImplementedError("Pipeshard Parallel for tpu is not supported") @@ -92,11 +101,19 @@ def compile_pipeshard_executable( fun.f = f_backup debug_compilation_time("trace") + # flatten manual sharding axis resources + out_tree = out_tree_thunk() + if manual_shard_options is not None: + assert global_input_shardings is None + parsed_ms_option = get_flatten_axis_resources( + manual_shard_options, in_tree, out_tree) + else: + parsed_ms_option = None pipeshard_config = compile_pipeshard_executable_internal( closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars, batch_invars, virtual_mesh, num_microbatch, pipeline_schedule, default_as_option, stage_option, name_base, global_input_shardings, - None, stage_input_shardings) + None, stage_input_shardings, parsed_ms_option) executable = PipeshardDriverExecutable( mesh_group=virtual_mesh.launched_physical_mesh_group, @@ -104,7 +121,7 @@ def compile_pipeshard_executable( num_batch=num_microbatch, layer_option=layer_option, in_tree=in_tree, - out_tree=out_tree_thunk(), + out_tree=out_tree, static_argnums=static_argnums) debug_compilation_time("driver executable") return executable @@ -119,7 +136,8 @@ def compile_pipeshard_executable_internal( stage_option: StageOption, name_base: str, global_input_shardings: Optional[Sequence[pxla.ShardingSpec]], global_output_shardings: Optional[Sequence[pxla.ShardingSpec]], - stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]]): + stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]], + parsed_manual_sharding_option: Optional[ParsedManualShardingOption]): """ Args: fun: The function to be parallelized. @@ -207,6 +225,8 @@ def compile_pipeshard_executable_internal( raise ValueError(f"Invalid schedule: {pipeline_schedule}") # Forcibly set the sharding specs of global invars and outvars. + # FIXME(yonghao): the invar can appear on multiple meshes and thus different + # sharding specs if global_input_shardings: assert len(global_input_shardings) == len(global_invars) input_sharding_dict = dict(zip(global_invars, global_input_shardings)) @@ -218,6 +238,16 @@ def compile_pipeshard_executable_internal( global_output_shardings)) else: output_sharding_dict = {} + if parsed_manual_sharding_option is not None: + assert (global_input_shardings is None and + global_output_shardings is None) + (input_sharding_dicts, + output_sharding_dicts) = get_manual_input_output_sharding_specs( + jax_all_stages, [mesh.shape for mesh in sliced_virtual_meshes], + parsed_manual_sharding_option, global_invars, global_outvars) + else: + input_sharding_dicts = [input_sharding_dict] * num_meshes + output_sharding_dicts = [output_sharding_dict] * num_meshes # Call auto-sharding pass to shard each stage xla_stages, total_flops = shard_each_stage( @@ -226,7 +256,7 @@ def compile_pipeshard_executable_internal( donate_invars_dict, num_microbatch, manual_stage_option.submesh_logical_shapes, manual_stage_option.submesh_autosharding_option_dicts, - default_as_option, input_sharding_dict, output_sharding_dict, + default_as_option, input_sharding_dicts, output_sharding_dicts, stage_input_shardings, name_base, gensym_func) total_flops *= num_microbatch debug_compilation_time("shard stages") @@ -320,11 +350,56 @@ def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr, accumulator_mapping, acc_grad_invars, acc_grad_outvars) +def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option, + global_invars, global_outvars, + stage_to_mesh): + """ + Split user assigned input and output PartitionSpec into sharding specs for + each pipeline stage. + """ + invar_set = set(global_invars) + outvar_set = set(global_outvars) + var_to_pspec = {} + if ms_option.in_parsed_pspec is not None: + var_to_pspec.update(dict(zip(global_invars, ms_option.in_parsed_pspec))) + if ms_option.out_parsed_pspec is not None: + var_to_pspec.update( + dict(zip(global_outvars, ms_option.out_parsed_pspec))) + + def get_vars_to_sharding_specs(variables, mesh_shape): + parsed_specs = [var_to_pspec[v] for v in variables] + var_op_shardings = parsed_spec_to_opsharding(parsed_specs, variables, + mesh_shape, ms_option) + var_sharding_specs = [ + hlo_sharding_to_sharding_spec(xc.HloSharding.from_proto(ops), aval, + mesh_shape) + for ops, aval in zip(variables, var_op_shardings) + ] + return dict(zip(variables, var_sharding_specs)) + + invar_shardings = [{}] * len(mesh_shapes) + outvar_shardings = [{}] * len(mesh_shapes) + for stage_idx, stage in stages: + mesh_idx = stage_to_mesh[stage_idx] + mesh_shape = mesh_shapes[stage_to_mesh[stage_idx]] + # invars + invar_in_global = [var for var in stage.invars if var in invar_set] + stage_invar_shardings = get_vars_to_sharding_specs( + invar_in_global, mesh_shape) + # outvars + outvar_in_global = [var for var in stage.outvars if var in outvar_set] + stage_outvar_shardings = get_vars_to_sharding_specs( + outvar_in_global, mesh_shape) + invar_shardings[mesh_idx].update(stage_invar_shardings) + outvar_shardings[mesh_idx].update(stage_outvar_shardings) + return invar_shardings, outvar_shardings + + def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes, accumulator_mapping, global_invars, acc_grad_outvars, donate_invars_dict, num_microbatch, logical_mesh_shapes, autosharding_option_dicts, default_as_option, - input_sharding_dict, output_sharding_dict, + input_sharding_dicts, output_sharding_dicts, stage_input_shardings, name_base, gensym_func): """Run intra-op parallelism compilation for a stage.""" # Initialize donation mapping @@ -362,6 +437,8 @@ def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes, compile_intermediate = [None] * num_meshes total_flops = 0 for mesh_idx in range(num_meshes): + input_sharding_dict = input_sharding_dicts[mesh_idx] + output_sharding_dict = output_sharding_dicts[mesh_idx] virtual_mesh = virtual_meshes[mesh_idx] logical_mesh = virtual_mesh.get_logical_mesh( logical_mesh_shapes[mesh_idx]) diff --git a/alpa/pipeline_parallel/computation.py b/alpa/pipeline_parallel/computation.py index 0be5add14..3dc2c7b68 100644 --- a/alpa/pipeline_parallel/computation.py +++ b/alpa/pipeline_parallel/computation.py @@ -741,14 +741,12 @@ def generate_sharded_xla_computations_arguments( if input_sharding_dict: sharding_protos = [] - sharding_specs = [] for x in invars: spec = input_sharding_dict.get(x, None) if spec is None: sharding_protos.append(undefined_sharding_spec_proto()) else: sharding_protos.append(spec.sharding_proto()) - sharding_specs.append(spec) hlo.set_input_shardings(sharding_protos) if output_sharding_dict: diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index f4399ca2b..b1c21b339 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -7,7 +7,7 @@ from jax._src.util import safe_zip from jax.experimental.pjit import (_is_unspecified, _is_auto, _is_from_gda, _prepare_axis_resources, get_array_mapping, - _UNSPECIFIED) + _UNSPECIFIED, ParsedPartitionSpec) from jax.interpreters import mlir, pxla from jax.tree_util import tree_unflatten, tree_flatten, tree_map @@ -23,6 +23,15 @@ class ManualShardingOption: out_axis_resources: Any = _UNSPECIFIED +@dataclasses.dataclass +class ParsedManualShardingOption: + """Options """ + mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None + # Parsed and flatten status + in_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None + out_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None + + def _parsed_pspec_to_hlo_sharding( mesh_shape, mesh_axis_names, @@ -63,7 +72,7 @@ def _parsed_pspec_to_hlo_sharding( return op_sharding -def flatten_axes(treedef, axis_tree): +def _flatten_axes(treedef, axis_tree): """Flatten the axis tree and consider None as an effective value.""" proxy = object() dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves) @@ -79,44 +88,68 @@ def add_leaves(i, x): return axes -def get_manual_sharding_spec( - sharding_option: ManualShardingOption, mesh_shape, in_tree, out_tree, - in_avals, out_avals) -> Tuple[Tuple[xc.OpSharding], xc.OpSharding]: - """Create input and output sharding spec from user's in_axis_resources.""" - named_mesh_shape = OrderedDict( - (name, size) - for name, size in safe_zip(sharding_option.mesh_axis_names, mesh_shape)) +def _prepare_axis_and_flatten(axis_resources, tree, name): + parsed_axis_resources, _, _, any_auto = _prepare_axis_resources( + axis_resources, name) + if any_auto: + raise NotImplementedError( + "auto mode in manual partition is unsupported.") + axis_flat = tuple(_flatten_axes(tree, parsed_axis_resources)) + if any(_is_unspecified(in_axis) for in_axis in axis_flat): + assert all(_is_unspecified(in_axis) for in_axis in axis_flat) + return axis_flat + + +def get_flatten_axis_resources( + sharding_option: ManualShardingOption, in_tree, + out_tree) -> ParsedManualShardingOption: + """Flatten axis resources for pipeline parallel to dispatch.""" + if sharding_option is None: + return None # process input if _is_unspecified(sharding_option.in_axis_resources): - in_op_shardings = None + in_axis_flat = None else: - in_axis_resources, _, _, any_auto = _prepare_axis_resources( - sharding_option.in_axis_resources, "in_axis_resources") - if any_auto: - raise NotImplementedError( - "auto mode in manual partition is unsupported.") - in_axis_flat = tuple(flatten_axes(in_tree, in_axis_resources)) - if any(_is_unspecified(in_axis) for in_axis in in_axis_flat): - assert all(_is_unspecified(in_axis) for in_axis in in_axis_flat) - in_op_shardings = tuple( - _parsed_pspec_to_hlo_sharding(named_mesh_shape, - sharding_option.mesh_axis_names, axis, - len(aval.shape)) - for axis, aval in safe_zip(in_axis_flat, in_avals)) + in_axis_flat = _prepare_axis_and_flatten( + sharding_option.in_axis_resources, in_tree, "in_axis_resources") # process output if _is_unspecified(sharding_option.out_axis_resources): - out_op_shardings = None + out_axis_flat = None else: - out_axis_resources, _, _, _ = _prepare_axis_resources( - sharding_option.out_axis_resources, "out_axis_resources") - out_axis_flat = tuple(flatten_axes(out_tree, out_axis_resources)) - if any(_is_unspecified(out_axis) for out_axis in out_axis_flat): - assert all(_is_unspecified(out_axis) for out_axis in out_axis_flat) - out_op_shardings = tuple( - _parsed_pspec_to_hlo_sharding(named_mesh_shape, - sharding_option.mesh_axis_names, axis, - len(aval.shape)) - for axis, aval in safe_zip(out_axis_flat, out_avals)) + out_axis_flat = _prepare_axis_and_flatten( + sharding_option.out_axis_resources, out_tree, "out_axis_resources") + return ParsedManualShardingOption(sharding_option.mesh_axis_names, + in_axis_flat, out_axis_flat) + + +def parsed_spec_to_opsharding(axes, avals, mesh_shape, sharding_option): + """Translate axis(a sequence of ParsedPartitionSpec) into OpShardings""" + if axes is None: + return None + + named_mesh_shape = OrderedDict( + (name, size) + for name, size in safe_zip(sharding_option.mesh_axis_names, mesh_shape)) + op_shardings = tuple( + _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. + mesh_axis_names, axis, len(aval.shape)) + for axis, aval in safe_zip(axes, avals)) + return op_shardings + + +def get_manual_sharding_spec( + sharding_option: ManualShardingOption, mesh_shape, in_tree, out_tree, + in_avals, out_avals) -> Tuple[Tuple[xc.OpSharding, ...], xc.OpSharding]: + """Create input and output sharding spec from user's in_axis_resources.""" + parsed_resources = get_flatten_axis_resources(sharding_option, in_tree, + out_tree) + if parsed_resources is None: + return None, None + in_op_shardings = parsed_spec_to_opsharding( + parsed_resources.in_parsed_pspec, in_avals, mesh_shape, sharding_option) + out_op_shardings = parsed_spec_to_opsharding( + parsed_resources.out_parsed_pspec, out_avals, mesh_shape, + sharding_option) return in_op_shardings, out_op_shardings diff --git a/tests/shard_parallel/test_manual.py b/tests/shard_parallel/test_manual.py index a2db45b0b..1ed943ae5 100644 --- a/tests/shard_parallel/test_manual.py +++ b/tests/shard_parallel/test_manual.py @@ -5,7 +5,7 @@ import jax from jax.experimental.pjit import PartitionSpec -from jax.tree_util import tree_flatten, tree_map +from jax.tree_util import tree_map import jax.numpy as jnp import alpa From f8385420c8256166772190ee16cae8aed18190fd Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sat, 24 Dec 2022 03:46:07 +0000 Subject: [PATCH 09/15] add pipeshard testcase --- alpa/pipeline_parallel/compile_executable.py | 47 ++++-- alpa/shard_parallel/manual_sharding.py | 17 +- .../pipeline_parallel/test_manual_sharding.py | 151 ++++++++++++++++++ tests/shard_parallel/test_manual.py | 8 +- 4 files changed, 200 insertions(+), 23 deletions(-) create mode 100644 tests/pipeline_parallel/test_manual_sharding.py diff --git a/alpa/pipeline_parallel/compile_executable.py b/alpa/pipeline_parallel/compile_executable.py index 27f9bbdd9..c98d08be3 100644 --- a/alpa/pipeline_parallel/compile_executable.py +++ b/alpa/pipeline_parallel/compile_executable.py @@ -244,7 +244,8 @@ def compile_pipeshard_executable_internal( (input_sharding_dicts, output_sharding_dicts) = get_manual_input_output_sharding_specs( jax_all_stages, [mesh.shape for mesh in sliced_virtual_meshes], - parsed_manual_sharding_option, global_invars, global_outvars) + parsed_manual_sharding_option, global_invars, global_outvars, + schedule.stage_mesh_mapping) else: input_sharding_dicts = [input_sharding_dict] * num_meshes output_sharding_dicts = [output_sharding_dict] * num_meshes @@ -360,36 +361,56 @@ def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option, invar_set = set(global_invars) outvar_set = set(global_outvars) var_to_pspec = {} + handle_invar = False + handle_outvar = False if ms_option.in_parsed_pspec is not None: var_to_pspec.update(dict(zip(global_invars, ms_option.in_parsed_pspec))) + handle_invar = True if ms_option.out_parsed_pspec is not None: var_to_pspec.update( dict(zip(global_outvars, ms_option.out_parsed_pspec))) + handle_outvar = True + submesh_axis_names = ms_option.submesh_axis_names + if submesh_axis_names is None: + submesh_axis_names = [ms_option.mesh_axis_names] * len(mesh_shapes) - def get_vars_to_sharding_specs(variables, mesh_shape): + def get_vars_to_sharding_specs(variables, mesh_shape, mesh_axis_names): parsed_specs = [var_to_pspec[v] for v in variables] - var_op_shardings = parsed_spec_to_opsharding(parsed_specs, variables, - mesh_shape, ms_option) + avals = [v.aval for v in variables] + var_op_shardings = parsed_spec_to_opsharding(parsed_specs, avals, + mesh_shape, + mesh_axis_names) var_sharding_specs = [ hlo_sharding_to_sharding_spec(xc.HloSharding.from_proto(ops), aval, mesh_shape) - for ops, aval in zip(variables, var_op_shardings) + for ops, aval in zip(var_op_shardings, avals) ] return dict(zip(variables, var_sharding_specs)) invar_shardings = [{}] * len(mesh_shapes) outvar_shardings = [{}] * len(mesh_shapes) - for stage_idx, stage in stages: + for stage_idx, stage in enumerate(stages): mesh_idx = stage_to_mesh[stage_idx] - mesh_shape = mesh_shapes[stage_to_mesh[stage_idx]] + assert len(mesh_idx) == 1 + mesh_idx = list(mesh_idx)[0] + mesh_shape = mesh_shapes[mesh_idx] + mesh_axis_names = submesh_axis_names[mesh_idx] # invars - invar_in_global = [var for var in stage.invars if var in invar_set] - stage_invar_shardings = get_vars_to_sharding_specs( - invar_in_global, mesh_shape) + if handle_invar: + invar_in_global = [var for var in stage.invars if var in invar_set] + stage_invar_shardings = get_vars_to_sharding_specs( + invar_in_global, mesh_shape, mesh_axis_names) + else: + stage_invar_shardings = {} # outvars - outvar_in_global = [var for var in stage.outvars if var in outvar_set] - stage_outvar_shardings = get_vars_to_sharding_specs( - outvar_in_global, mesh_shape) + if handle_outvar: + outvar_in_global = [ + var for var in stage.outvars if var in outvar_set + ] + stage_outvar_shardings = get_vars_to_sharding_specs( + outvar_in_global, mesh_shape, mesh_axis_names) + else: + stage_outvar_shardings = {} invar_shardings[mesh_idx].update(stage_invar_shardings) outvar_shardings[mesh_idx].update(stage_outvar_shardings) return invar_shardings, outvar_shardings diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index b1c21b339..43386ad93 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -18,6 +18,7 @@ class ManualShardingOption: """Options to manually set shardings in pjit convention.""" mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None + submesh_axis_names: Tuple[Tuple[pxla.MeshAxisName, ...], ...] = None # According to pjit, None means replicated. in_axis_resources: Any = _UNSPECIFIED out_axis_resources: Any = _UNSPECIFIED @@ -27,6 +28,7 @@ class ManualShardingOption: class ParsedManualShardingOption: """Options """ mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None + submesh_axis_names: Tuple[Tuple[pxla.MeshAxisName, ...], ...] = None # Parsed and flatten status in_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None out_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None @@ -121,20 +123,21 @@ def get_flatten_axis_resources( out_axis_flat = _prepare_axis_and_flatten( sharding_option.out_axis_resources, out_tree, "out_axis_resources") return ParsedManualShardingOption(sharding_option.mesh_axis_names, + sharding_option.submesh_axis_names, in_axis_flat, out_axis_flat) -def parsed_spec_to_opsharding(axes, avals, mesh_shape, sharding_option): +def parsed_spec_to_opsharding(axes, avals, mesh_shape, mesh_axis_names): """Translate axis(a sequence of ParsedPartitionSpec) into OpShardings""" if axes is None: return None named_mesh_shape = OrderedDict( (name, size) - for name, size in safe_zip(sharding_option.mesh_axis_names, mesh_shape)) + for name, size in safe_zip(mesh_axis_names, mesh_shape)) op_shardings = tuple( - _parsed_pspec_to_hlo_sharding(named_mesh_shape, sharding_option. - mesh_axis_names, axis, len(aval.shape)) + _parsed_pspec_to_hlo_sharding(named_mesh_shape, mesh_axis_names, axis, + len(aval.shape)) for axis, aval in safe_zip(axes, avals)) return op_shardings @@ -147,9 +150,11 @@ def get_manual_sharding_spec( out_tree) if parsed_resources is None: return None, None + assert parsed_resources.mesh_axis_names is not None + mesh_axis_names = sharding_option.mesh_axis_names in_op_shardings = parsed_spec_to_opsharding( - parsed_resources.in_parsed_pspec, in_avals, mesh_shape, sharding_option) + parsed_resources.in_parsed_pspec, in_avals, mesh_shape, mesh_axis_names) out_op_shardings = parsed_spec_to_opsharding( parsed_resources.out_parsed_pspec, out_avals, mesh_shape, - sharding_option) + mesh_axis_names) return in_op_shardings, out_op_shardings diff --git a/tests/pipeline_parallel/test_manual_sharding.py b/tests/pipeline_parallel/test_manual_sharding.py new file mode 100644 index 000000000..dbf33c231 --- /dev/null +++ b/tests/pipeline_parallel/test_manual_sharding.py @@ -0,0 +1,151 @@ +""" +Test the manual sharding spec. +""" +import itertools +import unittest + +import jax +from jax.experimental.pjit import PartitionSpec +from jax.tree_util import tree_map +import jax.numpy as jnp + +import alpa +from alpa import (ManualShardingOption, ManualStageOption, PipeshardParallel, + mark_pipeline_boundary, parallelize) + + +class PipeshardManualShardingTest(unittest.TestCase): + + def setUp(self): + alpa.init() + # use (1 * 4) mesh + alpa.set_global_virtual_physical_mesh( + alpa.get_global_cluster().get_virtual_physical_mesh([0], 4)) + + def tearDown(self): + alpa.shutdown() + + def _get_fn_manual_sharding_with(self, + fn, + num_microbatches, + stage_option, + ms_option, + *args,): + method = PipeshardParallel( + num_micro_batches=num_microbatches, + stage_option=stage_option, + manual_sharding_option=ms_option, + ) + parallelized = parallelize(fn, method=method) + return parallelized.get_executable(*args).get_hlo_text() + + @staticmethod + def _get_param_line(text: str): + text = text[text.find("ENTRY"):] + text = text[:text.find("\n")] + return text + + @staticmethod + def _get_root_line(text: str): + text = text[text.find("ENTRY"):] + text = text[text.find("ROOT"):] + text = text[:text.find("\n")] + return text + + @staticmethod + def _parse_param_shapes(text: str): + # the first one is "ENTRY %xxx (" + params = text.split("param")[1:] + shapes = tuple(map(lambda x: x[x.find("f32"):x.find("]") + 1], params)) + return shapes + + @staticmethod + def _parse_root_shapes(text: str): + tuple_shape = text[text.find("=") + 2:text.find("tuple(")] + # the last one is ')' + shapes = tuple_shape.split("0}")[:-1] + shapes = tuple(map(lambda x: x[x.find("f32"):x.find("{")], shapes)) + return shapes + + def test_set_input_output(self): + + def fn(params, batch): + x, tgt = batch + + def loss_fn(params): + w0, b0, w1, b1, w2, b2, w3, b3 = params + y = jax.nn.relu(x @ w0 + b0) + z = jax.nn.relu(y @ w1 + b1) + mark_pipeline_boundary() + u = jax.nn.relu(z @ w2 + b2) + v = jax.nn.softmax(u @ w3 + b3) + return jnp.mean((v - tgt)**2) + + grads = alpa.grad(loss_fn)(params) + new_params = tree_map(lambda p, g: p - g, params, grads) + return new_params + + # data + batch_size = 64 + hiddens = [6, 8, 10, 12, 14] + params = itertools.chain(*[(jnp.ones((hiddens[i], hiddens[i + 1])), + jnp.ones((hiddens[i + 1],))) + for i in range(len(hiddens) - 1)]) + params = tuple(params) + x = jnp.ones((batch_size, hiddens[0])) + tgt = jnp.ones((batch_size, hiddens[-1])) + batch = (x, tgt) + + # partitions + mp_start = PartitionSpec(None, "model") + mp_end = PartitionSpec("model", None) + bias_partitioned = PartitionSpec("model") + replicated = None + dp = PartitionSpec("data", None) + + param_axis_resources = (mp_start, bias_partitioned, mp_end, + replicated) + (replicated, replicated, + replicated, replicated) + batch_axis_resources = (replicated, dp) + in_axis_resources = (param_axis_resources, batch_axis_resources) + + # options + s_option = ManualStageOption([[0], [1]], [(1, 2)] * 2, [(1, 2), (2, 1)], + [{}] * 2) + submesh_axis_names = (("dummy", "model"), ("data", "dummy")) + ms_option = ManualShardingOption(None, submesh_axis_names, + in_axis_resources) + text = self._get_fn_manual_sharding_with(fn, 2, s_option, ms_option, + params, batch) + print(text) + # apply_grad_start = text.find("HloModule", 1) + # acc_grad_text = text[:apply_grad_start] + # apply_grad_text = text[apply_grad_start:] + # # 1. Accumulate grad: + # acc_grad_params = self._get_param_line(acc_grad_text) + # acc_grad_param_shapes = self._parse_param_shapes(acc_grad_params) + # acc_grad_root = self._get_root_line(acc_grad_text) + # acc_grad_root_shapes = self._parse_root_shapes(acc_grad_root) + + # param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") + # # batch_size / num_microbatches / data_parallel + # batch_shape = ("f32[16,6]", "f32[16,10]") + # assert acc_grad_param_shapes == param_shape + batch_shape + param_shape + # assert acc_grad_root_shapes == param_shape + # # 2. Apply grad: + # apply_grad_params = self._get_param_line(apply_grad_text) + # apply_grad_param_shapes = self._parse_param_shapes(apply_grad_params) + # apply_grad_root = self._get_root_line(apply_grad_text) + # apply_grad_root_shapes = self._parse_root_shapes(apply_grad_root) + # assert apply_grad_param_shapes == param_shape + param_shape + # assert apply_grad_root_shapes == param_shape + +def suite(): + suite = unittest.TestSuite() + suite.addTest(PipeshardManualShardingTest("test_set_input_output")) + return suite + + +if __name__ == "__main__": + runner = unittest.TextTestRunner() + runner.run(suite()) diff --git a/tests/shard_parallel/test_manual.py b/tests/shard_parallel/test_manual.py index 1ed943ae5..34e477c2b 100644 --- a/tests/shard_parallel/test_manual.py +++ b/tests/shard_parallel/test_manual.py @@ -76,20 +76,20 @@ def fn(a, b): in_axis_resources = (PartitionSpec(None, "model"), PartitionSpec(None, "model")) ms_option = ManualShardingOption(self.mesh_axis_names, - in_axis_resources) + in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) text = self._get_param_line(text) assert "param: f32[6,2]" in text and "param.1: f32[6,2]" in text in_axis_resources = (PartitionSpec("data", None), PartitionSpec("data", "model")) ms_option = ManualShardingOption(self.mesh_axis_names, - in_axis_resources) + in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) text = self._get_param_line(text) assert "param: f32[3,4]" in text and "param.1: f32[3,2]" in text in_axis_resources = (None, PartitionSpec("data", None)) ms_option = ManualShardingOption(self.mesh_axis_names, - in_axis_resources) + in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) text = self._get_param_line(text) assert "param: f32[6,4]" in text and "param.1: f32[3,4]" in text @@ -138,7 +138,7 @@ def loss_fn(params): (PartitionSpec("data", None), PartitionSpec("data", None))) ms_option = ManualShardingOption(self.mesh_axis_names, - in_axis_resources) + in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, params, From 32a23c39b21dbcb6eb2f8876b89030a0346c6675 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sat, 24 Dec 2022 05:32:45 +0000 Subject: [PATCH 10/15] update test --- .../pipeline_parallel/test_manual_sharding.py | 83 +++++++++++-------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/tests/pipeline_parallel/test_manual_sharding.py b/tests/pipeline_parallel/test_manual_sharding.py index dbf33c231..2d7d74bb7 100644 --- a/tests/pipeline_parallel/test_manual_sharding.py +++ b/tests/pipeline_parallel/test_manual_sharding.py @@ -10,8 +10,8 @@ import jax.numpy as jnp import alpa -from alpa import (ManualShardingOption, ManualStageOption, PipeshardParallel, - mark_pipeline_boundary, parallelize) +from alpa import (AutoShardingOption, ManualShardingOption, ManualStageOption, + PipeshardParallel, mark_pipeline_boundary, parallelize) class PipeshardManualShardingTest(unittest.TestCase): @@ -25,17 +25,13 @@ def setUp(self): def tearDown(self): alpa.shutdown() - def _get_fn_manual_sharding_with(self, - fn, - num_microbatches, - stage_option, - ms_option, - *args,): + def _get_fn_manual_sharding_with(self, fn, num_microbatches, stage_option, + ms_option, *args): method = PipeshardParallel( num_micro_batches=num_microbatches, stage_option=stage_option, manual_sharding_option=ms_option, - ) + default_auto_sharding_option=AutoShardingOption(False)) parallelized = parallelize(fn, method=method) return parallelized.get_executable(*args).get_hlo_text() @@ -56,7 +52,8 @@ def _get_root_line(text: str): def _parse_param_shapes(text: str): # the first one is "ENTRY %xxx (" params = text.split("param")[1:] - shapes = tuple(map(lambda x: x[x.find("f32"):x.find("]") + 1], params)) + shapes = tuple( + map(lambda x: x[x.find(": ") + 2:x.find("]") + 1], params)) return shapes @staticmethod @@ -67,6 +64,14 @@ def _parse_root_shapes(text: str): shapes = tuple(map(lambda x: x[x.find("f32"):x.find("{")], shapes)) return shapes + @staticmethod + def _is_superset_with_x_more(seq1, seq2, x): + set1 = set(seq1) + set2 = set(seq2) + if set1.issuperset(set2) and len(set1) - len(set2) == x: + return True + return False + def test_set_input_output(self): def fn(params, batch): @@ -110,35 +115,45 @@ def loss_fn(params): in_axis_resources = (param_axis_resources, batch_axis_resources) # options - s_option = ManualStageOption([[0], [1]], [(1, 2)] * 2, [(1, 2), (2, 1)], + s_option = ManualStageOption([[0], [1]], [(1, 2)] * 2, [(1, 2)] * 2, [{}] * 2) - submesh_axis_names = (("dummy", "model"), ("data", "dummy")) + submesh_axis_names = (("dummy", "model"), ("dummy", "data")) ms_option = ManualShardingOption(None, submesh_axis_names, in_axis_resources) text = self._get_fn_manual_sharding_with(fn, 2, s_option, ms_option, params, batch) - print(text) - # apply_grad_start = text.find("HloModule", 1) - # acc_grad_text = text[:apply_grad_start] - # apply_grad_text = text[apply_grad_start:] - # # 1. Accumulate grad: - # acc_grad_params = self._get_param_line(acc_grad_text) - # acc_grad_param_shapes = self._parse_param_shapes(acc_grad_params) - # acc_grad_root = self._get_root_line(acc_grad_text) - # acc_grad_root_shapes = self._parse_root_shapes(acc_grad_root) - - # param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") - # # batch_size / num_microbatches / data_parallel - # batch_shape = ("f32[16,6]", "f32[16,10]") - # assert acc_grad_param_shapes == param_shape + batch_shape + param_shape - # assert acc_grad_root_shapes == param_shape - # # 2. Apply grad: - # apply_grad_params = self._get_param_line(apply_grad_text) - # apply_grad_param_shapes = self._parse_param_shapes(apply_grad_params) - # apply_grad_root = self._get_root_line(apply_grad_text) - # apply_grad_root_shapes = self._parse_root_shapes(apply_grad_root) - # assert apply_grad_param_shapes == param_shape + param_shape - # assert apply_grad_root_shapes == param_shape + l0_fwd, l1_fwd, l1_bwd, l0_bwd, l0_apl, l1_apl = text + # layer 0 + l0_param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") + l0_batch_shape = ("f32[32,6]",) + l0_fwd_param = self._parse_param_shapes(self._get_param_line(l0_fwd)) + assert sorted(l0_fwd_param) == sorted(l0_param_shape + l0_batch_shape) + l0_bwd_param = self._parse_param_shapes(self._get_param_line(l0_bwd)) + l0_bwd_root = self._parse_root_shapes(self._get_root_line(l0_bwd)) + # the donated accumulated gradient are at first + assert sorted(l0_bwd_param[:4]) == sorted(l0_param_shape) + assert sorted(l0_bwd_root) == sorted(l0_param_shape) + l0_apl_param = self._parse_param_shapes(self._get_param_line(l0_apl)) + l0_apl_root = self._parse_root_shapes(self._get_root_line(l0_apl)) + assert sorted(l0_apl_param) == sorted(l0_param_shape + l0_param_shape) + assert sorted(l0_apl_root) == sorted(l0_param_shape) + + # layer 1 + l1_param_shape = ("f32[10,12]", "f32[12]", "f32[12,14]", "f32[14]") + l1_batch_shape = ("f32[16,14]",) + l1_fwd_param = self._parse_param_shapes(self._get_param_line(l1_fwd)) + assert self._is_superset_with_x_more(l1_fwd_param, + l1_param_shape + l1_batch_shape, 1) + l1_bwd_param = self._parse_param_shapes(self._get_param_line(l1_bwd)) + l1_bwd_root = self._parse_root_shapes(self._get_root_line(l1_bwd)) + # the donated accumulated gradient are at first + assert sorted(l1_bwd_param[:4]) == sorted(l1_param_shape) + assert self._is_superset_with_x_more(l1_bwd_root, l1_param_shape, 1) + l1_apl_param = self._parse_param_shapes(self._get_param_line(l1_apl)) + l1_apl_root = self._parse_root_shapes(self._get_root_line(l1_apl)) + assert sorted(l1_apl_param) == sorted(l1_param_shape + l1_param_shape) + assert sorted(l1_apl_root) == sorted(l1_param_shape) + def suite(): suite = unittest.TestSuite() From b569df723547446a102d30f35aa3f948188a5ea6 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sat, 24 Dec 2022 17:11:58 +0000 Subject: [PATCH 11/15] format --- alpa/parallel_method.py | 20 ++++++++++---------- alpa/pipeline_parallel/compile_executable.py | 4 ++-- alpa/shard_parallel/manual_sharding.py | 8 +++----- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index dd257567c..3b92902b9 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -182,16 +182,16 @@ class PipeshardParallel(ParallelMethod): """ def __init__( - self, - devices: Optional[VirtualPhysicalMesh] = None, - num_micro_batches: int = 1, - default_auto_sharding_option: Optional[AutoShardingOption] = None, - pipeline_schedule: str = "1f1b", - layer_option: Optional[Union[LayerOption, str]] = None, - stage_option: Optional[Union[StageOption, str]] = None, - stage_input_shardings: Optional[Sequence[Sequence[ - pxla.ShardingSpec]]] = None, - manual_sharding_option: ManualShardingOption = None): + self, + devices: Optional[VirtualPhysicalMesh] = None, + num_micro_batches: int = 1, + default_auto_sharding_option: Optional[AutoShardingOption] = None, + pipeline_schedule: str = "1f1b", + layer_option: Optional[Union[LayerOption, str]] = None, + stage_option: Optional[Union[StageOption, str]] = None, + stage_input_shardings: Optional[Sequence[Sequence[ + pxla.ShardingSpec]]] = None, + manual_sharding_option: ManualShardingOption = None): self.devices = devices self.num_micro_batches = num_micro_batches self.as_option = (default_auto_sharding_option or diff --git a/alpa/pipeline_parallel/compile_executable.py b/alpa/pipeline_parallel/compile_executable.py index c98d08be3..8e244072e 100644 --- a/alpa/pipeline_parallel/compile_executable.py +++ b/alpa/pipeline_parallel/compile_executable.py @@ -105,8 +105,8 @@ def compile_pipeshard_executable( out_tree = out_tree_thunk() if manual_shard_options is not None: assert global_input_shardings is None - parsed_ms_option = get_flatten_axis_resources( - manual_shard_options, in_tree, out_tree) + parsed_ms_option = get_flatten_axis_resources(manual_shard_options, + in_tree, out_tree) else: parsed_ms_option = None pipeshard_config = compile_pipeshard_executable_internal( diff --git a/alpa/shard_parallel/manual_sharding.py b/alpa/shard_parallel/manual_sharding.py index 43386ad93..afd8e7304 100644 --- a/alpa/shard_parallel/manual_sharding.py +++ b/alpa/shard_parallel/manual_sharding.py @@ -102,9 +102,8 @@ def _prepare_axis_and_flatten(axis_resources, tree, name): return axis_flat -def get_flatten_axis_resources( - sharding_option: ManualShardingOption, in_tree, - out_tree) -> ParsedManualShardingOption: +def get_flatten_axis_resources(sharding_option: ManualShardingOption, in_tree, + out_tree) -> ParsedManualShardingOption: """Flatten axis resources for pipeline parallel to dispatch.""" if sharding_option is None: return None @@ -133,8 +132,7 @@ def parsed_spec_to_opsharding(axes, avals, mesh_shape, mesh_axis_names): return None named_mesh_shape = OrderedDict( - (name, size) - for name, size in safe_zip(mesh_axis_names, mesh_shape)) + (name, size) for name, size in safe_zip(mesh_axis_names, mesh_shape)) op_shardings = tuple( _parsed_pspec_to_hlo_sharding(named_mesh_shape, mesh_axis_names, axis, len(aval.shape)) From 727703da24e60f4599bb2b5909dbc6ed882d3e93 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Tue, 27 Dec 2022 03:19:23 +0000 Subject: [PATCH 12/15] move tools into testing.py --- alpa/testing.py | 31 ++++++++++ .../pipeline_parallel/test_manual_sharding.py | 60 +++++++------------ tests/shard_parallel/test_manual.py | 54 +++++------------ 3 files changed, 66 insertions(+), 79 deletions(-) diff --git a/alpa/testing.py b/alpa/testing.py index dc5cdd0f9..b0a8d5396 100644 --- a/alpa/testing.py +++ b/alpa/testing.py @@ -361,3 +361,34 @@ def data_loader_input_iter_func(start, end, batch_size): for i in range(num_batches): idx = start + i * batch_size yield dataset_x[idx:idx + batch_size], dataset_y[idx:idx + batch_size] + + +class HloParser: + + @staticmethod + def get_param_line(text: str): + text = text[text.find("ENTRY"):] + text = text[:text.find("\n")] + return text + + @staticmethod + def get_root_line(text: str): + text = text[text.find("ENTRY"):] + text = text[text.find("ROOT"):] + text = text[:text.find("\n")] + return text + + @staticmethod + def parse_param_shapes(text: str): + # the first one is "ENTRY %xxx (" + params = text.split("param")[1:] + shapes = tuple(map(lambda x: x[x.find("f32"):x.find("]") + 1], params)) + return shapes + + @staticmethod + def parse_root_shapes(text: str): + tuple_shape = text[text.find("=") + 2:text.find("tuple(")] + # the last one is ')' + shapes = tuple_shape.split("0}")[:-1] + shapes = tuple(map(lambda x: x[x.find("f32"):x.find("{")], shapes)) + return shapes diff --git a/tests/pipeline_parallel/test_manual_sharding.py b/tests/pipeline_parallel/test_manual_sharding.py index 2d7d74bb7..044bab022 100644 --- a/tests/pipeline_parallel/test_manual_sharding.py +++ b/tests/pipeline_parallel/test_manual_sharding.py @@ -12,6 +12,7 @@ import alpa from alpa import (AutoShardingOption, ManualShardingOption, ManualStageOption, PipeshardParallel, mark_pipeline_boundary, parallelize) +from alpa.testing import HloParser class PipeshardManualShardingTest(unittest.TestCase): @@ -35,35 +36,6 @@ def _get_fn_manual_sharding_with(self, fn, num_microbatches, stage_option, parallelized = parallelize(fn, method=method) return parallelized.get_executable(*args).get_hlo_text() - @staticmethod - def _get_param_line(text: str): - text = text[text.find("ENTRY"):] - text = text[:text.find("\n")] - return text - - @staticmethod - def _get_root_line(text: str): - text = text[text.find("ENTRY"):] - text = text[text.find("ROOT"):] - text = text[:text.find("\n")] - return text - - @staticmethod - def _parse_param_shapes(text: str): - # the first one is "ENTRY %xxx (" - params = text.split("param")[1:] - shapes = tuple( - map(lambda x: x[x.find(": ") + 2:x.find("]") + 1], params)) - return shapes - - @staticmethod - def _parse_root_shapes(text: str): - tuple_shape = text[text.find("=") + 2:text.find("tuple(")] - # the last one is ')' - shapes = tuple_shape.split("0}")[:-1] - shapes = tuple(map(lambda x: x[x.find("f32"):x.find("{")], shapes)) - return shapes - @staticmethod def _is_superset_with_x_more(seq1, seq2, x): set1 = set(seq1) @@ -126,31 +98,41 @@ def loss_fn(params): # layer 0 l0_param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") l0_batch_shape = ("f32[32,6]",) - l0_fwd_param = self._parse_param_shapes(self._get_param_line(l0_fwd)) + l0_fwd_param = HloParser.parse_param_shapes( + HloParser.get_param_line(l0_fwd)) assert sorted(l0_fwd_param) == sorted(l0_param_shape + l0_batch_shape) - l0_bwd_param = self._parse_param_shapes(self._get_param_line(l0_bwd)) - l0_bwd_root = self._parse_root_shapes(self._get_root_line(l0_bwd)) + l0_bwd_param = HloParser.parse_param_shapes( + HloParser.get_param_line(l0_bwd)) + l0_bwd_root = HloParser.parse_root_shapes( + HloParser.get_root_line(l0_bwd)) # the donated accumulated gradient are at first assert sorted(l0_bwd_param[:4]) == sorted(l0_param_shape) assert sorted(l0_bwd_root) == sorted(l0_param_shape) - l0_apl_param = self._parse_param_shapes(self._get_param_line(l0_apl)) - l0_apl_root = self._parse_root_shapes(self._get_root_line(l0_apl)) + l0_apl_param = HloParser.parse_param_shapes( + HloParser.get_param_line(l0_apl)) + l0_apl_root = HloParser.parse_root_shapes( + HloParser.get_root_line(l0_apl)) assert sorted(l0_apl_param) == sorted(l0_param_shape + l0_param_shape) assert sorted(l0_apl_root) == sorted(l0_param_shape) # layer 1 l1_param_shape = ("f32[10,12]", "f32[12]", "f32[12,14]", "f32[14]") l1_batch_shape = ("f32[16,14]",) - l1_fwd_param = self._parse_param_shapes(self._get_param_line(l1_fwd)) + l1_fwd_param = HloParser.parse_param_shapes( + HloParser.get_param_line(l1_fwd)) assert self._is_superset_with_x_more(l1_fwd_param, l1_param_shape + l1_batch_shape, 1) - l1_bwd_param = self._parse_param_shapes(self._get_param_line(l1_bwd)) - l1_bwd_root = self._parse_root_shapes(self._get_root_line(l1_bwd)) + l1_bwd_param = HloParser.parse_param_shapes( + HloParser.get_param_line(l1_bwd)) + l1_bwd_root = HloParser.parse_root_shapes( + HloParser.get_root_line(l1_bwd)) # the donated accumulated gradient are at first assert sorted(l1_bwd_param[:4]) == sorted(l1_param_shape) assert self._is_superset_with_x_more(l1_bwd_root, l1_param_shape, 1) - l1_apl_param = self._parse_param_shapes(self._get_param_line(l1_apl)) - l1_apl_root = self._parse_root_shapes(self._get_root_line(l1_apl)) + l1_apl_param = HloParser.parse_param_shapes( + HloParser.get_param_line(l1_apl)) + l1_apl_root = HloParser.parse_root_shapes( + HloParser.get_root_line(l1_apl)) assert sorted(l1_apl_param) == sorted(l1_param_shape + l1_param_shape) assert sorted(l1_apl_root) == sorted(l1_param_shape) diff --git a/tests/shard_parallel/test_manual.py b/tests/shard_parallel/test_manual.py index 34e477c2b..673bb5647 100644 --- a/tests/shard_parallel/test_manual.py +++ b/tests/shard_parallel/test_manual.py @@ -11,6 +11,7 @@ import alpa from alpa import (AutoShardingOption, LocalPhysicalDeviceMesh, ManualShardingOption, ShardParallel, parallelize) +from alpa.testing import HloParser class ManualShardingTest(unittest.TestCase): @@ -38,34 +39,6 @@ def _get_fn_manual_sharding_with(self, batch_argnums=batch_argnums) return parallelized.get_executable(*args).get_hlo_text() - @staticmethod - def _get_param_line(text: str): - text = text[text.find("ENTRY"):] - text = text[:text.find("\n")] - return text - - @staticmethod - def _get_root_line(text: str): - text = text[text.find("ENTRY"):] - text = text[text.find("ROOT"):] - text = text[:text.find("\n")] - return text - - @staticmethod - def _parse_param_shapes(text: str): - # the first one is "ENTRY %xxx (" - params = text.split("param")[1:] - shapes = tuple(map(lambda x: x[x.find("f32"):x.find("]") + 1], params)) - return shapes - - @staticmethod - def _parse_root_shapes(text: str): - tuple_shape = text[text.find("=") + 2:text.find("tuple(")] - # the last one is ')' - shapes = tuple_shape.split("0}")[:-1] - shapes = tuple(map(lambda x: x[x.find("f32"):x.find("{")], shapes)) - return shapes - def test_set_input(self): def fn(a, b): @@ -78,20 +51,20 @@ def fn(a, b): ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) - text = self._get_param_line(text) + text = HloParser.get_param_line(text) assert "param: f32[6,2]" in text and "param.1: f32[6,2]" in text in_axis_resources = (PartitionSpec("data", None), PartitionSpec("data", "model")) ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) - text = self._get_param_line(text) + text = HloParser.get_param_line(text) assert "param: f32[3,4]" in text and "param.1: f32[3,2]" in text in_axis_resources = (None, PartitionSpec("data", None)) ms_option = ManualShardingOption(self.mesh_axis_names, in_axis_resources=in_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a, b) - text = self._get_param_line(text) + text = HloParser.get_param_line(text) assert "param: f32[6,4]" in text and "param.1: f32[3,4]" in text def test_set_output(self): @@ -106,7 +79,7 @@ def fn(a): ms_option = ManualShardingOption(self.mesh_axis_names, out_axis_resources=out_axis_resources) text = self._get_fn_manual_sharding_with(fn, ms_option, a) - text = self._get_root_line(text) + text = HloParser.get_root_line(text) assert ("(f32[3,4]{1,0}, f32[6,4]{1,0}, f32[6,2]{1,0}, f32[3,2]{1,0}" in text) @@ -148,10 +121,10 @@ def loss_fn(params): acc_grad_text = text[:apply_grad_start] apply_grad_text = text[apply_grad_start:] # 1. Accumulate grad: - acc_grad_params = self._get_param_line(acc_grad_text) - acc_grad_param_shapes = self._parse_param_shapes(acc_grad_params) - acc_grad_root = self._get_root_line(acc_grad_text) - acc_grad_root_shapes = self._parse_root_shapes(acc_grad_root) + acc_grad_params = HloParser.get_param_line(acc_grad_text) + acc_grad_param_shapes = HloParser.parse_param_shapes(acc_grad_params) + acc_grad_root = HloParser.get_root_line(acc_grad_text) + acc_grad_root_shapes = HloParser.parse_root_shapes(acc_grad_root) param_shape = ("f32[6,4]", "f32[4]", "f32[4,10]", "f32[10]") # batch_size / num_microbatches / data_parallel @@ -159,10 +132,11 @@ def loss_fn(params): assert acc_grad_param_shapes == param_shape + batch_shape + param_shape assert acc_grad_root_shapes == param_shape # 2. Apply grad: - apply_grad_params = self._get_param_line(apply_grad_text) - apply_grad_param_shapes = self._parse_param_shapes(apply_grad_params) - apply_grad_root = self._get_root_line(apply_grad_text) - apply_grad_root_shapes = self._parse_root_shapes(apply_grad_root) + apply_grad_params = HloParser.get_param_line(apply_grad_text) + apply_grad_param_shapes = HloParser.parse_param_shapes( + apply_grad_params) + apply_grad_root = HloParser.get_root_line(apply_grad_text) + apply_grad_root_shapes = HloParser.parse_root_shapes(apply_grad_root) assert apply_grad_param_shapes == param_shape + param_shape assert apply_grad_root_shapes == param_shape From 1245c1e9d4e8b7ec4df6f3f4ff1ac4a6ba2d8733 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Tue, 27 Dec 2022 03:30:03 +0000 Subject: [PATCH 13/15] update doc and fix typo --- alpa/mesh_executable.py | 2 +- alpa/pipeline_parallel/pipeshard_executable.py | 6 +++--- alpa/testing.py | 5 ++++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/alpa/mesh_executable.py b/alpa/mesh_executable.py index 31a6e75f1..b12452802 100644 --- a/alpa/mesh_executable.py +++ b/alpa/mesh_executable.py @@ -1106,7 +1106,7 @@ def __del__(self): class UtilMeshWorkerExecutable(MeshWorkerExecutable): """Worker executable that runs a manually generated function. It is lighter - than NoralMeshWorkerExecutable as it does not have a StagePlan. + than NormalMeshWorkerExecutable as it does not have a StagePlan. Currently, it is used for concatenate(will be deprecated after we move it to apply_grad) and allgather. diff --git a/alpa/pipeline_parallel/pipeshard_executable.py b/alpa/pipeline_parallel/pipeshard_executable.py index db9750bcb..f61ecdcc6 100644 --- a/alpa/pipeline_parallel/pipeshard_executable.py +++ b/alpa/pipeline_parallel/pipeshard_executable.py @@ -104,7 +104,7 @@ def __init__(self, task.create_resharding_communicators() self.exec_uuid = next_mesh_executable_uuid() - # Create a PipeshardMeshWorkerExecuable for each MeshHostWorker + # Create a PipeshardMeshWorkerExecutable for each MeshHostWorker for mesh_idx, physical_mesh in enumerate(self.mesh_group): mesh_grad_uuids = pipeshard_config.grad_uuids[mesh_idx] for worker in physical_mesh.workers: @@ -119,7 +119,7 @@ def __init__(self, pipeshard_config.reduced_var_uuid_lists[mesh_idx], self.donate_invars[mesh_idx]) worker.put_executable.remote(self.exec_uuid, - PipeshardMeshWorkerExecuable, + PipeshardMeshWorkerExecutable, *args) ##### Compilation Related Functions ##### @@ -425,7 +425,7 @@ def __del__(self): mesh.delete_remote_executable(self.exec_uuid) -class PipeshardMeshWorkerExecuable: +class PipeshardMeshWorkerExecutable: """ An executable that executes static pipeline runtime instructions on a worker. diff --git a/alpa/testing.py b/alpa/testing.py index b0a8d5396..89ae41372 100644 --- a/alpa/testing.py +++ b/alpa/testing.py @@ -364,7 +364,10 @@ def data_loader_input_iter_func(start, end, batch_size): class HloParser: - + """ + Parse Hlo text to check whether the parameter and output has correct + sharding. + """ @staticmethod def get_param_line(text: str): text = text[text.find("ENTRY"):] From eee81bd29da28a7b2f327bf38f1c92b96e282679 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Tue, 27 Dec 2022 04:16:13 +0000 Subject: [PATCH 14/15] format... --- alpa/testing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/alpa/testing.py b/alpa/testing.py index 89ae41372..722a66a9f 100644 --- a/alpa/testing.py +++ b/alpa/testing.py @@ -368,6 +368,7 @@ class HloParser: Parse Hlo text to check whether the parameter and output has correct sharding. """ + @staticmethod def get_param_line(text: str): text = text[text.find("ENTRY"):] From 25a066e37a0925bf808159816189ada630c118bb Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Tue, 27 Dec 2022 04:51:13 +0000 Subject: [PATCH 15/15] fix more typo --- .../alpa/resharding/benchmark_cross_mesh_resharding.py | 6 +++--- tests/pipeline_parallel/test_cross_mesh_resharding.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py b/benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py index 2a37b90f6..ea8ce021a 100644 --- a/benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py +++ b/benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py @@ -20,7 +20,7 @@ SymbolicReshardingTask, SymbolicBroadcastReshardingTask) from alpa.pipeline_parallel.pipeshard_executable import ( AllocateZeroWorkerExecutableConfig, PipelineInstruction, - PipeshardMeshWorkerExecuable) + PipeshardMeshWorkerExecutable) from alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray from alpa.util import get_shard_shape from alpa.timer import timers @@ -179,7 +179,7 @@ def benchmark_one_case_internal( for worker in src_mesh.workers: exec_uuid = next_mesh_executable_uuid() # print(worker, exec_uuid) - worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecuable, + worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [src_uuid], [], [], [], [], [False] * src_mesh.num_devices_per_host) @@ -187,7 +187,7 @@ def benchmark_one_case_internal( for worker in dst_mesh.workers: exec_uuid = next_mesh_executable_uuid() # print(worker, exec_uuid) - worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecuable, + worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [], [dst_uuid], executable_config_lists[worker], [], [], [False] * dst_mesh.num_devices_per_host) diff --git a/tests/pipeline_parallel/test_cross_mesh_resharding.py b/tests/pipeline_parallel/test_cross_mesh_resharding.py index d118a660d..f2b8c2c6b 100644 --- a/tests/pipeline_parallel/test_cross_mesh_resharding.py +++ b/tests/pipeline_parallel/test_cross_mesh_resharding.py @@ -21,7 +21,7 @@ SymbolicReshardingTask, SymbolicBroadcastReshardingTask) from alpa.pipeline_parallel.pipeshard_executable import ( AllocateZeroWorkerExecutableConfig, PipelineInstruction, - PipeshardMeshWorkerExecuable) + PipeshardMeshWorkerExecutable) from alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray from alpa.testing import assert_allclose from alpa.util import get_shard_shape @@ -114,14 +114,14 @@ def test_resharding(var, # Compile Pipeline Executable for worker in src_mesh.workers: exec_uuid = next_mesh_executable_uuid() - worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecuable, + worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [src_uuid], [], [], [], [], [False] * src_mesh.num_devices_per_host) exec_uuids[worker] = exec_uuid for worker in dst_mesh.workers: exec_uuid = next_mesh_executable_uuid() - worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecuable, + worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable, instruction_lists[worker], [], [dst_uuid], executable_config_lists[worker], [], [], [False] * dst_mesh.num_devices_per_host)