From abafcb920dae6a7e7d14db753938def4f00d6670 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Jan 2024 10:22:56 +0000 Subject: [PATCH 01/57] fix: llama import --- src/nanotron/models/llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index cc80579f..adb7551d 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -30,9 +30,10 @@ from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs, RecomputeGranularity -from nanotron.fused.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.logging import log_rank -from nanotron.models import AttachableStore, NanotronModel +from nanotron.models import NanotronModel +from nanotron.generation.generate_store import AttachableStore from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import ( From e8b8314d61812b8d1233158d58a3aa21a7c5efde Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Jan 2024 10:29:17 +0000 Subject: [PATCH 02/57] feat: at leading loading mamba properly --- examples/mamba/configs/make_config_mamba.py | 106 ++++ examples/mamba/run.sh | 29 + run_train.py | 7 +- src/nanotron/config/config.py | 4 +- src/nanotron/config/models_config.py | 31 +- src/nanotron/models/mamba.py | 562 ++++++++++++++++++++ src/nanotron/trainer.py | 26 + 7 files changed, 759 insertions(+), 6 deletions(-) create mode 100644 examples/mamba/configs/make_config_mamba.py create mode 100755 examples/mamba/run.sh create mode 100644 src/nanotron/models/mamba.py diff --git a/examples/mamba/configs/make_config_mamba.py b/examples/mamba/configs/make_config_mamba.py new file mode 100644 index 00000000..d92f8513 --- /dev/null +++ b/examples/mamba/configs/make_config_mamba.py @@ -0,0 +1,106 @@ +""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" +import os +import torch + +from nanotron.config import ( + CheckpointsArgs, + Config, + DataArgs, + GeneralArgs, + MambaConfig, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + PretrainDatasetsArgs, + MambaInit, + TokenizerArgs, + TokensArgs, +) +from nanotron.logging import human_format + +model_config = MambaConfig( + d_model=256, + num_hidden_layers=1, + vocab_size=50277, + ssm_cfg={}, + rms_norm=True, + fused_add_norm=True, + residual_in_fp32=True, + pad_vocab_size_multiple=8, + # Custom + dtype=torch.float32, + rms_norm_eps=1e-5, +) + + +#TODO(fmom): do something similar +# num_params = human_format( +# model_config.vocab_size * model_config.d_model * 2 +# + model_config.num_hidden_layers +# * ( +# 3 * model_config.d_model * model_config.intermediate_size +# + 4 * model_config.d_model * model_config.d_model +# ) +# ).replace(".", "p") + +# print(f"Model has {num_params} parameters") + +seed = 42 + +learning_rate = LRSchedulerArgs( + learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 +) + +optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + learning_rate_scheduler=learning_rate, +) + +parallelism = ParallelismArgs( + dp=1, + pp=1, + tp=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, + recompute_granularity="selective", +) + +tokens = TokensArgs(sequence_length=32, train_steps=10, micro_batch_size=2, batch_accumulation_per_replica=1) + +dataset = PretrainDatasetsArgs( + hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text" +) + +checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" +os.makedirs(checkpoints_path, exist_ok=True) + +config = Config( + general=GeneralArgs(project="test", run="mamba", seed=seed), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + parallelism=parallelism, + model=ModelArgs(init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), model_config=model_config), + tokenizer=TokenizerArgs("gpt2"), + optimizer=optimizer, + logging=LoggingArgs(), + tokens=tokens, + data=DataArgs(dataset=dataset, seed=seed), + profiler=None, +) + +if __name__ == "__main__": + dir = os.path.dirname(__file__) + + # Save config as YAML file + config.save_as_yaml(f"{dir}/config_mamba.yaml") + + # You can now train a model with this config using `/run_train.py` diff --git a/examples/mamba/run.sh b/examples/mamba/run.sh new file mode 100755 index 00000000..c5cbc5a7 --- /dev/null +++ b/examples/mamba/run.sh @@ -0,0 +1,29 @@ +#!/bin/sh + +if [ "$1" = "debug" ]; then + python configs/make_config_mamba.py && \ + FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 1234 -m torch.distributed.launch \ + -- \ + --nproc_per_node=1 \ + --master_port=29600 \ + --rdzv_endpoint=localhost:6000 \ + --use_env \ + --tee=3 \ + ../../run_train.py \ + --config-file=configs/config_mamba.yaml +elif [ "$1" = "eval" ]; then + FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \ + --nproc_per_node=1 \ + --master_port 29600 \ + ../generate.py \ + --pp 1 \ + --tp 1 \ + --ckpt-path /fsx/ferdinandmom/github/mamba/checkpoints/mamba-1p62M-stas-openwebtext-10k/7 +else + python configs/make_config_mamba.py && \ + FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \ + --nproc_per_node=1 \ + --master_port=29600 \ + ../../run_train.py \ + --config-file=configs/config_mamba.yaml +fi \ No newline at end of file diff --git a/run_train.py b/run_train.py index d3276f6b..f0830547 100644 --- a/run_train.py +++ b/run_train.py @@ -131,6 +131,9 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) dataloader = get_dataloader(trainer) - + + print("HELLOOOOOOOOO") + print(trainer.model) + # Train - trainer.train(dataloader) + # trainer.train(dataloader) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index b4ec938b..9885b952 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -10,7 +10,7 @@ from dacite import from_dict from yaml.loader import SafeLoader -from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit +from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, MambaInit from nanotron.config.utils_config import ( RecomputeGranularity, cast_str_to_pipeline_engine, @@ -210,7 +210,7 @@ class ModelArgs: """Arguments related to model architecture""" model_config: NanotronConfigs - init_method: Union[RandomInit, ExistingCheckpointInit] + init_method: Union[RandomInit, MambaInit, ExistingCheckpointInit] dtype: Optional[torch.dtype] = None make_vocab_size_divisible_by: int = 1 ddp_bucket_cap_mb: int = 25 diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index b85749f2..673f44e2 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,12 +1,18 @@ from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Union - +import torch @dataclass class RandomInit: std: float +@dataclass +class MambaInit: + # mamba_ssm.models.mixer_seq_simple._init_weights + initializer_range: float = 0.02 + rescale_prenorm_residual: bool = True, + n_residuals_per_layer: int = 1, # Change to 2 if we have MLP @dataclass class ExistingCheckpointInit: @@ -14,6 +20,25 @@ class ExistingCheckpointInit: path: Path +@dataclass +class MambaConfig: + """Configuration for a Mamba model + + Be careful on having a coherent typing as we use it to reconstruct the model from yaml + """ + + d_model: int = 2560 + num_hidden_layers: int = 64 + vocab_size: int = 50277 + ssm_cfg: Optional[dict] = None + rms_norm: bool = True + fused_add_norm: bool = True + residual_in_fp32: bool = True + pad_vocab_size_multiple: int = 8 + # ==== Custom ====== + dtype: torch.dtype = torch.float32 + rms_norm_eps: float = 1e-5 + pad_token_id: Optional[int] = None @dataclass class LlamaConfig: @@ -116,4 +141,6 @@ def n_inner(self): return self.intermediate_size -NanotronConfigs = Union[LlamaConfig, Starcoder2Config] +#TODO(fmom): check why MambaConfig won't load if it's not the first one in the union +NanotronConfigs = Union[MambaConfig, LlamaConfig, Starcoder2Config] + diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py new file mode 100644 index 00000000..e6d91e0c --- /dev/null +++ b/src/nanotron/models/mamba.py @@ -0,0 +1,562 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mamba model. +""" +from typing import Dict, Optional, Union +import math +import torch +from flash_attn import bert_padding +from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, +) +from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding +from torch import nn +from transformers.activations import ACT2FN +from functools import partial + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ParallelismArgs, RecomputeGranularity +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.generation.generate_store import AttachableStore +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import ( + PipelineBlock, + TensorPointer, +) +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.random import RandomStates +from nanotron.utils import checkpoint_method +from nanotron.config.models_config import MambaConfig + +from mamba_ssm.models.mixer_seq_simple import create_block, Mamba, _init_weights + +try: + from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + +class Embedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: MambaConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.d_model, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length] + store = self.get_local_store() + if store is not None: + if "past_length" in store: + past_length = store["past_length"] + else: + past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) + + cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) + # Store new past_length in store + store["past_length"] = past_length + cumsum_mask[:, -1] + + # Format input in `[seq_length, batch_size]` to support high TP with low batch_size + input_ids = input_ids.transpose(0, 1) + input_embeds = self.token_embedding(input_ids) + return {"input_embeds": input_embeds} + +class MambaDecoderLayer(nn.Module): + def __init__( + self, + config: MambaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + self.block = create_block( + config.d_model, + ssm_cfg=config.ssm_cfg, + norm_epsilon=config.rms_norm_eps, + rms_norm=config.rms_norm, + residual_in_fp32=config.residual_in_fp32, + fused_add_norm=config.fused_add_norm, + layer_idx=layer_idx, + **factory_kwargs, + ) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + residual: Optional[Union[torch.Tensor, TensorPointer]] = None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + hidden_states, residual = self.block(hidden_states) + + return { + "hidden_states": hidden_states, + "sequence_mask": sequence_mask, # NOTE(fmom): dunno how to use it for now. Just keep it + "residual": residual, + } + + +class MambaModel(nn.Module): + def __init__( + self, + config: MambaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids", "input_mask"}, + module_output_keys={"input_embeds"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=MambaDecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "layer_idx": layer_idx, + "device": self.p2p.device, + "dtype": config.dtype, + }, + module_input_keys={"hidden_states", "sequence_mask", "residual"}, # TODO(fmom): is this correct? + module_output_keys={"hidden_states", "sequence_mask", "residual"}, # TODO(fmom): is this correct? + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=RMSNorm, + module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) # TODO + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.d_model, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO @thomasw21: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + # TODO(fmom): call tied weights here + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + + def forward_with_hidden_states( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) + + residual = None + + hidden_encoder_states = { + "hidden_states": output["input_embeds"], + "sequence_mask": input_mask, + "residual": residual, + } + + for block in self.decoder: + hidden_encoder_states = block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits, hidden_states + + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + # model_config = self.config + # d_ff = model_config.intermediate_size + # d_qkv = model_config.d_model // model_config.num_attention_heads + # block_compute_costs = { + # # CausalSelfAttention (qkv proj + attn out) + MLP + # LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.d_model + # + 3 * d_ff * model_config.d_model, + # # This is the last lm_head + # TensorParallelColumnLinear: model_config.vocab_size * model_config.d_model, + # } + model_config = self.config + + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + MambaDecoderLayer: 4096, + # This is the last lm_head + TensorParallelColumnLinear: model_config.vocab_size * model_config.d_model, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + # world_size = self.parallel_context.world_pg.size() + # try: + # num_key_values_heads = self.config.num_key_value_heads + # except AttributeError: + # num_key_values_heads = self.config.num_attention_heads + + # model_flops, hardware_flops = get_flops( + # num_layers=self.config.num_hidden_layers, + # hidden_size=self.config.d_model, + # num_heads=self.config.num_attention_heads, + # num_key_value_heads=num_key_values_heads, + # vocab_size=self.config.vocab_size, + # ffn_hidden_size=self.config.intermediate_size, + # seq_len=sequence_length, + # batch_size=global_batch_size, + # recompute_granularity=self.parallel_config.recompute_granularity, + # ) + + # model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + # hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + + # TODO(fmom): undo hardcoding of model_flops_per_s and hardware_flops_per_s + model_flops_per_s = 0.000681 + hardware_flops_per_s = 0.000681 + return model_flops_per_s, hardware_flops_per_s + + +torch.jit.script +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [seq_length, batch_size, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. + # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + loss = sharded_cross_entropy( + sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + ).transpose(0, 1) + # TODO @thomasw21: It's unclear what kind of normalization we want to do. + loss = masked_mean(loss, label_mask, dtype=torch.float) + # I think indexing causes a sync we don't actually want + # loss = loss[label_mask].sum() + return {"loss": loss} + + +class MambaForTraining(NanotronModel): + def __init__( + self, + config: MambaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + + self.model = MambaModel( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=random_states, + ) + + self.loss = PipelineBlock( + p2p=self.model.p2p, + module_builder=Loss, + module_kwargs={"tp_pg": parallel_context.tp_pg}, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + @torch.no_grad() + def init_model_randomly(self, init_method, scaled_init_method): + """Initialize model parameters randomly. + Args: + init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ + scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ + + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ + model = self + initialized_parameters = set() + + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for module_name, module in model.named_modules(): + if isinstance(module, TensorParallelColumnLinear): + # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 + # What it does: + # - instantiate a buffer of the `full size` in fp32 + # - run init method on it + # - shard result to get only a specific shard + # Instead I'm lazy and just going to run init_method, since they are scalar independent + assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { + name for name, _ in module.named_parameters() + } + for param_name, param in module.named_parameters(): + assert isinstance(param, NanotronParameter) + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + if "weight" == param_name: + init_method(param) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + elif isinstance(module, TensorParallelRowLinear): + # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 + # What it does: + # - instantiate a buffer of the `full size` in fp32 + # - run init method on it + # - shard result to get only a specific shard + # Instead I'm lazy and just going to run init_method, since they are scalar independent + assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { + name for name, _ in module.named_parameters() + } + for param_name, param in module.named_parameters(): + assert isinstance(param, NanotronParameter) + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + if "weight" == param_name: + scaled_init_method(param) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + elif isinstance(module, RMSNorm): + assert {"weight"} == {name for name, _ in module.named_parameters()} + for param_name, param in module.named_parameters(): + assert isinstance(param, NanotronParameter) + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + param.fill_(1) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + elif isinstance(module, TensorParallelEmbedding): + # TODO @thomasw21: Handle tied embeddings + # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 + # What it does: + # - instantiate a buffer of the `full size` in fp32 + # - run init method on it + # - shard result to get only a specific shard + # Instead I'm lazy and just going to run init_method, since they are scalar independent + assert {"weight"} == {name for name, _ in module.named_parameters()} + + assert isinstance(module.weight, NanotronParameter) + if module.weight.is_tied: + tied_info = module.weight.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.weight" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + init_method(module.weight) + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + @torch.no_grad() + def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residual, n_residuals_per_layer): + + model = self + initialized_parameters = set() + + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + #TODO(fmom): make it compatible with TensorParralel, TensorEmbedding + for module_name, module in model.named_modules(): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + #TODO(fmom): perform check + # assert initialized_parameters == { + # param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + # if param.is_tied + # else name + # for name, param in model.named_parameters() + # }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + # TODO(fmom): implement get_block_compute_costs + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + return self.model.get_block_compute_costs() + + # TODO(fmom): implement get_flops_per_sec + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 119f658c..d33f4187 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -18,6 +18,7 @@ ExistingCheckpointInit, ParallelismArgs, RandomInit, + MambaInit, get_config_from_file, ) from nanotron.dataloader import sanity_check_dataloader @@ -34,6 +35,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining +from nanotron.models.mamba import MambaForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -82,6 +84,7 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, + "MambaConfig": MambaForTraining, } @@ -557,6 +560,29 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: parallel_context=self.parallel_context, root_folder=self.config.model.init_method.path, ) + elif isinstance(self.config.model.init_method, MambaInit): + # Initialize model randomly + normalized_model.init_mamba_weights( + n_layer=self.model_config.num_hidden_layers, + initializer_range=self.config.model.init_method.initializer_range, + rescale_prenorm_residual=self.config.model.init_method.rescale_prenorm_residual, + n_residuals_per_layer=self.config.model.init_method.n_residuals_per_layer + ) + # Synchronize parameters so that the model is consistent + # sync all params across dp + for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg) + + # sync tied params across tied groups + for (_, group_ranks), param in sorted( + get_tied_id_to_param( + parameters=model.parameters(), + root_module=normalized_model, + ).items(), + key=lambda x: x[0], + ): + group = self.parallel_context.world_ranks_to_pg[group_ranks] + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) elif isinstance(self.config.model.init_method, RandomInit): # Initialize model randomly normalized_model.init_model_randomly( From d9513b093275a41db77b3f5b6e87dbb2d04d88dd Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Jan 2024 11:08:01 +0000 Subject: [PATCH 03/57] loss going down --- examples/mamba/configs/make_config_mamba.py | 6 ++--- run_train.py | 2 +- src/nanotron/config/models_config.py | 1 + src/nanotron/models/mamba.py | 29 ++++++++++++++++++--- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/examples/mamba/configs/make_config_mamba.py b/examples/mamba/configs/make_config_mamba.py index d92f8513..4fa7442e 100644 --- a/examples/mamba/configs/make_config_mamba.py +++ b/examples/mamba/configs/make_config_mamba.py @@ -57,7 +57,7 @@ zero_stage=0, weight_decay=0.01, clip_grad=1.0, - accumulate_grad_in_fp32=True, + accumulate_grad_in_fp32=False, #NOTE(fmom): because we are using PP=TP=DP=1 adam_eps=1e-08, adam_beta1=0.9, adam_beta2=0.95, @@ -75,7 +75,7 @@ recompute_granularity="selective", ) -tokens = TokensArgs(sequence_length=32, train_steps=10, micro_batch_size=2, batch_accumulation_per_replica=1) +tokens = TokensArgs(sequence_length=1024, train_steps=40, micro_batch_size=2, batch_accumulation_per_replica=1) dataset = PretrainDatasetsArgs( hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text" @@ -86,7 +86,7 @@ config = Config( general=GeneralArgs(project="test", run="mamba", seed=seed), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=243232232232323332), parallelism=parallelism, model=ModelArgs(init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), model_config=model_config), tokenizer=TokenizerArgs("gpt2"), diff --git a/run_train.py b/run_train.py index f0830547..4d31ff0c 100644 --- a/run_train.py +++ b/run_train.py @@ -136,4 +136,4 @@ def get_args(): print(trainer.model) # Train - # trainer.train(dataloader) + trainer.train(dataloader) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 673f44e2..a46d8374 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -142,5 +142,6 @@ def n_inner(self): #TODO(fmom): check why MambaConfig won't load if it's not the first one in the union +# NanotronConfigs = Union[LlamaConfig, MambaConfig, Starcoder2Config] NanotronConfigs = Union[MambaConfig, LlamaConfig, Starcoder2Config] diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py index e6d91e0c..f66daf4d 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba.py @@ -51,6 +51,7 @@ from nanotron.utils import checkpoint_method from nanotron.config.models_config import MambaConfig +#NOTE(fmom): mamba_ssm=1.1.1 from mamba_ssm.models.mixer_seq_simple import create_block, Mamba, _init_weights try: @@ -58,6 +59,8 @@ except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None +logger = logging.get_logger(__name__) + class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: MambaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -182,7 +185,7 @@ def __init__( p2p=self.p2p, module_builder=RMSNorm, module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps}, - module_input_keys={"input"}, + module_input_keys={"x"}, module_output_keys={"hidden_states"}, ) # TODO @@ -240,7 +243,7 @@ def forward_with_hidden_states( for block in self.decoder: hidden_encoder_states = block(**hidden_encoder_states) - hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + hidden_states = self.final_layer_norm(x=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] @@ -359,6 +362,24 @@ def __init__( self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + sharded_logits = self.model( + input_ids=input_ids, + input_mask=input_mask, + ) + loss = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + )["loss"] + return {"loss": loss} @torch.no_grad() def init_model_randomly(self, init_method, scaled_init_method): @@ -542,7 +563,9 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) - + + log_rank(f"Initialized {module_name} (NEED TO CHECK IF DONE PROPERLY)", logger=logger, level=logging.INFO, rank=0) + #TODO(fmom): perform check # assert initialized_parameters == { # param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) From 147fd2553c91d5918e124d949b2257f0cec529cc Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 24 Jan 2024 14:22:54 +0000 Subject: [PATCH 04/57] refacto: init mamba weights --- src/nanotron/models/mamba.py | 87 ++++++++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 18 deletions(-) diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py index f66daf4d..3778b7a4 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba.py @@ -174,8 +174,8 @@ def __init__( "device": self.p2p.device, "dtype": config.dtype, }, - module_input_keys={"hidden_states", "sequence_mask", "residual"}, # TODO(fmom): is this correct? - module_output_keys={"hidden_states", "sequence_mask", "residual"}, # TODO(fmom): is this correct? + module_input_keys={"hidden_states", "sequence_mask", "residual"}, + module_output_keys={"hidden_states", "sequence_mask", "residual"}, ) for layer_idx in range(config.num_hidden_layers) ] @@ -187,7 +187,7 @@ def __init__( module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps}, module_input_keys={"x"}, module_output_keys={"hidden_states"}, - ) # TODO + ) self.lm_head = PipelineBlock( p2p=self.p2p, @@ -268,10 +268,11 @@ def get_block_compute_costs(self): block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - MambaDecoderLayer: 4096, + MambaDecoderLayer: 0, # This is the last lm_head - TensorParallelColumnLinear: model_config.vocab_size * model_config.d_model, + TensorParallelColumnLinear: 0, } + log_rank(f"get_block_compute_costs() Not implemented yet", logger=logger, level=logging.INFO, rank=0) return block_compute_costs def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): @@ -298,8 +299,9 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch # hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) # TODO(fmom): undo hardcoding of model_flops_per_s and hardware_flops_per_s - model_flops_per_s = 0.000681 - hardware_flops_per_s = 0.000681 + model_flops_per_s = 0 + hardware_flops_per_s = 0 + log_rank(f"get_flops_per_sec() Not implemented yet", logger=logger, level=logging.INFO, rank=0) return model_flops_per_s, hardware_flops_per_s @@ -537,15 +539,53 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua # Fix the root_model module_id_to_prefix[id(model)] = "" - #TODO(fmom): make it compatible with TensorParralel, TensorEmbedding + #TODO(fmom): port initiliaztion from mamba_ssm.mamba_simple.Mamba to here + for module_name, module in model.named_modules(): if isinstance(module, nn.Linear): - if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) - - elif isinstance(module, nn.Embedding): + + for param_name, param in module.named_parameters(): + assert isinstance(param, NanotronParameter) + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + if "weight" == param_name: + pass + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + elif isinstance(module, TensorParallelEmbedding): + assert {"weight"} == {name for name, _ in module.named_parameters()} + + assert isinstance(module.weight, NanotronParameter) + if module.weight.is_tied: + tied_info = module.weight.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.weight" + + if full_param_name in initialized_parameters: + # Already initialized + continue + nn.init.normal_(module.weight, std=initializer_range) + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: @@ -555,7 +595,17 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: + if name in ["out_proj.weight"]: + # get fullname + assert isinstance(p, NanotronParameter) + if p.is_tied: + tied_info = p.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times @@ -563,10 +613,11 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) - - log_rank(f"Initialized {module_name} (NEED TO CHECK IF DONE PROPERLY)", logger=logger, level=logging.INFO, rank=0) - - #TODO(fmom): perform check + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + # #TODO(fmom): perform check # assert initialized_parameters == { # param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) # if param.is_tied From 3e4dda1479f15fe446b8aeb9538ff5760065adcf Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Sun, 28 Jan 2024 14:46:55 +0000 Subject: [PATCH 05/57] fix: dtype for brrr compatibility --- examples/mamba/run.sh | 4 ++-- src/nanotron/config/models_config.py | 25 +++++++++++++++++++++++-- src/nanotron/models/mamba.py | 3 ++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/examples/mamba/run.sh b/examples/mamba/run.sh index c5cbc5a7..b8fc12b8 100755 --- a/examples/mamba/run.sh +++ b/examples/mamba/run.sh @@ -1,7 +1,7 @@ #!/bin/sh if [ "$1" = "debug" ]; then - python configs/make_config_mamba.py && \ + python configs/make_config_mamba_fast.py && \ FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 1234 -m torch.distributed.launch \ -- \ --nproc_per_node=1 \ @@ -20,7 +20,7 @@ elif [ "$1" = "eval" ]; then --tp 1 \ --ckpt-path /fsx/ferdinandmom/github/mamba/checkpoints/mamba-1p62M-stas-openwebtext-10k/7 else - python configs/make_config_mamba.py && \ + python configs/make_config_mamba_fast.py && \ FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \ --nproc_per_node=1 \ --master_port=29600 \ diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index a46d8374..62975c7b 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -36,7 +36,28 @@ class MambaConfig: residual_in_fp32: bool = True pad_vocab_size_multiple: int = 8 # ==== Custom ====== - dtype: torch.dtype = torch.float32 + dtype: str = "float32" + rms_norm_eps: float = 1e-5 + pad_token_id: Optional[int] = None + + +@dataclass +class MambaFastConfig: + """Configuration for a Mamba model + + Be careful on having a coherent typing as we use it to reconstruct the model from yaml + """ + + d_model: int = 2560 + num_hidden_layers: int = 64 + vocab_size: int = 50277 + ssm_cfg: Optional[dict] = None + rms_norm: bool = True + fused_add_norm: bool = True + residual_in_fp32: bool = True + pad_vocab_size_multiple: int = 8 + # ==== Custom ====== + dtype: str = "float32" rms_norm_eps: float = 1e-5 pad_token_id: Optional[int] = None @@ -143,5 +164,5 @@ def n_inner(self): #TODO(fmom): check why MambaConfig won't load if it's not the first one in the union # NanotronConfigs = Union[LlamaConfig, MambaConfig, Starcoder2Config] -NanotronConfigs = Union[MambaConfig, LlamaConfig, Starcoder2Config] +NanotronConfigs = Union[MambaFastConfig, LlamaConfig, MambaConfig, Starcoder2Config] diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py index 3778b7a4..69c74795 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba.py @@ -27,6 +27,7 @@ from transformers.activations import ACT2FN from functools import partial +from nanotron.config.utils_config import str_to_dtype from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs, RecomputeGranularity @@ -172,7 +173,7 @@ def __init__( "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, "device": self.p2p.device, - "dtype": config.dtype, + "dtype": str_to_dtype[config.dtype], }, module_input_keys={"hidden_states", "sequence_mask", "residual"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, From bebbf6fb1bbd224c31089870dbfaf0c4b186b738 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 29 Jan 2024 10:25:24 +0000 Subject: [PATCH 06/57] fix: bring back strict union to fix NanotronConfigs + use torch.dtype in config --- src/nanotron/config/config.py | 4 ++-- src/nanotron/config/models_config.py | 9 ++++----- src/nanotron/models/mamba.py | 3 +-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 9885b952..521cfe7d 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -392,8 +392,8 @@ def get_config_from_file(config_path: str, config_class: Type[Config] = Config) RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], }, - # strict_unions_match=True, - # strict=True, + strict_unions_match=True, + strict=True, ), ) except Exception as e: diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 62975c7b..13b280ee 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -27,6 +27,7 @@ class MambaConfig: Be careful on having a coherent typing as we use it to reconstruct the model from yaml """ + is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion d_model: int = 2560 num_hidden_layers: int = 64 vocab_size: int = 50277 @@ -36,7 +37,7 @@ class MambaConfig: residual_in_fp32: bool = True pad_vocab_size_multiple: int = 8 # ==== Custom ====== - dtype: str = "float32" + dtype: torch.dtype = torch.float32 rms_norm_eps: float = 1e-5 pad_token_id: Optional[int] = None @@ -47,7 +48,7 @@ class MambaFastConfig: Be careful on having a coherent typing as we use it to reconstruct the model from yaml """ - + is_mamba_fast_config: bool = True # We use this help differentiate models in yaml/python conversion d_model: int = 2560 num_hidden_layers: int = 64 vocab_size: int = 50277 @@ -57,7 +58,7 @@ class MambaFastConfig: residual_in_fp32: bool = True pad_vocab_size_multiple: int = 8 # ==== Custom ====== - dtype: str = "float32" + dtype: torch.dtype = torch.float32 rms_norm_eps: float = 1e-5 pad_token_id: Optional[int] = None @@ -162,7 +163,5 @@ def n_inner(self): return self.intermediate_size -#TODO(fmom): check why MambaConfig won't load if it's not the first one in the union -# NanotronConfigs = Union[LlamaConfig, MambaConfig, Starcoder2Config] NanotronConfigs = Union[MambaFastConfig, LlamaConfig, MambaConfig, Starcoder2Config] diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py index 69c74795..3778b7a4 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba.py @@ -27,7 +27,6 @@ from transformers.activations import ACT2FN from functools import partial -from nanotron.config.utils_config import str_to_dtype from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs, RecomputeGranularity @@ -173,7 +172,7 @@ def __init__( "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, "device": self.p2p.device, - "dtype": str_to_dtype[config.dtype], + "dtype": config.dtype, }, module_input_keys={"hidden_states", "sequence_mask", "residual"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, From ad384b9c9af4c6d10c9151a17affe3f34d89a797 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 29 Jan 2024 10:42:14 +0000 Subject: [PATCH 07/57] fix: revert back to casting str to torch dtype --- src/nanotron/config/models_config.py | 4 ++-- src/nanotron/models/mamba.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 13b280ee..84374676 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -37,7 +37,7 @@ class MambaConfig: residual_in_fp32: bool = True pad_vocab_size_multiple: int = 8 # ==== Custom ====== - dtype: torch.dtype = torch.float32 + dtype: str = "float32" rms_norm_eps: float = 1e-5 pad_token_id: Optional[int] = None @@ -58,7 +58,7 @@ class MambaFastConfig: residual_in_fp32: bool = True pad_vocab_size_multiple: int = 8 # ==== Custom ====== - dtype: torch.dtype = torch.float32 + dtype: str = "float32" rms_norm_eps: float = 1e-5 pad_token_id: Optional[int] = None diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py index 3778b7a4..626147d6 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba.py @@ -29,6 +29,7 @@ from nanotron import distributed as dist from nanotron import logging +from nanotron.config.utils_config import cast_str_to_torch_dtype from nanotron.config import ParallelismArgs, RecomputeGranularity from nanotron.logging import log_rank from nanotron.models import NanotronModel @@ -172,7 +173,7 @@ def __init__( "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, "device": self.p2p.device, - "dtype": config.dtype, + "dtype": cast_str_to_torch_dtype[config.dtype], }, module_input_keys={"hidden_states", "sequence_mask", "residual"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, From 02d21de21d21c472f07a31c47f95641cac5ca43a Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 29 Jan 2024 11:14:01 +0000 Subject: [PATCH 08/57] fix --- src/nanotron/models/mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py index 626147d6..b94a04f5 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba.py @@ -173,7 +173,7 @@ def __init__( "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, "device": self.p2p.device, - "dtype": cast_str_to_torch_dtype[config.dtype], + "dtype": cast_str_to_torch_dtype(config.dtype), }, module_input_keys={"hidden_states", "sequence_mask", "residual"}, module_output_keys={"hidden_states", "sequence_mask", "residual"}, From 1168bfecf430b065b4de147fddde90a86371170d Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 29 Jan 2024 13:59:07 +0000 Subject: [PATCH 09/57] fix: mismatch params counting due to no tie embedding weight --- src/nanotron/models/mamba.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba.py index b94a04f5..5ec2b6d5 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba.py @@ -215,7 +215,6 @@ def __init__( module_output_keys={"output"}, ) - # TODO(fmom): call tied weights here def forward( self, @@ -626,6 +625,13 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua # for name, param in model.named_parameters() # }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + @staticmethod + def get_embeddings_lm_head_tied_names(): + return [ + "model.token_position_embeddings.pp_block.token_embedding.weight", + "model.lm_head.pp_block.weight", + ] + # TODO(fmom): implement get_block_compute_costs def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" From 7bfc56cbd35d301258773acfd738d6533117c961 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 30 Jan 2024 13:24:04 +0000 Subject: [PATCH 10/57] fix: run_generate is now compatible with Brrr --- run_generate.py | 4 ++-- src/nanotron/config/config.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/run_generate.py b/run_generate.py index 725c9b8d..2d0d09ab 100644 --- a/run_generate.py +++ b/run_generate.py @@ -60,7 +60,7 @@ def main(): assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist" - config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix()) + config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(), is_run_generate=True) model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path @@ -69,7 +69,7 @@ def main(): pp=args.pp or config.parallelism.pp, tp=args.tp or config.parallelism.tp, pp_engine=OneForwardOneBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_mode=TensorParallelLinearMode.REDUCE_SCATTER, recompute_granularity=None, tp_linear_async_communication=True, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 521cfe7d..3bb9cb15 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -78,7 +78,7 @@ def __post_init__(self): @dataclass class PretrainDatasetsArgs: - hf_dataset_or_datasets: Union[str, list, dict] + hf_dataset_mixer: Union[str, list, dict] hf_dataset_splits: Optional[Union[str, list]] = None hf_dataset_config_name: Optional[str] = None dataset_processing_num_proc_per_process: Optional[int] = 1 @@ -364,7 +364,7 @@ def as_dict(self) -> dict: return serialize(self) -def get_config_from_file(config_path: str, config_class: Type[Config] = Config) -> Config: +def get_config_from_file(config_path: str, config_class: Type[Config] = Config, is_run_generate: bool = False) -> Config: """Get a config objet from a file (python or YAML) Args: @@ -377,6 +377,22 @@ def get_config_from_file(config_path: str, config_class: Type[Config] = Config) with open(config_path) as f: args = yaml.load(f, Loader=SafeLoader) + # To run generate with Nanotron, we have to remove unused arguments in the config (s3 etc...) + if is_run_generate: + # Remove BRRR dataclasses that are not used in Nanotron + exclude_keys = set(args.keys()).difference(set(config_class.__dataclass_fields__.keys())) + for key in exclude_keys: + args.pop(key) + print(f"Removed '{key}' dataclass from config") + + # Remove keys from each Brrr dataclasses that are not used in Nanotron (i.e: is_brrr_data) + for key, value in args.items(): + if isinstance(value, dict): + exclude_keys = set(value.keys()).difference(set(config_class.__dataclass_fields__[key].type.__dataclass_fields__.keys())) + for key2 in exclude_keys: + args[key].pop(key2) + print(f"Removed '{key2}' from '{key}' dataclass from config") + print(args) # Make a nice dataclass from our yaml try: From 1fe1cd4111e128b4f0e249753bd7634f4448b43f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 30 Jan 2024 17:47:32 +0000 Subject: [PATCH 11/57] chore: single file mamba slow (for residual fix) --- src/nanotron/models/{ => mamba_slow}/mamba.py | 384 +++++++++++++++++- src/nanotron/trainer.py | 2 +- 2 files changed, 384 insertions(+), 2 deletions(-) rename src/nanotron/models/{ => mamba_slow}/mamba.py (65%) diff --git a/src/nanotron/models/mamba.py b/src/nanotron/models/mamba_slow/mamba.py similarity index 65% rename from src/nanotron/models/mamba.py rename to src/nanotron/models/mamba_slow/mamba.py index 5ec2b6d5..fc7ebe95 100644 --- a/src/nanotron/models/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -53,15 +53,397 @@ from nanotron.config.models_config import MambaConfig #NOTE(fmom): mamba_ssm=1.1.1 -from mamba_ssm.models.mixer_seq_simple import create_block, Mamba, _init_weights +# from mamba_ssm.models.mixer_seq_simple import create_block, Mamba, _init_weights + +# try: +# from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +# except ImportError: +# RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + +# Copyright (c) 2023, Albert Gu, Tri Dao. +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from einops import rearrange, repeat + +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None + +try: + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None try: from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + logger = logging.get_logger(__name__) + +class Mamba(nn.Module): + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", + dt_scale=1.0, + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + + self.activation = "silu" + self.act = nn.SiLU() + + self.x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + print() + + def forward(self, hidden_states, inference_params=None): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out + + # We do matmul and transpose BLH -> HBL at the same time + xz = rearrange( + self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), + "d (b l) -> b d l", + l=seqlen, + ) + if self.in_proj.bias is not None: + xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + # In the backward pass we write dx and dz next to each other to avoid torch.cat + if self.use_fast_path and inference_params is None: # Doesn't support outputting the states + out = mamba_inner_fn( + xz, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias, + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + else: + x, z = xz.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + return out + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=self.dt_proj.weight.device, + dtype=self.dt_proj.weight.dtype, + # dtype=torch.float32, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + +class Block(nn.Module): + def __init__( + self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(dim) + self.norm = norm_cls(dim) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return hidden_states, residual + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + +def create_block( + d_model, + ssm_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, +): + if ssm_cfg is None: + ssm_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + block = Block( + d_model, + mixer_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: MambaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index d33f4187..4946b094 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -35,7 +35,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining -from nanotron.models.mamba import MambaForTraining +from nanotron.models.mamba_slow.mamba import MambaForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp From 720df8ebf1ff9b239c236068e2778c6be2718235 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 2 Feb 2024 10:02:44 +0000 Subject: [PATCH 12/57] fix: residual for slow mamba --- src/nanotron/models/mamba_slow/mamba.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index fc7ebe95..55419fd4 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -179,7 +179,6 @@ def __init__( self.D._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - print() def forward(self, hidden_states, inference_params=None): """ @@ -395,7 +394,9 @@ def forward( residual: hidden_states = Mixer(LN(residual)) """ if not self.fused_add_norm: - residual = (hidden_states + residual) if residual is not None else hidden_states + # self.layer_idx was assigned when calling create_block + # residual=None happens only at the first block + residual = hidden_states if (self.layer_idx == 0) else hidden_states + residual hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) @@ -501,9 +502,9 @@ def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], - residual: Optional[Union[torch.Tensor, TensorPointer]] = None, + residual: Optional[Union[torch.Tensor, TensorPointer]], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - hidden_states, residual = self.block(hidden_states) + hidden_states, residual = self.block(hidden_states, residual) return { "hidden_states": hidden_states, @@ -614,12 +615,10 @@ def forward_with_hidden_states( output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) - residual = None - hidden_encoder_states = { "hidden_states": output["input_embeds"], "sequence_mask": input_mask, - "residual": residual, + "residual": output["input_embeds"], } for block in self.decoder: From 7902959dbf4e6f28cd8fbf87c881904408e6789c Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Sat, 3 Feb 2024 13:16:50 +0000 Subject: [PATCH 13/57] fix(mamba-slow): transpose embedding --- src/nanotron/models/mamba_slow/mamba.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 55419fd4..3086dbac 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -470,7 +470,8 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ store["past_length"] = past_length + cumsum_mask[:, -1] # Format input in `[seq_length, batch_size]` to support high TP with low batch_size - input_ids = input_ids.transpose(0, 1) + #NOTE(fmom): undo transpose for now since Mamba is not using TP + # input_ids = input_ids.transpose(0, 1) input_embeds = self.token_embedding(input_ids) return {"input_embeds": input_embeds} @@ -704,9 +705,16 @@ def forward( ) -> Dict[str, torch.Tensor]: # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + + #NOTE(fmom): undo transpose for now since Mamba is not using TP + # loss = sharded_cross_entropy( + # sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + # ).transpose(0, 1) + loss = sharded_cross_entropy( - sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float - ).transpose(0, 1) + sharded_logits, label_ids, group=self.tp_pg, dtype=torch.float + ) + # TODO @thomasw21: It's unclear what kind of normalization we want to do. loss = masked_mean(loss, label_mask, dtype=torch.float) # I think indexing causes a sync we don't actually want From b2ca553e86e8eaabf5aba8163292c63289498b0b Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 8 Feb 2024 10:16:27 +0000 Subject: [PATCH 14/57] update(mamba-slow) --- src/nanotron/models/mamba_slow/mamba.py | 154 +++++++++++++++++------- 1 file changed, 108 insertions(+), 46 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 3086dbac..b3678166 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -69,6 +69,7 @@ import torch.nn.functional as F from torch import Tensor +from nanotron.utils import init_method_normal, scaled_init_method_normal from einops import rearrange, repeat from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn @@ -76,7 +77,7 @@ try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: - causal_conv1d_fn, causal_conv1d_update = None + causal_conv1d_fn, causal_conv1d_update = None, None try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -88,6 +89,7 @@ except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None +import lovely_tensors as lt; lt.monkey_patch() logger = logging.get_logger(__name__) @@ -96,6 +98,8 @@ class Mamba(nn.Module): def __init__( self, d_model, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, d_state=16, d_conv=4, expand=2, @@ -123,8 +127,24 @@ def __init__( self.use_fast_path = use_fast_path self.layer_idx = layer_idx - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) - + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + # Get current tensor parallel rank + self.tp_pg = tp_pg + self.tp_rank = dist.get_rank(self.tp_pg) + + self.in_proj = TensorParallelColumnLinear( + in_features=self.d_model, + out_features=self.d_inner * 2, + pg=tp_pg, + mode=tp_mode, + bias=bias, + async_communication=False, + contiguous_chunks=None + ) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, @@ -138,9 +158,16 @@ def __init__( self.activation = "silu" self.act = nn.SiLU() - self.x_proj = nn.Linear( - self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + self.x_proj = TensorParallelRowLinear( + in_features=self.d_inner, + out_features=self.dt_rank + self.d_state * 2, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=None ) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -178,7 +205,20 @@ def __init__( self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 self.D._no_weight_decay = True - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.out_proj = TensorParallelRowLinear( + in_features=self.d_inner, + out_features=self.d_model, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=None + ) + + def _split_weight(self, data: torch.Tensor, dim: int) -> torch.Tensor: + chunks = torch.chunk(data, self.tp_pg.size(), dim=dim) + return chunks[self.tp_rank].contiguous() def forward(self, hidden_states, inference_params=None): """ @@ -196,47 +236,58 @@ def forward(self, hidden_states, inference_params=None): return out # We do matmul and transpose BLH -> HBL at the same time - xz = rearrange( - self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), - "d (b l) -> b d l", - l=seqlen, - ) - if self.in_proj.bias is not None: - xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") + xz = self.in_proj(hidden_states).transpose(1, 2) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + A_shard = self._split_weight(A, dim=0) + conv1d_weight_shard = self._split_weight(self.conv1d.weight, dim=0) + conv1d_bias_shard = self._split_weight(self.conv1d.bias, dim=0) + dt_proj_weight_shard = self._split_weight(self.dt_proj.weight, dim=0) + + D_shard = self._split_weight(self.D, dim=0) + dt_proj_bias_shard = self._split_weight(self.dt_proj.bias, dim=0) + # In the backward pass we write dx and dz next to each other to avoid torch.cat if self.use_fast_path and inference_params is None: # Doesn't support outputting the states + raise NotImplementedError("Fast path not implemented, need to add xz.view(...) into mamba_inner_fn") out = mamba_inner_fn( + self.d_inner, + self.tp_pg, xz, - self.conv1d.weight, - self.conv1d.bias, + conv1d_weight_shard, + conv1d_bias_shard, self.x_proj.weight, - self.dt_proj.weight, + dt_proj_weight_shard, self.out_proj.weight, self.out_proj.bias, - A, + A_shard, None, # input-dependent B None, # input-dependent C - self.D.float(), - delta_bias=self.dt_proj.bias.float(), + D_shard.float(), + delta_bias=dt_proj_bias_shard.float(), delta_softplus=True, ) else: - x, z = xz.chunk(2, dim=1) + assert self.d_inner % self.tp_pg.size() == 0 + x, z = xz.view(batch, self.d_inner // self.tp_pg.size(), 2, seqlen).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) if causal_conv1d_fn is None: + # TODO(fmom): do split tp x = self.act(self.conv1d(x)[..., :seqlen]) else: + assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x=x, - weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, + weight=rearrange(conv1d_weight_shard, "d 1 w -> d w"), + bias=conv1d_bias_shard, activation=self.activation, ) @@ -245,7 +296,7 @@ def forward(self, hidden_states, inference_params=None): # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() + dt = dt_proj_weight_shard @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() @@ -253,12 +304,12 @@ def forward(self, hidden_states, inference_params=None): y = selective_scan_fn( x, dt, - A, + A_shard, B, C, - self.D.float(), + D_shard.float(), z=z, - delta_bias=self.dt_proj.bias.float(), + delta_bias=dt_proj_bias_shard.float(), delta_softplus=True, return_last_state=ssm_state is not None, ) @@ -376,7 +427,7 @@ def __init__( super().__init__() self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(dim) + self.mixer = mixer_cls(d_model=dim) self.norm = norm_cls(dim) if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" @@ -418,6 +469,8 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) def create_block( + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, d_model, ssm_cfg=None, norm_epsilon=1e-5, @@ -431,7 +484,7 @@ def create_block( if ssm_cfg is None: ssm_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} - mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + mixer_cls = partial(Mamba, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) @@ -489,7 +542,9 @@ def __init__( super().__init__() self.block = create_block( - config.d_model, + parallel_config=parallel_config, + tp_pg=tp_pg, + d_model=config.d_model, ssm_cfg=config.ssm_cfg, norm_epsilon=config.rms_norm_eps, rms_norm=config.rms_norm, @@ -570,7 +625,7 @@ def __init__( p2p=self.p2p, module_builder=RMSNorm, module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps}, - module_input_keys={"x"}, + module_input_keys={"x", "residual"}, module_output_keys={"hidden_states"}, ) @@ -625,14 +680,32 @@ def forward_with_hidden_states( for block in self.decoder: hidden_encoder_states = block(**hidden_encoder_states) - hidden_states = self.final_layer_norm(x=hidden_encoder_states["hidden_states"])["hidden_states"] + hidden_states = self.final_layer_norm(x=hidden_encoder_states["hidden_states"], residual=hidden_encoder_states["residual"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] - fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] return fp32_sharded_logits, hidden_states + def _print_param_stat(self, param): + print(f"\tmin={param.min().item()}") + print(f"\tmean={param.mean().item()}") + print(f"\tmedian={param.median().item()}") + print(f"\tmax={param.max().item()}") + + + def _print_all_param_stats(self, msg): + print(msg) + named_parameters = list(self.named_parameters()) + named_parameters.sort(key=lambda x: x[0]) + for name, param in named_parameters: + print(name) + print(f"\tmin={param.min().item()}") + print(f"\tmean={param.mean().item()}") + print(f"\tmedian={param.median().item()}") + print(f"\tmax={param.max().item()}") + print("================") + def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" @@ -790,6 +863,8 @@ def init_model_randomly(self, init_method, scaled_init_method): # Fix the root_model module_id_to_prefix[id(model)] = "" + #TODO(fmom): clean this + for module_name, module in model.named_modules(): if isinstance(module, TensorParallelColumnLinear): # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 @@ -985,26 +1060,13 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight"]: - # get fullname - assert isinstance(p, NanotronParameter) - if p.is_tied: - tied_info = p.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): - p /= math.sqrt(n_residuals_per_layer * n_layer) - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + p /= math.sqrt(n_residuals_per_layer * n_layer) # #TODO(fmom): perform check # assert initialized_parameters == { From 3c4ef42e215cfbcc4ecc57a7d71e67713a5ee165 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 8 Feb 2024 13:06:51 +0000 Subject: [PATCH 15/57] fix: unified random unit --- src/nanotron/models/base.py | 2 +- src/nanotron/trainer.py | 22 ++++++++++++++-------- src/nanotron/utils.py | 6 +++--- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index f2ab8ed5..ba528a68 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -35,7 +35,7 @@ def __init__(self, *args, **kwargs) -> None: self.output_pp_rank: int @abstractmethod - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, init_method, scaled_init_method, **kwargs): ... @staticmethod diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 4946b094..df4c5298 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -36,6 +36,7 @@ from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.models.mamba_slow.mamba import MambaForTraining +from brrr.models.mamba_fast.mamba import MambaFastForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -85,6 +86,7 @@ "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, "MambaConfig": MambaForTraining, + "MambaFastConfig": MambaFastForTraining, } @@ -561,12 +563,15 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: root_folder=self.config.model.init_method.path, ) elif isinstance(self.config.model.init_method, MambaInit): - # Initialize model randomly - normalized_model.init_mamba_weights( - n_layer=self.model_config.num_hidden_layers, - initializer_range=self.config.model.init_method.initializer_range, + + normalized_model.init_model_randomly( + init_method=init_method_normal(self.config.model.init_method.initializer_range), + scaled_init_method=scaled_init_method_normal( + sigma=self.config.model.init_method.initializer_range, + num_layers=self.model_config.num_hidden_layers, + scale=self.config.model.init_method.n_residuals_per_layer + ), rescale_prenorm_residual=self.config.model.init_method.rescale_prenorm_residual, - n_residuals_per_layer=self.config.model.init_method.n_residuals_per_layer ) # Synchronize parameters so that the model is consistent # sync all params across dp @@ -587,9 +592,10 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: # Initialize model randomly normalized_model.init_model_randomly( init_method=init_method_normal(self.config.model.init_method.std), - scaled_init_method=scaled_init_method_normal( - self.config.model.init_method.std, self.model_config.num_hidden_layers - ), + scaled_init_method_normal=scaled_init_method_normal( + sigma=self.config.model.init_method.std, + num_layers=self.model_config.num_hidden_layers, + ) ) # Synchronize parameters so that the model is consistent # sync all params across dp diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 4eaf8a9f..6ea115a9 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -131,9 +131,9 @@ def init_(tensor: torch.Tensor): return init_ -def scaled_init_method_normal(sigma: float, num_layers: int) -> Callable[[torch.Tensor], None]: - """Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(2.0 * num_layers) +def scaled_init_method_normal(sigma: float, num_layers: int, scale: int = 2) -> Callable[[torch.Tensor], None]: + """Default: Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(scale * num_layers) def init_(tensor: torch.Tensor): torch.nn.init.normal_(tensor, mean=0.0, std=std) From afcbb10c620f8eb64b8e9e910fe5ac3194619b12 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 8 Feb 2024 16:43:10 +0000 Subject: [PATCH 16/57] feat(optimizer): can now apply weight decay to specifcs parameters --- src/nanotron/helpers.py | 33 ++++++++++++------- src/nanotron/optim/named_optimizer.py | 47 +++++++++++++++------------ 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 6c76e8b9..ac5c712a 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -148,28 +148,37 @@ def init_optimizer_and_grad_accumulator( module_id_to_prefix[root_model_id] = "" # named parameters - named_parameters = [ - ( - param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - if param.is_tied - else name, - param, - ) - for name, param in normalized_model.named_parameters() - ] + named_parameters = { + "decay": [], + "no_decay": [] + } + + # NOTE(fmom): Separate parameters who have weight decay and those who don't + # (based on _no_weight_decay attribute that is set in init_model_randomly of each model) + for name, param in normalized_model.named_parameters(): + if param.is_tied: + full_name = param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + else: + full_name = name + + if optimizer_args.weight_decay == 0.0 or (hasattr(param, "_no_weight_decay") and param._no_weight_decay): + named_parameters["no_decay"].append((full_name, param)) + else: + named_parameters["decay"].append((full_name, param)) + # Basic optimizer builder def basic_optimizer_builder(named_param_groups): return NamedOptimizer( named_params_or_groups=named_param_groups, + weight_decay=optimizer_args.weight_decay, optimizer_builder=lambda param_groups: AdamW( # pylint: disable=E0601 param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, - weight_decay=optimizer_args.weight_decay, eps=optimizer_args.adam_eps, betas=(optimizer_args.adam_beta1, optimizer_args.adam_beta2), - fused=optimizer_args.torch_adam_is_fused, - ), + fused=optimizer_args.torch_adam_is_fused + ) ) optimizer_builder = basic_optimizer_builder diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index 23614b05..0d1f2188 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -11,31 +11,36 @@ class NamedOptimizer(InheritFromOtherOptimizer): def __init__( self, named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]], + weight_decay: float, optimizer_builder: Callable[[Iterable[Dict[str, Any]]], torch.optim.Optimizer], - ): - named_param_groups = list(named_params_or_groups) - if len(named_param_groups) == 0 or not isinstance(named_param_groups[0], dict): - named_param_groups = [{"named_params": named_param_groups}] - - id_to_name = {} - params = [] - for named_param_group in named_param_groups: - assert "named_params" in named_param_group - # Don't need to check that param_groups are overlapping since the optimizer will do it for me. - # https://github.com/pytorch/pytorch/blob/88b3810c94b45f5982df616e2bc4c471d173f491/torch/optim/optimizer.py#L473 - id_to_name.update( - {id(param): name for name, param in named_param_group["named_params"] if id(param) not in id_to_name} - ) - params.append( - { - **{k: v for k, v in named_param_group.items() if k != "named_params"}, - "params": [param for _, param in named_param_group["named_params"]], - } - ) - + ): + id_to_name_decay, id_to_name_no_decay = {}, {} + + # Don't need to check that param_groups are overlapping since the optimizer will do it for me. + # https://github.com/pytorch/pytorch/blob/88b3810c94b45f5982df616e2bc4c471d173f491/torch/optim/optimizer.py#L473 + id_to_name_decay.update( + {id(param): name for name, param in named_params_or_groups["decay"] if id(param) not in id_to_name_decay} + ) + id_to_name_no_decay.update( + {id(param): name for name, param in named_params_or_groups["no_decay"] if id(param) not in id_to_name_no_decay} + ) + + id_to_name = {**id_to_name_decay, **id_to_name_no_decay} name_to_id = {v: k for k, v in id_to_name.items()} assert len(id_to_name) == len(name_to_id) + #TODO(fmom) Pass weight decay value from config here + params = [ + { + "params": [param for _, param in named_params_or_groups["decay"]], + "weight_decay": weight_decay + }, + { + "params": [param for _, param in named_params_or_groups["no_decay"]], + "weight_decay": 0.0 + } + ] + # Sanity check for param_group in params: _params = param_group["params"] From a6a163f53c24e3b19b8167380d21290af1094d59 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 12 Feb 2024 19:07:39 +0000 Subject: [PATCH 17/57] refacto: mamba1 --- src/nanotron/models/mamba_slow/mamba.py | 199 +++++++++--------------- 1 file changed, 75 insertions(+), 124 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index b3678166..763edf3c 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -14,6 +14,7 @@ # limitations under the License. """ PyTorch Mamba model. """ +import os from typing import Dict, Optional, Union import math import torch @@ -249,25 +250,25 @@ def forward(self, hidden_states, inference_params=None): dt_proj_bias_shard = self._split_weight(self.dt_proj.bias, dim=0) # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states + if self.use_fast_path and inference_params is None and os.environ.get("FAST_PATH", "0") == "1": # Doesn't support outputting the states raise NotImplementedError("Fast path not implemented, need to add xz.view(...) into mamba_inner_fn") - out = mamba_inner_fn( - self.d_inner, - self.tp_pg, - xz, - conv1d_weight_shard, - conv1d_bias_shard, - self.x_proj.weight, - dt_proj_weight_shard, - self.out_proj.weight, - self.out_proj.bias, - A_shard, - None, # input-dependent B - None, # input-dependent C - D_shard.float(), - delta_bias=dt_proj_bias_shard.float(), - delta_softplus=True, - ) + # out = mamba_inner_fn( + # self.d_inner, + # self.tp_pg, + # xz, + # conv1d_weight_shard, + # conv1d_bias_shard, + # self.x_proj.weight, + # dt_proj_weight_shard, + # self.out_proj.weight, + # self.out_proj.bias, + # A_shard, + # None, # input-dependent B + # None, # input-dependent C + # D_shard.float(), + # delta_bias=dt_proj_bias_shard.float(), + # delta_softplus=True, + # ) else: assert self.d_inner % self.tp_pg.size() == 0 x, z = xz.view(batch, self.d_inner // self.tp_pg.size(), 2, seqlen).chunk(2, dim=2) @@ -408,96 +409,6 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states ssm_state.zero_() return conv_state, ssm_state -class Block(nn.Module): - def __init__( - self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False - ): - """ - Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" - - This Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA/MLP -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Add -> LN -> Mixer, returning both - the hidden_states (output of the mixer) and the residual. - This is purely for performance reasons, as we can fuse add and LayerNorm. - The residual needs to be provided (except for the very first block). - """ - super().__init__() - self.residual_in_fp32 = residual_in_fp32 - self.fused_add_norm = fused_add_norm - self.mixer = mixer_cls(d_model=dim) - self.norm = norm_cls(dim) - if self.fused_add_norm: - assert RMSNorm is not None, "RMSNorm import fails" - assert isinstance( - self.norm, (nn.LayerNorm, RMSNorm) - ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - - def forward( - self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None - ): - r"""Pass the input through the encoder layer. - - Args: - hidden_states: the sequence to the encoder layer (required). - residual: hidden_states = Mixer(LN(residual)) - """ - if not self.fused_add_norm: - # self.layer_idx was assigned when calling create_block - # residual=None happens only at the first block - residual = hidden_states if (self.layer_idx == 0) else hidden_states + residual - hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - else: - fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn - hidden_states, residual = fused_add_norm_fn( - hidden_states, - self.norm.weight, - self.norm.bias, - residual=residual, - prenorm=True, - residual_in_fp32=self.residual_in_fp32, - eps=self.norm.eps, - ) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - return hidden_states, residual - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - -def create_block( - parallel_config: Optional[ParallelismArgs], - tp_pg: dist.ProcessGroup, - d_model, - ssm_cfg=None, - norm_epsilon=1e-5, - rms_norm=False, - residual_in_fp32=False, - fused_add_norm=False, - layer_idx=None, - device=None, - dtype=None, -): - if ssm_cfg is None: - ssm_cfg = {} - factory_kwargs = {"device": device, "dtype": dtype} - mixer_cls = partial(Mamba, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) - norm_cls = partial( - nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs - ) - block = Block( - d_model, - mixer_cls, - norm_cls=norm_cls, - fused_add_norm=fused_add_norm, - residual_in_fp32=residual_in_fp32, - ) - block.layer_idx = layer_idx - return block - class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: MambaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -537,37 +448,77 @@ def __init__( layer_idx: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + + if config.ssm_cfg is None: + ssm_cfg = {} + else: + ssm_cfg = config.ssm_cfg + + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.fused_add_norm = config.fused_add_norm + + self.mixer = Mamba( + d_model=config.d_model, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs + ) + + self.norm = partial( + nn.LayerNorm if not config.rms_norm + else RMSNorm, eps=config.rms_norm_eps, **factory_kwargs + )(config.d_model) + + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - super().__init__() - self.block = create_block( - parallel_config=parallel_config, - tp_pg=tp_pg, - d_model=config.d_model, - ssm_cfg=config.ssm_cfg, - norm_epsilon=config.rms_norm_eps, - rms_norm=config.rms_norm, - residual_in_fp32=config.residual_in_fp32, - fused_add_norm=config.fused_add_norm, - layer_idx=layer_idx, - **factory_kwargs, - ) def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], residual: Optional[Union[torch.Tensor, TensorPointer]], + inference_params=None, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - hidden_states, residual = self.block(hidden_states, residual) + if not self.fused_add_norm: + # self.layer_idx was assigned when calling create_block + # residual=None happens only at the first block + residual = hidden_states if (self.layer_idx == 0) else hidden_states + residual + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return { "hidden_states": hidden_states, "sequence_mask": sequence_mask, # NOTE(fmom): dunno how to use it for now. Just keep it "residual": residual, } - + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) class MambaModel(nn.Module): def __init__( From ad17887619efc517f7bc7c28055d748ca208da32 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 12 Feb 2024 20:19:56 +0000 Subject: [PATCH 18/57] feat: add selective scan interface --- .../mamba_slow/selective_scan_interface.py | 345 ++++++++++++++++++ 1 file changed, 345 insertions(+) create mode 100644 src/nanotron/models/mamba_slow/selective_scan_interface.py diff --git a/src/nanotron/models/mamba_slow/selective_scan_interface.py b/src/nanotron/models/mamba_slow/selective_scan_interface.py new file mode 100644 index 00000000..b8f14dd0 --- /dev/null +++ b/src/nanotron/models/mamba_slow/selective_scan_interface.py @@ -0,0 +1,345 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +from einops import rearrange, repeat + +from causal_conv1d import causal_conv1d_fn +import causal_conv1d_cuda +import selective_scan_cuda + + +class SelectiveScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None) + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +class MambaInnerFn(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + """ + xz: (batch, dim, seqlen) + """ + assert checkpoint_lvl in [0, 1] + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + if torch.is_autocast_enabled(): + x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) + if out_proj_bias is not None else None) + if xz.stride(-1) != 1: + xz = xz.contiguous() + conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") + x, z = xz.chunk(2, dim=1) + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None + ctx.is_variable_C = C is None + ctx.B_proj_bias_is_None = B_proj_bias is None + ctx.C_proj_bias_is_None = C_proj_bias is None + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if B.stride(-1) != 1: + B = B.contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if C.stride(-1) != 1: + C = C.contiguous() + if D is not None: + D = D.contiguous() + out, scan_intermediates, out_z = selective_scan_cuda.fwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + ) + ctx.delta_softplus = delta_softplus + ctx.out_proj_bias_is_None = out_proj_bias is None + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + conv1d_out, delta = None, None + ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, + delta_proj_weight, out_proj_weight, conv1d_out, delta, + A, B, C, D, delta_bias, scan_intermediates, out) + return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + @staticmethod + @custom_bwd + def backward(ctx, dout): + # dout: (batch, seqlen, dim) + (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + if dout.stride(-1) != 1: + dout = dout.contiguous() + if ctx.checkpoint_lvl == 1: + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), + "d (b l) -> b d l", l = L) + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + dxz = torch.empty_like(xz) # (batch, dim, seqlen) + dx, dz = dxz.chunk(2, dim=1) + dout = rearrange(dout, "b l e -> e (b l)") + dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) + dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, + ctx.delta_softplus, + True # option to recompute out_z + ) + dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) + dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None + dD = dD if D is not None else None + dx_dbl = torch.empty_like(x_dbl) + dB_proj_bias = None + if ctx.is_variable_B: + if not A.is_complex(): + dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None + dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) + dB = None + dC_proj_bias = None + if ctx.is_variable_C: + if not A.is_complex(): + dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None + dx_dbl[:, -d_state:] = dC # (bl d) + dC = None + ddelta = rearrange(ddelta, "b d l -> d (b l)") + ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) + dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) + dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True + ) + dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None + dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + dout_proj_weight, dout_proj_bias, + dA, dB, dC, dD, + ddelta_bias if delta_bias is not None else None, + dB_proj_bias, dC_proj_bias, None) + + +def mamba_inner_fn( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + + +def mamba_inner_ref( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() + delta = rearrange(delta, "d (b l) -> b d l", l=L) + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + if C is None: # variable B + C = x_dbl[:, -d_state:] # (bl d) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) From 6bfd6aaff3266540e22414be042502a8b74eef8f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 12 Feb 2024 20:20:25 +0000 Subject: [PATCH 19/57] feat: make fast path compatible with TP --- src/nanotron/models/mamba_slow/mamba.py | 42 ++++----- .../mamba_slow/selective_scan_interface.py | 91 +++++++++++++------ 2 files changed, 83 insertions(+), 50 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 763edf3c..96caa1e2 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -73,7 +73,7 @@ from nanotron.utils import init_method_normal, scaled_init_method_normal from einops import rearrange, repeat -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn +from nanotron.models.mamba_slow.selective_scan_interface import selective_scan_fn, mamba_inner_fn try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -251,24 +251,21 @@ def forward(self, hidden_states, inference_params=None): # In the backward pass we write dx and dz next to each other to avoid torch.cat if self.use_fast_path and inference_params is None and os.environ.get("FAST_PATH", "0") == "1": # Doesn't support outputting the states - raise NotImplementedError("Fast path not implemented, need to add xz.view(...) into mamba_inner_fn") - # out = mamba_inner_fn( - # self.d_inner, - # self.tp_pg, - # xz, - # conv1d_weight_shard, - # conv1d_bias_shard, - # self.x_proj.weight, - # dt_proj_weight_shard, - # self.out_proj.weight, - # self.out_proj.bias, - # A_shard, - # None, # input-dependent B - # None, # input-dependent C - # D_shard.float(), - # delta_bias=dt_proj_bias_shard.float(), - # delta_softplus=True, - # ) + y = mamba_inner_fn( + d_inner=self.d_inner, + tp_pg=self.tp_pg, + xz=xz, + conv1d_weight=conv1d_weight_shard, + conv1d_bias=conv1d_bias_shard, + x_proj_weight=self.x_proj.weight, + delta_proj_weight=dt_proj_weight_shard, + A=A_shard, + B=None, # input-dependent B + C=None, # input-dependent C + D=D_shard.float(), + delta_bias=dt_proj_bias_shard.float(), + delta_softplus=True, + ) else: assert self.d_inner % self.tp_pg.size() == 0 x, z = xz.view(batch, self.d_inner // self.tp_pg.size(), 2, seqlen).chunk(2, dim=2) @@ -318,7 +315,8 @@ def forward(self, hidden_states, inference_params=None): y, last_state = y ssm_state.copy_(last_state) y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) + + out = self.out_proj(y) return out def step(self, hidden_states, conv_state, ssm_state): @@ -674,7 +672,7 @@ def get_block_compute_costs(self): block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - MambaDecoderLayer: 0, + MambaDecoderLayer: 1, # This is the last lm_head TensorParallelColumnLinear: 0, } @@ -1017,7 +1015,7 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): - p /= math.sqrt(n_residuals_per_layer * n_layer) + p /= math.sqrt(n_residuals_per_layer * n_layer) # #TODO(fmom): perform check # assert initialized_parameters == { diff --git a/src/nanotron/models/mamba_slow/selective_scan_interface.py b/src/nanotron/models/mamba_slow/selective_scan_interface.py index b8f14dd0..3aff7589 100644 --- a/src/nanotron/models/mamba_slow/selective_scan_interface.py +++ b/src/nanotron/models/mamba_slow/selective_scan_interface.py @@ -156,27 +156,30 @@ class MambaInnerFn(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, + def forward(ctx, d_inner, tp_pg, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): """ xz: (batch, dim, seqlen) """ assert checkpoint_lvl in [0, 1] - L = xz.shape[-1] + batch, L = xz.shape[0], xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) - if out_proj_bias is not None else None) + if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") - x, z = xz.chunk(2, dim=1) + + # x, z = xz.chunk(2, dim=1) + assert d_inner % tp_pg.size() == 0 + x, z = xz.view(batch, d_inner // tp_pg.size(), 2, L).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) # We're being very careful here about the layout, to avoid extra transposes. @@ -218,27 +221,36 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus - ctx.out_proj_bias_is_None = out_proj_bias is None + # ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None + + ctx.d_inner = d_inner + ctx.tp_pg = tp_pg + ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, - delta_proj_weight, out_proj_weight, conv1d_out, delta, + delta_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) - return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + return rearrange(out_z, "b d l -> b l d") @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) - (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, + (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors - L = xz.shape[-1] + batch, L = xz.shape[0], xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - x, z = xz.chunk(2, dim=1) - if dout.stride(-1) != 1: - dout = dout.contiguous() + + # x, z = xz.chunk(2, dim=1) + assert ctx.d_inner % ctx.tp_pg.size() == 0 + x, z = xz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2) + x = x.squeeze(2) + z = z.squeeze(2) + if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), @@ -246,16 +258,24 @@ def backward(ctx, dout): # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) - dx, dz = dxz.chunk(2, dim=1) - dout = rearrange(dout, "b l e -> e (b l)") - dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) + + # dx, dz = dxz.chunk(2, dim=1) + assert ctx.d_inner % ctx.tp_pg.size() == 0 + dx, dz = dxz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2) + dx = dx.squeeze(2) + dz = dz.squeeze(2) + + dout = rearrange(dout, "b l e -> b e l") + + if dout.stride(-1) != 1: + dout = dout.contiguous() + dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, + conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz, ctx.delta_softplus, True # option to recompute out_z ) - dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) - dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None + dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None @@ -290,22 +310,37 @@ def backward(ctx, dout): ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, - dout_proj_weight, dout_proj_bias, + return (None, # d_inner + None, # tp_pg + dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, dB_proj_bias, dC_proj_bias, None) def mamba_inner_fn( - xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, + d_inner, tp_pg, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): - return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + + return MambaInnerFn.apply( + d_inner, + tp_pg, + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + A, + B, + C, + D, + delta_bias, + B_proj_bias, + C_proj_bias, + delta_softplus + ) def mamba_inner_ref( From f0e219b3bef38253b2baab88804adcc5575a005f Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 13 Feb 2024 08:55:43 +0000 Subject: [PATCH 20/57] refacto: no split weight --- src/nanotron/models/mamba_slow/mamba.py | 58 ++++++++++--------------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 96caa1e2..5d66dfac 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -53,6 +53,7 @@ from nanotron.utils import checkpoint_method from nanotron.config.models_config import MambaConfig +#TODO(fmom): Need to clean imports #NOTE(fmom): mamba_ssm=1.1.1 # from mamba_ssm.models.mixer_seq_simple import create_block, Mamba, _init_weights @@ -146,12 +147,15 @@ def __init__( async_communication=False, contiguous_chunks=None ) + + assert self.d_inner % self.tp_pg.size() == 0 + self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, + in_channels=self.d_inner // self.tp_pg.size(), + out_channels=self.d_inner // self.tp_pg.size(), bias=conv_bias, kernel_size=d_conv, - groups=self.d_inner, + groups=self.d_inner // self.tp_pg.size(), padding=d_conv - 1, **factory_kwargs, ) @@ -169,7 +173,7 @@ def __init__( contiguous_chunks=None ) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner // self.tp_pg.size(), bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = self.dt_rank**-0.5 * dt_scale @@ -182,7 +186,7 @@ def __init__( # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( - torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + torch.rand(self.d_inner // self.tp_pg.size(), **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -196,14 +200,14 @@ def __init__( A = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", - d=self.d_inner, + d=self.d_inner // self.tp_pg.size(), ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + self.D = nn.Parameter(torch.ones(self.d_inner // self.tp_pg.size(), device=device)) # Keep in fp32 self.D._no_weight_decay = True # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) @@ -216,10 +220,6 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=None ) - - def _split_weight(self, data: torch.Tensor, dim: int) -> torch.Tensor: - chunks = torch.chunk(data, self.tp_pg.size(), dim=dim) - return chunks[self.tp_rank].contiguous() def forward(self, hidden_states, inference_params=None): """ @@ -238,32 +238,23 @@ def forward(self, hidden_states, inference_params=None): # We do matmul and transpose BLH -> HBL at the same time xz = self.in_proj(hidden_states).transpose(1, 2) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - A_shard = self._split_weight(A, dim=0) - conv1d_weight_shard = self._split_weight(self.conv1d.weight, dim=0) - conv1d_bias_shard = self._split_weight(self.conv1d.bias, dim=0) - dt_proj_weight_shard = self._split_weight(self.dt_proj.weight, dim=0) - - D_shard = self._split_weight(self.D, dim=0) - dt_proj_bias_shard = self._split_weight(self.dt_proj.bias, dim=0) - # In the backward pass we write dx and dz next to each other to avoid torch.cat if self.use_fast_path and inference_params is None and os.environ.get("FAST_PATH", "0") == "1": # Doesn't support outputting the states y = mamba_inner_fn( d_inner=self.d_inner, tp_pg=self.tp_pg, xz=xz, - conv1d_weight=conv1d_weight_shard, - conv1d_bias=conv1d_bias_shard, + conv1d_weight=self.conv1d.weight, + conv1d_bias=self.conv1d.bias, x_proj_weight=self.x_proj.weight, - delta_proj_weight=dt_proj_weight_shard, - A=A_shard, + delta_proj_weight=self.dt_proj.weight, + A=A, B=None, # input-dependent B C=None, # input-dependent C - D=D_shard.float(), - delta_bias=dt_proj_bias_shard.float(), + D=self.D.float(), + delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) else: @@ -277,15 +268,14 @@ def forward(self, hidden_states, inference_params=None): # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) if causal_conv1d_fn is None: - # TODO(fmom): do split tp x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x=x, - weight=rearrange(conv1d_weight_shard, "d 1 w -> d w"), - bias=conv1d_bias_shard, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, activation=self.activation, ) @@ -294,7 +284,7 @@ def forward(self, hidden_states, inference_params=None): # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = dt_proj_weight_shard @ dt.t() + dt = self.dt_proj.weight @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() @@ -302,12 +292,12 @@ def forward(self, hidden_states, inference_params=None): y = selective_scan_fn( x, dt, - A_shard, + A, B, C, - D_shard.float(), + self.D.float(), z=z, - delta_bias=dt_proj_bias_shard.float(), + delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, ) @@ -432,7 +422,6 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ store["past_length"] = past_length + cumsum_mask[:, -1] # Format input in `[seq_length, batch_size]` to support high TP with low batch_size - #NOTE(fmom): undo transpose for now since Mamba is not using TP # input_ids = input_ids.transpose(0, 1) input_embeds = self.token_embedding(input_ids) return {"input_embeds": input_embeds} @@ -636,6 +625,7 @@ def forward_with_hidden_states( return fp32_sharded_logits, hidden_states + #TODO(fmom): clean this def _print_param_stat(self, param): print(f"\tmin={param.min().item()}") print(f"\tmean={param.mean().item()}") From ae312fe9e65aab73e96efe55996a417c93f48ec7 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 13 Feb 2024 10:45:19 +0000 Subject: [PATCH 21/57] refacto: cleaner init weights --- src/nanotron/models/mamba_slow/mamba.py | 302 ++++++++---------------- 1 file changed, 97 insertions(+), 205 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 5d66dfac..0ac3ded7 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -18,6 +18,7 @@ from typing import Dict, Optional, Union import math import torch +from torch.nn import init from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( flash_attn_varlen_func, @@ -143,7 +144,7 @@ def __init__( out_features=self.d_inner * 2, pg=tp_pg, mode=tp_mode, - bias=bias, + bias=False, async_communication=False, contiguous_chunks=None ) @@ -794,142 +795,7 @@ def init_model_randomly(self, init_method, scaled_init_method): Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ - model = self - initialized_parameters = set() - - # Handle tensor parallelism - module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} - # Fix the root_model - module_id_to_prefix[id(model)] = "" - - #TODO(fmom): clean this - - for module_name, module in model.named_modules(): - if isinstance(module, TensorParallelColumnLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelRowLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, RMSNorm): - assert {"weight"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - - assert initialized_parameters == { - param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - if param.is_tied - else name - for name, param in model.named_parameters() - }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + raise NotImplementedError("init_model_randomly not implemented for MambaForTraining") @torch.no_grad() def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residual, n_residuals_per_layer): @@ -942,78 +808,104 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua # Fix the root_model module_id_to_prefix[id(model)] = "" - #TODO(fmom): port initiliaztion from mamba_ssm.mamba_simple.Mamba to here - for module_name, module in model.named_modules(): - if isinstance(module, nn.Linear): - - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - pass - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) + def _get_module(module, path): + # Get module name out of param_name + attrs = path.split('.') + for attr in attrs: + # Case: model.decoder.0.pp_block.mixer.in_proj + # Need to convert to model.decoder[0].pp_block.mixer.in_proj + if '[' in attr and ']' in attr: + # Split the attribute and index part + attr, index = attr[:-1].split('[') + # Convert index to integer + index = int(index) + # First get the attribute until the index + module = getattr(module, attr) + # Then use the index to get the final module + module = module[index] else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue + # Regular attribute access + module = getattr(module, attr) + return module + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit('.', 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + if full_param_name in initialized_parameters: + # Already initialized + continue + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + module = _get_module(model, module_name) + print(module_name, param_name, module) + + if isinstance(module, TensorParallelColumnLinear) or isinstance(module, TensorParallelRowLinear): + if "weight" == param_name: + init.kaiming_uniform_(param, a=math.sqrt(5)) + elif "bias" == param_name: + raise ValueError("We don't use bias for TensorParallelColumnLinear and TensorParallelRow") + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, nn.Linear): + fan_in = None + + if "weight" == param_name: + fan_in, _ = init._calculate_fan_in_and_fan_out(param) + init.kaiming_uniform_(param, a=math.sqrt(5)) + elif "bias" == param_name: + bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + init.uniform_(param, -bound, bound) + else: + raise ValueError(f"Who the fuck is {param_name}?") + + if rescale_prenorm_residual and param_name in ["out_proj.weight"]: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + elif isinstance(module, TensorParallelEmbedding): nn.init.normal_(module.weight, std=initializer_range) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - - if rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - with torch.no_grad(): - p /= math.sqrt(n_residuals_per_layer * n_layer) - - # #TODO(fmom): perform check - # assert initialized_parameters == { - # param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - # if param.is_tied - # else name - # for name, param in model.named_parameters() - # }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + elif isinstance(module, RMSNorm) or isinstance(module, nn.LayerNorm): + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + param.fill_(1) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, Mamba): + # NOTE(fmom): nn.Parameter are initialized in Mamba __init__ + # In Mamba, only those 3 parameters don't have weight decay. + if param_name in ["dt_bias", "A_log", "D"]: + param._no_weight_decay = True + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" @staticmethod def get_embeddings_lm_head_tied_names(): From ff26ff50f8e9057e544a07053679f6f6eb9f10f1 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 14 Feb 2024 10:48:21 +0000 Subject: [PATCH 22/57] refacto: init weights --- src/nanotron/models/base.py | 2 +- src/nanotron/models/llama.py | 174 +++++++----------------- src/nanotron/models/mamba_slow/mamba.py | 87 ++++++------ src/nanotron/trainer.py | 21 +-- 4 files changed, 96 insertions(+), 188 deletions(-) diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index ba528a68..357800e1 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -35,7 +35,7 @@ def __init__(self, *args, **kwargs) -> None: self.output_pp_rank: int @abstractmethod - def init_model_randomly(self, init_method, scaled_init_method, **kwargs): + def init_model_randomly(self, config): ... @staticmethod diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 3148bd67..6b92e643 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -15,7 +15,7 @@ """ PyTorch LLaMa model. """ from typing import Dict, Optional, Union - +import math import torch from flash_attn import bert_padding from flash_attn.flash_attn_interface import ( @@ -882,14 +882,10 @@ def forward( label_mask=label_mask, )["loss"] return {"loss": loss} - + @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): """Initialize model parameters randomly. - Args: - init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ - scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ - Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ @@ -900,126 +896,60 @@ def init_model_randomly(self, init_method, scaled_init_method): # Fix the root_model module_id_to_prefix[id(model)] = "" - for module_name, module in model.named_modules(): + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + #TODO(fmom): Make sure it works with PP=2 + module_name, param_name = param_name.rsplit('.', 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + if isinstance(module, TensorParallelColumnLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TensorParallelRowLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + torch.nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TritonRMSNorm): - assert {"weight"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + assert initialized_parameters == { param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) if param.is_tied diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 0ac3ded7..b8ed2401 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -786,20 +786,7 @@ def forward( return {"loss": loss} @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): - """Initialize model parameters randomly. - Args: - init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ - scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ - - Note: - Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` - """ - raise NotImplementedError("init_model_randomly not implemented for MambaForTraining") - - @torch.no_grad() - def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residual, n_residuals_per_layer): - + def init_model_randomly(self, config): model = self initialized_parameters = set() @@ -808,26 +795,11 @@ def init_mamba_weights(self, n_layer, initializer_range, rescale_prenorm_residua # Fix the root_model module_id_to_prefix[id(model)] = "" + initializer_range = config.model.init_method.initializer_range + n_residuals_per_layer = config.model.init_method.n_residuals_per_layer + num_hidden_layers = config.model.model_config.num_hidden_layers + rescale_prenorm_residual = config.model.init_method.rescale_prenorm_residual - def _get_module(module, path): - # Get module name out of param_name - attrs = path.split('.') - for attr in attrs: - # Case: model.decoder.0.pp_block.mixer.in_proj - # Need to convert to model.decoder[0].pp_block.mixer.in_proj - if '[' in attr and ']' in attr: - # Split the attribute and index part - attr, index = attr[:-1].split('[') - # Convert index to integer - index = int(index) - # First get the attribute until the index - module = getattr(module, attr) - # Then use the index to get the final module - module = module[index] - else: - # Regular attribute access - module = getattr(module, attr) - return module for param_name, param in model.named_parameters(): assert isinstance(param, NanotronParameter) @@ -845,33 +817,44 @@ def _get_module(module, path): if full_param_name in initialized_parameters: # Already initialized continue - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - - module = _get_module(model, module_name) - print(module_name, param_name, module) + + module = model.get_submodule(module_name) if isinstance(module, TensorParallelColumnLinear) or isinstance(module, TensorParallelRowLinear): if "weight" == param_name: - init.kaiming_uniform_(param, a=math.sqrt(5)) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) elif "bias" == param_name: raise ValueError("We don't use bias for TensorParallelColumnLinear and TensorParallelRow") else: raise ValueError(f"Who the fuck is {param_name}?") + + if rescale_prenorm_residual and full_param_name.endswith("out_proj.weight"): + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers) + elif isinstance(module, nn.Linear): fan_in = None if "weight" == param_name: - fan_in, _ = init._calculate_fan_in_and_fan_out(param) - init.kaiming_uniform_(param, a=math.sqrt(5)) + fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) elif "bias" == param_name: bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 - init.uniform_(param, -bound, bound) + init.uniform_(module.bias, -bound, bound) else: raise ValueError(f"Who the fuck is {param_name}?") - if rescale_prenorm_residual and param_name in ["out_proj.weight"]: + if rescale_prenorm_residual and full_param_name.endswith("out_proj.weight"): # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. @@ -883,15 +866,18 @@ def _get_module(module, path): # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down with torch.no_grad(): - p /= math.sqrt(n_residuals_per_layer * n_layer) + module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers) + elif isinstance(module, nn.Conv1d): + print("TODO: handle covn1d. For now, it is initialiazed in Mamba constructor") + pass elif isinstance(module, TensorParallelEmbedding): nn.init.normal_(module.weight, std=initializer_range) elif isinstance(module, RMSNorm) or isinstance(module, nn.LayerNorm): if "weight" == param_name: # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) + module.weight.fill_(1) elif "bias" == param_name: - param.zero_() + module.bias.zero_() else: raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, Mamba): @@ -899,7 +885,12 @@ def _get_module(module, path): # In Mamba, only those 3 parameters don't have weight decay. if param_name in ["dt_bias", "A_log", "D"]: param._no_weight_decay = True - + else: + raise Exception(f"Parameter {full_param_name} was not intialized") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + assert initialized_parameters == { param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) if param.is_tied diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 86e72156..a9607ecd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -583,15 +583,7 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: ) elif isinstance(self.config.model.init_method, MambaInit): - normalized_model.init_model_randomly( - init_method=init_method_normal(self.config.model.init_method.initializer_range), - scaled_init_method=scaled_init_method_normal( - sigma=self.config.model.init_method.initializer_range, - num_layers=self.model_config.num_hidden_layers, - scale=self.config.model.init_method.n_residuals_per_layer - ), - rescale_prenorm_residual=self.config.model.init_method.rescale_prenorm_residual, - ) + normalized_model.init_model_randomly(config=self.config) # Synchronize parameters so that the model is consistent # sync all params across dp for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): @@ -608,14 +600,9 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: group = self.parallel_context.world_ranks_to_pg[group_ranks] dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) elif isinstance(self.config.model.init_method, RandomInit): - # Initialize model randomly - normalized_model.init_model_randomly( - init_method=init_method_normal(self.config.model.init_method.std), - scaled_init_method_normal=scaled_init_method_normal( - sigma=self.config.model.init_method.std, - num_layers=self.model_config.num_hidden_layers, - ) - ) + + normalized_model.init_model_randomly(config=self.config) + # Synchronize parameters so that the model is consistent # sync all params across dp for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): From 0ab7c9b74c067f22a20524171c179fb5ec3eba4d Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 16 Feb 2024 17:22:28 +0000 Subject: [PATCH 23/57] feat: save/load _no_weight_decay attribute for Mamba model --- src/nanotron/serialize/weights.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index aab0d531..27a36146 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -5,6 +5,8 @@ import torch from packaging.version import Version from safetensors.torch import safe_open, save_file +from safetensors import SafetensorError + from torch import nn from tqdm import tqdm @@ -88,7 +90,13 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde ) path.parent.mkdir(exist_ok=True, parents=True) try: - save_file(tensors={"data": param_or_buffer}, filename=path, metadata=metadata) + tensors = {"data": param_or_buffer} + + # Mamba has some parameters that should not be weight decayed + if hasattr(model.get_parameter(name), "_no_weight_decay"): + tensors.update({"_no_weight_decay": torch.tensor(model.get_parameter(name)._no_weight_decay)}) + + save_file(tensors=tensors, filename=path, metadata=metadata) except Exception as e: log_rank( f"Error saving {path} with {metadata}", @@ -251,6 +259,13 @@ def load_weights( with safe_open(path, framework="pt", device=str(param.device)) as fi: # TODO @thomasw21: Choose only a slice if we switch the TP topology param_or_buffer[:] = fi.get_tensor("data") + + # Only Mamba params has this attribute + try: + param._no_weight_decay = fi.get_tensor("_no_weight_decay") + except SafetensorError: + pass + elif not path.parent.exists(): raise ValueError( f"Checkpoint is empty or checkpoint structure is not matching the model architecture." From 3caf7fef8c571ed038d3b260b34d2817edeb3628 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 28 Feb 2024 10:18:08 +0000 Subject: [PATCH 24/57] clean import --- src/nanotron/models/mamba_slow/mamba.py | 27 ++----------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index b8ed2401..748ee766 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -16,23 +16,13 @@ """ import os from typing import Dict, Optional, Union -import math -import torch from torch.nn import init -from flash_attn import bert_padding -from flash_attn.flash_attn_interface import ( - flash_attn_varlen_func, - flash_attn_with_kvcache, -) -from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding -from torch import nn -from transformers.activations import ACT2FN from functools import partial from nanotron import distributed as dist from nanotron import logging from nanotron.config.utils_config import cast_str_to_torch_dtype -from nanotron.config import ParallelismArgs, RecomputeGranularity +from nanotron.config import ParallelismArgs from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.generation.generate_store import AttachableStore @@ -51,19 +41,7 @@ TensorParallelRowLinear, ) from nanotron.random import RandomStates -from nanotron.utils import checkpoint_method from nanotron.config.models_config import MambaConfig - -#TODO(fmom): Need to clean imports -#NOTE(fmom): mamba_ssm=1.1.1 -# from mamba_ssm.models.mixer_seq_simple import create_block, Mamba, _init_weights - -# try: -# from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn -# except ImportError: -# RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None - -# Copyright (c) 2023, Albert Gu, Tri Dao. import math from typing import Optional @@ -72,7 +50,6 @@ import torch.nn.functional as F from torch import Tensor -from nanotron.utils import init_method_normal, scaled_init_method_normal from einops import rearrange, repeat from nanotron.models.mamba_slow.selective_scan_interface import selective_scan_fn, mamba_inner_fn @@ -92,7 +69,7 @@ except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None -import lovely_tensors as lt; lt.monkey_patch() +# import lovely_tensors as lt; lt.monkey_patch() logger = logging.get_logger(__name__) From ec6f9541884817c50c0622daa39428e6af5a73c1 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 28 Feb 2024 10:18:48 +0000 Subject: [PATCH 25/57] clean utils in mamba --- src/nanotron/models/mamba_slow/mamba.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 748ee766..242fcacd 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -603,26 +603,6 @@ def forward_with_hidden_states( return fp32_sharded_logits, hidden_states - #TODO(fmom): clean this - def _print_param_stat(self, param): - print(f"\tmin={param.min().item()}") - print(f"\tmean={param.mean().item()}") - print(f"\tmedian={param.median().item()}") - print(f"\tmax={param.max().item()}") - - - def _print_all_param_stats(self, msg): - print(msg) - named_parameters = list(self.named_parameters()) - named_parameters.sort(key=lambda x: x[0]) - for name, param in named_parameters: - print(name) - print(f"\tmin={param.min().item()}") - print(f"\tmean={param.mean().item()}") - print(f"\tmedian={param.median().item()}") - print(f"\tmax={param.max().item()}") - print("================") - def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" From 177342052fd63d53dce30b7130a3495f2689b270 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 28 Feb 2024 10:19:17 +0000 Subject: [PATCH 26/57] cleaning initialization for mamba --- src/nanotron/models/mamba_slow/mamba.py | 67 ++++++++++++++++--------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba_slow/mamba.py index 242fcacd..8eed7d2e 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba_slow/mamba.py @@ -154,13 +154,14 @@ def __init__( self.dt_proj = nn.Linear(self.dt_rank, self.d_inner // self.tp_pg.size(), bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization - dt_init_std = self.dt_rank**-0.5 * dt_scale - if dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) - else: - raise NotImplementedError + # Perform in `def init_model_randomly` + # dt_init_std = self.dt_rank**-0.5 * dt_scale + # if dt_init == "constant": + # nn.init.constant_(self.dt_proj.weight, dt_init_std) + # elif dt_init == "random": + # nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + # else: + # raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( @@ -757,6 +758,11 @@ def init_model_randomly(self, config): num_hidden_layers = config.model.model_config.num_hidden_layers rescale_prenorm_residual = config.model.init_method.rescale_prenorm_residual + if config.model.model_config.ssm_cfg is not None: + dt_init = config.model.model_config.ssm_cfg["dt_init"] + d_model = config.model.model_config.ssm_cfg["d_model"] + dt_rank = config.model.model_config.ssm_cfg["dt_rank"] + dt_scale = config.model.model_config.ssm_cfg["dt_scale"] for param_name, param in model.named_parameters(): assert isinstance(param, NanotronParameter) @@ -798,7 +804,18 @@ def init_model_randomly(self, config): # Having just p *= scale would repeatedly scale it down with torch.no_grad(): module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers) - + + elif isinstance(module, nn.Conv1d): + fan_in = None + if "weight" == param_name: + fan_in, _ = init._calculate_fan_in_and_fan_out(param) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + elif "bias" == param_name: + bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + init.uniform_(module.bias, -bound, bound) + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, nn.Linear): fan_in = None @@ -811,24 +828,24 @@ def init_model_randomly(self, config): else: raise ValueError(f"Who the fuck is {param_name}?") - if rescale_prenorm_residual and full_param_name.endswith("out_proj.weight"): - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - with torch.no_grad(): - module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers) - elif isinstance(module, nn.Conv1d): - print("TODO: handle covn1d. For now, it is initialiazed in Mamba constructor") - pass + + if config.model.model_config.ssm_cfg is not None: + + if dt_rank == "auto": + dt_init_std = math.ceil(d_model / 16)**-0.5 * dt_scale + else: + dt_init_std = dt_rank**-0.5 * dt_scale + + if dt_init == "constant": + nn.init.constant_(module.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(module.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + elif isinstance(module, TensorParallelEmbedding): nn.init.normal_(module.weight, std=initializer_range) + elif isinstance(module, RMSNorm) or isinstance(module, nn.LayerNorm): if "weight" == param_name: # TODO @thomasw21: Sometimes we actually want 0 @@ -837,11 +854,13 @@ def init_model_randomly(self, config): module.bias.zero_() else: raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, Mamba): # NOTE(fmom): nn.Parameter are initialized in Mamba __init__ # In Mamba, only those 3 parameters don't have weight decay. if param_name in ["dt_bias", "A_log", "D"]: param._no_weight_decay = True + else: raise Exception(f"Parameter {full_param_name} was not intialized") From 6fc0d9ab1d924e0a588b1a343dd63eebdd620cf9 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 28 Feb 2024 10:37:05 +0000 Subject: [PATCH 27/57] no sync conv and rmsnorm for mamba --- src/nanotron/serialize/main.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index aeb0c027..ecae4597 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import Optional +import os import torch from torch import nn @@ -131,9 +132,12 @@ def save( tied_info = tied_param.get_tied_info() group_ranks = tied_info.global_ranks group = parallel_context.world_ranks_to_pg[group_ranks] - assert_tensor_synced_across_pg( - tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" - ) + + # Conv1d and RMSNorm don't need to be synced for mamba + if not hasattr(config.model.model_config, "is_mamba_config"): + assert_tensor_synced_across_pg( + tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" + ) if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): check_optim_state_in_sync(optimizer, parallel_context.dp_pg) @@ -178,13 +182,16 @@ def save( src=get_global_rank(group=group, group_rank=reference_rank), group=group, ) - torch.testing.assert_close( - tensor, - reference_tensor, - atol=0, - rtol=0, - msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}", - ) + + + if not hasattr(config.model.model_config, "is_mamba_config"): + torch.testing.assert_close( + tensor, + reference_tensor, + atol=0, + rtol=0, + msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}", + ) ### dist.barrier(parallel_context.world_pg) From fa6953d803b843e090a7adf8b5fc9de99af6a488 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 28 Feb 2024 15:06:37 +0000 Subject: [PATCH 28/57] clean run_train --- run_train.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/run_train.py b/run_train.py index 5e7ac702..0c5482df 100644 --- a/run_train.py +++ b/run_train.py @@ -135,9 +135,5 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) dataloader = get_dataloader(trainer) - - print("HELLOOOOOOOOO") - print(trainer.model) - # Train trainer.train(dataloader) From 20dc9ec7f448c9f8b8cc413605c8406bd66e47a3 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 28 Feb 2024 15:07:57 +0000 Subject: [PATCH 29/57] clean run_train --- run_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/run_train.py b/run_train.py index 0c5482df..7874e424 100644 --- a/run_train.py +++ b/run_train.py @@ -135,5 +135,6 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) dataloader = get_dataloader(trainer) + # Train trainer.train(dataloader) From beb73ee09248398fbea8b8358534c05123bbcc96 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 12:31:03 +0000 Subject: [PATCH 30/57] rename folder --- .../models/{mamba_slow => mamba}/mamba.py | 27 +++++++++++++++++-- .../selective_scan_interface.py | 0 src/nanotron/trainer.py | 2 +- 3 files changed, 26 insertions(+), 3 deletions(-) rename src/nanotron/models/{mamba_slow => mamba}/mamba.py (97%) rename src/nanotron/models/{mamba_slow => mamba}/selective_scan_interface.py (100%) diff --git a/src/nanotron/models/mamba_slow/mamba.py b/src/nanotron/models/mamba/mamba.py similarity index 97% rename from src/nanotron/models/mamba_slow/mamba.py rename to src/nanotron/models/mamba/mamba.py index 8eed7d2e..0dad7521 100644 --- a/src/nanotron/models/mamba_slow/mamba.py +++ b/src/nanotron/models/mamba/mamba.py @@ -52,7 +52,7 @@ from einops import rearrange, repeat -from nanotron.models.mamba_slow.selective_scan_interface import selective_scan_fn, mamba_inner_fn +from nanotron.models.mamba.selective_scan_interface import selective_scan_fn, mamba_inner_fn try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -486,7 +486,30 @@ def forward( def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) -class MambaModel(nn.Module): +class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate( + self, + input_ids, + max_length, + top_k=1, + top_p=0.0, + min_p=0.0, + temperature=1.0, + return_dict_in_generate=False, + output_scores=False, + **kwargs, + ): + output = decode( + input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs + ) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + +class MambaModel(nn.Module, GenerationMixin): def __init__( self, config: MambaConfig, diff --git a/src/nanotron/models/mamba_slow/selective_scan_interface.py b/src/nanotron/models/mamba/selective_scan_interface.py similarity index 100% rename from src/nanotron/models/mamba_slow/selective_scan_interface.py rename to src/nanotron/models/mamba/selective_scan_interface.py diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f5e5dd0f..74ce1940 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -52,7 +52,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining -from nanotron.models.mamba_slow.mamba import MambaForTraining +from nanotron.models.mamba.mamba import MambaForTraining from brrr.models.mamba_fast.mamba import MambaFastForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext From d6ae6042b98fa55522d56850e877c82a684eabdb Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 12:54:47 +0000 Subject: [PATCH 31/57] add assert only ALL_REDUCE mode --- src/nanotron/models/mamba/mamba.py | 31 +++++++----------------------- 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/src/nanotron/models/mamba/mamba.py b/src/nanotron/models/mamba/mamba.py index 0dad7521..57d1d423 100644 --- a/src/nanotron/models/mamba/mamba.py +++ b/src/nanotron/models/mamba/mamba.py @@ -108,6 +108,8 @@ def __init__( self.layer_idx = layer_idx tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + assert tp_mode == TensorParallelLinearMode.REDUCE_SCATTER or not parallel_config.tp_linear_async_communication; "Only ALL_REDUCE and tp_linear_async_communication=False are supported" + tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) @@ -205,6 +207,10 @@ def forward(self, hidden_states, inference_params=None): hidden_states: (B, L, D) Returns: same shape as hidden_states """ + + if inference_params is not None: + raise NotImplementedError("Inference params not tested yet.") + batch, seqlen, dim = hidden_states.shape conv_state, ssm_state = None, None @@ -486,30 +492,7 @@ def forward( def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) -class GenerationMixin: - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - raise NotImplementedError - - def generate( - self, - input_ids, - max_length, - top_k=1, - top_p=0.0, - min_p=0.0, - temperature=1.0, - return_dict_in_generate=False, - output_scores=False, - **kwargs, - ): - output = decode( - input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs - ) - if not output_scores: - output.scores = None - return output if return_dict_in_generate else output.sequences - -class MambaModel(nn.Module, GenerationMixin): +class MambaModel(nn.Module): def __init__( self, config: MambaConfig, From 5d161b195f0a2e0d3e7b6107dab47e2d0021bf72 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 13:07:28 +0000 Subject: [PATCH 32/57] add typing to mamba --- src/nanotron/models/mamba/mamba.py | 40 +++++++++++++++--------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/nanotron/models/mamba/mamba.py b/src/nanotron/models/mamba/mamba.py index 57d1d423..714ee727 100644 --- a/src/nanotron/models/mamba/mamba.py +++ b/src/nanotron/models/mamba/mamba.py @@ -77,24 +77,24 @@ class Mamba(nn.Module): def __init__( self, - d_model, + d_model: int, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, - d_state=16, - d_conv=4, - expand=2, - dt_rank="auto", - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - conv_bias=True, - bias=False, - use_fast_path=True, # Fused kernel options - layer_idx=None, - device=None, - dtype=None, + d_state: int=16, + d_conv: int=4, + expand: int=2, + dt_rank: str="auto", + dt_min: float=0.001, + dt_max: float=0.1, + dt_init: str="random", + dt_scale: float=1.0, + dt_init_floor: float=1e-4, + conv_bias: bool=True, + bias: bool=False, + use_fast_path: bool=True, # Fused kernel options + layer_idx: Optional[int]=None, + device: Optional[torch.device]=None, + dtype: Optional[torch.dtype]=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -202,7 +202,7 @@ def __init__( contiguous_chunks=None ) - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states @@ -294,7 +294,7 @@ def forward(self, hidden_states, inference_params=None): out = self.out_proj(y) return out - def step(self, hidden_states, conv_state, ssm_state): + def step(self, hidden_states: Union[torch.Tensor, TensorPointer], conv_state: torch.Tensor, ssm_state: torch.Tensor): dtype = hidden_states.dtype assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) @@ -341,7 +341,7 @@ def step(self, hidden_states, conv_state, ssm_state): out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype=None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( @@ -354,7 +354,7 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) ) return conv_state, ssm_state - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + def _get_states_from_cache(self, inference_params, batch_size: int, initialize_states: bool=False): assert self.layer_idx is not None if self.layer_idx not in inference_params.key_value_memory_dict: batch_shape = (batch_size,) From befb66dc71945f1200c34b721c5d1de83612b974 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 13:14:03 +0000 Subject: [PATCH 33/57] remove init method in favor of new initialization method --- src/nanotron/models/llama.py | 1 - src/nanotron/utils.py | 20 -------------------- 2 files changed, 21 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 66b260c4..b930e0eb 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -902,7 +902,6 @@ def init_model_randomly(self, config): for param_name, param in model.named_parameters(): assert isinstance(param, NanotronParameter) - #TODO(fmom): Make sure it works with PP=2 module_name, param_name = param_name.rsplit('.', 1) if param.is_tied: diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 24c27d2d..14fe1ca8 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -123,26 +123,6 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() - -def init_method_normal(sigma: float) -> Callable[[torch.Tensor], None]: - """Init method based on N(0, sigma).""" - - def init_(tensor: torch.Tensor): - torch.nn.init.normal_(tensor, mean=0.0, std=sigma) - - return init_ - - -def scaled_init_method_normal(sigma: float, num_layers: int, scale: int = 2) -> Callable[[torch.Tensor], None]: - """Default: Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(scale * num_layers) - - def init_(tensor: torch.Tensor): - torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device From 35370b75a1d9ea70f6e77a0da72337b29f25ab67 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 10:34:52 +0000 Subject: [PATCH 34/57] fix test serialize --- src/nanotron/optim/named_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index 0d1f2188..db7630cd 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -11,8 +11,8 @@ class NamedOptimizer(InheritFromOtherOptimizer): def __init__( self, named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]], - weight_decay: float, optimizer_builder: Callable[[Iterable[Dict[str, Any]]], torch.optim.Optimizer], + weight_decay: float = 0.0, ): id_to_name_decay, id_to_name_no_decay = {}, {} From 790eddcfffe1e88c01f94267f41f6bc8056c47bf Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 14:16:40 +0000 Subject: [PATCH 35/57] small fix --- src/nanotron/config/config.py | 15 +- src/nanotron/models/mamba/mamba.py | 272 ++++++------ .../models/mamba/selective_scan_interface.py | 393 +++++++++++++----- src/nanotron/trainer.py | 13 +- 4 files changed, 460 insertions(+), 233 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index b91817e8..ad402699 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -11,7 +11,12 @@ from yaml.loader import SafeLoader from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, MambaInit +from nanotron.config.models_config import ( + ExistingCheckpointInit, + MambaInit, + NanotronConfigs, + RandomInit, +) from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( RecomputeGranularity, @@ -21,9 +26,7 @@ ) from nanotron.generation.sampler import SamplerType from nanotron.logging import get_logger -from nanotron.parallel.pipeline_parallel.engine import ( - PipelineEngine, -) +from nanotron.parallel.pipeline_parallel.engine import PipelineEngine from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode logger = get_logger(__name__) @@ -79,7 +82,7 @@ def __post_init__(self): @dataclass class PretrainDatasetsArgs: - hf_dataset_mixer: Union[str, list, dict] + hf_dataset_or_datasets: Union[str, list, dict] hf_dataset_splits: Optional[Union[str, list]] = None hf_dataset_config_name: Optional[str] = None dataset_processing_num_proc_per_process: Optional[int] = 1 @@ -384,7 +387,7 @@ def get_config_from_file( skip_unused_config_keys: bool = False, skip_null_keys: bool = False, ) -> Config: - """Get a config objet from a file (python or YAML) + """Get a config object from a file (python or YAML) Args: config_path: path to the config file diff --git a/src/nanotron/models/mamba/mamba.py b/src/nanotron/models/mamba/mamba.py index 714ee727..5f1d5046 100644 --- a/src/nanotron/models/mamba/mamba.py +++ b/src/nanotron/models/mamba/mamba.py @@ -14,24 +14,32 @@ # limitations under the License. """ PyTorch Mamba model. """ +import math import os -from typing import Dict, Optional, Union -from torch.nn import init from functools import partial +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch.nn import init from nanotron import distributed as dist from nanotron import logging -from nanotron.config.utils_config import cast_str_to_torch_dtype from nanotron.config import ParallelismArgs +from nanotron.config.models_config import MambaConfig +from nanotron.config.utils_config import cast_str_to_torch_dtype +from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel -from nanotron.generation.generate_store import AttachableStore +from nanotron.models.mamba.selective_scan_interface import ( + mamba_inner_fn, + selective_scan_fn, +) from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.parallel.pipeline_parallel.block import ( - PipelineBlock, - TensorPointer, -) +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -41,18 +49,6 @@ TensorParallelRowLinear, ) from nanotron.random import RandomStates -from nanotron.config.models_config import MambaConfig -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from einops import rearrange, repeat - -from nanotron.models.mamba.selective_scan_interface import selective_scan_fn, mamba_inner_fn try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -80,21 +76,21 @@ def __init__( d_model: int, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, - d_state: int=16, - d_conv: int=4, - expand: int=2, - dt_rank: str="auto", - dt_min: float=0.001, - dt_max: float=0.1, - dt_init: str="random", - dt_scale: float=1.0, - dt_init_floor: float=1e-4, - conv_bias: bool=True, - bias: bool=False, - use_fast_path: bool=True, # Fused kernel options - layer_idx: Optional[int]=None, - device: Optional[torch.device]=None, - dtype: Optional[torch.dtype]=None, + d_state: int = 16, + d_conv: int = 4, + expand: int = 2, + dt_rank: str = "auto", + dt_min: float = 0.001, + dt_max: float = 0.1, + dt_init: str = "random", + dt_scale: float = 1.0, + dt_init_floor: float = 1e-4, + conv_bias: bool = True, + bias: bool = False, + use_fast_path: bool = True, # Fused kernel options + layer_idx: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -108,16 +104,17 @@ def __init__( self.layer_idx = layer_idx tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - assert tp_mode == TensorParallelLinearMode.REDUCE_SCATTER or not parallel_config.tp_linear_async_communication; "Only ALL_REDUCE and tp_linear_async_communication=False are supported" - + assert tp_mode == TensorParallelLinearMode.REDUCE_SCATTER or not parallel_config.tp_linear_async_communication + "Only ALL_REDUCE and tp_linear_async_communication=False are supported" + tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - + # Get current tensor parallel rank self.tp_pg = tp_pg self.tp_rank = dist.get_rank(self.tp_pg) - + self.in_proj = TensorParallelColumnLinear( in_features=self.d_model, out_features=self.d_inner * 2, @@ -125,11 +122,11 @@ def __init__( mode=tp_mode, bias=False, async_communication=False, - contiguous_chunks=None + contiguous_chunks=None, ) - + assert self.d_inner % self.tp_pg.size() == 0 - + self.conv1d = nn.Conv1d( in_channels=self.d_inner // self.tp_pg.size(), out_channels=self.d_inner // self.tp_pg.size(), @@ -150,9 +147,9 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - contiguous_chunks=None + contiguous_chunks=None, ) - + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner // self.tp_pg.size(), bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -199,7 +196,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - contiguous_chunks=None + contiguous_chunks=None, ) def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_params=None): @@ -207,7 +204,7 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p hidden_states: (B, L, D) Returns: same shape as hidden_states """ - + if inference_params is not None: raise NotImplementedError("Inference params not tested yet.") @@ -224,9 +221,11 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p # We do matmul and transpose BLH -> HBL at the same time xz = self.in_proj(hidden_states).transpose(1, 2) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - + # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None and os.environ.get("FAST_PATH", "0") == "1": # Doesn't support outputting the states + if ( + self.use_fast_path and inference_params is None and os.environ.get("FAST_PATH", "0") == "1" + ): # Doesn't support outputting the states y = mamba_inner_fn( d_inner=self.d_inner, tp_pg=self.tp_pg, @@ -236,8 +235,8 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p x_proj_weight=self.x_proj.weight, delta_proj_weight=self.dt_proj.weight, A=A, - B=None, # input-dependent B - C=None, # input-dependent C + B=None, # input-dependent B + C=None, # input-dependent C D=self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, @@ -255,7 +254,6 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p if causal_conv1d_fn is None: x = self.act(self.conv1d(x)[..., :seqlen]) else: - assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x=x, @@ -290,11 +288,16 @@ def forward(self, hidden_states: Union[torch.Tensor, TensorPointer], inference_p y, last_state = y ssm_state.copy_(last_state) y = rearrange(y, "b d l -> b l d") - + out = self.out_proj(y) return out - def step(self, hidden_states: Union[torch.Tensor, TensorPointer], conv_state: torch.Tensor, ssm_state: torch.Tensor): + def step( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): dtype = hidden_states.dtype assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) @@ -335,29 +338,45 @@ def step(self, hidden_states: Union[torch.Tensor, TensorPointer], conv_state: to y = y * self.act(z) # (B D) else: y = selective_state_update( - ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ssm_state, + x, + dt, + A, + B, + C, + self.D, + z=z, + dt_bias=self.dt_proj.bias, + dt_softplus=True, ) out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state - def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype=None, **kwargs): + def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype + batch_size, + self.d_model * self.expand, + self.d_conv, + device=device, + dtype=conv_dtype, ) ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype # ssm_dtype = torch.float32 ssm_state = torch.zeros( - batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype + batch_size, + self.d_model * self.expand, + self.d_state, + device=device, + dtype=ssm_dtype, ) return conv_state, ssm_state - def _get_states_from_cache(self, inference_params, batch_size: int, initialize_states: bool=False): + def _get_states_from_cache(self, inference_params, batch_size: int, initialize_states: bool = False): assert self.layer_idx is not None if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (batch_size,) conv_state = torch.zeros( batch_size, self.d_model * self.expand, @@ -373,7 +392,10 @@ def _get_states_from_cache(self, inference_params, batch_size: int, initialize_s dtype=self.dt_proj.weight.dtype, # dtype=torch.float32, ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + inference_params.key_value_memory_dict[self.layer_idx] = ( + conv_state, + ssm_state, + ) else: conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] # TODO: What if batch size changes between generation, and we reuse the same states? @@ -382,8 +404,14 @@ def _get_states_from_cache(self, inference_params, batch_size: int, initialize_s ssm_state.zero_() return conv_state, ssm_state + class Embedding(nn.Module, AttachableStore): - def __init__(self, tp_pg: dist.ProcessGroup, config: MambaConfig, parallel_config: Optional[ParallelismArgs]): + def __init__( + self, + tp_pg: dist.ProcessGroup, + config: MambaConfig, + parallel_config: Optional[ParallelismArgs], + ): super().__init__() self.token_embedding = TensorParallelEmbedding( num_embeddings=config.vocab_size, @@ -411,6 +439,7 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ input_embeds = self.token_embedding(input_ids) return {"input_embeds": input_embeds} + class MambaDecoderLayer(nn.Module): def __init__( self, @@ -420,11 +449,11 @@ def __init__( layer_idx: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ): super().__init__() - + factory_kwargs = {"device": device, "dtype": dtype} - + if config.ssm_cfg is None: ssm_cfg = {} else: @@ -433,28 +462,28 @@ def __init__( self.layer_idx = layer_idx self.residual_in_fp32 = config.residual_in_fp32 self.fused_add_norm = config.fused_add_norm - + self.mixer = Mamba( d_model=config.d_model, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx, **ssm_cfg, - **factory_kwargs + **factory_kwargs, ) - + self.norm = partial( - nn.LayerNorm if not config.rms_norm - else RMSNorm, eps=config.rms_norm_eps, **factory_kwargs + nn.LayerNorm if not config.rms_norm else RMSNorm, + eps=config.rms_norm_eps, + **factory_kwargs, )(config.d_model) - + if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" assert isinstance( self.norm, (nn.LayerNorm, RMSNorm) ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" - def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -462,7 +491,6 @@ def forward( residual: Optional[Union[torch.Tensor, TensorPointer]], inference_params=None, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - if not self.fused_add_norm: # self.layer_idx was assigned when calling create_block # residual=None happens only at the first block @@ -482,16 +510,17 @@ def forward( eps=self.norm.eps, ) hidden_states = self.mixer(hidden_states, inference_params=inference_params) - + return { "hidden_states": hidden_states, "sequence_mask": sequence_mask, # NOTE(fmom): dunno how to use it for now. Just keep it "residual": residual, } - + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + class MambaModel(nn.Module): def __init__( self, @@ -501,7 +530,7 @@ def __init__( random_states: Optional[RandomStates] = None, ): super().__init__() - + # Declare all the nodes self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) self.config = config @@ -577,7 +606,6 @@ def __init__( module_output_keys={"output"}, ) - def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] @@ -603,14 +631,16 @@ def forward_with_hidden_states( for block in self.decoder: hidden_encoder_states = block(**hidden_encoder_states) - hidden_states = self.final_layer_norm(x=hidden_encoder_states["hidden_states"], residual=hidden_encoder_states["residual"])["hidden_states"] + hidden_states = self.final_layer_norm( + x=hidden_encoder_states["hidden_states"], + residual=hidden_encoder_states["residual"], + )["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] return fp32_sharded_logits, hidden_states - def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" # model_config = self.config @@ -623,7 +653,6 @@ def get_block_compute_costs(self): # # This is the last lm_head # TensorParallelColumnLinear: model_config.vocab_size * model_config.d_model, # } - model_config = self.config block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP @@ -631,7 +660,12 @@ def get_block_compute_costs(self): # This is the last lm_head TensorParallelColumnLinear: 0, } - log_rank(f"get_block_compute_costs() Not implemented yet", logger=logger, level=logging.INFO, rank=0) + log_rank( + "get_block_compute_costs() Not implemented yet", + logger=logger, + level=logging.INFO, + rank=0, + ) return block_compute_costs def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): @@ -656,19 +690,27 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch # model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) # hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) - + # TODO(fmom): undo hardcoding of model_flops_per_s and hardware_flops_per_s model_flops_per_s = 0 hardware_flops_per_s = 0 - log_rank(f"get_flops_per_sec() Not implemented yet", logger=logger, level=logging.INFO, rank=0) + log_rank( + "get_flops_per_sec() Not implemented yet", + logger=logger, + level=logging.INFO, + rank=0, + ) return model_flops_per_s, hardware_flops_per_s torch.jit.script + + def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + class Loss(nn.Module): def __init__(self, tp_pg: dist.ProcessGroup): super().__init__() @@ -682,16 +724,14 @@ def forward( ) -> Dict[str, torch.Tensor]: # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 - - #NOTE(fmom): undo transpose for now since Mamba is not using TP + + # NOTE(fmom): undo transpose for now since Mamba is not using TP # loss = sharded_cross_entropy( # sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float # ).transpose(0, 1) - - loss = sharded_cross_entropy( - sharded_logits, label_ids, group=self.tp_pg, dtype=torch.float - ) - + + loss = sharded_cross_entropy(sharded_logits, label_ids, group=self.tp_pg, dtype=torch.float) + # TODO @thomasw21: It's unclear what kind of normalization we want to do. loss = masked_mean(loss, label_mask, dtype=torch.float) # I think indexing causes a sync we don't actually want @@ -708,14 +748,14 @@ def __init__( random_states: Optional[RandomStates] = None, ): super().__init__() - + self.model = MambaModel( config=config, parallel_context=parallel_context, parallel_config=parallel_config, random_states=random_states, ) - + self.loss = PipelineBlock( p2p=self.model.p2p, module_builder=Loss, @@ -730,7 +770,7 @@ def __init__( self.parallel_context = parallel_context self.config = config self.parallel_config = parallel_config - + def forward( self, input_ids: Union[torch.Tensor, TensorPointer], @@ -753,7 +793,7 @@ def forward( def init_model_randomly(self, config): model = self initialized_parameters = set() - + # Handle tensor parallelism module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} # Fix the root_model @@ -763,18 +803,18 @@ def init_model_randomly(self, config): n_residuals_per_layer = config.model.init_method.n_residuals_per_layer num_hidden_layers = config.model.model_config.num_hidden_layers rescale_prenorm_residual = config.model.init_method.rescale_prenorm_residual + d_model = config.model.model_config.d_model if config.model.model_config.ssm_cfg is not None: dt_init = config.model.model_config.ssm_cfg["dt_init"] - d_model = config.model.model_config.ssm_cfg["d_model"] dt_rank = config.model.model_config.ssm_cfg["dt_rank"] dt_scale = config.model.model_config.ssm_cfg["dt_scale"] for param_name, param in model.named_parameters(): assert isinstance(param, NanotronParameter) - - module_name, param_name = param_name.rsplit('.', 1) - + + module_name, param_name = param_name.rsplit(".", 1) + if param.is_tied: tied_info = param.get_tied_info() full_param_name = tied_info.get_full_name_from_module_id_to_prefix( @@ -788,7 +828,7 @@ def init_model_randomly(self, config): continue module = model.get_submodule(module_name) - + if isinstance(module, TensorParallelColumnLinear) or isinstance(module, TensorParallelRowLinear): if "weight" == param_name: init.kaiming_uniform_(module.weight, a=math.sqrt(5)) @@ -796,7 +836,7 @@ def init_model_randomly(self, config): raise ValueError("We don't use bias for TensorParallelColumnLinear and TensorParallelRow") else: raise ValueError(f"Who the fuck is {param_name}?") - + if rescale_prenorm_residual and full_param_name.endswith("out_proj.weight"): # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -810,38 +850,36 @@ def init_model_randomly(self, config): # Having just p *= scale would repeatedly scale it down with torch.no_grad(): module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers) - + elif isinstance(module, nn.Conv1d): - fan_in = None + fan_in = None if "weight" == param_name: fan_in, _ = init._calculate_fan_in_and_fan_out(param) init.kaiming_uniform_(module.weight, a=math.sqrt(5)) - elif "bias" == param_name: + elif "bias" == param_name: bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 init.uniform_(module.bias, -bound, bound) else: raise ValueError(f"Who the fuck is {param_name}?") - + elif isinstance(module, nn.Linear): fan_in = None - + if "weight" == param_name: fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) init.kaiming_uniform_(module.weight, a=math.sqrt(5)) - elif "bias" == param_name: + elif "bias" == param_name: bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 init.uniform_(module.bias, -bound, bound) else: raise ValueError(f"Who the fuck is {param_name}?") - if config.model.model_config.ssm_cfg is not None: - if dt_rank == "auto": - dt_init_std = math.ceil(d_model / 16)**-0.5 * dt_scale + dt_init_std = math.ceil(d_model / 16) ** -0.5 * dt_scale else: - dt_init_std = dt_rank**-0.5 * dt_scale - + dt_init_std = dt_rank**-0.5 * dt_scale + if dt_init == "constant": nn.init.constant_(module.weight, dt_init_std) elif dt_init == "random": @@ -860,19 +898,19 @@ def init_model_randomly(self, config): module.bias.zero_() else: raise ValueError(f"Who the fuck is {param_name}?") - + elif isinstance(module, Mamba): - # NOTE(fmom): nn.Parameter are initialized in Mamba __init__ + # NOTE(fmom): nn.Parameter are initialized in Mamba __init__ # In Mamba, only those 3 parameters don't have weight decay. if param_name in ["dt_bias", "A_log", "D"]: param._no_weight_decay = True - + else: - raise Exception(f"Parameter {full_param_name} was not intialized") - + raise Exception(f"Parameter {full_param_name} was not initialized") + assert full_param_name not in initialized_parameters initialized_parameters.add(full_param_name) - + assert initialized_parameters == { param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) if param.is_tied @@ -886,7 +924,7 @@ def get_embeddings_lm_head_tied_names(): "model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight", ] - + # TODO(fmom): implement get_block_compute_costs def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" diff --git a/src/nanotron/models/mamba/selective_scan_interface.py b/src/nanotron/models/mamba/selective_scan_interface.py index 3aff7589..fab856b0 100644 --- a/src/nanotron/models/mamba/selective_scan_interface.py +++ b/src/nanotron/models/mamba/selective_scan_interface.py @@ -1,21 +1,29 @@ # Copyright (c) 2023, Tri Dao, Albert Gu. +import causal_conv1d_cuda +import selective_scan_cuda import torch import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd - -from einops import rearrange, repeat - from causal_conv1d import causal_conv1d_fn -import causal_conv1d_cuda -import selective_scan_cuda +from einops import rearrange, repeat +from torch.cuda.amp import custom_bwd, custom_fwd class SelectiveScanFn(torch.autograd.Function): - @staticmethod - def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + def forward( + ctx, + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + ): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -34,7 +42,9 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd( + u, delta, A, B, C, D, z, delta_bias, delta_softplus + ) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) @@ -60,31 +70,71 @@ def backward(ctx, dout, *args): # backward of selective_scan_cuda with the backward of chunk). # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( - u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + dout, + x, + out, + None, + ctx.delta_softplus, + False, # option to recompute out_z, not used here ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC - return (du, ddelta, dA, dB, dC, - dD if D is not None else None, - dz, - ddelta_bias if delta_bias is not None else None, - None, - None) + return ( + du, + ddelta, + dA, + dB, + dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None, + ) -def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): +def selective_scan_fn( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, +): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply( + u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state + ) -def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): +def selective_scan_ref( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, +): """ u: r(B D L) delta: r(B D L) @@ -110,41 +160,45 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: - B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + B = torch.view_as_complex( + rearrange(B.float(), "... (L two) -> ... L two", two=2) + ) if is_variable_C: - C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + C = torch.view_as_complex( + rearrange(C.float(), "... (L two) -> ... L two", two=2) + ) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u) else: if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) + y = torch.einsum("bdn,dn->bd", x, C) else: if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) + y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) @@ -153,14 +207,29 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta class MambaInnerFn(torch.autograd.Function): - @staticmethod @custom_fwd - def forward(ctx, d_inner, tp_pg, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + def forward( + ctx, + d_inner, + tp_pg, + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, + checkpoint_lvl=1, + ): """ - xz: (batch, dim, seqlen) + xz: (batch, dim, seqlen) """ assert checkpoint_lvl in [0, 1] batch, L = xz.shape[0], xz.shape[-1] @@ -168,38 +237,48 @@ def forward(ctx, d_inner, tp_pg, xz, conv1d_weight, conv1d_bias, x_proj_weight, d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to( + dtype=torch.get_autocast_gpu_dtype() + ) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") - + # x, z = xz.chunk(2, dim=1) assert d_inner % tp_pg.size() == 0 x, z = xz.view(batch, d_inner // tp_pg.size(), 2, L).chunk(2, dim=2) x = x.squeeze(2) z = z.squeeze(2) - + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, True + ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + x_dbl = F.linear( + rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight + ) # (bl d) + delta = rearrange( + delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L + ) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B - B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) + B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: - B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + B = rearrange( + B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 + ).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() @@ -211,7 +290,9 @@ def forward(ctx, d_inner, tp_pg, xz, conv1d_weight, conv1d_bias, x_proj_weight, # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: - C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + C = rearrange( + C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 + ).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() @@ -223,59 +304,114 @@ def forward(ctx, d_inner, tp_pg, xz, conv1d_weight, conv1d_bias, x_proj_weight, ctx.delta_softplus = delta_softplus # ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl - if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + if ( + checkpoint_lvl >= 1 + ): # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None - + ctx.d_inner = d_inner ctx.tp_pg = tp_pg - ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, - delta_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, out) - + ctx.save_for_backward( + xz, + conv1d_weight, + conv1d_bias, + x_dbl, + x_proj_weight, + delta_proj_weight, + conv1d_out, + delta, + A, + B, + C, + D, + delta_bias, + scan_intermediates, + out, + ) + return rearrange(out_z, "b d l -> b l d") @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) - (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + ( + xz, + conv1d_weight, + conv1d_bias, + x_dbl, + x_proj_weight, + delta_proj_weight, + conv1d_out, + delta, + A, + B, + C, + D, + delta_bias, + scan_intermediates, + out, + ) = ctx.saved_tensors batch, L = xz.shape[0], xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) - + # x, z = xz.chunk(2, dim=1) assert ctx.d_inner % ctx.tp_pg.size() == 0 x, z = xz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2) x = x.squeeze(2) z = z.squeeze(2) - + if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) - delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), - "d (b l) -> b d l", l = L) + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, True + ) + delta = rearrange( + delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L + ) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) - + # dx, dz = dxz.chunk(2, dim=1) assert ctx.d_inner % ctx.tp_pg.size() == 0 dx, dz = dxz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2) dx = dx.squeeze(2) dz = dz.squeeze(2) - + dout = rearrange(dout, "b l e -> b e l") - + if dout.stride(-1) != 1: dout = dout.contiguous() - - dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz, + + ( + dconv1d_out, + ddelta, + dA, + dB, + dC, + dD, + ddelta_bias, + dz, + out_z, + ) = selective_scan_cuda.bwd( + conv1d_out, + delta, + A, + B, + C, + D, + z, + delta_bias, + dout, + scan_intermediates, + out, + dz, ctx.delta_softplus, - True # option to recompute out_z + True, # option to recompute out_z ) - + dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None @@ -283,16 +419,20 @@ def backward(ctx, dout): if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: - dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB = rearrange( + dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 + ).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None - dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) + dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: - dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC = rearrange( + dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 + ).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None @@ -300,9 +440,15 @@ def backward(ctx, dout): ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") - dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) - dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) - dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + dx_proj_weight = torch.einsum( + "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d") + ) + dconv1d_out = torch.addmm( + dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out + ) + dconv1d_out = rearrange( + dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1] + ) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( @@ -310,64 +456,101 @@ def backward(ctx, dout): ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (None, # d_inner - None, # tp_pg - dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, - dA, dB, dC, dD, - ddelta_bias if delta_bias is not None else None, - dB_proj_bias, dC_proj_bias, None) + return ( + None, # d_inner + None, # tp_pg + dxz, + dconv1d_weight, + dconv1d_bias, + dx_proj_weight, + ddelta_proj_weight, + dA, + dB, + dC, + dD, + ddelta_bias if delta_bias is not None else None, + dB_proj_bias, + dC_proj_bias, + None, + ) def mamba_inner_fn( - d_inner, tp_pg, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + d_inner, + tp_pg, + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, ): - return MambaInnerFn.apply( d_inner, tp_pg, xz, - conv1d_weight, + conv1d_weight, conv1d_bias, - x_proj_weight, + x_proj_weight, delta_proj_weight, - A, - B, - C, - D, + A, + B, + C, + D, delta_bias, - B_proj_bias, - C_proj_bias, - delta_softplus + B_proj_bias, + C_proj_bias, + delta_softplus, ) def mamba_inner_ref( - xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, - out_proj_weight, out_proj_bias, - A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + xz, + conv1d_weight, + conv1d_bias, + x_proj_weight, + delta_proj_weight, + out_proj_weight, + out_proj_bias, + A, + B=None, + C=None, + D=None, + delta_bias=None, + B_proj_bias=None, + C_proj_bias=None, + delta_softplus=True, ): L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") + x = causal_conv1d_fn( + x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu" + ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B - B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) + B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: - B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + B = rearrange( + B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 + ).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: @@ -375,6 +558,10 @@ def mamba_inner_ref( if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: - C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + C = rearrange( + C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 + ).contiguous() + y = selective_scan_fn( + x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True + ) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 74ce1940..51192603 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,6 +19,7 @@ ) import torch +from brrr.models.mamba_fast.mamba import MambaFastForTraining from torch.nn.parallel import DistributedDataParallel from nanotron import distributed as dist @@ -26,9 +27,9 @@ from nanotron.config import ( Config, ExistingCheckpointInit, + MambaInit, ParallelismArgs, RandomInit, - MambaInit, get_config_from_file, ) from nanotron.dataloader import sanity_check_dataloader @@ -51,9 +52,8 @@ from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding -from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.models.mamba.mamba import MambaForTraining -from brrr.models.mamba_fast.mamba import MambaFastForTraining +from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -87,7 +87,6 @@ save, save_random_states, ) -from nanotron.utils import init_method_normal, scaled_init_method_normal logger = logging.get_logger(__name__) @@ -584,7 +583,7 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: root_folder=self.config.model.init_method.path, ) elif isinstance(self.config.model.init_method, MambaInit): - + unwrapped_model.init_model_randomly(config=self.config) # Synchronize parameters so that the model is consistent # sync all params across dp @@ -602,9 +601,9 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: group = self.parallel_context.world_ranks_to_pg[group_ranks] dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) elif isinstance(self.config.model.init_method, RandomInit): - + unwrapped_model.init_model_randomly(config=self.config) - + # Synchronize parameters so that the model is consistent # sync all params across dp for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): From 37867d12418a77474738bc9a96e9b88122668acf Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 14:17:43 +0000 Subject: [PATCH 36/57] add mamba example --- examples/mamba/README.md | 25 ++++ examples/mamba/config_mamba.py | 151 ++++++++++++++++++++ examples/mamba/config_mamba.yaml | 101 +++++++++++++ examples/mamba/configs/make_config_mamba.py | 106 -------------- examples/mamba/run.sh | 29 ---- examples/mamba/train_mamba.sh | 24 ++++ 6 files changed, 301 insertions(+), 135 deletions(-) create mode 100644 examples/mamba/README.md create mode 100644 examples/mamba/config_mamba.py create mode 100644 examples/mamba/config_mamba.yaml delete mode 100644 examples/mamba/configs/make_config_mamba.py delete mode 100755 examples/mamba/run.sh create mode 100755 examples/mamba/train_mamba.sh diff --git a/examples/mamba/README.md b/examples/mamba/README.md new file mode 100644 index 00000000..c63452ab --- /dev/null +++ b/examples/mamba/README.md @@ -0,0 +1,25 @@ +--- +library_name: nanotron +--- + +# Mamba + +Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/nanotron/) + +## 🚀 Quickstart + +```bash +# Generate a config file +python examples/moe/config_mamba.py + +pip install einops +pip install causal-conv1d>=1.1.0,<1.2.0 +pip install mamba-ssm + +# Run training +./examples/mamba/train_mamba.sh +``` + +## Credits +Credits to the following repositories from which the code was adapted: +- https://github.com/state-spaces/mamba diff --git a/examples/mamba/config_mamba.py b/examples/mamba/config_mamba.py new file mode 100644 index 00000000..1052d34e --- /dev/null +++ b/examples/mamba/config_mamba.py @@ -0,0 +1,151 @@ +""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" +import math +import os + +from nanotron.config import ( + CheckpointsArgs, + Config, + DataArgs, + GeneralArgs, + LoggingArgs, + LRSchedulerArgs, + MambaConfig, + MambaInit, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + PretrainDatasetsArgs, + TokenizerArgs, + TokensArgs, +) +from nanotron.logging import human_format + +ssm_cfg_dtype = "bfloat16" +ssm_cfg = { + "d_state": 16, + "d_conv": 4, + "expand": 2, + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, + "bias": False, + "use_fast_path": True, +} +# https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json +model_config = MambaConfig( + d_model=1536, + num_hidden_layers=48, + vocab_size=50277, + ssm_cfg=ssm_cfg, + rms_norm=True, + fused_add_norm=True, + residual_in_fp32=True, + pad_vocab_size_multiple=8, + # Custom + dtype=ssm_cfg_dtype, + rms_norm_eps=1e-5, +) + +# NOTE: vocab_size is normally round up to the nearest multiple of 10. But here, we don't really care +tie_embedding = model_config.vocab_size * model_config.d_model # model_config.vocab_size * model_config.d_model +expand = 2 if ("expand" not in ssm_cfg) else ssm_cfg["expand"] +ngroups = 1 if ("ngroups" not in ssm_cfg) else ssm_cfg["ngroups"] +d_state = 16 if ("d_state" not in ssm_cfg) else ssm_cfg["d_state"] +d_conv = 4 if ("d_conv" not in ssm_cfg) else ssm_cfg["d_conv"] +dt_rank = ( + math.ceil(model_config.d_model / 16) + if ("dt_rank" not in ssm_cfg or ssm_cfg["dt_rank"] == "auto") + else ssm_cfg["dt_rank"] +) + +d_inner = int(expand * model_config.d_model) +in_proj = model_config.d_model * d_inner * 2 + +# conv1d.weight = out_channels * (in_channels // groups) * kernel_size +# conv1d.bias = out_channels +conv1d = d_inner * int(d_inner / d_inner) * d_conv + d_inner +# linear.weight = out_features * in_features +in_proj = model_config.d_model * d_inner * 2 + 0 +x_proj = d_inner * (dt_rank + d_state * 2) + 0 +out_proj = d_inner * model_config.d_model + 0 +dt_proj = dt_rank * d_inner + d_inner +A_log = d_inner * d_state +D = d_inner +norm = model_config.d_model +norm_f = model_config.d_model + +num_params = human_format( + ( + tie_embedding + + model_config.num_hidden_layers * (A_log + D + in_proj + conv1d + x_proj + dt_proj + out_proj + norm + norm_f) + ) +).replace(".", "p") + +print(f"Model has {num_params} parameters") + +seed = 42 + +optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=False, # NOTE(fmom): because we are using PP=TP=DP=1 + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + learning_rate_scheduler=LRSchedulerArgs( + learning_rate=3e-4, lr_warmup_steps=10, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + ), +) + +parallelism = ParallelismArgs( + dp=2, + pp=2, + tp=2, + pp_engine="1f1b", + tp_mode="ALL_REDUCE", + tp_linear_async_communication=False, +) + +tokens = TokensArgs(sequence_length=2048, train_steps=100, micro_batch_size=2, batch_accumulation_per_replica=1) + +dataset = PretrainDatasetsArgs( + hf_dataset_or_datasets={"roneneldan/TinyStories": 1.0}, + hf_dataset_config_name=None, + hf_dataset_splits="train", + dataset_processing_num_proc_per_process=24, + dataset_overwrite_cache=False, + text_column_name="text", +) + +checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" +os.makedirs(checkpoints_path, exist_ok=True) + +config = Config( + general=GeneralArgs(project="test", run="mamba", seed=seed), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + parallelism=parallelism, + model=ModelArgs( + init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), + model_config=model_config, + ), + tokenizer=TokenizerArgs("gpt2"), + optimizer=optimizer, + logging=LoggingArgs(), + tokens=tokens, + data=DataArgs(dataset=dataset, seed=seed), + profiler=None, +) + +if __name__ == "__main__": + dir = os.path.dirname(__file__) + + # Save config as YAML file + config.save_as_yaml(f"{dir}/config_mamba.yaml") + + # You can now train a model with this config using `/run_train.py` diff --git a/examples/mamba/config_mamba.yaml b/examples/mamba/config_mamba.yaml new file mode 100644 index 00000000..85492e25 --- /dev/null +++ b/examples/mamba/config_mamba.yaml @@ -0,0 +1,101 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 24 + hf_dataset_config_name: null + hf_dataset_or_datasets: + roneneldan/TinyStories: 1.0 + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: test + run: mamba + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + initializer_range: 0.02 + n_residuals_per_layer: 1 + rescale_prenorm_residual: true + make_vocab_size_divisible_by: 1 + model_config: + d_model: 1536 + dtype: bfloat16 + fused_add_norm: true + is_mamba_config: true + num_hidden_layers: 48 + pad_token_id: null + pad_vocab_size_multiple: 8 + residual_in_fp32: true + rms_norm: true + rms_norm_eps: 1.0e-05 + ssm_cfg: + bias: false + conv_bias: true + d_conv: 4 + d_state: 16 + dt_init: random + dt_init_floor: 0.0001 + dt_max: 0.1 + dt_min: 0.001 + dt_rank: auto + dt_scale: 1.0 + expand: 2 + use_fast_path: true + vocab_size: 50277 +optimizer: + accumulate_grad_in_fp32: false + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 90 + lr_decay_style: cosine + lr_warmup_steps: 10 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 2 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: gpt2 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 2048 + train_steps: 100 + val_check_interval: -1 diff --git a/examples/mamba/configs/make_config_mamba.py b/examples/mamba/configs/make_config_mamba.py deleted file mode 100644 index 4fa7442e..00000000 --- a/examples/mamba/configs/make_config_mamba.py +++ /dev/null @@ -1,106 +0,0 @@ -""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" -import os -import torch - -from nanotron.config import ( - CheckpointsArgs, - Config, - DataArgs, - GeneralArgs, - MambaConfig, - LoggingArgs, - LRSchedulerArgs, - ModelArgs, - OptimizerArgs, - ParallelismArgs, - PretrainDatasetsArgs, - MambaInit, - TokenizerArgs, - TokensArgs, -) -from nanotron.logging import human_format - -model_config = MambaConfig( - d_model=256, - num_hidden_layers=1, - vocab_size=50277, - ssm_cfg={}, - rms_norm=True, - fused_add_norm=True, - residual_in_fp32=True, - pad_vocab_size_multiple=8, - # Custom - dtype=torch.float32, - rms_norm_eps=1e-5, -) - - -#TODO(fmom): do something similar -# num_params = human_format( -# model_config.vocab_size * model_config.d_model * 2 -# + model_config.num_hidden_layers -# * ( -# 3 * model_config.d_model * model_config.intermediate_size -# + 4 * model_config.d_model * model_config.d_model -# ) -# ).replace(".", "p") - -# print(f"Model has {num_params} parameters") - -seed = 42 - -learning_rate = LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 -) - -optimizer = OptimizerArgs( - zero_stage=0, - weight_decay=0.01, - clip_grad=1.0, - accumulate_grad_in_fp32=False, #NOTE(fmom): because we are using PP=TP=DP=1 - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, - learning_rate_scheduler=learning_rate, -) - -parallelism = ParallelismArgs( - dp=1, - pp=1, - tp=1, - pp_engine="1f1b", - tp_mode="REDUCE_SCATTER", - tp_linear_async_communication=True, - recompute_granularity="selective", -) - -tokens = TokensArgs(sequence_length=1024, train_steps=40, micro_batch_size=2, batch_accumulation_per_replica=1) - -dataset = PretrainDatasetsArgs( - hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text" -) - -checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" -os.makedirs(checkpoints_path, exist_ok=True) - -config = Config( - general=GeneralArgs(project="test", run="mamba", seed=seed), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=243232232232323332), - parallelism=parallelism, - model=ModelArgs(init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1), model_config=model_config), - tokenizer=TokenizerArgs("gpt2"), - optimizer=optimizer, - logging=LoggingArgs(), - tokens=tokens, - data=DataArgs(dataset=dataset, seed=seed), - profiler=None, -) - -if __name__ == "__main__": - dir = os.path.dirname(__file__) - - # Save config as YAML file - config.save_as_yaml(f"{dir}/config_mamba.yaml") - - # You can now train a model with this config using `/run_train.py` diff --git a/examples/mamba/run.sh b/examples/mamba/run.sh deleted file mode 100755 index b8fc12b8..00000000 --- a/examples/mamba/run.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/sh - -if [ "$1" = "debug" ]; then - python configs/make_config_mamba_fast.py && \ - FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 1234 -m torch.distributed.launch \ - -- \ - --nproc_per_node=1 \ - --master_port=29600 \ - --rdzv_endpoint=localhost:6000 \ - --use_env \ - --tee=3 \ - ../../run_train.py \ - --config-file=configs/config_mamba.yaml -elif [ "$1" = "eval" ]; then - FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \ - --nproc_per_node=1 \ - --master_port 29600 \ - ../generate.py \ - --pp 1 \ - --tp 1 \ - --ckpt-path /fsx/ferdinandmom/github/mamba/checkpoints/mamba-1p62M-stas-openwebtext-10k/7 -else - python configs/make_config_mamba_fast.py && \ - FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \ - --nproc_per_node=1 \ - --master_port=29600 \ - ../../run_train.py \ - --config-file=configs/config_mamba.yaml -fi \ No newline at end of file diff --git a/examples/mamba/train_mamba.sh b/examples/mamba/train_mamba.sh new file mode 100755 index 00000000..6f675c3e --- /dev/null +++ b/examples/mamba/train_mamba.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Simple script to create a tiny llama model and train it + +set -e -x + +# Create the YAML config file + +EXAMPLE_PATH=$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P) +REPO_PATH=$(dirname $EXAMPLE_PATH) +python $EXAMPLE_PATH/config_mamba.py + +# Setup from environment variables + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export FI_PROVIDER="efa" + +python -u -m torch.distributed.run \ + --nproc_per_node 8 \ + --nnodes 1 \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + $REPO_PATH/../run_train.py --config-file $EXAMPLE_PATH/config_mamba.yaml \ No newline at end of file From 54a9c6324fb3f1745755a5fd4e0a4d34fcb011df Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 1 Mar 2024 13:18:25 +0000 Subject: [PATCH 37/57] fix logger --- src/nanotron/trainer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 51192603..713fcf0e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,7 +19,6 @@ ) import torch -from brrr.models.mamba_fast.mamba import MambaFastForTraining from torch.nn.parallel import DistributedDataParallel from nanotron import distributed as dist @@ -52,7 +51,6 @@ from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding -from nanotron.models.mamba.mamba import MambaForTraining from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext @@ -97,8 +95,6 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, - "MambaConfig": MambaForTraining, - "MambaFastConfig": MambaFastForTraining, } try: @@ -247,7 +243,7 @@ def post_init(self): def pre_training(self, *args, **kwargs): current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") - if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks and wandb is not None: wandb.init( project=self.config.general.project, name=f"{current_time}_{self.config.general.project}_{self.config.general.run}", @@ -486,7 +482,7 @@ def train_step_logs( ] ) - if wandb is not None: + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks and wandb is not None: wandb.log( {**{log_item.tag: log_item.scalar_value for log_item in log_entries}, "step": self.iteration_step} ) From 3953a5c38e1635835421ac9a1015d349b8442f47 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 18:13:43 +0000 Subject: [PATCH 38/57] update README --- examples/mamba/README.md | 4 ++++ examples/mamba/assets/loss_mamba.png | Bin 0 -> 40017 bytes 2 files changed, 4 insertions(+) create mode 100644 examples/mamba/assets/loss_mamba.png diff --git a/examples/mamba/README.md b/examples/mamba/README.md index c63452ab..bd06c9c5 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -20,6 +20,10 @@ pip install mamba-ssm ./examples/mamba/train_mamba.sh ``` +![mamba](./assets/loss_mamba.png) + +> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5 + ## Credits Credits to the following repositories from which the code was adapted: - https://github.com/state-spaces/mamba diff --git a/examples/mamba/assets/loss_mamba.png b/examples/mamba/assets/loss_mamba.png new file mode 100644 index 0000000000000000000000000000000000000000..f2bfc0408712d9885b2cee970498cbb235d61c7e GIT binary patch literal 40017 zcmeFZcRZZI*EcMU6g@!@5sQc(1kr<4LW1a>=)ITdNhBI4MW zDG3O!I^4Vpl=T0mp#vV*++|*B-n@AeIj1@cywZBe>3V25TX}fDb+aU}c5-&Ke7o4j~lGqZtE z=@WPFOSxs$&#O9QVP%CUoZ#8(jSW1#1Seu<*YHnX2nqgC|40b`*!eh)e>C0se^w5_ zyS36YGGt|CuU)$)NcXzVkQ`|BLSxDa-fLBmpU=U~jZPM;5)`^m9Fm!tnUYAFGMn-72n#n9K{HU)7?JAj=m|R_5 zYyf|6iE!>)7)O!2rYI>R{=Zlz+83jErE_Y1Q%nHBU>NuMwQIr#16CU0h#!F6`Y4 z=k(X>saO*T^zIV@xjnMRX{ch630VXL7nHzK`7?{?J$~h@D zwl%f11Z`+3DQjH@)7OSnBO;?|V!FD*4>v{wPOx?7?r!3a75ZpeYpl~uqV>UgT=9q< z+Uq>Nw=n?Q*6dg|E@Z=DP7=DMbW~;n`VksAua7Hk-Mo2U!fJDs z&%E*6s*RZphe#ubZn1?0;VXb_h`3_jEu2Q`88U^daHh*qy*M*ks?08J~oyw;gR3x)4l4> zUkjtsEkcaMyk~sQDGNq8K&7Rns@Rpp>QA3?b8|Ui#uu=Thi~Y}82MyGPFBpLFKX9l zC{Eo@SGMa<<$-zV=_T&{g}lj#3$|pe$E|l)nwQQ5?nizICUl&*FsK_LA+T|N6-G#a z)6y_MJxBFnE;KsYNy-i?b2b_erh-)c+}E~Dgt|C+VP(2eM4wMiyz2bqt(Mq6_FcO@dO1+>B9U@c#{S(H=FsxvP07>wHN;S*=YCy5L4i=A$(WPp zH3CrrK!Yw4rkzCHs*kEmW@B!_0{IR6?~^=^!y!^f3v?6$*a+SWF$2DoPRCfg_`wsT;Bd}}ow>RF zIS7!)^~M9VFekN%=P6p*$yf(YCJQYrw1+j81||WLG8NE))-h^WEmxIPpI3MXZZ+8s z+F6u-O+X;k%xgoiLy~|Tqk30vze2k7NJNB{k3A2o+!7MKb0D`lUW!>fcA73*nvp_b zrUWX7s@Yzry16AWhmzq|I@#i<{Db+e@)Yds?D|vDx>5vb3>5SPO~Uyb=#u2GPxEJ< zJY;kcP3ig|AJrc@UTf+8=FKmNP35FA4`RyQxL|~*c9VbU-FX$a+~?0-u7DVQ{nfH= z;)R*4?E0W$axNt-pW%9xTyu!r?Cm?+yqFCG5fB8m;AxJm4Rxf5S&Z#el{)XOjIbUV zKFR*^xAud)EyAz&x#RLNDu5FtC@2W4Klz;!RI32C>QAJnq}&@K<2)p6xe(~P|J()w zsWvPIWY*Zn+*EBq;&=v{{H<*7+*!x}a3rS1=k45_jxj>l$MZbiv8$OIJ|nbHp|`&b z-5o3=14Many~2GZ81 z@shb)VPSNUR_(UsR>?vI*fwtJ?gCgn3&_=Vmsuf-ItF`<=#cSwt~3d+`0@T18#odDdluuxJ_V0V#$5O2+C2Fzjt9t`3 z;{Sa6!}%8GtbiAJoOG3-g88+Tl~n;MIWCSBB<8hg)96c(X0SS#anxk`^~9x4}OP0t$gNNV+7)(g|;0b7AVuHb!25(v4 zvk7^8Rnh><$b1~LQps@-Qi&R~0)%H&r1V1#3BkaP*Qxx6@vLwkM{R9wuR$McZy5;! zf=WQcUeIW2>YjABlNcu@a&U_Mx(}4Ra3FZkW7BfxNk0nP_oO0tUmP44i}&dNCtT|R zkOrg8bOV5iPi`arFA(tm0gsO^kqQeJS7C6m1An6EYFA3w%zxv8bdP>6hFf{-@%SPoEuW@@A0{lMqLb z5#2|^MAJBdIdYf}5Y=C-kf@>>%J!WlJXa@paXu|^`?$ZSaKek;MS4le{W%hf&T9ci zitz6P-jj)B&i6gn?xYIP1Cpd{&-TCAi zyvzzdbQvD^Vm$L4+S;CWN`OU(@-&$yKecg+{v_#}!o(}(*L%dvRmEXKGxA`BikKk( zX%N<*La)KU26ksn6TrlQGVa(f^t{Qgy!ROIrr^%@7CrwWg9)GUL$2ZZ(MuN@gz}NWL3NKgR~OC#cx6D zX}IZVQ5`AP?O0%iyViLJjT15No*7wV4|b}Zg>w7&nZ44_<*Q%lq8GN=N{L7O3XL{m3Tp?96gZ-hn`~6Yb@VeVxJ}V!V28 zB6;G5{(Ej>9Gh}n(5CG3`1So{3^ei@ z6Idxk)ud1ACz|-~R4$P@O(OU~&#b1>Y8ZSsM&B4Ub1Z$wQC>Mw%3%6+&O$`UekJn9 zvxCCYYgUMpH4{&L%tt*qN$*b(2eLjYBJ`HA# zJ3T}>5fi8It^ek`W|h^Nz2Q+Sd&5u4@Dx>auI)!>Q(4POBT*WFtA&j|l!p}x517$;CBys_di#O@cMuzOCyw^MjRA<05}zx5WVAIl#W z7V=46CV>vr$ zF28!nYQ<}8m58=mlLZel$ODO0g^x%%S0;Hs9Dg3TGxcyc@w27RYec}{rWj?qh8zo% z=wa0j90f(h^hHmoca^h`R(yDl)QNd;LLhR#>E|{yy?DP;5+1z zDpoxnu9q=r`Kpm8U_S1{FTU)HO-CrUlEvB)1Gj%x`>>^7h1N4rGx}FZ+U$B+MvZ;^ z^2LcEM(A4KiSm0sv;A|p!XcfY=@vA7Nd9?YT&Pe(-F`%98~V2Dy3}lspSfSOyC|1J z;Q70A>*G2*2e#2CJUJ3V?q-b>L0^;@v=>V5ZD9jD_K1p>I@aO0qFyRP-&y{CpV<6H z!cZ`3KDiPpRR^oJX&h61YRUuGPG>Fl`ao;eFSqP%v3%a^_I7n}VLLWpfqA9s$2g z5o7!k!$vG(evb8|>DNymS-sP}=i6$nLhtIN{%TTVd)f;n=o~_*@PM=Lb(k{rxkaG9 z%xIBnuY0Xs?D3H@{b5eG2`OI)O&o)WZV4Mxen;AYD`rT3djI$5RK6wABB$S91-?kX zMSLgfqbr}CEj7}v6OpjoDWZ|GEd6nINauQTI8aGhD@d^onAdEwHK_ef0Cadqd;wi1I6 zb{&^$q~T2imroLPQ+nY7*7b)<3h?A}Jt^0s{>1v&BBiQ{{C@2kIs~PEtqa*gS*9KW zWf-%5G4ZRu%e*eb@0kAS@%pDgfxwgFwXIqRKlZ1vdWfo5H5hSQ1{dx{-za{I+=_+gh`3p+r62Hse_8}if!vq)SUI4FTbw782zj&X6JC+ z=+QZvEC?zCoc>CG^;+Re9bpIwQ#*v*{Nx+k@-mfGgrfSZ?M7paZ=-R>vTF5w`3Ntk zqZh*kYvaLe86#BCfJ2VuvH0%XMttO|FmrjwKtL$`84h6*up9s=RoT7}&U{UfN{bcr zjSlL$FNX`i9W7O}yLu%DcYy4G`D2B@*SrGDA0_Q}D5r{;&&Nq&j&`vbNV&zitg%(k zE3c-BDQn6tn7m{FPL{B1L!$D5nY)|g7q?dh{l z6EfV7^dvEgji#}cXB1jgoa#vX?ckc;n#mL6diYisIei)*K_wuk)+GKCnL)N zeKMcV?r1bNv%dQ+52r*);I{XM`xuXwx)NagHJp7A+iS{zq5swHv}3V4CEfRa zt`v(^VU#HbZ(LVB!60v8LP^DNrGot?xzC=6Rj=(VL%jlz;Q7A;KM^CI{GBlJzSrDBV?8$XOhY2FD0GX?M1C{WMl{88&$l2C+E0Kkt*!rhV8{pfk+)?VwY&AIUH(O}ilX>=3q{TROde z@ma}=JSIBGR|Whr`1e+T>xXZ&;Kc=~u}hs1;v1oj%BE@QL~oD}R)$po@~8CMHD~`r z2&Ct}&xf4eJ*bJ#LV-eTPnu={u@~JpkL(k@AYFVR{7?2Y3O04{OFkooFTd(X3Xzb*G`uzs#|h)GJ1EPEtbfGM<r3TLx`g9yB``?Vo+Y{$Qk{b4E0Wv;eT~5mBU%NiKVcU&(^}@;r*k3{VM_FGaFLDm zf#?rVlfwgtd)r1@_0Ha=x!_D5PQeBwPG?1O&M@j+l3) z2-_Ab`V4LQFapJ#jlNdO`GDN-J$kTZJK%i8pdY-6nTuSA8_y5)!H@}hy`&ctW8pzy zN0J5iTBNLjeGB-uSAcu z4V~KP)EV6~CmOv_1<#+WIJ*cI@tF3y)97xH>^YowWZeu?R;3rq+1wc{Vw_nz75-Xx zboh5FSVF#c%w{~3%!b75nr1e4xmz3s9k+Mp$)~{zIm=MFMO{}X%=gd0e=HSYFVY8X zzL^h&TBwvwa1#cYE`}xj5VhNEqf(GCiPyaMJVlVi>20Z&mcc90l$m$)!9RzVJch?N zpn8q=pCOr#4{SYBlE0?yG)O4OhqQ+&8d@l6pvufFP_#gh1JehS2)6!dT`qjVymXer z31;E+)K#9D$;~Y}k#paw5G!;J8cL6g{W(anWczG#qaPEhk`uJBp zDs$dBVwlJtwBc9rzkdk*m^ANIiIr;62!PeFz_2BL?k`gzx|~rX!{boi{wIVG2x5mN36L$IvXz?_#~O25s6&eu{aXhy?%bpuHw5p=v9A zrMY*e*=V6+3{{!7MNJ)ZGX*XcE%Km>!(GnC1q@dFV*S_M0lb&&8XT*FesbNEBTt6Aa!nlS^cZooMM;?F?2gaV|*j^8_ne~&E`cmp{W zr_sZWzn71lIHl?<^ESqkPsG@svWPy?O%_x)^39c2)|iuw8EUETeeXSLTyiAbzY~!P zQ50!7^z@mD@7+etY$Caj>~7tRN!vMDPoG4IHxK5meWzg~OJs~E&lr4m5ijvz`L@W* zeMX%-pW7_M6(`3(zYJa;mV&iL6R56Ub>V=Js3rDJNICUk2gkIYv2i_L-|9BmLb{(c z)N?xr9**}-k@$S}gI+XJb+VCCt=`lx(gshdpN~zgG4r0ZOprZK>0=B!Ukx&mOm=^_ zG5h6~GPU@U?Y4)QQ6y$f2PUA1#HJU1V(g!cRh>RVv%j8093_Vm9zCqHLhkX7*6ZW^ zHEwvj`3!|X{rW04qNo^8w!@UED%0x;<2>A&i>!F0%GOeF_6f?K_j`A4s9L=e?KZSR zHlD?$OGw2t&hVd1h4Md68JQE-PHo&clgsdp`r(51lb$~h>!8INU)FNipQOAGZq?Rv$)~w)?7u4q zbEM`q-mHwwr=WlD<#m{Nd%70uh%z6b{b^T!K5(m? zg{HS=zv2yz)rw=_ItHn)>tY=*oR90-==W>YG&(vSD=w}|eY%!TwbZ7Ns`iqNX6o}{ z>N~mi%}U%3FSzlpdkml?4V-RfSg8^{g&&W;forKeBEOU;hMtD@Yl-;1E(b&?eEIPH1V zN|-hsxh~c6#{7II)LQhkq2=rqbqq$Kiu0PL{6R|g$a=%Xj+aK_FWQ<8FFY{4<@m%Z z9I`1|B7Iu6bj_US=V{Q&D})g2(fO~jVz*~P9EESXc&O1BOemHuRGw^4L`I@qNoz;Ca^a?Rzn>QG~Hhh!*&_Y`TQw$H+MWd86@f z5zdnsM6c6U;<)$x>_WF@XehY$ zLX+yH{ML2`!)vmxr;y94#CZqI!jvD{P}??5UqO32f7^_8N*ssZ8=H%F z5C$|=u80?otOrhBlur2meqe)R*G1Ub6@n-=ohd%ZEpDF` z^Ea{xNidtQ-`eQ6^9Z@SyjbxzlEgzbvOU=88||kVK)4)d&WDEQR;JwB>`jXNj2672 z8J>B+8;uRm|2F<-(W^dUXzirquzS?u+)QUHq+~W%WPvC0b~KLryd$oWM}y7*)!4UI zm?(f%k0|004)RuVIIF*vkd$OxktrEbv>pk=`2Z?E<2HCVSb%Nm!3-s1im3a6oRA;0 zhMXJ=tK`BOQoDHSK-yT^EBZvzz1<{#jCSVi44QMWf=0c@Vt!A+BmTt9V|02N*`7A( z25B_1%^3adAxP?4{|l%=P7blY+pdN^JSdSou6DD-TB4sCtv>U9#(cZsVu?U@pxM06 zDyr57`-zj}@G*?{e7DA^h8|(OxL&?RFZB+4mSR$8SiHvE_ga-!Q$GfGmVmPv=n1;5 z4NEsd9X{Z*Jdwlk%O95fygkqV_%PMO*z;u~B;j-*rxhLlYh@-ybUr!5Afe8vucGu$ zhVM~HtXNgRfpwO&R&&ETrMqVI;$T9fhX-;mZJ>JQZt+dx2nkoJimI;j!$GVWaMTK61(M`T3EaIZvwKvYA38dPs{m0I?NGv*l!wgZ4x=c5(iZwBHl_aHSb! zR`nALzt}b@#Gu#%6uyUy^4t{>w_ops(baQL?r)Y!P5Z35FA(<93fn{>(Ua91NFcmE zSz14cu_nC#5zNA!7o@L_BZOg33>ND6g?C!(l<_f(uG8ZuCQki&`v*kcnwkmEY?h^N z6pVckRBriBWT{eTdh*#jOwser$m9x*eZ^VWm;^G)n=e5d7Vx%tCj$Dw*BC9H3~cX( z5w0oLcg^6)nyoB3bC@r;V^2>7pMa4{r_u4msAW*9Z#Rwf+%2?jXM0uz=e{#qc%z-k z8kL^RK9hjis8i;M-m>YagFq_(RxY*F5kWEV)Qs#fNH;5yc2L!6u$XwGPyX#PVZ>?9 zTN?3arZZ*~rAd}!3DGOXK4|Y~KV0w7@92X0KS$rDg82}CLyULEg~_6{RhEl=IPI5z zzqzw5QG0sV&oc}XyX1)7DxGNIIkDa=dCA_w76w6b->4I!v4rgzuCI>vpY{c2m|DXhD^fBOVVi$0SM%=z)<&Jhu%M3U6e5lVw0zE%7Gg!;GA|a3soEgJBQTL zl?LvdK^tE-5@O<~54pyNR#R;NsGE075A@B?A!lJ2_uF=WD3kSw?CJHT{8d*s3UH@R2X*W>HVSaM50E5MRhtz+Jh@V^me`CEJN-*&e|2nzCE`g)uqHxwHPK6 zdi#i_?I@f?wToB6W3Q0}$6jrJlQ@z@dDR{stMhPr;@7n)+4;YbG3?9V4BgEQQUL#6 z(i(WB{ZF{L7()*K*_NiQN64Sp=GvyO`RM_pdm2f%(3ttNrv20@b6C9U_G;GPD#(|3 zkm$P3gXt}sLN;l}`*iWcs$dRMXZMu(gKzQvhWY`fn(I$$4pQt~KMA=;c_m-7%x1llP5+>3U&j_X2zz0dDcCUZsGFqOPUve3rMU zK?f;fw>c3SD0W&8jb-!XTztsc=U8HEdJQ-nnViTEmk2u9RvVE!FNio2J_Ixk2+W-c z2QuemY!1K}eD7E;rXkT>{q~EEg3)vaT6#{LxWlT&1-FrLCDKA?m$G@q1qM28jb9_0 znx(@4BxFB%MSMl1->1(# zbw62W+UHQ+(!nJn>3=H54>pkaf0Ne;bZshO1T;S>FFsDr>l$zJ1v0k?cC3xIMKt0r zQ~O?6gZ{_N$b$x+7e$nm-{dS7Jxv6MtQP3UsLSiU*>T=OcRk3qfQzda zX;$<{r-3+Pdm$nnsD1C1-8DlPX+vi;l#l3xPsIrGiZU!;1dDOUou*cN6|$OZ@&W9J zqQ8{^4=K<-&hFejyYY>s!wxinKHMvi31oRgVFOyoF;(#36OhM2Kv0{-6Hog|T8ZHo zC@9*EJhkZ&D^Tijqc!A6$GpfxGa($JZGAIJ<R=4o8j;ym|GxvOHZ@FPI7_;81W_eUklqQc$<$?q!bP z`&wWIN?Oy8vSwE`i~44fWTHpyrqHfe##a?8Z!4 z@gneI1mIF`?dp3ZnmQ&80#0dO-EhN?s|}F<#9Qml^Z}ulGrqtMpe{ZS816sptpq;o z*3=c_6#aC`%rt#~Cc6o^G&p}_)Re?p`x($Xyt#oSP@4N(2%>0d zDUUBus$9oJhQ~m$xAfszMEj@r9+^Qe5 z?V$?>TR$ZX>G1x|00_XW0iI`?6#yeaQU)S7{}C?fv;3Hnj^yS~Warm0h45-sZW;>&<*xIg1OG{f? zW+CCuJlxzOHF|1aUB^peTwQNz7g%;g6)8kN+`3RtGUF&rp<smS{j1M2?`TLx$S^J-(6HJ0IiSo}Pi$9DE zpOUJ_n81Hf+Lr)R@UTs}NS?T*NltHOHt$rw8|{mj`>r~JL&GEWi=3L?OIxRSHrK51 zi2EM=X$~fewem$^!#QDhNJwzA!NS5$hV*yu!fgh%%B_04`=x3vyR+qCQ=L-X`L)tY zOoO&5J|+OQ9%X=80aLltD=9|+pBzuQs}|}X2580Qm6iEV*N$*%`xgUos%YrEt&r1H z>z_ZlBLzCe2Jp!>4{GXVUZb;bNOY>e$}2|dpsVn1ZM`R;^btZima|&Cbfjeg#7{XL z=RgBz*R<3&XArVr%DdmitqI8IpEp0=rOCaSuc<#4xqPQlcSeU6Gl`9Ny9F! z_zS^dI{LGNH8xNv>4=x~r_XZNuV0Uirxv!+R#)c&@$r?W7#eGEd6qWck6Zs|o8mQ^ zV^_c1R!-(=sJhopZ;WCJlW!3d-=ip>m`clPGXqS)BXiCzN;ao1jhJ z7t6mJBt_NAii*nSvMw%)0x-JB=G283rB8_Fb`cBnKd692r8K9{pF8}GWZ2%AIINC2 zHyx~%D=SAx;rtU2m@+u9v0RaL2YEDp;7L!rmS7i=H8H>uS7?$eUF-&2NahbE)N^M z=r6-AU^~}Q1*h?c69r1if?g?-j~=lo72I(Vq3;&}=BQSQ7Z)RqbaOMaAl!M<)oYQ? zh{LYA7NU(ti*h|qj;@Rh^UBG)4H=(4Sx~TZ^=uEzBw6xEWInjwakOX41a$F2Iq$I$ zSWeEw5B+Ut=Z|Ig%w(10ekX~WTiKy+fR&Xt=DVIA#)wBprEjK>V@S+>dX$7|kSiN+ z?mX2L93Wr5)E-UZBErM5^AW0}#YWRMNhI{6kUu?%`{JDUNND-RA3yd2l+()Hb7f>b z(t0Drw8{zi_`OoDsgta(?uqBvG~M0121btI(W=!Kdu^gA((54+=oy52?@vZK3si#y zvI-;mZUWV8^9o*xkPUd~ z*NWIrM1HHDIUu%tQ{?3H-|>cFz}5dSdBRnbmVOxF91MZ{P$pw1k{8T?AH86U7mVUH$hR_`R&E|w1C~1P_?S0Tfj%h@%;ZP zMFB>&3n>0e7CaF#3+@M4GiFK>_HzFpYxW@|lIAcD*U@pgR`UC1H*fF%!))IHe$)cq z_V%3xd`4p}2r6Y=C8pILQd|l!Npp-l``&};p-Y)bV>rDKDWA^c7xJKuFPmH=**w9o zV354S4(mTEH~=a-4abn=bZzo@Y_Ma}N@&6#_fVZwCOzzui}-@F(L~R5iOvQsJ%7SK0dE5e30?!Rbd?_uTC^tKEPNX~lPZ(=_?YOB0A6^eiQkut>fMNnR zKq+s)0eB3X<`m@;YWcT-B;@oLpYw9PlGPOrd7;$!4xWJ>c4$W0h+?l@+LT#myy(p| zkN~HduYl_|0KIcM0KGB*y&C|%hlK2(+XwNsCCw=i&|DsT^UVC;h%RYGkDtz*cwj{b zj2%H&*Y?JU1Eh}tyzBfw=)3?S32fEW6yyBqiibx#oUMdX>RRw0>NXC}|7J6+;IPA- zI!-S0ZxMjyztlXw%iCBd6bRUz=a#m%h?W+R*Y?Dk%Q$`ga3ktbQSn=H{-gqTvQ1|X z60S&_3p+=y{wU|UmjMBnR$%fsDpdffB3Q~KDOn1DIcE7UDzDM3ATkLLRa@Ja!9oef zUY;*rs0DwosS#0?!t~3-)|{o3mD}FvM9LNjuU`8|(ta~E>-QKHFuOFt3$#c%TJtWw zRbR50S?dogATe|D{rj}CvcJxam~`xEmy)1_e^U679Lph1%s};*GTFwn+6TadUx4LG z-%gZyOP+7aQq;_aAkEAdRGr8lJOF{@LVjIq0Olp$C3OvF>j~oI&V1SIw4`I-`2+%<_=@mHpYj32hLGHily6Z`;*lXf;Rc3C(WuQr?)XXP>ni6>5FCfS0&i& zVWiRKZ4w^Od}l7Mu7vgy=U^NlFjA+B68#B6N$6?sgzW6>whp$V&ULu^O4WDNUjPA} z8h<>KAnr0mXa9H=*IVyf7XBLn#KqEC3h6{Z zvKHjrBK&fq*X>9&njbDbdIsqJ<%aRLXOE+z@`gm<#l0n6 zy{gqO0)6F3N6=mi8ESEbF)}hTIx@0xw;c(0n21fPL!rH_td8hKJu;DK>})XHQAp5n zTFo3jUeJA2CcVNpn6E4}t6s3618+mniVUSpj~_b?=ei#D6Vvh=ij3F$mMo5~DAtcy8nt;coPTIQM7!UxKh|T^KZx6u4%@l;@pDxFG;SJa;jW{`x zF6k$J_}EI*7+~VM48H-Pnsz=FJmioB{IJG|=MeqBP)0xC4ll6~<^y1c`Lf6_2h5F> z0{nXvhj~lhR!drK0DY#hv+Vs#$1~gxg$fub0M=J*XY<7Bi_+jKaRWEr!%sgbZC0&G zAyuBBSo#kd;x-G$Tbme0633Oe5BxQp34tTFt|V%KnmxVHmxES z4^2J-PWbrkUQf8yrBn6~9iJV0g=dQ**FRyk2Kdrz=&vSX+JTSBoRF6`!?eao)NhC3 zuTK#)UDj)A0{-~lv}eUy2^+20r8e_jA&M^)V#^_3bWwGa?Q~rq5t4WNg0P7@6mAoNCR*y zIc3q{y8mUErWOw?oGJkpZn)eu!-o{!)8#ISlrQ0ZKfZFEki8U-zz9TXfln_ZSUXhs z6{~g*PJYTN`fm@G=CkhG=Wy#k<&fuvj}uxRgY;$7B%e%OLaVfwmkj_cq6RnU3&?M` zgyjp3JZBIT)b@)+WVg@dmWm`Far*_$xSBei&-593J%1)C3;xLbjk72Y?dJ**kjCB` zBD98dx6t5C3E78m{Z}r4FeZdiC}3xAkEkc!%dY{h`9Q-gn2$fN%+%Y@h#L6NCY6JT zJ+G}IW<|;!Nd{zvAdP7%E@35s|A(0nAs;(@2AGq%He@y;e^B zfjEc=Efgw;+&}n0(_ecSX3zkf$_%dC|LnFmzjkTX1`+|$94>PXMxxIwMk0hIB?AuQ z`{M-IcN57>aC5-@j;gJX@m-TOButkpC^dR`TW60ly|N9C-LHA%0XW zOTo^^M=v8I;N!eiw2NW zEE7oq;!isspKTiG_0vgNe+}Gm7>%`;V-9)$HK4$l+*`+h>^^>09Fmg8U?n9b(F30R zL?k&Q5YXO!Pr)polYK-P6Z)_=r=ns`;d}Db!cTjrn`F_N4dR#sfd?p`_+8m? zcr>8E`JKah(Vf)P=N1i+WRX+4Ncr>iJtCDgkW<2t83*z0Xm5L#CtS?}viJ@0s#<;+ ziGXz<1E1L|0g?k>jFyfOBOM*KnaJ5`s#*x}vV-XbR=qzHpg|tp@vl>~t@l&?k8vNb z$^hxik3rV2YL#+E8V@E+RM&1VhnsitD@jv?{nG-{kse3gs?C)1?Z$oV`8F31k3&Us zwd-AG<_v9^agEcoUCxl6?vF+wEaU|D>q{08s~NU>ZX_68-|!9-Mb(wy&SvO^79mUd zONEaO2p|9rCrFVKER{~oHRac;H7u5=Cw?d{10ID*Nxn_1u6Ds;5kQ1AEG)3XFmFg} ze>`c-=UYo5-<@KK6(-x=OxgnL4!Gh!bq=#LlLhkVUfbF67L2H?sj1=oot&Z|nOTQN zZ;T?{>x`liJ}K5dS$gKwQ~u57+^r2lX=E9KyQ>IgkB zAViT3%FJ}|XEh8-({N#BF?&%N? ze~kBGCwa?Ifsx#ZR>Bk&Hc;PpFE`+lEhv4Ao-ROVU#dGd_XqV1;xwLAL8DX}ycoIt zaB#iyduTEJ}Y++E(TLlYCVa2DFyK?&V)-^<5zv?`N<1nTIqAd>|T2$0&e zOVd+b+8Ga=H~c5^l9ImB8uDM^zh^%u8;8jLw*X+Jsp_9`-2C z(qIkYE0m&lu6UKNG?gtOTgpDW%Hs5zTlIC{<*(JOhYKZ)HIKaH>sEc0G}6q>fuBBGlpGZLxHNNP7Cm@tQhgS;iF=hRSV(GJoczBpFE+gdE1z)zwvQmLeT_5eLb= zR*AMVQwPba$GCNNtpB#5YA&NEguU%>Ev`~Ef-D26teQ4i%jm!tzA7O$UgEh$f zQp>+xC0V;quh4hlFi*N*PG7D21aI^nac8Z#F9Ubg&FR~H5UC$N7-zV0#nFzYZ^dlz zLeE~VAKvxDu4GD+GDXLHWBEibB#Pex5j2JChL*k!c<#1rJ>4>Piz!c<^lGt%zXm@A z*?Hk7^4oFklpH!m;Xq3QmyQ<>iW;33dmX$1EgRmHGGPBxfoPn?ymsSx#vI9y%cDnv zct$hXGdEvu9Jb3O%Xo|mRgPG1I(b}p8B!zQyj99nRC$mp=DXFX0h2=D#1@tol!aZu z|4yGvHp%J#PIu>Er_tqiWrqK55_$QvYn|-M4E|T|sGna*?#bwl?p07m5q4N>`T8H5 z06ASUimMr=Bo3EG`u+c6?Ja}iYM$^>VmLtp1PcU$vp{f%WFf&NxI0;3vEU8~?(P;m zxVs0pUEC$OLvV+CHt+9W_g3Ay-!8=ms&+YNrl)86>F%d{&Z=nBIbIvf_W1jic9z#p zStjm^tx<&4Qi4$fy1m?|Ii54MBIce`S96le=LKcwUNDR9?>dU`5IMhju*^^k4$;T<# z%d#yqx*k8@X0vaTNVjZAv8$CSB0jef|6rlcfsA?ZS~EOP*K21)1f@sK{gs*O|Knq6 z&r4X{E6GH$^BnejN zF`u0yLIf*I4R_6xe6&~FPIQCw&X1RxMKAAP`MlFy4F0DmuAlS{(A$5|;2kfwoZh!R z48fho_|%M6kR)-y)ehKZOh@u(dRg06?|)$;r|n=gM(BBbt6@T({~W-ZjO(ea^s zUVqO;a&-mpjzcQSX^1J7lRufnD(qw$+k$qFE;Vay=o3<2j>UXQ1Lmof_EG?{Be;ee9s+xULWrzF(%BSpN+ zdp=W{5slKZ$wQ;;$CUpk&GRXLGdAbw6A!O64G$YzKFS}aIj5VW6|EQ*=L|Rur3=HiI9Wz+;SFyHCRHY(ecWSXFZDyRD^7U}+ZQo!g zfhN#rfA3&1vgIj~a&45mI8SrJ)FW-YYxc+qQlNL&VDYm6J#p^IxvjhS zZ-naOs}V%!4M)Qg84t@OeJ7&cNnpx0xX{Js8? zH|Hf`X>?2{q+)^$lsRv!%tdDpF?b?A8{vW_sK`$f8_$`o6C33rsuVe6M3}W7(wE;) zxq^rzN*x#fE!s<_SES_&^KT3$7EaYfW={`umvfgk%?4U;yI;@B9Cmp$=cPlvt~9C0 zX?#&9b5qjqUie9AT^Aw(FCRz%=~0zrPWgZ1aTOsVg3c>eM10O36e}D{m3^(bEJH1v zUdy%4<}&P&JaCVX?)Pr&@5Xf=>{XGv0b&-3ulN~qYifu{Q7J#qD@^}AZ-QcfDd5bd zu9DF5TG>qLH#T%fY4hOSulXf)2)sV_1t)SZx3kM+;lHh6_t8QE3gzU+0T-86bW~it znM7g0FVu`rD`9Dwa0E_xa--K_uxR6HTrl-)4Q|`GrKw)-0Ak?nWPteQ=EAUiak4#T z^xv5NT9D6FsEmEVXm&9fuEqau7P-tOS2MFP5hshWanNY@Vm?4EzZ5a%_e@>>#DKoc z=U1Ka3B`w*G@(&8B(oPNxU}5f2MHuL!6mBp>Z3tm%H5O2i_D(L-fF2hL}XpvsNAZP zd1NlKlF;ZGrSZtfAr|khv?CrEoxO-i*>#?|lPEe5rgZZAz~cK4aIP8S<)|s>%jZvm zKS=g;%_<0$G? zD)kX8Z{O>?L#YV4&r{>Jc*>nA9@7-eFk1!HzLYb&tr@vs$$#rgg01YK6=NZe?{X4^ zEE|Kcd?QI)Ad{7;+OouyITBi0>2({LTm7b>)HxCi7eMU=ZSM)n8hDZd?R*&-)A&3= zaWH(7SSS?x+CtQ`cYAjMF{<`An1Gcl$)EXv+A#M`P$@m3goVcUSI%6MId%<}wypfO zH-5zPmQ7uyt)TB1@FX|5pEfNx`^H^rahw{krkZGt{FZKPGBYJw#Z3q(h=WEHg`wUjJB6 z52*CfHS>m>hRjRK98RaTi(Q}l8}6oF-oY=p89LMU)FdZ?TpOIU0XVVW)8yo>^4IDS!jk zzNyVDV*Pw4msuGJ+k6Gy2BkOaLoa%Ii7i^E7f2)gRkz6zPn-fcKh=%_UC3Yw`c)CP zHF=h@8Y(Xg9~iZ7?E1vfuWRFs((V=!4l{>>aH(J~$%u%x}h8bZ9_-~myzU|NH zo#9%z>h)HlQ&{I{zyqVkNyV5#ZFxjG^f7e&Ul!y1+fSx|><;8YfThr%x%_wv<(^jm zr28zUBTIq4d)7g~T$s;zySkxglb?Tw>oS~^{6YiyD%hP+9#H==qUim(1?G&bsw;9) zaVqXLvNFVc-@p3r%u?Yr-WOavD^yj>{)1xNE^E&135OD3sMr zR!Cbz&b*d}F^qD!ba3{H*YEQskpP=M0xC~EDW12%sOHXD_)_s52QU0Nysp`x97|=B zt2~hm-B0mX$)d*P-=CS|yJPSzOyT*+(w)HRh2E%*@)x$id|?2)7q)cf%#Q1{W5sg% z1~Ix~ihmOh$70CdOdppmyLno=`}>VA4vSW~UFSRY%tD7hWA4dsP$`Ayf0pj3B(%mc z4R4~V_~MYH<}C^PG(TK5RO*Q8k=5=G_zy*jcfp_99Ygna~zq`{5r(!s+to448QjJE+TO zc1-*PA%ILP(c`md*;wCqvS&6LZJUFHm%qTC~b~Raor><)ymYh>3bnAwpG)H2`5_RTyYi zIIn9?7i3mk6aof^HUTF~h2UWJuD}#Dmfs{PVJl20+4b9qXE{{7^W+CI06{ung+gMv zRr@|_@pbIR+sPAQu2v6P*f*7(&9 zu3@s`$glMQHR|g$!P5N}01rUvoFpe=PhlI)h+%opS3XQx{zPzqZvL#U8iFOn#e2sl zA;Ss<>^*uZ!`mVpEx(JCt%Wa)N=`c4Dva5`UG(h1+ir_{j}o_tiDY!U=^oj#Z69+^ zeh%vI2toGROi9f{#eIBnYN@5)4sQaI7 z?Lf7qOcU2SdwDnUxK;Av_W8R5>&H{7mxyOdAha3~^hfHa8)`4_2jhcVVg@W(I)7#w z*%VvOjPRLmSG#@P`z;1JCio&&uiYVYW4mnWW=R{8B~iHtoSTKgN1F^MXq2(kXs$BCx2Em6*fZkK z!biPDkDNkfQc$*K8i{(zy14%->P=(atVZBPI0z7a=*_CktuX;HVvq@iVrgbp}!tOpy7_nbwiNNIU@qQ-exs9(&zfeDced5>YQn2hVE z5IxRJR;#>m`hBKckY3Mi`*y8=BE$~ zwQ>{7Ab*~A-vA%;a&tf^SR(nq13_K& zRgCCmUznH6VSxg5d9j+YjTouOX~rw1@6!u=L{^p9dQ}nJpYAEJGjJd((v^Dv60F?k zr}et|vZzFic;@vimA@5Y6ArzsFzAqy9yH^;$@%@EU4)afe0>^hp>8FAcncG~6i#y$ zV&Ic_E)~bA#XPJ!XsbN7Zwfm3eM|xK+So}_FdI6#Uk>u~2#=!O>I*aN64|08`xehz z6wM0HX5;1L4*?q~Y|8+^Dwoa7o>nb2Dw)DDwPH=|Xg1>(7h#w6ciwV?W~$&1zWv7X zjAZ2GSF>=>=6^$Yk}-Q)5^!|;b+=}TU%&ii-q_QS{-kQi`+WiTljWt>t9gR5HL?kl zrqiPbyNWz}!_0~`=jB;FkjmgQwe)<3`8bkEp^~V)cTqcE0|A})K7UeSVDL+MH^TEQ zLk=kXC~%^UcN07;Exn*=1Husl&R0wMB2`BMrwcA!6SsrYkFAhuGmWBZ)ZpI-ewV^R zX4+yiMr;llwHJET*jy}zN#@AnIVElqO6O=%c@`H!9`WA?;U3&1zd1XAneB}abouw| zxK6RKB#5I{Xb!8D>1rP|wnw`h?6YQ7^lf3^b>`UmI4D4Nc+^o)DQ`*AWlF9ivspYZ zcHTi!!tV~#AT1%9Lk3Cex2&4CtWuSG7)KAmSvGEkku=Ji7M@g)(^tMY&_^2R<0s(! z_$e)NTkwlBFV9f^^W>9Hqv0;PX=R|9t991Whl`M)s}q}D8XXdc__(X8k^c>zIT;gPPA0;+A9`&TDo3v@6`ky&j?Gn@;Q9HbUPZ< zlc3t1XZOl{Ams46u+V<1TI<=%w3;I@^>%V&P>=pNU)E5e{0LFZttkK+1ss@_ursT! zT(p3FT6>n>^f1O_U3LhJBQktM((Lm%%qJx!<#E{h;T(k`)qSsIdb8>ggxGkIR)CMx zb}@2!%E8ntjf|Bh;<;j0WoJ(S<$8rIpT*KBu39gvQ6Ba%8jXt)BQfJoBKw9eG_4o1 z;~lHw_Uop@OR|L=>T$aX^=hj}2UNZNM{%12si|Ut{V2+yF*`VLGMO-&nJand=4RNn zMXqRgvGqQEV1Oi3!J;PFuwrIaP9j&4fn)9XYRwJV=RxB_hE{Exb@S!cS+YDgxudHL z%`B;+Q6A0H?(YYBzgQb0I#)HwZZ^_?JAA|&jqIJU+!f@rIj)Ou#h{D`0z(1}4gy5(bOp)vw$V0hCqY_3cZRDyF)FsfbY1IY;ha)ym*f*sf!4|b z_vHy$An=i4<$CUVSrAt?S(q6)fa3KN)WCXPx>n-F>C{};$OUX!b&9bS48tz}`A@lu zZTAO{xypqZZiicGX_nG%wxO2luj(xxE((1vCqxwEo4)N%nSv}TBZv;(eV&gT3>wjw~}sYeOToId2Csq`f?%%^lWmdyOYF_qBV-bx;giNiMCt* zeqLP16eV$X;m8nY9!6`|Z}hAf?R(_WdQB%?*&9`<>dOhCdg$63y5#1SDsoz`3ISB5 z`xQ2tzX!9&R*D-;`$+iUg4vp)gTQfZjO@W8u|`^zMHH~w(`50K*>kZ@YJ zQKcUycSaB1h=w-0AO8Wd@2+C0a>@W&D8su3z^>fviTBKfGvB5GiqIac1up?GZi^fC z{-WF8tK&Mkba%*M%WJ?5geq6mKBPCIFrC!>GU3DIhhC%wgz5}0LXur!_7U0{)j_?4 z17aje3N)`Fe|djoY z;?`%O+$m%`R}-WdE~O7EMf0s^Pqlt&=Sja;Z=PUxEUj<^uzTp6USdH8`RHynd|R3% z?SgpL!3`PH)1;um8OuX{GYcLG%F5?5cx%KXRE}_YiyJ(13eI) zgztS+BpogW-Rh_qC^^s%@7Pq3P_l>WSe=?2%d;LbLyEP;K9#|&2$>R$RPnIsPhnr^ zVNu$W6rFdoW61N#WYaP%f zcFJ%0>a}pbIgLoG;s4!1t1iOV2fxlTuh!g<-H6#Ho2US0MOllZX2E3J z{-s{X^TqMc$zmWeqjEEQLO)o}Xp%BX3>Y%94MGo#LcP`Tm#eE_!o^^n^^l`fNO-Q( z^GPNG2;SMG;E)o2&K+CudZd$8dsreW z#)YO<*v`AMKO)D7o17AXgv0u?IT~cb17VxviGqw_2K$2K_Tj@@)j%-QWxkDmDi#k+ z<(WUt{-+M>#?pt}E@K+tVe35uj*hoj0KfyT9f%%`BYQc$V?T2+Sntkd_-u$pB)gC^ zU$>nAUftPKxpvZ#4Pk#R7?P|Yq7AAifE~?uPO0T#3X$4L(-b}bB zIZsQ)ekWjDvvbEdHSVZb>z}J-X>hu9II3KihE@t;mYSNu$*^3A_WMsi${LeeOz^aQ zHBI>j_S?sbAl$zi+ZW}^`WRvPw-V0JdoWc#ZNAS-j^(nY0pSl}RUiNq(-=e&5)gAN z!A=;~`#Z>)GyUS-$rB{~Bp^=6)rDN|fkkjD+;Tae^vadT@Ar45>6Z~2DEL%iGR09r z?D(rPaUi$3&S#X?Z`aQQnmZyBdI>_D@hGj6JwLX3?_E1(cs$+|ofas{j_TZ!rrXv7 zN7Li42J4}*ql^5(gpJi=a~IubVcof7M9KhAgp8TVpd)~ysXbv*`O`IJvh>jWKD`A8 zB|FhDZLe;~j(ptvM8Si+-jg~{tBY{;Q^+A7)YN5shHDMa7n4?ArBzy|RbIn=8nhL3 zwci+@{Z3ZA+#eYTFTVH(zS|WW@-I5%5AJIX$zw1U!=HG}!}V_G4UD#A1c3ZU#=;{H z!F38<+^x?c;PrLP0m0rq(EYtp$@f3W`>pa2Ti;~#FQ5HBC^1gO9$U-ej(Q8M)&<0rf!|XVHoY(HYJ{lD z-JJY(>0F`6LjD(*JyxwdV|C%v8Y;C~I<4X=wnsm#9AS>pXT7+P{!^YktLEA{fCjAo z3<%ee7(N4)$aA`1jNY~~5Yk(n3buL=Zq&KChflb7u}4WU4-zlCt`3B-l0A;*cq+^m zM#J}SVLg?%n=WWYb{6`@Eqcq*H;-KW)IL3b{>J$oo4&`hkrTM{emD7bS|l1sqr~@{ zXFs_dnSA}31Y15?&77*1I)q7iG-=_uVosUFXF&gp83T7+u0ULFlcLclvQAmqfJfVh zw0c5c*g8G>kz(AExJmhMa_JGHmo73N6DZef4CYvn)V2OMV?Jab2b=lVQg+<4i^xs$ z+kX=+Sh5$5Y8%f+XE>h)UX!sUF#oB1VL{7-k5=!&4eU;bx@Yskj|DButL0Y(bygYo z*hj){n>|opoG{Rll3QC%%T4a{>thnVXc1%oWUP#KpP?pg@Aim#uV0FI z%M)x$YBDU~iAB)ljxNJia909aJzm8J%FRJ%jtC|j=OG~Bp&K`AmJc8cc(G9Imj*9G zgZ2W7#xpmaQ_?ySff)KK#_ZjPp@I7`KobAS7;g^s)!-T%8aTGn z50Cnq(Ar&MUA33I#Yd%`RiYAiI1zi4#;?$2O<{J|IWB!g2FGKZG-*=qz67!Yuu83o zqrXj7ovmES}Sc##@-?oqXE(BWHa)GT*eISE@X94Ac&Dbezi1J zC?A>tm}jjtBOngx*%~S)Dh|KDB;E4lGmEB=qZR;q+Lf>jt0rukNpWlm!7b>?Y_2U? z%4FMd(OuY|$XEX47dKiNv^*kjyIGx}X;gcW)(!vFC{lFz5j~FE`GKbg_Db3l?NbN0 z?D-Sf>%$1Qy-2ySwUaB1eGo-O6X#gIv&CcKnas2<&00s^G;wUj*_gUF{EbG2fBuk7j`6jLs>vkzSnQ!O%z9m> z&uA2gol`U?%h2+IYZR0iCzMe+28cjQ6KU^+n?$Eq$CmBGYk3`*eQ{{js79v%ryq#j zj0+dYGZDTdYLZKuTduj#OI3%dty_DZw9QaC>{`NaAK7dkrmHf7I|wCCxd%^c`%byx zoF4J>GbV?^LT(YPU{oKv(b#>(xJ4eNJln7u0e_MbQMI&pJb8|4O8CnqEXzeK^ISXP zuWhzJm$qads!_vNz9bvn^g#`irL1v?J4#O+6 zfo|drd)RvTkDDo@Mks5(XWWiAjk5W`mqMP7|cV2-uvZ-FM^Z`hjyTZwI8QWHojz1xZ4$(tzY`0rZgnVa-frE}ryX z;LE#4`rY8|oai#PuLmpnW0dEttM*?$LrQtZbx9wCjCqLDHDPVIl*z+X6O>}%i(RAx zlqXSG2TDyM740jseSB3cS2p>|#@RhwT`;+XrC&ASD$jAU%A84M)a7aGkxbeKkAiQ zjjg8w>ZgZl>s2{gYUlOGarYAk&v@cK`8q#G!XYC7{%= zgOGmE(QM!_#V!|oKU@cx_6QNdWBluZJ7hCYA5N$7YD#93lbL|hS(LK508$^8Y{$-Q zzj+Z61}qW*c`Q9o0904`GN$wTFe_Jxk(4TYGPS4GdaIlO+wyGLT#D-HG%D#;{U#RB zXqZuZ`0>-q#c{N0-x|BUYuF;WQeE$aP%9z258aAnwYv+8c-BV0Es!m570dwVirvG! zr%5mxMLoqQAb@9q+a|PLg_3sDUc6$i63AX;+>o+5szCh*)vYAzr(51E$rK^7JXQZM z{K7#+r5RbW=5=<*z0x=2P$9?LcU?Uj+=42Vdnw}&`CD}n)vd>E@r-y0`!|C>npy=P zxF_dujs-eBn_922#)qoRSSf&%p#Ka~`BNhMucFVl@~hr-tnHsYo2Rt@s^=VWWg~r8FeR%bq66Gy(*<0+ZdB zL&D@nS&@VK+q?!8jfJ2_fvO9fGzV$Ci^RtJi!U07d1mrN5y)>kWq}kaKua)#6pT#? zsh6(}E)c6aCFy|3Y{F{XLf-^AD=Sq*@EA;Sfh9nABRp+v?gty*mZrFLdStDY3R;g6 zst3=iSi2=qCB-^D^TV0?wY$J4$Ld|s!*flP`zg<|f@z}zY&YTkg1+^^`c;`$!33r$ zTjN>5f_S(NxqRV<1iMS6Hm&%6)mcN&oKD_h z($p6NTbhB@Cj3uMOz5>?PCY3=u2}327TIflAObP#hytrUS zU~84T!sf&D-A!LjXWeW>U>T@b_zeE4?bXihpM=uUOg0(ITaK5->fs>~TU@=koCBzf>X;lO&bGl8&3H(JXl!wKVP2P($0I30ckA zu8dIjrml}AFVV%d=;4Pxp7DJdyzh!#38#Ei1s@pj`uxZ$B!1!p>zyvfsDSRGNND#Y zf;P&u&y=Eq_snSOL4vc-1jA`U@O{nNlwW*y!rW32jiC%l%x6d?5r{vm-?$tg6ws6X zy*dE5D)+@nJx8b1>sUi&@RdSMYOL*%fpOmN^zCFMlTmmP9sLq^-cHXyrL&K6qQNs{ zbWki=vgT*6;7I%&pRYo=ZaPq17!5KRbtkdo4=B~K@HjQxtC&@zhp_W%nnd*r3uXxG zo7_IqmaTBu9TJ}NJ>5+A^@v@mI?j4Rsm8cE=;uxx13M(KvN&jZuWy=hN`2Z9v(<0M z6KuylV8>ILw}(V%A5P344&*3YF)6&H6nBb;cAw3rkW(Bjy)W@rvfjCFpSb;Q^mo9W zF`HpIrEV?6F8tzwp3c6pcWyDSR8p^T5WRNO!d=^APkRW6Ke0_o0k;dTj(RigSBc+U z{8}csm2EJjX(tFIYC@*E2^!4)=dx4DNze8uixMl`@7 zq=q)#+LDCw@j}wDAirk-(w!gK}&EqAzR`)bEA|4j$m3t1G9^DBTJvK?CF=Rt8T80X)5+d<1YkS zuZ|1+%N&~T%N`AfC1_`%^TU@0PF7m{JkeTo=n#@?! zPwvD4-H?$aKuLhu6Ojy{BtJQb?FrL*Y+OOex8<|K{0`nvD2=7Bmljh*bo6ZDc&ZE} z7?+_7rUz{^+%^qrPDXaKhrNzMF#D0GaEWc|q*D zAA(~3=FRBz2zwMMMkk4gM&!Aelh(U89L(g_IyFC`i|Oe9l3*I=2WXF~PN526c!WMN zr@bSl_AXKL7@NC8{AT!qKRov6F{NN5^rtu(7Bp!AFEdy=b4ZsaV%anMRz~S^C*}V2 zYffVmLLf^drp4U9NDp5W%Tx#}sD8hPM+jsNQp3FBRh&!>Gew=o=Wf3@Y+3A%$}$fJ z(HWjnxaQE$#Mj-2;908l8KKpOQAW`0a|g8Kc@N@T%=wV7fiObd-C46j_hr4}bskX> zNjV24LrOi+eu%$E)K3BOmtvi*&NCEPE`~sWZaw7FT}31-36|}1bXxg=^}xiZG%0ks z@6`dy4R4Iyo=u+VjL4L&ef3V?l@drRhp=GKmw1Cpwt=_Gw)q<^A=;Hi=f`{)kgwxE z^Q<#iP1V6HP3(?)?~F$%-?KWX;qLcAN6?a@e3Ag;T00N>QJ)^4F@+UkOB0$D6YY`H zKk<;BdEat+vv07k+kKJm1(1l=-widJl{MKuz%8PO8qDH>vTVc6Obp?-KOvd z@|EP!%ValzH&RwWr_#4BRk=$sF#C5VhcDsT#h1n5Y`MkZYZ<<`!?Qm@ty)@-eOK6w zBsNZIR~4-r+c`o>?i$_qj~RLHm(D@Cm(6b9m*1K~u`BX18>NY@Em3a|p5(WlEC@a4JYy1Onp68KSYF=vZk~Q)>Ky^kJBbF{-=QfZ+_ytKY z+LoQKw*8hvFAZ1ix)4!v$Z39cp$is9V z564mMMFK_So;d-OGJ8zj9TGD`js~jMR5*d;3n^T^9zz?$vD6PcE%+oll9>hcoLq5v zS&AC?y9--sa&S{S=<@ZeFjjz4(Ve!!%{KF6Xby#Djs&{Qr*t`s8AU6v6rb=VDnlbJ zLnB?&%A{{*a^d`8vQ+8*CMBf)dBQLp{)@4%cmQ536)kqu7j_s$dP=OHUE-$DbUj|+ zaH$F`m892wpZa{otdkKszQuR;b#{>91zqijewFnQWg5Xx&qO3X?|{}>?1K=q zK$}qR`CC>EWR_u0;Y4Nyc4xd#9geaat!7$usW($sn3lqTuI|@~yI-ic6Q6kT8dL4p zjcd5O9MYOC#V3&rY{DULsxR~~Zx4dhIf)~vfuJ*IYM}U9g^eJ2(AAqYJ_FjZJO5Oy z2vf!j*;032{WtZ-djgO{0g-Q^C^4)`1dBBC=lhK{Z5gfCrqe+b@aXK=?n`Bnzl$z% zFDu@@(jpyNgxK>uRIeu_zWPT3d4^Q%FzF}YXz8xq9@n)B5Ese90@f-dcu9hLv!m3h z@v!` zLas*ft;Y7W0Ua^y4T^;F1FMomf($WC2FE5cr2v0{ZRRO*h&<%DAavGTm;$#=Y1!63 zXJ>Mj&PS8hA}YB*=Eb7_FE3q6U#7{EjcfN$A36X0ofW^ATLiWE&rsD!=rE2Uyuq&I zN@qdygC_kq%@!2()jzp-W}=IM2vXrm3sc4hZLqk{)WaZs2k>Q%4ta?6jBKqAN;^j1MFIrvSOy}4xNN`qgL4dqYo0G1 zgoud(Ik@2qTo$dq4b*xNW!l9z)TGDlgl|Abu#@9n$vT=ktHK6E7}C=@s?JM-)oZu{ zK;AI%yN?k|bJmWbd9=nxHe&YN86Ef|>HsnogM8ZoJxLqh!ic5%;{PG%tDT=cGJKCc zQD)I<%M|nn{Wfg{dyCgr=04gPFtZ0-0o<`md*=dJ7X~376YE5yS}35GEwIWzkZeqs zR#lC}Tu|2d_V^)Ah`5$3oVkDGHKlxE^jCp}O|x4j7)n4l>k#5B8X7DfnvP|n)^L?> zgN@~LT!tsQT8zF`(#kHFp{qZLv%?!LH>#c$UNEIFx8KncX>O#Bc{h5JJBt86>4Pz* ze9@R?Chb#Rv-~#SgypwFXeVNeq?{Oo50l}Qp}Nv-HJ7OkTl?TG@zsIT`YgEf)yoV zc_|kNa=KnqdWq7bNoP5yUF)Vb z{48bgVn`VKd#5G)ocN+=N-q?+QIou_bEBrw-7#W$<#_%Y05FB`#dV^y8T%Coheg6k zEJYKvB6N10t^3N$At~CjE*h|&to=EFbM!f#N7u@Ulo`er&}-qao2FZ}Xma*E@90yf z{`(Wo4L^;V24WNQ9Wo`OPg$8?l;Dx`6M0ayPd>hprXW1jMzXbKJEdQk>9GK~P+H@E zZp$q{usKLf#^wh*39U)GWt|NmO>hi*iJYd9Q^GV{zzQp69h~hmv{iidC5J@%p3<6r zAIrD7^=H8Fy6r5OMXTN5v;jn!(79FdT$?+id%e+vI{|Kh#yiZ}7{eSWlk}`}&Z%Pz z__cGSNjTAeuRa@`^+np6laSJP#4MD;^FC}p^km7oD_KVD-AB208f=$aaV<}`{fxC& zpQnq7{)ZgO(~TL?f%sK$O-(l+sZxkOsKJIhao3Cx)#aX2@9jRCzx(&dX;Pb)!~y}B zZ>q2>elavcm39t?3EYgc_-J2|l0hkmprqW67kl%02R1aXP4J;mDLbCHO$#z>3{3*Z zx?|1Wpv{L7-Oq`}Kup z(k25L8ZVF{=>2VI>+&zeS8Jrpi#LJTr)eGBrjDUR2id@3ahg5B=C$S}-91YhTs*3keWOfFDsXgonD#54~&ZD3#7f^uTQRg$n{1xE@riLi+KLn-&)9(0U^Ss zd+vy3%2(#T?gidVQGYkUtspMx^hi1kUm<|d>}lt3Voo~{xpidie9{^gkxN1+8~&xG zL9dVC_0~kh(W~(1LY@JGh+UDNqT4d|a2Wkaw)sEj0cnbfB}Ff}g;JzSsy!~$i*mr| zbap?Y?Ntr`IbnXs3fF9^kz>T@>Il#SsA&Gru|Ry2>x-nxNxykRAmg)+;qQ^|&6G-l z_Uht!XRBW+%5VMhlL>+<%p!fZSxl05Tt-AbFvAg%R3RDKv0o+)#FY!PTVgN$Jvw_1 z*`(V#jj|Me(eRNLZPlIC%nG|!oF80dGP&0(bR}T8Cv|ND1#2G%Kv# z&NDw1v+iIwHuC;}X~$B&hEzaI6C|R20r$>$ z8z{0lnjB{H{J^Me!Vq%8X<{4oxi1FY`6MoslLJls$1|gYXE!OL{x4H*VI zHn|A~FUXUcdaJDp&StNNs=iuB(+<9@fUX+%<(9v3y+hZdrG3cNpbgNETp* z)XQ$TXmQ6K&MCJFq|0bq2{|_0FXAd)nyZ-~-bHFgj_y>2B{$yM4+~`v-!!O7YOw+F zS41i;+SI}wr_D*QC-l`Rfw_j9zo6Q?;V(=llVkfJla(Ljy&g{z2Cbc3L7o2Ll%l;q zGwpCiCtu69$JuB+eVGC`!Ju|`F>i21Oo{2ac98NfoD*Zg@DYl$=iu}Yynit2!MH#+ z9m}uzAuX>+nGttKEJ8R2R5_}6b)PNunFCeqyUy})O$ z)2Yk&>%)A2g@$m|0qzTruy}LJ4N^?U#nt|ycjum6{T+AHeD&P9CJP8Cl3tDQDuA0D zD~3HIBNQRJaJh3#fCENiQMhTa?rc!Gta;d^$!S8E*P&$)ziip1ac6Q(O*fT4w~L4T zdenoIg49c(qGm;^*Db%Stuiz&y#t^Q?&MQS+wGM9V`&y!xBlF_ntZL+M94%(hUK5# zagK|&nO(7`Y&2bv@wg+PHE_A4a;Y!7@{$p#3)~@Z^}HLvNB{yH-T6BoKB2m-EwCcO zP6glmrRa~Ha{$}#+#_A5=JacdgMHbgWG6zxqrhj2IbOIUTPgZ7-J^rb^piMLGaw2l$b2`rjAc zqx>sO1^BMd{#EJ&G*tuH9%S@P5ERj07^uBa_kt$@#2Ed>f2Ai9ki8>5JK!6L50TU* z|GyveOVj5YyJQ9RI@O!hILF*>8G>F(sn}yRFgTr zlK4VIKQ1KIHxa$dRa+0cyZo67WI?`Xu8TF$gp}k`TD+ji(z%7HWxGwFNLBq`w-tSl z!GVG6x#@I+Dn8*1h)sD!zw!De;riBV%SG2|JNW(YtB{BScULu+@mFOo$GB@ zWaQ=5L_2kijNZY8Zx_e&SNJv`1$KgVkr}6Q5Jw4KL`9|Y;mSivM zs!lrv6|}XJ6Vwu;q8#q7=HVibML>UA?Rp!3sa!&Q{MQ6=z~7FtfJ>0|e^X5&A zQ&=?rA$RHI%x$%69_TGfw0zKVxm2kx~r*~phFo2k_{v;QV z-tEKXBWmye3->g=~)2nZ(e(Q_NFZ*n$&jE)-6#-x$FWMd(F zDbL@Wy+%t(&BOD^&D8dg&N00{gIxTt8|Ewl64+*LwVctDb{43@c)Bey`YkZOfWJp!P+5o;6>5{r}MOL$ZG3%94`Y z?Cf(4J?F;a;*-|fin0 zacXk%jX3#r?B^IW&hVC#0gM7ie^X9HU0sGa#tCz*I4MAn$%%<&Xk*^J^Ep2!WnzL( znq&ENC>L-#IFM!lY-4dM(2&JvV?zZCOF+?Y)O7u1;}q!GKpx}{+>%jL)_Ug0 zDyXBPqNk#gAww&`uPIC0Ig@xO@~{?wEqql_Sa`F$3$R`R>hkzF-}@1P9dK_%Tdrck z<$i<6WU)b&MmU*^*{;X@SrRY#istzGsUH}iW~&c*wpp*{JwqX1Xs}x*wI%};`kAoaPgTg3yKdEN zy|NwsBFt_R{J%2!1fGN9{;B_a4|E%TUYps!?OHpIipn(OxprTgg=pt0m=-JojX4N) z`prt(wq*Wmt$^gs_unP)Kqd2(FSLuUxud?Octz z^pw2C&+#3-yVdDy_m!x3th!B(8njA+*XyqVfiD}gFEr0q&(-`h_v)Qf@EL!9{{zE{ z_1}IRYNgu179kATx!&?hlcBvl^3ty<`e%;IYfJiPK2XLM_^GIz`$iV2!6Oo8EaE#l zg#EyAnf=L+E{I_uS&jb>j_eB71GR8P!&rVlzE-&yx|rYFSUSQ4;s4%hbBnN#{KWHb zaRL)%@;ZK|C|lLdBW#YlMw#%pZqs^>Zw-Ph33YQd-YTB~3s?38$DJOD@_7b^Q?vlY z;K~7yR2@V}NsWUl!-$Ri3~1|-QWa>oy8N%e-aD3SR?>_SXqJR3NsSjXuy)K#=_>~* z1|t~Q00<0m7@NZZKBTn2S)Vs}PXv@4@>U1h!DiQeHxd2pkm^N?V`l!GUZYaFx9Tf4 ze0WRkd*!~g(R7wJo{5!5lfY!g>H7B1rR!pm&TDn9r{4nyH@er@Qidz%9KKJ+-teOV z!{l{mX?LxW<7e>eq^z5#i5n=G&^n|)XRYrx)Xa7LclZ>eAUMGS*oW;OSkvSHCUCmC zy0&(_1oeZSn-j%?a@x_iyA-A5Jj)yX@%UQ7+Su9=lv|`=??8)>xrf249N`1njz|1% z4lQd#4uE}%NV~U&I6AU$>;`K`v8Wn)hrq(nr@8sh1z%4(HiQT{eSSUhh1<1f=l(<7iz5ag=3e{oWsPKHCps zg+`g}~t1N1&*DHt>Pc1bz@&HI5_(I~( zU*@FREYq;G`MteM{Iy#sZUJ`AR3UOM2b{tM?pAr zX+kKLBZMxAfDj>4Ly;1C2qbz00YL%+B7}&vASFZ+Lx2Q$8{QrF=N)&vUvK}|J0olF zIp)}Fu5Zq7jj_Lp@ZX2NzkUnYk5<9PJ+BX39QpAWwv+d*qFO)DE^M^ zy`4E!0r}Ns0|kX}50CGEKU=q+6-?M*GAmP47ljf#aJa&%s?cKtn3yMcH(3Xp2ccF2 z$3crL?ZI2W%rm-AT#EBA{ka)qdmpfaB3pjctmp$bu0OkuJ|S94p-Kx*}rtRaCCy;)#qK+q;${Lr03<0MUD2K?$I5@!s&Yf z^z66T2M@dhCw`&l=Iq#oW-zJ{1Q7QViH0qb5 zL1GOS*1AQ_X8@ld^j6q*pHgp+8Y-elNLUF8j*n=K@c(tg+q(?mSp5Zf6Hqsbh^dp) zWiPMWVj2}i!PB~o$+#0`29CPC`($gcb!{ei) z_y9_kF6D}~b%=AGxpcAe$2=WB*P2W^rNyzjBd~?rr(t9eGGi98o{A2tImX~=EoT4j zq7~%!?(sqQ^w>XpkzN%Qdh{t0W3X^rCS&2VLTCWp9z%}gP>qDQ4MFFG4S)Lh`%k7z z?=}QYDcSo4#qZ9Oji$KVDN@+_O3Ts#)vh#;#m^^&x_`nM^e_Td{tuiZWaE2na*>_A zy8V$`Vj{6Vk~KG-oSq@1ATe)u!H>uGL~jvA$_ffzp7|0_=4ZMD7$lO~Pm6?1(0vz| zD#Q&;L9rqpziE1xnX5f@Tu5jmQ5wP%2C1m23DaXhcNA@-Z+_YaR^QreVmHLa*;z$C zQ8<1c77f>EYRd2Hd0CK^b)wBl-lx9d{v3_g(s)rvx8PVmecf|+szOxn$;L+K&d$!; z&%tXm{2(m})M(-eKU8)i!A~0=QGr&d_xKJ7Aucug;yhm}Jn&x(p- zap)Av&=90Kz|6v8XqjCMk%e~tNk0jZFX~D?4G0$(7rR0)Tqw)Q&L#t$zQ?-x(+h1T zk1j295Iwbke)xn8qNnn9#*t}{5uZ-q^z*onvq|n z<^Ry^gylyekLb9!Z#0*byg(^sY{4u6K!SEo2d7E|wFm=Ydrwo`Uw3RI*JzVHzdvoU zRBAq)6L)cSqqX%C=(LQCOjZ@8w>$$;!3xcEx2>rum}5ul1;&)MZg92JrG&b#%hDlX zKBTSSNej;widVaZmPwpC)9(`+O(t&A{3P{uw(tvXhM_b`bU^S9@9znEZ57FncJy znb~<_7x~;HW_LSsL>p$Op3`5I(P0SxwElB<`xqpK_XL!mEMIZrQ2Um`&BDfY0Y+!sH}7W zK$^s|s-BQ@tOgrIv=4pxvi0_02ISXZLXP`RF@Qd4t>?BT9M5}$di(p29oN>zrXaeQ z#Z$)e`oQcH=Cezsr5Bc$n$w~!s-=o`j~8q8dLS5!xL_BI8CKaMGScCCcT|pFXLwed zBY~`>o@421+qW2Y2mcw9oh|e5S_r=FR&876na{bOsPf3n%+uo0hBDaUim~OT07&BB8F%5jv1> zAW4jp`|;&X1)eoP=%>|GySUqP#11w!UQ9AGHXc&K@Pcvvp3_%BzP|rl&Ho1C1603Q zG^5mGzBC)s9gnqOC+W^Ux$$VvLuy zby740sj4;PqbbhQykHRh>xz$3O0fQexkm^)hg!A<+1cLS^Fpbfu5Kz}{>;oPFy-sm{SNk*`Nx(O zHI6t7C{OQEG2mE}oLxL4ZgDa9@@4JBgu|mvkre_6h?cLf@7VSqsNQq(U@#b^hp8B# zN2H}*vanQ3N$I;&lj+nzGyFAp0>~r;#+UGL*Q1;j5z)v=YggM)5+W?Dk&Mo9zui<) z{~*MqixWTaYSR?FF-VOKS(_;>EzNSP<&>dsEPnwT!Xs7-uen!nGZ7FVj4iV#192~5 z`TDEE+^sj`G6Of^2L) zGUwGC-juI0NW66lWdwFm)nsIOWkvb-YR9_Q=M3OweLBFpuM75EJey&d`}}z;oI`PS z&0o@7;=8J-t!)c^V_!w2C)C-_gsZ>(gWC3W!UY4z`Y`YEq?$NaVWwc|B8o7X-;PomY#Y9D`-AV7_ z-xijZ`U2I;_IA#}?kQY-w14SKd6p99VIU&Iy(5}l@1-_cYG+c?!m)|n=3-%(<5K3X zXJpYwu7!l)5Eu21aCj{MKPcntfT$Z7bT6>N!^6|T!_s=^O{{?&-r05feD-RgoJ%YA zX`Q)+#Xx_*gpAAuG@*4}OBJYd`}^ONl*|!{->f6~*7t%1Fr%jCQk${NAWc%>zLR_P z)`5Rfi$Y(?)_qp<=&N?>SpOdeU*4hsd>IP0{s9IYd!%_Q zvTKye{7xK#x~+uio16psSRobn$kcaC3~VAQRf z7#~^;L+R}Gq&hIm+$$FvNTex$U3kYgk@i_=#y;OeMBn*%9Jt}%gzrDH!T)BECu>@C qtFO+~*$6Nzo4Rp9sj<^yp3Y>NWU2h2^Zpfy)Lp8OBH<}6+S literal 0 HcmV?d00001 From 3fa21945f8361bfe8bb2e0294a1d3ffd2d8c144b Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Thu, 29 Feb 2024 18:15:01 +0000 Subject: [PATCH 39/57] update yaml --- examples/mamba/config_mamba.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/mamba/config_mamba.yaml b/examples/mamba/config_mamba.yaml index 85492e25..3f04aea3 100644 --- a/examples/mamba/config_mamba.yaml +++ b/examples/mamba/config_mamba.yaml @@ -15,6 +15,13 @@ data: text_column_name: text num_loading_workers: 1 seed: 42 +experiment_logger: + tensorboard_logger: + flush_secs: 30 + tensorboard_dir: /fsx/ferdinandmom/logs/tb_logs + wandb_logger: + wandb_entity: bouteille + wandb_project: test general: benchmark_csv_path: null consumed_train_samples: null From d4bba7c55eaaf600f9b47d5bee88d7a9e87a590d Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 1 Mar 2024 13:56:19 +0000 Subject: [PATCH 40/57] move to examples --- examples/mamba/config_mamba.yaml | 7 -- examples/mamba/{ => mamba}/config_mamba.py | 26 ++++- examples/mamba/mamba/config_mamba.yaml | 101 ++++++++++++++++++ .../models => examples/mamba}/mamba/mamba.py | 8 +- .../mamba}/mamba/selective_scan_interface.py | 100 ++++------------- examples/mamba/mamba/trainer.py | 22 ++++ examples/mamba/train_mamba.py | 33 ++++++ examples/mamba/train_mamba.sh | 6 +- src/nanotron/config/models_config.py | 50 +-------- 9 files changed, 215 insertions(+), 138 deletions(-) rename examples/mamba/{ => mamba}/config_mamba.py (86%) create mode 100644 examples/mamba/mamba/config_mamba.yaml rename {src/nanotron/models => examples/mamba}/mamba/mamba.py (99%) rename {src/nanotron/models => examples/mamba}/mamba/selective_scan_interface.py (83%) create mode 100644 examples/mamba/mamba/trainer.py create mode 100644 examples/mamba/train_mamba.py diff --git a/examples/mamba/config_mamba.yaml b/examples/mamba/config_mamba.yaml index 3f04aea3..85492e25 100644 --- a/examples/mamba/config_mamba.yaml +++ b/examples/mamba/config_mamba.yaml @@ -15,13 +15,6 @@ data: text_column_name: text num_loading_workers: 1 seed: 42 -experiment_logger: - tensorboard_logger: - flush_secs: 30 - tensorboard_dir: /fsx/ferdinandmom/logs/tb_logs - wandb_logger: - wandb_entity: bouteille - wandb_project: test general: benchmark_csv_path: null consumed_train_samples: null diff --git a/examples/mamba/config_mamba.py b/examples/mamba/mamba/config_mamba.py similarity index 86% rename from examples/mamba/config_mamba.py rename to examples/mamba/mamba/config_mamba.py index 1052d34e..be4803b7 100644 --- a/examples/mamba/config_mamba.py +++ b/examples/mamba/mamba/config_mamba.py @@ -1,6 +1,8 @@ """ Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" import math import os +from dataclasses import dataclass +from typing import Optional from nanotron.config import ( CheckpointsArgs, @@ -9,7 +11,6 @@ GeneralArgs, LoggingArgs, LRSchedulerArgs, - MambaConfig, MambaInit, ModelArgs, OptimizerArgs, @@ -20,6 +21,29 @@ ) from nanotron.logging import human_format + +@dataclass +class MambaConfig: + """Configuration for a Mamba model + + Be careful on having a coherent typing as we use it to reconstruct the model from yaml + """ + + is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion + d_model: int = 2560 + num_hidden_layers: int = 64 + vocab_size: int = 50277 + ssm_cfg: Optional[dict] = None + rms_norm: bool = True + fused_add_norm: bool = True + residual_in_fp32: bool = True + pad_vocab_size_multiple: int = 8 + # ==== Custom ====== + dtype: str = "float32" + rms_norm_eps: float = 1e-5 + pad_token_id: Optional[int] = None + + ssm_cfg_dtype = "bfloat16" ssm_cfg = { "d_state": 16, diff --git a/examples/mamba/mamba/config_mamba.yaml b/examples/mamba/mamba/config_mamba.yaml new file mode 100644 index 00000000..67ddd708 --- /dev/null +++ b/examples/mamba/mamba/config_mamba.yaml @@ -0,0 +1,101 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/mamba/checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 24 + hf_dataset_config_name: null + hf_dataset_or_datasets: + roneneldan/TinyStories: 1.0 + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: test + run: mamba + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + initializer_range: 0.02 + n_residuals_per_layer: 1 + rescale_prenorm_residual: true + make_vocab_size_divisible_by: 1 + model_config: + d_model: 1536 + dtype: bfloat16 + fused_add_norm: true + is_mamba_config: true + num_hidden_layers: 48 + pad_token_id: null + pad_vocab_size_multiple: 8 + residual_in_fp32: true + rms_norm: true + rms_norm_eps: 1.0e-05 + ssm_cfg: + bias: false + conv_bias: true + d_conv: 4 + d_state: 16 + dt_init: random + dt_init_floor: 0.0001 + dt_max: 0.1 + dt_min: 0.001 + dt_rank: auto + dt_scale: 1.0 + expand: 2 + use_fast_path: true + vocab_size: 50277 +optimizer: + accumulate_grad_in_fp32: false + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 90 + lr_decay_style: cosine + lr_warmup_steps: 10 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 2 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: gpt2 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 2048 + train_steps: 100 + val_check_interval: -1 diff --git a/src/nanotron/models/mamba/mamba.py b/examples/mamba/mamba/mamba.py similarity index 99% rename from src/nanotron/models/mamba/mamba.py rename to examples/mamba/mamba/mamba.py index 5f1d5046..1ded37ea 100644 --- a/src/nanotron/models/mamba/mamba.py +++ b/examples/mamba/mamba/mamba.py @@ -28,15 +28,10 @@ from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs -from nanotron.config.models_config import MambaConfig from nanotron.config.utils_config import cast_str_to_torch_dtype from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel -from nanotron.models.mamba.selective_scan_interface import ( - mamba_inner_fn, - selective_scan_fn, -) from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer @@ -50,6 +45,9 @@ ) from nanotron.random import RandomStates +from .config_mamba import MambaConfig +from .selective_scan_interface import mamba_inner_fn, selective_scan_fn + try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: diff --git a/src/nanotron/models/mamba/selective_scan_interface.py b/examples/mamba/mamba/selective_scan_interface.py similarity index 83% rename from src/nanotron/models/mamba/selective_scan_interface.py rename to examples/mamba/mamba/selective_scan_interface.py index fab856b0..123641c8 100644 --- a/src/nanotron/models/mamba/selective_scan_interface.py +++ b/examples/mamba/mamba/selective_scan_interface.py @@ -42,9 +42,7 @@ def forward( if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd( - u, delta, A, B, C, D, z, delta_bias, delta_softplus - ) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) @@ -118,9 +116,7 @@ def selective_scan_fn( last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply( - u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state - ) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref( @@ -160,13 +156,9 @@ def selective_scan_ref( is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: - B = torch.view_as_complex( - rearrange(B.float(), "... (L two) -> ... L two", two=2) - ) + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: - C = torch.view_as_complex( - rearrange(C.float(), "... (L two) -> ... L two", two=2) - ) + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() @@ -237,9 +229,7 @@ def forward( d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) - delta_proj_weight = delta_proj_weight.to( - dtype=torch.get_autocast_gpu_dtype() - ) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) if xz.stride(-1) != 1: xz = xz.contiguous() @@ -252,18 +242,12 @@ def forward( z = z.squeeze(2) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( - x, conv1d_weight, conv1d_bias, None, True - ) + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = F.linear( - rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight - ) # (bl d) - delta = rearrange( - delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L - ) + x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None @@ -276,9 +260,7 @@ def forward( # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: - B = rearrange( - B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 - ).contiguous() + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() @@ -290,9 +272,7 @@ def forward( # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: - C = rearrange( - C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2 - ).contiguous() + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() @@ -304,9 +284,7 @@ def forward( ctx.delta_softplus = delta_softplus # ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl - if ( - checkpoint_lvl >= 1 - ): # Will recompute conv1d_out and delta in the backward pass + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.d_inner = d_inner @@ -364,12 +342,8 @@ def backward(ctx, dout): z = z.squeeze(2) if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( - x, conv1d_weight, conv1d_bias, None, True - ) - delta = rearrange( - delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L - ) + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) @@ -385,17 +359,7 @@ def backward(ctx, dout): if dout.stride(-1) != 1: dout = dout.contiguous() - ( - dconv1d_out, - ddelta, - dA, - dB, - dC, - dD, - ddelta_bias, - dz, - out_z, - ) = selective_scan_cuda.bwd( + (dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z,) = selective_scan_cuda.bwd( conv1d_out, delta, A, @@ -419,9 +383,7 @@ def backward(ctx, dout): if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: - dB = rearrange( - dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 - ).contiguous() + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d) dB = None @@ -430,9 +392,7 @@ def backward(ctx, dout): if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: - dC = rearrange( - dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2 - ).contiguous() + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None @@ -440,15 +400,9 @@ def backward(ctx, dout): ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") - dx_proj_weight = torch.einsum( - "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d") - ) - dconv1d_out = torch.addmm( - dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out - ) - dconv1d_out = rearrange( - dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1] - ) + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( @@ -532,9 +486,7 @@ def mamba_inner_ref( delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - x = causal_conv1d_fn( - x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu" - ) + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. @@ -548,9 +500,7 @@ def mamba_inner_ref( if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: - B = rearrange( - B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 - ).contiguous() + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: @@ -558,10 +508,6 @@ def mamba_inner_ref( if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: - C = rearrange( - C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2 - ).contiguous() - y = selective_scan_fn( - x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True - ) + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/examples/mamba/mamba/trainer.py b/examples/mamba/mamba/trainer.py new file mode 100644 index 00000000..41d12e7b --- /dev/null +++ b/examples/mamba/mamba/trainer.py @@ -0,0 +1,22 @@ +from typing import Type, Union + +from nanotron import logging +from nanotron.config import Config, get_config_from_file +from nanotron.trainer import DistributedTrainer + +try: + import wandb +except ImportError: + wandb = None + +logger = logging.get_logger(__name__) + + +class MambaTrainer(DistributedTrainer): + def __init__( + self, + config_or_config_file: Union[Config, str], + config_class: Type[Config] = Config, + ): + get_config_from_file(config_or_config_file, config_class=config_class) + super().__init__(config_or_config_file, config_class) diff --git a/examples/mamba/train_mamba.py b/examples/mamba/train_mamba.py new file mode 100644 index 00000000..82b19f23 --- /dev/null +++ b/examples/mamba/train_mamba.py @@ -0,0 +1,33 @@ +import argparse +import os +import sys + +from mamba.config_mamba import MambaConfig +from mamba.mamba import MambaForTraining + +from nanotron import logging +from nanotron.trainer import DistributedTrainer + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + +from run_train import get_dataloader # noqa + +logger = logging.get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + config_file = args.config_file + + # Load trainer and data + trainer = DistributedTrainer(config_file, model_config_class=MambaConfig, model_class=MambaForTraining) + dataloader = get_dataloader(trainer) + + # Train + trainer.train(dataloader) diff --git a/examples/mamba/train_mamba.sh b/examples/mamba/train_mamba.sh index 6f675c3e..08de0ef6 100755 --- a/examples/mamba/train_mamba.sh +++ b/examples/mamba/train_mamba.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Simple script to create a tiny llama model and train it +# Simple script to create a tiny mamba model and train it set -e -x @@ -8,7 +8,7 @@ set -e -x EXAMPLE_PATH=$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P) REPO_PATH=$(dirname $EXAMPLE_PATH) -python $EXAMPLE_PATH/config_mamba.py +python $EXAMPLE_PATH/mamba/config_mamba.py # Setup from environment variables @@ -21,4 +21,4 @@ python -u -m torch.distributed.run \ --rdzv_backend c10d \ --max_restarts 0 \ --tee 3 \ - $REPO_PATH/../run_train.py --config-file $EXAMPLE_PATH/config_mamba.yaml \ No newline at end of file + $REPO_PATH/mamba/train_mamba.py --config-file $EXAMPLE_PATH/mamba/config_mamba.yaml diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index e6a07a7a..1d3df0c6 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -7,12 +7,14 @@ class RandomInit: std: float + @dataclass class MambaInit: # mamba_ssm.models.mixer_seq_simple._init_weights initializer_range: float = 0.02 - rescale_prenorm_residual: bool = True, - n_residuals_per_layer: int = 1, # Change to 2 if we have MLP + rescale_prenorm_residual: bool = (True,) + n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP + @dataclass class ExistingCheckpointInit: @@ -20,47 +22,6 @@ class ExistingCheckpointInit: path: Path -@dataclass -class MambaConfig: - """Configuration for a Mamba model - - Be careful on having a coherent typing as we use it to reconstruct the model from yaml - """ - - is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion - d_model: int = 2560 - num_hidden_layers: int = 64 - vocab_size: int = 50277 - ssm_cfg: Optional[dict] = None - rms_norm: bool = True - fused_add_norm: bool = True - residual_in_fp32: bool = True - pad_vocab_size_multiple: int = 8 - # ==== Custom ====== - dtype: str = "float32" - rms_norm_eps: float = 1e-5 - pad_token_id: Optional[int] = None - - -@dataclass -class MambaFastConfig: - """Configuration for a Mamba model - - Be careful on having a coherent typing as we use it to reconstruct the model from yaml - """ - is_mamba_fast_config: bool = True # We use this help differentiate models in yaml/python conversion - d_model: int = 2560 - num_hidden_layers: int = 64 - vocab_size: int = 50277 - ssm_cfg: Optional[dict] = None - rms_norm: bool = True - fused_add_norm: bool = True - residual_in_fp32: bool = True - pad_vocab_size_multiple: int = 8 - # ==== Custom ====== - dtype: str = "float32" - rms_norm_eps: float = 1e-5 - pad_token_id: Optional[int] = None @dataclass class LlamaConfig: @@ -160,5 +121,4 @@ def n_inner(self): return self.intermediate_size -NanotronConfigs = Union[MambaFastConfig, LlamaConfig, MambaConfig, Starcoder2Config, Any] - +NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] From a52c37fd4901939522f6ad716b84f7b5cde3c6c6 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 1 Mar 2024 13:57:18 +0000 Subject: [PATCH 41/57] update readme --- examples/mamba/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/mamba/README.md b/examples/mamba/README.md index bd06c9c5..d3c0669c 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -9,8 +9,6 @@ Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/na ## 🚀 Quickstart ```bash -# Generate a config file -python examples/moe/config_mamba.py pip install einops pip install causal-conv1d>=1.1.0,<1.2.0 From 1f54ae2ee78887dced1c642d8bd2ab66d7d98412 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 1 Mar 2024 13:58:27 +0000 Subject: [PATCH 42/57] discard run_generate for now --- run_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_generate.py b/run_generate.py index 99da5c7a..e280a5fd 100644 --- a/run_generate.py +++ b/run_generate.py @@ -60,7 +60,7 @@ def main(): assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist" - config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix(), is_run_generate=True) + config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix()) model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path From 7616ce75e025efbcdaa64e4970dd448ee5fe539e Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 10:26:59 +0000 Subject: [PATCH 43/57] revert dynamic weight decay for now --- src/nanotron/helpers.py | 35 +++----------------- src/nanotron/optim/named_optimizer.py | 47 ++++++++++++--------------- 2 files changed, 26 insertions(+), 56 deletions(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 72210fed..feeff76c 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -18,12 +18,7 @@ from nanotron import distributed as dist from nanotron import logging -from nanotron.config import ( - Config, - LRSchedulerArgs, - OptimizerArgs, - ParallelismArgs, -) +from nanotron.config import Config, LRSchedulerArgs, OptimizerArgs, ParallelismArgs from nanotron.distributed import ProcessGroup from nanotron.logging import LogItem, log_rank from nanotron.models.base import NanotronModel @@ -40,9 +35,7 @@ ) from nanotron.optim.zero import ZeroDistributedOptimizer from nanotron.parallel import ParallelContext -from nanotron.parallel.tensor_parallel.nn import ( - TensorParallelLinearMode, -) +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode from nanotron.random import ( RandomStates, get_current_random_state, @@ -166,25 +159,7 @@ def init_optimizer_and_grad_accumulator( # Fix the root_model module_id_to_prefix[id(unwrapped_model)] = "" - # named parameters - named_parameters = { - "decay": [], - "no_decay": [] - } - - # NOTE(fmom): Separate parameters who have weight decay and those who don't - # (based on _no_weight_decay attribute that is set in init_model_randomly of each model) - for name, param in unwrapped_model.named_parameters(): - if param.is_tied: - full_name = param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - else: - full_name = name - - if optimizer_args.weight_decay == 0.0 or (hasattr(param, "_no_weight_decay") and param._no_weight_decay): - named_parameters["no_decay"].append((full_name, param)) - else: - named_parameters["decay"].append((full_name, param)) - + named_parameters = list(unwrapped_model.get_named_params_with_correct_tied()) # Basic optimizer builder def basic_optimizer_builder(named_param_groups): @@ -196,8 +171,8 @@ def basic_optimizer_builder(named_param_groups): lr=optimizer_args.learning_rate_scheduler.learning_rate, eps=optimizer_args.adam_eps, betas=(optimizer_args.adam_beta1, optimizer_args.adam_beta2), - fused=optimizer_args.torch_adam_is_fused - ) + fused=optimizer_args.torch_adam_is_fused, + ), ) optimizer_builder = basic_optimizer_builder diff --git a/src/nanotron/optim/named_optimizer.py b/src/nanotron/optim/named_optimizer.py index db7630cd..23614b05 100644 --- a/src/nanotron/optim/named_optimizer.py +++ b/src/nanotron/optim/named_optimizer.py @@ -12,35 +12,30 @@ def __init__( self, named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]], optimizer_builder: Callable[[Iterable[Dict[str, Any]]], torch.optim.Optimizer], - weight_decay: float = 0.0, - ): - id_to_name_decay, id_to_name_no_decay = {}, {} - - # Don't need to check that param_groups are overlapping since the optimizer will do it for me. - # https://github.com/pytorch/pytorch/blob/88b3810c94b45f5982df616e2bc4c471d173f491/torch/optim/optimizer.py#L473 - id_to_name_decay.update( - {id(param): name for name, param in named_params_or_groups["decay"] if id(param) not in id_to_name_decay} - ) - id_to_name_no_decay.update( - {id(param): name for name, param in named_params_or_groups["no_decay"] if id(param) not in id_to_name_no_decay} - ) - - id_to_name = {**id_to_name_decay, **id_to_name_no_decay} + ): + named_param_groups = list(named_params_or_groups) + if len(named_param_groups) == 0 or not isinstance(named_param_groups[0], dict): + named_param_groups = [{"named_params": named_param_groups}] + + id_to_name = {} + params = [] + for named_param_group in named_param_groups: + assert "named_params" in named_param_group + # Don't need to check that param_groups are overlapping since the optimizer will do it for me. + # https://github.com/pytorch/pytorch/blob/88b3810c94b45f5982df616e2bc4c471d173f491/torch/optim/optimizer.py#L473 + id_to_name.update( + {id(param): name for name, param in named_param_group["named_params"] if id(param) not in id_to_name} + ) + params.append( + { + **{k: v for k, v in named_param_group.items() if k != "named_params"}, + "params": [param for _, param in named_param_group["named_params"]], + } + ) + name_to_id = {v: k for k, v in id_to_name.items()} assert len(id_to_name) == len(name_to_id) - #TODO(fmom) Pass weight decay value from config here - params = [ - { - "params": [param for _, param in named_params_or_groups["decay"]], - "weight_decay": weight_decay - }, - { - "params": [param for _, param in named_params_or_groups["no_decay"]], - "weight_decay": 0.0 - } - ] - # Sanity check for param_group in params: _params = param_group["params"] From e944f6366f5ba9186e6d68258e0003686ff03c69 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 1 Mar 2024 18:44:32 +0000 Subject: [PATCH 44/57] unifying init_model_randomly for other models --- examples/doremi/doremi/llama.py | 162 ++++++-------------- examples/moe/llamoe.py | 195 +++++++++--------------- src/nanotron/models/starcoder2.py | 245 ++++++++++-------------------- 3 files changed, 204 insertions(+), 398 deletions(-) diff --git a/examples/doremi/doremi/llama.py b/examples/doremi/doremi/llama.py index aa9f47eb..8ae9202a 100644 --- a/examples/doremi/doremi/llama.py +++ b/examples/doremi/doremi/llama.py @@ -1,6 +1,8 @@ +import math from typing import Dict, Optional, Union import torch +import torch.nn as nn from transformers import LlamaConfig from nanotron import logging @@ -26,15 +28,12 @@ class BaseLLaMa(NanotronModel): @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): """Initialize model parameters randomly. - Args: - init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ - scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ - Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ + model = self initialized_parameters = set() # Handle tensor parallelism @@ -42,125 +41,58 @@ def init_model_randomly(self, init_method, scaled_init_method): # Fix the root_model module_id_to_prefix[id(model)] = "" - for module_name, module in model.named_modules(): - if isinstance(module, TensorParallelColumnLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers - if full_param_name in initialized_parameters: - # Already initialized - continue + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") + module_name, param_name = param_name.rsplit(".", 1) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelRowLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" - if full_param_name in initialized_parameters: - # Already initialized - continue + if full_param_name in initialized_parameters: + # Already initialized + continue - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") + module = model.get_submodule(module_name) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if isinstance(module, TensorParallelColumnLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelRowLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TritonRMSNorm): - assert {"weight"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) assert initialized_parameters == { param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) diff --git a/examples/moe/llamoe.py b/examples/moe/llamoe.py index fb500cc2..83f23ab6 100644 --- a/examples/moe/llamoe.py +++ b/examples/moe/llamoe.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch LLaMa MoE model.""" +import math from typing import Dict, Optional, Union import torch @@ -24,6 +25,9 @@ ) from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding from moe import dMoE +from torch import nn +from torch.nn import init + from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs @@ -33,10 +37,7 @@ from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.parallel.pipeline_parallel.block import ( - PipelineBlock, - TensorPointer, -) +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( @@ -47,7 +48,6 @@ ) from nanotron.random import RandomStates from nanotron.utils import checkpoint_method -from torch import nn logger = logging.get_logger(__name__) @@ -834,12 +834,8 @@ def forward( return {"loss": loss} @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): """Initialize model parameters randomly. - Args: - init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/ - scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/ - Note: Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` """ @@ -850,122 +846,77 @@ def init_model_randomly(self, init_method, scaled_init_method): # Fix the root_model module_id_to_prefix[id(model)] = "" - # TODO @nouamane: initialization for dmoe - for module_name, module in model.named_modules(): + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + if isinstance(module, TensorParallelColumnLinear): - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TensorParallelRowLinear): - assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TritonRMSNorm): - assert {"weight"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() else: - full_param_name = f"{module_name}.weight" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - - # assert initialized_parameters == { - # param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) - # if param.is_tied - # else name - # for name, param in model.named_parameters() - # }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" - # TODO @nouamane: init dMoE + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, nn.Linear): + fan_in = None + + if "weight" == param_name: + fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight) + init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + elif "bias" == param_name: + bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0 + init.uniform_(module.bias, -bound, bound) + else: + raise ValueError(f"Who the fuck is {param_name}?") + + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" diff --git a/src/nanotron/models/starcoder2.py b/src/nanotron/models/starcoder2.py index d67361d5..81b5bca6 100644 --- a/src/nanotron/models/starcoder2.py +++ b/src/nanotron/models/starcoder2.py @@ -30,8 +30,9 @@ flash_attn_with_kvcache, ) from torch import nn -from torch.nn import LayerNorm, init +from torch.nn import LayerNorm from torch.nn import functional as F +from torch.nn import init from nanotron import distributed as dist from nanotron.config import ParallelismArgs, Starcoder2Config @@ -44,9 +45,15 @@ from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer -from nanotron.parallel.sharded_parameters import SplitConfig, mark_all_parameters_in_module_as_sharded +from nanotron.parallel.sharded_parameters import ( + SplitConfig, + mark_all_parameters_in_module_as_sharded, +) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.functional import column_linear, sharded_cross_entropy +from nanotron.parallel.tensor_parallel.functional import ( + column_linear, + sharded_cross_entropy, +) from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -1458,169 +1465,85 @@ def tie_custom_params(self) -> None: ) @torch.no_grad() - def init_model_randomly(self, init_method, scaled_init_method): + def init_model_randomly(self, config): + """Initialize model parameters randomly. + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ model = self - # Set to 0: LayerNorm bias / all bias initialized_parameters = set() # Handle tensor parallelism - with torch.no_grad(): - module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} - # Fix the root_model - module_id_to_prefix[id(model)] = "" - - for module_name, module in model.named_modules(): - if isinstance(module, TensorParallelColumnLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight", "bias"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelRowLinear): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight", "bias"} == {name for name, _ in module.named_parameters()} or {"weight"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - scaled_init_method(param) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, LayerNorm): - assert {"weight", "bias"} == {name for name, _ in module.named_parameters()} - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 - param.fill_(1) - elif "bias" == param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, MQAColumnLinears): - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - # TODO @thomasw21: handle the case there's no bias - assert {"q.weight", "q.bias", "kv.weight", "kv.bias"} == { - name for name, _ in module.named_parameters() - } - for param_name, param in module.named_parameters(): - assert isinstance(param, NanotronParameter) - if param.is_tied: - tied_info = param.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.{param_name}" - - if full_param_name in initialized_parameters: - # Already initialized - continue - - if ".weight" in param_name: - init_method(param) - elif ".bias" in param_name: - param.zero_() - else: - raise ValueError(f"Who the fuck is {param_name}?") - - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) - elif isinstance(module, TensorParallelEmbedding): - # TODO @thomasw21: Handle tied embeddings - # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96 - # What it does: - # - instantiate a buffer of the `full size` in fp32 - # - run init method on it - # - shard result to get only a specific shard - # Instead I'm lazy and just going to run init_method, since they are scalar independent - assert {"weight"} == {name for name, _ in module.named_parameters()} - - assert isinstance(module.weight, NanotronParameter) - if module.weight.is_tied: - tied_info = module.weight.get_tied_info() - full_param_name = tied_info.get_full_name_from_module_id_to_prefix( - module_id_to_prefix=module_id_to_prefix - ) - else: - full_param_name = f"{module_name}.weight" + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + std = config.model.init_method.std + sigma = config.model.init_method.std + num_layers = config.model.model_config.num_hidden_layers + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + + if isinstance(module, TensorParallelColumnLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, TensorParallelRowLinear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) + elif "bias" == param_name: + param.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, LayerNorm): + if "weight" == param_name: + # TODO @thomasw21: Sometimes we actually want 0 + module.weight.fill_(1) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, MQAColumnLinears): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + + elif isinstance(module, TensorParallelEmbedding): + nn.init.normal_(module.weight, mean=0.0, std=std) + else: + raise Exception(f"Parameter {full_param_name} was not intialized") - if full_param_name in initialized_parameters: - # Already initialized - continue + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) - init_method(module.weight) - assert full_param_name not in initialized_parameters - initialized_parameters.add(full_param_name) + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" @staticmethod def get_embeddings_lm_head_tied_names() -> List[str]: From 8224221b131da469bdbbca371b57e13c08353fc4 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Fri, 1 Mar 2024 19:10:19 +0000 Subject: [PATCH 45/57] decouple mamba logic from core --- examples/mamba/mamba/trainer.py | 68 ++++++++++++++++++++++++++++----- examples/mamba/train_mamba.py | 4 +- src/nanotron/config/config.py | 3 +- src/nanotron/trainer.py | 19 --------- 4 files changed, 62 insertions(+), 32 deletions(-) diff --git a/examples/mamba/mamba/trainer.py b/examples/mamba/mamba/trainer.py index 41d12e7b..1a5a5988 100644 --- a/examples/mamba/mamba/trainer.py +++ b/examples/mamba/mamba/trainer.py @@ -1,22 +1,72 @@ -from typing import Type, Union +from typing import Optional, Type, Union + +from torch.nn.parallel import DistributedDataParallel from nanotron import logging -from nanotron.config import Config, get_config_from_file +from nanotron.config import Config from nanotron.trainer import DistributedTrainer -try: - import wandb -except ImportError: - wandb = None - logger = logging.get_logger(__name__) +from nanotron import distributed as dist +from nanotron.config import Config, ExistingCheckpointInit, MambaInit +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.serialize import load_weights, parse_ckpt_path + class MambaTrainer(DistributedTrainer): def __init__( self, config_or_config_file: Union[Config, str], config_class: Type[Config] = Config, + model_config_class: Optional[Type] = None, + model_class: Type[NanotronModel] = None, ): - get_config_from_file(config_or_config_file, config_class=config_class) - super().__init__(config_or_config_file, config_class) + super().__init__(config_or_config_file, config_class, model_config_class, model_class) + + def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: + unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model + + # Load or initialize model weights + self.init_checkpoint_path = parse_ckpt_path(config=self.config) + reloaded_from_checkpoint = False + if self.init_checkpoint_path is not None: + # Reload from a training checkpoint + log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + self.param_shard_metadata = load_weights( + model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + ) + reloaded_from_checkpoint = True + if not reloaded_from_checkpoint: + log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO) + if isinstance(self.config.model.init_method, ExistingCheckpointInit): + # Initialize model from an pretrained model checkpoint + self.param_shard_metadata = load_weights( + model=unwrapped_model, + parallel_context=self.parallel_context, + root_folder=self.config.model.init_method.path, + ) + elif isinstance(self.config.model.init_method, MambaInit): + + unwrapped_model.init_model_randomly(config=self.config) + # Synchronize parameters so that the model is consistent + # sync all params across dp + for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg) + + # sync tied params across tied groups + for (_, group_ranks), param in sorted( + get_tied_id_to_param( + parameters=model.parameters(), + root_module=unwrapped_model, + ).items(), + key=lambda x: x[0], + ): + group = self.parallel_context.world_ranks_to_pg[group_ranks] + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) + else: + raise ValueError(f"Unsupported {self.config.model.init_method}") + + return model diff --git a/examples/mamba/train_mamba.py b/examples/mamba/train_mamba.py index 82b19f23..50e09d3a 100644 --- a/examples/mamba/train_mamba.py +++ b/examples/mamba/train_mamba.py @@ -4,9 +4,9 @@ from mamba.config_mamba import MambaConfig from mamba.mamba import MambaForTraining +from mamba.trainer import MambaTrainer from nanotron import logging -from nanotron.trainer import DistributedTrainer sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) @@ -26,7 +26,7 @@ def get_args(): config_file = args.config_file # Load trainer and data - trainer = DistributedTrainer(config_file, model_config_class=MambaConfig, model_class=MambaForTraining) + trainer = MambaTrainer(config_file, model_config_class=MambaConfig, model_class=MambaForTraining) dataloader = get_dataloader(trainer) # Train diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index ad402699..ec48f8bc 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -13,7 +13,6 @@ from nanotron.config.lighteval_config import LightEvalConfig from nanotron.config.models_config import ( ExistingCheckpointInit, - MambaInit, NanotronConfigs, RandomInit, ) @@ -177,7 +176,7 @@ class ModelArgs: """Arguments related to model architecture""" model_config: NanotronConfigs - init_method: Union[RandomInit, MambaInit, ExistingCheckpointInit] + init_method: Union[RandomInit, ExistingCheckpointInit] dtype: Optional[torch.dtype] = None make_vocab_size_divisible_by: int = 1 ddp_bucket_cap_mb: int = 25 diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 713fcf0e..cdc24092 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -26,7 +26,6 @@ from nanotron.config import ( Config, ExistingCheckpointInit, - MambaInit, ParallelismArgs, RandomInit, get_config_from_file, @@ -578,24 +577,6 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: parallel_context=self.parallel_context, root_folder=self.config.model.init_method.path, ) - elif isinstance(self.config.model.init_method, MambaInit): - - unwrapped_model.init_model_randomly(config=self.config) - # Synchronize parameters so that the model is consistent - # sync all params across dp - for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): - dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg) - - # sync tied params across tied groups - for (_, group_ranks), param in sorted( - get_tied_id_to_param( - parameters=model.parameters(), - root_module=unwrapped_model, - ).items(), - key=lambda x: x[0], - ): - group = self.parallel_context.world_ranks_to_pg[group_ranks] - dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group) elif isinstance(self.config.model.init_method, RandomInit): unwrapped_model.init_model_randomly(config=self.config) From d28c968156b5b5a92b190c0b0f06cfb9c63090c4 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 10:21:40 +0000 Subject: [PATCH 46/57] fix logging + init_method --- examples/mamba/mamba/config_mamba.py | 100 +++++++++++++++++++++++-- examples/mamba/mamba/config_mamba.yaml | 2 +- examples/mamba/mamba/trainer.py | 3 +- src/nanotron/config/models_config.py | 8 -- src/nanotron/trainer.py | 4 +- 5 files changed, 97 insertions(+), 20 deletions(-) diff --git a/examples/mamba/mamba/config_mamba.py b/examples/mamba/mamba/config_mamba.py index be4803b7..1b20a843 100644 --- a/examples/mamba/mamba/config_mamba.py +++ b/examples/mamba/mamba/config_mamba.py @@ -1,27 +1,113 @@ """ Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" import math import os -from dataclasses import dataclass -from typing import Optional +from dataclasses import dataclass, fields +from typing import Optional, Union + +import torch +import yaml from nanotron.config import ( CheckpointsArgs, - Config, DataArgs, + ExistingCheckpointInit, GeneralArgs, LoggingArgs, LRSchedulerArgs, - MambaInit, - ModelArgs, + NanotronConfigs, OptimizerArgs, ParallelismArgs, PretrainDatasetsArgs, + ProfilerArgs, TokenizerArgs, TokensArgs, + get_config_from_file, ) +from nanotron.config.lighteval_config import LightEvalConfig +from nanotron.config.utils_config import cast_str_to_torch_dtype, serialize from nanotron.logging import human_format +@dataclass +class MambaInit: + # mamba_ssm.models.mixer_seq_simple._init_weights + initializer_range: float = 0.02 + rescale_prenorm_residual: bool = (True,) + n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP + + +@dataclass +class ModelArgs: + """Arguments related to model architecture""" + + model_config: NanotronConfigs + init_method: Union[MambaInit, ExistingCheckpointInit] + dtype: Optional[torch.dtype] = None + make_vocab_size_divisible_by: int = 1 + ddp_bucket_cap_mb: int = 25 + + def __post_init__(self): + if self.dtype is None: + self.dtype = torch.bfloat16 + if isinstance(self.dtype, str): + self.dtype = cast_str_to_torch_dtype(self.dtype) + + # if self.model_config.max_position_embeddings is None: + # self.model_config.max_position_embeddings = 0 + + +@dataclass +class Config: + """Main configuration class""" + + general: GeneralArgs + parallelism: ParallelismArgs + model: ModelArgs + tokenizer: TokenizerArgs + checkpoints: Optional[CheckpointsArgs] = None + logging: Optional[LoggingArgs] = None + tokens: Optional[TokensArgs] = None + optimizer: Optional[OptimizerArgs] = None + data: Optional[DataArgs] = None + profiler: Optional[ProfilerArgs] = None + lighteval: Optional[LightEvalConfig] = None + + @classmethod + def create_empty(cls): + cls_fields = fields(cls) + return cls(**{f.name: None for f in cls_fields}) + + def __post_init__(self): + # Some final sanity checks across separate arguments sections: + if self.profiler is not None and self.profiler.profiler_export_path is not None: + assert self.tokens.train_steps < 10 + + if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None: + self.optimizer.learning_rate_scheduler.lr_decay_steps = ( + self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps + ) + + # # if lighteval, we need tokenizer to be defined + # if self.checkpoints.lighteval is not None: + # assert self.tokenizer.tokenizer_name_or_path is not None + + @property + def global_batch_size(self): + return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp + + def save_as_yaml(self, file_path: str): + config_dict = serialize(self) + file_path = str(file_path) + with open(file_path, "w") as f: + yaml.dump(config_dict, f) + + # Sanity test config can be reloaded + _ = get_config_from_file(file_path, config_class=self.__class__) + + def as_dict(self) -> dict: + return serialize(self) + + @dataclass class MambaConfig: """Configuration for a Mamba model @@ -117,7 +203,7 @@ class MambaConfig: zero_stage=0, weight_decay=0.01, clip_grad=1.0, - accumulate_grad_in_fp32=False, # NOTE(fmom): because we are using PP=TP=DP=1 + accumulate_grad_in_fp32=True, # NOTE(fmom): because we are using PP=TP=DP=1 adam_eps=1e-08, adam_beta1=0.9, adam_beta2=0.95, @@ -171,5 +257,3 @@ class MambaConfig: # Save config as YAML file config.save_as_yaml(f"{dir}/config_mamba.yaml") - - # You can now train a model with this config using `/run_train.py` diff --git a/examples/mamba/mamba/config_mamba.yaml b/examples/mamba/mamba/config_mamba.yaml index 67ddd708..300ffdff 100644 --- a/examples/mamba/mamba/config_mamba.yaml +++ b/examples/mamba/mamba/config_mamba.yaml @@ -62,7 +62,7 @@ model: use_fast_path: true vocab_size: 50277 optimizer: - accumulate_grad_in_fp32: false + accumulate_grad_in_fp32: true adam_beta1: 0.9 adam_beta2: 0.95 adam_eps: 1.0e-08 diff --git a/examples/mamba/mamba/trainer.py b/examples/mamba/mamba/trainer.py index 1a5a5988..38bb3990 100644 --- a/examples/mamba/mamba/trainer.py +++ b/examples/mamba/mamba/trainer.py @@ -9,12 +9,13 @@ logger = logging.get_logger(__name__) from nanotron import distributed as dist -from nanotron.config import Config, ExistingCheckpointInit, MambaInit from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.parallel.tied_parameters import get_tied_id_to_param from nanotron.serialize import load_weights, parse_ckpt_path +from .config_mamba import ExistingCheckpointInit, MambaInit + class MambaTrainer(DistributedTrainer): def __init__( diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 1d3df0c6..610acc06 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -8,14 +8,6 @@ class RandomInit: std: float -@dataclass -class MambaInit: - # mamba_ssm.models.mixer_seq_simple._init_weights - initializer_range: float = 0.02 - rescale_prenorm_residual: bool = (True,) - n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP - - @dataclass class ExistingCheckpointInit: """This is used to initialize from an already existing model (without optimizer, lr_scheduler...)""" diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index cdc24092..bfbb3696 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -242,7 +242,7 @@ def post_init(self): def pre_training(self, *args, **kwargs): current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") - if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks and wandb is not None: + if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.init( project=self.config.general.project, name=f"{current_time}_{self.config.general.project}_{self.config.general.run}", @@ -481,7 +481,7 @@ def train_step_logs( ] ) - if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks and wandb is not None: + if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.log( {**{log_item.tag: log_item.scalar_value for log_item in log_entries}, "step": self.iteration_step} ) From 576eb560f05938c27c929a21f220aad8827fbc25 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 10:22:43 +0000 Subject: [PATCH 47/57] delete yaml file --- examples/mamba/config_mamba.yaml | 101 ------------------------------- 1 file changed, 101 deletions(-) delete mode 100644 examples/mamba/config_mamba.yaml diff --git a/examples/mamba/config_mamba.yaml b/examples/mamba/config_mamba.yaml deleted file mode 100644 index 85492e25..00000000 --- a/examples/mamba/config_mamba.yaml +++ /dev/null @@ -1,101 +0,0 @@ -checkpoints: - checkpoint_interval: 10 - checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/checkpoints - checkpoints_path_is_shared_file_system: false - resume_checkpoint_path: null - save_initial_state: false -data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 24 - hf_dataset_config_name: null - hf_dataset_or_datasets: - roneneldan/TinyStories: 1.0 - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 -general: - benchmark_csv_path: null - consumed_train_samples: null - ignore_sanity_checks: true - project: test - run: mamba - seed: 42 - step: null -lighteval: null -logging: - iteration_step_info_interval: 1 - log_level: info - log_level_replica: info -model: - ddp_bucket_cap_mb: 25 - dtype: bfloat16 - init_method: - initializer_range: 0.02 - n_residuals_per_layer: 1 - rescale_prenorm_residual: true - make_vocab_size_divisible_by: 1 - model_config: - d_model: 1536 - dtype: bfloat16 - fused_add_norm: true - is_mamba_config: true - num_hidden_layers: 48 - pad_token_id: null - pad_vocab_size_multiple: 8 - residual_in_fp32: true - rms_norm: true - rms_norm_eps: 1.0e-05 - ssm_cfg: - bias: false - conv_bias: true - d_conv: 4 - d_state: 16 - dt_init: random - dt_init_floor: 0.0001 - dt_max: 0.1 - dt_min: 0.001 - dt_rank: auto - dt_scale: 1.0 - expand: 2 - use_fast_path: true - vocab_size: 50277 -optimizer: - accumulate_grad_in_fp32: false - adam_beta1: 0.9 - adam_beta2: 0.95 - adam_eps: 1.0e-08 - clip_grad: 1.0 - learning_rate_scheduler: - learning_rate: 0.0003 - lr_decay_starting_step: null - lr_decay_steps: 90 - lr_decay_style: cosine - lr_warmup_steps: 10 - lr_warmup_style: linear - min_decay_lr: 1.0e-05 - torch_adam_is_fused: true - weight_decay: 0.01 - zero_stage: 0 -parallelism: - dp: 2 - expert_parallel_size: 1 - pp: 2 - pp_engine: 1f1b - tp: 2 - tp_linear_async_communication: false - tp_mode: ALL_REDUCE -profiler: null -tokenizer: - tokenizer_max_length: null - tokenizer_name_or_path: gpt2 - tokenizer_revision: null -tokens: - batch_accumulation_per_replica: 1 - limit_test_batches: 0 - limit_val_batches: 0 - micro_batch_size: 2 - sequence_length: 2048 - train_steps: 100 - val_check_interval: -1 From ea2fa96aeae68829f7cb8e5114d6f87a3aff0028 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 10:37:42 +0000 Subject: [PATCH 48/57] small fix --- examples/mamba/mamba/trainer.py | 2 +- src/nanotron/helpers.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/mamba/mamba/trainer.py b/examples/mamba/mamba/trainer.py index 38bb3990..ff339794 100644 --- a/examples/mamba/mamba/trainer.py +++ b/examples/mamba/mamba/trainer.py @@ -3,7 +3,7 @@ from torch.nn.parallel import DistributedDataParallel from nanotron import logging -from nanotron.config import Config +from .config_mamba import Config from nanotron.trainer import DistributedTrainer logger = logging.get_logger(__name__) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index feeff76c..ad6f2418 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -165,7 +165,6 @@ def init_optimizer_and_grad_accumulator( def basic_optimizer_builder(named_param_groups): return NamedOptimizer( named_params_or_groups=named_param_groups, - weight_decay=optimizer_args.weight_decay, optimizer_builder=lambda param_groups: AdamW( # pylint: disable=E0601 param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, From a8448bb901c6e78279d1b2ec6eb075ee26654aa8 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 10:58:49 +0000 Subject: [PATCH 49/57] fix tp assert --- examples/mamba/mamba/mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mamba/mamba/mamba.py b/examples/mamba/mamba/mamba.py index 1ded37ea..a1659dc2 100644 --- a/examples/mamba/mamba/mamba.py +++ b/examples/mamba/mamba/mamba.py @@ -102,7 +102,7 @@ def __init__( self.layer_idx = layer_idx tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - assert tp_mode == TensorParallelLinearMode.REDUCE_SCATTER or not parallel_config.tp_linear_async_communication + assert tp_mode == TensorParallelLinearMode.ALL_REDUCE or parallel_config.tp_linear_async_communication == False "Only ALL_REDUCE and tp_linear_async_communication=False are supported" tp_linear_async_communication = ( From 106d59ce02229eefe8d3fc1b6d9fdd886005d5c0 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 13:16:25 +0000 Subject: [PATCH 50/57] separate config mamba --- examples/mamba/mamba/config.py | 126 +++++++++++++++++ examples/mamba/mamba/config_mamba.yaml | 6 +- ...config_mamba.py => create_config_mamba.py} | 127 ++---------------- examples/mamba/mamba/mamba.py | 4 +- examples/mamba/mamba/trainer.py | 5 +- examples/mamba/train_mamba.py | 2 +- examples/mamba/train_mamba.sh | 4 +- 7 files changed, 146 insertions(+), 128 deletions(-) create mode 100644 examples/mamba/mamba/config.py rename examples/mamba/mamba/{config_mamba.py => create_config_mamba.py} (53%) diff --git a/examples/mamba/mamba/config.py b/examples/mamba/mamba/config.py new file mode 100644 index 00000000..a11d23bc --- /dev/null +++ b/examples/mamba/mamba/config.py @@ -0,0 +1,126 @@ +from dataclasses import dataclass, fields +from typing import Optional, Union + +import torch +import yaml + +from nanotron.config import ( + CheckpointsArgs, + DataArgs, + ExistingCheckpointInit, + GeneralArgs, + LoggingArgs, + LRSchedulerArgs, + PretrainDatasetsArgs, + NanotronConfigs, + OptimizerArgs, + ParallelismArgs, + ProfilerArgs, + TokenizerArgs, + TokensArgs, + get_config_from_file, +) +from nanotron.config.lighteval_config import LightEvalConfig +from nanotron.config.utils_config import cast_str_to_torch_dtype, serialize + + +@dataclass +class MambaInit: + # mamba_ssm.models.mixer_seq_simple._init_weights + initializer_range: float = 0.02 + rescale_prenorm_residual: bool = (True,) + n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP + + +@dataclass +class ModelArgs: + """Arguments related to model architecture""" + + model_config: NanotronConfigs + init_method: Union[MambaInit, ExistingCheckpointInit] + dtype: Optional[torch.dtype] = None + make_vocab_size_divisible_by: int = 1 + ddp_bucket_cap_mb: int = 25 + + def __post_init__(self): + if self.dtype is None: + self.dtype = torch.bfloat16 + if isinstance(self.dtype, str): + self.dtype = cast_str_to_torch_dtype(self.dtype) + + # if self.model_config.max_position_embeddings is None: + # self.model_config.max_position_embeddings = 0 + + +@dataclass +class Config: + """Main configuration class""" + + general: GeneralArgs + parallelism: ParallelismArgs + model: ModelArgs + tokenizer: TokenizerArgs + checkpoints: Optional[CheckpointsArgs] = None + logging: Optional[LoggingArgs] = None + tokens: Optional[TokensArgs] = None + optimizer: Optional[OptimizerArgs] = None + data: Optional[DataArgs] = None + profiler: Optional[ProfilerArgs] = None + lighteval: Optional[LightEvalConfig] = None + + @classmethod + def create_empty(cls): + cls_fields = fields(cls) + return cls(**{f.name: None for f in cls_fields}) + + def __post_init__(self): + # Some final sanity checks across separate arguments sections: + if self.profiler is not None and self.profiler.profiler_export_path is not None: + assert self.tokens.train_steps < 10 + + if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None: + self.optimizer.learning_rate_scheduler.lr_decay_steps = ( + self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps + ) + + # # if lighteval, we need tokenizer to be defined + # if self.checkpoints.lighteval is not None: + # assert self.tokenizer.tokenizer_name_or_path is not None + + @property + def global_batch_size(self): + return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp + + def save_as_yaml(self, file_path: str): + config_dict = serialize(self) + file_path = str(file_path) + with open(file_path, "w") as f: + yaml.dump(config_dict, f) + + # Sanity test config can be reloaded + _ = get_config_from_file(file_path, config_class=self.__class__) + + def as_dict(self) -> dict: + return serialize(self) + + +@dataclass +class MambaConfig: + """Configuration for a Mamba model + + Be careful on having a coherent typing as we use it to reconstruct the model from yaml + """ + + is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion + d_model: int = 2560 + num_hidden_layers: int = 64 + vocab_size: int = 50277 + ssm_cfg: Optional[dict] = None + rms_norm: bool = True + fused_add_norm: bool = True + residual_in_fp32: bool = True + pad_vocab_size_multiple: int = 8 + # ==== Custom ====== + dtype: str = "float32" + rms_norm_eps: float = 1e-5 + pad_token_id: Optional[int] = None diff --git a/examples/mamba/mamba/config_mamba.yaml b/examples/mamba/mamba/config_mamba.yaml index 300ffdff..d896f70f 100644 --- a/examples/mamba/mamba/config_mamba.yaml +++ b/examples/mamba/mamba/config_mamba.yaml @@ -79,11 +79,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 2 + dp: 1 expert_parallel_size: 1 - pp: 2 + pp: 1 pp_engine: 1f1b - tp: 2 + tp: 1 tp_linear_async_communication: false tp_mode: ALL_REDUCE profiler: null diff --git a/examples/mamba/mamba/config_mamba.py b/examples/mamba/mamba/create_config_mamba.py similarity index 53% rename from examples/mamba/mamba/config_mamba.py rename to examples/mamba/mamba/create_config_mamba.py index 1b20a843..f2026375 100644 --- a/examples/mamba/mamba/config_mamba.py +++ b/examples/mamba/mamba/create_config_mamba.py @@ -1,134 +1,25 @@ """ Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" import math import os -from dataclasses import dataclass, fields -from typing import Optional, Union -import torch -import yaml - -from nanotron.config import ( +from config import ( CheckpointsArgs, + Config, DataArgs, - ExistingCheckpointInit, GeneralArgs, LoggingArgs, LRSchedulerArgs, - NanotronConfigs, + MambaConfig, + MambaInit, + ModelArgs, OptimizerArgs, ParallelismArgs, PretrainDatasetsArgs, - ProfilerArgs, TokenizerArgs, TokensArgs, - get_config_from_file, ) -from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.utils_config import cast_str_to_torch_dtype, serialize -from nanotron.logging import human_format - - -@dataclass -class MambaInit: - # mamba_ssm.models.mixer_seq_simple._init_weights - initializer_range: float = 0.02 - rescale_prenorm_residual: bool = (True,) - n_residuals_per_layer: int = (1,) # Change to 2 if we have MLP - - -@dataclass -class ModelArgs: - """Arguments related to model architecture""" - - model_config: NanotronConfigs - init_method: Union[MambaInit, ExistingCheckpointInit] - dtype: Optional[torch.dtype] = None - make_vocab_size_divisible_by: int = 1 - ddp_bucket_cap_mb: int = 25 - - def __post_init__(self): - if self.dtype is None: - self.dtype = torch.bfloat16 - if isinstance(self.dtype, str): - self.dtype = cast_str_to_torch_dtype(self.dtype) - - # if self.model_config.max_position_embeddings is None: - # self.model_config.max_position_embeddings = 0 - - -@dataclass -class Config: - """Main configuration class""" - - general: GeneralArgs - parallelism: ParallelismArgs - model: ModelArgs - tokenizer: TokenizerArgs - checkpoints: Optional[CheckpointsArgs] = None - logging: Optional[LoggingArgs] = None - tokens: Optional[TokensArgs] = None - optimizer: Optional[OptimizerArgs] = None - data: Optional[DataArgs] = None - profiler: Optional[ProfilerArgs] = None - lighteval: Optional[LightEvalConfig] = None - - @classmethod - def create_empty(cls): - cls_fields = fields(cls) - return cls(**{f.name: None for f in cls_fields}) - - def __post_init__(self): - # Some final sanity checks across separate arguments sections: - if self.profiler is not None and self.profiler.profiler_export_path is not None: - assert self.tokens.train_steps < 10 - - if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None: - self.optimizer.learning_rate_scheduler.lr_decay_steps = ( - self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps - ) - - # # if lighteval, we need tokenizer to be defined - # if self.checkpoints.lighteval is not None: - # assert self.tokenizer.tokenizer_name_or_path is not None - - @property - def global_batch_size(self): - return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp - - def save_as_yaml(self, file_path: str): - config_dict = serialize(self) - file_path = str(file_path) - with open(file_path, "w") as f: - yaml.dump(config_dict, f) - - # Sanity test config can be reloaded - _ = get_config_from_file(file_path, config_class=self.__class__) - - def as_dict(self) -> dict: - return serialize(self) - - -@dataclass -class MambaConfig: - """Configuration for a Mamba model - - Be careful on having a coherent typing as we use it to reconstruct the model from yaml - """ - - is_mamba_config: bool = True # We use this help differentiate models in yaml/python conversion - d_model: int = 2560 - num_hidden_layers: int = 64 - vocab_size: int = 50277 - ssm_cfg: Optional[dict] = None - rms_norm: bool = True - fused_add_norm: bool = True - residual_in_fp32: bool = True - pad_vocab_size_multiple: int = 8 - # ==== Custom ====== - dtype: str = "float32" - rms_norm_eps: float = 1e-5 - pad_token_id: Optional[int] = None +from nanotron.logging import human_format ssm_cfg_dtype = "bfloat16" ssm_cfg = { @@ -214,9 +105,9 @@ class MambaConfig: ) parallelism = ParallelismArgs( - dp=2, - pp=2, - tp=2, + dp=1, + pp=1, + tp=1, pp_engine="1f1b", tp_mode="ALL_REDUCE", tp_linear_async_communication=False, diff --git a/examples/mamba/mamba/mamba.py b/examples/mamba/mamba/mamba.py index a1659dc2..c458791d 100644 --- a/examples/mamba/mamba/mamba.py +++ b/examples/mamba/mamba/mamba.py @@ -45,7 +45,7 @@ ) from nanotron.random import RandomStates -from .config_mamba import MambaConfig +from .config import MambaConfig from .selective_scan_interface import mamba_inner_fn, selective_scan_fn try: @@ -102,7 +102,7 @@ def __init__( self.layer_idx = layer_idx tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - assert tp_mode == TensorParallelLinearMode.ALL_REDUCE or parallel_config.tp_linear_async_communication == False + assert tp_mode == TensorParallelLinearMode.ALL_REDUCE or parallel_config.tp_linear_async_communication is False "Only ALL_REDUCE and tp_linear_async_communication=False are supported" tp_linear_async_communication = ( diff --git a/examples/mamba/mamba/trainer.py b/examples/mamba/mamba/trainer.py index ff339794..cf07de13 100644 --- a/examples/mamba/mamba/trainer.py +++ b/examples/mamba/mamba/trainer.py @@ -3,9 +3,10 @@ from torch.nn.parallel import DistributedDataParallel from nanotron import logging -from .config_mamba import Config from nanotron.trainer import DistributedTrainer +from .config import Config + logger = logging.get_logger(__name__) from nanotron import distributed as dist @@ -14,7 +15,7 @@ from nanotron.parallel.tied_parameters import get_tied_id_to_param from nanotron.serialize import load_weights, parse_ckpt_path -from .config_mamba import ExistingCheckpointInit, MambaInit +from .config import ExistingCheckpointInit, MambaInit class MambaTrainer(DistributedTrainer): diff --git a/examples/mamba/train_mamba.py b/examples/mamba/train_mamba.py index 50e09d3a..5a665fec 100644 --- a/examples/mamba/train_mamba.py +++ b/examples/mamba/train_mamba.py @@ -2,7 +2,7 @@ import os import sys -from mamba.config_mamba import MambaConfig +from mamba.config import MambaConfig from mamba.mamba import MambaForTraining from mamba.trainer import MambaTrainer diff --git a/examples/mamba/train_mamba.sh b/examples/mamba/train_mamba.sh index 08de0ef6..fa53e922 100755 --- a/examples/mamba/train_mamba.sh +++ b/examples/mamba/train_mamba.sh @@ -8,7 +8,7 @@ set -e -x EXAMPLE_PATH=$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P) REPO_PATH=$(dirname $EXAMPLE_PATH) -python $EXAMPLE_PATH/mamba/config_mamba.py +python $EXAMPLE_PATH/mamba/create_config_mamba.py # Setup from environment variables @@ -16,7 +16,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export FI_PROVIDER="efa" python -u -m torch.distributed.run \ - --nproc_per_node 8 \ + --nproc_per_node 1 \ --nnodes 1 \ --rdzv_backend c10d \ --max_restarts 0 \ From db889b94bebc8c27b38c092c29e015739ace9cdc Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 14:22:16 +0000 Subject: [PATCH 51/57] update requirements readme --- examples/mamba/README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/mamba/README.md b/examples/mamba/README.md index d3c0669c..63c0db1e 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -9,10 +9,11 @@ Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/na ## 🚀 Quickstart ```bash - +pip install torch==2.1.0 pip install einops -pip install causal-conv1d>=1.1.0,<1.2.0 -pip install mamba-ssm +pip install causal-conv1d==1.1.0 +pip install mamba-ssm==1.1.4 +pip install flash-attn==2.5.0 # Run training ./examples/mamba/train_mamba.sh From f4816ee5e781779759b77cf6ee3e3f74df2b0295 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 14:40:51 +0000 Subject: [PATCH 52/57] various fix --- examples/mamba/mamba/config.py | 76 ++------------------- examples/mamba/mamba/create_config_mamba.py | 6 +- examples/mamba/mamba/mamba.py | 10 +-- examples/mamba/mamba/trainer.py | 7 +- examples/mamba/train_mamba.py | 4 +- 5 files changed, 21 insertions(+), 82 deletions(-) diff --git a/examples/mamba/mamba/config.py b/examples/mamba/mamba/config.py index a11d23bc..c7bdc7b9 100644 --- a/examples/mamba/mamba/config.py +++ b/examples/mamba/mamba/config.py @@ -1,27 +1,10 @@ -from dataclasses import dataclass, fields +from dataclasses import dataclass from typing import Optional, Union import torch -import yaml - -from nanotron.config import ( - CheckpointsArgs, - DataArgs, - ExistingCheckpointInit, - GeneralArgs, - LoggingArgs, - LRSchedulerArgs, - PretrainDatasetsArgs, - NanotronConfigs, - OptimizerArgs, - ParallelismArgs, - ProfilerArgs, - TokenizerArgs, - TokensArgs, - get_config_from_file, -) -from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.utils_config import cast_str_to_torch_dtype, serialize + +from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs +from nanotron.config.utils_config import cast_str_to_torch_dtype @dataclass @@ -52,60 +35,15 @@ def __post_init__(self): # self.model_config.max_position_embeddings = 0 -@dataclass -class Config: +@dataclass(kw_only=True) # pylint: disable=unexpected-keyword-arg +class MambaConfig(Config): """Main configuration class""" - general: GeneralArgs - parallelism: ParallelismArgs model: ModelArgs - tokenizer: TokenizerArgs - checkpoints: Optional[CheckpointsArgs] = None - logging: Optional[LoggingArgs] = None - tokens: Optional[TokensArgs] = None - optimizer: Optional[OptimizerArgs] = None - data: Optional[DataArgs] = None - profiler: Optional[ProfilerArgs] = None - lighteval: Optional[LightEvalConfig] = None - - @classmethod - def create_empty(cls): - cls_fields = fields(cls) - return cls(**{f.name: None for f in cls_fields}) - - def __post_init__(self): - # Some final sanity checks across separate arguments sections: - if self.profiler is not None and self.profiler.profiler_export_path is not None: - assert self.tokens.train_steps < 10 - - if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None: - self.optimizer.learning_rate_scheduler.lr_decay_steps = ( - self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps - ) - - # # if lighteval, we need tokenizer to be defined - # if self.checkpoints.lighteval is not None: - # assert self.tokenizer.tokenizer_name_or_path is not None - - @property - def global_batch_size(self): - return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp - - def save_as_yaml(self, file_path: str): - config_dict = serialize(self) - file_path = str(file_path) - with open(file_path, "w") as f: - yaml.dump(config_dict, f) - - # Sanity test config can be reloaded - _ = get_config_from_file(file_path, config_class=self.__class__) - - def as_dict(self) -> dict: - return serialize(self) @dataclass -class MambaConfig: +class MambaModelConfig: """Configuration for a Mamba model Be careful on having a coherent typing as we use it to reconstruct the model from yaml diff --git a/examples/mamba/mamba/create_config_mamba.py b/examples/mamba/mamba/create_config_mamba.py index f2026375..9103ef8f 100644 --- a/examples/mamba/mamba/create_config_mamba.py +++ b/examples/mamba/mamba/create_config_mamba.py @@ -4,13 +4,13 @@ from config import ( CheckpointsArgs, - Config, DataArgs, GeneralArgs, LoggingArgs, LRSchedulerArgs, MambaConfig, MambaInit, + MambaModelConfig, ModelArgs, OptimizerArgs, ParallelismArgs, @@ -37,7 +37,7 @@ "use_fast_path": True, } # https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json -model_config = MambaConfig( +model_config = MambaModelConfig( d_model=1536, num_hidden_layers=48, vocab_size=50277, @@ -127,7 +127,7 @@ checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" os.makedirs(checkpoints_path, exist_ok=True) -config = Config( +config = MambaConfig( general=GeneralArgs(project="test", run="mamba", seed=seed), checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), parallelism=parallelism, diff --git a/examples/mamba/mamba/mamba.py b/examples/mamba/mamba/mamba.py index c458791d..a2bb0db8 100644 --- a/examples/mamba/mamba/mamba.py +++ b/examples/mamba/mamba/mamba.py @@ -45,7 +45,7 @@ ) from nanotron.random import RandomStates -from .config import MambaConfig +from .config import MambaModelConfig from .selective_scan_interface import mamba_inner_fn, selective_scan_fn try: @@ -407,7 +407,7 @@ class Embedding(nn.Module, AttachableStore): def __init__( self, tp_pg: dist.ProcessGroup, - config: MambaConfig, + config: MambaModelConfig, parallel_config: Optional[ParallelismArgs], ): super().__init__() @@ -441,7 +441,7 @@ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_ class MambaDecoderLayer(nn.Module): def __init__( self, - config: MambaConfig, + config: MambaModelConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int, @@ -522,7 +522,7 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) class MambaModel(nn.Module): def __init__( self, - config: MambaConfig, + config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, @@ -740,7 +740,7 @@ def forward( class MambaForTraining(NanotronModel): def __init__( self, - config: MambaConfig, + config: MambaModelConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], random_states: Optional[RandomStates] = None, diff --git a/examples/mamba/mamba/trainer.py b/examples/mamba/mamba/trainer.py index cf07de13..d080b7f2 100644 --- a/examples/mamba/mamba/trainer.py +++ b/examples/mamba/mamba/trainer.py @@ -5,7 +5,7 @@ from nanotron import logging from nanotron.trainer import DistributedTrainer -from .config import Config +from .config import MambaConfig logger = logging.get_logger(__name__) @@ -21,11 +21,12 @@ class MambaTrainer(DistributedTrainer): def __init__( self, - config_or_config_file: Union[Config, str], - config_class: Type[Config] = Config, + config_or_config_file: Union[MambaConfig, str], + config_class: Type[MambaConfig] = MambaConfig, model_config_class: Optional[Type] = None, model_class: Type[NanotronModel] = None, ): + assert config_class == MambaConfig super().__init__(config_or_config_file, config_class, model_config_class, model_class) def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: diff --git a/examples/mamba/train_mamba.py b/examples/mamba/train_mamba.py index 5a665fec..3df6ead2 100644 --- a/examples/mamba/train_mamba.py +++ b/examples/mamba/train_mamba.py @@ -2,7 +2,7 @@ import os import sys -from mamba.config import MambaConfig +from mamba.config import MambaModelConfig from mamba.mamba import MambaForTraining from mamba.trainer import MambaTrainer @@ -26,7 +26,7 @@ def get_args(): config_file = args.config_file # Load trainer and data - trainer = MambaTrainer(config_file, model_config_class=MambaConfig, model_class=MambaForTraining) + trainer = MambaTrainer(config_file, model_config_class=MambaModelConfig, model_class=MambaForTraining) dataloader = get_dataloader(trainer) # Train From 1098e19041d95ad58e537b5900ef8e832f895d58 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 15:04:47 +0000 Subject: [PATCH 53/57] fix import --- examples/mamba/mamba/create_config_mamba.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/mamba/mamba/create_config_mamba.py b/examples/mamba/mamba/create_config_mamba.py index 9103ef8f..17113fe4 100644 --- a/examples/mamba/mamba/create_config_mamba.py +++ b/examples/mamba/mamba/create_config_mamba.py @@ -2,15 +2,14 @@ import math import os -from config import ( +from config import MambaConfig, MambaInit, MambaModelConfig + +from nanotron.config import ( CheckpointsArgs, DataArgs, GeneralArgs, LoggingArgs, LRSchedulerArgs, - MambaConfig, - MambaInit, - MambaModelConfig, ModelArgs, OptimizerArgs, ParallelismArgs, @@ -18,7 +17,6 @@ TokenizerArgs, TokensArgs, ) - from nanotron.logging import human_format ssm_cfg_dtype = "bfloat16" From ae6af18addbba52f22c942b44a29572941ecde8d Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 15:21:04 +0000 Subject: [PATCH 54/57] change directory level --- examples/mamba/{mamba => }/config.py | 0 examples/mamba/{mamba => }/config_mamba.yaml | 2 +- examples/mamba/{mamba => }/create_config_mamba.py | 0 examples/mamba/{mamba => }/mamba.py | 5 ++--- examples/mamba/{mamba => }/selective_scan_interface.py | 0 examples/mamba/train_mamba.py | 6 +++--- examples/mamba/train_mamba.sh | 4 ++-- examples/mamba/{mamba => }/trainer.py | 5 +---- 8 files changed, 9 insertions(+), 13 deletions(-) rename examples/mamba/{mamba => }/config.py (100%) rename examples/mamba/{mamba => }/config_mamba.yaml (98%) rename examples/mamba/{mamba => }/create_config_mamba.py (100%) rename examples/mamba/{mamba => }/mamba.py (99%) rename examples/mamba/{mamba => }/selective_scan_interface.py (100%) rename examples/mamba/{mamba => }/trainer.py (97%) diff --git a/examples/mamba/mamba/config.py b/examples/mamba/config.py similarity index 100% rename from examples/mamba/mamba/config.py rename to examples/mamba/config.py diff --git a/examples/mamba/mamba/config_mamba.yaml b/examples/mamba/config_mamba.yaml similarity index 98% rename from examples/mamba/mamba/config_mamba.yaml rename to examples/mamba/config_mamba.yaml index d896f70f..dd15c79f 100644 --- a/examples/mamba/mamba/config_mamba.yaml +++ b/examples/mamba/config_mamba.yaml @@ -1,6 +1,6 @@ checkpoints: checkpoint_interval: 10 - checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/mamba/checkpoints + checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null save_initial_state: false diff --git a/examples/mamba/mamba/create_config_mamba.py b/examples/mamba/create_config_mamba.py similarity index 100% rename from examples/mamba/mamba/create_config_mamba.py rename to examples/mamba/create_config_mamba.py diff --git a/examples/mamba/mamba/mamba.py b/examples/mamba/mamba.py similarity index 99% rename from examples/mamba/mamba/mamba.py rename to examples/mamba/mamba.py index a2bb0db8..523c94c7 100644 --- a/examples/mamba/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -22,7 +22,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from config import MambaModelConfig from einops import rearrange, repeat +from selective_scan_interface import mamba_inner_fn, selective_scan_fn from torch.nn import init from nanotron import distributed as dist @@ -45,9 +47,6 @@ ) from nanotron.random import RandomStates -from .config import MambaModelConfig -from .selective_scan_interface import mamba_inner_fn, selective_scan_fn - try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: diff --git a/examples/mamba/mamba/selective_scan_interface.py b/examples/mamba/selective_scan_interface.py similarity index 100% rename from examples/mamba/mamba/selective_scan_interface.py rename to examples/mamba/selective_scan_interface.py diff --git a/examples/mamba/train_mamba.py b/examples/mamba/train_mamba.py index 3df6ead2..4d587fcf 100644 --- a/examples/mamba/train_mamba.py +++ b/examples/mamba/train_mamba.py @@ -2,9 +2,9 @@ import os import sys -from mamba.config import MambaModelConfig -from mamba.mamba import MambaForTraining -from mamba.trainer import MambaTrainer +from config import MambaModelConfig +from mamba import MambaForTraining +from trainer import MambaTrainer from nanotron import logging diff --git a/examples/mamba/train_mamba.sh b/examples/mamba/train_mamba.sh index fa53e922..6410d708 100755 --- a/examples/mamba/train_mamba.sh +++ b/examples/mamba/train_mamba.sh @@ -8,7 +8,7 @@ set -e -x EXAMPLE_PATH=$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P) REPO_PATH=$(dirname $EXAMPLE_PATH) -python $EXAMPLE_PATH/mamba/create_config_mamba.py +python $EXAMPLE_PATH/create_config_mamba.py # Setup from environment variables @@ -21,4 +21,4 @@ python -u -m torch.distributed.run \ --rdzv_backend c10d \ --max_restarts 0 \ --tee 3 \ - $REPO_PATH/mamba/train_mamba.py --config-file $EXAMPLE_PATH/mamba/config_mamba.yaml + $REPO_PATH/mamba/train_mamba.py --config-file $EXAMPLE_PATH/config_mamba.yaml diff --git a/examples/mamba/mamba/trainer.py b/examples/mamba/trainer.py similarity index 97% rename from examples/mamba/mamba/trainer.py rename to examples/mamba/trainer.py index d080b7f2..e0c010be 100644 --- a/examples/mamba/mamba/trainer.py +++ b/examples/mamba/trainer.py @@ -1,12 +1,11 @@ from typing import Optional, Type, Union +from config import ExistingCheckpointInit, MambaConfig, MambaInit from torch.nn.parallel import DistributedDataParallel from nanotron import logging from nanotron.trainer import DistributedTrainer -from .config import MambaConfig - logger = logging.get_logger(__name__) from nanotron import distributed as dist @@ -15,8 +14,6 @@ from nanotron.parallel.tied_parameters import get_tied_id_to_param from nanotron.serialize import load_weights, parse_ckpt_path -from .config import ExistingCheckpointInit, MambaInit - class MambaTrainer(DistributedTrainer): def __init__( From baba00e046f8aed88639dd379cb36214e50acd34 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 16:33:21 +0000 Subject: [PATCH 55/57] decouple Mamba logic from sync weight --- examples/mamba/README.md | 7 +-- examples/mamba/config_mamba.yaml | 6 +- examples/mamba/create_config_mamba.py | 8 +-- examples/mamba/requirements.txt | 5 ++ examples/mamba/train_mamba.sh | 2 +- examples/mamba/trainer.py | 85 ++++++++++++++++++++++++++- src/nanotron/serialize/main.py | 40 +++++++------ src/nanotron/trainer.py | 9 ++- 8 files changed, 126 insertions(+), 36 deletions(-) create mode 100644 examples/mamba/requirements.txt diff --git a/examples/mamba/README.md b/examples/mamba/README.md index 63c0db1e..5c31d07f 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -9,12 +9,7 @@ Modeling code for Mamba to use with [Nanotron](https://github.com/huggingface/na ## 🚀 Quickstart ```bash -pip install torch==2.1.0 -pip install einops -pip install causal-conv1d==1.1.0 -pip install mamba-ssm==1.1.4 -pip install flash-attn==2.5.0 - +pip install -r requirements.txt # Run training ./examples/mamba/train_mamba.sh ``` diff --git a/examples/mamba/config_mamba.yaml b/examples/mamba/config_mamba.yaml index dd15c79f..2d79880d 100644 --- a/examples/mamba/config_mamba.yaml +++ b/examples/mamba/config_mamba.yaml @@ -79,11 +79,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 1 + dp: 2 expert_parallel_size: 1 - pp: 1 + pp: 2 pp_engine: 1f1b - tp: 1 + tp: 2 tp_linear_async_communication: false tp_mode: ALL_REDUCE profiler: null diff --git a/examples/mamba/create_config_mamba.py b/examples/mamba/create_config_mamba.py index 17113fe4..40c211f9 100644 --- a/examples/mamba/create_config_mamba.py +++ b/examples/mamba/create_config_mamba.py @@ -103,9 +103,9 @@ ) parallelism = ParallelismArgs( - dp=1, - pp=1, - tp=1, + dp=2, + pp=2, + tp=2, pp_engine="1f1b", tp_mode="ALL_REDUCE", tp_linear_async_communication=False, @@ -126,7 +126,7 @@ os.makedirs(checkpoints_path, exist_ok=True) config = MambaConfig( - general=GeneralArgs(project="test", run="mamba", seed=seed), + general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True), checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), parallelism=parallelism, model=ModelArgs( diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt new file mode 100644 index 00000000..abc3bd62 --- /dev/null +++ b/examples/mamba/requirements.txt @@ -0,0 +1,5 @@ +torch==2.1.0 +einops +causal-conv1d==1.1.0 +mamba-ssm==1.1.4 +flash-attn==2.5.0 diff --git a/examples/mamba/train_mamba.sh b/examples/mamba/train_mamba.sh index 6410d708..36384c8c 100755 --- a/examples/mamba/train_mamba.sh +++ b/examples/mamba/train_mamba.sh @@ -16,7 +16,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export FI_PROVIDER="efa" python -u -m torch.distributed.run \ - --nproc_per_node 1 \ + --nproc_per_node 8 \ --nnodes 1 \ --rdzv_backend c10d \ --max_restarts 0 \ diff --git a/examples/mamba/trainer.py b/examples/mamba/trainer.py index e0c010be..e3dec27a 100644 --- a/examples/mamba/trainer.py +++ b/examples/mamba/trainer.py @@ -9,9 +9,21 @@ logger = logging.get_logger(__name__) from nanotron import distributed as dist +from nanotron.config import ParallelismArgs from nanotron.logging import log_rank from nanotron.models import NanotronModel -from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.parallel.tied_parameters import ( + create_pg_for_tied_weights, + get_tied_id_to_param, + tie_parameters, +) from nanotron.serialize import load_weights, parse_ckpt_path @@ -26,6 +38,77 @@ def __init__( assert config_class == MambaConfig super().__init__(config_or_config_file, config_class, model_config_class, model_class) + def _mark_tied_parameters( + self, + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, + ): + # Tie embeddings + embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names() + if len(embeddings_lm_head_tied_names) > 0: + shared_embeddings = [ + ( + target, + ( + parallel_context.world_rank_matrix[ + dist.get_rank(parallel_context.expert_pg), + get_pp_rank_of(target, module=model), + dist.get_rank(parallel_context.dp_pg), + dist.get_rank(parallel_context.tp_pg), + ], + ), + ) + for target in embeddings_lm_head_tied_names + ] + tie_parameters( + root_module=model, + ties=shared_embeddings, + parallel_context=parallel_context, + reduce_op=dist.ReduceOp.SUM, + ) + + # Tie custom params + model.tie_custom_params() + + # Sync all parameters that have the same name and that are not sharded + assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point" + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + name = f"{module_name}.{param_name}" + + if isinstance(param, NanotronParameter) and (param.is_sharded or param.is_tied): + continue + + if isinstance(module, TensorParallelRowLinear) and "bias" == param_name: + # bias for TensorParallelRowLinear only exists on TP=0 so we don't need to tie it + continue + + shared_weights = [ + ( + name, + # sync across TP group + tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))), + ) + ] + + if ( + parallel_config is None + or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE + or hasattr(model.config.model.model_config, "is_mamba_config") + ): + # We add `reduce_op=None` in order to signal that the weight are synced by design without needing to reduce + # when TP=2 we have LN that is duplicated across TP, so by design it's tied + reduce_op = None + else: + reduce_op = dist.ReduceOp.SUM + + tie_parameters( + root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op + ) + + create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context) + def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index ecae4597..a2a3d4aa 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -1,6 +1,5 @@ from pathlib import Path from typing import Optional -import os import torch from torch import nn @@ -14,9 +13,17 @@ from nanotron.logging import log_rank from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter -from nanotron.sanity_checks import assert_tensor_synced_across_pg, check_optim_state_in_sync +from nanotron.sanity_checks import ( + assert_tensor_synced_across_pg, + check_optim_state_in_sync, +) from nanotron.serialize.metadata import CheckpointMetadata, load_meta, save_meta -from nanotron.serialize.optimizer import load_lr_scheduler, load_optimizer, save_lr_scheduler, save_optimizer +from nanotron.serialize.optimizer import ( + load_lr_scheduler, + load_optimizer, + save_lr_scheduler, + save_optimizer, +) from nanotron.serialize.weights import load_weights, save_weights """ @@ -132,13 +139,10 @@ def save( tied_info = tied_param.get_tied_info() group_ranks = tied_info.global_ranks group = parallel_context.world_ranks_to_pg[group_ranks] - - # Conv1d and RMSNorm don't need to be synced for mamba - if not hasattr(config.model.model_config, "is_mamba_config"): - assert_tensor_synced_across_pg( - tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" - ) + assert_tensor_synced_across_pg( + tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" + ) if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): check_optim_state_in_sync(optimizer, parallel_context.dp_pg) @@ -182,16 +186,14 @@ def save( src=get_global_rank(group=group, group_rank=reference_rank), group=group, ) - - - if not hasattr(config.model.model_config, "is_mamba_config"): - torch.testing.assert_close( - tensor, - reference_tensor, - atol=0, - rtol=0, - msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}", - ) + + torch.testing.assert_close( + tensor, + reference_tensor, + atol=0, + rtol=0, + msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}", + ) ### dist.barrier(parallel_context.world_pg) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 81c0732e..3b7534ad 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -633,8 +633,8 @@ def _init_model( module.init_rotary_embeddings() # Mark some parameters as tied - mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) - + self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + # count number of parameters num_params = sum(p.numel() for p in model.parameters()) size_params = sum(p.numel() * p.element_size() for p in model.parameters()) @@ -770,6 +770,11 @@ def save_checkpoint(self) -> Path: return checkpoint_path + def _mark_tied_parameters(self, model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None): + mark_tied_parameters( + model=self.model, parallel_context=self.parallel_context, parallel_config=self.config.parallelism + ) + def mark_tied_parameters( model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None ): From 7b513e8c0542266cc4b6c8b3c146799f53c60c9a Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 16:36:27 +0000 Subject: [PATCH 56/57] fix decorator --- examples/mamba/mamba.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index 523c94c7..c8a222f8 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -700,9 +700,6 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch return model_flops_per_s, hardware_flops_per_s -torch.jit.script - - def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() From dbefcc75192faf56b4f9ce3dba7ce25a3046e3c2 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Mar 2024 16:50:29 +0000 Subject: [PATCH 57/57] small fixes --- .pre-commit-config.yaml | 16 ++++++++-------- examples/mamba/mamba.py | 5 ++--- src/nanotron/helpers.py | 1 + src/nanotron/trainer.py | 15 +++++++++------ 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5141302e..523d5ef1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,14 +19,14 @@ repos: args: - --fix - --exit-non-zero-on-fix - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: - - --profile=black - - --skip-glob=wandb/**/* - - --thirdparty=wandb + # - repo: https://github.com/PyCQA/isort + # rev: 5.12.0 + # hooks: + # - id: isort + # args: + # - --profile=black + # - --skip-glob=wandb/**/* + # - --thirdparty=wandb - repo: https://github.com/codespell-project/codespell rev: v2.1.0 hooks: diff --git a/examples/mamba/mamba.py b/examples/mamba/mamba.py index c8a222f8..5065ed53 100644 --- a/examples/mamba/mamba.py +++ b/examples/mamba/mamba.py @@ -24,9 +24,6 @@ import torch.nn.functional as F from config import MambaModelConfig from einops import rearrange, repeat -from selective_scan_interface import mamba_inner_fn, selective_scan_fn -from torch.nn import init - from nanotron import distributed as dist from nanotron import logging from nanotron.config import ParallelismArgs @@ -46,6 +43,8 @@ TensorParallelRowLinear, ) from nanotron.random import RandomStates +from selective_scan_interface import mamba_inner_fn, selective_scan_fn +from torch.nn import init try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index ad6f2418..17a859fe 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -167,6 +167,7 @@ def basic_optimizer_builder(named_param_groups): named_params_or_groups=named_param_groups, optimizer_builder=lambda param_groups: AdamW( # pylint: disable=E0601 param_groups, + weight_decay=optimizer_args.weight_decay, lr=optimizer_args.learning_rate_scheduler.learning_rate, eps=optimizer_args.adam_eps, betas=(optimizer_args.adam_beta1, optimizer_args.adam_beta2), diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3b7534ad..dc4a44dd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -634,7 +634,7 @@ def _init_model( # Mark some parameters as tied self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) - + # count number of parameters num_params = sum(p.numel() for p in model.parameters()) size_params = sum(p.numel() * p.element_size() for p in model.parameters()) @@ -769,12 +769,15 @@ def save_checkpoint(self) -> Path: return checkpoint_path + def _mark_tied_parameters( + self, + model: NanotronModel, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs] = None, + ): + mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + - def _mark_tied_parameters(self, model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None): - mark_tied_parameters( - model=self.model, parallel_context=self.parallel_context, parallel_config=self.config.parallelism - ) - def mark_tied_parameters( model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None ):