Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[DataLoader] Use PlacementSpecs in data loader #581

Merged
merged 2 commits into from
Jul 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions alpa/create_state_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax.core import ClosedJaxpr, Var
from jax.interpreters import partial_eval as pe, pxla
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
import numpy as np

from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup
from alpa.mesh_executable import (NormalMeshDriverExecutable,
Expand All @@ -31,11 +32,13 @@ def __init__(self,
mesh_group: PhysicalDeviceMeshGroup,
pipeshard_config: PipeshardConfig,
target_placement_specs: Sequence[PlacementSpec],
in_tree: PyTreeDef,
out_tree: Optional[PyTreeDef] = None,
static_argnums: Optional[Sequence[int]] = None):
super().__init__(mesh_group=mesh_group,
pipeshard_config=pipeshard_config,
num_batch=1,
in_tree=in_tree,
out_tree=out_tree,
static_argnums=static_argnums)
self.target_placement_specs = target_placement_specs
Expand All @@ -58,7 +61,7 @@ def launch_on_driver(self, *args):
indices = pxla.spec_to_indices(array.shape, sharding_spec)
dis_array = self.mesh_group[mesh_id].shard_args_to_arrays(
(array.aval,), (indices,), (sharding_spec,),
(array,))[0]
(np.asarray(array),))[0]
distributed_arrays.append(dis_array)
outputs[idx] = ReplicatedDistributedArray(
meshes, distributed_arrays)
Expand Down Expand Up @@ -104,7 +107,8 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk,

return NormalMeshDriverExecutable(physical_mesh, hlo_module, stage_plan,
avals, out_avals,
[False] * len(avals))
[False] * len(avals), static_argnums,
in_tree, out_tree)
else:
# Construct a new pipelined jaxpr
outvars = jaxpr.outvars
Expand All @@ -126,14 +130,15 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk,

# Compile a pipeshard executable with predefined output shardings
pipeshard_config = compile_pipeshard_executable_internal(
new_jaxpr, None, 1, in_tree, [False] * len(avals),
[False] * len(avals), executable.mesh_group.parent, 1, "inference",
new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals),
executable.mesh_group.parent, 1, "inference",
AutoShardingOption(enable_auto_sharding=False),
UniformStageOption(), name, output_shardings)

return CreateStateExecutable(mesh_group=executable.mesh_group,
pipeshard_config=pipeshard_config,
target_placement_specs=placement_specs,
in_tree=in_tree,
out_tree=out_tree_thunk(),
static_argnums=static_argnums)

Expand Down
59 changes: 27 additions & 32 deletions alpa/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,40 @@
import itertools

import jax
from jax.interpreters import pxla, xla
from jax.interpreters import pxla
import numpy as np
import ray

from alpa.device_mesh import (LocalPhysicalDeviceMesh, DistributedArray,
from alpa.device_mesh import (DistributedArray, get_global_physical_mesh,
create_remote_array_refs)


class DataLoader:
"""A driver-only dataloader that loads data on the driver process and
sends the data to all workers."""

def __init__(self,
input_iter,
sharding_specs,
physical_mesh=None,
prefetch_size=1):
def __init__(self, input_iter, placement_specs, prefetch_size=1):
self.input_iter = input_iter
self.sharding_specs = sharding_specs
self.prefetch_size = prefetch_size

if physical_mesh is None:
self.physical_mesh = LocalPhysicalDeviceMesh()
else:
self.physical_mesh = physical_mesh
self.physical_mesh = get_global_physical_mesh()
self.avals = []
self.indices = []
self.sharding_specs = []
for ps in jax.tree_leaves(placement_specs):
assert len(ps.mesh_ids) == 1
assert ps.mesh_ids[0] == self.physical_mesh.mesh_id

self.avals.append(ps.aval)
self.sharding_specs.append(ps.sharding_specs[0])
self.indices.append(
tuple(ps.sharding_specs[0].indices(ps.aval.shape)))

self.queue = collections.deque()
self.first_iter = True
self.avals = None
self.indices = None

def enqueue(self, num_batches):
for batch in itertools.islice(self.input_iter, num_batches):
flatten_args, tree = jax.tree_flatten(batch)

# Cache meta info
if self.first_iter:
self.first_iter = False
self.avals = [xla.abstractify(a) for a in flatten_args]
self.indices = [
tuple(spec.indices(aval.shape))
for spec, aval in zip(self.sharding_specs, self.avals)
]

new_args = self.physical_mesh.shard_args_to_arrays(
self.avals, self.indices, self.sharding_specs, flatten_args)
self.queue.append(jax.tree_unflatten(tree, new_args))
Expand Down Expand Up @@ -112,14 +102,19 @@ def __init__(self,
batch_size,
num_samples,
input_iter_func,
avals,
sharding_specs,
physical_mesh,
placement_specs,
prefetch_size=1):
indices = [
tuple(np.ravel(spec.indices(aval.shape)))
for spec, aval in zip(sharding_specs, avals)
]
physical_mesh = get_global_physical_mesh()
avals = []
sharding_specs = []
indices = []
for ps in jax.tree_leaves(placement_specs):
avals.append(ps.aval)
assert len(ps.mesh_ids) == 1
assert ps.mesh_ids[0] == physical_mesh.mesh_id
sharding_specs.append(ps.sharding_specs[0])
indices.append(np.ravel(ps.sharding_specs[0].indices(
ps.aval.shape)))

self.uuid = next_mesh_data_loader_uuid()
self.physical_mesh = physical_mesh
Expand Down
2 changes: 1 addition & 1 deletion alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ def __init__(self, devices: Sequence["Device"] = None):
self.devices = devices if devices is not None else xb.local_devices()
self.num_hosts = 1
self.num_devices_per_host = len(self.devices)
self.mesh_id = 0
self.mesh_id = -1
self.device_strs = []
self.operation_executables = {}

Expand Down
1 change: 1 addition & 0 deletions alpa/global_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self):

########## Options of logging ##########
self.print_compilation_time = False
self.print_auto_layer_stats = False


global_config = GlobalConfig()
Expand Down
74 changes: 59 additions & 15 deletions alpa/mesh_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,19 @@ def launch_on_driver(self, *args, **kwargs):
raise NotImplementedError()

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
raise NotImplementedError()

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
raise NotImplementedError()

def preshard_dynamic_args(self, *args):
Expand Down Expand Up @@ -160,6 +172,15 @@ def sync_func_worker():
return sync_func_worker


def wrap_to_placement_spec_tree(physical_mesh, avals, sharding_specs, pytree):
"""Wrap avals and sharding specs to a pytree of placement specs."""
placement_specs = [
PlacementSpec(aval, (physical_mesh.mesh_id,), (sharding_spec,))
for aval, sharding_spec in zip(avals, sharding_specs)
]
return tree_unflatten(pytree, placement_specs)


class NormalMeshDriverExecutable(MeshDriverExecutable):
"""The driver part of a normal mesh executable."""

Expand Down Expand Up @@ -284,13 +305,24 @@ def launch_on_driver(self, *args, **kwargs):
return self.outs_handler(output_bufs)

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
placement_specs = [
PlacementSpec(aval, (self.physical_mesh.mesh_id,),
(sharding_spec,)) for aval, sharding_spec in zip(
self.avals, self.input_sharding_specs)
]
return tree_unflatten(self.in_tree, placement_specs)
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.avals,
self.input_sharding_specs,
self.in_tree)

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals,
self.output_sharding_specs,
self.out_tree)

def preshard_dynamic_args(self, *args):
"""Pre-shard the input arguments."""
Expand Down Expand Up @@ -505,6 +537,7 @@ def __init__(self,
out_avals,
physical_mesh.num_devices,
logical_mesh_shape))
self.output_sharding_specs = output_sharding_specs
num_grads = len(grad_avals)
assert accumulate_grad_input_sharding_specs[
-num_grads:] == grad_sharding_specs
Expand Down Expand Up @@ -701,13 +734,24 @@ def launch_on_driver(self, *args):
return self.outs_handler(output_bufs)

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
placement_specs = [
PlacementSpec(aval, (self.physical_mesh.mesh_id,),
(sharding_spec,)) for aval, sharding_spec in zip(
self.avals, self.global_arg_sharding_specs)
]
return tree_unflatten(self.in_tree, placement_specs)
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.avals,
self.global_arg_sharding_specs,
self.in_tree)

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals,
self.output_sharding_specs,
self.out_tree)

def get_total_allocation_size(self):
"""Get the total allocated memory size of this executable."""
Expand Down
19 changes: 9 additions & 10 deletions alpa/pipeline_parallel/compile_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ def compile_pipeshard_executable(
debug_compilation_time("trace")

pipeshard_config = compile_pipeshard_executable_internal(
closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, in_tree,
donated_invars, batch_invars, virtual_mesh, num_microbatch,
pipeline_schedule, default_as_option, stage_option, name_base, None)
closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars,
batch_invars, virtual_mesh, num_microbatch, pipeline_schedule,
default_as_option, stage_option, name_base, None)

executable = PipeshardDriverExecutable(
mesh_group=virtual_mesh.launched_physical_mesh_group,
pipeshard_config=pipeshard_config,
num_batch=num_microbatch,
in_tree=in_tree,
out_tree=out_tree_thunk(),
static_argnums=static_argnums)
debug_compilation_time("driver executable")
Expand All @@ -91,11 +92,10 @@ def compile_pipeshard_executable(
def compile_pipeshard_executable_internal(
closed_jaxpr: ClosedJaxpr,
full_batch_closed_jaxpr: Optional[ClosedJaxpr], micro_batch_size: int,
in_tree: PyTreeDef, donated_invars: Sequence[bool],
batch_invars: Sequence[bool], virtual_mesh: VirtualPhysicalMesh,
num_microbatch: int, pipeline_schedule: str,
default_as_option: AutoShardingOption, stage_option: StageOption,
name_base: str,
donated_invars: Sequence[bool], batch_invars: Sequence[bool],
virtual_mesh: VirtualPhysicalMesh, num_microbatch: int,
pipeline_schedule: str, default_as_option: AutoShardingOption,
stage_option: StageOption, name_base: str,
output_shardings: Optional[Sequence[pxla.ShardingSpec]]):
global_invars = closed_jaxpr.jaxpr.invars
global_outvars = closed_jaxpr.jaxpr.outvars
Expand Down Expand Up @@ -224,8 +224,7 @@ def compile_pipeshard_executable_internal(
schedule=schedule,
is_batch=batch_invars,
num_batch=num_microbatch,
flop_count=total_flops,
in_tree=in_tree).compile()
flop_count=total_flops).compile()

debug_compilation_time("runtime emitter")
return pipeshard_config
Expand Down
6 changes: 5 additions & 1 deletion alpa/pipeline_parallel/layer_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
new_jaxpr_eqn, gensym, raise_to_shaped, get_aval)
from jax.interpreters.partial_eval import remat_call_p

from alpa.global_env import global_config
from alpa.pipeline_parallel.layer_stats import (global_invar_size,
is_nontrivial, eqn_flops,
heavy_count)
heavy_count,
log_layer_slicing_stats)
from alpa.pipeline_parallel.primitive_def import (pipeline_p,
mark_pipeline_jaxpreqn)
from alpa.util import (clone_jaxpr, slices_to_jaxpr, OrderedSet,
Expand Down Expand Up @@ -448,6 +450,8 @@ def wrapped(*args):
eps,
costs,
cost_criteria=cost_criteria)
if global_config.print_auto_layer_stats:
log_layer_slicing_stats(jaxpr, sliced_eqns)
else:
sliced_eqns = slice_eqns_by_layer_boundary(jaxpr)

Expand Down
19 changes: 17 additions & 2 deletions alpa/pipeline_parallel/pipeshard_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,23 @@ def __init__(self,
mesh_group: PhysicalDeviceMeshGroup,
pipeshard_config: PipeshardConfig,
num_batch: int,
in_tree: PyTreeDef,
out_tree: Optional[PyTreeDef] = None,
static_argnums: Optional[Sequence[int]] = None):
##### Input arguments #####
self.mesh_group = mesh_group
self.num_mesh = len(mesh_group)
self.num_batch = num_batch
self.static_argnums = static_argnums
self.in_tree = in_tree
self.out_tree = out_tree

##### For debugging and serialization #####
self.stages = pipeshard_config.xla_stages
self.schedule = pipeshard_config.schedule
self.flop_count = pipeshard_config.flop_count
self.input_placement_specs = pipeshard_config.input_placement_specs
self.output_placement_specs = pipeshard_config.output_placement_specs
# List[stage_idx -> str]
self.fully_optimized_hlo_texts = []
self.sharding_annotated_hlo_texts = (
Expand Down Expand Up @@ -192,8 +195,20 @@ def launch_on_driver(self, *args):
return self.outs_handler(self.mesh_group, output_bufs)

def get_input_placement_specs(self):
"""Return the preferred placement specs for input arguments."""
return self.input_placement_specs
"""
Return the preferred placement specs for input arguments.
The return value is a pytree of PlacementSpec
with the same structure as the input pytree.
"""
return tree_unflatten(self.in_tree, self.input_placement_specs)

def get_output_placement_specs(self):
"""
Return the preferred placement specs for outputs.
The return value is a pytree of PlacementSpec
with the same structure as the output pytree.
"""
return tree_unflatten(self.out_tree, self.output_placement_specs)

def __call__(self, *args):
"""Fast call without signature matching."""
Expand Down
Loading