Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[FEATURE] Support overlapping pipeline communication and computation #773

Merged
merged 20 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions alpa/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def _compile_parallel_executable(
for store in fun.stores:
if store:
store.reset()
batch_invars = list(batch_invars)
for idx, aval in enumerate(avals):
if len(aval.shape) == 0:
batch_invars[idx] = False
batch_invars = tuple(batch_invars)

# Compile a callable
return method.compile_executable(fun, in_tree, out_tree_thunk,
Expand Down
10 changes: 6 additions & 4 deletions alpa/collective/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
get_collective_group_size, allreduce, allreduce_multigpu, barrier, reduce,
reduce_multigpu, broadcast, broadcast_partialgpu, broadcast_multigpu,
allgather, allgather_multigpu, reducescatter, reducescatter_multigpu, send,
send_multigpu, recv, recv_multigpu, check_and_get_group)
send_multigpu, recv, recv_multigpu, check_and_get_group, record_events,
wait_events, comm_wait_compute, compute_wait_comm)

__all__ = [
"nccl_available", "gloo_available", "is_group_initialized",
"init_collective_group", "destroy_collective_group",
"create_collective_group", "get_rank", "get_collective_group_size",
"allreduce", "allreduce_multigpu", "barrier", "reduce", "reduce_multigpu",
"broadcast", "broadcast_multigpu", "allgather", "allgather_multigpu",
"reducescatter", "reducescatter_multigpu", "send", "send_multigpu", "recv",
"recv_multigpu", "check_and_get_group"
"broadcast", "broadcast_partialgpu", "broadcast_multigpu", "allgather",
"allgather_multigpu", "reducescatter", "reducescatter_multigpu", "send",
"send_multigpu", "recv", "recv_multigpu", "check_and_get_group",
"record_events", "wait_events", "comm_wait_compute", "compute_wait_comm"
]
20 changes: 20 additions & 0 deletions alpa/collective/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,26 @@ def _check_and_get_group(group_name):
check_and_get_group = _check_and_get_group


def record_events(group_name, uuids, num_devices, is_send):
g = _check_and_get_group(group_name)
g.record_events(uuids, num_devices, is_send)


def wait_events(group_name, uuids, num_devices, is_send):
g = _check_and_get_group(group_name)
g.wait_events(uuids, num_devices, is_send)


def comm_wait_compute(group_name, is_send, is_compute, device_id):
g = _check_and_get_group(group_name)
g.comm_wait_compute(is_send, is_compute, device_id)


def compute_wait_comm(group_name, is_send, is_compute, device_id):
g = _check_and_get_group(group_name)
g.compute_wait_comm(is_send, is_compute, device_id)


def _check_single_tensor_input(tensor):
"""Check if the tensor is with a supported type."""
if isinstance(tensor, (np.ndarray, xe.DeviceArray)):
Expand Down
28 changes: 23 additions & 5 deletions alpa/collective/collective_group/nccl_collective_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self, world_size, rank, group_name):

# TODO(Fu): might need an event map
self._dev_event_map = {}
# This is only for cross-mesh all-reduce to use
self.xla_comm_group = None

if nccl_util.get_nccl_build_version() < 2000:
raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.")
Expand Down Expand Up @@ -449,11 +451,14 @@ def _get_nccl_collective_communicator(self,

if lib == "xla":
# FIXME: pass the start rank at the initial point
if self.xla_comm_group is not None:
return self.xla_comm_group
start_rank = self.rank * len(device_list)
actual_ranks = [start_rank + i for i in range(len(device_list))]
local_ids = list(range(len(device_list)))
comms = xla_extension.nccl_create_communicators_no_stream(
actual_world_size, actual_ranks, local_ids, nccl_uid)
xla_extension.create_cross_mesh_communicator(
actual_world_size, actual_ranks, len(device_list), nccl_uid)
self.xla_comm_group = xla_extension.CommGroup(None)
return self.xla_comm_group
nccl_util.groupStart()
for i, device in enumerate(device_list):
actual_rank = self.rank * len(device_list) + i
Expand All @@ -473,9 +478,13 @@ def _get_nccl_collective_communicator(self,
self._dev_event_map[comm_key] = events
return comms

def get_nccl_collective_communicator(self, devices, lib="cupy"):
def create_nccl_collective_communicator(self, devices):
key = _get_comm_key_from_devices(devices)
return self._get_nccl_collective_communicator(key, devices, lib)
self._get_nccl_collective_communicator(key, devices)

def create_and_set_xla_communicators(self, devices):
key = _get_comm_key_from_devices(devices)
self._get_nccl_collective_communicator(key, devices, "xla")

@staticmethod
def _sync_streams(device_list, events, streams):
Expand Down Expand Up @@ -688,6 +697,15 @@ def create_p2p_communicator(self,
self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,
peer_gpu_idx, nccl_uid)

def create_nccl_broadcast_communicator(self,
comm_key,
world_size,
devices_ids,
devices_global_rank,
nccl_uid=None):
self._get_nccl_broadcast_communicator(comm_key, world_size, devices_ids,
devices_global_rank, nccl_uid)

def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: int):
"""A method to encapsulate all peer-to-peer calls (i.e., send/recv).

Expand Down
Loading