Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing quantize_base to llama 3.1 #1485

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions torchtune/models/llama3_1/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchtune.modules import (
MultiHeadAttention,
FeedForward,
FrozenNF4Linear,
KVCache,
RMSNorm,
TransformerDecoder,
Expand Down Expand Up @@ -118,17 +119,18 @@ def llama3_1(
output=output_proj,
)

def llama3_mlp(dim: int, hidden_dim: int) -> FeedForward:
def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward:
"""
Build the MLP layer associated with the Llama model.
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)




# ------------------ LoRA Llama3 ------------------


Expand Down Expand Up @@ -223,7 +225,7 @@ def lora_llama3_1(
use_dora=use_dora,
)
else:
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim)
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base)

layer = TransformerSelfAttentionLayer(
attn=self_attn,
Expand Down Expand Up @@ -328,7 +330,11 @@ def lora_llama3_1_self_attention(
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
else nn.Linear(embed_dim, num_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False)
)
)
k_proj = (
adapter_cls(
Expand All @@ -340,7 +346,11 @@ def lora_llama3_1_self_attention(
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
)
v_proj = (
adapter_cls(
Expand All @@ -352,7 +362,11 @@ def lora_llama3_1_self_attention(
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
)
output_proj = (
adapter_cls(
Expand All @@ -364,7 +378,11 @@ def lora_llama3_1_self_attention(
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
else (
nn.Linear(embed_dim, embed_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, embed_dim, bias=False)
)
)
rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
self_attn = MultiHeadAttention(
Expand Down
Loading