Skip to content

Commit

Permalink
sanitize_shape refactoring (#14565)
Browse files Browse the repository at this point in the history
* #11512: sanitize_shape refactoring

* #11512: remove sanitize_topk_shape

* #11512: Remove print functions from sanitize_shapes

* #11512: Modify sanitize_shape_rm function

* #11512: Reduce argmax sweep number of parameters

* #11512: Modify invalidate_function inside split_query_key_valu_and_split_heads..py sweep

* #11512: Modify santizie_shape function inside utils when method is split_query_key_value_and_split_heads
  • Loading branch information
amalbasaTT authored Nov 4, 2024
1 parent ea9c15d commit 6512ac5
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 116 deletions.
80 changes: 51 additions & 29 deletions tests/sweep_framework/sweep_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,49 @@
import math
import itertools
from typing import Optional, List
import copy


def sanitize_shape_rm(input_shape):
if input_shape[-1] % 2 != 0:
input_shape[-1] = input_shape[-1] + input_shape[-1] % 2
input_shape[-1] = input_shape[-1] + 1
return input_shape


def sanitize_shape(shape, method, **kwargs):
if method == "topk":
num_dims = len(shape)
last_dim = shape[num_dims - 1]
if not (last_dim & (last_dim - 1) == 0) and last_dim != 0:
last_dim = 2 ** math.ceil(math.log2(last_dim))
if last_dim < 64:
last_dim = 64
shape[num_dims - 1] = last_dim

if method == "split_query_key_value_and_split_heads":
assert len(shape) == 3

hidden_size = shape[2]
num_heads = kwargs.pop("num_heads")
num_kv_heads = kwargs.pop("num_kv_heads")

if num_kv_heads is None:
min_sum_heads_size = 32 * (3 * num_heads)
else:
min_sum_heads_size = 32 * (num_heads + 2 * num_kv_heads)
hidden_size = 4672

if hidden_size % min_sum_heads_size != 0:
if hidden_size < min_sum_heads_size:
hidden_size = min_sum_heads_size
else:
hidden_size = (hidden_size // min_sum_heads_size) * min_sum_heads_size

shape[2] = hidden_size

return shape


def tensor_to_dtype(x, dtype):
if x.dtype == torch.bool:
x = x.to(torch.bfloat16)
Expand Down Expand Up @@ -144,49 +179,36 @@ def gen_with_zeroes(size, probabilityzeroes=0.5, low=-100, high=100, dtype=torch
return mask


# at the moment, topk only works on last dim
# last dim must be a multiple of 64 and a pow of 2
def santize_topk_shape(input_shape):
num_dims = len(input_shape)
last_dim = input_shape[num_dims - 1]
if not (last_dim & (last_dim - 1) == 0) and last_dim != 0:
last_dim = 2 ** math.ceil(math.log2(last_dim))
last_dim = last_dim + last_dim % 64

input_shape[num_dims - 1] = last_dim

return input_shape


def gen_rand_integers(low, high, num_samples):
for i in range(num_samples):
yield random.randint(low, high)


def gen_split_qkv_heads_spec(
batch_size_list: List[int],
sequence_size_list: List[int],
num_heads_list: List[int],
input_shape_list: List[int],
transpose_key_list: List[bool],
num_heads_list: List[int],
num_kv_heads_list: List[int] = [None],
kv_input_tensor_list: List[bool] = [False],
use_invalid_hidden_size=False,
):
for batch_size, sequence_size, num_heads, num_kv_heads, kv_input_tensor, transpose_key in itertools.product(
batch_size_list, sequence_size_list, num_heads_list, num_kv_heads_list, kv_input_tensor_list, transpose_key_list
for input_shape, num_heads, num_kv_heads, kv_input_tensor, transpose_key in itertools.product(
input_shape_list, num_heads_list, num_kv_heads_list, kv_input_tensor_list, transpose_key_list
):
if num_kv_heads is None:
num_kv_heads = num_heads
input_shape_ = input_shape.copy()

if use_invalid_hidden_size is False:
head_size = 32 * random.randint(1, 3)
hidden_size = head_size * (num_heads + num_kv_heads * 2)
else:
hidden_size = random.randint(num_heads + num_kv_heads * 2, 2048)
input_shape_ = sanitize_shape(
input_shape_,
"split_query_key_value_and_split_heads",
num_heads=num_heads,
num_kv_heads=num_kv_heads,
)

yield {
"batch_size": batch_size,
"sequence_size": sequence_size,
"hidden_size": hidden_size,
"batch_size": input_shape_[0],
"sequence_size": input_shape_[1],
"hidden_size": input_shape_[2],
"num_heads": num_heads,
"num_kv_heads": num_kv_heads,
"kv_input_tensor": kv_input_tensor,
Expand Down
6 changes: 3 additions & 3 deletions tests/sweep_framework/sweeps/reduction/argmax/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_shape": gen_shapes([1, 1, 1, 1], [2, 6, 128, 128], [1, 1, 1, 1], 128)
+ gen_shapes([1, 1, 1], [6, 128, 128], [1, 1, 1], 128)
+ gen_shapes([1, 1], [128, 128], [1, 1], 128),
"input_shape": gen_shapes([1, 1, 1, 1], [2, 6, 128, 128], [1, 1, 1, 1], 32)
+ gen_shapes([1, 1, 1], [6, 128, 128], [1, 1, 1], 32)
+ gen_shapes([1, 1], [128, 128], [1, 1], 32),
"dim": [0, 1, 2, 3, None],
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT],
Expand Down
4 changes: 2 additions & 2 deletions tests/sweep_framework/sweeps/reduction/topk/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import random
import ttnn
from tests.sweep_framework.sweep_utils.utils import gen_shapes, santize_topk_shape
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt

from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_topk_simmilarity
Expand Down Expand Up @@ -89,7 +89,7 @@ def run(
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)

input_shape = santize_topk_shape(input_shape)
input_shape = sanitize_shape(input_shape, "topk")

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
"nightly": {
"input_spec": list(
gen_split_qkv_heads_spec(
batch_size_list=list(gen_rand_integers(1, 10, 2)),
sequence_size_list=list(gen_rand_integers(1, 512, 4)),
input_shape_list=gen_shapes([1, 1, 1], [6, 512, 2048], [1, 1, 1], 16),
num_heads_list=list(gen_rand_integers(1, 20, 6)),
transpose_key_list=[True, False],
num_kv_heads_list=[None, 1],
use_invalid_hidden_size=False,
)
),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
Expand All @@ -49,30 +49,30 @@


def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
if (
test_vector["input_spec"]["hidden_size"]
% (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2)
!= 0
):
return True, "Hidden side must be divisible by the total number of heads"
if (
test_vector["input_spec"]["hidden_size"]
// (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2)
) % 32 != 0:
return True, "Head size must be a multiple of 32"
if test_vector["input_spec"]["num_kv_heads"] != test_vector["input_spec"]["num_heads"] and (
test_vector["input_spec"]["hidden_size"] != 4672
or test_vector["input_spec"]["num_kv_heads"] != 1
or test_vector["input_spec"]["num_heads"] != 71
):
return (
True,
"When using num_kv_heads, hidden_size must be 4672, num_kv_heads must be 1 and num_heads must be 71",
)
if (test_vector["input_spec"]["num_kv_heads"] != test_vector["input_spec"]["num_heads"]) and test_vector[
"input_spec"
]["transpose_key"] is True:
return True, "Can't transpose key when using separate kv heads"
if test_vector["input_spec"]["num_kv_heads"] is not None:
if (
test_vector["input_spec"]["hidden_size"]
% (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2)
!= 0
):
return True, "Hidden size must be divisible by the total number of heads"
if (
test_vector["input_spec"]["hidden_size"]
// (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2)
) % 32 != 0:
return True, "Head size must be a multiple of 32"
if test_vector["input_spec"]["hidden_size"] != 4672:
return True, "When using num_kv_heads, hidden_size must be 4672"
if not (test_vector["input_spec"]["num_kv_heads"] == 1 and test_vector["input_spec"]["num_kv_heads"] == 71):
return True, "When using num_kv_heads, num_kv_heads must be 1, and num_heads must be 71"
if test_vector["input_spec"]["transpose_key"] is True:
return True, "Can't transpose key when using separate kv heads"
else:
if test_vector["input_spec"]["hidden_size"] % (test_vector["input_spec"]["num_heads"] * 3) != 0:
return True, "Hidden size must be divisible by the total number of heads"
if (test_vector["input_spec"]["hidden_size"] // (test_vector["input_spec"]["num_heads"] * 3)) % 32 != 0:
return True, "Head size must be a multiple of 32"

if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT:
return True, "Inputs to eltwise binary must be tilized"
if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT and test_vector["input_a_dtype"] == ttnn.bfloat8_b:
Expand All @@ -99,35 +99,19 @@ def run(
batch_size, sequence_size, hidden_size, num_heads, num_kv_heads, _, transpose_key = input_spec.values()
input_shape = (batch_size, sequence_size, hidden_size)

if num_kv_heads == num_heads:
num_kv_heads = num_heads

head_size = hidden_size // (2 * num_kv_heads + num_heads)
if num_kv_heads is not None:
head_size = hidden_size // (2 * num_kv_heads + num_heads)
else:
head_size = hidden_size // (3 * num_heads)

torch_input_tensor = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_shape)
intermediate_result = torch.reshape(
torch_input_tensor, (batch_size, sequence_size, num_heads + (2 * num_kv_heads), head_size)
)
query, key, value = (
intermediate_result[..., :num_heads, :],
intermediate_result[..., num_heads : num_heads + num_kv_heads, :],
intermediate_result[..., num_heads + num_kv_heads :, :],
)

query = torch.reshape(query, (batch_size, sequence_size, num_heads, head_size))
key = torch.reshape(key, (batch_size, sequence_size, num_kv_heads, head_size))
value = torch.reshape(value, (batch_size, sequence_size, num_kv_heads, head_size))

query = torch.permute(query, (0, 2, 1, 3)).contiguous().clone()
key = torch.permute(key, (0, 2, 1, 3)).contiguous().clone()
value = torch.permute(value, (0, 2, 1, 3)).contiguous().clone()

if transpose_key:
key = torch.permute(key, (0, 1, 3, 2)).contiguous().clone()

torch_output_tensors = (query, key, value)
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)
torch_output_tensors = golden_function(
torch_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads, transpose_key=transpose_key
)

input_tensor_a = ttnn.from_torch(
torch_input_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
sanitize_shape_rm,
gen_rand_integers,
gen_split_qkv_heads_spec,
tensor_to_dtype,
)

from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt
from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

Expand All @@ -33,12 +32,12 @@
"nightly": {
"input_spec": list(
gen_split_qkv_heads_spec(
batch_size_list=list(gen_rand_integers(1, 10, 2)),
sequence_size_list=list(gen_rand_integers(1, 512, 4)),
input_shape_list=gen_shapes([1, 1, 1], [6, 512, 2048], [1, 1, 1], 8),
num_heads_list=list(gen_rand_integers(1, 20, 4)),
transpose_key_list=[True, False],
num_kv_heads_list=[None, 1],
kv_input_tensor_list=[True],
use_invalid_hidden_size=False,
)
),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
Expand All @@ -54,7 +53,7 @@
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
if test_vector["input_a_dtype"] != test_vector["input_b_dtype"]:
return True, "KV tensor dtype must be same as Q tensor dtype"
if test_vector["input_spec"]["num_kv_heads"] != test_vector["input_spec"]["num_heads"]:
if test_vector["input_spec"]["num_kv_heads"] is not None:
return True, "Can't use num_kv_heads when using separate kv tensor"
if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT:
return True, "Inputs to eltwise binary must be tilized"
Expand Down Expand Up @@ -82,48 +81,44 @@ def run(
torch.manual_seed(data_seed)

batch_size, sequence_size, hidden_size, num_heads, num_kv_heads, _, transpose_key = input_spec.values()
input_shape = (batch_size, sequence_size, hidden_size)

if num_kv_heads == num_heads:
num_kv_heads = num_heads

head_size = hidden_size // (2 * num_kv_heads + num_heads)

torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.float32)
intermediate_result = torch.reshape(
torch_input_tensor, (batch_size, sequence_size, num_heads + (2 * num_kv_heads), head_size)
)
query, key, value = (
intermediate_result[..., :num_heads, :],
intermediate_result[..., num_heads : num_heads + num_kv_heads, :],
intermediate_result[..., num_heads + num_kv_heads :, :],
)
if num_kv_heads is not None:
head_size = hidden_size // (2 * num_kv_heads + num_heads)
else:
head_size = hidden_size // (3 * num_heads)

query = tensor_to_dtype(torch.reshape(query, (batch_size, sequence_size, num_heads, head_size)), input_a_dtype)
key = tensor_to_dtype(torch.reshape(key, (batch_size, sequence_size, num_kv_heads, head_size)), input_b_dtype)
value = tensor_to_dtype(torch.reshape(value, (batch_size, sequence_size, num_kv_heads, head_size)), input_b_dtype)
q_hidden_size = head_size * num_heads

query = torch.permute(query, (0, 2, 1, 3)).contiguous().clone()
key = torch.permute(key, (0, 2, 1, 3)).contiguous().clone()
value = torch.permute(value, (0, 2, 1, 3)).contiguous().clone()
q_input_shape = (batch_size, sequence_size, q_hidden_size)
kv_input_shape = (batch_size, sequence_size, hidden_size - q_hidden_size)

if transpose_key:
key = torch.permute(key, (0, 1, 3, 2)).contiguous().clone()
torch_q_input_tensor = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(q_input_shape)

torch_output_tensors = (query, key, value)
torch_kv_input_tensor = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype
)(kv_input_shape)

q_hiden_size = head_size * num_heads
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)
torch_output_tensors = golden_function(
torch_q_input_tensor,
torch_kv_input_tensor,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
transpose_key=transpose_key,
)

input_tensor_a = ttnn.from_torch(
torch_input_tensor[:, :, :q_hiden_size],
torch_q_input_tensor,
dtype=input_a_dtype,
layout=input_layout,
device=device,
memory_config=input_a_memory_config,
)

input_tensor_b = ttnn.from_torch(
torch_input_tensor[:, :, q_hiden_size:],
torch_kv_input_tensor,
dtype=input_b_dtype,
layout=input_layout,
device=device,
Expand Down

0 comments on commit 6512ac5

Please sign in to comment.