diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 23d6f5c..7850738 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -34,7 +34,10 @@ tensor_to_amax, ) -from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor +from float8_experimental.fsdp_utils import ( + WeightWithDelayedFloat8CastTensor, + WeightWithDynamicFloat8CastTensor, +) def _maybe_initialize_amaxes_scales_for_float8_cast( @@ -316,28 +319,30 @@ def cast_w_to_float8( self, w: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: if self.scaling_type_w is TensorScalingType.DELAYED: - scale_fn_name = self.recipe.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - w, - self.fp8_amax_w, - self.fp8_amax_history_w, - self.fp8_scale_w, - scale_fn_name, - e4m3_dtype, - is_amax_initialized, - reduce_amax=False, - ) - - w_fp8 = Float8Tensor.to_float8( - w, - self.fp8_scale_w, - e4m3_dtype, - self.fp8_amax_w, - self.forward_config, - ) + if isinstance(self.weight, Float8Tensor): # cast by FSDP + w_fp8 = self.weight + else: + scale_fn_name = self.recipe.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + w, + self.fp8_amax_w, + self.fp8_amax_history_w, + self.fp8_scale_w, + scale_fn_name, + e4m3_dtype, + is_amax_initialized, + reduce_amax=False, + ) + + w_fp8 = Float8Tensor.to_float8( + w, + self.fp8_scale_w, + e4m3_dtype, + self.fp8_amax_w, + self.forward_config, + ) else: assert self.scaling_type_w is TensorScalingType.DYNAMIC - # TODO(future): also support FSDP integration in delayed scaling path if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: @@ -436,18 +441,36 @@ def from_float( scaling_type_dL_dY=scaling_type_dL_dY, emulate=emulate, ) - if ( - scaling_type_w == TensorScalingType.DYNAMIC - and config.enable_fsdp_fp8_all_gather - ): - new_mod.weight = torch.nn.Parameter( - WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) - ) - else: - assert not config.enable_fsdp_fp8_all_gather, "unsupported" - new_mod.weight = mod.weight + new_mod.weight = mod.weight new_mod.bias = mod.bias # need to create buffers again when moving from meta device to # real device new_mod.create_buffers() + + # If FSDP float8 all-gather is on, wrap the weight in a float8-aware + # tensor subclass. This must happen last because: + # 1. weight needs to be on the correct device to create the buffers + # 2. buffers need to be already created for the delayed scaling version + # of the weight wrapper to be initialized + if config.enable_fsdp_fp8_all_gather: + if scaling_type_w is TensorScalingType.DYNAMIC: + new_mod.weight = torch.nn.Parameter( + WeightWithDynamicFloat8CastTensor( + new_mod.weight, + new_mod.forward_config, + ) + ) + else: + assert scaling_type_w is TensorScalingType.DELAYED + new_mod.weight = torch.nn.Parameter( + WeightWithDelayedFloat8CastTensor( + new_mod.weight, + new_mod.fp8_amax_w, + new_mod.fp8_amax_history_w, + new_mod.fp8_scale_w, + new_mod.forward_config, + new_mod.is_amax_initialized, + ) + ) + return new_mod diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index e7abe5c..818fef0 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -289,11 +289,10 @@ def inner_func(): ), "Mismatched lengths of amax tensors." if dist.is_initialized(): - # Combine all the amax tensors into one tensor and reduce it - # Note: do not reduce the weight values, because FSDP already ensures - # the weight values on all ranks are the same after all-gather. all_amax_tensors = torch.cat( - fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list + fp8_amax_x_tensor_list + + fp8_amax_w_tensor_list + + fp8_amax_dL_dY_tensor_list ) all_reduced_amax_tensor = all_reduce( all_amax_tensors, "MAX", list(range(dist.get_world_size())) @@ -302,12 +301,14 @@ def inner_func(): all_reduced_amax_tensor = all_reduced_amax_tensor.wait() ( - reduced_fp8_amax_tensor, + reduced_fp8_amax_x_tensor, + reduced_fp8_amax_w_tensor, reduced_fp8_amax_dL_dY_tensor, ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) for idx, child in enumerate(fp8_layers): - child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) + child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx]) + child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) # We create two stacked tensor groups, one for the amax history and one for the current scales diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 2f23a3b..81d53b5 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -18,7 +18,7 @@ ScaledMMConfig, ) -from float8_experimental.float8_utils import EPS +from float8_experimental.float8_utils import e4m3_dtype, EPS from torch._prims_common import suggest_memory_format @@ -189,3 +189,182 @@ def fsdp_post_all_gather( out._scale = scale return return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) + + +class WeightWithDelayedFloat8CastTensor(torch.Tensor): + @staticmethod + def __new__( + cls, + tensor: torch.Tensor, + amax_buffer: torch.Tensor, + amax_history_buffer: torch.Tensor, + scale_buffer: torch.Tensor, + mm_config: ScaledMMConfig, + is_amax_initialized: bool, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + memory_format=suggest_memory_format(tensor), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + pin_memory=tensor.is_pinned(), + requires_grad=tensor.requires_grad, + ) + + def __init__( + self, + tensor: torch.Tensor, + amax_buffer: torch.Tensor, + amax_history_buffer: torch.Tensor, + scale_buffer: torch.Tensor, + mm_config: ScaledMMConfig, + is_amax_initialized: bool, + ): + self._tensor = tensor + self._amax_buffer = amax_buffer + self._amax_history_buffer = amax_history_buffer + self._scale_buffer = scale_buffer + self._mm_config = mm_config + + # Note: is_amax_initialized is not a buffer to avoid data dependent + # control flow visible to dynamo + # TODO(future PR): add serialization for this flag + self.is_amax_initialized = is_amax_initialized + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func == torch.ops.aten.detach.default: + return WeightWithDelayedFloat8CastTensor( + args[0]._tensor, + args[0]._amax_buffer, + args[0]._amax_history_buffer, + args[0]._scale_buffer, + args[0]._mm_config, + args[0].is_amax_initialized, + ) + mm_config: Optional[ScaledMMConfig] = None + amax_buffer: Optional[torch.Tensor] = None + amax_history_buffer: Optional[torch.Tensor] = None + scale_buffer: Optional[torch.Tensor] = None + is_amax_initialized: Optional[bool] = None + + def unwrap(t): + nonlocal mm_config + if mm_config is None: + mm_config = t._mm_config + else: + mm_config = merge_mm_configs(mm_config, t._mm_config) + nonlocal amax_buffer + if amax_buffer is None: + amax_buffer = t._amax_buffer + nonlocal amax_history_buffer + if amax_history_buffer is None: + amax_history_buffer = t._amax_history_buffer + nonlocal scale_buffer + if scale_buffer is None: + scale_buffer = t._scale_buffer + nonlocal is_amax_initialized + if is_amax_initialized is None: + is_amax_initialized = t.is_amax_initialized + return t._tensor + + args, kwargs = pytree.tree_map_only( + WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) + ) + out = func(*args, **kwargs) + if func not in _ops_to_preserve_subclass: + return out + return pytree.tree_map_only( + torch.Tensor, + lambda x: WeightWithDelayedFloat8CastTensor( + x, + amax_buffer, + amax_history_buffer, + scale_buffer, + mm_config, + is_amax_initialized, + ), + out, + ) + + def __tensor_flatten__(self): + return ( + [ + "_tensor", + "_amax_buffer", + "_amax_history_buffer", + "_scale_buffer", + ], + { + "mm_config": self._mm_config, + "is_amax_initialized": is_amax_initialized, + }, + ) + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + return WeightWithDelayedFloat8CastTensor( + inner_tensors["_tensor"], + inner_tensors["_amax_buffer"], + inner_tensors["_amax_history_buffer"], + inner_tensors["_scale_buffer"], + metadata["mm_config"], + metadata["is_amax_initialized"], + ) + + def __repr__(self): + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" + + def fsdp_pre_all_gather(self, mesh): + # initialize if needed + # TODO(before land): ensure settings are consistent between Float8Linear and here + if not self.is_amax_initialized: + from float8_experimental.float8_linear import ( + _maybe_initialize_amaxes_scales_for_float8_cast, + ) + + _maybe_initialize_amaxes_scales_for_float8_cast( + self._tensor, + self._amax_buffer, + self._amax_history_buffer, + self._scale_buffer, + "max", # TODO(before land): read this from parent + e4m3_dtype, + self.is_amax_initialized, + reduce_amax=True, + ) + self.is_amax_initialized = True + + # this will: + # 1. cast the tensor to float8 using `_scale_buffer` + # 2. populate `_amax_buffer` inplace + # TODO(future PR): clean up all the casting functions and clearly + # separate dynamic vs delayed, tech debt has accumulated + float8_tensor = Float8Tensor.to_float8( + self._tensor, + self._scale_buffer, + e4m3_dtype, + self._amax_buffer, + self._mm_config, + ) + return (float8_tensor._data,), (float8_tensor._scale,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + out._scale = scale + return + return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index af57871..2638401 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,6 +6,11 @@ import torch import torch.distributed as dist import torch.nn as nn +from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear_utils import ( + linear_requires_sync, + sync_float8_amax_and_scale_history, +) from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp @@ -17,6 +22,7 @@ def check_parity_no_mp( fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, precompute: bool = False, + scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -28,10 +34,18 @@ def check_parity_no_mp( for param in model.parameters(): dist.all_reduce(param.grad) param.grad.div_(dist.get_world_size()) - # TODO(future): add amax syncing once delayed scaling is supported + + if linear_requires_sync(scaling_type_w=scaling_type_w): + sync_float8_amax_and_scale_history(model) + optim.step() - if model is fsdp_model and precompute: + if ( + model is fsdp_model + and precompute + and scaling_type_w is TensorScalingType.DYNAMIC + ): precompute_float8_dynamic_scale_for_fsdp(model) + test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index b5ae234..91c629f 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -1,4 +1,5 @@ import copy +import itertools import threading import unittest from typing import Any, List @@ -79,22 +80,30 @@ def world_size(self) -> int: return min(torch.cuda.device_count(), 2) @skip_if_lt_x_gpu(2) - def test_transformer_parity_dynamic(self): + def test_transformer_parity(self): self.run_subtests( { "enable_fsdp_fp8_all_gather": [False, True], "precompute": [False, True], + "scaling_type_w": [ + TensorScalingType.DYNAMIC, + TensorScalingType.DELAYED, + ], }, - self._test_transformer_parity_dynamic, + self._test_transformer_parity, ) - def _test_transformer_parity_dynamic( + def _test_transformer_parity( self, enable_fsdp_fp8_all_gather: bool, precompute: bool, + scaling_type_w: TensorScalingType, ): if not enable_fsdp_fp8_all_gather and precompute: return + elif scaling_type_w is TensorScalingType.DELAYED and precompute: + return + # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -102,9 +111,9 @@ def _test_transformer_parity_dynamic( weight_tying = not enable_fsdp_fp8_all_gather module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) - swap_linear_with_float8_linear(ref_module) + swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - swap_linear_with_float8_linear(module) + swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) for submodule in module.modules(): if isinstance(submodule, TransformerBlock): fully_shard(submodule) @@ -115,7 +124,14 @@ def _test_transformer_parity_dynamic( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) check_parity_no_mp( - self, ref_module, ref_optim, module, optim, local_inp, precompute + self, + ref_module, + ref_optim, + module, + optim, + local_inp, + precompute, + scaling_type_w=scaling_type_w, ) @skip_if_lt_x_gpu(2) @@ -376,13 +392,21 @@ def test_fp32_fp8_single_module_parity(self): Tests numeric parity for fp32 parameters with fp8 computation with a single module/FSDP communication group. """ - for enable_fsdp_fp8_all_gather in [False, True]: + choices = itertools.product( + [False, True], + [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + ) + for enable_fsdp_fp8_all_gather, scaling_type_w in choices: module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) - ref_module = swap_linear_with_float8_linear(ref_module) + ref_module = swap_linear_with_float8_linear( + ref_module, scaling_type_w=scaling_type_w + ) ref_module = ref_module.cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = swap_linear_with_float8_linear(module_fp32) + module = swap_linear_with_float8_linear( + module_fp32, scaling_type_w=scaling_type_w + ) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -394,6 +418,7 @@ def test_fp32_fp8_single_module_parity(self): module, optim, local_inp, + scaling_type_w=scaling_type_w, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -402,12 +427,20 @@ def test_fp32_fp8_multi_module_parity(self): Tests numeric parity for fp32 parameters with fp8 computation with multiple modules/FSDP communication groups. """ - for enable_fsdp_fp8_all_gather in [False, True]: + choices = itertools.product( + [False, True], + [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], + ) + for enable_fsdp_fp8_all_gather, scaling_type_w in choices: module = self.init_multi_module().cuda() ref_module = copy.deepcopy(module) - ref_module = swap_linear_with_float8_linear(ref_module) + ref_module = swap_linear_with_float8_linear( + ref_module, scaling_type_w=scaling_type_w + ) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = swap_linear_with_float8_linear(module) + module = swap_linear_with_float8_linear( + module, scaling_type_w=scaling_type_w + ) for submodule in module: fully_shard(submodule) fully_shard(module) @@ -421,6 +454,7 @@ def test_fp32_fp8_multi_module_parity(self): module, optim, local_inp, + scaling_type_w=scaling_type_w, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -455,6 +489,23 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): self.get_local_inp(torch.bfloat16), ) + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_delayed_scaling_inplace_update(self): + """ + Verify that `WeightWithDelayedFloat8CastTensor` updates buffers inplace + """ + module = self.init_single_module() + with set_enable_fsdp_fp8_all_gather(True): + m_fp8 = swap_linear_with_float8_linear( + module, + scaling_type_w=TensorScalingType.DELAYED, + ) + + fp8_amax_w_old = m_fp8.fp8_amax_w.clone().detach() + dummy_mesh = None + data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh) + self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item()) + if __name__ == "__main__": run_tests()