Skip to content

Commit

Permalink
feat(mlx_lm): Nemotron (#949)
Browse files Browse the repository at this point in the history
* feat: Nemotron

https://huggingface.co/nvidia/Minitron-4B-Base

This is basically Llama with partial RoPE and LayerNorm instead of
BatchNorm. Also they add 1 to the LayerNorm weight for some reason.

* fixup! feat: Nemotron

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
  • Loading branch information
llllvvuu and awni authored Aug 30, 2024
1 parent b1186e2 commit fc93c55
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 0 deletions.
227 changes: 227 additions & 0 deletions llms/mlx_lm/models/nemotron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright © 2024 Apple Inc.

from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Union

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, KVCache, create_attention_mask


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
hidden_act: str
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
norm_eps: float
vocab_size: int
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
partial_rotary_factor: float = 0.5
rope_theta: float = 10000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False

def __post_init__(self):
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear"]:
raise ValueError("rope_scaling 'type' currently only supports 'linear'")


@partial(mx.compile, shapeless=True)
def relu_squared(x):
return nn.relu(x).square()


class NemotronLayerNorm1P(nn.LayerNorm):
def __call__(self, x):
weight = self.weight + 1 if "weight" in self else None
bias = self.bias if "bias" in self else None
return mx.fast.layer_norm(x, weight, bias, self.eps)


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads

self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.partial_rotary_factor = args.partial_rotary_factor

self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)

rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
base=args.rope_theta,
scale=rope_scale,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, _ = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

dim = args.hidden_size
hidden_dim = args.intermediate_size
mlp_bias = args.mlp_bias

self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)

def __call__(self, x) -> mx.array:
return self.down_proj(relu_squared(self.up_proj(x)))


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
self.post_attention_layernorm = NemotronLayerNorm1P(
args.hidden_size, eps=args.norm_eps
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out


class NemotronModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)

def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)

mask = create_attention_mask(h, cache)

if cache is None:
cache = [None] * len(self.layers)

for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)

return self.norm(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = NemotronModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out

@property
def layers(self):
return self.model.layers

@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)

@property
def n_kv_heads(self):
return self.args.num_key_value_heads
1 change: 1 addition & 0 deletions llms/mlx_lm/tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def to_lora(layer):
"llama",
"phi",
"mixtral",
"nemotron",
"stablelm",
"qwen2",
"qwen2_moe",
Expand Down

0 comments on commit fc93c55

Please sign in to comment.