Skip to content

Commit

Permalink
[fix] float8 should be applied on all model_parts
Browse files Browse the repository at this point in the history
ghstack-source-id: 52ed6836de39e82c4c5824a40ecfc1d9ec7ed2bd
Pull Request resolved: #500
  • Loading branch information
tianyu-l committed Aug 5, 2024
1 parent ce08308 commit f9e114b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
18 changes: 14 additions & 4 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance

from typing import List, Union

import torch
import torch.nn as nn

Expand Down Expand Up @@ -103,7 +105,9 @@ def convert_to_float8_training(self, model: nn.Module):
f"{self.config.enable_fsdp_float8_all_gather}"
)

def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):
def precompute_float8_dynamic_scale_for_fsdp(
self, model: Union[nn.Module, List[nn.Module]]
):
if not self.enabled:
return

Expand All @@ -112,9 +116,13 @@ def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):

from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)

def sync_float8_amax_and_scale_history(self, model: nn.Module):
def sync_float8_amax_and_scale_history(
self, model: Union[nn.Module, List[nn.Module]]
):
if not self.enabled:
return

Expand All @@ -136,4 +144,6 @@ def sync_float8_amax_and_scale_history(self, model: nn.Module):
sync_float8_amax_and_scale_history
)

self._sync_float8_amax_and_scale_history(model)
models = [model] if isinstance(model, nn.Module) else model
for m in models:
self._sync_float8_amax_and_scale_history(m)
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def loss_fn(pred, labels):
)

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
float8_handler.sync_float8_amax_and_scale_history(model_parts)

# optimizer step
checkpoint.maybe_wait_for_staging()
Expand All @@ -316,7 +316,7 @@ def loss_fn(pred, labels):

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)

losses_since_last_log.append(loss)

Expand Down

0 comments on commit f9e114b

Please sign in to comment.