Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Nemotron in Nemo-UX #10138

Merged
merged 13 commits into from
Aug 22, 2024
12 changes: 12 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
MixtralConfig8x7B,
MixtralConfig8x22B,
MixtralModel,
Nemotron3Config4B,
Nemotron3Config8B,
Nemotron4Config15B,
Nemotron4Config340B,
NemotronConfig,
NemotronModel,
gpt_data_step,
gpt_forward_step,
)
Expand All @@ -57,6 +63,12 @@
"MixtralConfig8x7B",
"MixtralConfig8x22B",
"MixtralModel",
"NemotronModel",
"Nemotron3Config4B",
"Nemotron3Config8B",
"Nemotron4Config15B",
"Nemotron4Config340B",
"NemotronConfig",
"LlamaConfig",
"Llama2Config7B",
"Llama2Config13B",
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/llm/fn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@ def gelu_impl(x):

def openai_gelu(x):
return gelu_impl(x)


@torch.jit.script
def squared_relu(x):
"""Squared ReLU activation function."""
return torch.pow(torch.nn.functional.relu(x), 2)
7 changes: 7 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel
from nemo.collections.llm.gpt.model.nemotron import Nemotron3Config8B, NemotronConfig, NemotronModel

__all__ = [
"GPTConfig",
Expand All @@ -44,6 +45,12 @@
"Llama2Config70B",
"Llama3Config8B",
"Llama3Config70B",
"NemotronConfig",
"Nemotron3Config4B",
Fixed Show fixed Hide fixed
"Nemotron3Config8B",
"Nemotron4Config15B",
Fixed Show fixed Hide fixed
"Nemotron4Config340B",
Fixed Show fixed Hide fixed
"NemotronModel",
"CodeLlamaConfig7B",
"CodeLlamaConfig13B",
"CodeLlamaConfig34B",
Expand Down
314 changes: 314 additions & 0 deletions nemo/collections/llm/gpt/model/nemotron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Callable, Optional

import torch
from torch import nn

from nemo.collections.llm.fn.activation import squared_relu
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown

if TYPE_CHECKING:
from transformers import NemotronConfig as HFNemotronConfig
from transformers import NemotronForCausalLM

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec


@dataclass
suiyoubi marked this conversation as resolved.
Show resolved Hide resolved
class NemotronConfig(GPTConfig):
# configs that are common across model sizes
normalization: str = "LayerNorm"
activation_func: Callable = squared_relu
add_bias_linear: bool = False
seq_length: int = 4096
position_embedding_type: str = "rope"
rotary_percent: float = 0.5
hidden_dropout: float = 0.0
attention_dropout: float = 0.0
layernorm_zero_centered_gamma: bool = True # layernorm1p
init_method_std: float = 0.01
share_embeddings_and_output_weights: bool = False
kv_channels: int = None
num_query_groups: int = None


@dataclass
class Nemotron3Config4B(NemotronConfig):
num_layers: int = 32
hidden_size: int = 3072
ffn_hidden_size: int = 9216
kv_channels: int = 128
num_query_groups: int = 8
num_attention_heads: int = 24
init_method_std: float = 0.0134


@dataclass
class Nemotron3Config8B(NemotronConfig):
num_layers: int = 32
hidden_size: int = 4096
ffn_hidden_size: int = 16384
num_attention_heads: int = 32


@dataclass
class Nemotron4Config15B(NemotronConfig):
num_layers: int = 32
hidden_size: int = 6144
ffn_hidden_size: int = 24576
num_attention_heads: int = 48
num_query_groups: int = 8
init_method_std: float = 0.0134


@dataclass
class Nemotron4Config340B(NemotronConfig):
num_layers: int = 96
hidden_size: int = 18432
ffn_hidden_size: int = 73728
num_attention_heads: int = 96
num_query_groups: int = 8
init_method_std: float = 0.0063


class NemotronModel(GPTModel):
def __init__(
self,
config: Annotated[Optional[NemotronConfig], Config[NemotronConfig]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional["TokenizerSpec"] = None,
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
):
super().__init__(config or NemotronConfig(), optim=optim, tokenizer=tokenizer, model_transform=model_transform)


@io.model_importer(NemotronModel, "hf")
class HFNemotronImporter(io.ModelConnector["NemotronForCausalLM", NemotronModel]):
def init(self) -> NemotronModel:
return NemotronModel(self.config, tokenizer=self.tokenizer)

def apply(self, output_path: Path) -> Path:
from transformers import NemotronForCausalLM
suiyoubi marked this conversation as resolved.
Show resolved Hide resolved

source = NemotronForCausalLM.from_pretrained(str(self))
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)

print(f"Converted Nemotron model to Nemo, model saved to {output_path}")

teardown(trainer, target)
del trainer, target

return output_path

def convert_state(self, source, target):
mapping = {
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
"model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight",
"model.layers.*.mlp.up_proj.weight": "decoder.layers.*.mlp.linear_fc1.weight",
"model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight",
"model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
"model.layers.*.input_layernorm.bias": "decoder.layers.*.self_attention.linear_qkv.layer_norm_bias",
"model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight",
"model.layers.*.post_attention_layernorm.bias": "decoder.layers.*.mlp.linear_fc1.layer_norm_bias",
"model.norm.weight": "decoder.final_layernorm.weight",
"model.norm.bias": "decoder.final_layernorm.bias",
"lm_head.weight": "output_layer.weight",
}

return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv])

@property
def tokenizer(self) -> "AutoTokenizer":
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
Fixed Show fixed Hide fixed

return AutoTokenizer(str(self))

@property
def config(self) -> NemotronConfig:
from transformers import NemotronConfig as HFNemotronConfig
suiyoubi marked this conversation as resolved.
Show resolved Hide resolved

source = HFNemotronConfig.from_pretrained(str(self))

def make_vocab_size_divisible_by(vocab_size):
base = 128
while vocab_size % base != 0:
base //= 2
return base

output = NemotronConfig(
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
num_attention_heads=source.num_attention_heads,
init_method_std=source.initializer_range,
seq_length=source.max_position_embeddings,
layernorm_epsilon=source.norm_eps,
num_query_groups=source.num_key_value_heads,
rotary_base=source.rope_theta,
rotary_percent=source.partial_rotary_factor,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
)

return output


@io.model_exporter(NemotronModel, "hf")
class HFNemotronExporter(io.ModelConnector[NemotronModel, "NemotronForCausalLM"]):
def init(self) -> "NemotronForCausalLM":
from transformers import AutoModelForCausalLM
suiyoubi marked this conversation as resolved.
Show resolved Hide resolved

return AutoModelForCausalLM.from_config(self.config)

def apply(self, output_path: Path) -> Path:
target = self.init()
source, _ = self.nemo_load(str(self))
target = self.convert_state(source, target)

target = target.cpu()
target.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)

return output_path

def convert_state(self, source, target):
mapping = {
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
"decoder.layers.*.mlp.linear_fc1.weight": "model.layers.*.mlp.up_proj.weight",
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
"decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.layers.*.input_layernorm.bias",
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
"decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.layers.*.post_attention_layernorm.bias",
"decoder.final_layernorm.weight": "model.norm.weight",
"decoder.final_layernorm.bias": "model.norm.bias",
"output_layer.weight": "lm_head.weight",
}

return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv])

@property
def tokenizer(self):
return io.load_context(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "HFNemotronConfig":
source: NemotronConfig = io.load_context(str(self)).model.config

from transformers import NemotronConfig as HFNemotronConfig
suiyoubi marked this conversation as resolved.
Show resolved Hide resolved

return HFNemotronConfig(
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
intermediate_size=source.ffn_hidden_size,
num_attention_heads=source.num_attention_heads,
head_dim=(
source.kv_channels
if source.kv_channels is not None
else source.hidden_size // source.num_attention_heads
),
tie_word_embeddings=source.share_embeddings_and_output_weights,
max_position_embeddings=source.seq_length,
initializer_range=source.init_method_std,
norm_eps=source.layernorm_epsilon,
num_key_value_heads=source.num_query_groups,
rope_theta=source.rotary_base,
partial_rotary_factor=source.rotary_percent,
vocab_size=self.tokenizer.vocab_size,
)


@io.state_transform(
source_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv(ctx: io.TransformCTX, q, k, v):
megatron_config = ctx.target.config

head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num

old_tensor_shape = q.size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]

q = q.view(*new_q_tensor_shape)
k = k.view(*new_kv_tensor_shape)
v = v.view(*new_kv_tensor_shape)

qkv_weights_l = []
for i in range(num_query_groups):
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
qkv_weights_l.append(k[i : i + 1, :, :])
qkv_weights_l.append(v[i : i + 1, :, :])
qkv_weights = torch.cat(qkv_weights_l)
assert qkv_weights.ndim == 3, qkv_weights.shape
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape

qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])

return qkv_weights


@io.state_transform(
source_key="decoder.layers.*.self_attention.linear_qkv.weight",
target_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
)
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
megatron_config = ctx.source.config

head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
qkv_total_dim = head_num + 2 * num_query_groups

linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))

q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu()
k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu()
v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu()

return q_proj, k_proj, v_proj


__all__ = [
"NemotronConfig",
"Nemotron3Config4B",
"Nemotron3Config8B",
"Nemotron4Config15B",
suiyoubi marked this conversation as resolved.
Show resolved Hide resolved
"Nemotron4Config340B",
"NemotronModel",
]
Loading