Skip to content

Commit

Permalink
Llama3.1 70b Prefill - MLP and Attention (#11724)
Browse files Browse the repository at this point in the history
* #0: Llama3.1  MLP prefill

* #0: Llama3.1 Attention implementation

* #0: Llama3.1 galaxy - Update kv cache for specific user_id

* #0: Address comments and fix PCC
  • Loading branch information
djordje-tt authored Aug 23, 2024
1 parent e3ee890 commit 384d6c1
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 17 deletions.
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_model_config(
llm_mode = "decode" if seq_len == 1 else "prefill"
assert num_devices == 8
assert batch in (1, 16, 32)
assert seq_len in (1, 128, 2048, 8192)
assert seq_len in (1, 128, 256, 2048, 8192)

# Supported values, TODO update for larger TT chips
if max_context_len > 4096:
Expand Down
105 changes: 92 additions & 13 deletions models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from loguru import logger
import torch
import ttnn
from ttnn import ReplicateTensorToMesh, ListMeshToTensor, ConcatMeshToTensor
from ttnn import ReplicateTensorToMesh
import gc

from models.demos.t3000.llama2_70b.reference.llama.llama import Llama
from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy
from models.demos.t3000.llama2_70b.reference.llama.llama.model import precompute_freqs_cis

from models.utility_functions import skip_for_grayskull
from models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
check_device_mesh,
Expand All @@ -33,18 +31,37 @@
get_rot_transformation_mat,
should_skip_model_load,
check_kv_cache,
)

from models.utility_functions import skip_for_grayskull
from models.demos.t3000.llama2_70b.tt.llama_common import (
setup_llama_env,
check_device_mesh,
extract_pcc_from_log,
generate_rot_emb,
get_rotation_mat,
MAX_SEQ_LEN,
MAX_SEQ_LEN_LLAMA3,
BASE_URL,
UNIT_TEST_N_LAYER,
UNIT_TEST_LAYER_NUM,
UNIT_TEST_START_POS,
UNIT_TEST_GENERATION_LENGTH,
comp_pcc,
get_rot_transformation_mat,
should_skip_model_load,
check_kv_cache,
num_to_corerange,
ConcatMesh2DToTensor,
ShardTensor2dMesh,
)
from models.utility_functions import nearest_32


class PytorchLlamaAttentionModel(torch.nn.Module):
def __init__(self, hf_reference_model, layer_num):
def __init__(self, hf_reference_model, layer_num, rope_theta):
super().__init__()
self.attention = hf_reference_model.layers[layer_num].attention

self.rope_theta = rope_theta
# Disable dropout
self.attention.eval()

Expand Down Expand Up @@ -76,7 +93,7 @@ def prepare_inputs_prefill(self, x, start_pos):
"""
batch = x.size(0)
seq_len = x.size(1)
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2)
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta)
freqs_cis = freqs_cis[start_pos : start_pos + seq_len]

attn_mask = torch.full((seq_len, seq_len), float("-inf"))
Expand All @@ -103,7 +120,7 @@ def forward(self, x, start_pos, freqs_cis, mask):
return result


def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos):
def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_theta):
assert len(x.size()) == 3
batch, seq_len, _ = x.shape

Expand Down Expand Up @@ -134,7 +151,7 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos):

batch_size_per_group = llama_attention_model.batch_size_per_device_group

rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2)
rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta)
rot_mat = get_rotation_mat(rot_emb, start_pos, seq_len, batch=batch_size_per_group)
assert rot_mat.size() == (
1,
Expand Down Expand Up @@ -167,6 +184,66 @@ def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos):
)

attn_masks = None

elif llama_attention_model.model_config["LLM_MODE"] == "prefill":
assert (
seq_len % 256 == 0 and seq_len > 0 and seq_len <= 8192
), "Prefill mode only supports seqlen as a multiple of 256 up to 8k"
assert batch == 1, "prefill mode only supports batch size 1"
x = x.unsqueeze(0)
assert x.shape == (1, batch, seq_len, llama_attention_model.hidden_size)
xs = ttnn.as_tensor(
x,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
device=llama_attention_model.device_mesh,
mesh_mapper=ShardTensor2dMesh(
llama_attention_model.device_mesh, dims=(3, None), cluster_shape=llama_attention_model.cluster_shape
),
)

cos, sin = precompute_freqs(
llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta, use_scaled=True
)
cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin)
assert cos_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim)
assert sin_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim)

cos_gathereds = ttnn.as_tensor(
cos_gathered,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
device=llama_attention_model.device_mesh,
mesh_mapper=ReplicateTensorToMesh(llama_attention_model.device_mesh),
)
sin_gathereds = ttnn.as_tensor(
sin_gathered,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
device=llama_attention_model.device_mesh,
mesh_mapper=ReplicateTensorToMesh(llama_attention_model.device_mesh),
)

rot_mats = [cos_gathereds, sin_gathereds]

attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min)
attn_mask = torch.triu(attn_mask, diagonal=1)
attn_mask = attn_mask.expand(1, batch, -1, -1)
attn_masks = ttnn.as_tensor(
attn_mask,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"),
mesh_mapper=ReplicateTensorToMesh(llama_attention_model.device_mesh),
memory_config=ttnn.DRAM_MEMORY_CONFIG,
device=llama_attention_model.device_mesh,
)

return (
xs,
start_pos,
Expand Down Expand Up @@ -206,7 +283,9 @@ def run_test_LlamaAttention_inference(
configuration = hugging_face_reference_model.params

# PyTorch model --------------------------------------------------------------------
pytorch_LlamaAttention_model = PytorchLlamaAttentionModel(hugging_face_reference_model, UNIT_TEST_LAYER_NUM)
pytorch_LlamaAttention_model = PytorchLlamaAttentionModel(
hugging_face_reference_model, UNIT_TEST_LAYER_NUM, configuration.rope_theta
)
# TT model -------------------------------------------------------------------------
transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size

Expand Down Expand Up @@ -266,7 +345,7 @@ def run_test_LlamaAttention_inference(

# TT hardware execution -------------------------------------------------------------
attention_input, start_pos, rot_mat, attn_mask = tt_llama_attention_prepare_inputs(
tt_LlamaAttention_model, tt_input, start_pos
tt_LlamaAttention_model, tt_input, start_pos, configuration.rope_theta
)
tt_out = tt_LlamaAttention_model(
attention_input,
Expand Down Expand Up @@ -350,8 +429,8 @@ def run_test_LlamaAttention_inference(
)
@pytest.mark.parametrize(
"batch, seq_len, pcc",
[(32, 1, 0.9995)],
ids=["decode"],
[(32, 1, 0.9995), (1, 256, 0.999)],
ids=["decode", "prefill"],
)
@pytest.mark.parametrize(
"max_batch_size, max_context_len",
Expand Down
18 changes: 15 additions & 3 deletions models/demos/tg/llama3_70b/tests/test_llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def tt_llama_mlp_prepare_inputs(llama_mlp_model, x):
llama_mlp_model.device_mesh, dims=(3, None), cluster_shape=llama_mlp_model.cluster_shape
),
)
elif llama_mlp_model.model_config["LLM_MODE"] == "prefill":
x_multichip = ttnn.from_torch(
x,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=llama_mlp_model.device_mesh,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ShardTensor2dMesh(
llama_mlp_model.device_mesh, dims=(3, None), cluster_shape=llama_mlp_model.cluster_shape
),
)

return x_multichip

Expand Down Expand Up @@ -103,7 +114,8 @@ def run_test_LlamaMLP_inference(
if model_config["LLM_MODE"] == "decode":
# shape should be (1, seq_len, batch, dim)
pt_inp_normed = pt_inp_normed.unsqueeze(1).permute(2, 1, 0, 3)
else:
else: # prefill
# shape should be (1, batch, seq_len, dim)
pt_inp_normed = pt_inp_normed.unsqueeze(0)

tt_inp = pt_inp_normed.clone()
Expand Down Expand Up @@ -154,8 +166,8 @@ def run_test_LlamaMLP_inference(
)
@pytest.mark.parametrize(
"batch, seq_len, pcc",
[(32, 1, 0.9997)],
ids=["decode"],
[(32, 1, 0.9997), (1, 256, 0.9995)],
ids=["decode", "prefill"],
)
@pytest.mark.parametrize(
"max_batch_size, max_context_len",
Expand Down
Loading

0 comments on commit 384d6c1

Please sign in to comment.