Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[RFC][FEATURE] support manual parallelization strategy in shard parallel #816

Merged
merged 18 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions alpa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AutoStageOption,
UniformStageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption
from alpa.shard_parallel.manual_sharding import ManualShardingOption
from alpa.serialization import save_checkpoint, restore_checkpoint
from alpa.timer import timers
from alpa.version import __version__
2 changes: 1 addition & 1 deletion alpa/create_state_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk,
new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals),
executable.mesh_group.parent, 1, "inference",
AutoShardingOption(enable_auto_sharding=False),
UniformStageOption(), name, None, output_shardings, None)
UniformStageOption(), name, None, output_shardings, None, None)

return CreateStateExecutable(mesh_group=executable.mesh_group,
pipeshard_config=pipeshard_config,
Expand Down
2 changes: 1 addition & 1 deletion alpa/follow_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ def is_leave(x):
fun, in_tree, out_tree_thunk, static_argnums, donated_invars,
batch_invars, mesh, num_micro_batches, pipeline_schedule,
AutoShardingOption(enable_auto_sharding=False), layer_option,
UniformStageOption(), input_shardings, None, *avals)
UniformStageOption(), input_shardings, None, None, *avals)
2 changes: 1 addition & 1 deletion alpa/mesh_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def __del__(self):

class UtilMeshWorkerExecutable(MeshWorkerExecutable):
"""Worker executable that runs a manually generated function. It is lighter
than NoralMeshWorkerExecutable as it does not have a StagePlan.
than NormalMeshWorkerExecutable as it does not have a StagePlan.

Currently, it is used for concatenate(will be deprecated after we move it
to apply_grad) and allgather.
Expand Down
31 changes: 19 additions & 12 deletions alpa/parallel_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
UniformStageOption)
from alpa.shard_parallel.auto_sharding import AutoShardingOption, LogicalDeviceMesh
from alpa.shard_parallel.compile_executable import compile_shard_executable
from alpa.shard_parallel.manual_sharding import ManualShardingOption

traceback_util.register_exclusion(__file__)

Expand Down Expand Up @@ -75,10 +76,12 @@ def __init__(self,
devices: Optional[Union[LogicalDeviceMesh,
PhysicalDeviceMesh]] = None,
num_micro_batches: Optional[int] = None,
auto_sharding_option: Optional[AutoShardingOption] = None):
auto_sharding_option: Optional[AutoShardingOption] = None,
manual_sharding_option: Optional[ManualShardingOption] = None):
self.devices = devices
self.num_micro_batches = num_micro_batches
self.as_option = auto_sharding_option or AutoShardingOption()
self.ms_option = manual_sharding_option

def compile_executable(
self,
Expand Down Expand Up @@ -106,7 +109,7 @@ def compile_executable(
static_argnums, donated_invars,
batch_invars, mesh,
self.num_micro_batches, self.as_option,
*avals)
self.ms_option, *avals)


class DataParallel(ShardParallel):
Expand Down Expand Up @@ -179,15 +182,16 @@ class PipeshardParallel(ParallelMethod):
"""

def __init__(
self,
devices: Optional[VirtualPhysicalMesh] = None,
num_micro_batches: int = 1,
default_auto_sharding_option: Optional[AutoShardingOption] = None,
pipeline_schedule: str = "1f1b",
layer_option: Optional[Union[LayerOption, str]] = None,
stage_option: Optional[Union[StageOption, str]] = None,
stage_input_shardings: Optional[Sequence[Sequence[
pxla.ShardingSpec]]] = None):
self,
devices: Optional[VirtualPhysicalMesh] = None,
num_micro_batches: int = 1,
default_auto_sharding_option: Optional[AutoShardingOption] = None,
pipeline_schedule: str = "1f1b",
layer_option: Optional[Union[LayerOption, str]] = None,
stage_option: Optional[Union[StageOption, str]] = None,
stage_input_shardings: Optional[Sequence[Sequence[
pxla.ShardingSpec]]] = None,
manual_sharding_option: ManualShardingOption = None):
self.devices = devices
self.num_micro_batches = num_micro_batches
self.as_option = (default_auto_sharding_option or
Expand All @@ -209,6 +213,9 @@ def __init__(
stage_option = UniformStageOption()
self.stage_option = stage_option or UniformStageOption()
self.stage_input_shardings = stage_input_shardings
assert not (stage_input_shardings is not None and
manual_sharding_option is not None)
self.manual_sharding_option = manual_sharding_option

def compile_executable(
self,
Expand All @@ -234,7 +241,7 @@ def compile_executable(
fun, in_tree, out_tree_thunk, static_argnums, donated_invars,
batch_invars, mesh, self.num_micro_batches, self.pipeline_schedule,
self.as_option, self.layer_option, self.stage_option, None,
self.stage_input_shardings, *avals)
self.stage_input_shardings, self.manual_sharding_option, *avals)


def get_3d_parallel_method(num_micro_batches: int,
Expand Down
110 changes: 104 additions & 6 deletions alpa/pipeline_parallel/compile_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Sequence, Optional

from jax import linear_util as lu
from jax._src.lib import xla_client as xc
from jax.core import gensym, AbstractValue, ClosedJaxpr
from jax.interpreters import pxla
from jax.tree_util import PyTreeDef
Expand Down Expand Up @@ -32,7 +33,12 @@
from alpa.pipeline_parallel.stage_construction import (
cluster_layers_and_slice_mesh, StageOption)
from alpa.pipeline_parallel.stage_profiling import CompileWorkerPool
from alpa.shard_parallel.auto_sharding import AutoShardingOption
from alpa.shard_parallel.auto_sharding import (AutoShardingOption,
hlo_sharding_to_sharding_spec)
from alpa.shard_parallel.manual_sharding import (ManualShardingOption,
ParsedManualShardingOption,
get_flatten_axis_resources,
parsed_spec_to_opsharding)
from alpa.util import (get_var_mapping, trace_jaxpr_with_micro_batch,
OrderedSet, GradFuncTransformContext)

Expand All @@ -49,6 +55,7 @@ def compile_pipeshard_executable(
layer_option: LayerOption, stage_option: StageOption,
global_input_shardings: Optional[Sequence[pxla.ShardingSpec]],
stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]],
manual_shard_options: Optional[ManualShardingOption],
*avals: Sequence[AbstractValue]):
"""
Compile a callable for pipeshard parallel which combines
Expand All @@ -60,6 +67,8 @@ def compile_pipeshard_executable(
input vars.
stage_input_shardings: Forcibly set sharding specs of input vars of
each stage.
manual_sharding_options: pjit style sharding constraints of global input
vars.
"""
if global_config.backend == "tpu":
raise NotImplementedError("Pipeshard Parallel for tpu is not supported")
Expand Down Expand Up @@ -92,19 +101,27 @@ def compile_pipeshard_executable(
fun.f = f_backup
debug_compilation_time("trace")

# flatten manual sharding axis resources
out_tree = out_tree_thunk()
if manual_shard_options is not None:
assert global_input_shardings is None
parsed_ms_option = get_flatten_axis_resources(manual_shard_options,
in_tree, out_tree)
else:
parsed_ms_option = None
pipeshard_config = compile_pipeshard_executable_internal(
closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars,
batch_invars, virtual_mesh, num_microbatch, pipeline_schedule,
default_as_option, stage_option, name_base, global_input_shardings,
None, stage_input_shardings)
None, stage_input_shardings, parsed_ms_option)

executable = PipeshardDriverExecutable(
mesh_group=virtual_mesh.launched_physical_mesh_group,
pipeshard_config=pipeshard_config,
num_batch=num_microbatch,
layer_option=layer_option,
in_tree=in_tree,
out_tree=out_tree_thunk(),
out_tree=out_tree,
static_argnums=static_argnums)
debug_compilation_time("driver executable")
return executable
Expand All @@ -119,7 +136,8 @@ def compile_pipeshard_executable_internal(
stage_option: StageOption, name_base: str,
global_input_shardings: Optional[Sequence[pxla.ShardingSpec]],
global_output_shardings: Optional[Sequence[pxla.ShardingSpec]],
stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]]):
stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]],
parsed_manual_sharding_option: Optional[ParsedManualShardingOption]):
"""
Args:
fun: The function to be parallelized.
Expand Down Expand Up @@ -207,6 +225,8 @@ def compile_pipeshard_executable_internal(
raise ValueError(f"Invalid schedule: {pipeline_schedule}")

# Forcibly set the sharding specs of global invars and outvars.
# FIXME(yonghao): the invar can appear on multiple meshes and thus different
# sharding specs
if global_input_shardings:
assert len(global_input_shardings) == len(global_invars)
input_sharding_dict = dict(zip(global_invars, global_input_shardings))
Expand All @@ -218,6 +238,17 @@ def compile_pipeshard_executable_internal(
global_output_shardings))
else:
output_sharding_dict = {}
if parsed_manual_sharding_option is not None:
assert (global_input_shardings is None and
global_output_shardings is None)
(input_sharding_dicts,
output_sharding_dicts) = get_manual_input_output_sharding_specs(
jax_all_stages, [mesh.shape for mesh in sliced_virtual_meshes],
parsed_manual_sharding_option, global_invars, global_outvars,
schedule.stage_mesh_mapping)
else:
input_sharding_dicts = [input_sharding_dict] * num_meshes
output_sharding_dicts = [output_sharding_dict] * num_meshes

# Call auto-sharding pass to shard each stage
xla_stages, total_flops = shard_each_stage(
Expand All @@ -226,7 +257,7 @@ def compile_pipeshard_executable_internal(
donate_invars_dict, num_microbatch,
manual_stage_option.submesh_logical_shapes,
manual_stage_option.submesh_autosharding_option_dicts,
default_as_option, input_sharding_dict, output_sharding_dict,
default_as_option, input_sharding_dicts, output_sharding_dicts,
stage_input_shardings, name_base, gensym_func)
total_flops *= num_microbatch
debug_compilation_time("shard stages")
Expand Down Expand Up @@ -320,11 +351,76 @@ def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr,
accumulator_mapping, acc_grad_invars, acc_grad_outvars)


def get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option,
global_invars, global_outvars,
stage_to_mesh):
"""
Split user assigned input and output PartitionSpec into sharding specs for
each pipeline stage.
"""
invar_set = set(global_invars)
outvar_set = set(global_outvars)
var_to_pspec = {}
handle_invar = False
handle_outvar = False
if ms_option.in_parsed_pspec is not None:
var_to_pspec.update(dict(zip(global_invars, ms_option.in_parsed_pspec)))
handle_invar = True
if ms_option.out_parsed_pspec is not None:
var_to_pspec.update(
dict(zip(global_outvars, ms_option.out_parsed_pspec)))
handle_outvar = True
submesh_axis_names = ms_option.submesh_axis_names
if submesh_axis_names is None:
submesh_axis_names = [ms_option.mesh_axis_names] * len(mesh_shapes)

def get_vars_to_sharding_specs(variables, mesh_shape, mesh_axis_names):
parsed_specs = [var_to_pspec[v] for v in variables]
avals = [v.aval for v in variables]
var_op_shardings = parsed_spec_to_opsharding(parsed_specs, avals,
mesh_shape,
mesh_axis_names)
var_sharding_specs = [
hlo_sharding_to_sharding_spec(xc.HloSharding.from_proto(ops), aval,
mesh_shape)
for ops, aval in zip(var_op_shardings, avals)
]
return dict(zip(variables, var_sharding_specs))

invar_shardings = [{}] * len(mesh_shapes)
outvar_shardings = [{}] * len(mesh_shapes)
for stage_idx, stage in enumerate(stages):
mesh_idx = stage_to_mesh[stage_idx]
assert len(mesh_idx) == 1
mesh_idx = list(mesh_idx)[0]
mesh_shape = mesh_shapes[mesh_idx]
mesh_axis_names = submesh_axis_names[mesh_idx]
# invars
if handle_invar:
invar_in_global = [var for var in stage.invars if var in invar_set]
stage_invar_shardings = get_vars_to_sharding_specs(
invar_in_global, mesh_shape, mesh_axis_names)
else:
stage_invar_shardings = {}
# outvars
if handle_outvar:
outvar_in_global = [
var for var in stage.outvars if var in outvar_set
]
stage_outvar_shardings = get_vars_to_sharding_specs(
outvar_in_global, mesh_shape, mesh_axis_names)
else:
stage_outvar_shardings = {}
invar_shardings[mesh_idx].update(stage_invar_shardings)
outvar_shardings[mesh_idx].update(stage_outvar_shardings)
return invar_shardings, outvar_shardings


def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes,
accumulator_mapping, global_invars, acc_grad_outvars,
donate_invars_dict, num_microbatch, logical_mesh_shapes,
autosharding_option_dicts, default_as_option,
input_sharding_dict, output_sharding_dict,
input_sharding_dicts, output_sharding_dicts,
stage_input_shardings, name_base, gensym_func):
"""Run intra-op parallelism compilation for a stage."""
# Initialize donation mapping
Expand Down Expand Up @@ -362,6 +458,8 @@ def shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes,
compile_intermediate = [None] * num_meshes
total_flops = 0
for mesh_idx in range(num_meshes):
input_sharding_dict = input_sharding_dicts[mesh_idx]
output_sharding_dict = output_sharding_dicts[mesh_idx]
virtual_mesh = virtual_meshes[mesh_idx]
logical_mesh = virtual_mesh.get_logical_mesh(
logical_mesh_shapes[mesh_idx])
Expand Down
2 changes: 0 additions & 2 deletions alpa/pipeline_parallel/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,14 +741,12 @@ def generate_sharded_xla_computations_arguments(

if input_sharding_dict:
sharding_protos = []
sharding_specs = []
for x in invars:
spec = input_sharding_dict.get(x, None)
if spec is None:
sharding_protos.append(undefined_sharding_spec_proto())
else:
sharding_protos.append(spec.sharding_proto())
sharding_specs.append(spec)
hlo.set_input_shardings(sharding_protos)

if output_sharding_dict:
Expand Down
6 changes: 3 additions & 3 deletions alpa/pipeline_parallel/pipeshard_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self,
task.create_resharding_communicators()

self.exec_uuid = next_mesh_executable_uuid()
# Create a PipeshardMeshWorkerExecuable for each MeshHostWorker
# Create a PipeshardMeshWorkerExecutable for each MeshHostWorker
for mesh_idx, physical_mesh in enumerate(self.mesh_group):
mesh_grad_uuids = pipeshard_config.grad_uuids[mesh_idx]
for worker in physical_mesh.workers:
Expand All @@ -119,7 +119,7 @@ def __init__(self,
pipeshard_config.reduced_var_uuid_lists[mesh_idx],
self.donate_invars[mesh_idx])
worker.put_executable.remote(self.exec_uuid,
PipeshardMeshWorkerExecuable,
PipeshardMeshWorkerExecutable,
*args)

##### Compilation Related Functions #####
Expand Down Expand Up @@ -425,7 +425,7 @@ def __del__(self):
mesh.delete_remote_executable(self.exec_uuid)


class PipeshardMeshWorkerExecuable:
class PipeshardMeshWorkerExecutable:
"""
An executable that executes static pipeline runtime instructions on a
worker.
Expand Down
5 changes: 3 additions & 2 deletions alpa/shard_parallel/auto_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def run_auto_sharding_pass(
device_assignment=np.arange(num_devices).reshape((1, -1)),
use_spmd_partitioning=True,
parameter_is_tupled_arguments=False,
build_random_seed=build_random_seed)
build_random_seed=build_random_seed,
spmd_propagation_to_outputs=hlo.is_manually_annotated)

# Set configs for force_zero_stage_3
if as_option.force_zero_stage_3:
Expand Down Expand Up @@ -381,7 +382,7 @@ def run_spmd_partitioner_pass(
rewrite_grad_acc_indices: The indices of tensors in output that are
gradients.
"""
assert hlo.is_sharding_annotated(), f"{hlo.status}"
assert hlo.is_sharding_annotated(), hlo.status
compile_options = get_compile_options(
num_replicas=1,
num_partitions=num_devices,
Expand Down
Loading