diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 30317e3c3..b71419c61 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -10,6 +10,8 @@ import re import shutil import time +from dataclasses import dataclass, field +from io import BytesIO from multiprocessing import get_context from typing import Any, Dict, List, Union @@ -27,7 +29,7 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import DataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP -from torchtitan.logging_utils import init_logger, logger +from torchtitan.logging import init_logger, logger class IntervalType(enum.Enum): @@ -41,6 +43,43 @@ class AsyncMode(str, enum.Enum): ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" +@dataclass +class TrainState(Stateful): + step: int = 0 + global_avg_losses: List[float] = field(default_factory=list) + global_max_losses: List[float] = field(default_factory=list) + log_steps: List[int] = field(default_factory=list) + + def state_dict(self) -> Dict[str, Any]: + # Only checkpoint global_avg_losses and global_max_losses per log frequency + # to avoid sync overhead in every iteration. + global_avg_losses_bytes = BytesIO() + torch.save(self.global_avg_losses, global_avg_losses_bytes) + global_max_losses_bytes = BytesIO() + torch.save(self.global_max_losses, global_max_losses_bytes) + log_steps_bytes = BytesIO() + torch.save(self.log_steps, log_steps_bytes) + return { + "step": torch.tensor(self.step, dtype=torch.int32), + "global_avg_losses": global_avg_losses_bytes, + "global_max_losses": global_max_losses_bytes, + "log_steps": log_steps_bytes, + } + + def load_state_dict(self, state_dict) -> None: + self.step = state_dict["step"].item() + state_dict["global_avg_losses"].seek(0) + self.global_avg_losses = torch.load( + state_dict["global_avg_losses"], weights_only=False + ) + state_dict["global_max_losses"].seek(0) + self.global_max_losses = torch.load( + state_dict["global_max_losses"], weights_only=False + ) + state_dict["log_steps"].seek(0) + self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) + + class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model @@ -124,10 +163,10 @@ def checkpoint_mp(recv, send): class CheckpointManager: def __init__( self, + dataloader: DataLoader, model_parts: List[nn.Module], optimizers: List[torch.optim.Optimizer], lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler], - dataloader: DataLoader, states: Dict[str, Any], job_config: JobConfig, ) -> None: @@ -390,7 +429,7 @@ def save(self, curr_step: int, force: bool = False) -> None: f"in {time.monotonic() - begin:.2f} seconds." ) - def wait_for_staging(self) -> None: + def maybe_wait_for_staging(self) -> None: if ( self.enable_checkpoint and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 330701203..dd5ba7cde 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -16,7 +16,7 @@ except ModuleNotFoundError: import tomli as tomllib -from torchtitan.logging_utils import logger +from torchtitan.logging import logger TORCH_DTYPE_MAP = { "float16": torch.float16, diff --git a/torchtitan/datasets/__init__.py b/torchtitan/datasets/__init__.py index e9a149c64..75ea6b663 100644 --- a/torchtitan/datasets/__init__.py +++ b/torchtitan/datasets/__init__.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. from torchtitan.datasets.hf_datasets import build_hf_data_loader -from torchtitan.datasets.tokenizer import create_tokenizer +from torchtitan.datasets.tokenizer import build_tokenizer __all__ = [ "build_hf_data_loader", - "create_tokenizer", + "build_tokenizer", ] diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index d8cd5d83e..0b894e24c 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -20,7 +20,7 @@ ) from e from torchtitan.datasets.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from datasets import load_dataset from datasets.distributed import split_dataset_by_node diff --git a/torchtitan/datasets/tokenizer/__init__.py b/torchtitan/datasets/tokenizer/__init__.py index 346caf83b..7ff747228 100644 --- a/torchtitan/datasets/tokenizer/__init__.py +++ b/torchtitan/datasets/tokenizer/__init__.py @@ -8,10 +8,10 @@ from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger -def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer: +def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer: logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}") if tokenizer_type == "sentencepiece": return SentencePieceTokenizer(tokenizer_path) diff --git a/torchtitan/datasets/tokenizer/sentencepiece.py b/torchtitan/datasets/tokenizer/sentencepiece.py index 7229daa3d..c71afddd9 100644 --- a/torchtitan/datasets/tokenizer/sentencepiece.py +++ b/torchtitan/datasets/tokenizer/sentencepiece.py @@ -11,7 +11,7 @@ from sentencepiece import SentencePieceProcessor from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger class SentencePieceTokenizer(Tokenizer): diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py index 1ec5de203..c879e7f3f 100644 --- a/torchtitan/datasets/tokenizer/tiktoken.py +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -26,7 +26,7 @@ from tiktoken.load import load_tiktoken_bpe from torchtitan.datasets.tokenizer.tokenizer import Tokenizer -from torchtitan.logging_utils import logger +from torchtitan.logging import logger class TikTokenizer(Tokenizer): diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index 658a41cc3..fa311061d 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -20,7 +20,7 @@ from torch._logging import warning_once from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger @functools.lru_cache(None) diff --git a/torchtitan/logging_utils.py b/torchtitan/logging.py similarity index 100% rename from torchtitan/logging_utils.py rename to torchtitan/logging.py diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 1717439b5..87c082a3e 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -12,7 +12,8 @@ import torch from torch.utils.tensorboard import SummaryWriter from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger +from torchtitan.parallelisms import ParallelDims # named tuple for passing GPU memory stats for logging GPUMemStats = namedtuple( @@ -110,16 +111,29 @@ def close(self): self.writer.close() +def _get_metrics_rank(parallel_dims: ParallelDims) -> int: + """ + Returns global rank 0 in non-pipeline-parallel configs, and returns the global + rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. + """ + if parallel_dims.pp_enabled: + world_size = parallel_dims.world_size + pp_size = parallel_dims.pp + metrics_log_rank = (world_mesh.size() // pp_size) * (pp_size - 1) + else: + metrics_log_rank = 0 + + return metrics_log_rank + + def build_metric_logger( - config: JobConfig, metrics_log_rank: int = 0, tag: Optional[str] = None + config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None ): """ - metrics_log_rank controls which rank acts as 'rank 0' for logging metrics. - - If 'tb_config.rank_0_only' is set, then `metrics_log_rank` will be used as the rank to log metrics. - This is intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline - parallelism is enabled, without forcing logging from all ranks to capture loss information when using pipeline - parallelism. + parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. + In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is + intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline + parallelism is enabled, without forcing logging from all ranks to capture loss information. """ dump_dir = config.job.dump_folder tb_config = config.metrics @@ -134,7 +148,7 @@ def build_metric_logger( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" ) if tb_config.rank_0_only: - enable_tb = torch.distributed.get_rank() == metrics_log_rank + enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) else: rank_str = f"rank_{torch.distributed.get_rank()}" log_dir = os.path.join(log_dir, rank_str) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda6241..e47d0fb9f 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -14,7 +14,7 @@ import torch import torch.nn.functional as F from torch import nn -from torchtitan.models.norms import create_norm +from torchtitan.models.norms import build_norm @dataclass @@ -291,10 +291,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs): self.layer_id = layer_id self.num_layers = model_args.n_layers - self.attention_norm = create_norm( + self.attention_norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) - self.ffn_norm = create_norm( + self.ffn_norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) @@ -370,7 +370,7 @@ def __init__(self, model_args: ModelArgs): for layer_id in range(model_args.n_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) - self.norm = create_norm( + self.norm = build_norm( model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps ) diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 10a6b8531..c0ef6a803 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -18,18 +18,18 @@ from torch.distributed._tensor.experimental import local_map -def create_norm(norm_type: str, dim: int, eps: float = 1e-6): +def build_norm(norm_type: str, dim: int, eps: float = 1e-6): """ - Creates the specified normalization layer based on the norm_type. + Builds the specified normalization layer based on the norm_type. Args: - norm_type (str): The type of normalization layer to create. + norm_type (str): The type of normalization layer to build. Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm dim (int): The dimension of the normalization layer. eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. Returns: - The created normalization layer. + The built normalization layer. Raises: NotImplementedError: If an unknown norm_type is provided. diff --git a/torchtitan/lr_scheduling.py b/torchtitan/optimizer.py similarity index 51% rename from torchtitan/lr_scheduling.py rename to torchtitan/optimizer.py index 9f7662680..3f9eb3a8f 100644 --- a/torchtitan/lr_scheduling.py +++ b/torchtitan/optimizer.py @@ -6,10 +6,57 @@ import functools +import torch from torch.optim.lr_scheduler import LambdaLR from torchtitan.config_manager import JobConfig +# consider split between PP and non-PP +def build_optimizers(model_parts, job_config: JobConfig): + """Wrap one optimizer per model part in an OptimizersContainer which provides a single + step() and zero_grad() method for all the child optimizers. + """ + + def _build_optimizer(model): + name = job_config.optimizer.name + lr = job_config.optimizer.lr + fused = job_config.optimizer.fused + + # Common parameters for both optimizers + optimizer_kwargs = { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1, + "fused": fused, + "foreach": not fused, + } + if name == "Adam": + # TODO: make the optimizer options configurable by toml/cmd args + optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) + elif name == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) + else: + raise NotImplementedError(f"Optimizer {name} not added.") + + return optimizer + + class OptimizersContainer: + """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" + + def __init__(self, optimizers): + self.optimizers = optimizers + + def step(self): + for optimizer in self.optimizers: + optimizer.step() + + def zero_grad(self): + for optimizer in self.optimizers: + optimizer.zero_grad() + + return OptimizersContainer([_build_optimizer(model) for model in model_parts]) + + def linear_warmup_linear_decay( warmup_steps: int, decay_steps: int, current_step: int ) -> float: @@ -32,8 +79,8 @@ def linear_warmup_linear_decay( return curr_adjustment -def get_lr_schedulers(optimizers, job_config: JobConfig): - def _get_lr_scheduler(optimizer): +def build_lr_schedulers(optimizers, job_config: JobConfig): + def _build_lr_scheduler(optimizer): """Build a linear warmup and linear decay scheduler""" warmup_steps = int(job_config.training.warmup_steps) decay_steps = float(max(1, job_config.training.steps - warmup_steps)) @@ -54,5 +101,5 @@ def step(self): schedulers.step() return SchedulersContainer( - [_get_lr_scheduler(optimizer) for optimizer in optimizers] + [_build_lr_scheduler(optimizer) for optimizer in optimizers] ) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index 2fdba316f..7188474dd 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -8,8 +8,17 @@ from functools import cached_property from torch.distributed.device_mesh import init_device_mesh -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule + + +__all__ = [ + "build_pipeline_schedule", + "models_parallelize_fns", + "models_pipelining_fns", + "ParallelDims", +] models_parallelize_fns = { "llama2": parallelize_llama, diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index e3c6fc80d..11a8188fd 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -32,7 +32,7 @@ ) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP -from torchtitan.logging_utils import logger +from torchtitan.logging import logger from torchtitan.models.llama.model import ModelArgs from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index adf9eb090..aafe70faf 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -11,12 +11,12 @@ ScheduleGPipe, ScheduleInterleaved1F1B, ) -from torchtitan.logging_utils import logger +from torchtitan.logging import logger def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): - looped_schedule = False + if job_config.experimental.pipeline_parallel_schedule == "1f1b": schedule_class = Schedule1F1B elif job_config.experimental.pipeline_parallel_schedule == "gpipe": diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 662b64f8c..9da5c8fb9 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -11,7 +11,7 @@ import torch from torchtitan.config_manager import JobConfig -from torchtitan.logging_utils import logger +from torchtitan.logging import logger # the number of warmup steps before the active step in each profiling cycle WARMUP = 3 diff --git a/torchtitan/utils.py b/torchtitan/utils.py index c29836601..3ed74d133 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import gc import os from dataclasses import dataclass from datetime import timedelta @@ -13,18 +14,17 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch.distributed.device_mesh import DeviceMesh -from torchtitan.logging_utils import logger -from torchtitan.parallelisms import ParallelDims +from torchtitan.logging import logger def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item() def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item() def _warn_overwrite_env(env, val): @@ -35,24 +35,6 @@ def _warn_overwrite_env(env, val): os.environ[env] = val -def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int: - """ - Returns global rank 0 in non-pipeline-parallel configs, and returns the global - rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. - """ - if parallel_dims.pp_enabled: - assert ( - world_mesh.mesh_dim_names[0] == "pp" - ), "get_metrics_rank assumes pp is the outer mesh dim" - pp_mesh = world_mesh["pp"] - pp_size = pp_mesh.size() - metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) - else: - metrics_log_rank = 0 - - return metrics_log_rank - - def set_pg_timeouts(timeout, world_mesh): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -80,6 +62,19 @@ def set_pg_timeouts(timeout, world_mesh): torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) +# used to avoid stragglers in garbage collection +class GarbageCollection: + def __init__(self, gc_freq=1000): + assert gc_freq > 0, "gc_freq must be a positive integer" + self.gc_freq = gc_freq + gc.disable() + gc.collect(1) + + def run(self, step_count): + if step_count > 1 and step_count % self.gc_freq == 0: + gc.collect(1) + + TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" diff --git a/train.py b/train.py index 5a637f46f..56b78de7d 100644 --- a/train.py +++ b/train.py @@ -5,138 +5,33 @@ # LICENSE file in the root directory of this source tree. import contextlib -import gc import os import time - -from dataclasses import dataclass, field from datetime import timedelta -from io import BytesIO -from timeit import default_timer as timer -from typing import Any, Dict, List - -import numpy as np import torch -import torch.nn.functional as F -from torch.distributed import destroy_process_group -from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.tensor.parallel import loss_parallel -from torchtitan.checkpoint import CheckpointManager +import torchtitan.utils as utils +from torch.distributed.elastic.multiprocessing.errors import record +from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig -from torchtitan.datasets import build_hf_data_loader, create_tokenizer +from torchtitan.datasets import build_hf_data_loader, build_tokenizer from torchtitan.float8_linear import ( maybe_build_fp8_linear, maybe_precompute_fp8_dynamic_scale_for_fsdp, maybe_sync_float8_amax_and_scale_history, ) -from torchtitan.logging_utils import init_logger, logger -from torchtitan.lr_scheduling import get_lr_schedulers +from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import ( + build_pipeline_schedule, models_parallelize_fns, models_pipelining_fns, ParallelDims, ) -from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling -from torchtitan.utils import ( - Color, - dist_max, - dist_mean, - get_metrics_rank, - get_num_flop_per_token, - get_num_params, - get_peak_flops, - init_distributed, - NoColor, - set_pg_timeouts, -) - - -@dataclass -class TrainState(Stateful): - step: int = 0 - global_avg_losses: List[float] = field(default_factory=list) - global_max_losses: List[float] = field(default_factory=list) - log_steps: List[int] = field(default_factory=list) - - def state_dict(self) -> Dict[str, Any]: - # Only checkpoint global_avg_losses and global_max_losses per log frequency - # to avoid sync overhead in every iteration. - global_avg_losses_bytes = BytesIO() - torch.save(self.global_avg_losses, global_avg_losses_bytes) - global_max_losses_bytes = BytesIO() - torch.save(self.global_max_losses, global_max_losses_bytes) - log_steps_bytes = BytesIO() - torch.save(self.log_steps, log_steps_bytes) - return { - "step": torch.tensor(self.step, dtype=torch.int32), - "global_avg_losses": global_avg_losses_bytes, - "global_max_losses": global_max_losses_bytes, - "log_steps": log_steps_bytes, - } - - def load_state_dict(self, state_dict) -> None: - self.step = state_dict["step"].item() - state_dict["global_avg_losses"].seek(0) - self.global_avg_losses = torch.load( - state_dict["global_avg_losses"], weights_only=False - ) - state_dict["global_max_losses"].seek(0) - self.global_max_losses = torch.load( - state_dict["global_max_losses"], weights_only=False - ) - state_dict["log_steps"].seek(0) - self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) - - -def build_optimizers(model_parts, job_config: JobConfig): - """Wrap one optimizer per model part in an OptimizersContainer which provides a single - step() and zero_grad() method for all the child optimizers. - """ - - def _build_optimizer(model): - name = job_config.optimizer.name - lr = job_config.optimizer.lr - fused = job_config.optimizer.fused - - # Common parameters for both optimizers - optimizer_kwargs = { - "lr": lr, - "betas": (0.9, 0.95), - "weight_decay": 0.1, - "fused": fused, - "foreach": not fused, - } - if name == "Adam": - # TODO: make the optimizer options configurable by toml/cmd args - optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs) - elif name == "AdamW": - optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs) - else: - raise NotImplementedError(f"Optimizer {name} not added.") - - return optimizer - - class OptimizersContainer: - """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages""" - - def __init__(self, optimizers): - self.optimizers = optimizers - - def step(self): - for optimizer in self.optimizers: - optimizer.step() - - def zero_grad(self): - for optimizer in self.optimizers: - optimizer.zero_grad() - - return OptimizersContainer([_build_optimizer(model) for model in model_parts]) def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool): @@ -144,12 +39,11 @@ def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool def context(): with contextlib.ExitStack() as stack: if enable_loss_parallel: - stack.enter_context(loss_parallel()) + stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) if enable_compiled_autograd: stack.enter_context( torch._dynamo.utils.maybe_enable_compiled_autograd(True) ) - yield return context @@ -162,12 +56,10 @@ def main(job_config: JobConfig): logger.info(f"Starting job: {job_config.job.description}") # used for colorful printing - color = Color if job_config.metrics.enable_color_printing else NoColor + color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor # take control of garbage collection to avoid stragglers - _gc_freq = job_config.training.gc_freq - gc.disable() - gc.collect(1) + gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) # init distributed world_size = int(os.environ["WORLD_SIZE"]) @@ -181,14 +73,16 @@ def main(job_config: JobConfig): ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) - init_distributed(job_config) + utils.init_distributed(job_config) + # initialize GPU memory monitor and get peak flops for MFU calculation + gpu_memory_monitor = build_gpu_memory_monitor() + gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name) # build meshes world_mesh = parallel_dims.build_mesh(device_type="cuda") if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] - dp_degree = dp_mesh.size() - dp_rank = dp_mesh.get_local_rank() + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 @@ -199,7 +93,7 @@ def main(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader data_loader = build_hf_data_loader( @@ -212,15 +106,6 @@ def main(job_config: JobConfig): dp_rank, ) - train_context = get_train_context( - parallel_dims.loss_parallel_enabled, - job_config.experimental.enable_compiled_autograd, - ) - - # loss fn can be shared by pipeline-parallel or non-pp execution - def loss_fn(pred, labels): - return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) - # build model (using meta init) model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][job_config.model.flavor] @@ -240,9 +125,9 @@ def loss_fn(pred, labels): maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) # log model size - model_param_count = get_num_params(whole_model) - num_flop_per_token = get_num_flop_per_token( - get_num_params(whole_model, exclude_embedding=True), + model_param_count = utils.get_num_params(whole_model) + num_flop_per_token = utils.get_num_flop_per_token( + utils.get_num_params(whole_model, exclude_embedding=True), model_config, job_config.training.seq_len, ) @@ -251,11 +136,6 @@ def loss_fn(pred, labels): f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) - # initialize GPU memory monitor before applying parallelisms to the model - gpu_memory_monitor = build_gpu_memory_monitor() - # obtain the peak flops of bf16 type for MFU calculation - gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - if parallel_dims.pp_enabled: stages, model_parts = models_pipelining_fns[model_name]( whole_model, world_mesh, parallel_dims, job_config, device, model_config @@ -276,6 +156,12 @@ def loss_fn(pred, labels): for model in model_parts: model.to_empty(device=init_device) + # loss fn can be shared by pipeline-parallel or non-pp execution + def loss_fn(pred, labels): + return torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) + ) + if parallel_dims.pp_enabled: pp_schedule = build_pipeline_schedule( job_config, parallel_dims, stages, loss_fn @@ -295,11 +181,7 @@ def loss_fn(pred, labels): # build optimizer after applying parallelisms to the model optimizers = build_optimizers(model_parts, job_config) - lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) - - metric_logger = build_metric_logger( - job_config, metrics_log_rank=get_metrics_rank(world_mesh, parallel_dims) - ) + lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) train_state = TrainState() @@ -309,10 +191,10 @@ def loss_fn(pred, labels): # load initial checkpoint checkpoint = CheckpointManager( + dataloader=data_loader, model_parts=model_parts, optimizers=optimizers.optimizers, lr_schedulers=lr_schedulers.schedulers, - dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, ) @@ -333,6 +215,8 @@ def loss_fn(pred, labels): "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" ) + metric_logger = build_metric_logger(job_config, parallel_dims) + # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq @@ -346,21 +230,28 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) - checkpoint.reset() + train_context = get_train_context( + parallel_dims.loss_parallel_enabled, + job_config.experimental.enable_compiled_autograd, + ) # variables used to keep info for metrics logging - losses_since_last_log: List[float] = [] + losses_since_last_log = [] ntokens_since_last_log = 0 - data_loading_times: List[float] = [] - time_last_log = timer() + data_loading_times = [] + time_last_log = time.perf_counter() gpu_memory_monitor.reset_peak_stats() + checkpoint.reset() + # train loop logger.info( f"Training starts at step {train_state.step + 1}, " - f"with local batch size: {job_config.training.batch_size}, " - f"sequence length: {job_config.training.seq_len}, " - f"total steps: {job_config.training.steps}({job_config.training.warmup_steps}), " + f"with local batch size {job_config.training.batch_size}, " + f"global batch size {job_config.training.batch_size * dp_degree}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.training.warmup_steps})" ) with maybe_enable_profiling( job_config, global_step=train_state.step @@ -369,15 +260,14 @@ def loss_fn(pred, labels): ) as memory_profiler: while train_state.step < job_config.training.steps: train_state.step += 1 - if train_state.step > 1 and train_state.step % _gc_freq == 0: - gc.collect(1) + gc_handler.run(train_state.step) # get batch - data_load_start = timer() + data_load_start = time.perf_counter() batch = next(data_iterator) input_ids, labels = batch ntokens_since_last_log += labels.numel() - data_loading_times.append(timer() - data_load_start) + data_loading_times.append(time.perf_counter() - data_load_start) input_ids = input_ids.cuda() labels = labels.cuda() @@ -422,7 +312,7 @@ def loss_fn(pred, labels): maybe_sync_float8_amax_and_scale_history(model, job_config) # optimizer step - checkpoint.wait_for_staging() + checkpoint.maybe_wait_for_staging() optimizers.step() lr_schedulers.step() @@ -439,23 +329,21 @@ def loss_fn(pred, labels): or train_state.step % job_config.metrics.log_freq == 0 ): losses = [loss.item() for loss in losses_since_last_log] - avg_loss, max_loss = ( - np.mean(losses), - np.max(losses), - ) + avg_loss, max_loss = sum(losses) / len(losses), max(losses) if parallel_dims.dp_enabled: global_avg_loss, global_max_loss = ( - dist_mean(avg_loss, dp_mesh).item(), - dist_max(max_loss, dp_mesh).item(), + utils.dist_mean(avg_loss, dp_mesh), + utils.dist_max(max_loss, dp_mesh), ) else: global_avg_loss, global_max_loss = avg_loss, max_loss + # update train state train_state.log_steps.append(train_state.step) train_state.global_avg_losses.append(global_avg_loss) train_state.global_max_losses.append(global_max_loss) - time_delta = timer() - time_last_log + time_delta = time.perf_counter() - time_last_log # tokens per second, abbr. as wps by convention wps = ntokens_since_last_log / ( @@ -467,8 +355,8 @@ def loss_fn(pred, labels): mfu = 100 * num_flop_per_token * wps / gpu_peak_flops time_end_to_end = time_delta / job_config.metrics.log_freq - time_data_loading = np.mean(data_loading_times) - time_data_loading_pct = 100 * np.sum(data_loading_times) / time_delta + time_data_loading = sum(data_loading_times) / len(data_loading_times) + time_data_loading_pct = 100 * sum(data_loading_times) / time_delta gpu_mem_stats = gpu_memory_monitor.get_peak_stats() @@ -501,7 +389,7 @@ def loss_fn(pred, labels): losses_since_last_log.clear() ntokens_since_last_log = 0 data_loading_times.clear() - time_last_log = timer() + time_last_log = time.perf_counter() gpu_memory_monitor.reset_peak_stats() checkpoint.save( @@ -517,7 +405,7 @@ def loss_fn(pred, labels): # Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished) if train_state.step == 1: - set_pg_timeouts( + utils.set_pg_timeouts( timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), world_mesh=world_mesh, ) @@ -534,4 +422,4 @@ def loss_fn(pred, labels): config = JobConfig() config.parse_args() main(config) - destroy_process_group() + torch.distributed.destroy_process_group()