Skip to content

Commit

Permalink
Merge branch 'dev' into jerry/mlflow-objectstore-part1
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jan 8, 2024
2 parents cd83f64 + a36fb74 commit 836e601
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 26 deletions.
19 changes: 9 additions & 10 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,19 @@ def get_gpu_flops_available(state: State):
device_name = 'v100-pcie'
elif 't4' in device_name:
device_name = 't4'
else:
device_name = None

if device_name is not None:
try:
gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value])
except:
gpu_flops_available = None
if device_name in GPU_AVAILABLE_FLOPS and state.precision.value in GPU_AVAILABLE_FLOPS[device_name]:
gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value])
else:
gpu_flops_available = None

if gpu_flops_available is None:
warnings.warn(
f'gpu_flop count not found for {device_name} with precision: {state.precision.value}; ' +\
f'MFU cannot be calculated and reported. gpu_flops_available can be manually' +\
f'overridden by setting gpu_flops_available in SpeedMonitor.'
f'gpu_flop count not found for {device_name} with precision={state.precision.value} ' +\
f'so MFU cannot be calculated and reported. gpu_flops_available can be manually ' +\
f'overridden by setting gpu_flops_available in SpeedMonitor or {device_name} can ' +\
f'be added to GPU_AVAILABLE_FLOPS in composer/callbacks/speed_monitor.py',
stacklevel=2,
)
# Setting to 0 will disable MFU computation and prevent
# the speed monitor from running this helper every batch
Expand Down
3 changes: 2 additions & 1 deletion composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def predict_end(self, state: State, logger: Logger) -> None:

def close(self, state: State, logger: Logger) -> None:
self._flush_metadata(force_flush=True, future=False)
wait(self._futures) # Ignore raised errors on close
if self._enabled:
wait(self._futures) # Ignore raised errors on close

def _log_metadata(self, metadata: Dict[str, Any]) -> None:
"""Buffer metadata and prefix keys with mosaicml."""
Expand Down
10 changes: 5 additions & 5 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ def load_huggingface_tokenizer_from_saved_state(

tokenizer_file_path = Path(tokenizer_save_dir) / tokenizer_file_name
if saved_content['file_extension'] == '.json':
with open(tokenizer_file_path, 'w') as _f:
with open(tokenizer_file_path, 'w', encoding='utf-8') as _f:
json.dump(saved_content['content'], _f)
elif saved_content['file_extension'] == '.txt':
with open(tokenizer_file_path, 'w') as _f:
with open(tokenizer_file_path, 'w', encoding='utf-8') as _f:
for line in saved_content['content']:
_f.write(line)
_f.write('\n')
elif saved_content['file_extension'] == '.py':
with open(tokenizer_file_path, 'w') as _f:
with open(tokenizer_file_path, 'w', encoding='utf-8') as _f:
_f.write(saved_content['content'])
elif saved_content['file_extension'] == '.model':
try:
Expand Down Expand Up @@ -503,13 +503,13 @@ def get_metadata(self):
tokenizer_file_path = tokenizer_dir / tokenizer_file_name
tokenizer_file_extension = tokenizer_file_path.suffix
if tokenizer_file_extension == '.txt':
with open(tokenizer_file_path) as _tokenizer_file:
with open(tokenizer_file_path, encoding='utf-8') as _tokenizer_file:
tokenizer_file_content = _tokenizer_file.read().split('\n')
elif tokenizer_file_extension == '.json':
with open(tokenizer_file_path, 'rb') as _tokenizer_file:
tokenizer_file_content = json.load(_tokenizer_file)
elif tokenizer_file_extension == '.py':
with open(tokenizer_file_path) as _tokenizer_file:
with open(tokenizer_file_path, encoding='utf-8') as _tokenizer_file:
tokenizer_file_content = _tokenizer_file.read()
elif tokenizer_file_extension == '.model':
try:
Expand Down
2 changes: 1 addition & 1 deletion composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def batch_start(self, state: State, logger: Logger) -> None:

def close(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.profiler is not None:
if self.profiler is not None and self.profiler.profiler is not None:
log.info(self.profiler.key_averages().table(sort_by='cpu_time_total', row_limit=20))
if self.profile_memory:
log.info(self.profiler.key_averages().table(sort_by='self_cpu_memory_usage', row_limit=20))
Expand Down
16 changes: 12 additions & 4 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,30 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.0'):
# Monkey patch for torch < 2.2.0 ie torch == 2.1.1, 2.1.2
elif version.parse(torch.__version__) < version.parse('2.1.3'):
# Monkey patch for torch < 2.1.3 ie torch == 2.1.1, 2.1.2

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Better overlap communication and computation
from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p1
from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1,
_wait_for_computation_stream, forward)
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1
_runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
_runtime_utils._root_pre_forward = _root_pre_forward
FullyShardedDataParallel.forward = forward

elif version.parse(torch.__version__) < version.parse('2.2.1'):
# Monkey patch for torch < 2.2.1 ie torch == 2.2.0

# Better overlap communication and computation
from torch.distributed.fsdp import _runtime_utils

from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2
from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
_wait_for_computation_stream, forward)
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
_runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
_runtime_utils._root_pre_forward = _root_pre_forward
FullyShardedDataParallel.forward = forward
159 changes: 157 additions & 2 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,153 @@ def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]:
return tuple(get_process_group_ranks(state.process_group))


def _wait_for_computation_stream(
computation_stream: torch.Stream,
root_state: '_FSDPState',
pre_unshard_stream: torch.Stream,
):
"""Unshard and pre-unshard streams wait for computation stream.
Has the unshard and pre-unshard streams wait for the computation stream.
For example, this should be called in the FSDP root's pre-forward to
respect optimizer step computation.
"""
# Tracing does not need to wait
if torch.distributed._functional_collectives.is_torchdynamo_compiling():
return
# Ensure all unshard streams wait for the computation stream.
unshard_streams = set()
for fsdp_state in root_state._all_fsdp_states:
unshard_streams.add(fsdp_state._unshard_stream)
for unshard_stream in unshard_streams:
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
# Having the pre-all-gather stream wait for the current stream even if we
# do not leverage the pre-all-gather stream is tolerable since this only
# runs once per iteration
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]


@no_type_check
def _root_pre_forward(
state: '_FSDPState',
module: nn.Module,
args,
kwargs,
) -> None:
"""Runs pre-forward logic specific to the root FSDP instance.
This should run before any individual module's pre-forward. This starts
with an attempt at lazy initialization (which only runs non-vacuously once).
Otherwise, if this is called on a non-root FSDP instance, then it returns
directly.
"""
from torch.distributed.fsdp._common_utils import _is_composable
from torch.distributed.fsdp._runtime_utils import (_cast_buffers_to_dtype_and_device,
_get_buffers_and_dtypes_for_computation, _lazy_init,
_reset_flat_param_grad_info_if_needed, _root_cast_forward_input)
from torch.distributed.utils import _p_assert, _to_kwargs
with torch.profiler.record_function('FullyShardedDataParallel._root_pre_forward'):
_lazy_init(state, module)
_p_assert(state._is_root is not None, 'Expects a root FSDP to have been set')
if not state._is_root:
# Always cast forward inputs in the root of this local FSDP unit for mixed
# precision, as this is where mixed precision could be configed.
# This is more useful for auto wrapping that is recommended in composable path.
# For manual wrapping, cast forward inputs on each local FSDP unit root will
# increase some overhead, so not turned on for model wrapper path right now where
# manual wrapping is more broadly used.
if _is_composable(state):
return _root_cast_forward_input(state, module, args, kwargs)
return args, kwargs

# We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
# are in full precision and if we should cast them back to lower precision, which happens when
# exiting eval() mode.
handle = state._handle
if handle:
should_cast_buffers_to_full_prec = handle._force_full_precision
else:
should_cast_buffers_to_full_prec = True

if should_cast_buffers_to_full_prec:
_cast_buffers_to_dtype_and_device(
buffers=dict(module.named_buffers()).values(),
buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
device=state.compute_device,
)
# This flag is only set when we cast buffers to full precision, to avoid the
# CPU overhead that can stem from retrieving all buffers and their types in the
# following else branch.
state._needs_buffer_dtype_restore_check = True
elif getattr(state, '_needs_buffer_dtype_restore_check', False):
# Check if buffers are in full precision and we need to cast them
# back down.
(
buffers,
buffer_dtypes_for_computation,
) = _get_buffers_and_dtypes_for_computation(state, module)
if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
if any(buffer.dtype != buffer_dtype_for_computation
for buffer, buffer_dtype_for_computation in zip(buffers, buffer_dtypes_for_computation)):
# Assume we have to cast everything if there is one mismatch
_cast_buffers_to_dtype_and_device(buffers, buffer_dtypes_for_computation, state.compute_device)
# We don't have to check this again until we cast buffers to full precision again.
state._needs_buffer_dtype_restore_check = False

if state.forward_prefetch:
handles = []
for fsdp_state in state._all_fsdp_states:
if fsdp_state._handle:
handles.append(fsdp_state._handle)
for handle in handles:
handle._needs_pre_forward_unshard = True
handle._prefetched = False

_wait_for_computation_stream(
state._device_handle.current_stream(),
state,
state._pre_unshard_stream,
)
_reset_flat_param_grad_info_if_needed(state._all_handles)

# Prepares the forward inputs by moving them to ``compute_device``
# TODO: Do not use the side stream for tensor copies for now; investigate
# the perf with/without it.
with torch.profiler.record_function('FullyShardedDataParallel._to_kwargs'):
args_tuple, kwargs_tuple = _to_kwargs(args, kwargs, state.compute_device, False)
args = args_tuple[0]
kwargs = kwargs_tuple[0]

return _root_cast_forward_input(state, module, args, kwargs)


def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
from torch.distributed.fsdp._runtime_utils import (_post_forward, _post_forward_reshard, _pre_forward,
_pre_forward_unshard)
from torch.distributed.utils import _p_assert
handle = self._handle
with torch.autograd.profiler.record_function('FullyShardedDataParallel.forward'):
args, kwargs = _root_pre_forward(self, self, args, kwargs)
unused = None
args, kwargs = _pre_forward(
self,
handle,
_pre_forward_unshard,
self._fsdp_wrapped_module,
args,
kwargs,
)
if handle:
_p_assert(
handle.flat_param.device == self.compute_device,
'Expected `FlatParameter` to be on the compute device '
f'{self.compute_device} but got {handle.flat_param.device}',
)
output = self._fsdp_wrapped_module(*args, **kwargs)
return _post_forward(self, handle, _post_forward_reshard, self, unused, output)


@no_type_check
def _share_state_and_init_handle_attrs_t2p1(
root_state: '_FSDPState',
Expand Down Expand Up @@ -824,7 +971,11 @@ def _share_state_and_init_handle_attrs_t2p1(
# Patching so that _FSDPStates with different process groups have separate unshard streams.
# Keep track of any new unshard streams we may have to add for specific process groups.
fsdp_pg_unshard_streams = {}
unshard_priority = root_state._unshard_stream.priority
try:
unshard_priority = root_state._unshard_stream.priority
except AttributeError:
# Use the default priority of 0 if the stream has no assigned priority.
unshard_priority = 0
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
Expand Down Expand Up @@ -911,7 +1062,11 @@ def _share_state_and_init_handle_attrs_t2p2(
# Patching so that _FSDPStates with different process groups have separate unshard streams.
# Keep track of any new unshard streams we may have to add for specific process groups.
fsdp_pg_unshard_streams = {}
unshard_priority = root_state._unshard_stream.priority
try:
unshard_priority = root_state._unshard_stream.priority
except AttributeError:
# Use the default priority of 0 if the stream has no assigned priority.
unshard_priority = 0
for fsdp_state in root_state._all_fsdp_states:
for attr_name in HOMOGENEOUS_ATTR_NAMES:
_p_assert(
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def package_files(prefix: str, directory: str, extension: str):
'tqdm>=4.62.3,<5',
'torchmetrics>=0.10.0,<1.1',
'torch_optimizer>=0.3.0,<0.4',
'torchvision>=0.13.1,<0.17',
'torch>=1.13.1,<2.1.3',
'torchvision>=0.13.1,<0.19',
'torch>=1.13.1,<2.2.1',
'requests>=2.26.0,<3',
'numpy>=1.21.5,<1.27.0',
'psutil>=5.8.0,<6',
Expand Down Expand Up @@ -126,7 +126,7 @@ def package_files(prefix: str, directory: str, extension: str):
'sphinx_panels==0.6.0',
'sphinxcontrib-images==0.9.4',
'pytest_codeblocks==0.17.0',
'traitlets==5.13.0',
'traitlets==5.14.1',
'nbsphinx==0.9.1',
'pandoc==2.3',
'pypandoc==1.12',
Expand Down

0 comments on commit 836e601

Please sign in to comment.