diff --git a/torchtitan/float8.py b/torchtitan/float8.py index 4dc7122b..043b1832 100644 --- a/torchtitan/float8.py +++ b/torchtitan/float8.py @@ -13,6 +13,8 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance +from typing import List, Union + import torch import torch.nn as nn @@ -103,7 +105,9 @@ def convert_to_float8_training(self, model: nn.Module): f"{self.config.enable_fsdp_float8_all_gather}" ) - def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module): + def precompute_float8_dynamic_scale_for_fsdp( + self, model: Union[nn.Module, List[nn.Module]] + ): if not self.enabled: return @@ -112,9 +116,13 @@ def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module): from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp - precompute_float8_dynamic_scale_for_fsdp(model) + models = [model] if isinstance(model, nn.Module) else model + for m in models: + precompute_float8_dynamic_scale_for_fsdp(m) - def sync_float8_amax_and_scale_history(self, model: nn.Module): + def sync_float8_amax_and_scale_history( + self, model: Union[nn.Module, List[nn.Module]] + ): if not self.enabled: return @@ -136,4 +144,6 @@ def sync_float8_amax_and_scale_history(self, model: nn.Module): sync_float8_amax_and_scale_history ) - self._sync_float8_amax_and_scale_history(model) + models = [model] if isinstance(model, nn.Module) else model + for m in models: + self._sync_float8_amax_and_scale_history(m) diff --git a/train.py b/train.py index 5c62debf..58d23307 100644 --- a/train.py +++ b/train.py @@ -307,7 +307,7 @@ def loss_fn(pred, labels): ) # sync float8 amaxes and scales - float8_handler.sync_float8_amax_and_scale_history(model) + float8_handler.sync_float8_amax_and_scale_history(model_parts) # optimizer step checkpoint.maybe_wait_for_staging() @@ -316,7 +316,7 @@ def loss_fn(pred, labels): # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) + float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) losses_since_last_log.append(loss)