Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft Auto-SAC workflow #710

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ sentencepiece
tiktoken
blobfile
tabulate
pwlf
pulp
155 changes: 109 additions & 46 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,40 @@
import contextlib
import gc
import os
from typing import Any, Set, Union

import torch
from torch import nn, optim
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.auto_sac import (
AutoSACResult,
get_auto_sac_policies,
get_module_name_dict,
SACAlgorithm,
)
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.logging import init_logger, logger, logging
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims


def estimate_memory(job_config: JobConfig):
init_logger()
def estimate(job_config: JobConfig) -> Union[AutoSACResult, None]:
assert not (
job_config.memory_estimation.enabled and job_config.sac_estimation.enabled
), "Enabling SAC estimation and FSDP memory estimation together is not permitted."
if job_config.memory_estimation.enabled:
init_logger()
else:
logging.disable()

logger.info("Estimating memory usage...")
gc.disable()
gc.collect(1)
Expand All @@ -37,10 +52,9 @@ def estimate_memory(job_config: JobConfig):
if (
job_config.training.tensor_parallel_degree > 1
or job_config.experimental.pipeline_parallel_degree > 1
or job_config.experimental.context_parallel_degree > 1
):
logger.info(
"Tensor parallelism and pipeline parallelism are not supported yet."
)
logger.info("Tensor, Context and Pipeline parallelism are not supported yet.")
return

# fake tensor doesn't work with fused rmsnorm
Expand Down Expand Up @@ -76,25 +90,12 @@ def estimate_memory(job_config: JobConfig):
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)

# init fake pg
store = FakeStore()
torch.distributed.init_process_group(
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

if not parallel_dims.dp_enabled:
logger.info("Data parallelism is not enabled. Skipping memory estimation.")
return

model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

train_context = utils.get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
Expand All @@ -111,11 +112,28 @@ def loss_fn(pred, labels):
model_config = models_config[model_name][job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
# 3. max_seq_len base on inputs
# 2. max_seq_len base on inputs
# 3. vocab size from tokenizer

model_config.norm_type = job_config.model.norm_type
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len
model_config.vocab_size = 128256
if not job_config.sac_estimation.enabled:
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
model_config.vocab_size = tokenizer.n_words
# init fake pg
store = FakeStore()
torch.distributed.init_process_group(
"fake",
rank=int(os.environ["LOCAL_RANK"]),
world_size=world_size,
store=store,
)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():

Expand All @@ -130,8 +148,11 @@ def loss_fn(pred, labels):
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
if not job_config.sac_estimation.enabled:
# apply PT-D DP/TP parallelisms and activation checkpointing
models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)

model.to_empty(device="cuda")
if not active_fake_mode():
Expand All @@ -158,32 +179,74 @@ def loss_fn(pred, labels):
device="cuda",
),
)

def train_step(models: nn.Module, optims: optim.Optimizer, batch: Any):
# train step
input_ids, labels = batch
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
optimizers.step()
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()

if job_config.sac_estimation.enabled:
logging.disable(logging.NOTSET)
gib = 1024**3
budget = float(job_config.activation_checkpoint.auto_sac_budget)
recommended_budget = (
0.85 * torch.cuda.get_device_properties(device).total_memory / gib
)
if budget > recommended_budget:
logger.warning(
"It is recommended to set Auto-SAC memory budget to 85 percent of device memory.\n"
"Current budget is %.2f GiB, reducing it to %.2f GiB.",
budget,
recommended_budget,
)
budget = recommended_budget

mod_fqns = get_module_name_dict(model)
fsdp_unit_fqns: Set[str] = set()
for transformer_block in model.layers.values():
fsdp_unit_fqns.add(mod_fqns[transformer_block])
fsdp_unit_fqns.add(mod_fqns[model])

sac_algorithm = SACAlgorithm(
job_config.activation_checkpoint.auto_sac_algorithm
)
auto_sac_result = get_auto_sac_policies(
train_step=train_step,
models=[model],
optimizers=optimizers.optimizers,
inputs=batch,
dev=device,
memory_budget=budget,
sac_algo=sac_algorithm,
shard_degree=parallel_dims.dp_shard,
fsdp_units=fsdp_unit_fqns,
)
return auto_sac_result

fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
optimizers.step()
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()
train_step([model], optimizers.optimizers, batch)
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
if iter_idx == 0:
Expand Down Expand Up @@ -214,7 +277,7 @@ def loss_fn(pred, labels):
config = JobConfig()
config.parse_args()
try:
estimate_memory(config)
estimate(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the plan to call estimate(config) from train.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Because the current model we have will not be initialized in FakeTensorMode and neither do we have the entire train_step or optimizer at this point. Hence we call estimate by supplying the correct config and then obtain the SAC result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you. Is the plan to include trian.py changes in this PR as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no changes to train.py. The estimate is called from parallelize_llama.py. Please see the diff. This PR is self contained.

finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
7 changes: 4 additions & 3 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
)

from torchtitan import utils
from torchtitan.utils import device_module, device_type

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
from torchtitan.metrics import build_device_memory_monitor
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.parallelisms import ParallelDims
from torchtitan.utils import device_module, device_type

# support running w/o installing as package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -143,7 +143,8 @@ def test_generate(
# Build world mesh for parallelism
world_mesh = parallel_dims.build_mesh(device_type=device_type)

# apply_tp (with Sequence Parallel) on unevenly sharded sequences would require https://github.com/pytorch/torchtitan/pull/686
# apply_tp (with Sequence Parallel) on unevenly sharded sequences would require
# https://github.com/pytorch/torchtitan/pull/686
apply_tp_minus_sp(model, world_mesh["tp"])

# materalize model
Expand Down
28 changes: 27 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def __init__(self):
"--activation_checkpoint.mode",
type=str,
default="selective",
help="Type of activation checkpointing to use ['none', 'full', 'selective']",
help="Type of activation checkpointing to use ['none', 'full', 'selective', 'auto']",
)
self.parser.add_argument(
"--activation_checkpoint.selective_ac_option",
Expand All @@ -490,6 +490,25 @@ def __init__(self):
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
""",
)
self.parser.add_argument(
"--activation_checkpoint.auto_sac_budget",
type=str,
default="65.0",
help="""
Auto-SAC Memory Budget in GiB.
Recommended to set 85 percent of total device memory.
""",
)
self.parser.add_argument(
"--activation_checkpoint.auto_sac_algorithm",
type=str,
default="optimal",
choices=["greedy", "optimal"],
help="""
Algorithm to use for determining SAC policies.
`greedy` runs in linear time, while `optimal` solves an ILP.
""",
)

# float8 configs
self.parser.add_argument(
Expand Down Expand Up @@ -570,6 +589,13 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--sac_estimation.enabled",
help="Whether to calculate SAC (Selective Activation Checkpointing) policies",
default=False,
action="store_true",
)

def parse_args(self, args_list: list = sys.argv[1:]):
args, cmd_args = self.parse_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
Expand Down
45 changes: 44 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

from collections import defaultdict
from copy import deepcopy

import torch
import torch.nn as nn
Expand All @@ -20,6 +21,7 @@
)
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed._tools.auto_sac import apply_auto_sac_policies
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
Expand Down Expand Up @@ -66,7 +68,13 @@ def parallelize_llama(
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
if job_config.activation_checkpoint.mode == "auto":
if not apply_auto_sac(model, job_config):
logger.info("Auto-SAC failed, falling back to full AC mode.")
job_config.activation_checkpoint.mode = "full"
apply_ac(model, job_config.activation_checkpoint)
else:
apply_ac(model, job_config.activation_checkpoint)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if job_config.training.compile:
Expand Down Expand Up @@ -314,6 +322,41 @@ def apply_ac(model: nn.Module, ac_config):
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")


def apply_auto_sac(model: nn.Module, job_config: JobConfig) -> bool:
if (
job_config.training.tensor_parallel_degree > 1
or job_config.experimental.pipeline_parallel_degree > 1
or job_config.experimental.context_parallel_degree > 1
or job_config.training.enable_cpu_offload
):
logger.info(
"Tensor, Context and Pipeline parallelism or FSDP with CPU Offload option"
" are not supported yet with Auto-SAC."
)
return False
est_job_config = deepcopy(job_config)
est_job_config.memory_estimation.disable_fake_mode = False
est_job_config.memory_estimation.enabled = False
est_job_config.sac_estimation.enabled = True
est_job_config.training.compile = False
est_job_config.experimental.enable_compiled_autograd = False
if (
est_job_config.model.norm_type == "compiled_rmsnorm"
or est_job_config.model.norm_type == "fused_rmsnorm"
):
est_job_config.model.norm_type = "rmsnorm"
from scripts.estimate.estimation import estimate

auto_sac_result = estimate(est_job_config)
assert auto_sac_result is not None
if auto_sac_result.peak_mem == -1:
return False
apply_auto_sac_policies(
model, auto_sac_result.sac_policies, preserve_rng_state=False
)
return True


def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
Expand Down
Loading