From 762b1ba4601e561b9a56ed852c37e8ac70991b9f Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Mon, 12 Aug 2024 12:03:12 -0700 Subject: [PATCH 01/11] add nemotron --- nemo/collections/llm/__init__.py | 3 + nemo/collections/llm/gpt/model/__init__.py | 8 + nemo/collections/llm/gpt/model/nemotron.py | 171 +++++++++++++++++++++ 3 files changed, 182 insertions(+) create mode 100644 nemo/collections/llm/gpt/model/nemotron.py diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 3ef8f6dd7fe4..12d18d6e282d 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -40,6 +40,9 @@ MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel, + NemotronModel, + NemotronConfig, + Nemotron3Config8B, gpt_data_step, gpt_forward_step, ) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index e63c45ca99cd..5cef77a2821e 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -28,6 +28,11 @@ LlamaConfig, LlamaModel, ) +from nemo.collections.llm.gpt.model.nemotron import ( + NemotronConfig, + Nemotron3Config8B, + NemotronModel, +) from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel @@ -44,6 +49,9 @@ "Llama2Config70B", "Llama3Config8B", "Llama3Config70B", + "NemotronConfig", + "Nemotron3Config8B", + "NemotronModel", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py new file mode 100644 index 000000000000..ae67ff3d8ce3 --- /dev/null +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -0,0 +1,171 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown +from nemo.collections.nlp.modules.common.megatron.utils import squared_relu + +if TYPE_CHECKING: + from transformers import NemotronConfig as HFNemotronConfig + from transformers import NemotronForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + + +@dataclass +class NemotronConfig(GPTConfig): + # configs that are common across model sizes + normalization: str = "LayerNorm" + activation_func: Callable = squared_relu + add_bias_linear: bool = False + seq_length: int = 4096 + position_embedding_type: str = "rope" + rotary_percent: float = 0.5 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + layernorm_zero_centered_gamma: bool = True # layernorm1p + init_method_std: float = 0.01 + # apply_query_key_layer_scaling: bool = True + share_embeddings_and_output_weights: bool = False + + +@dataclass +class Nemotron3Config8B(NemotronConfig): + num_layers: int = 32 + hidden_size: int = 4096 + ffn_hidden_size: int = 16384 + num_attention_heads: int = 32 + + +class NemotronModel(GPTModel): + def __init__( + self, + config: Annotated[Optional[NemotronConfig], Config[NemotronConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config or NemotronConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + +@io.model_importer(NemotronModel, "hf") +class HFNemotronImporter(io.ModelConnector["NemotronForCausalLM", NemotronModel]): + def init(self) -> NemotronModel: + return NemotronModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import NemotronForCausalLM + + source = NemotronForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Nemotron model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.up_proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv]) + + @property + def tokenizer(self) -> "AutoTokenizer": + # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + # return get_nmt_tokenizer(model_name='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', + # tokenizer_model='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model') + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> NemotronConfig: + from transformers import NemotronConfig as HFNemotronConfig + + source = HFNemotronConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = NemotronConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + seq_length=source.max_position_embeddings, + layernorm_epsilon=source.norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + rotary_percent=source.partial_rotary_factor, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights From 6beec49fbfd50769f4965aede4d81a3640093f4b Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Wed, 14 Aug 2024 09:47:53 -0700 Subject: [PATCH 02/11] add nemotron exporter. make converted model identical --- nemo/collections/llm/gpt/model/nemotron.py | 111 ++++++++++++++++++++- 1 file changed, 108 insertions(+), 3 deletions(-) diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index ae67ff3d8ce3..bd49fc0eecd6 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -35,6 +35,7 @@ class NemotronConfig(GPTConfig): init_method_std: float = 0.01 # apply_query_key_layer_scaling: bool = True share_embeddings_and_output_weights: bool = False + kv_channels: int = None @dataclass @@ -83,8 +84,11 @@ def convert_state(self, source, target): "model.layers.*.mlp.up_proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.input_layernorm.bias": "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.layers.*.post_attention_layernorm.bias": "decoder.layers.*.mlp.linear_fc1.layer_norm_bias", "model.norm.weight": "decoder.final_layernorm.weight", + "model.norm.bias": "decoder.final_layernorm.bias", "lm_head.weight": "output_layer.weight", } @@ -92,9 +96,9 @@ def convert_state(self, source, target): @property def tokenizer(self) -> "AutoTokenizer": - # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer - # return get_nmt_tokenizer(model_name='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', - # tokenizer_model='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model') + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + return get_nmt_tokenizer(model_name='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', + tokenizer_model='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model') from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer return AutoTokenizer(str(self)) @@ -128,6 +132,66 @@ def make_vocab_size_divisible_by(vocab_size): return output +@io.model_exporter(NemotronModel, "hf") +class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]): + def init(self) -> "NemotronForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc1.weight": "model.layers.*.mlp.up_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.layers.*.input_layernorm.bias", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.layers.*.post_attention_layernorm.bias", + "decoder.final_layernorm.weight": "model.norm.weight", + "decoder.final_layernorm.bias": "model.norm.bias", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv]) + + @property + def tokenizer(self): + return io.load_context(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFNemotronConfig": + source: NemotronConfig = io.load_context(str(self)).model.config + + from transformers import NemotronConfig as HFNemotronConfig + + return HFNemotronConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + head_dim=source.kv_channels if source.kv_channels is not None else source.hidden_size // source.num_attention_heads, + tie_word_embeddings=source.share_embeddings_and_output_weights, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + partial_rotary_factor=source.rotary_percent, + vocab_size=self.tokenizer.vocab_size, + ) @io.state_transform( source_key=( @@ -169,3 +233,44 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) return qkv_weights + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + +__all__ = [ + "NemotronConfig", + "Nemotron3Config8B", + "NemotronModel", +] From ee32130d078a073a1f6e235dc2cbdf3811333a28 Mon Sep 17 00:00:00 2001 From: suiyoubi Date: Wed, 14 Aug 2024 16:58:41 +0000 Subject: [PATCH 03/11] Apply isort and black reformatting Signed-off-by: suiyoubi --- nemo/collections/llm/__init__.py | 4 ++-- nemo/collections/llm/gpt/model/__init__.py | 6 +---- nemo/collections/llm/gpt/model/nemotron.py | 27 +++++++++++++++------- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 12d18d6e282d..caa1e5ee5473 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -40,9 +40,9 @@ MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel, - NemotronModel, - NemotronConfig, Nemotron3Config8B, + NemotronConfig, + NemotronModel, gpt_data_step, gpt_forward_step, ) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 5cef77a2821e..ec664728046f 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -28,13 +28,9 @@ LlamaConfig, LlamaModel, ) -from nemo.collections.llm.gpt.model.nemotron import ( - NemotronConfig, - Nemotron3Config8B, - NemotronModel, -) from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel +from nemo.collections.llm.gpt.model.nemotron import Nemotron3Config8B, NemotronConfig, NemotronModel __all__ = [ "GPTConfig", diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index bd49fc0eecd6..a713c148538e 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -8,8 +8,8 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config -from nemo.lightning import OptimizerModule, io, teardown from nemo.collections.nlp.modules.common.megatron.utils import squared_relu +from nemo.lightning import OptimizerModule, io, teardown if TYPE_CHECKING: from transformers import NemotronConfig as HFNemotronConfig @@ -19,7 +19,6 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec - @dataclass class NemotronConfig(GPTConfig): # configs that are common across model sizes @@ -31,12 +30,12 @@ class NemotronConfig(GPTConfig): rotary_percent: float = 0.5 hidden_dropout: float = 0.0 attention_dropout: float = 0.0 - layernorm_zero_centered_gamma: bool = True # layernorm1p + layernorm_zero_centered_gamma: bool = True # layernorm1p init_method_std: float = 0.01 # apply_query_key_layer_scaling: bool = True share_embeddings_and_output_weights: bool = False kv_channels: int = None - + @dataclass class Nemotron3Config8B(NemotronConfig): @@ -44,7 +43,7 @@ class Nemotron3Config8B(NemotronConfig): hidden_size: int = 4096 ffn_hidden_size: int = 16384 num_attention_heads: int = 32 - + class NemotronModel(GPTModel): def __init__( @@ -56,6 +55,7 @@ def __init__( ): super().__init__(config or NemotronConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + @io.model_importer(NemotronModel, "hf") class HFNemotronImporter(io.ModelConnector["NemotronForCausalLM", NemotronModel]): def init(self) -> NemotronModel: @@ -97,8 +97,11 @@ def convert_state(self, source, target): @property def tokenizer(self) -> "AutoTokenizer": from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer - return get_nmt_tokenizer(model_name='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', - tokenizer_model='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model') + + return get_nmt_tokenizer( + model_name='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', + tokenizer_model='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', + ) from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer return AutoTokenizer(str(self)) @@ -132,6 +135,7 @@ def make_vocab_size_divisible_by(vocab_size): return output + @io.model_exporter(NemotronModel, "hf") class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]): def init(self) -> "NemotronForCausalLM": @@ -182,7 +186,11 @@ def config(self) -> "HFNemotronConfig": hidden_size=source.hidden_size, intermediate_size=source.ffn_hidden_size, num_attention_heads=source.num_attention_heads, - head_dim=source.kv_channels if source.kv_channels is not None else source.hidden_size // source.num_attention_heads, + head_dim=( + source.kv_channels + if source.kv_channels is not None + else source.hidden_size // source.num_attention_heads + ), tie_word_embeddings=source.share_embeddings_and_output_weights, max_position_embeddings=source.seq_length, initializer_range=source.init_method_std, @@ -193,6 +201,7 @@ def config(self) -> "HFNemotronConfig": vocab_size=self.tokenizer.vocab_size, ) + @io.state_transform( source_key=( "model.layers.*.self_attn.q_proj.weight", @@ -234,6 +243,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): return qkv_weights + @io.state_transform( source_key="decoder.layers.*.self_attention.linear_qkv.weight", target_key=( @@ -269,6 +279,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): return q_proj, k_proj, v_proj + __all__ = [ "NemotronConfig", "Nemotron3Config8B", From de5d87bbe4fae589bc5114b9d73489386488e26b Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Thu, 15 Aug 2024 07:28:46 -0700 Subject: [PATCH 04/11] add more config --- nemo/collections/llm/__init__.py | 11 +++- nemo/collections/llm/fn/activation.py | 6 +++ nemo/collections/llm/gpt/model/__init__.py | 3 ++ nemo/collections/llm/gpt/model/nemotron.py | 58 ++++++++++++++-------- 4 files changed, 57 insertions(+), 21 deletions(-) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index caa1e5ee5473..32b32c3667af 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -40,9 +40,12 @@ MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel, + NemotronModel, + Nemotron3Config4B, Nemotron3Config8B, + Nemotron4Config15B, + Nemotron4Config340B, NemotronConfig, - NemotronModel, gpt_data_step, gpt_forward_step, ) @@ -60,6 +63,12 @@ "MixtralConfig8x7B", "MixtralConfig8x22B", "MixtralModel", + "NemotronModel", + "Nemotron3Config4B", + "Nemotron3Config8B", + "Nemotron4Config15B", + "Nemotron4Config340B", + "NemotronConfig", "LlamaConfig", "Llama2Config7B", "Llama2Config13B", diff --git a/nemo/collections/llm/fn/activation.py b/nemo/collections/llm/fn/activation.py index 89b5ba93f0f6..fb638ee31f86 100644 --- a/nemo/collections/llm/fn/activation.py +++ b/nemo/collections/llm/fn/activation.py @@ -9,3 +9,9 @@ def gelu_impl(x): def openai_gelu(x): return gelu_impl(x) + + +@torch.jit.script +def squared_relu(x): + """Squared ReLU activation function.""" + return torch.pow(torch.nn.functional.relu(x), 2) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index ec664728046f..5c12a0208b63 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -46,7 +46,10 @@ "Llama3Config8B", "Llama3Config70B", "NemotronConfig", + "Nemotron3Config4B", "Nemotron3Config8B", + "Nemotron4Config15B", + "Nemotron4Config340B", "NemotronModel", "CodeLlamaConfig7B", "CodeLlamaConfig13B", diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index a713c148538e..dda047ca45d8 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -3,13 +3,12 @@ from typing import TYPE_CHECKING, Annotated, Callable, Optional import torch -import torch.nn.functional as F from torch import nn from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config -from nemo.collections.nlp.modules.common.megatron.utils import squared_relu from nemo.lightning import OptimizerModule, io, teardown +from nemo.collections.llm.fn.activation import squared_relu if TYPE_CHECKING: from transformers import NemotronConfig as HFNemotronConfig @@ -19,6 +18,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + @dataclass class NemotronConfig(GPTConfig): # configs that are common across model sizes @@ -30,11 +30,22 @@ class NemotronConfig(GPTConfig): rotary_percent: float = 0.5 hidden_dropout: float = 0.0 attention_dropout: float = 0.0 - layernorm_zero_centered_gamma: bool = True # layernorm1p + layernorm_zero_centered_gamma: bool = True # layernorm1p init_method_std: float = 0.01 - # apply_query_key_layer_scaling: bool = True share_embeddings_and_output_weights: bool = False kv_channels: int = None + num_query_groups: int = None + + +@dataclass +class Nemotron3Config4B(NemotronConfig): + num_layers: int = 32 + hidden_size: int = 3072 + ffn_hidden_size: int = 9216 + kv_channels: int = 128 + num_query_groups: int = 8 + num_attention_heads: int = 24 + init_method_std: float = 0.0134 @dataclass @@ -44,6 +55,25 @@ class Nemotron3Config8B(NemotronConfig): ffn_hidden_size: int = 16384 num_attention_heads: int = 32 +@dataclass +class Nemotron4Config15B(NemotronConfig): + num_layers: int = 32 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + num_query_groups: int = 8 + init_method_std: float = 0.0134 + + +@dataclass +class Nemotron4Config340B(NemotronConfig): + num_layers: int = 96 + hidden_size: int = 18432 + ffn_hidden_size: int = 73728 + num_attention_heads: int = 96 + num_query_groups: int = 8 + init_method_std: float = 0.0063 + class NemotronModel(GPTModel): def __init__( @@ -55,7 +85,6 @@ def __init__( ): super().__init__(config or NemotronConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) - @io.model_importer(NemotronModel, "hf") class HFNemotronImporter(io.ModelConnector["NemotronForCausalLM", NemotronModel]): def init(self) -> NemotronModel: @@ -96,12 +125,6 @@ def convert_state(self, source, target): @property def tokenizer(self) -> "AutoTokenizer": - from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer - - return get_nmt_tokenizer( - model_name='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', - tokenizer_model='/aot/checkpoints/nemotron/nemotron3-8b/tokenizer.model', - ) from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer return AutoTokenizer(str(self)) @@ -135,7 +158,6 @@ def make_vocab_size_divisible_by(vocab_size): return output - @io.model_exporter(NemotronModel, "hf") class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]): def init(self) -> "NemotronForCausalLM": @@ -186,11 +208,7 @@ def config(self) -> "HFNemotronConfig": hidden_size=source.hidden_size, intermediate_size=source.ffn_hidden_size, num_attention_heads=source.num_attention_heads, - head_dim=( - source.kv_channels - if source.kv_channels is not None - else source.hidden_size // source.num_attention_heads - ), + head_dim=source.kv_channels if source.kv_channels is not None else source.hidden_size // source.num_attention_heads, tie_word_embeddings=source.share_embeddings_and_output_weights, max_position_embeddings=source.seq_length, initializer_range=source.init_method_std, @@ -201,7 +219,6 @@ def config(self) -> "HFNemotronConfig": vocab_size=self.tokenizer.vocab_size, ) - @io.state_transform( source_key=( "model.layers.*.self_attn.q_proj.weight", @@ -243,7 +260,6 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): return qkv_weights - @io.state_transform( source_key="decoder.layers.*.self_attention.linear_qkv.weight", target_key=( @@ -279,9 +295,11 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): return q_proj, k_proj, v_proj - __all__ = [ "NemotronConfig", + "Nemotron3Config4B", "Nemotron3Config8B", + "Nemotron4Config15B", + "Nemotron4Config340B", "NemotronModel", ] From 6755455a200193cae240027aa2e61ec6ea29b7aa Mon Sep 17 00:00:00 2001 From: suiyoubi Date: Thu, 15 Aug 2024 14:29:38 +0000 Subject: [PATCH 05/11] Apply isort and black reformatting Signed-off-by: suiyoubi --- nemo/collections/llm/__init__.py | 2 +- nemo/collections/llm/gpt/model/nemotron.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 32b32c3667af..4e0c6aae5f04 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -40,12 +40,12 @@ MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel, - NemotronModel, Nemotron3Config4B, Nemotron3Config8B, Nemotron4Config15B, Nemotron4Config340B, NemotronConfig, + NemotronModel, gpt_data_step, gpt_forward_step, ) diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index dda047ca45d8..cac3c8a9d8ef 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -5,10 +5,10 @@ import torch from torch import nn +from nemo.collections.llm.fn.activation import squared_relu from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.collections.llm.fn.activation import squared_relu if TYPE_CHECKING: from transformers import NemotronConfig as HFNemotronConfig @@ -18,7 +18,6 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec - @dataclass class NemotronConfig(GPTConfig): # configs that are common across model sizes @@ -30,7 +29,7 @@ class NemotronConfig(GPTConfig): rotary_percent: float = 0.5 hidden_dropout: float = 0.0 attention_dropout: float = 0.0 - layernorm_zero_centered_gamma: bool = True # layernorm1p + layernorm_zero_centered_gamma: bool = True # layernorm1p init_method_std: float = 0.01 share_embeddings_and_output_weights: bool = False kv_channels: int = None @@ -55,6 +54,7 @@ class Nemotron3Config8B(NemotronConfig): ffn_hidden_size: int = 16384 num_attention_heads: int = 32 + @dataclass class Nemotron4Config15B(NemotronConfig): num_layers: int = 32 @@ -85,6 +85,7 @@ def __init__( ): super().__init__(config or NemotronConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + @io.model_importer(NemotronModel, "hf") class HFNemotronImporter(io.ModelConnector["NemotronForCausalLM", NemotronModel]): def init(self) -> NemotronModel: @@ -158,6 +159,7 @@ def make_vocab_size_divisible_by(vocab_size): return output + @io.model_exporter(NemotronModel, "hf") class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]): def init(self) -> "NemotronForCausalLM": @@ -208,7 +210,11 @@ def config(self) -> "HFNemotronConfig": hidden_size=source.hidden_size, intermediate_size=source.ffn_hidden_size, num_attention_heads=source.num_attention_heads, - head_dim=source.kv_channels if source.kv_channels is not None else source.hidden_size // source.num_attention_heads, + head_dim=( + source.kv_channels + if source.kv_channels is not None + else source.hidden_size // source.num_attention_heads + ), tie_word_embeddings=source.share_embeddings_and_output_weights, max_position_embeddings=source.seq_length, initializer_range=source.init_method_std, @@ -219,6 +225,7 @@ def config(self) -> "HFNemotronConfig": vocab_size=self.tokenizer.vocab_size, ) + @io.state_transform( source_key=( "model.layers.*.self_attn.q_proj.weight", @@ -260,6 +267,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): return qkv_weights + @io.state_transform( source_key="decoder.layers.*.self_attention.linear_qkv.weight", target_key=( @@ -295,6 +303,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): return q_proj, k_proj, v_proj + __all__ = [ "NemotronConfig", "Nemotron3Config4B", From f8362cb0fb558ba20f7c6481d7f57c103551243c Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Thu, 15 Aug 2024 09:04:00 -0700 Subject: [PATCH 06/11] add config --- nemo/collections/llm/gpt/model/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 5c12a0208b63..d83de4f62100 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -30,7 +30,14 @@ ) from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel -from nemo.collections.llm.gpt.model.nemotron import Nemotron3Config8B, NemotronConfig, NemotronModel +from nemo.collections.llm.gpt.model.nemotron import ( + Nemotron3Config4B, + Nemotron3Config8B, + Nemotron4Config15B, + Nemotron4Config340B, + NemotronConfig, + NemotronModel +) __all__ = [ "GPTConfig", From c3826106f37810bb798e9b099f38940c27f30a09 Mon Sep 17 00:00:00 2001 From: suiyoubi Date: Thu, 15 Aug 2024 16:04:51 +0000 Subject: [PATCH 07/11] Apply isort and black reformatting Signed-off-by: suiyoubi --- nemo/collections/llm/gpt/model/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index d83de4f62100..74004cb210d4 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -35,8 +35,8 @@ Nemotron3Config8B, Nemotron4Config15B, Nemotron4Config340B, - NemotronConfig, - NemotronModel + NemotronConfig, + NemotronModel, ) __all__ = [ From 0037ec8cabb77ed219273c234c5844585a03470c Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Thu, 15 Aug 2024 10:40:57 -0700 Subject: [PATCH 08/11] import refactor --- nemo/collections/llm/gpt/model/nemotron.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index cac3c8a9d8ef..2f33d4ec0c95 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -10,11 +10,9 @@ from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown +from transformers import NemotronForCausalLM, NemotronConfig as HFNemotronConfig +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer if TYPE_CHECKING: - from transformers import NemotronConfig as HFNemotronConfig - from transformers import NemotronForCausalLM - - from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -92,8 +90,6 @@ def init(self) -> NemotronModel: return NemotronModel(self.config, tokenizer=self.tokenizer) def apply(self, output_path: Path) -> Path: - from transformers import NemotronForCausalLM - source = NemotronForCausalLM.from_pretrained(str(self)) target = self.init() trainer = self.nemo_setup(target) @@ -126,14 +122,10 @@ def convert_state(self, source, target): @property def tokenizer(self) -> "AutoTokenizer": - from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer - return AutoTokenizer(str(self)) @property def config(self) -> NemotronConfig: - from transformers import NemotronConfig as HFNemotronConfig - source = HFNemotronConfig.from_pretrained(str(self)) def make_vocab_size_divisible_by(vocab_size): @@ -163,8 +155,6 @@ def make_vocab_size_divisible_by(vocab_size): @io.model_exporter(NemotronModel, "hf") class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]): def init(self) -> "NemotronForCausalLM": - from transformers import AutoModelForCausalLM - return AutoModelForCausalLM.from_config(self.config) def apply(self, output_path: Path) -> Path: @@ -203,8 +193,6 @@ def tokenizer(self): def config(self) -> "HFNemotronConfig": source: NemotronConfig = io.load_context(str(self)).model.config - from transformers import NemotronConfig as HFNemotronConfig - return HFNemotronConfig( num_hidden_layers=source.num_layers, hidden_size=source.hidden_size, From 3ffa9129194a334169fbf96cb0e5a3c3a9ddbbe7 Mon Sep 17 00:00:00 2001 From: suiyoubi Date: Thu, 15 Aug 2024 17:41:48 +0000 Subject: [PATCH 09/11] Apply isort and black reformatting Signed-off-by: suiyoubi --- nemo/collections/llm/gpt/model/nemotron.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index 2f33d4ec0c95..be305f96929d 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -4,14 +4,15 @@ import torch from torch import nn +from transformers import NemotronConfig as HFNemotronConfig +from transformers import NemotronForCausalLM +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.llm.fn.activation import squared_relu from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from transformers import NemotronForCausalLM, NemotronConfig as HFNemotronConfig -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer if TYPE_CHECKING: from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec From fd64baf053d2035f96c5185468fa1808f7a1b227 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Thu, 15 Aug 2024 10:50:13 -0700 Subject: [PATCH 10/11] refactor config --- nemo/collections/llm/gpt/model/nemotron.py | 46 ++++++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index be305f96929d..4342faed60d9 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -22,55 +22,75 @@ class NemotronConfig(GPTConfig): # configs that are common across model sizes normalization: str = "LayerNorm" activation_func: Callable = squared_relu - add_bias_linear: bool = False - seq_length: int = 4096 position_embedding_type: str = "rope" - rotary_percent: float = 0.5 + share_embeddings_and_output_weights: bool = False + add_bias_linear: bool = False + hidden_dropout: float = 0.0 attention_dropout: float = 0.0 - layernorm_zero_centered_gamma: bool = True # layernorm1p - init_method_std: float = 0.01 - share_embeddings_and_output_weights: bool = False - kv_channels: int = None - num_query_groups: int = None + apply_query_key_layer_scaling: bool = True + rotary_percent: float = 0.5 + masked_softmax_fusion: bool = True + persist_layer_norm: bool = True + bias_dropout_add_fusion: bool = False + layernorm_zero_centered_gamma: bool = True + + # Nemotron3Config4B as default configs + num_layers: int = 32 + seq_length: int = 4096 + hidden_size: int = 3072 + ffn_hidden_size: int = 9216 + num_attention_heads: int = 24 + num_query_groups: Optional[int] = 8 + kv_channels: Optional[int] = 128 + init_method_std: float = 0.0134 @dataclass class Nemotron3Config4B(NemotronConfig): num_layers: int = 32 + seq_length: int = 4096 hidden_size: int = 3072 ffn_hidden_size: int = 9216 - kv_channels: int = 128 - num_query_groups: int = 8 num_attention_heads: int = 24 + num_query_groups: int = 8 + kv_channels: Optional[int] = 128 init_method_std: float = 0.0134 @dataclass class Nemotron3Config8B(NemotronConfig): num_layers: int = 32 + seq_length: int = 4096 hidden_size: int = 4096 ffn_hidden_size: int = 16384 num_attention_heads: int = 32 + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + init_method_std: float = 0.010 @dataclass class Nemotron4Config15B(NemotronConfig): num_layers: int = 32 + seq_length: int = 4096 hidden_size: int = 6144 ffn_hidden_size: int = 24576 num_attention_heads: int = 48 - num_query_groups: int = 8 + num_query_groups: Optional[int] = 8 + kv_channels: Optional[int] = None init_method_std: float = 0.0134 @dataclass class Nemotron4Config340B(NemotronConfig): num_layers: int = 96 + seq_length: int = 4096 hidden_size: int = 18432 ffn_hidden_size: int = 73728 num_attention_heads: int = 96 - num_query_groups: int = 8 + num_query_groups: Optional[int] = 8 + kv_channels: Optional[int] = None init_method_std: float = 0.0063 @@ -156,7 +176,7 @@ def make_vocab_size_divisible_by(vocab_size): @io.model_exporter(NemotronModel, "hf") class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]): def init(self) -> "NemotronForCausalLM": - return AutoModelForCausalLM.from_config(self.config) + return NemotronForCausalLM.from_config(self.config) def apply(self, output_path: Path) -> Path: target = self.init() From fde4d227cb557974c8ed36043473b17c4b812769 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Thu, 15 Aug 2024 11:02:11 -0700 Subject: [PATCH 11/11] add 22B config --- nemo/collections/llm/__init__.py | 2 ++ nemo/collections/llm/gpt/model/__init__.py | 2 ++ nemo/collections/llm/gpt/model/nemotron.py | 13 +++++++++++++ 3 files changed, 17 insertions(+) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 4e0c6aae5f04..7edd2e1204c3 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -43,6 +43,7 @@ Nemotron3Config4B, Nemotron3Config8B, Nemotron4Config15B, + Nemotron4Config22B, Nemotron4Config340B, NemotronConfig, NemotronModel, @@ -67,6 +68,7 @@ "Nemotron3Config4B", "Nemotron3Config8B", "Nemotron4Config15B", + "Nemotron4Config22B", "Nemotron4Config340B", "NemotronConfig", "LlamaConfig", diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 74004cb210d4..562b27725a60 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -34,6 +34,7 @@ Nemotron3Config4B, Nemotron3Config8B, Nemotron4Config15B, + Nemotron4Config22B, Nemotron4Config340B, NemotronConfig, NemotronModel, @@ -56,6 +57,7 @@ "Nemotron3Config4B", "Nemotron3Config8B", "Nemotron4Config15B", + "Nemotron4Config22B", "Nemotron4Config340B", "NemotronModel", "CodeLlamaConfig7B", diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index 4342faed60d9..dd659f7eedf7 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -82,6 +82,18 @@ class Nemotron4Config15B(NemotronConfig): init_method_std: float = 0.0134 +@dataclass +class Nemotron4Config22B(NemotronConfig): + num_layers: int = 40 + seq_length: int = 4096 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + num_query_groups: Optional[int] = None + kv_channels: Optional[int] = None + init_method_std: float = 0.008 + + @dataclass class Nemotron4Config340B(NemotronConfig): num_layers: int = 96 @@ -318,6 +330,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): "Nemotron3Config4B", "Nemotron3Config8B", "Nemotron4Config15B", + "Nemotron4Config22B", "Nemotron4Config340B", "NemotronModel", ]