From 6512ac5cb635c132b444c4f9a4e35b43e6e484fc Mon Sep 17 00:00:00 2001 From: Andrija Malbasa Date: Mon, 4 Nov 2024 16:05:50 +0100 Subject: [PATCH] sanitize_shape refactoring (#14565) * #11512: sanitize_shape refactoring * #11512: remove sanitize_topk_shape * #11512: Remove print functions from sanitize_shapes * #11512: Modify sanitize_shape_rm function * #11512: Reduce argmax sweep number of parameters * #11512: Modify invalidate_function inside split_query_key_valu_and_split_heads..py sweep * #11512: Modify santizie_shape function inside utils when method is split_query_key_value_and_split_heads --- tests/sweep_framework/sweep_utils/utils.py | 80 +++++++++++------- .../sweeps/reduction/argmax/argmax.py | 6 +- .../sweeps/reduction/topk/topk.py | 4 +- .../split_query_key_value_and_split_heads.py | 84 ++++++++----------- ...uery_key_value_and_split_heads_kv_input.py | 59 ++++++------- 5 files changed, 117 insertions(+), 116 deletions(-) diff --git a/tests/sweep_framework/sweep_utils/utils.py b/tests/sweep_framework/sweep_utils/utils.py index c155ef20ca3..ba6ec3bd07b 100644 --- a/tests/sweep_framework/sweep_utils/utils.py +++ b/tests/sweep_framework/sweep_utils/utils.py @@ -11,14 +11,49 @@ import math import itertools from typing import Optional, List +import copy def sanitize_shape_rm(input_shape): if input_shape[-1] % 2 != 0: - input_shape[-1] = input_shape[-1] + input_shape[-1] % 2 + input_shape[-1] = input_shape[-1] + 1 return input_shape +def sanitize_shape(shape, method, **kwargs): + if method == "topk": + num_dims = len(shape) + last_dim = shape[num_dims - 1] + if not (last_dim & (last_dim - 1) == 0) and last_dim != 0: + last_dim = 2 ** math.ceil(math.log2(last_dim)) + if last_dim < 64: + last_dim = 64 + shape[num_dims - 1] = last_dim + + if method == "split_query_key_value_and_split_heads": + assert len(shape) == 3 + + hidden_size = shape[2] + num_heads = kwargs.pop("num_heads") + num_kv_heads = kwargs.pop("num_kv_heads") + + if num_kv_heads is None: + min_sum_heads_size = 32 * (3 * num_heads) + else: + min_sum_heads_size = 32 * (num_heads + 2 * num_kv_heads) + hidden_size = 4672 + + if hidden_size % min_sum_heads_size != 0: + if hidden_size < min_sum_heads_size: + hidden_size = min_sum_heads_size + else: + hidden_size = (hidden_size // min_sum_heads_size) * min_sum_heads_size + + shape[2] = hidden_size + + return shape + + def tensor_to_dtype(x, dtype): if x.dtype == torch.bool: x = x.to(torch.bfloat16) @@ -144,49 +179,36 @@ def gen_with_zeroes(size, probabilityzeroes=0.5, low=-100, high=100, dtype=torch return mask -# at the moment, topk only works on last dim -# last dim must be a multiple of 64 and a pow of 2 -def santize_topk_shape(input_shape): - num_dims = len(input_shape) - last_dim = input_shape[num_dims - 1] - if not (last_dim & (last_dim - 1) == 0) and last_dim != 0: - last_dim = 2 ** math.ceil(math.log2(last_dim)) - last_dim = last_dim + last_dim % 64 - - input_shape[num_dims - 1] = last_dim - - return input_shape - - def gen_rand_integers(low, high, num_samples): for i in range(num_samples): yield random.randint(low, high) def gen_split_qkv_heads_spec( - batch_size_list: List[int], - sequence_size_list: List[int], - num_heads_list: List[int], + input_shape_list: List[int], transpose_key_list: List[bool], + num_heads_list: List[int], num_kv_heads_list: List[int] = [None], kv_input_tensor_list: List[bool] = [False], use_invalid_hidden_size=False, ): - for batch_size, sequence_size, num_heads, num_kv_heads, kv_input_tensor, transpose_key in itertools.product( - batch_size_list, sequence_size_list, num_heads_list, num_kv_heads_list, kv_input_tensor_list, transpose_key_list + for input_shape, num_heads, num_kv_heads, kv_input_tensor, transpose_key in itertools.product( + input_shape_list, num_heads_list, num_kv_heads_list, kv_input_tensor_list, transpose_key_list ): - if num_kv_heads is None: - num_kv_heads = num_heads + input_shape_ = input_shape.copy() + if use_invalid_hidden_size is False: - head_size = 32 * random.randint(1, 3) - hidden_size = head_size * (num_heads + num_kv_heads * 2) - else: - hidden_size = random.randint(num_heads + num_kv_heads * 2, 2048) + input_shape_ = sanitize_shape( + input_shape_, + "split_query_key_value_and_split_heads", + num_heads=num_heads, + num_kv_heads=num_kv_heads, + ) yield { - "batch_size": batch_size, - "sequence_size": sequence_size, - "hidden_size": hidden_size, + "batch_size": input_shape_[0], + "sequence_size": input_shape_[1], + "hidden_size": input_shape_[2], "num_heads": num_heads, "num_kv_heads": num_kv_heads, "kv_input_tensor": kv_input_tensor, diff --git a/tests/sweep_framework/sweeps/reduction/argmax/argmax.py b/tests/sweep_framework/sweeps/reduction/argmax/argmax.py index aa91f3bf26f..913d4178865 100644 --- a/tests/sweep_framework/sweeps/reduction/argmax/argmax.py +++ b/tests/sweep_framework/sweeps/reduction/argmax/argmax.py @@ -26,9 +26,9 @@ # Developers can create their own generator functions and pass them to the parameters as inputs. parameters = { "nightly": { - "input_shape": gen_shapes([1, 1, 1, 1], [2, 6, 128, 128], [1, 1, 1, 1], 128) - + gen_shapes([1, 1, 1], [6, 128, 128], [1, 1, 1], 128) - + gen_shapes([1, 1], [128, 128], [1, 1], 128), + "input_shape": gen_shapes([1, 1, 1, 1], [2, 6, 128, 128], [1, 1, 1, 1], 32) + + gen_shapes([1, 1, 1], [6, 128, 128], [1, 1, 1], 32) + + gen_shapes([1, 1], [128, 128], [1, 1], 32), "dim": [0, 1, 2, 3, None], "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], "input_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], diff --git a/tests/sweep_framework/sweeps/reduction/topk/topk.py b/tests/sweep_framework/sweeps/reduction/topk/topk.py index 087973b1b08..e38a9134e5c 100644 --- a/tests/sweep_framework/sweeps/reduction/topk/topk.py +++ b/tests/sweep_framework/sweeps/reduction/topk/topk.py @@ -8,7 +8,7 @@ import torch import random import ttnn -from tests.sweep_framework.sweep_utils.utils import gen_shapes, santize_topk_shape +from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_topk_simmilarity @@ -89,7 +89,7 @@ def run( data_seed = random.randint(0, 20000000) torch.manual_seed(data_seed) - input_shape = santize_topk_shape(input_shape) + input_shape = sanitize_shape(input_shape, "topk") torch_input_tensor_a = gen_func_with_cast_tt( partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype diff --git a/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.py b/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.py index 3085854c5e4..90edf93e7fa 100644 --- a/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.py +++ b/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.py @@ -33,11 +33,11 @@ "nightly": { "input_spec": list( gen_split_qkv_heads_spec( - batch_size_list=list(gen_rand_integers(1, 10, 2)), - sequence_size_list=list(gen_rand_integers(1, 512, 4)), + input_shape_list=gen_shapes([1, 1, 1], [6, 512, 2048], [1, 1, 1], 16), num_heads_list=list(gen_rand_integers(1, 20, 6)), transpose_key_list=[True, False], num_kv_heads_list=[None, 1], + use_invalid_hidden_size=False, ) ), "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], @@ -49,30 +49,30 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: - if ( - test_vector["input_spec"]["hidden_size"] - % (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2) - != 0 - ): - return True, "Hidden side must be divisible by the total number of heads" - if ( - test_vector["input_spec"]["hidden_size"] - // (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2) - ) % 32 != 0: - return True, "Head size must be a multiple of 32" - if test_vector["input_spec"]["num_kv_heads"] != test_vector["input_spec"]["num_heads"] and ( - test_vector["input_spec"]["hidden_size"] != 4672 - or test_vector["input_spec"]["num_kv_heads"] != 1 - or test_vector["input_spec"]["num_heads"] != 71 - ): - return ( - True, - "When using num_kv_heads, hidden_size must be 4672, num_kv_heads must be 1 and num_heads must be 71", - ) - if (test_vector["input_spec"]["num_kv_heads"] != test_vector["input_spec"]["num_heads"]) and test_vector[ - "input_spec" - ]["transpose_key"] is True: - return True, "Can't transpose key when using separate kv heads" + if test_vector["input_spec"]["num_kv_heads"] is not None: + if ( + test_vector["input_spec"]["hidden_size"] + % (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2) + != 0 + ): + return True, "Hidden size must be divisible by the total number of heads" + if ( + test_vector["input_spec"]["hidden_size"] + // (test_vector["input_spec"]["num_heads"] + test_vector["input_spec"]["num_kv_heads"] * 2) + ) % 32 != 0: + return True, "Head size must be a multiple of 32" + if test_vector["input_spec"]["hidden_size"] != 4672: + return True, "When using num_kv_heads, hidden_size must be 4672" + if not (test_vector["input_spec"]["num_kv_heads"] == 1 and test_vector["input_spec"]["num_kv_heads"] == 71): + return True, "When using num_kv_heads, num_kv_heads must be 1, and num_heads must be 71" + if test_vector["input_spec"]["transpose_key"] is True: + return True, "Can't transpose key when using separate kv heads" + else: + if test_vector["input_spec"]["hidden_size"] % (test_vector["input_spec"]["num_heads"] * 3) != 0: + return True, "Hidden size must be divisible by the total number of heads" + if (test_vector["input_spec"]["hidden_size"] // (test_vector["input_spec"]["num_heads"] * 3)) % 32 != 0: + return True, "Head size must be a multiple of 32" + if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT: return True, "Inputs to eltwise binary must be tilized" if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT and test_vector["input_a_dtype"] == ttnn.bfloat8_b: @@ -99,35 +99,19 @@ def run( batch_size, sequence_size, hidden_size, num_heads, num_kv_heads, _, transpose_key = input_spec.values() input_shape = (batch_size, sequence_size, hidden_size) - if num_kv_heads == num_heads: - num_kv_heads = num_heads - - head_size = hidden_size // (2 * num_kv_heads + num_heads) + if num_kv_heads is not None: + head_size = hidden_size // (2 * num_kv_heads + num_heads) + else: + head_size = hidden_size // (3 * num_heads) torch_input_tensor = gen_func_with_cast_tt( partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype )(input_shape) - intermediate_result = torch.reshape( - torch_input_tensor, (batch_size, sequence_size, num_heads + (2 * num_kv_heads), head_size) - ) - query, key, value = ( - intermediate_result[..., :num_heads, :], - intermediate_result[..., num_heads : num_heads + num_kv_heads, :], - intermediate_result[..., num_heads + num_kv_heads :, :], - ) - - query = torch.reshape(query, (batch_size, sequence_size, num_heads, head_size)) - key = torch.reshape(key, (batch_size, sequence_size, num_kv_heads, head_size)) - value = torch.reshape(value, (batch_size, sequence_size, num_kv_heads, head_size)) - - query = torch.permute(query, (0, 2, 1, 3)).contiguous().clone() - key = torch.permute(key, (0, 2, 1, 3)).contiguous().clone() - value = torch.permute(value, (0, 2, 1, 3)).contiguous().clone() - if transpose_key: - key = torch.permute(key, (0, 1, 3, 2)).contiguous().clone() - - torch_output_tensors = (query, key, value) + golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads) + torch_output_tensors = golden_function( + torch_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads, transpose_key=transpose_key + ) input_tensor_a = ttnn.from_torch( torch_input_tensor, diff --git a/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads_kv_input.py b/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads_kv_input.py index f53992beae6..708965d66df 100644 --- a/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads_kv_input.py +++ b/tests/sweep_framework/sweeps/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads_kv_input.py @@ -13,9 +13,8 @@ sanitize_shape_rm, gen_rand_integers, gen_split_qkv_heads_spec, - tensor_to_dtype, ) - +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random @@ -33,12 +32,12 @@ "nightly": { "input_spec": list( gen_split_qkv_heads_spec( - batch_size_list=list(gen_rand_integers(1, 10, 2)), - sequence_size_list=list(gen_rand_integers(1, 512, 4)), + input_shape_list=gen_shapes([1, 1, 1], [6, 512, 2048], [1, 1, 1], 8), num_heads_list=list(gen_rand_integers(1, 20, 4)), transpose_key_list=[True, False], num_kv_heads_list=[None, 1], kv_input_tensor_list=[True], + use_invalid_hidden_size=False, ) ), "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], @@ -54,7 +53,7 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: if test_vector["input_a_dtype"] != test_vector["input_b_dtype"]: return True, "KV tensor dtype must be same as Q tensor dtype" - if test_vector["input_spec"]["num_kv_heads"] != test_vector["input_spec"]["num_heads"]: + if test_vector["input_spec"]["num_kv_heads"] is not None: return True, "Can't use num_kv_heads when using separate kv tensor" if test_vector["input_layout"] == ttnn.ROW_MAJOR_LAYOUT: return True, "Inputs to eltwise binary must be tilized" @@ -82,40 +81,36 @@ def run( torch.manual_seed(data_seed) batch_size, sequence_size, hidden_size, num_heads, num_kv_heads, _, transpose_key = input_spec.values() - input_shape = (batch_size, sequence_size, hidden_size) - - if num_kv_heads == num_heads: - num_kv_heads = num_heads - - head_size = hidden_size // (2 * num_kv_heads + num_heads) - torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.float32) - intermediate_result = torch.reshape( - torch_input_tensor, (batch_size, sequence_size, num_heads + (2 * num_kv_heads), head_size) - ) - query, key, value = ( - intermediate_result[..., :num_heads, :], - intermediate_result[..., num_heads : num_heads + num_kv_heads, :], - intermediate_result[..., num_heads + num_kv_heads :, :], - ) + if num_kv_heads is not None: + head_size = hidden_size // (2 * num_kv_heads + num_heads) + else: + head_size = hidden_size // (3 * num_heads) - query = tensor_to_dtype(torch.reshape(query, (batch_size, sequence_size, num_heads, head_size)), input_a_dtype) - key = tensor_to_dtype(torch.reshape(key, (batch_size, sequence_size, num_kv_heads, head_size)), input_b_dtype) - value = tensor_to_dtype(torch.reshape(value, (batch_size, sequence_size, num_kv_heads, head_size)), input_b_dtype) + q_hidden_size = head_size * num_heads - query = torch.permute(query, (0, 2, 1, 3)).contiguous().clone() - key = torch.permute(key, (0, 2, 1, 3)).contiguous().clone() - value = torch.permute(value, (0, 2, 1, 3)).contiguous().clone() + q_input_shape = (batch_size, sequence_size, q_hidden_size) + kv_input_shape = (batch_size, sequence_size, hidden_size - q_hidden_size) - if transpose_key: - key = torch.permute(key, (0, 1, 3, 2)).contiguous().clone() + torch_q_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(q_input_shape) - torch_output_tensors = (query, key, value) + torch_kv_input_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(kv_input_shape) - q_hiden_size = head_size * num_heads + golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads) + torch_output_tensors = golden_function( + torch_q_input_tensor, + torch_kv_input_tensor, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + transpose_key=transpose_key, + ) input_tensor_a = ttnn.from_torch( - torch_input_tensor[:, :, :q_hiden_size], + torch_q_input_tensor, dtype=input_a_dtype, layout=input_layout, device=device, @@ -123,7 +118,7 @@ def run( ) input_tensor_b = ttnn.from_torch( - torch_input_tensor[:, :, q_hiden_size:], + torch_kv_input_tensor, dtype=input_b_dtype, layout=input_layout, device=device,