Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] Allow for skipping stream synch #1505

Merged
merged 9 commits into from
Oct 12, 2022
26 changes: 24 additions & 2 deletions apex/transformer/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _communicate(
fp32_residual_connection: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]:
"""Base function for communication of tensors between stages.

Expand Down Expand Up @@ -159,6 +160,8 @@ def _communicate(
sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled.
This argument is here for consistency with Megatron-LM.
This argument has an effect on the communication optimization, not on tensor_shape update.
sync_batch_comm: If :obj:`False`, disable cuda synchronization after the batched communication.
To disable, https://github.com/pytorch/pytorch/pull/82450 would be required.

Returns:
tuple containing
Expand Down Expand Up @@ -267,8 +270,9 @@ def tensor_recv_next_wait():
torch.cuda.synchronize()
tensor_recv_next_waitfunc = tensor_recv_next_wait
else:
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
if sync_batch_comm:
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()

# If using scatter-gather optimization, gather smaller chunks.
if scatter_gather_optimization_doable:
Expand Down Expand Up @@ -325,6 +329,7 @@ def recv_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
Expand All @@ -342,6 +347,7 @@ def recv_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-recv").stop()
Expand All @@ -354,6 +360,7 @@ def recv_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from next rank in pipeline (backward receive)."""
Expand All @@ -370,6 +377,7 @@ def recv_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-recv").stop()
Expand All @@ -384,6 +392,7 @@ def send_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
Expand All @@ -401,6 +410,7 @@ def send_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-send").stop()
Expand All @@ -413,6 +423,7 @@ def send_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
Expand All @@ -429,6 +440,7 @@ def send_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-send").stop()
Expand All @@ -441,6 +453,7 @@ def send_forward_recv_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with next rank in pipeline."""
Expand All @@ -457,6 +470,7 @@ def send_forward_recv_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-send-backward-recv").stop()
Expand All @@ -470,6 +484,7 @@ def send_backward_recv_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with previous rank in pipeline."""
Expand All @@ -486,6 +501,7 @@ def send_backward_recv_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-send-forward-recv").stop()
Expand All @@ -500,6 +516,7 @@ def send_forward_recv_forward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from previous rank and send to next rank in pipeline."""
Expand All @@ -514,6 +531,7 @@ def send_forward_recv_forward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-send-forward-recv").stop()
Expand All @@ -528,6 +546,7 @@ def send_backward_recv_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from next rank and send to previous rank in pipeline."""
Expand All @@ -542,6 +561,7 @@ def send_backward_recv_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("backward-send-backward-recv").stop()
Expand All @@ -558,6 +578,7 @@ def send_forward_backward_recv_forward_backward(
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
timers: _Timers = None,
) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]:
"""Batched send and recv with previous and next ranks in pipeline."""
Expand All @@ -572,6 +593,7 @@ def send_forward_backward_recv_forward_backward(
dtype_=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _forward_backward_pipelining_with_interleaving(
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
sync_batch_comm: bool = True,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
Expand Down Expand Up @@ -70,6 +71,8 @@ def _forward_backward_pipelining_with_interleaving(
sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length.
When :obj:`True`, the sequence length on each tensor model parallel rank is updated
to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`.
sync_batch_comm: If :obj:`False`, disable cuda synchronization after the batched communication.
To disable, https://github.com/pytorch/pytorch/pull/82450 would be required.

Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
Expand Down Expand Up @@ -221,6 +224,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
)
_logger.info("Warmup phase")
Expand Down Expand Up @@ -269,6 +273,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
Expand All @@ -280,6 +285,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
Expand Down Expand Up @@ -365,6 +371,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)

Expand All @@ -387,6 +394,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
)
for k in range(num_microbatches_remaining, num_microbatches):
Expand All @@ -409,6 +417,7 @@ def backward_step_helper(microbatch_id: int) -> torch.Tensor:
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
sync_batch_comm=sync_batch_comm,
)
)

Expand Down
Loading