Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

support delayed scaling of weight in float8 all-gather #312

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
tensor_to_amax,
)

from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
from float8_experimental.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
)


def _maybe_initialize_amaxes_scales_for_float8_cast(
Expand Down Expand Up @@ -316,28 +319,30 @@ def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
if self.scaling_type_w is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_w,
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_w,
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
vkuzo marked this conversation as resolved.
Show resolved Hide resolved
else:
assert self.scaling_type_w is TensorScalingType.DYNAMIC
# TODO(future): also support FSDP integration in delayed scaling path
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
Expand Down Expand Up @@ -436,18 +441,36 @@ def from_float(
scaling_type_dL_dY=scaling_type_dL_dY,
emulate=emulate,
)
if (
scaling_type_w == TensorScalingType.DYNAMIC
and config.enable_fsdp_fp8_all_gather
):
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
)
else:
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
new_mod.weight = mod.weight
new_mod.weight = mod.weight
new_mod.bias = mod.bias
# need to create buffers again when moving from meta device to
# real device
new_mod.create_buffers()

# If FSDP float8 all-gather is on, wrap the weight in a float8-aware
# tensor subclass. This must happen last because:
# 1. weight needs to be on the correct device to create the buffers
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_fp8_all_gather:
if scaling_type_w is TensorScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.forward_config,
)
)
else:
assert scaling_type_w is TensorScalingType.DELAYED
new_mod.weight = torch.nn.Parameter(
WeightWithDelayedFloat8CastTensor(
new_mod.weight,
new_mod.fp8_amax_w,
new_mod.fp8_amax_history_w,
new_mod.fp8_scale_w,
new_mod.forward_config,
new_mod.is_amax_initialized,
)
)

return new_mod
13 changes: 7 additions & 6 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,10 @@ def inner_func():
), "Mismatched lengths of amax tensors."

if dist.is_initialized():
# Combine all the amax tensors into one tensor and reduce it
# Note: do not reduce the weight values, because FSDP already ensures
# the weight values on all ranks are the same after all-gather.
all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list
fp8_amax_x_tensor_list
+ fp8_amax_w_tensor_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we only do this if we are using fp8 all gather ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that could make sense, I'd love to see the data to see if this is going to matter for performance. Focusing on numerics for now, was hoping for performance be tackled in future PRs.

+ fp8_amax_dL_dY_tensor_list
)
all_reduced_amax_tensor = all_reduce(
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
Expand All @@ -302,12 +301,14 @@ def inner_func():
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()

(
reduced_fp8_amax_tensor,
reduced_fp8_amax_x_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])

# We create two stacked tensor groups, one for the amax history and one for the current scales
Expand Down
181 changes: 180 additions & 1 deletion float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ScaledMMConfig,
)

from float8_experimental.float8_utils import EPS
from float8_experimental.float8_utils import e4m3_dtype, EPS
from torch._prims_common import suggest_memory_format


Expand Down Expand Up @@ -189,3 +189,182 @@ def fsdp_post_all_gather(
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)


class WeightWithDelayedFloat8CastTensor(torch.Tensor):
Copy link
Contributor

@drisspg drisspg Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no change needed] I wish there was a way to share some more code with the dynamic version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, me too. Looking at the code below, really the only code which would be shared is fsdp_post_all_gather, everything else would have to have if/else branches for delayed vs dynamic

@staticmethod
def __new__(
cls,
tensor: torch.Tensor,
amax_buffer: torch.Tensor,
amax_history_buffer: torch.Tensor,
scale_buffer: torch.Tensor,
mm_config: ScaledMMConfig,
is_amax_initialized: bool,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
memory_format=suggest_memory_format(tensor),
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
pin_memory=tensor.is_pinned(),
requires_grad=tensor.requires_grad,
)

def __init__(
self,
tensor: torch.Tensor,
amax_buffer: torch.Tensor,
amax_history_buffer: torch.Tensor,
scale_buffer: torch.Tensor,
mm_config: ScaledMMConfig,
is_amax_initialized: bool,
):
self._tensor = tensor
self._amax_buffer = amax_buffer
self._amax_history_buffer = amax_history_buffer
self._scale_buffer = scale_buffer
self._mm_config = mm_config

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
# TODO(future PR): add serialization for this flag
self.is_amax_initialized = is_amax_initialized

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.detach.default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly just a nit, but any reason to special-case detach here? Alternatively, you could set it up so that every view ops automatiomatically propagates subclass-ness in the same way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is something I wrote, I think it was just something I saw in some other subclasses. Having every view up propagate subclass-ness in the same way sounds good to me.

return WeightWithDelayedFloat8CastTensor(
args[0]._tensor,
args[0]._amax_buffer,
args[0]._amax_history_buffer,
args[0]._scale_buffer,
args[0]._mm_config,
args[0].is_amax_initialized,
)
mm_config: Optional[ScaledMMConfig] = None
amax_buffer: Optional[torch.Tensor] = None
amax_history_buffer: Optional[torch.Tensor] = None
scale_buffer: Optional[torch.Tensor] = None
is_amax_initialized: Optional[bool] = None

def unwrap(t):
nonlocal mm_config
if mm_config is None:
mm_config = t._mm_config
else:
mm_config = merge_mm_configs(mm_config, t._mm_config)
nonlocal amax_buffer
if amax_buffer is None:
amax_buffer = t._amax_buffer
nonlocal amax_history_buffer
if amax_history_buffer is None:
amax_history_buffer = t._amax_history_buffer
nonlocal scale_buffer
if scale_buffer is None:
scale_buffer = t._scale_buffer
nonlocal is_amax_initialized
if is_amax_initialized is None:
is_amax_initialized = t.is_amax_initialized
return t._tensor

args, kwargs = pytree.tree_map_only(
WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {})
)
out = func(*args, **kwargs)
if func not in _ops_to_preserve_subclass:
return out
return pytree.tree_map_only(
torch.Tensor,
lambda x: WeightWithDelayedFloat8CastTensor(
x,
amax_buffer,
amax_history_buffer,
scale_buffer,
mm_config,
is_amax_initialized,
),
out,
)

def __tensor_flatten__(self):
return (
[
"_tensor",
"_amax_buffer",
"_amax_history_buffer",
"_scale_buffer",
],
{
"mm_config": self._mm_config,
"is_amax_initialized": is_amax_initialized,
},
)

@staticmethod
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
return WeightWithDelayedFloat8CastTensor(
inner_tensors["_tensor"],
inner_tensors["_amax_buffer"],
inner_tensors["_amax_history_buffer"],
inner_tensors["_scale_buffer"],
metadata["mm_config"],
metadata["is_amax_initialized"],
)

def __repr__(self):
return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ill let @weifengpy confirm this portion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirming that fsdp part looks good

# initialize if needed
# TODO(before land): ensure settings are consistent between Float8Linear and here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to resolve this?

if not self.is_amax_initialized:
from float8_experimental.float8_linear import (
_maybe_initialize_amaxes_scales_for_float8_cast,
)

_maybe_initialize_amaxes_scales_for_float8_cast(
self._tensor,
self._amax_buffer,
self._amax_history_buffer,
self._scale_buffer,
"max", # TODO(before land): read this from parent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

e4m3_dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.is_amax_initialized = True

# this will:
# 1. cast the tensor to float8 using `_scale_buffer`
# 2. populate `_amax_buffer` inplace
# TODO(future PR): clean up all the casting functions and clearly
# separate dynamic vs delayed, tech debt has accumulated
float8_tensor = Float8Tensor.to_float8(
self._tensor,
self._scale_buffer,
e4m3_dtype,
self._amax_buffer,
self._mm_config,
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
(scale,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
18 changes: 16 additions & 2 deletions test/test_fsdp2/test_fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp


Expand All @@ -17,6 +22,7 @@ def check_parity_no_mp(
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
precompute: bool = False,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
Expand All @@ -28,10 +34,18 @@ def check_parity_no_mp(
for param in model.parameters():
dist.all_reduce(param.grad)
param.grad.div_(dist.get_world_size())
# TODO(future): add amax syncing once delayed scaling is supported

if linear_requires_sync(scaling_type_w=scaling_type_w):
sync_float8_amax_and_scale_history(model)

optim.step()
if model is fsdp_model and precompute:
if (
model is fsdp_model
and precompute
and scaling_type_w is TensorScalingType.DYNAMIC
):
precompute_float8_dynamic_scale_for_fsdp(model)

test_cls.assertEqual(losses[0], losses[1])


Expand Down
Loading
Loading