diff --git a/estimation.py b/estimation.py index 70fb66cb..13ccd4c1 100644 --- a/estimation.py +++ b/estimation.py @@ -122,33 +122,25 @@ def loss_fn(pred, labels): f"Building {model_name} {job_config.model.flavor} with {model_config}" ) with torch.device("meta"): - whole_model = model_cls.from_model_args(model_config) + model = model_cls.from_model_args(model_config) # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear based on float8 configs - float8_handler.convert_to_float8_training(whole_model) + float8_handler.convert_to_float8_training(model) # apply PT-D DP/TP parallelisms and activation checkpointing - model_parts = [whole_model] - model_parts = [ - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - for m in model_parts - ] - - init_device = "cuda" - for model in model_parts: - model.to_empty(device=init_device) + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + model.to_empty(device="cuda") if not active_fake_mode(): - whole_model.init_weights() + model.init_weights() + model.train() # build optimizer after applying parallelisms to the model - optimizers = build_optimizers(model_parts, job_config) + optimizers = build_optimizers([model], job_config) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) - for model in model_parts: - model.train() logger.info(f"Vocab size: {model_config.vocab_size}") # Create a dummy batch instead of loading from a dataset batch = ( @@ -165,7 +157,7 @@ def loss_fn(pred, labels): device="cuda", ), ) - fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0]) + fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0]) fsdp_memtracker.track_inputs(batch) with fsdp_memtracker: @@ -173,16 +165,15 @@ def loss_fn(pred, labels): input_ids, labels = batch # train step with train_context(): - pred = whole_model(input_ids) + pred = model(input_ids) loss = loss_fn(pred, labels) del pred loss.backward() # clip gradients - for model in model_parts: - torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True - ) + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm, foreach=True + ) # sync float8 amaxes and scales float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index dc06d572..b75cb336 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -8,11 +8,9 @@ from torchtitan.parallelisms.parallel_dims import ParallelDims from torchtitan.parallelisms.parallelize_llama import parallelize_llama from torchtitan.parallelisms.pipeline_llama import pipeline_llama -from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule __all__ = [ - "build_pipeline_schedule", "models_parallelize_fns", "models_pipelining_fns", "ParallelDims", diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index fa093b6e..67983270 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -7,7 +7,7 @@ # This file applies the PT-D pipeline parallelism to the Llama model. import copy -from typing import Union +from typing import Callable, Union import torch import torch.nn as nn @@ -18,7 +18,10 @@ from torchtitan.logging import logger from torchtitan.models.llama.model import ModelArgs from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank +from torchtitan.parallelisms.pipelining_utils import ( + build_pipeline_schedule, + stage_ids_this_rank, +) DeviceType = Union[int, str, torch.device] @@ -31,6 +34,7 @@ def pipeline_llama( job_config: JobConfig, device: DeviceType, model_config: ModelArgs, + loss_fn: Callable[..., torch.Tensor], ): split_mode = job_config.experimental.pipeline_parallel_split_mode valid_split_modes = ("manual", "tracer") @@ -39,14 +43,18 @@ def pipeline_llama( f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}" ) if split_mode == "manual": - return pipeline_llama_manual( + stages, models = pipeline_llama_manual( model, pp_mesh, parallel_dims, job_config, device, model_config ) elif split_mode == "tracer": - return pipeline_llama_tracer( + stages, models = pipeline_llama_tracer( model, pp_mesh, parallel_dims, job_config, device, model_config ) + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + return pp_schedule, models + def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"): """Get meta tensors with the right input shapes used for tracing""" @@ -218,4 +226,4 @@ def pipeline_llama_tracer( group=pp_mesh.get_group(), ) ) - return (stages, models) + return stages, models diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index aafe70fa..a5c61e62 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -14,7 +14,7 @@ from torchtitan.logging import logger -def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): +def build_pipeline_schedule(job_config, stages, loss_fn): looped_schedule = False if job_config.experimental.pipeline_parallel_schedule == "1f1b": diff --git a/train.py b/train.py index afb74408..390263bb 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,6 @@ from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import ( - build_pipeline_schedule, models_parallelize_fns, models_pipelining_fns, ParallelDims, @@ -143,11 +142,8 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - stages, model_parts = models_pipelining_fns[model_name]( - model, pp_mesh, parallel_dims, job_config, device, model_config - ) - pp_schedule = build_pipeline_schedule( - job_config, parallel_dims, stages, loss_fn + pp_schedule, model_parts = models_pipelining_fns[model_name]( + model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # For PP with looped schedules, each item in model_parts is one stage-model-chunk.