Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[DataLoader] Use PlacementSpecs in data loader (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 3, 2022
1 parent 724a148 commit 4a68171
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 95 deletions.
13 changes: 9 additions & 4 deletions alpa/create_state_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax.core import ClosedJaxpr, Var
from jax.interpreters import partial_eval as pe, pxla
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
import numpy as np

from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup
from alpa.mesh_executable import (NormalMeshDriverExecutable,
Expand All @@ -31,11 +32,13 @@ def __init__(self,
mesh_group: PhysicalDeviceMeshGroup,
pipeshard_config: PipeshardConfig,
target_placement_specs: Sequence[PlacementSpec],
in_tree: PyTreeDef,
out_tree: Optional[PyTreeDef] = None,
static_argnums: Optional[Sequence[int]] = None):
super().__init__(mesh_group=mesh_group,
pipeshard_config=pipeshard_config,
num_batch=1,
in_tree=in_tree,
out_tree=out_tree,
static_argnums=static_argnums)
self.target_placement_specs = target_placement_specs
Expand All @@ -58,7 +61,7 @@ def launch_on_driver(self, *args):
indices = pxla.spec_to_indices(array.shape, sharding_spec)
dis_array = self.mesh_group[mesh_id].shard_args_to_arrays(
(array.aval,), (indices,), (sharding_spec,),
(array,))[0]
(np.asarray(array),))[0]
distributed_arrays.append(dis_array)
outputs[idx] = ReplicatedDistributedArray(
meshes, distributed_arrays)
Expand Down Expand Up @@ -104,7 +107,8 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk,

return NormalMeshDriverExecutable(physical_mesh, hlo_module, stage_plan,
avals, out_avals,
[False] * len(avals))
[False] * len(avals), static_argnums,
in_tree, out_tree)
else:
# Construct a new pipelined jaxpr
outvars = jaxpr.outvars
Expand All @@ -126,14 +130,15 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk,

# Compile a pipeshard executable with predefined output shardings
pipeshard_config = compile_pipeshard_executable_internal(
new_jaxpr, None, 1, in_tree, [False] * len(avals),
[False] * len(avals), executable.mesh_group.parent, 1, "inference",
new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals),
executable.mesh_group.parent, 1, "inference",
AutoShardingOption(enable_auto_sharding=False),
UniformStageOption(), name, output_shardings)

return CreateStateExecutable(mesh_group=executable.mesh_group,
pipeshard_config=pipeshard_config,
target_placement_specs=placement_specs,
in_tree=in_tree,
out_tree=out_tree_thunk(),
static_argnums=static_argnums)

Expand Down
59 changes: 27 additions & 32 deletions alpa/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,40 @@
import itertools

import jax
from jax.interpreters import pxla, xla
from jax.interpreters import pxla
import numpy as np
import ray

from alpa.device_mesh import (LocalPhysicalDeviceMesh, DistributedArray,
from alpa.device_mesh import (DistributedArray, get_global_physical_mesh,
create_remote_array_refs)


class DataLoader:
"""A driver-only dataloader that loads data on the driver process and
sends the data to all workers."""

def __init__(self,
input_iter,
sharding_specs,
physical_mesh=None,
prefetch_size=1):
def __init__(self, input_iter, placement_specs, prefetch_size=1):
self.input_iter = input_iter
self.sharding_specs = sharding_specs
self.prefetch_size = prefetch_size

if physical_mesh is None:
self.physical_mesh = LocalPhysicalDeviceMesh()
else:
self.physical_mesh = physical_mesh
self.physical_mesh = get_global_physical_mesh()
self.avals = []
self.indices = []
self.sharding_specs = []
for ps in jax.tree_leaves(placement_specs):
assert len(ps.mesh_ids) == 1
assert ps.mesh_ids[0] == self.physical_mesh.mesh_id

self.avals.append(ps.aval)
self.sharding_specs.append(ps.sharding_specs[0])
self.indices.append(
tuple(ps.sharding_specs[0].indices(ps.aval.shape)))

self.queue = collections.deque()
self.first_iter = True
self.avals = None
self.indices = None

def enqueue(self, num_batches):
for batch in itertools.islice(self.input_iter, num_batches):
flatten_args, tree = jax.tree_flatten(batch)

# Cache meta info
if self.first_iter:
self.first_iter = False
self.avals = [xla.abstractify(a) for a in flatten_args]
self.indices = [
tuple(spec.indices(aval.shape))
for spec, aval in zip(self.sharding_specs, self.avals)
]

new_args = self.physical_mesh.shard_args_to_arrays(
self.avals, self.indices, self.sharding_specs, flatten_args)
self.queue.append(jax.tree_unflatten(tree, new_args))
Expand Down Expand Up @@ -112,14 +102,19 @@ def __init__(self,
batch_size,
num_samples,
input_iter_func,
avals,
sharding_specs,
physical_mesh,
placement_specs,
prefetch_size=1):
indices = [
tuple(np.ravel(spec.indices(aval.shape)))
for spec, aval in zip(sharding_specs, avals)
]
physical_mesh = get_global_physical_mesh()
avals = []
sharding_specs = []
indices = []
for ps in jax.tree_leaves(placement_specs):
avals.append(ps.aval)
assert len(ps.mesh_ids) == 1
assert ps.mesh_ids[0] == physical_mesh.mesh_id
sharding_specs.append(ps.sharding_specs[0])
indices.append(np.ravel(ps.sharding_specs[0].indices(
ps.aval.shape)))

self.uuid = next_mesh_data_loader_uuid()
self.physical_mesh = physical_mesh
Expand Down
2 changes: 1 addition & 1 deletion alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def __init__(self, devices: Sequence["Device"] = None):
self.devices = devices if devices is not None else xb.local_devices()
self.num_hosts = 1
self.num_devices_per_host = len(self.devices)
self.mesh_id = 0
self.mesh_id = -1
self.device_strs = []
self.operation_executables = {}

Expand Down
1 change: 1 addition & 0 deletions alpa/global_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self):

########## Options of logging ##########
self.print_compilation_time = False
self.print_auto_layer_stats = False


global_config = GlobalConfig()
Expand Down
74 changes: 59 additions & 15 deletions alpa/mesh_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,19 @@ def launch_on_driver(self, *args, **kwargs):
raise NotImplementedError()

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
raise NotImplementedError()

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
raise NotImplementedError()

def preshard_dynamic_args(self, *args):
Expand Down Expand Up @@ -160,6 +172,15 @@ def sync_func_worker():
return sync_func_worker


def wrap_to_placement_spec_tree(physical_mesh, avals, sharding_specs, pytree):
"""Wrap avals and sharding specs to a pytree of placement specs."""
placement_specs = [
PlacementSpec(aval, (physical_mesh.mesh_id,), (sharding_spec,))
for aval, sharding_spec in zip(avals, sharding_specs)
]
return tree_unflatten(pytree, placement_specs)


class NormalMeshDriverExecutable(MeshDriverExecutable):
"""The driver part of a normal mesh executable."""

Expand Down Expand Up @@ -284,13 +305,24 @@ def launch_on_driver(self, *args, **kwargs):
return self.outs_handler(output_bufs)

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
placement_specs = [
PlacementSpec(aval, (self.physical_mesh.mesh_id,),
(sharding_spec,)) for aval, sharding_spec in zip(
self.avals, self.input_sharding_specs)
]
return tree_unflatten(self.in_tree, placement_specs)
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.avals,
self.input_sharding_specs,
self.in_tree)

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals,
self.output_sharding_specs,
self.out_tree)

def preshard_dynamic_args(self, *args):
"""Pre-shard the input arguments."""
Expand Down Expand Up @@ -505,6 +537,7 @@ def __init__(self,
out_avals,
physical_mesh.num_devices,
logical_mesh_shape))
self.output_sharding_specs = output_sharding_specs
num_grads = len(grad_avals)
assert accumulate_grad_input_sharding_specs[
-num_grads:] == grad_sharding_specs
Expand Down Expand Up @@ -701,13 +734,24 @@ def launch_on_driver(self, *args):
return self.outs_handler(output_bufs)

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
placement_specs = [
PlacementSpec(aval, (self.physical_mesh.mesh_id,),
(sharding_spec,)) for aval, sharding_spec in zip(
self.avals, self.global_arg_sharding_specs)
]
return tree_unflatten(self.in_tree, placement_specs)
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.avals,
self.global_arg_sharding_specs,
self.in_tree)

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals,
self.output_sharding_specs,
self.out_tree)

def get_total_allocation_size(self):
"""Get the total allocated memory size of this executable."""
Expand Down
19 changes: 9 additions & 10 deletions alpa/pipeline_parallel/compile_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ def compile_pipeshard_executable(
debug_compilation_time("trace")

pipeshard_config = compile_pipeshard_executable_internal(
closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, in_tree,
donated_invars, batch_invars, virtual_mesh, num_microbatch,
pipeline_schedule, default_as_option, stage_option, name_base, None)
closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars,
batch_invars, virtual_mesh, num_microbatch, pipeline_schedule,
default_as_option, stage_option, name_base, None)

executable = PipeshardDriverExecutable(
mesh_group=virtual_mesh.launched_physical_mesh_group,
pipeshard_config=pipeshard_config,
num_batch=num_microbatch,
in_tree=in_tree,
out_tree=out_tree_thunk(),
static_argnums=static_argnums)
debug_compilation_time("driver executable")
Expand All @@ -91,11 +92,10 @@ def compile_pipeshard_executable(
def compile_pipeshard_executable_internal(
closed_jaxpr: ClosedJaxpr,
full_batch_closed_jaxpr: Optional[ClosedJaxpr], micro_batch_size: int,
in_tree: PyTreeDef, donated_invars: Sequence[bool],
batch_invars: Sequence[bool], virtual_mesh: VirtualPhysicalMesh,
num_microbatch: int, pipeline_schedule: str,
default_as_option: AutoShardingOption, stage_option: StageOption,
name_base: str,
donated_invars: Sequence[bool], batch_invars: Sequence[bool],
virtual_mesh: VirtualPhysicalMesh, num_microbatch: int,
pipeline_schedule: str, default_as_option: AutoShardingOption,
stage_option: StageOption, name_base: str,
output_shardings: Optional[Sequence[pxla.ShardingSpec]]):
global_invars = closed_jaxpr.jaxpr.invars
global_outvars = closed_jaxpr.jaxpr.outvars
Expand Down Expand Up @@ -224,8 +224,7 @@ def compile_pipeshard_executable_internal(
schedule=schedule,
is_batch=batch_invars,
num_batch=num_microbatch,
flop_count=total_flops,
in_tree=in_tree).compile()
flop_count=total_flops).compile()

debug_compilation_time("runtime emitter")
return pipeshard_config
Expand Down
6 changes: 5 additions & 1 deletion alpa/pipeline_parallel/layer_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
new_jaxpr_eqn, gensym, raise_to_shaped, get_aval)
from jax.interpreters.partial_eval import remat_call_p

from alpa.global_env import global_config
from alpa.pipeline_parallel.layer_stats import (global_invar_size,
is_nontrivial, eqn_flops,
heavy_count)
heavy_count,
log_layer_slicing_stats)
from alpa.pipeline_parallel.primitive_def import (pipeline_p,
mark_pipeline_jaxpreqn)
from alpa.util import (clone_jaxpr, slices_to_jaxpr, OrderedSet,
Expand Down Expand Up @@ -448,6 +450,8 @@ def wrapped(*args):
eps,
costs,
cost_criteria=cost_criteria)
if global_config.print_auto_layer_stats:
log_layer_slicing_stats(jaxpr, sliced_eqns)
else:
sliced_eqns = slice_eqns_by_layer_boundary(jaxpr)

Expand Down
19 changes: 17 additions & 2 deletions alpa/pipeline_parallel/pipeshard_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,23 @@ def __init__(self,
mesh_group: PhysicalDeviceMeshGroup,
pipeshard_config: PipeshardConfig,
num_batch: int,
in_tree: PyTreeDef,
out_tree: Optional[PyTreeDef] = None,
static_argnums: Optional[Sequence[int]] = None):
##### Input arguments #####
self.mesh_group = mesh_group
self.num_mesh = len(mesh_group)
self.num_batch = num_batch
self.static_argnums = static_argnums
self.in_tree = in_tree
self.out_tree = out_tree

##### For debugging and serialization #####
self.stages = pipeshard_config.xla_stages
self.schedule = pipeshard_config.schedule
self.flop_count = pipeshard_config.flop_count
self.input_placement_specs = pipeshard_config.input_placement_specs
self.output_placement_specs = pipeshard_config.output_placement_specs
# List[stage_idx -> str]
self.fully_optimized_hlo_texts = []
self.sharding_annotated_hlo_texts = (
Expand Down Expand Up @@ -192,8 +195,20 @@ def launch_on_driver(self, *args):
return self.outs_handler(self.mesh_group, output_bufs)

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
return self.input_placement_specs
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
return tree_unflatten(self.in_tree, self.input_placement_specs)

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
return tree_unflatten(self.out_tree, self.output_placement_specs)

def __call__(self, *args):
"""Fast call without signature matching."""
Expand Down
Loading

0 comments on commit 4a68171

Please sign in to comment.