Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] float8 should be applied on all model_parts #500

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance

from typing import List, Union

import torch
import torch.nn as nn

Expand Down Expand Up @@ -103,7 +105,9 @@ def convert_to_float8_training(self, model: nn.Module):
f"{self.config.enable_fsdp_float8_all_gather}"
)

def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):
def precompute_float8_dynamic_scale_for_fsdp(
self, model: Union[nn.Module, List[nn.Module]]
):
if not self.enabled:
return

Expand All @@ -112,9 +116,13 @@ def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):

from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)

def sync_float8_amax_and_scale_history(self, model: nn.Module):
def sync_float8_amax_and_scale_history(
self, model: Union[nn.Module, List[nn.Module]]
):
if not self.enabled:
return

Expand All @@ -136,4 +144,6 @@ def sync_float8_amax_and_scale_history(self, model: nn.Module):
sync_float8_amax_and_scale_history
)

self._sync_float8_amax_and_scale_history(model)
models = [model] if isinstance(model, nn.Module) else model
for m in models:
self._sync_float8_amax_and_scale_history(m)
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def loss_fn(pred, labels):
)

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
float8_handler.sync_float8_amax_and_scale_history(model_parts)

# optimizer step
checkpoint.maybe_wait_for_staging()
Expand All @@ -316,7 +316,7 @@ def loss_fn(pred, labels):

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)

losses_since_last_log.append(loss)

Expand Down
Loading