Skip to content

Commit

Permalink
enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (pytorch#413)
Browse files Browse the repository at this point in the history
we have landed fp8 all-gather optimizations in float8_experimental
pytorch-labs/float8_experimental#266

this PR proposes torchtitan changes. also include fp8 in CI
```
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
```

FSDP2 fp8 all-gather are added to CI
```
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather
CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_fp8_linear --training.enable_fsdp_fp8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp
```

TP fp8 all-gather are locally tested. will add them to CI after
uploading a new tokenizer with vacab size 2560 (divisible by 16)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 1 --training.tensor_parallel_degree 4
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=4 ./run_llama_train.sh --training.enable_fp8_linear --training.data_parallel_degree 2 --training.tensor_parallel_degree 2
```

precompute scales after optimizer.step
<img width="319" alt="Screenshot 2024-07-12 at 5 11 14 PM"
src="https://github.com/user-attachments/assets/1c55bd89-9183-42ca-9445-23f3b95e0817">

FSDP2 pre-all-gather do not have any small all-reduces
<img width="794" alt="Screenshot 2024-07-12 at 5 13 04 PM"
src="https://github.com/user-attachments/assets/1a00dc70-a8ca-4ce1-a93c-316f22efdb08">

TODO
* upload tokenizer with vacab size 2560 to enable CI on TP fp8
all-gather
* torch.compile complains about fp8
* add delayed scaling and brainstorm about best config option to express
fp8
* compare perf between delayed scaling and dynamic scaling
https://github.com/pytorch-labs/float8_experimental/pull/312/files
  • Loading branch information
weifengpy authored Jul 16, 2024
1 parent 174c44a commit a4b2ee3
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 20 deletions.
4 changes: 2 additions & 2 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def loss_fn(pred, labels):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)
if job_config.training.enable_fp8_linear:
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
Expand Down
33 changes: 33 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,39 @@ def build_test_list():
"fsdp2_mem_tracker",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
]
],
"FSDP2 with original dtype",
"fp8_fsdp2_orig_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
]
],
"FSDP2 with fp8 all-gather",
"fsdp2_fp8_all_gather",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.enable_fp8_linear",
"--training.enable_fsdp_fp8_all_gather",
"--training.precompute_float8_dynamic_scale_for_fsdp",
]
],
"FSDP2 with fp8 all-gather and precomputed dynamic scales",
"fsdp2_fp8_all_gather_precompute_dynamic_scales",
ngpu=4,
),
]
return integration_tests_flavors

Expand Down
14 changes: 13 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self):
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.fp8_linear",
"--training.enable_fp8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear` with
Expand All @@ -347,6 +347,18 @@ def __init__(self):
here: https://github.com/pytorch-labs/float8_experimental
""",
)
self.parser.add_argument(
"--training.enable_fsdp_fp8_all_gather",
action="store_true",
default=False,
help="Whether enable fp8 all-gather in FSDP",
)
self.parser.add_argument(
"--training.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute fp8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
41 changes: 34 additions & 7 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,58 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib
from typing import Optional

import float8_experimental.config as config

import torch
import torch.nn as nn
from float8_experimental.float8_linear import TensorScalingType

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger


def build_fp8_linear(model: nn.Module, job_config: JobConfig):
@contextlib.contextmanager
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
prev = config.enable_fsdp_fp8_all_gather
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
try:
yield
finally:
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = prev


def build_fp8_linear(
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
):
"""
This function converts the linear layers to `Float8Linear`. Note that today,
only dynamic tensor scaling (the default) is supported.
This will mutate the model inplace.
"""
use_fp8_linear = job_config.training.fp8_linear
enable_fp8_linear = job_config.training.enable_fp8_linear
enable_fsdp_fp8_all_gather = (
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
)
try:
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_fp8_all_gather=}"
)
except ImportError as exc:
raise ImportError(
"float8_experimental is not installed. Please install it to use fp8 linear layers."
) from exc
if use_fp8_linear:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
swap_linear_with_float8_linear(model, Float8Linear)
logger.info("Swapped to Float8Linear layers")
16 changes: 14 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,24 @@ def selective_checkpointing_context_fn():

def get_tp_parallel_strategy(
job_config: JobConfig,
model: nn.Module,
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
"""Get the parallel strategy for the transformer model.
This function handles the special case of using float8 with tensor parallelism.
"""
if job_config.training.fp8_linear == "dynamic":
if job_config.training.enable_fp8_linear:
from float8_experimental.float8_linear import Float8Linear, TensorScalingType

if any(
isinstance(m, Float8Linear)
and m.scaling_type_w is TensorScalingType.DELAYED
for m in model.modules()
):
raise NotImplementedError(
"1D TP fp8 all-gather only supports dynamic scaling"
)

from float8_experimental.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
Expand Down Expand Up @@ -346,7 +358,7 @@ def apply_tp(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_input,
) = get_tp_parallel_strategy(job_config)
) = get_tp_parallel_strategy(job_config, model)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the embedding and shard its outputs (which are the first
Expand Down
14 changes: 12 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn.functional as F
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -216,8 +217,8 @@ def loss_fn(pred, labels):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)
if job_config.training.enable_fp8_linear:
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# log model size
model_param_count = get_num_params(whole_model)
Expand Down Expand Up @@ -398,6 +399,15 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

if (
job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
precompute_float8_dynamic_scale_for_fsdp(model)

losses_since_last_log.append(loss)

# log metrics
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 8 # 8-way TP
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
fp8_linear = false
enable_fp8_linear = false
compile = false
dataset = "c4"

Expand Down

0 comments on commit a4b2ee3

Please sign in to comment.