Skip to content

Commit

Permalink
[transformer] Allow for skipping stream sync (NVIDIA#1505)
Browse files Browse the repository at this point in the history
* Optionally disable stream synchronization after batched p2p communication

* Add test cases with `sync_batch_comm=False`

only when pytorch/pytorch#82450 is included in
pytorch.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

* utilize existing test methods

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

* consistent naming

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: Aidyn-A <Aidyn-A@users.noreply.github.com>

* silly boy, to skip the sync, set False

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

* cosmetic

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

* Test with async_pipelinign w/o sync after batch_isend_irecv

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

* again, set sync_batch_comm to False

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: Aidyn-A <Aidyn-A@users.noreply.github.com>

* Remove `torch.testing._internal.common_cuda`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: Aidyn-A <Aidyn-A@users.noreply.github.com>
  • Loading branch information
3 people authored and hubertlu-tw committed Dec 29, 2022
1 parent 7c3cae3 commit 23ef6ff
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 29 deletions.
26 changes: 24 additions & 2 deletions apex/transformer/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _communicate(
fp32_residual_connection: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]:
"""Base function for communication of tensors between stages.
Expand Down Expand Up @@ -159,6 +160,8 @@ def _communicate(
sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled.
This argument is here for consistency with Megatron-LM.
This argument has an effect on the communication optimization, not on tensor_shape update.
sync_batch_comm: If :obj:`False`, disable cuda synchronization after the batched communication.
To disable, https://github.com/pytorch/pytorch/pull/82450 would be required.
Returns:
tuple containing
Expand Down Expand Up @@ -267,8 +270,9 @@ def tensor_recv_next_wait():
torch.cuda.synchronize()
tensor_recv_next_waitfunc = tensor_recv_next_wait
else:
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
if sync_batch_comm:
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()

# If using scatter-gather optimization, gather smaller chunks.
if scatter_gather_optimization_doable:
Expand Down Expand Up @@ -325,6 +329,7 @@ def recv_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
Expand All @@ -342,6 +347,7 @@ def recv_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-recv").stop()
Expand All @@ -354,6 +360,7 @@ def recv_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from next rank in pipeline (backward receive)."""
Expand All @@ -370,6 +377,7 @@ def recv_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-recv").stop()
Expand All @@ -384,6 +392,7 @@ def send_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
Expand All @@ -401,6 +410,7 @@ def send_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-send").stop()
Expand All @@ -413,6 +423,7 @@ def send_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
Expand All @@ -429,6 +440,7 @@ def send_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-send").stop()
Expand All @@ -441,6 +453,7 @@ def send_forward_recv_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with next rank in pipeline."""
Expand All @@ -457,6 +470,7 @@ def send_forward_recv_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-send-backward-recv").stop()
Expand All @@ -470,6 +484,7 @@ def send_backward_recv_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with previous rank in pipeline."""
Expand All @@ -486,6 +501,7 @@ def send_backward_recv_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-send-forward-recv").stop()
Expand All @@ -500,6 +516,7 @@ def send_forward_recv_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from previous rank and send to next rank in pipeline."""
Expand All @@ -514,6 +531,7 @@ def send_forward_recv_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-send-forward-recv").stop()
Expand All @@ -528,6 +546,7 @@ def send_backward_recv_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from next rank and send to previous rank in pipeline."""
Expand All @@ -542,6 +561,7 @@ def send_backward_recv_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-send-backward-recv").stop()
Expand All @@ -558,6 +578,7 @@ def send_forward_backward_recv_forward_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]:
"""Batched send and recv with previous and next ranks in pipeline."""
Expand All @@ -572,6 +593,7 @@ def send_forward_backward_recv_forward_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _forward_backward_pipelining_with_interleaving(
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
Expand Down Expand Up @@ -70,6 +71,8 @@ def _forward_backward_pipelining_with_interleaving(
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
sync_batch_comm: If :obj:`False`, disable cuda synchronization after the batched communication.
To disable, https://github.com/pytorch/pytorch/pull/82450 would be required.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
Expand Down Expand Up @@ -221,6 +224,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
)
_logger.info("Warmup phase")
Expand Down Expand Up @@ -269,6 +273,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
Expand All @@ -280,6 +285,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
Expand Down Expand Up @@ -365,6 +371,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)

Expand All @@ -387,6 +394,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
)
for k in range(num_microbatches_remaining, num_microbatches):
Expand All @@ -409,6 +417,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
)

Expand Down
Loading

0 comments on commit 23ef6ff

Please sign in to comment.