Skip to content

Commit

Permalink
Correct multi-unshard stream patching for torch 2.2.0dev, and stream …
Browse files Browse the repository at this point in the history
…waiting correctness. (#2817)

* patched torch

* fixed torch imports

* fixed torch imports

* fixed torch imports

* patching through composer

* patching through composer

* patching typingr

* comment added

* don't patch torch 2.1.0

* patch torch 2.1.1 and 2.2.0

* linting fix

* waiting on computation stream from unshard stream

* waiting on computation stream from unshard stream

* less waiting

* no waiting

* all unshard streams wait on computation stream now

* 2.2.0 dev change
  • Loading branch information
snarayan21 authored Jan 4, 2024
1 parent 206a9ea commit e5240d2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
4 changes: 2 additions & 2 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.0'):
# Monkey patch for torch < 2.2.0 ie torch == 2.1.1, 2.1.2
elif version.parse(torch.__version__) < version.parse('2.1.3'):
# Monkey patch for torch < 2.1.3 ie torch == 2.1.1, 2.1.2

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
Expand Down
32 changes: 28 additions & 4 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,8 @@ def _share_state_and_init_handle_attrs_t2p1(
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh,
_validate_and_get_hybrid_shard_state)
_validate_and_get_hybrid_shard_state,
_wait_for_computation_stream)
from torch.distributed.utils import _p_assert

handle = root_state._handle
Expand All @@ -824,7 +825,11 @@ def _share_state_and_init_handle_attrs_t2p1(
# Patching so that _FSDPStates with different process groups have separate unshard streams.
# Keep track of any new unshard streams we may have to add for specific process groups.
fsdp_pg_unshard_streams = {}
unshard_priority = root_state._unshard_stream.priority
try:
unshard_priority = root_state._unshard_stream.priority
except AttributeError:
# Use the default priority of 0 if the stream has no assigned priority.
unshard_priority = 0
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
Expand Down Expand Up @@ -870,6 +875,13 @@ def _share_state_and_init_handle_attrs_t2p1(
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
# Ensure that all unshard streams wait on the default computation stream
for pg_unshard_stream in fsdp_pg_unshard_streams.values():
_wait_for_computation_stream(
root_state._device_handle.current_stream(),
pg_unshard_stream,
root_state._pre_unshard_stream,
)
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')
Expand All @@ -887,7 +899,8 @@ def _share_state_and_init_handle_attrs_t2p2(
done together to require a single loop over the states. This function has
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state,
_wait_for_computation_stream)
from torch.distributed.utils import _p_assert

handle = root_state._handle
Expand All @@ -911,7 +924,11 @@ def _share_state_and_init_handle_attrs_t2p2(
# Patching so that _FSDPStates with different process groups have separate unshard streams.
# Keep track of any new unshard streams we may have to add for specific process groups.
fsdp_pg_unshard_streams = {}
unshard_priority = root_state._unshard_stream.priority
try:
unshard_priority = root_state._unshard_stream.priority
except AttributeError:
# Use the default priority of 0 if the stream has no assigned priority.
unshard_priority = 0
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
Expand Down Expand Up @@ -956,6 +973,13 @@ def _share_state_and_init_handle_attrs_t2p2(
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
# Ensure that all unshard streams wait on the default computation stream
for pg_unshard_stream in fsdp_pg_unshard_streams.values():
_wait_for_computation_stream(
root_state._device_handle.current_stream(),
pg_unshard_stream,
root_state._pre_unshard_stream,
)
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')

0 comments on commit e5240d2

Please sign in to comment.