From bb857dff134aece0026c9ea1fc5ed9436224f37f Mon Sep 17 00:00:00 2001 From: Mohamed Bahnas Date: Sat, 13 Apr 2024 02:39:03 +0000 Subject: [PATCH] #6712: Add ViT model code + tests - Update create_sharded_device_tensor to respect non-4D shapes - Disable all tests (#7527) except sharded ViT --- .../reference/torch_functional_vit.py | 292 ++++++++ .../functional_vit/tt/ttnn_functional_vit.py | 296 ++++++++ .../tt/ttnn_functional_vit_highres.py | 292 ++++++++ .../tt/ttnn_optimized_interleaved_vit.py | 446 ++++++++++++ .../tt/ttnn_optimized_sharded_vit.py | 621 ++++++++++++++++ .../tt/ttnn_optimized_sharded_vit_backup.py | 688 ++++++++++++++++++ .../tt/ttnn_optimized_vit_highres.py | 527 ++++++++++++++ .../vit/test_accuracy_ttnn_functional_vit.py | 151 ++++ ...est_accuracy_ttnn_optim_interleaved_vit.py | 185 +++++ .../test_accuracy_ttnn_optim_sharded_vit.py | 157 ++++ ...rformance_deviceOPs_ttnn_functional_vit.py | 175 +++++ ...ce_deviceOPs_ttnn_optim_interleaved_vit.py | 192 +++++ ...rmance_deviceOPs_ttnn_optim_sharded_vit.py | 319 ++++++++ .../vit/test_performance_highres.py | 288 ++++++++ .../vit/test_performance_highres_deviceOPs.py | 294 ++++++++ .../test_performance_ttnn_functional_vit.py | 210 ++++++ ..._performance_ttnn_optim_interleaved_vit.py | 253 +++++++ ...test_performance_ttnn_optim_sharded_vit.py | 264 +++++++ .../vit/test_torch_functional_vit.py | 270 +++++++ .../vit/test_torch_functional_vit_highres.py | 347 +++++++++ .../vit/test_ttnn_functional_vit.py | 360 +++++++++ .../vit/test_ttnn_functional_vit_highres.py | 399 ++++++++++ .../test_ttnn_optimized_interleaved_vit.py | 496 +++++++++++++ .../vit/test_ttnn_optimized_sharded_vit.py | 542 ++++++++++++++ .../vit/test_ttnn_optimized_vit_highres.py | 465 ++++++++++++ tt_eager/tensor/tensor.cpp | 10 +- .../eltwise_binary/eltwise_binary_op.cpp | 3 +- .../eltwise_binary/eltwise_binary_op.hpp | 7 +- 28 files changed, 8543 insertions(+), 6 deletions(-) create mode 100644 models/experimental/functional_vit/reference/torch_functional_vit.py create mode 100644 models/experimental/functional_vit/tt/ttnn_functional_vit.py create mode 100644 models/experimental/functional_vit/tt/ttnn_functional_vit_highres.py create mode 100644 models/experimental/functional_vit/tt/ttnn_optimized_interleaved_vit.py create mode 100644 models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py create mode 100644 models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py create mode 100644 models/experimental/functional_vit/tt/ttnn_optimized_vit_highres.py create mode 100644 tests/ttnn/integration_tests/vit/test_accuracy_ttnn_functional_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_interleaved_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_sharded_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_functional_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_interleaved_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_sharded_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_highres.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_highres_deviceOPs.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_ttnn_functional_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_interleaved_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_sharded_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_torch_functional_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_torch_functional_vit_highres.py create mode 100644 tests/ttnn/integration_tests/vit/test_ttnn_functional_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_ttnn_functional_vit_highres.py create mode 100644 tests/ttnn/integration_tests/vit/test_ttnn_optimized_interleaved_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py create mode 100644 tests/ttnn/integration_tests/vit/test_ttnn_optimized_vit_highres.py diff --git a/models/experimental/functional_vit/reference/torch_functional_vit.py b/models/experimental/functional_vit/reference/torch_functional_vit.py new file mode 100644 index 00000000000..7ae5b2f30a6 --- /dev/null +++ b/models/experimental/functional_vit/reference/torch_functional_vit.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import transformers + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def fold_torch(input_tensor, stride_h, stride_w): + N, H, W, C = input_tensor.shape + reshaped = input_tensor.reshape(N, H // stride_h, stride_h, W // stride_w, stride_w, C) + transposed = reshaped.permute(0, 1, 3, 2, 4, 5) + return transposed.reshape(N, H // stride_h, W // stride_w, C * stride_h * stride_w) + + +def vit_patch_embeddings( + pixel_values, + *, + parameters, +): + batch_size, img_c, img_h, img_w = pixel_values.shape + patch_size = 16 + patch_count = img_h // patch_size # 14 + patch_size_sq = int(patch_size * patch_size) # 256 + patch_size_sq_trpl = int(patch_size_sq * img_c) # 768 + patch_count_sq = int(patch_count * patch_count) # 196 + stride_h = patch_size + stride_w = 1 + + pixel_values = torch.permute(pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 4 - pixel_values.shape[3], 0, 0, 0, 0)) + pixel_values = pixel_values.reshape( + pixel_values.shape[0], + pixel_values.shape[1], + pixel_values.shape[2] // patch_size, + pixel_values.shape[3] * patch_size, + ) + + pixel_values = fold_torch(pixel_values, stride_h, stride_w) + pixel_values = pixel_values.reshape(1, 1, -1, pixel_values.shape[-1]) + + patch_embedding_output = pixel_values @ parameters.projection.weight + patch_embedding_output = patch_embedding_output + parameters.projection.bias + + patch_embedding_output = patch_embedding_output.reshape(batch_size, patch_count_sq, patch_size_sq_trpl) + + return patch_embedding_output + + +def vit_embeddings( + config, + pixel_values, + position_embeddings, + cls_tokens, + *, + parameters, +): + batch_size, img_c, img_h, img_w = pixel_values.shape + patch_size = 16 + patch_count = img_h // patch_size # 14 + + patch_embeddings = vit_patch_embeddings(pixel_values, parameters=parameters.patch_embeddings) + cls_tokens = cls_tokens.expand(batch_size, -1, -1) + patch_embeddings = torch.cat((cls_tokens, patch_embeddings), dim=1) + embedding_output = patch_embeddings + position_embeddings + + return embedding_output + + +def vit_layernorm_before( + config, + hidden_states, + *, + parameters, +): + batch_size, sequence_size, hidden_size = hidden_states.shape + attention_output = torch.nn.functional.layer_norm( + hidden_states, + (hidden_size,), + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + eps=config.layer_norm_eps, + ) + + return attention_output + + +def vit_layernorm_after( + config, + hidden_states, + *, + parameters, +): + batch_size, sequence_size, hidden_size = hidden_states.shape + layernorm_output = torch.nn.functional.layer_norm( + hidden_states, + (hidden_size,), + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + eps=config.layer_norm_eps, + ) + + return layernorm_output + + +def vit_attention( + config, + hidden_states, + attention_mask, + *, + parameters, +): + num_heads = config.num_attention_heads + batch_size, sequence_size, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query = hidden_states @ parameters.attention.query.weight + query = query + parameters.attention.query.bias + query = torch.reshape(query, (batch_size, sequence_size, num_heads, head_size)) + query = torch.permute(query, (0, 2, 1, 3)) + + key = hidden_states @ parameters.attention.key.weight + key = key + parameters.attention.key.bias + key = torch.reshape(key, (batch_size, sequence_size, num_heads, head_size)) + key = torch.permute(key, (0, 2, 3, 1)) + + value = hidden_states @ parameters.attention.value.weight + value = value + parameters.attention.value.bias + value = torch.reshape(value, (batch_size, sequence_size, num_heads, head_size)) + value = torch.permute(value, (0, 2, 1, 3)) + + attention_scores = query @ key + attention_scores = attention_scores / (head_size**0.5) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) + + context_layer = attention_probs @ value + context_layer = torch.permute(context_layer, (0, 2, 1, 3)) + context_layer = torch.reshape(context_layer, (batch_size, sequence_size, hidden_size)) + + self_output = context_layer + self_output = self_output @ parameters.output.dense.weight + self_output = self_output + parameters.output.dense.bias + + return self_output + + +def vit_intermediate(hidden_states, *, parameters): + hidden_states = hidden_states @ parameters.dense.weight + hidden_states = hidden_states + parameters.dense.bias + hidden_states = torch.nn.functional.gelu(hidden_states) + return hidden_states + + +def vit_output(config, hidden_states, residual, *, parameters): + output = hidden_states @ parameters.dense.weight + output = output + parameters.dense.bias + output = output + residual + + return output + + +def vit_feedforward( + config, + hidden_states, + attention_output, + *, + parameters, +): + intermediate = vit_intermediate(hidden_states, parameters=parameters.intermediate) + hidden_states = vit_output(config, intermediate, attention_output, parameters=parameters.output) + + return hidden_states + + +def vit_layer( + config, + hidden_states, + attention_mask, + *, + parameters, +): + layernorm_before_output = vit_layernorm_before( + config, + hidden_states, + parameters=parameters, + ) + attention_output = vit_attention( + config, + layernorm_before_output, + attention_mask, + parameters=parameters.attention, + ) + attention_output = attention_output + hidden_states + + layernorm_after_output = vit_layernorm_after( + config, + attention_output, + parameters=parameters, + ) + + feedforward_output = vit_feedforward( + config, + layernorm_after_output, + attention_output, + parameters=parameters, + ) + + return feedforward_output + + +def vit_encoder( + config, + hidden_states, + attention_mask, + *, + parameters, +): + encoder_input = hidden_states + encoder_output = None + for encoder_parameters in parameters.layer: + encoder_output = vit_layer( + config, + encoder_input, + attention_mask, + parameters=encoder_parameters, + ) + encoder_input = encoder_output + return encoder_output + + +def vit( + config, + pixel_values, + position_embeddings, + cls_tokens, + attention_mask, + *, + parameters, +): + hidden_states = vit_embeddings( + config, pixel_values, position_embeddings, cls_tokens, parameters=parameters.vit.embeddings + ) + + hidden_states = vit_encoder( + config, + hidden_states, + attention_mask, + parameters=parameters.vit.encoder, + ) + + # Final LayerNorm + output = torch.nn.functional.layer_norm( + hidden_states, + (config.hidden_size,), + parameters.vit.layernorm.weight, + parameters.vit.layernorm.bias, + eps=config.layer_norm_eps, + ) + + # Pooler + pooler_output = output[0] @ parameters.classifier.weight + pooler_output = pooler_output + parameters.classifier.bias + # pooler_output = torch.tanh(pooler_output) + + return pooler_output + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.vit.modeling_vit.ViTPatchEmbeddings): + weight = torch_model.projection.weight + bias = torch_model.projection.bias + + three_times_hidden_size, c, _, _ = weight.shape + + pad_value = 4 - c + preprocessed_weight = torch.nn.functional.pad(weight, (0, 0, 0, 0, 0, pad_value)) + preprocessed_weight = torch.permute(preprocessed_weight, (2, 3, 1, 0)) + preprocessed_weight = torch.reshape( + preprocessed_weight, (int(three_times_hidden_size * (4 / c)), three_times_hidden_size) + ) + + parameters = {"projection": {}} + parameters["projection"]["weight"] = preprocessed_weight + parameters["projection"]["bias"] = bias + + return parameters diff --git a/models/experimental/functional_vit/tt/ttnn_functional_vit.py b/models/experimental/functional_vit/tt/ttnn_functional_vit.py new file mode 100644 index 00000000000..197e8c803fa --- /dev/null +++ b/models/experimental/functional_vit/tt/ttnn_functional_vit.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +import torch + +import ttnn +import tt_lib as ttl +from ttnn.model_preprocessing import ( + preprocess_linear_weight, + preprocess_linear_bias, +) + + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def vit_patch_embeddings(config, pixel_values, *, parameters, unittest_check=False): + # batch_size, img_c, img_h, img_w = pixel_values.shape # NCHW + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + patch_count = img_h // patch_size # 14 + patch_size_sq_trpl = int(patch_size * patch_size * 3) # 768 + patch_count_all = int(patch_count * patch_count) # 196 + stride_h = patch_size + stride_w = 1 + + pixel_values = ttnn.reshape(pixel_values, (batch_size, img_h, img_w // patch_size, 4 * patch_size)) + pixel_values = ttl.tensor.fold(pixel_values, stride_h, stride_w) + pixel_values = ttnn.to_layout(pixel_values, layout=ttnn.TILE_LAYOUT) + + if unittest_check: + parameters = parameters.vit.embeddings.patch_embeddings + + patch_embedding_output = pixel_values @ parameters.projection.weight + patch_embedding_output = patch_embedding_output + parameters.projection.bias + + patch_embedding_output = ttnn.to_layout(patch_embedding_output, layout=ttnn.ROW_MAJOR_LAYOUT) + patch_embedding_output = ttnn.reshape(patch_embedding_output, (batch_size, patch_count_all, patch_size_sq_trpl)) + + return patch_embedding_output + + +def vit_embeddings( + config, + pixel_values, + cls_token, + position_embeddings, + *, + parameters, +): + parameters = parameters.vit.embeddings + + patch_embeddings = vit_patch_embeddings(config, pixel_values, parameters=parameters.patch_embeddings) + embedding_output = ttnn.concat((cls_token, patch_embeddings), dim=1) + embedding_output = embedding_output + position_embeddings + + return embedding_output + + +def vit_layernorm_before( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + ) + + return attention_output + + +def vit_layernorm_after( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + ) + + return attention_output + + +def vit_attention( + config, + hidden_states, + attention_mask, + *, + parameters, +): + num_heads = config.num_attention_heads + batch_size, sequence_size, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query = hidden_states @ parameters.attention.query.weight + query = query + parameters.attention.query.bias + query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT) + query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size)) + query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT) + query = ttnn.permute(query, (0, 2, 1, 3)) + + key = hidden_states @ parameters.attention.key.weight + key = key + parameters.attention.key.bias + key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT) + key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size)) + key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT) + key = ttnn.permute(key, (0, 2, 3, 1)) + + value = hidden_states @ parameters.attention.value.weight + value = value + parameters.attention.value.bias + value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT) + value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size)) + value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT) + value = ttnn.permute(value, (0, 2, 1, 3)) + + attention_scores = query @ key + attention_scores = attention_scores * (1 / (head_size**0.5)) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = ttnn.softmax(attention_scores, dim=-1) + + context_layer = attention_probs @ value + context_layer = ttnn.permute(context_layer, (0, 2, 1, 3)) + context_layer = ttnn.to_layout(context_layer, ttnn.ROW_MAJOR_LAYOUT) + context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size)) + context_layer = ttnn.to_layout(context_layer, ttnn.TILE_LAYOUT) + + self_output = context_layer + self_output = self_output @ parameters.output.dense.weight + self_output = self_output + parameters.output.dense.bias + + return self_output + + +def vit_intermediate( + hidden_states, + *, + parameters, +): + output = hidden_states @ parameters.dense.weight + output = output + parameters.dense.bias + output = ttnn.gelu(output) + return output + + +def vit_output( + config, + hidden_states, + residual, + *, + parameters, +): + output = hidden_states @ parameters.dense.weight + output = output + parameters.dense.bias + output = output + residual + + return output + + +def vit_feedforward( + config, + hidden_states, + attention_output, + *, + parameters, +): + intermediate = vit_intermediate(hidden_states, parameters=parameters.intermediate) + hidden_states = vit_output(config, intermediate, attention_output, parameters=parameters.output) + return hidden_states + + +def vit_layer( + config, + hidden_states, + attention_mask, + *, + parameters, +): + print("hhhh", hidden_states.shape) + layernorm_before_output = vit_layernorm_before( + config, + hidden_states, + parameters=parameters, + ) + attention_output = vit_attention( + config, + layernorm_before_output, + attention_mask, + parameters=parameters.attention, + ) + attention_output = attention_output + hidden_states + layernorm_after_output = vit_layernorm_after( + config, + attention_output, + parameters=parameters, + ) + feedforward_output = vit_feedforward( + config, + layernorm_after_output, + attention_output, + parameters=parameters, + ) + + return feedforward_output + + +def vit_encoder( + config, + hidden_states, + attention_mask, + *, + parameters, +): + encoder_input = hidden_states + encoder_output = None + for encoder_parameters in parameters.layer: + encoder_output = vit_layer( + config, + encoder_input, + attention_mask, + parameters=encoder_parameters, + ) + encoder_input = encoder_output + return encoder_output + + +def vit( + config, + pixel_values, + attention_mask, + cls_token, + position_embeddings, + *, + parameters, +): + embeddings_output = vit_embeddings(config, pixel_values, cls_token, position_embeddings, parameters=parameters) + + hidden_states = vit_encoder( + config, + embeddings_output, + attention_mask=None, + parameters=parameters.vit.encoder, + ) + + # Final LayerNorm + output = ttnn.layer_norm( + hidden_states, + weight=parameters.vit.layernorm.weight, + bias=parameters.vit.layernorm.bias, + ) + + # Classifier + classifier_output = output @ parameters.classifier.weight + classifier_output = classifier_output + parameters.classifier.bias + + return classifier_output + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.vit.modeling_vit.ViTEmbeddings): + weight = torch_model.patch_embeddings.projection.weight + bias = torch_model.patch_embeddings.projection.bias + + three_times_hidden_size, c, _, _ = weight.shape + pad_value = 4 - c + preprocessed_weight = torch.nn.functional.pad(weight, (0, 0, 0, 0, 0, pad_value)) + preprocessed_weight = torch.permute(preprocessed_weight, (2, 3, 1, 0)) + preprocessed_weight = torch.reshape( + preprocessed_weight, (int(three_times_hidden_size * (4 / c)), three_times_hidden_size) + ) + + parameters = {"patch_embeddings": {}} + parameters["patch_embeddings"] = {"projection": {}} + parameters["patch_embeddings"]["projection"]["weight"] = ttnn.from_torch( + preprocessed_weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT + ) + parameters["patch_embeddings"]["projection"]["bias"] = ttnn.from_torch( + bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT + ) + + parameters["cls_token"] = ttnn.from_torch(torch_model.cls_token, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + parameters["position_embeddings"] = ttnn.from_torch( + torch_model.position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT + ) + + return parameters diff --git a/models/experimental/functional_vit/tt/ttnn_functional_vit_highres.py b/models/experimental/functional_vit/tt/ttnn_functional_vit_highres.py new file mode 100644 index 00000000000..7507fff5272 --- /dev/null +++ b/models/experimental/functional_vit/tt/ttnn_functional_vit_highres.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +import torch + +import ttnn +import tt_lib as ttl +from ttnn.model_preprocessing import ( + preprocess_linear_weight, + preprocess_linear_bias, +) + + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def vit_patch_embeddings(config, pixel_values, *, parameters, unittest_check=False): + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + patch_count = img_h // patch_size # 14 + patch_size_sq_trpl = int(patch_size * patch_size * 3) # 768 + patch_count_all = int(patch_count * patch_count) # 196 + stride_h = patch_size + stride_w = 1 + + pixel_values = ttnn.reshape(pixel_values, (batch_size, img_h, img_w // patch_size, 4 * patch_size)) + pixel_values = ttl.tensor.fold(pixel_values, stride_h, stride_w) + pixel_values = ttnn.to_layout(pixel_values, layout=ttnn.TILE_LAYOUT) + + if unittest_check: + parameters = parameters.vit.embeddings.patch_embeddings + patch_embedding_output = pixel_values @ parameters.projection.weight + patch_embedding_output = patch_embedding_output + parameters.projection.bias + + patch_embedding_output = ttnn.to_layout(patch_embedding_output, layout=ttnn.ROW_MAJOR_LAYOUT) + patch_embedding_output = ttnn.reshape(patch_embedding_output, (batch_size, patch_count_all, patch_size_sq_trpl)) + + return patch_embedding_output + + +def vit_embeddings( + config, + pixel_values, + position_embeddings_interpolated, + *, + parameters, +): + parameters = parameters.vit.embeddings + + patch_embeddings = vit_patch_embeddings(config, pixel_values, parameters=parameters.patch_embeddings) + embedding_output = ttnn.concat((parameters.cls_token, patch_embeddings), dim=1) + embedding_output = embedding_output + position_embeddings_interpolated + # embedding_output = ttnn.pad(embedding_output, ((0, 0), (0, 31), (0, 0)), 0) + + return embedding_output + + +def vit_layernorm_before( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + ) + + return attention_output + + +def vit_layernorm_after( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + ) + + return attention_output + + +def vit_attention( + config, + hidden_states, + attention_mask, + *, + parameters, +): + num_heads = config.num_attention_heads + batch_size, sequence_size, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query = hidden_states @ parameters.attention.query.weight + query = query + parameters.attention.query.bias + query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT) + query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size)) + query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT) + query = ttnn.permute(query, (0, 2, 1, 3)) + + key = hidden_states @ parameters.attention.key.weight + key = key + parameters.attention.key.bias + key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT) + key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size)) + key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT) + key = ttnn.permute(key, (0, 2, 3, 1)) + + value = hidden_states @ parameters.attention.value.weight + value = value + parameters.attention.value.bias + value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT) + value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size)) + value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT) + value = ttnn.permute(value, (0, 2, 1, 3)) + + attention_scores = query @ key + attention_scores = attention_scores * (1 / (head_size**0.5)) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = ttnn.softmax(attention_scores, dim=-1) + + context_layer = attention_probs @ value + context_layer = ttnn.permute(context_layer, (0, 2, 1, 3)) + context_layer = ttnn.to_layout(context_layer, ttnn.ROW_MAJOR_LAYOUT) + context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size)) + context_layer = ttnn.to_layout(context_layer, ttnn.TILE_LAYOUT) + + self_output = context_layer + self_output = self_output @ parameters.output.dense.weight + self_output = self_output + parameters.output.dense.bias + + return self_output + + +def vit_intermediate( + hidden_states, + *, + parameters, +): + output = hidden_states @ parameters.dense.weight + output = output + parameters.dense.bias + output = ttnn.gelu(output) + return output + + +def vit_output( + config, + hidden_states, + residual, + *, + parameters, +): + output = hidden_states @ parameters.dense.weight + output = output + parameters.dense.bias + output = output + residual + + return output + + +def vit_feedforward( + config, + hidden_states, + attention_output, + *, + parameters, +): + intermediate = vit_intermediate(hidden_states, parameters=parameters.intermediate) + hidden_states = vit_output(config, intermediate, attention_output, parameters=parameters.output) + return hidden_states + + +def vit_layer( + config, + hidden_states, + attention_mask, + *, + parameters, +): + layernorm_before_output = vit_layernorm_before( + config, + hidden_states, + parameters=parameters, + ) + attention_output = vit_attention( + config, + layernorm_before_output, + attention_mask, + parameters=parameters.attention, + ) + attention_output = attention_output + hidden_states + layernorm_after_output = vit_layernorm_after( + config, + attention_output, + parameters=parameters, + ) + feedforward_output = vit_feedforward( + config, + layernorm_after_output, + attention_output, + parameters=parameters, + ) + + return feedforward_output + + +def vit_encoder( + config, + hidden_states, + attention_mask, + *, + parameters, +): + encoder_input = hidden_states + encoder_output = None + for encoder_parameters in parameters.layer: + encoder_output = vit_layer( + config, + encoder_input, + attention_mask, + parameters=encoder_parameters, + ) + encoder_input = encoder_output + return encoder_output + + +def vit( + config, + pixel_values, + attention_mask, + position_embeddings_interpolated, + *, + parameters, +): + embeddings_output = vit_embeddings(config, pixel_values, position_embeddings_interpolated, parameters=parameters) + embeddings_output = ttnn.to_layout(embeddings_output, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + hidden_states = vit_encoder( + config, + embeddings_output, + attention_mask=attention_mask, + parameters=parameters.vit.encoder, + ) + + # Final LayerNorm + output = ttnn.layer_norm( + hidden_states, + weight=parameters.vit.layernorm.weight, + bias=parameters.vit.layernorm.bias, + ) + + # Classifier + classifier_output = output @ parameters.classifier.weight + classifier_output = classifier_output + parameters.classifier.bias + return classifier_output + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.vit.modeling_vit.ViTEmbeddings): + weight = torch_model.patch_embeddings.projection.weight + bias = torch_model.patch_embeddings.projection.bias + + three_times_hidden_size, c, _, _ = weight.shape + pad_value = 4 - c + preprocessed_weight = torch.nn.functional.pad(weight, (0, 0, 0, 0, 0, pad_value)) + preprocessed_weight = torch.permute(preprocessed_weight, (2, 3, 1, 0)) + preprocessed_weight = torch.reshape( + preprocessed_weight, (int(three_times_hidden_size * (4 / c)), three_times_hidden_size) + ) + + parameters = {"patch_embeddings": {}} + parameters["patch_embeddings"] = {"projection": {}} + parameters["patch_embeddings"]["projection"]["weight"] = ttnn.from_torch( + preprocessed_weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT + ) + parameters["patch_embeddings"]["projection"]["bias"] = ttnn.from_torch( + bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT + ) + + parameters["cls_token"] = ttnn.from_torch(torch_model.cls_token, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + parameters["position_embeddings"] = ttnn.from_torch( + torch_model.position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT + ) + + return parameters diff --git a/models/experimental/functional_vit/tt/ttnn_optimized_interleaved_vit.py b/models/experimental/functional_vit/tt/ttnn_optimized_interleaved_vit.py new file mode 100644 index 00000000000..0bd124f0e76 --- /dev/null +++ b/models/experimental/functional_vit/tt/ttnn_optimized_interleaved_vit.py @@ -0,0 +1,446 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +import torch + +import ttnn +from ttnn.model_preprocessing import ( + preprocess_linear_weight, + preprocess_linear_bias, +) + +core_grid = ttnn.CoreGrid(y=8, x=12) + + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def vit_patch_embeddings(config, pixel_values, *, parameters, unittest_check=False): + # batch_size, img_c, img_h, img_w = pixel_values.shape # NCHW + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + patch_count = img_h // patch_size # 14 + patch_size_sq_trpl = int(patch_size * patch_size * 3) # 768 + patch_count_all = int(patch_count * patch_count) # 196 + stride_h = patch_size + stride_w = 1 + + folded_pixel_values = ttnn.experimental.tensor.fold(pixel_values, stride_h, stride_w) # 1568, 1024 + ttnn.deallocate(pixel_values) + x = ttnn.reallocate(folded_pixel_values) + folded_pixel_values = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) + + if unittest_check: + parameters = parameters.vit.embeddings.patch_embeddings + + patch_embedding_output = ttnn.linear( + folded_pixel_values, + parameters.projection.weight, + bias=parameters.projection.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=8, x=12), + ) + # ttnn.deallocate(pixel_values) + + patch_embedding_output = ttnn.to_layout(patch_embedding_output, layout=ttnn.ROW_MAJOR_LAYOUT) + patch_embedding_output = ttnn.reshape(patch_embedding_output, (batch_size, patch_count_all, patch_size_sq_trpl)) + + return patch_embedding_output + + +def vit_embeddings( + config, + pixel_values, + cls_token, + position_embeddings, + *, + parameters, +): + parameters = parameters.vit.embeddings + # cls_token = parameters.cls_token + # position_embeddings = parameters.position_embeddings + + l1_memory_config = ttnn.experimental.tensor.MemoryConfig( + memory_layout=ttnn.experimental.tensor.TensorMemoryLayout.INTERLEAVED, + buffer_type=ttnn.experimental.tensor.BufferType.L1, + ) + + patch_embeddings = vit_patch_embeddings(config, pixel_values, parameters=parameters.patch_embeddings) + + embedding_output = ttnn.experimental.tensor.concat([cls_token, patch_embeddings], -2, l1_memory_config) + # embedding_output = ttnn.pad(embedding_output, padding=((0, 0), (0, 27), (0, 0)), value=0) + # print("out", embedding_output.shape) + embedding_output = ttnn.to_layout(embedding_output, layout=ttnn.TILE_LAYOUT) + + embedding_output = ttnn.add( + embedding_output, position_embeddings, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b + ) + # embedding_output = ttnn.to_layout(embedding_output, layout=ttnn.TILE_LAYOUT) + + return embedding_output + + +def vit_layernorm_before( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_layernorm_after( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_attention( + config, + hidden_states, + attention_mask, + parameters, +): + num_heads = config.num_attention_heads + *_, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query_key_value = ttnn.linear( + hidden_states, + parameters.attention.query_key_value.weight, + bias=parameters.attention.query_key_value.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["query_key_value_matmul_program_config"], + ) + + ( + query, + key, + value, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value, + memory_config=ttnn.L1_MEMORY_CONFIG, + num_heads=num_heads, + ) + ttnn.deallocate(query_key_value) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["query_by_key_matmul_program_config"], + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + + attention_probs = ttnn.transformer.attention_softmax_( + attention_scores, + attention_mask=attention_mask, + head_size=head_size, + # program_config=program_configs["softmax_program_config"], + ) + + context_layer = ttnn.matmul( + attention_probs, + value, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["attention_probabilities_by_value_matmul_program_config"], + ) + ttnn.deallocate(attention_probs) + ttnn.deallocate(value) + + context_layer = ttnn.transformer.concatenate_heads( + context_layer, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + self_output = ttnn.linear( + context_layer, + parameters.output.dense.weight, + bias=parameters.output.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["self_output_matmul_program_config"], + ) + ttnn.deallocate(context_layer) + + return self_output + + +def vit_intermediate( + hidden_states, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + # program_config=program_configs["ff1_matmul_program_config"], + core_grid=ttnn.CoreGrid(y=8, x=12), + activation="gelu", + ) + # ttnn.deallocate(hidden_states) + + return output + + +def vit_output( + config, + hidden_states, + residual, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["ff2_matmul_program_config"], + ) + ttnn.deallocate(hidden_states) + + output = ttnn.add(output, residual, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b) + + # ttnn.deallocate(residual) + + return output + + +def vit_feedforward( + config, + hidden_states, + attention_output, + *, + parameters, +): + intermediate = vit_intermediate(hidden_states, parameters=parameters.intermediate) + hidden_states = vit_output(config, intermediate, attention_output, parameters=parameters.output) + return hidden_states + + +def vit_layer( + config, + hidden_states, + attention_mask, + parameters, +): + layernorm_before_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + # program_config=program_configs["layernorm_program_config"], + ) + + multi_head_attention_output = vit_attention( + config, + layernorm_before_output, + attention_mask=attention_mask, + parameters=parameters.attention, + ) + + multi_head_attention_output = ttnn.add( + multi_head_attention_output, hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b + ) + # return multi_head_attention_output + + layernorm_after_output = ttnn.layer_norm( + multi_head_attention_output, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + # program_config=program_configs["layernorm_program_config"], + ) + + feedforward_output = vit_feedforward( + config, + layernorm_after_output, + multi_head_attention_output, + parameters=parameters, + ) + + return feedforward_output + + +def vit_encoder( + config, + embeddings, + head_masks, + parameters, +): + encoder_input = embeddings + + encoder_output = None + for index, encoder_parameters in enumerate(parameters.layer): + encoder_output = vit_layer( + config, + encoder_input, + head_masks[index], + encoder_parameters, + ) + encoder_input = encoder_output + + return encoder_output + + +def vit( + config, + pixel_values, + attention_mask, + cls_token, + position_embeddings, + parameters, +): + embeddings_output = vit_embeddings(config, pixel_values, cls_token, position_embeddings, parameters=parameters) + + hidden_states = vit_encoder( + config, + embeddings_output, + attention_mask, + parameters=parameters.vit.encoder, + ) + # ttnn.deallocate(embeddings_output) + + # Final LayerNorm + output = ttnn.layer_norm( + hidden_states, + weight=parameters.vit.layernorm.weight, + bias=parameters.vit.layernorm.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + # core_grid=ttnn.CoreGrid(y=8, x=12), + ) + + # Classifier + # classifier_output = ttnn.linear( + # output, + # parameters.classifier.weight, + # bias=parameters.classifier.bias, + # memory_config=ttnn.L1_MEMORY_CONFIG, + # dtype=ttnn.bfloat8_b, + # core_grid=ttnn.CoreGrid(y=8, x=12), + # ) + classifier_output = output @ parameters.classifier.weight + classifier_output = classifier_output + parameters.classifier.bias + + return classifier_output + + +def preprocess_inputs( + input_ids, + token_type_ids, + attention_mask, + device, +): + batch_size, _ = input_ids.shape + + input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + + token_type_ids = ttnn.from_torch( + token_type_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) + attention_mask = torch.nn.functional.pad(attention_mask, (0, 0, 0, 0, 0, 0, 0, batch_size - 1)) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return input_ids, token_type_ids, attention_mask + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.vit.modeling_vit.ViTEmbeddings): + weight = torch_model.patch_embeddings.projection.weight + bias = torch_model.patch_embeddings.projection.bias + + three_times_hidden_size, c, _, _ = weight.shape + pad_value = 4 - c + preprocessed_weight = torch.nn.functional.pad(weight, (0, 0, 0, 0, 0, pad_value)) + preprocessed_weight = torch.permute(preprocessed_weight, (2, 3, 1, 0)) + preprocessed_weight = torch.reshape( + preprocessed_weight, (int(three_times_hidden_size * (4 / c)), three_times_hidden_size) + ) + + parameters = {"patch_embeddings": {}} + parameters["patch_embeddings"] = {"projection": {}} + parameters["patch_embeddings"]["projection"]["weight"] = ttnn.from_torch( + preprocessed_weight, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + parameters["patch_embeddings"]["projection"]["bias"] = ttnn.from_torch( + bias.unsqueeze(0), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + parameters["cls_token"] = ttnn.from_torch(torch_model.cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT) + parameters["position_embeddings"] = ttnn.from_torch( + torch_model.position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + if hasattr(torch_model, "query") and hasattr(torch_model, "key") and hasattr(torch_model, "value"): + qkv_weight = torch.cat( + [ + torch_model.query.weight, + torch_model.key.weight, + torch_model.value.weight, + ], + dim=0, + ) + qkv_bias = torch.cat( + [torch_model.query.bias, torch_model.key.bias, torch_model.value.bias], + dim=0, + ) + + parameters = {"query_key_value": {}} + parameters["query_key_value"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=ttnn.bfloat8_b) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=ttnn.bfloat8_b) + + elif isinstance(torch_model, torch.nn.Linear): + parameters["weight"] = preprocess_linear_weight(torch_model.weight, dtype=ttnn.bfloat8_b) + parameters["bias"] = preprocess_linear_bias(torch_model.bias, dtype=ttnn.bfloat8_b) + + return parameters diff --git a/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py new file mode 100644 index 00000000000..130088e8e01 --- /dev/null +++ b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit.py @@ -0,0 +1,621 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +import torch +from ttnn.model_preprocessing import ( + preprocess_linear_weight, + preprocess_linear_bias, +) + +import ttnn + +# import tt_lib as ttl +from ttnn.dot_access import DotAccessDict + + +def update_model_config(config, batch_size): + core_grid = ttnn.CoreGrid(y=8, x=12) + + program_configs = { + "fold_output_program_config": ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + ttnn.experimental.tensor.ShardSpec( + ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(12, 7), + ), + } + ), + [ + 224, + 192, + ], + ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, + False, + ), + ), + "embedding_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=3, + out_subblock_h=1, + out_subblock_w=6, + per_core_M=7, + per_core_N=6, + transpose_mcast=False, + fused_activation=None, + ), + "query_key_value_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=6, + per_core_M=7, + per_core_N=6, + transpose_mcast=False, + fused_activation=None, + ), + "query_by_key_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=7, + per_core_M=7, + per_core_N=7, + ), + "attention_probabilities_by_value_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=7, + out_subblock_h=1, + out_subblock_w=2, + per_core_M=7, + per_core_N=2, + ), + "self_output_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=7, + out_subblock_w=2, + per_core_M=7, + per_core_N=2, + transpose_mcast=False, + fused_activation=None, + ), + "ff1_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=4, + per_core_M=7, + per_core_N=8, + transpose_mcast=False, + fused_activation=(ttnn.experimental.tensor.FusibleActivation.GELU, True), + ), + "ff2_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=8, + out_subblock_h=7, + out_subblock_w=2, + per_core_M=7, + per_core_N=2, + transpose_mcast=False, + fused_activation=None, + ), + "classifer_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=3, + per_core_M=7, + per_core_N=3, + transpose_mcast=False, + fused_activation=(ttnn.experimental.tensor.FusibleActivation.GELU, True), + ), + "layernorm_program_config": ttnn.experimental.operations.primary.LayerNormShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + subblock_w=2, + block_h=7, + block_w=2, + # math_fidelity=ttnn.experimental.tensor.MathFidelity.HiFi4, + # im_data_format=ttnn.experimental.tensor.DataType.BFLOAT16, + # out_data_format=ttnn.bfloat8_b, + inplace=False, + ), + "layernorm_after_output_program_config": ttnn.experimental.operations.primary.LayerNormShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + subblock_w=2, + block_h=7, + block_w=2, + # math_fidelity=ttnn.experimental.tensor.MathFidelity.HiFi4, + # im_data_format=ttnn.experimental.tensor.DataType.BFLOAT16, + # out_data_format=ttnn.bfloat8_b, + inplace=False, + ), + "softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + subblock_w=7, + block_h=7, + block_w=7, + # math_fidelity=ttnn.experimental.tensor.MathFidelity.HiFi4, + # im_data_format=ttnn.experimental.tensor.DataType.BFLOAT16, + ), + } + + return DotAccessDict(dict(**config.to_dict(), core_grid=core_grid, program_configs=program_configs)) + + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def vit_patch_embeddings(config, pixel_values, *, parameters, unittest_check=False): + # batch_size, img_c, img_h, img_w = pixel_values.shape # NCHW + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + patch_count = img_h // patch_size # 14 + patch_size_sq_trpl = int(patch_size * patch_size * 3) # 768 + patch_count_all = int(patch_count * patch_count) # 196 + stride_h = patch_size + stride_w = 1 + + fold_h_padded = (batch_size * img_h * patch_count_all) + 224 + fold_w_padded = (4 * patch_size * patch_size) + 128 + + folded_pixel_values = ttnn.experimental.tensor.fold(pixel_values, stride_h, stride_w) # 1568, 1024 + ttnn.deallocate(pixel_values) + x = ttnn.reallocate(folded_pixel_values) + folded_pixel_values = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) + + #### Exp 1 of resharding after Fold and before Matmul + # pixel_values = ttnn.pad(pixel_values, ((0, 0), (0, 0), (0, 224), (0, 128)), 0) + # output_sharded_memory_config_args = dict(core_grid=ttnn.CoreGrid(y=8, x=12), strategy=ttnn.ShardStrategy.BLOCK) + # input_shape = [fold_h_padded, fold_w_padded] + # output_shard_memory_config = ttnn.create_sharded_memory_config(input_shape, **output_sharded_memory_config_args) + # resharded_pixel_values = ttnn.to_memory_config(pixel_values, output_shard_memory_config) + + #### Exp 2 of resharding after Fold and before Matmul + # pixel_values = ttnn.pad(pixel_values, ((0, 0), (0, 0), (0, 224), (0, 128)), 0) + # post_fold_config = ttnn.experimental.tensor.MemoryConfig( + # ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED, + # ttnn.experimental.tensor.BufferType.L1, + # ttnn.experimental.tensor.ShardSpec( + # ttnn.experimental.tensor.CoreRangeSet( + # {ttnn.experimental.tensor.CoreRange( + # ttnn.experimental.tensor.CoreCoord(0, 0), + # ttnn.experimental.tensor.CoreCoord(11, 7), + # ),}, + # ), + # [224,192], + # ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, + # False, + # ), + # ) + # resharded_pixel_values = ttnn.experimental.tensor.reshard(pixel_values, post_fold_config) + + # return resharded_pixel_values + + if unittest_check: + parameters = parameters.vit.embeddings.patch_embeddings + + patch_embedding_output = ttnn.linear( + folded_pixel_values, + parameters.projection.weight, + bias=parameters.projection.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=config.program_configs["embedding_matmul_program_config"], + ) + # ttnn.deallocate(pixel_values) + + patch_embedding_output = ttnn.to_layout(patch_embedding_output, layout=ttnn.ROW_MAJOR_LAYOUT) + patch_embedding_output = ttnn.reshape(patch_embedding_output, (batch_size, patch_count_all, patch_size_sq_trpl)) + + return patch_embedding_output + + +def vit_embeddings( + config, + pixel_values, + cls_token, + position_embeddings, + *, + parameters, +): + parameters = parameters.vit.embeddings + + l1_memory_config = ttnn.experimental.tensor.MemoryConfig( + memory_layout=ttnn.experimental.tensor.TensorMemoryLayout.INTERLEAVED, + buffer_type=ttnn.experimental.tensor.BufferType.L1, + ) + + patch_embeddings = vit_patch_embeddings(config, pixel_values, parameters=parameters.patch_embeddings) + embedding_output = ttnn.experimental.tensor.concat([cls_token, patch_embeddings], -2, l1_memory_config) + embedding_output = ttnn.to_layout(embedding_output, layout=ttnn.TILE_LAYOUT) + embedding_output = ttnn.add( + embedding_output, position_embeddings, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b + ) + # Needed to improve PCC in an older commit + # embedding_output = ttnn.pad(embedding_output, ((0, 0), (0, 27), (0, 0)), 0) + + return embedding_output + + +def vit_layernorm_before( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_layernorm_after( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_attention( + config, + hidden_states, + attention_mask, + parameters, +): + num_heads = config.num_attention_heads + num_heads = 12 + *_, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query_key_value = ttnn.linear( + hidden_states, + parameters.attention.query_key_value.weight, + bias=parameters.attention.query_key_value.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["query_key_value_matmul_program_config"], + ) + + ( + query, + key, + value, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + num_heads=num_heads, + ) + ttnn.deallocate(query_key_value) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["query_by_key_matmul_program_config"], + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + + attention_probs = ttnn.transformer.attention_softmax_( + attention_scores, + attention_mask=attention_mask, + head_size=head_size, + program_config=config.program_configs["softmax_program_config"], + ) + + context_layer = ttnn.matmul( + attention_probs, + value, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["attention_probabilities_by_value_matmul_program_config"], + ) + ttnn.deallocate(attention_probs) + ttnn.deallocate(value) + + context_layer = ttnn.transformer.concatenate_heads( + context_layer, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + ) + + self_output = ttnn.linear( + context_layer, + parameters.output.dense.weight, + bias=parameters.output.dense.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["self_output_matmul_program_config"], + ) + ttnn.deallocate(context_layer) + + return self_output + + +def vit_intermediate( + config, + hidden_states, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["ff1_matmul_program_config"], + ) + ttnn.deallocate(hidden_states) + + return output + + +def vit_output( + config, + hidden_states, + residual, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["ff2_matmul_program_config"], + ) + ttnn.deallocate(hidden_states) + + output = ttnn.add(output, residual, memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, dtype=ttnn.bfloat8_b) + ttnn.deallocate(residual) + + return output + + +def vit_feedforward( + config, + hidden_states, + attention_output, + *, + parameters, +): + intermediate = vit_intermediate(config, hidden_states, parameters=parameters.intermediate) + hidden_states = vit_output(config, intermediate, attention_output, parameters=parameters.output) + return hidden_states + + +def vit_layer( + config, + hidden_states, + attention_mask, + parameters, +): + layernorm_before_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + multi_head_attention_output = vit_attention( + config, + layernorm_before_output, + attention_mask=attention_mask, + parameters=parameters.attention, + ) + + multi_head_attention_output = ttnn.add( + multi_head_attention_output, + hidden_states, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + ) + + layernorm_after_output = ttnn.layer_norm( + multi_head_attention_output, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_after_output_program_config"], + ) + + feedforward_output = vit_feedforward( + config, + layernorm_after_output, + multi_head_attention_output, + parameters=parameters, + ) + + return feedforward_output + + +def vit_encoder( + config, + embeddings, + head_masks, + parameters, +): + encoder_input = ttnn.to_memory_config( + embeddings, + memory_config=ttnn.create_sharded_memory_config( + [8, 224, 768], # embeddings.shape, # hardcoded because a bug where it still sees the 197 not 224 + core_grid=config.core_grid, + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ), + dtype=ttnn.bfloat8_b, + ) + ttnn.deallocate(embeddings) + + for index, encoder_parameters in enumerate(parameters.layer): + encoder_output = vit_layer( + config, + encoder_input, + head_masks[index], + encoder_parameters, + ) + encoder_input = encoder_output + + return encoder_output + + +def vit( + config, + pixel_values, + attention_mask, + cls_token, + position_embeddings, + parameters, +): + embeddings_output = vit_embeddings(config, pixel_values, cls_token, position_embeddings, parameters=parameters) + + hidden_states = vit_encoder( + config, + embeddings_output, + attention_mask, + parameters=parameters.vit.encoder, + ) + + # Final LayerNorm + output = ttnn.layer_norm( + hidden_states, + weight=parameters.vit.layernorm.weight, + bias=parameters.vit.layernorm.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + # Classifier + classifier_output = ttnn.linear( + output, + parameters.classifier.weight, + bias=parameters.classifier.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["classifer_matmul_program_config"], + ) + + return classifier_output + + +def preprocess_inputs( + input_ids, + token_type_ids, + attention_mask, + device, +): + batch_size, _ = input_ids.shape + + input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + + token_type_ids = ttnn.from_torch( + token_type_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) + attention_mask = torch.nn.functional.pad(attention_mask, (0, 0, 0, 0, 0, 0, 0, batch_size - 1)) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return input_ids, token_type_ids, attention_mask + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.vit.modeling_vit.ViTEmbeddings): + weight = torch_model.patch_embeddings.projection.weight + bias = torch_model.patch_embeddings.projection.bias + + three_times_hidden_size, c, _, _ = weight.shape + pad_value = 4 - c + preprocessed_weight = torch.nn.functional.pad(weight, (0, 0, 0, 0, 0, pad_value)) + preprocessed_weight = torch.permute(preprocessed_weight, (2, 3, 1, 0)) + preprocessed_weight = torch.reshape( + preprocessed_weight, (int(three_times_hidden_size * (4 / c)), three_times_hidden_size) + ) + + parameters = {"patch_embeddings": {}} + parameters["patch_embeddings"] = {"projection": {}} + parameters["patch_embeddings"]["projection"]["weight"] = ttnn.from_torch( + preprocessed_weight, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + parameters["patch_embeddings"]["projection"]["bias"] = ttnn.from_torch( + bias.unsqueeze(0), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + parameters["cls_token"] = ttnn.from_torch(torch_model.cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT) + parameters["position_embeddings"] = ttnn.from_torch( + torch_model.position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + if hasattr(torch_model, "query") and hasattr(torch_model, "key") and hasattr(torch_model, "value"): + num_heads = 12 + head_size = 64 + hidden_size = num_heads * head_size * 3 + qkv_weight = torch.cat( + [ + torch_model.query.weight.reshape([num_heads, head_size, -1]), + torch_model.key.weight.reshape([num_heads, head_size, -1]), + torch_model.value.weight.reshape([num_heads, head_size, -1]), + ], + dim=1, + ).reshape([hidden_size, -1]) + qkv_bias = torch.cat( + [ + torch_model.query.bias.reshape([num_heads, head_size]), + torch_model.key.bias.reshape([num_heads, head_size]), + torch_model.value.bias.reshape([num_heads, head_size]), + ], + dim=1, + ).reshape([hidden_size]) + + parameters = {"query_key_value": {}} + parameters["query_key_value"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=ttnn.bfloat8_b) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=ttnn.bfloat8_b) + + elif isinstance(torch_model, torch.nn.Linear): + # TODO: better way of detection for the classify linear weights + if torch_model.weight.shape[0] == 1000: + preprocessed_weight = torch.nn.functional.pad(torch_model.weight, (0, 0, 0, int(1152 - 1000))) + preprocessed_bias = torch.nn.functional.pad(torch_model.bias, (0, int(1152 - 1000))) + parameters["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat8_b) + parameters["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat8_b) + else: + parameters["weight"] = preprocess_linear_weight(torch_model.weight, dtype=ttnn.bfloat8_b) + parameters["bias"] = preprocess_linear_bias(torch_model.bias, dtype=ttnn.bfloat8_b) + + return parameters diff --git a/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py new file mode 100644 index 00000000000..557aae803bd --- /dev/null +++ b/models/experimental/functional_vit/tt/ttnn_optimized_sharded_vit_backup.py @@ -0,0 +1,688 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +import torch +from ttnn.model_preprocessing import ( + preprocess_linear_weight, + preprocess_linear_bias, +) + +import ttnn + +# import tt_lib as ttl +from ttnn.dot_access import DotAccessDict + + +def update_model_config(config, batch_size): + core_grid = ttnn.CoreGrid(y=8, x=12) + + program_configs = { + "fold_output_program_config": ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + ttnn.experimental.tensor.ShardSpec( + ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(12, 7), + ), + } + ), + [ + 224, + 192, + ], + ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, + False, + ), + ), + "embedding_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=3, + out_subblock_h=1, + out_subblock_w=6, + per_core_M=7, + per_core_N=6, + transpose_mcast=False, + fused_activation=None, + ), + "query_key_value_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=6, + per_core_M=7, + per_core_N=6, + transpose_mcast=False, + fused_activation=None, + ), + "query_by_key_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=7, + per_core_M=7, + per_core_N=7, + ), + "attention_probabilities_by_value_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=7, + out_subblock_h=1, + out_subblock_w=2, + per_core_M=7, + per_core_N=2, + ), + "self_output_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=7, + out_subblock_w=2, + per_core_M=7, + per_core_N=2, + transpose_mcast=False, + fused_activation=None, + ), + "ff1_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=4, + per_core_M=7, + per_core_N=8, + transpose_mcast=False, + fused_activation=(ttnn.experimental.tensor.FusibleActivation.GELU, True), + ), + "ff2_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=8, + out_subblock_h=7, + out_subblock_w=2, + per_core_M=7, + per_core_N=2, + transpose_mcast=False, + fused_activation=None, + ), + "classifer_matmul_program_config": ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=3, + per_core_M=7, + per_core_N=3, + transpose_mcast=False, + fused_activation=(ttnn.experimental.tensor.FusibleActivation.GELU, True), + ), + "layernorm_program_config": ttnn.experimental.operations.primary.LayerNormShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + subblock_w=2, + block_h=7, + block_w=2, + # math_fidelity=ttnn.experimental.tensor.MathFidelity.HiFi4, + # im_data_format=ttnn.experimental.tensor.DataType.BFLOAT16, + # out_data_format=ttnn.bfloat8_b, + inplace=True, + ), + "layernorm_after_output_program_config": ttnn.experimental.operations.primary.LayerNormShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + subblock_w=2, + block_h=7, + block_w=2, + # math_fidelity=ttnn.experimental.tensor.MathFidelity.HiFi4, + # im_data_format=ttnn.experimental.tensor.DataType.BFLOAT16, + # out_data_format=ttnn.bfloat8_b, + inplace=False, + ), + "softmax_program_config": ttnn.experimental.operations.primary.transformers.SoftmaxShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=(core_grid.x, core_grid.y), + subblock_w=7, + block_h=7, + block_w=7, + # math_fidelity=ttnn.experimental.tensor.MathFidelity.HiFi4, + # im_data_format=ttnn.experimental.tensor.DataType.BFLOAT16, + ), + } + + return DotAccessDict(dict(**config.to_dict(), core_grid=core_grid, program_configs=program_configs)) + + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def vit_patch_embeddings( + config, + pixel_values, + *, + parameters, +): + # batch_size, img_c, img_h, img_w = pixel_values.shape # NCHW + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + patch_count = img_h // patch_size # 14 + patch_size_sq_trpl = int(patch_size * patch_size * 3) # 768 + patch_count_all = int(patch_count * patch_count) # 196 + stride_h = patch_size + stride_w = 1 + + fold_h_padded = (batch_size * img_h * patch_count_all) + 224 + fold_w_padded = (4 * patch_size * patch_size) + 128 + + # pixel_values = ttnn.reshape(pixel_values, (batch_size, img_h, img_w // patch_size, 4 * patch_size)) + folded_pixel_values = ttnn.experimental.tensor.fold(pixel_values, stride_h, stride_w) # 1568, 1024 + ttnn.deallocate(pixel_values) + x = ttnn.reallocate(folded_pixel_values) + folded_pixel_values = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat8_b) + + #### Exp 1 of resharding after Fold and before Matmul + # pixel_values = ttnn.pad(pixel_values, ((0, 0), (0, 0), (0, 224), (0, 128)), 0) + # output_sharded_memory_config_args = dict(core_grid=ttnn.CoreGrid(y=8, x=12), strategy=ttnn.ShardStrategy.BLOCK) + # input_shape = [fold_h_padded, fold_w_padded] + # output_shard_memory_config = ttnn.create_sharded_memory_config(input_shape, **output_sharded_memory_config_args) + # resharded_pixel_values = ttnn.to_memory_config(pixel_values, output_shard_memory_config) + + #### Exp 2 of resharding after Fold and before Matmul + # pixel_values = ttnn.pad(pixel_values, ((0, 0), (0, 0), (0, 224), (0, 128)), 0) + # post_fold_config = ttnn.experimental.tensor.MemoryConfig( + # ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED, + # ttnn.experimental.tensor.BufferType.L1, + # ttnn.experimental.tensor.ShardSpec( + # ttnn.experimental.tensor.CoreRangeSet( + # {ttnn.experimental.tensor.CoreRange( + # ttnn.experimental.tensor.CoreCoord(0, 0), + # ttnn.experimental.tensor.CoreCoord(11, 7), + # ),}, + # ), + # [224,192], + # ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, + # False, + # ), + # ) + # resharded_pixel_values = ttnn.experimental.tensor.reshard(pixel_values, post_fold_config) + + # return resharded_pixel_values + + ## Needed only when running the standalone module pytest test_vit_patch_embeddings + ## Please comment out when running the pytest on parent module like test_vit_embeddings or test_vit + # parameters = parameters.vit.embeddings.patch_embeddings + + patch_embedding_output = ttnn.linear( + folded_pixel_values, + parameters.projection.weight, + bias=parameters.projection.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat16, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=config.program_configs["embedding_matmul_program_config"], + ) + # ttnn.deallocate(pixel_values) + + patch_embedding_output = ttnn.to_layout(patch_embedding_output, layout=ttnn.ROW_MAJOR_LAYOUT) + patch_embedding_output = ttnn.reshape(patch_embedding_output, (batch_size, patch_count_all, patch_size_sq_trpl)) + + return patch_embedding_output + + +def vit_embeddings( + config, + pixel_values, + cls_token, + position_embeddings, + *, + parameters, +): + parameters = parameters.vit.embeddings + # cls_token = parameters.cls_token + # position_embeddings = parameters.position_embeddings + + l1_memory_config = ttnn.experimental.tensor.MemoryConfig( + memory_layout=ttnn.experimental.tensor.TensorMemoryLayout.INTERLEAVED, + buffer_type=ttnn.experimental.tensor.BufferType.L1, + ) + + patch_embeddings = vit_patch_embeddings(config, pixel_values, parameters=parameters.patch_embeddings) + # print("clcs", cls_token.shape) + # print("patch", patch_embeddings.shape) + # patch_embeddings = ttnn.pad(patch_embeddings, padding=((0, 0), (1, 27), (0, 0)), value=0) + # embedding_output = ttnn.to_layout(patch_embeddings, layout=ttnn.TILE_LAYOUT) + + embedding_output = ttnn.experimental.tensor.concat([cls_token, patch_embeddings], -2, l1_memory_config) + embedding_output = ttnn.pad(embedding_output, padding=((0, 0), (0, 27), (0, 0)), value=0) + # print("out", embedding_output.shape) + embedding_output = ttnn.to_layout(embedding_output, layout=ttnn.TILE_LAYOUT) + # print("outTilized", embedding_output.shape) + embedding_output = ttnn.add( + embedding_output, position_embeddings, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b + ) + # embedding_output = ttnn.to_layout(embedding_output, layout=ttnn.TILE_LAYOUT) + # Needed to improve PCC in an older commit + # embedding_output = ttnn.pad(embedding_output, ((0, 0), (0, 27), (0, 0)), 0) + + return embedding_output + + +def vit_layernorm_before( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_layernorm_after( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_attention( + config, + hidden_states, + attention_mask, + parameters, +): + num_heads = config.num_attention_heads + num_heads = 12 + *_, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + # encoder_input = ttnn.to_memory_config( + # hidden_states, + # memory_config=ttnn.create_sharded_memory_config( + # hidden_states.shape, + # core_grid=config.core_grid, + # strategy=ttnn.ShardStrategy.BLOCK, + # orientation=ttnn.ShardOrientation.ROW_MAJOR, + # #orientation=ttnn.ShardOrientation.COLUMN_MAJOR, + # ), + # dtype=ttnn.bfloat8_b, + # ) + # ttnn.deallocate(hidden_states) + + encoder_input = hidden_states + + query_key_value = ttnn.linear( + encoder_input, + parameters.attention.query_key_value.weight, + bias=parameters.attention.query_key_value.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + # core_grid=ttnn.CoreGrid(y=8, x=8), + program_config=config.program_configs["query_key_value_matmul_program_config"], + ) + + ( + query, + key, + value, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + num_heads=num_heads, + ) + ttnn.deallocate(query_key_value) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + # core_grid=ttnn.CoreGrid(y=8, x=8), + program_config=config.program_configs["query_by_key_matmul_program_config"], + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + + attention_probs = ttnn.transformer.attention_softmax_( + attention_scores, + attention_mask=attention_mask, + head_size=head_size, + program_config=config.program_configs["softmax_program_config"], + ) + + context_layer = ttnn.matmul( + attention_probs, + value, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + # core_grid=ttnn.CoreGrid(y=8, x=8), + program_config=config.program_configs["attention_probabilities_by_value_matmul_program_config"], + ) + ttnn.deallocate(attention_probs) + ttnn.deallocate(value) + + context_layer = ttnn.transformer.concatenate_heads( + context_layer, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + ) + + self_output = ttnn.linear( + context_layer, + parameters.output.dense.weight, + bias=parameters.output.dense.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + # core_grid=ttnn.CoreGrid(y=8, x=8), + program_config=config.program_configs["self_output_matmul_program_config"], + ) + ttnn.deallocate(context_layer) + + return self_output + + +def vit_intermediate( + config, + hidden_states, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["ff1_matmul_program_config"], + # core_grid=ttnn.CoreGrid(y=8, x=8), + # activation="gelu", + ) + ttnn.deallocate(hidden_states) + + return output + + +def vit_output( + config, + hidden_states, + residual, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + # core_grid=ttnn.CoreGrid(y=8, x=8), + program_config=config.program_configs["ff2_matmul_program_config"], + ) + ttnn.deallocate(hidden_states) + + # residual_sh = ttnn.to_memory_config( + # residual, + # memory_config=ttnn.create_sharded_memory_config( + # residual.shape, + # core_grid=config.core_grid, + # strategy=ttnn.ShardStrategy.BLOCK, + # orientation=ttnn.ShardOrientation.ROW_MAJOR, + # # orientation=ttnn.ShardOrientation.COLUMN_MAJOR, + # ), + # dtype=ttnn.bfloat8_b, + # ) + # ttnn.deallocate(residual) + + residual_sh = residual + + output = ttnn.add(output, residual_sh, memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, dtype=ttnn.bfloat8_b) + ttnn.deallocate(residual_sh) + + return output + + +def vit_feedforward( + config, + hidden_states, + attention_output, + *, + parameters, +): + intermediate = vit_intermediate(config, hidden_states, parameters=parameters.intermediate) + hidden_states = vit_output(config, intermediate, attention_output, parameters=parameters.output) + return hidden_states + + +def vit_layer( + config, + hidden_states, + attention_mask, + parameters, +): + # encoder_input = ttnn.to_memory_config( + # hidden_states, + # memory_config=ttnn.create_sharded_memory_config( + # hidden_states.shape, + # core_grid=config.core_grid, + # strategy=ttnn.ShardStrategy.BLOCK, + # orientation=ttnn.ShardOrientation.ROW_MAJOR, + # # orientation=ttnn.ShardOrientation.COLUMN_MAJOR, + # ), + # dtype=ttnn.bfloat8_b, + # ) + # ttnn.deallocate(hidden_states) + + encoder_input = hidden_states + + layernorm_before_output = ttnn.layer_norm( + encoder_input, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + multi_head_attention_output = vit_attention( + config, + layernorm_before_output, + attention_mask=attention_mask, + parameters=parameters.attention, + ) + + residual = ttnn.add( + multi_head_attention_output, + encoder_input, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + ) + ttnn.deallocate(multi_head_attention_output) + + layernorm_after_output = ttnn.layer_norm( + residual, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_after_output_program_config"], + ) + + feedforward_output = vit_feedforward( + config, + layernorm_after_output, + residual, + parameters=parameters, + ) + + return feedforward_output + + +def vit_encoder( + config, + embeddings, + head_masks, + parameters, +): + encoder_input = ttnn.to_memory_config( + embeddings, + memory_config=ttnn.create_sharded_memory_config( + [8, 224, 768], # embeddings.shape + core_grid=config.core_grid, + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ), + dtype=ttnn.bfloat8_b, + ) + + # ttnn.deallocate(embeddings) + # encoder_input = embeddings + + encoder_output = None + for index, encoder_parameters in enumerate(parameters.layer): + encoder_output = vit_layer( + config, + encoder_input, + head_masks[index], + encoder_parameters, + ) + encoder_input = encoder_output + + return encoder_output + + +def vit( + config, + pixel_values, + attention_mask, + cls_token, + position_embeddings, + parameters, +): + embeddings_output = vit_embeddings(config, pixel_values, cls_token, position_embeddings, parameters=parameters) + + hidden_states = vit_encoder( + config, + embeddings_output, + attention_mask, + parameters=parameters.vit.encoder, + ) + + # Final LayerNorm + output = ttnn.layer_norm( + hidden_states, + weight=parameters.vit.layernorm.weight, + bias=parameters.vit.layernorm.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + program_config=config.program_configs["layernorm_program_config"], + ) + + # Classifier + classifier_output = ttnn.linear( + output, + parameters.classifier.weight, + bias=parameters.classifier.bias, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + program_config=config.program_configs["classifer_matmul_program_config"], + ) + + return classifier_output + + +def preprocess_inputs( + input_ids, + token_type_ids, + attention_mask, + device, +): + batch_size, _ = input_ids.shape + + input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + + token_type_ids = ttnn.from_torch( + token_type_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) + attention_mask = torch.nn.functional.pad(attention_mask, (0, 0, 0, 0, 0, 0, 0, batch_size - 1)) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return input_ids, token_type_ids, attention_mask + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.vit.modeling_vit.ViTEmbeddings): + weight = torch_model.patch_embeddings.projection.weight + bias = torch_model.patch_embeddings.projection.bias + + three_times_hidden_size, c, _, _ = weight.shape + pad_value = 4 - c + preprocessed_weight = torch.nn.functional.pad(weight, (0, 0, 0, 0, 0, pad_value)) + preprocessed_weight = torch.permute(preprocessed_weight, (2, 3, 1, 0)) + preprocessed_weight = torch.reshape( + preprocessed_weight, (int(three_times_hidden_size * (4 / c)), three_times_hidden_size) + ) + + parameters = {"patch_embeddings": {}} + parameters["patch_embeddings"] = {"projection": {}} + parameters["patch_embeddings"]["projection"]["weight"] = ttnn.from_torch( + preprocessed_weight, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + parameters["patch_embeddings"]["projection"]["bias"] = ttnn.from_torch( + bias.unsqueeze(0), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + parameters["cls_token"] = ttnn.from_torch(torch_model.cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT) + parameters["position_embeddings"] = ttnn.from_torch( + torch_model.position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + if hasattr(torch_model, "query") and hasattr(torch_model, "key") and hasattr(torch_model, "value"): + qkv_weight = torch.cat( + [ + torch_model.query.weight, + torch_model.key.weight, + torch_model.value.weight, + ], + dim=0, + ) + qkv_bias = torch.cat( + [torch_model.query.bias, torch_model.key.bias, torch_model.value.bias], + dim=0, + ) + + parameters = {"query_key_value": {}} + parameters["query_key_value"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=ttnn.bfloat8_b) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=ttnn.bfloat8_b) + + elif isinstance(torch_model, torch.nn.Linear): + # print(torch_model.weight.shape) + if torch_model.weight.shape[0] == 1000: + preprocessed_weight = torch.nn.functional.pad(torch_model.weight, (0, 0, 0, int(1152 - 1000))) + preprocessed_bias = torch.nn.functional.pad(torch_model.bias, (0, int(1152 - 1000))) + parameters["weight"] = preprocess_linear_weight(preprocessed_weight, dtype=ttnn.bfloat8_b) + parameters["bias"] = preprocess_linear_bias(preprocessed_bias, dtype=ttnn.bfloat8_b) + else: + parameters["weight"] = preprocess_linear_weight(torch_model.weight, dtype=ttnn.bfloat8_b) + parameters["bias"] = preprocess_linear_bias(torch_model.bias, dtype=ttnn.bfloat8_b) + + return parameters diff --git a/models/experimental/functional_vit/tt/ttnn_optimized_vit_highres.py b/models/experimental/functional_vit/tt/ttnn_optimized_vit_highres.py new file mode 100644 index 00000000000..a56777dc10e --- /dev/null +++ b/models/experimental/functional_vit/tt/ttnn_optimized_vit_highres.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import transformers +import torch + +import ttnn +from ttnn.model_preprocessing import ( + preprocess_linear_weight, + preprocess_linear_bias, +) +import tt_lib as ttl +import tt_lib.fallback_ops + +core_grid = ttnn.CoreGrid(y=8, x=12) + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def vit_patch_embeddings( + config, + pixel_values, + *, + parameters, +): + # batch_size, img_c, img_h, img_w = pixel_values.shape # NCHW + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + patch_count = img_h // patch_size # 14 + patch_size_sq_trpl = int(patch_size * patch_size * 3) # 768 + patch_count_all = int(patch_count * patch_count) # 196 + stride_h = patch_size + stride_w = 1 + + pixel_values = ttnn.reshape(pixel_values, (batch_size, img_h, img_w // patch_size, 4 * patch_size)) + pixel_values = ttl.tensor.fold(pixel_values, stride_h, stride_w) + pixel_values = ttnn.to_layout(pixel_values, layout=ttnn.TILE_LAYOUT) + + ## Needed only when running the standalone module pytest test_vit_patch_embeddings + ## Please comment out when running the pytest on parent module like test_vit_embeddings or test_vit + # parameters = parameters.vit.embeddings.patch_embeddings + + patch_embedding_output = ttnn.linear( + pixel_values, + parameters.projection.weight, + bias=parameters.projection.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + ) + ttnn.deallocate(pixel_values) + + patch_embedding_output = ttnn.to_layout(patch_embedding_output, layout=ttnn.ROW_MAJOR_LAYOUT) + patch_embedding_output = ttnn.reshape(patch_embedding_output, (batch_size, patch_count_all, patch_size_sq_trpl)) + + return patch_embedding_output + + +def vit_embeddings( + config, + pixel_values, + position_embeddings_interpolated, + *, + parameters, +): + parameters = parameters.vit.embeddings + + patch_embeddings = vit_patch_embeddings(config, pixel_values, parameters=parameters.patch_embeddings) + embedding_output = ttnn.concat((parameters.cls_token, patch_embeddings), dim=1) + + embedding_output = ttnn.add( + embedding_output, position_embeddings_interpolated, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b + ) + + # padding + # 1024 / 16 = 64 + # 64*64 + 32 = 4128 (from cls_token concat) + # 4352 = (4128 + 224) + # 4352 / 8 = 136 + + # embedding_output = ttnn.pad(embedding_output, ((0, 0), (0, 224), (0, 0)), 0) + + return embedding_output + + +def vit_layernorm_before( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_layernorm_after( + config, + hidden_states, + *, + parameters, +): + attention_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + epsilon=config.layer_norm_eps, + memory_config=ttnn.L1_MEMORY_CONFIG, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["layernorm_program_config"], + ) + + return attention_output + + +def vit_attention_experimental( + config, + hidden_states, + attention_mask, + parameters, +): + num_heads = config.num_attention_heads + *_, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query_key_value = ttnn.linear( + hidden_states, + parameters.attention.query_key_value.weight, + bias=parameters.attention.query_key_value.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["query_key_value_matmul_program_config"], + ) + ttnn.reallocate(hidden_states) + ( + query, + key, + value, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value, + memory_config=ttnn.L1_MEMORY_CONFIG, + num_heads=num_heads, + ) + ttnn.deallocate(query_key_value) + ttnn.reallocate(value) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["query_by_key_matmul_program_config"], + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + """ + attention_probs = ttnn.transformer.attention_softmax_( + attention_scores, + attention_mask=attention_mask, + head_size=head_size, + # program_config=program_configs["softmax_program_config"], + ) + + context_layer = ttnn.matmul( + attention_probs, + value, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["attention_probabilities_by_value_matmul_program_config"], + ) + ttnn.deallocate(attention_probs) + ttnn.deallocate(value) + + context_layer = ttnn.transformer.concatenate_heads( + context_layer, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + self_output = ttnn.linear( + context_layer, + parameters.output.dense.weight, + bias=parameters.output.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["self_output_matmul_program_config"], + ) + ttnn.deallocate(context_layer) + + return self_output + """ + return attention_scores + + +def vit_attention( + config, + hidden_states, + attention_mask, + parameters, +): + num_heads = config.num_attention_heads + *_, hidden_size = hidden_states.shape + head_size = hidden_size // num_heads + + query_key_value = ttnn.linear( + hidden_states, + parameters.attention.query_key_value.weight, + bias=parameters.attention.query_key_value.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["query_key_value_matmul_program_config"], + ) + ttnn.reallocate(hidden_states) + ( + query, + key, + value, + ) = ttnn.transformer.split_query_key_value_and_split_heads( + query_key_value, + memory_config=ttnn.L1_MEMORY_CONFIG, + num_heads=num_heads, + ) + ttnn.deallocate(query_key_value) + value = ttnn.reallocate(value) + + attention_scores = ttnn.matmul( + query, + key, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["query_by_key_matmul_program_config"], + ) + ttnn.deallocate(query) + ttnn.deallocate(key) + + attention_probs = ttnn.transformer.attention_softmax_( + attention_scores, + attention_mask=attention_mask, + head_size=head_size, + # program_config=program_configs["softmax_program_config"], + ) + + context_layer = ttnn.matmul( + attention_probs, + value, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["attention_probabilities_by_value_matmul_program_config"], + ) + ttnn.deallocate(attention_probs) + ttnn.deallocate(value) + + context_layer = ttnn.transformer.concatenate_heads( + context_layer, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + self_output = ttnn.linear( + context_layer, + parameters.output.dense.weight, + bias=parameters.output.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["self_output_matmul_program_config"], + ) + ttnn.deallocate(context_layer) + + return self_output + + +def vit_intermediate( + hidden_states, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + # program_config=program_configs["ff1_matmul_program_config"], + core_grid=ttnn.CoreGrid(y=8, x=12), + activation="gelu", + ) + # ttnn.deallocate(hidden_states) + + return output + + +def vit_output( + config, + hidden_states, + residual, + *, + parameters, +): + output = ttnn.linear( + hidden_states, + parameters.dense.weight, + bias=parameters.dense.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + core_grid=ttnn.CoreGrid(y=8, x=12), + # program_config=program_configs["ff2_matmul_program_config"], + ) + ttnn.deallocate(hidden_states) + + output = ttnn.add(output, residual, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b) + + # ttnn.deallocate(residual) + + return output + + +def vit_feedforward( + config, + hidden_states, + attention_output, + *, + parameters, +): + intermediate = vit_intermediate(hidden_states, parameters=parameters.intermediate) + hidden_states = vit_output(config, intermediate, attention_output, parameters=parameters.output) + return hidden_states + + +def vit_layer( + config, + hidden_states, + attention_mask, + parameters, +): + layernorm_before_output = ttnn.layer_norm( + hidden_states, + weight=parameters.layernorm_before.weight, + bias=parameters.layernorm_before.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + # program_config=program_configs["layernorm_program_config"], + ) + + multi_head_attention_output = vit_attention( + config, + layernorm_before_output, + attention_mask=attention_mask, + parameters=parameters.attention, + ) + + multi_head_attention_output = ttnn.add( + multi_head_attention_output, hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat8_b + ) + + layernorm_after_output = ttnn.layer_norm( + multi_head_attention_output, + weight=parameters.layernorm_after.weight, + bias=parameters.layernorm_after.bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + # program_config=program_configs["layernorm_program_config"], + ) + + feedforward_output = vit_feedforward( + config, + layernorm_after_output, + multi_head_attention_output, + parameters=parameters, + ) + + return feedforward_output + + +def vit_encoder( + config, + embeddings, + head_masks, + parameters, +): + # encoder_input = ttnn.to_memory_config( + # embeddings, + # memory_config=ttnn.create_sharded_memory_config( + # embeddings.shape, + # core_grid=core_grid, + # strategy=ttnn.ShardStrategy.BLOCK, + # orientation=ttnn.ShardOrientation.ROW_MAJOR, + # ), + # dtype=ttnn.bfloat8_b, + # ) + # ttnn.deallocate(embeddings) + encoder_input = embeddings + + encoder_output = None + for index, encoder_parameters in enumerate(parameters.layer): + encoder_output = vit_layer( + config, + encoder_input, + head_masks[index], + encoder_parameters, + ) + encoder_input = encoder_output + + return encoder_output + + +def vit( + config, + pixel_values, + attention_mask, + position_embeddings_interpolated, + parameters, +): + embeddings_output = vit_embeddings(config, pixel_values, position_embeddings_interpolated, parameters=parameters) + + hidden_states = vit_encoder( + config, + embeddings_output, + attention_mask, + parameters=parameters.vit.encoder, + ) + + # Final LayerNorm + output = ttnn.layer_norm( + hidden_states, + weight=parameters.vit.layernorm.weight, + bias=parameters.vit.layernorm.bias, + ) + + # Classifier + classifier_output = output @ parameters.classifier.weight + classifier_output = classifier_output + parameters.classifier.bias + + return classifier_output + + +def preprocess_inputs( + input_ids, + token_type_ids, + attention_mask, + device, +): + batch_size, _ = input_ids.shape + + input_ids = ttnn.from_torch(input_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + + token_type_ids = ttnn.from_torch( + token_type_ids, dtype=ttnn.uint32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) + attention_mask = torch.nn.functional.pad(attention_mask, (0, 0, 0, 0, 0, 0, 0, batch_size - 1)) + attention_mask = ttnn.from_torch( + attention_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + return input_ids, token_type_ids, attention_mask + + +def custom_preprocessor(torch_model, name): + parameters = {} + if isinstance(torch_model, transformers.models.vit.modeling_vit.ViTEmbeddings): + weight = torch_model.patch_embeddings.projection.weight + bias = torch_model.patch_embeddings.projection.bias + + three_times_hidden_size, c, _, _ = weight.shape + pad_value = 4 - c + preprocessed_weight = torch.nn.functional.pad(weight, (0, 0, 0, 0, 0, pad_value)) + preprocessed_weight = torch.permute(preprocessed_weight, (2, 3, 1, 0)) + preprocessed_weight = torch.reshape( + preprocessed_weight, (int(three_times_hidden_size * (4 / c)), three_times_hidden_size) + ) + + parameters = {"patch_embeddings": {}} + parameters["patch_embeddings"] = {"projection": {}} + parameters["patch_embeddings"]["projection"]["weight"] = ttnn.from_torch( + preprocessed_weight, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + parameters["patch_embeddings"]["projection"]["bias"] = ttnn.from_torch( + bias.unsqueeze(0), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + parameters["cls_token"] = ttnn.from_torch(torch_model.cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT) + parameters["position_embeddings"] = ttnn.from_torch( + torch_model.position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT + ) + + if hasattr(torch_model, "query") and hasattr(torch_model, "key") and hasattr(torch_model, "value"): + qkv_weight = torch.cat( + [ + torch_model.query.weight, + torch_model.key.weight, + torch_model.value.weight, + ], + dim=0, + ) + qkv_bias = torch.cat( + [torch_model.query.bias, torch_model.key.bias, torch_model.value.bias], + dim=0, + ) + + parameters = {"query_key_value": {}} + parameters["query_key_value"]["weight"] = preprocess_linear_weight(qkv_weight, dtype=ttnn.bfloat8_b) + parameters["query_key_value"]["bias"] = preprocess_linear_bias(qkv_bias, dtype=ttnn.bfloat8_b) + + elif isinstance(torch_model, torch.nn.Linear): + parameters["weight"] = preprocess_linear_weight(torch_model.weight, dtype=ttnn.bfloat8_b) + parameters["bias"] = preprocess_linear_bias(torch_model.bias, dtype=ttnn.bfloat8_b) + + return parameters diff --git a/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_functional_vit.py b/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_functional_vit.py new file mode 100644 index 00000000000..c519890a139 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_functional_vit.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_functional_vit +from models.experimental.vit.vit_helper_funcs import get_data_loader, get_batch + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) +from models.perf.perf_utils import prep_perf_report +import ast +from pathlib import Path + + +def get_expected_times(functional_vit): + return { + ttnn_functional_vit: (12, 17), + }[functional_vit] + + +def get_imagenet_label_dict(): + path = "models/sample_data/imagenet_class_labels.txt" + with open(path, "r") as file: + class_labels = ast.literal_eval(file.read()) + return class_labels + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit]) +def test_accuracy( + device, + use_program_cache, + model_name, + batch_size, + image_size, + sequence_size, + functional_vit, + model_location_generator, +): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + model = model.to(torch.bfloat16) + + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + ################## + + iterations = 50 + imagenet_label_dict = get_imagenet_label_dict() + + data_loader = get_data_loader("ImageNet_data", batch_size, iterations) + correct = 0 + for iter in range(iterations): + predictions = [] + inputs, labels = get_batch(data_loader, image_processor) + + inputs = torch.permute(inputs, (0, 2, 3, 1)) + inputs = torch.nn.functional.pad(inputs, (0, 1, 0, 0, 0, 0, 0, 0)) + tt_inputs = ttnn.from_torch(inputs, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + tt_output = functional_vit.vit( + config, + tt_inputs, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + print(tt_output.shape) + + prediction = ttnn.to_torch(tt_output[:, 0]).argmax(dim=-1) + + for i in range(batch_size): + predictions.append(imagenet_label_dict[prediction[i].item()]) + logger.info( + f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- Predicted Label: {predictions[-1]}" + ) + if imagenet_label_dict[labels[i]] == predictions[-1]: + correct += 1 + del tt_output, tt_inputs, inputs, labels, predictions + + enable_persistent_kernel_cache() + + accuracy = correct / (batch_size * iterations) + logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}") diff --git a/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_interleaved_vit.py b/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_interleaved_vit.py new file mode 100644 index 00000000000..b8ddd66f54f --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_interleaved_vit.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_optimized_interleaved_vit +from models.experimental.vit.vit_helper_funcs import get_data_loader, get_batch + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch2tt_tensor, +) +from models.perf.perf_utils import prep_perf_report +import ast +from pathlib import Path + + +def get_expected_times(functional_vit): + return { + ttnn_optimized_interleaved_vit: (12, 0.08), + }[functional_vit] + + +def get_imagenet_label_dict(): + path = "models/sample_data/imagenet_class_labels.txt" + with open(path, "r") as file: + class_labels = ast.literal_eval(file.read()) + return class_labels + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_interleaved_vit]) +def test_accuracy( + device, + use_program_cache, + model_name, + batch_size, + image_size, + sequence_size, + functional_vit, + model_location_generator, +): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + model = model.to(torch.bfloat16) + + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_optimized_interleaved_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + batch_size = 8 + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + ################## + + iterations = 50 + imagenet_label_dict = get_imagenet_label_dict() + + data_loader = get_data_loader("ImageNet_data", batch_size, iterations) + correct = 0 + for iter in range(iterations): + predictions = [] + inputs, labels = get_batch(data_loader, image_processor) + + torch_pixel_values = torch.permute(inputs, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = ttnn.experimental.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False + ) + + tt_inputs = torch2tt_tensor( + torch_pixel_values, + device, + ttnn.experimental.tensor.Layout.ROW_MAJOR, + tt_memory_config=ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + shard_spec, + ), + tt_dtype=ttnn.experimental.tensor.DataType.BFLOAT16, + ) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + tt_output = functional_vit.vit( + config, + tt_inputs, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + print(tt_output.shape) + + prediction = ttnn.to_torch(tt_output[:, 0, :1000]).argmax(dim=-1) + + for i in range(batch_size): + predictions.append(imagenet_label_dict[prediction[i].item()]) + logger.info( + f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- Predicted Label: {predictions[-1]}" + ) + if imagenet_label_dict[labels[i]] == predictions[-1]: + correct += 1 + del tt_output, tt_inputs, inputs, labels, predictions + + enable_persistent_kernel_cache() + + accuracy = correct / (batch_size * iterations) + logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}") diff --git a/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_sharded_vit.py b/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_sharded_vit.py new file mode 100644 index 00000000000..a2cfea3d888 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_accuracy_ttnn_optim_sharded_vit.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_optimized_sharded_vit +from models.experimental.vit.vit_helper_funcs import get_data_loader, get_batch + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) +from models.perf.perf_utils import prep_perf_report +import ast +from pathlib import Path + + +def get_expected_times(functional_vit): + return { + ttnn_optimized_sharded_vit: (12, 0.08), + }[functional_vit] + + +def get_imagenet_label_dict(): + path = "models/sample_data/imagenet_class_labels.txt" + with open(path, "r") as file: + class_labels = ast.literal_eval(file.read()) + return class_labels + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_sharded_vit]) +def test_accuracy( + device, + use_program_cache, + model_name, + batch_size, + image_size, + sequence_size, + functional_vit, + model_location_generator, +): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + model = model.to(torch.bfloat16) + + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_optimized_sharded_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + ################## + + iterations = 50 + imagenet_label_dict = get_imagenet_label_dict() + + data_loader = get_data_loader("ImageNet_data", batch_size, iterations) + correct = 0 + for iter in range(iterations): + predictions = [] + inputs, labels = get_batch(data_loader, image_processor) + + inputs = torch.permute(inputs, (0, 2, 3, 1)) + inputs = torch.nn.functional.pad(inputs, (0, 1, 0, 0, 0, 0, 0, 0)) + tt_inputs = ttnn.from_torch(inputs, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + tt_output = functional_vit.vit( + config, + tt_inputs, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + print(tt_output.shape) + + prediction = ttnn.to_torch(tt_output[:, 0]).argmax(dim=-1) + + for i in range(batch_size): + predictions.append(imagenet_label_dict[prediction[i].item()]) + logger.info( + f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- Predicted Label: {predictions[-1]}" + ) + if imagenet_label_dict[labels[i]] == predictions[-1]: + correct += 1 + del tt_output, tt_inputs, inputs, labels, predictions + + enable_persistent_kernel_cache() + + accuracy = correct / (batch_size * iterations) + logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_functional_vit.py b/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_functional_vit.py new file mode 100644 index 00000000000..e7cf4494a73 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_functional_vit.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_functional_vit + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_functional_vit: (12, 17), + }[functional_vit] + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) ## padded from 197 to 224 +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_functional_vit: + tt_model_name = f"ttnn_{model_name}" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + if torch_attention_mask is not None: + head_masks = ttnn.from_torch(torch_attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + else: + head_masks = None + head_masks = None + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + + inference_time, *_ = durations + logger.info(f"Inference time: {inference_time}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + # cls_token expand to batch_size + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + + if functional_vit == ttnn_functional_vit: + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + if functional_vit == ttnn_functional_vit: + tt_model_name = f"ttnn_{model_name}" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + + inference_time, *_ = durations + logger.info(f"Inference time: {inference_time}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_interleaved_vit.py b/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_interleaved_vit.py new file mode 100644 index 00000000000..8671a6a0bbe --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_interleaved_vit.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_optimized_interleaved_vit + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_functional_vit: (12, 17), + ttnn_optimized_interleaved_vit: (12, 0.08), + }[functional_vit] + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [196]) ## padded from 197 to 224 +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_interleaved_vit]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + # disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_optimized_interleaved_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + + inference_time, *_ = durations + logger.info(f"Inference time: {inference_time}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_interleaved_vit]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + # disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + # cls_token expand to batch_size + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + if functional_vit == ttnn_optimized_interleaved_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + + inference_time, *_ = durations + logger.info(f"Inference time: {inference_time}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_sharded_vit.py b/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_sharded_vit.py new file mode 100644 index 00000000000..26644e62c17 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_deviceOPs_ttnn_optim_sharded_vit.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn +import tt_lib +from models.experimental.functional_vit.tt import ttnn_optimized_sharded_vit +from models.utility_functions import torch_random, skip_for_wormhole_b0, torch2tt_tensor +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_functional_vit: (12, 17), + ttnn_optimized_sharded_vit: (12, 0.08), + }[functional_vit] + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_sharded_vit]) +def test_performance_vit_embeddings(device, model_name, batch_size, image_size, image_channels, functional_vit): + # tt_lib.device.EnableMemoryReports() + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.float32) + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + # cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + # position_embeddings = ttnn.from_torch( + # torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + # ) + torch_cls_token_padded = torch.nn.functional.pad(torch_cls_token, (0, 0, 0, 196, 0, 0)) + torch_cls_position_embeddings = torch.add(torch_cls_token_padded, torch_position_embeddings) + cls_position_embeddings = ttnn.from_torch( + torch_cls_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_sharded_vit.custom_preprocessor, + ) + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = tt_lib.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + tt_lib.tensor.Layout.ROW_MAJOR, + tt_memory_config=tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec + ), + tt_dtype=tt_lib.tensor.DataType.BFLOAT16, + ) + # pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit_embeddings( + config, + pixel_values, + # cls_token, + cls_position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + + inference_time, *_ = durations + logger.info(f"Inference time: {inference_time}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [196]) ## padded from 197 to 224 +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_sharded_vit]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + # disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_optimized_sharded_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + + inference_time, *_ = durations + logger.info(f"Inference time: {inference_time}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_sharded_vit]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + # disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + # cls_token expand to batch_size + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + torch_position_embeddings = torch.nn.functional.pad(torch_position_embeddings, (0, 0, 0, 27, 0, 0)) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + # torch_cls_token_padded = torch.nn.functional.pad(torch_cls_token, (0, 0, 0, 196, 0, 0)) + # torch_cls_position_embeddings = torch.add(torch_cls_token_padded, torch_position_embeddings) + # cls_position_embeddings = ttnn.from_torch( + # torch_cls_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + # ) + + if functional_vit == ttnn_optimized_sharded_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + + durations = [] + for _ in range(1): + start = time.time() + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = tt_lib.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + tt_lib.tensor.Layout.ROW_MAJOR, + tt_memory_config=tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec + ), + tt_dtype=tt_lib.tensor.DataType.BFLOAT16, + ) + # pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + + inference_time, *_ = durations + logger.info(f"Inference time: {inference_time}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_highres.py b/tests/ttnn/integration_tests/vit/test_performance_highres.py new file mode 100644 index 00000000000..14c98d227b6 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_highres.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_functional_vit_highres +from models.experimental.functional_vit.tt import ttnn_optimized_vit_highres + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_functional_vit_highres: (12, 17), + ttnn_optimized_vit_highres: (12, 0.08), + }[functional_vit] + + +def interpolate_pos_encoding( + position_embeddings: torch.Tensor, patch_size, num_patches, height: int, width: int +) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = position_embeddings.shape[-1] + h0 = height // patch_size + w0 = width // patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = torch.nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [448]) +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit_highres, ttnn_optimized_vit_highres]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + # disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_functional_vit_highres: + tt_model_name = f"ttnn_{model_name}" + elif functional_vit == ttnn_optimized_vit_highres: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + if functional_vit == ttnn_optimized_vit_highres: + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + else: + hidden_states = ttnn.from_torch( + torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + if torch_attention_mask is not None: + head_masks = ttnn.from_torch( + torch_attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + else: + head_masks = None + head_masks = None + + durations = [] + for _ in range(2): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("image_size", [960]) +@pytest.mark.parametrize("sequence_size", [448]) +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit_highres, ttnn_optimized_vit_highres]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0].resize((image_size, image_size)) + + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor( + image, return_tensors="pt", do_resize=False, do_center_crop=False + ).pixel_values.to(torch.bfloat16) + + # torch_pixel_values = torch.rand((1, 3, 960, 960)) + torch_attention_mask = ( + None # torch.zeros(1, sequence_size) if functional_vit == ttnn_optimized_functional_vit else None + ) + + if functional_vit == ttnn_functional_vit_highres: + tt_model_name = f"ttnn_{model_name}" + elif functional_vit == ttnn_optimized_vit_highres: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + # High resolution patch_parameters interpolation + model_state_dict = model.state_dict() + init_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size // patch_size) * (image_size // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size, image_size) + ) + position_embeddings = ttnn.from_torch(torch_position_embeddings, layout=ttnn.TILE_LAYOUT, device=device) + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + torch_pixel_values = torch_pixel_values.to(torch.bfloat16) + pixel_values = ttnn.from_torch( + torch_pixel_values, + layout=ttnn.TILE_LAYOUT, + device=device, + dtype=ttnn.bfloat8_b, + # memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(2): + start = time.time() + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + position_embeddings, + cls_token, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_highres_deviceOPs.py b/tests/ttnn/integration_tests/vit/test_performance_highres_deviceOPs.py new file mode 100644 index 00000000000..02609beca7d --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_highres_deviceOPs.py @@ -0,0 +1,294 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_functional_vit_highres +from models.experimental.functional_vit.tt import ttnn_optimized_vit_highres + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_functional_vit_highres: (12, 17), + ttnn_optimized_vit_highres: (12, 0.08), + }[functional_vit] + + +def interpolate_pos_encoding( + position_embeddings: torch.Tensor, patch_size, num_patches, height: int, width: int +) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = position_embeddings.shape[-1] + h0 = height // patch_size + w0 = width // patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = torch.nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [448]) +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit_highres, ttnn_optimized_vit_highres]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + # disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_functional_vit_highres: + tt_model_name = f"ttnn_{model_name}" + elif functional_vit == ttnn_optimized_vit_highres: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + if functional_vit == ttnn_optimized_vit_highres: + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + else: + hidden_states = ttnn.from_torch( + torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + if torch_attention_mask is not None: + head_masks = ttnn.from_torch( + torch_attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + else: + head_masks = None + head_masks = None + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + # enable_persistent_kernel_cache() + + # inference_and_compile_time, inference_time, *_ = durations + inference_time = durations + + """ + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + """ + + # logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + # logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("image_size", [960]) +@pytest.mark.parametrize("sequence_size", [448]) +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit_highres, ttnn_optimized_vit_highres]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + # disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0].resize((image_size, image_size)) + + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor( + image, return_tensors="pt", do_resize=False, do_center_crop=False + ).pixel_values.to(torch.bfloat16) + + # torch_pixel_values = torch.rand((1, 3, 960, 960)) + torch_attention_mask = ( + None # torch.zeros(1, sequence_size) if functional_vit == ttnn_optimized_functional_vit else None + ) + + if functional_vit == ttnn_functional_vit_highres: + tt_model_name = f"ttnn_{model_name}" + elif functional_vit == ttnn_optimized_vit_highres: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + # High resolution patch_parameters interpolation + model_state_dict = model.state_dict() + init_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size // patch_size) * (image_size // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size, image_size) + ) + position_embeddings = ttnn.from_torch(torch_position_embeddings, layout=ttnn.TILE_LAYOUT, device=device) + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + torch_pixel_values = torch_pixel_values.to(torch.bfloat16) + pixel_values = ttnn.from_torch( + torch_pixel_values, + layout=ttnn.TILE_LAYOUT, + device=device, + dtype=ttnn.bfloat8_b, + # memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(1): + start = time.time() + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + position_embeddings, + cls_token, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + # enable_persistent_kernel_cache() + + # inference_and_compile_time, inference_time, *_ = durations + inference_time, *_ = durations + + """ + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + """ + + # logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + # logger.info(f"Samples per second: {1 / inference_time * batch_size}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_ttnn_functional_vit.py b/tests/ttnn/integration_tests/vit/test_performance_ttnn_functional_vit.py new file mode 100644 index 00000000000..f182ed9f1a9 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_ttnn_functional_vit.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_functional_vit + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_functional_vit: (12, 17), + }[functional_vit] + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) ## padded from 197 to 224 +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_functional_vit: + tt_model_name = f"ttnn_{model_name}" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + if torch_attention_mask is not None: + head_masks = ttnn.from_torch(torch_attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + else: + head_masks = None + head_masks = None + + durations = [] + for _ in range(2): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_functional_vit]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + # cls_token expand to batch_size + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + if functional_vit == ttnn_functional_vit: + tt_model_name = f"ttnn_{model_name}" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(2): + start = time.time() + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_interleaved_vit.py b/tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_interleaved_vit.py new file mode 100644 index 00000000000..568a737da98 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_interleaved_vit.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn + +from models.experimental.functional_vit.tt import ttnn_optimized_interleaved_vit + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_optimized_interleaved_vit: (12, 0.08), + }[functional_vit] + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [196]) ## padded from 197 to 224 +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_interleaved_vit]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_optimized_interleaved_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(2): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_interleaved_vit]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + # cls_token expand to batch_size + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + if functional_vit == ttnn_optimized_interleaved_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + # pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + # pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + # pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + durations = [] + for _ in range(2): + start = time.time() + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + pixel_values = pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = pixel_values.shape + shard_grid = tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = tt_lib.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + pixel_values, + device, + tt_lib.tensor.Layout.ROW_MAJOR, + tt_memory_config=tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec + ), + tt_dtype=tt_lib.tensor.DataType.BFLOAT16, + ) + # pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") diff --git a/tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_sharded_vit.py b/tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_sharded_vit.py new file mode 100644 index 00000000000..5917a823e6b --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_performance_ttnn_optim_sharded_vit.py @@ -0,0 +1,264 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from loguru import logger +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn +import tt_lib +from models.experimental.functional_vit.tt import ttnn_optimized_sharded_vit +from models.utility_functions import torch_random, skip_for_wormhole_b0, torch2tt_tensor + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.utility_functions import ( + skip_for_wormhole_b0, + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, + torch_random, +) +from models.perf.perf_utils import prep_perf_report + + +def get_expected_times(functional_vit): + return { + ttnn_optimized_sharded_vit: (12, 0.08), + }[functional_vit] + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [196]) ## padded from 197 to 224 +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_sharded_vit]) +def test_performance_vit_encoder(device, use_program_cache, model_name, batch_size, sequence_size, functional_vit): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + if functional_vit == ttnn_optimized_sharded_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + + durations = [] + for _ in range(2): + start = time.time() + tt_output = functional_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.models_performance_bare_metal +@pytest.mark.models_performance_virtual_machine +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("functional_vit", [ttnn_optimized_sharded_vit]) +def test_performance_vit_e2e( + device, use_program_cache, model_name, batch_size, image_size, sequence_size, functional_vit +): + disable_persistent_kernel_cache() + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + # cls_token expand to batch_size + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + # torch_cls_token_padded = torch.nn.functional.pad(torch_cls_token, (0, 0, 0, 196, 0, 0)) + # torch_cls_position_embeddings = torch.add(torch_cls_token_padded, torch_position_embeddings) + # cls_position_embeddings = ttnn.from_torch( + # torch_cls_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + # ) + + if functional_vit == ttnn_optimized_sharded_vit: + tt_model_name = f"ttnn_{model_name}_optimized" + else: + raise ValueError(f"Unknown functional_vit: {functional_vit}") + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + custom_preprocessor=functional_vit.custom_preprocessor, + device=device, + ) + + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + + durations = [] + import tracy + + # tracyProfiler = tracy.Profiler() + # tracyProfiler.enable() + for _ in range(2): + start = time.time() + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = pixel_values.shape # permuted input NHWC + patch_size = 16 + pixel_values = pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = pixel_values.shape + shard_grid = tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = tt_lib.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], tt_lib.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + pixel_values, + device, + tt_lib.tensor.Layout.ROW_MAJOR, + tt_memory_config=tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec + ), + tt_dtype=tt_lib.tensor.DataType.BFLOAT16, + ) + # pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + tt_output = functional_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + tt_output = ttnn.from_device(tt_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + # tracyProfiler.disable() + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times(functional_vit) + prep_perf_report( + model_name=tt_model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") diff --git a/tests/ttnn/integration_tests/vit/test_torch_functional_vit.py b/tests/ttnn/integration_tests/vit/test_torch_functional_vit.py new file mode 100644 index 00000000000..e59777136c3 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_torch_functional_vit.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch +import transformers + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.experimental.functional_vit.reference import torch_functional_vit +from models.utility_functions import torch_random, skip_for_wormhole_b0 + +from tests.ttnn.utils_for_testing import assert_with_pcc + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_patch_embeddings(model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTPatchEmbeddings(config).eval() + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_vit.custom_preprocessor, + ) + + output = torch_functional_vit.vit_patch_embeddings( + torch_pixel_values, + parameters=parameters, + ) + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_embeddings(model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTEmbeddings(config).eval() + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_vit.custom_preprocessor, + ) + + # TODO: integrate within paramters + model_state_dict = model.state_dict() + torch_cls_token = torch.nn.Parameter(model_state_dict["cls_token"]) + torch_position_embeddings = torch.nn.Parameter(model_state_dict["position_embeddings"]) + + output = torch_functional_vit.vit_embeddings( + config, + torch_pixel_values, + torch_position_embeddings, + torch_cls_token, + parameters=parameters, + ) + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_attention(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_attention_mask = torch.ones(1, sequence_size) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + print(parameters) + + output = torch_functional_vit.vit_attention( + config, + torch_hidden_states, + attention_mask=torch_attention_mask, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_intermediate(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTIntermediate(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_output = model(torch_hidden_states) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_intermediate( + torch_hidden_states, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_output(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTOutput(config).eval() + + torch_intermediate = torch_random( + (batch_size, sequence_size, config.intermediate_size), -0.1, 0.1, dtype=torch.float32 + ) + torch_residual = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_output = model(torch_intermediate, torch_residual) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_output( + config, + torch_intermediate, + torch_residual, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_layer(model_name, batch_size, sequence_size): + torch.manual_seed(322) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTLayer(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_attention_mask = torch.ones(1, sequence_size) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_layer( + config, + torch_hidden_states, + attention_mask=torch_attention_mask, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [198]) +def test_vit_encoder(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTEncoder(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_output = model(torch_hidden_states).last_hidden_state + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_encoder( + config, + torch_hidden_states, + attention_mask=None, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit(model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + model = model.to(torch.bfloat16) + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.bfloat16) + torch_output, *_ = model(torch_pixel_values).logits + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_vit.custom_preprocessor, + ) + + # TODO: integrate within paramters + model_state_dict = model.state_dict() + torch_cls_token = torch.nn.Parameter(model_state_dict["vit.embeddings.cls_token"]) + torch_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + + output = torch_functional_vit.vit( + config, + torch_pixel_values, + torch_position_embeddings, + torch_cls_token, + attention_mask=None, + parameters=parameters, + ) + + print(torch_output.shape, output.shape) + assert_with_pcc(torch_output, output[0], 0.9999) diff --git a/tests/ttnn/integration_tests/vit/test_torch_functional_vit_highres.py b/tests/ttnn/integration_tests/vit/test_torch_functional_vit_highres.py new file mode 100644 index 00000000000..186102c53e3 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_torch_functional_vit_highres.py @@ -0,0 +1,347 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch +import math +import transformers + +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.experimental.functional_vit.reference import torch_functional_vit +from models.utility_functions import torch_random, skip_for_wormhole_b0 + +from tests.ttnn.utils_for_testing import assert_with_pcc + +# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/vit/modeling_vit.py + + +def interpolate_pos_encoding( + position_embeddings: torch.Tensor, patch_size, num_patches, height: int, width: int +) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # num_patches = embeddings.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = position_embeddings.shape[-1] + h0 = height // patch_size + w0 = width // patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = torch.nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("image_size", [1280]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_patch_embeddings(model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTPatchEmbeddings(config).eval() + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values, interpolate_pos_encoding=True) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_vit.custom_preprocessor, + ) + + output = torch_functional_vit.vit_patch_embeddings( + torch_pixel_values, + parameters=parameters, + ) + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("image_size", [1280]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_embeddings(model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTEmbeddings(config).eval() + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values, interpolate_pos_encoding=True) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_vit.custom_preprocessor, + ) + + # TODO: integrate within paramters + model_state_dict = model.state_dict() + torch_cls_token = torch.nn.Parameter(model_state_dict["cls_token"]) + init_position_embeddings = torch.nn.Parameter(model_state_dict["position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size // patch_size) * (image_size // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size, image_size) + ) + + output = torch_functional_vit.vit_embeddings( + config, + torch_pixel_values, + torch_position_embeddings, + torch_cls_token, + parameters=parameters, + ) + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [198]) +def test_vit_layernorm_before(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTLayer.layernorm_before(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_attention_mask = torch.ones(1, sequence_size) + torch_output, *_ = model(torch_hidden_states, interpolate_pos_encoding=True) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_layer.layernorm_before( + config, + torch_hidden_states, + attention_mask=torch_attention_mask, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_attention(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_attention_mask = torch.ones(1, sequence_size) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + print(parameters) + + output = torch_functional_vit.vit_attention( + config, + torch_hidden_states, + attention_mask=torch_attention_mask, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_intermediate(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTIntermediate(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_output = model(torch_hidden_states) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_intermediate( + torch_hidden_states, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_output(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTOutput(config).eval() + + torch_intermediate = torch_random( + (batch_size, sequence_size, config.intermediate_size), -0.1, 0.1, dtype=torch.float32 + ) + torch_residual = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_output = model(torch_intermediate, torch_residual) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_output( + config, + torch_intermediate, + torch_residual, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [196]) +def test_vit_layer(model_name, batch_size, sequence_size): + torch.manual_seed(322) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTLayer(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_attention_mask = torch.ones(1, sequence_size) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_layer( + config, + torch_hidden_states, + attention_mask=torch_attention_mask, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [198]) +def test_vit_encoder(model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTEncoder(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.float32) + torch_output = model(torch_hidden_states).last_hidden_state + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + + output = torch_functional_vit.vit_encoder( + config, + torch_hidden_states, + attention_mask=None, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("image_size", [1280]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit(model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + model = model.to(torch.bfloat16) + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.bfloat16) + # torch_output, *_ = model(torch_pixel_values).last_hidden_state + torch_output, *_ = model(torch_pixel_values, interpolate_pos_encoding=True).logits + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + custom_preprocessor=torch_functional_vit.custom_preprocessor, + ) + + # TODO: integrate within paramters + model_state_dict = model.state_dict() + torch_cls_token = torch.nn.Parameter(model_state_dict["vit.embeddings.cls_token"]) + init_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size // patch_size) * (image_size // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size, image_size) + ) + + output = torch_functional_vit.vit( + config, + torch_pixel_values, + torch_position_embeddings, + torch_cls_token, + attention_mask=None, + parameters=parameters, + ) + + assert_with_pcc(torch_output, output[0], 0.9999) diff --git a/tests/ttnn/integration_tests/vit/test_ttnn_functional_vit.py b/tests/ttnn/integration_tests/vit/test_ttnn_functional_vit.py new file mode 100644 index 00000000000..0ae4f77a462 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_ttnn_functional_vit.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.experimental.functional_vit.tt import ttnn_functional_vit +from models.utility_functions import torch_random, skip_for_wormhole_b0 + +from tests.ttnn.utils_for_testing import assert_with_pcc + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_patch_embeddings(device, model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + model = model.to(torch.bfloat16) + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.bfloat16) + torch_output, *_ = model(torch_pixel_values) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + output = ttnn_functional_vit.vit_patch_embeddings(config, pixel_values, parameters=parameters, unittest_check=True) + output = ttnn.to_torch(output) + print(output.shape) + + torch_output, *_ = model.vit.embeddings.patch_embeddings(torch_pixel_values) + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_embeddings(device, model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + model = model.to(torch.bfloat16) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + output = ttnn_functional_vit.vit_embeddings( + config, + pixel_values, + cls_token, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + print(output.shape) + + torch_output, *_ = model.vit.embeddings(torch_pixel_values) + + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) +def test_vit_attention(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_attention_mask = torch.ones(1, sequence_size, dtype=torch.bfloat16) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device) + attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit.vit_attention( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16]) +def test_vit_intermediate(device, model_name, batch_size, sequence_size, torch_dtype): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTIntermediate(config).eval() + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_output = model(torch_hidden_states) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit.vit_intermediate( + hidden_states, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) +def test_vit_output(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTOutput(config).eval() + model = model.to(torch.bfloat16) + + torch_intermediate = torch_random( + (batch_size, sequence_size, config.intermediate_size), -0.1, 0.1, dtype=torch.bfloat16 + ) + torch_residual = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_output = model(torch_intermediate, torch_residual) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + intermediate = ttnn.from_torch(torch_intermediate, layout=ttnn.TILE_LAYOUT, device=device) + residual = ttnn.from_torch(torch_residual, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit.vit_output( + config, + intermediate, + residual, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) # 9994 + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) +def test_vit_layer(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").vit.encoder.layer[0] + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.bfloat16) + torch_attention_mask = torch.ones(1, sequence_size, dtype=torch.bfloat16) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device) + attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit.vit_layer( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) # 0.9957 + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) +def test_vit_encoder(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").vit.encoder + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_attention_mask = None + torch_output = model(torch_hidden_states, torch_attention_mask).last_hidden_state + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + if torch_attention_mask is not None: + attention_mask = ttnn.from_torch( + torch_attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + else: + attention_mask = None + + output = ttnn_functional_vit.vit_encoder( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) # 0.9294 + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit(device, model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0:batch_size] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(torch.bfloat16) + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + + torch_output, *_ = model(torch_pixel_values).logits + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit.custom_preprocessor, + ) + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + output = ttnn_functional_vit.vit( + config, + pixel_values, + None, + cls_token, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output[0][0], 0.9999) # 0.9806 diff --git a/tests/ttnn/integration_tests/vit/test_ttnn_functional_vit_highres.py b/tests/ttnn/integration_tests/vit/test_ttnn_functional_vit_highres.py new file mode 100644 index 00000000000..a0a67a4f415 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_ttnn_functional_vit_highres.py @@ -0,0 +1,399 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.experimental.functional_vit.tt import ttnn_functional_vit_highres +from models.utility_functions import torch_random, skip_for_wormhole_b0 + +from tests.ttnn.utils_for_testing import assert_with_pcc + + +def interpolate_pos_encoding( + position_embeddings: torch.Tensor, patch_size, num_patches, height: int, width: int +) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = position_embeddings.shape[-1] + h0 = height // patch_size + w0 = width // patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = torch.nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("image_size_h", [1024]) +@pytest.mark.parametrize("image_size_w", [1024]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_patch_embeddings(device, model_name, image_size_h, image_size_w, image_channels): + torch.manual_seed(0) + + # strictly batch=1 for large resolution + batch_size = 1 + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + model = model.to(torch.bfloat16) + + torch_pixel_values = torch_random( + (batch_size, image_channels, image_size_h, image_size_w), -1, 1, dtype=torch.bfloat16 + ) + torch_output, *_ = model.vit.embeddings.patch_embeddings(torch_pixel_values, interpolate_pos_encoding=True) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + device=device, + custom_preprocessor=ttnn_functional_vit_highres.custom_preprocessor, + ) + + # pixel_values = ttnn.from_torch(torch_pixel_values, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + print(pixel_values.shape) + # pixel_values = pixel_values.reshape(batch_size, image_size, image_size // patch_size, 4 * patch_size) # run it on device + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + output = ttnn_functional_vit_highres.vit_patch_embeddings( + config, pixel_values, parameters=parameters, unittest_check=True + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, torch.squeeze(output, 0), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("image_size_h", [1024]) +@pytest.mark.parametrize("image_size_w", [1024]) +def test_vit_embeddings(device, model_name, image_size_h, image_size_w): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0].resize((image_size_h, image_size_w)) + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt", do_resize=False, do_center_crop=False).pixel_values + torch_output, *_ = model.vit.embeddings(torch_pixel_values, interpolate_pos_encoding=True) + + # High resolution patch_parameters interpolation + model_state_dict = model.state_dict() + init_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size_h // patch_size) * (image_size_w // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size_h, image_size_w) + ) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit_highres.custom_preprocessor, + ) + + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + print(pixel_values.shape) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + output = ttnn_functional_vit_highres.vit_embeddings( + config, + pixel_values, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + print(output.shape) + assert_with_pcc(torch_output, torch.squeeze(output, 0), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [224]) +def test_vit_attention(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((8, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_attention_mask = torch.ones(1, sequence_size, dtype=torch.bfloat16) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device) + attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit_highres.vit_attention( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [4096]) +def test_vit_intermediate(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTIntermediate(config).eval() + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((1, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_output = model(torch_hidden_states) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit_highres.vit_intermediate( + hidden_states, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [4096]) +def test_vit_output(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTOutput(config).eval() + model = model.to(torch.bfloat16) + + torch_intermediate = torch_random((1, sequence_size, config.intermediate_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_residual = torch_random((1, sequence_size, config.hidden_size), -0.1, 0.1, dtype=torch.bfloat16) + torch_output = model(torch_intermediate, torch_residual) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + ) + + intermediate = ttnn.from_torch(torch_intermediate, layout=ttnn.TILE_LAYOUT, device=device) + residual = ttnn.from_torch(torch_residual, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit_highres.vit_output( + config, + intermediate, + residual, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [4096]) +def test_vit_layer(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").vit.encoder.layer[0] + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((1, sequence_size, config.hidden_size), -1, 1, dtype=torch.bfloat16) + torch_attention_mask = torch.ones(1, sequence_size, dtype=torch.bfloat16) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + attention_mask = ttnn.from_torch(torch_attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_functional_vit_highres.vit_layer( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [4096]) +def test_vit_encoder(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").vit.encoder + model = model.to(torch.bfloat16) + + torch_hidden_states = torch_random((1, sequence_size, config.hidden_size), -1, 1, dtype=torch.bfloat16) + torch_attention_mask = None + torch_output = model(torch_hidden_states, torch_attention_mask).last_hidden_state + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + if torch_attention_mask is not None: + attention_mask = ttnn.from_torch( + torch_attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + else: + attention_mask = None + + output = ttnn_functional_vit_highres.vit_encoder( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("image_size_h", [1024]) +@pytest.mark.parametrize("image_size_w", [1024]) +@pytest.mark.parametrize("sequence_size", [4128]) +def test_vit(device, model_name, image_size_h, image_size_w, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + model = model.to(torch.bfloat16) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0].resize((image_size_h, image_size_w)) + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor( + image, return_tensors="pt", do_resize=False, do_center_crop=False + ).pixel_values.to(torch.bfloat16) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + + torch_output, *_ = model(torch_pixel_values, interpolate_pos_encoding=True).logits + + # High resolution patch_parameters interpolation + model_state_dict = model.state_dict() + init_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size_h // patch_size) * (image_size_w // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size_h, image_size_w) + ) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_functional_vit_highres.custom_preprocessor, + ) + + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + print(pixel_values.shape) + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(1, -1, -1, -1), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + output = ttnn_functional_vit_highres.vit( + config, + pixel_values, + None, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + print(torch_output.shape) + print(output.shape) + + assert_with_pcc(torch_output, output[0][0], 0.9990) diff --git a/tests/ttnn/integration_tests/vit/test_ttnn_optimized_interleaved_vit.py b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_interleaved_vit.py new file mode 100644 index 00000000000..b10991c2ee4 --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_interleaved_vit.py @@ -0,0 +1,496 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.experimental.functional_vit.tt import ttnn_optimized_interleaved_vit +from models.utility_functions import torch_random, skip_for_wormhole_b0, torch2tt_tensor + +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.experimental.functional_vit.reference import torch_functional_vit + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_patch_embeddings(device, model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values) + torch_output, *_ = model.vit.embeddings.patch_embeddings(torch_pixel_values) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_interleaved_vit.custom_preprocessor, + ) + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = ttnn.experimental.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + ttnn.experimental.tensor.Layout.ROW_MAJOR, + tt_memory_config=ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + shard_spec, + ), + tt_dtype=ttnn.experimental.tensor.DataType.BFLOAT16, + ) + + output = ttnn_optimized_interleaved_vit.vit_patch_embeddings( + config, pixel_values, parameters=parameters, unittest_check=True + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_embeddings(device, model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + torch_output, *_ = model.vit.embeddings(torch_pixel_values) + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_interleaved_vit.custom_preprocessor, + ) + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = ttnn.experimental.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + ttnn.experimental.tensor.Layout.ROW_MAJOR, + tt_memory_config=ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + shard_spec, + ), + tt_dtype=ttnn.experimental.tensor.DataType.BFLOAT16, + ) + + output = ttnn_optimized_interleaved_vit.vit_embeddings( + config, + pixel_values, + cls_token, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_attention(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(batch_size, 1, 1, sequence_size, dtype=torch.float32) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_interleaved_vit.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + attention_mask = ttnn.from_torch( + torch_attention_mask, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + output = ttnn_optimized_interleaved_vit.vit_attention( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_intermediate(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTIntermediate(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_output = model(torch_hidden_states) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_optimized_interleaved_vit.vit_intermediate( + hidden_states, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_output(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTOutput(config).eval() + + torch_intermediate = torch_random((batch_size, sequence_size, config.intermediate_size), -1, 1, dtype=torch.float32) + torch_residual = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_output = model(torch_intermediate, torch_residual) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + ) + + intermediate = ttnn.from_torch(torch_intermediate, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + residual = ttnn.from_torch(torch_residual, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_optimized_interleaved_vit.vit_output( + config, + intermediate, + residual, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_layer(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").vit.encoder.layer[0] + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(batch_size, 1, 1, sequence_size, dtype=torch.float32) + + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_interleaved_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + attention_mask = ttnn.from_torch( + torch_attention_mask, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + output = ttnn_optimized_interleaved_vit.vit_layer( + config, + hidden_states, + attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + # assert_with_pcc(torch_output, output, 0.9999) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + convert_to_ttnn=lambda *_: False, + ) + torch_functional_output = torch_functional_vit.vit_layer( + config, + torch_hidden_states, + attention_mask=torch_attention_mask, + parameters=parameters, + ) + + assert_with_pcc(torch_functional_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) ## padded from 197 to 224 +def test_vit_encoder(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + torch_output = model(torch_hidden_states, torch_attention_mask).last_hidden_state + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_interleaved_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + output = ttnn_optimized_interleaved_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +@pytest.mark.parametrize("sequence_size", [224]) +def test_vit(device, model_name, batch_size, image_size, image_channels, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values).logits + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_interleaved_vit.custom_preprocessor, + ) + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = ttnn.experimental.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + ttnn.experimental.tensor.Layout.ROW_MAJOR, + tt_memory_config=ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + shard_spec, + ), + tt_dtype=ttnn.experimental.tensor.DataType.BFLOAT16, + ) + # pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + output = ttnn_optimized_interleaved_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output[0, 0, :1000], 0.9999) diff --git a/tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py new file mode 100644 index 00000000000..35c2a9648eb --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_sharded_vit.py @@ -0,0 +1,542 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.experimental.functional_vit.tt import ttnn_optimized_sharded_vit +from models.experimental.functional_vit.reference import torch_functional_vit +from models.utility_functions import torch_random, skip_for_wormhole_b0, torch2tt_tensor + +from tests.ttnn.utils_for_testing import assert_with_pcc + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_patch_embeddings(device, model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + torch_pixel_values = torch_random((batch_size, image_channels, image_size, image_size), -1, 1, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values) + torch_output, *_ = model.vit.embeddings.patch_embeddings(torch_pixel_values) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_sharded_vit.custom_preprocessor, + ) + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = ttnn.experimental.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + ttnn.experimental.tensor.Layout.ROW_MAJOR, + tt_memory_config=ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + shard_spec, + ), + tt_dtype=ttnn.experimental.tensor.DataType.BFLOAT16, + ) + # pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + + output = ttnn_optimized_sharded_vit.vit_patch_embeddings( + config, pixel_values, parameters=parameters, unittest_check=True + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_embeddings(device, model_name, batch_size, image_size, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + torch_output, *_ = model.vit.embeddings(torch_pixel_values) + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + # torch_cls_token_padded = torch.nn.functional.pad(torch_cls_token, (0, 0, 0, 196, 0, 0)) + # torch_cls_position_embeddings = torch.add(torch_cls_token_padded, torch_position_embeddings) + # cls_position_embeddings = ttnn.from_torch( + # torch_cls_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + # ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_sharded_vit.custom_preprocessor, + ) + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = ttnn.experimental.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + ttnn.experimental.tensor.Layout.ROW_MAJOR, + tt_memory_config=ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + shard_spec, + ), + tt_dtype=ttnn.experimental.tensor.DataType.BFLOAT16, + ) + + output = ttnn_optimized_sharded_vit.vit_embeddings( + config, + pixel_values, + cls_token, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + print(output.shape) + assert_with_pcc(torch_output, output[0][:197:], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_attention(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(batch_size, 1, 1, sequence_size, dtype=torch.float32) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_sharded_vit.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + attention_mask = ttnn.from_torch( + torch_attention_mask, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + encoder_input = ttnn.to_memory_config( + hidden_states, + memory_config=ttnn.create_sharded_memory_config( + hidden_states.shape, + core_grid=config.core_grid, + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + # orientation=ttnn.ShardOrientation.COLUMN_MAJOR, + ), + dtype=ttnn.bfloat8_b, + ) + ttnn.deallocate(hidden_states) + + output = ttnn_optimized_sharded_vit.vit_attention( + config, + encoder_input, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_intermediate(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + model = transformers.models.vit.modeling_vit.ViTIntermediate(config).eval() + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_output = model(torch_hidden_states) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_optimized_sharded_vit.vit_intermediate( + config, + hidden_states, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_output(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + model = transformers.models.vit.modeling_vit.ViTOutput(config).eval() + + torch_intermediate = torch_random((batch_size, sequence_size, config.intermediate_size), -1, 1, dtype=torch.float32) + torch_residual = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_output = model(torch_intermediate, torch_residual) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + ) + + intermediate = ttnn.from_torch(torch_intermediate, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + residual = ttnn.from_torch(torch_residual, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + residual_sh = ttnn.to_memory_config( + residual, + memory_config=ttnn.create_sharded_memory_config( + residual.shape, + core_grid=config.core_grid, + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + # orientation=ttnn.ShardOrientation.COLUMN_MAJOR, + ), + dtype=ttnn.bfloat8_b, + ) + ttnn.deallocate(residual) + + output = ttnn_optimized_sharded_vit.vit_output( + config, + intermediate, + residual_sh, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) # padded from 197 to 224 +def test_vit_layer(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").vit.encoder.layer[0] + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(batch_size, 1, 1, sequence_size, dtype=torch.float32) + + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_sharded_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + attention_mask = ttnn.from_torch( + torch_attention_mask, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + encoder_input = ttnn.to_memory_config( + hidden_states, + memory_config=ttnn.create_sharded_memory_config( + hidden_states.shape, + core_grid=config.core_grid, + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + # orientation=ttnn.ShardOrientation.COLUMN_MAJOR, + ), + dtype=ttnn.bfloat8_b, + ) + ttnn.deallocate(hidden_states) + + output = ttnn_optimized_sharded_vit.vit_layer( + config, + encoder_input, + attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("sequence_size", [224]) ## padded from 197 to 224 +def test_vit_encoder(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + torch_output = model(torch_hidden_states, torch_attention_mask).last_hidden_state + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_sharded_vit.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + output = ttnn_optimized_sharded_vit.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("image_size", [224]) +@pytest.mark.parametrize("image_channels", [3]) +@pytest.mark.parametrize("sequence_size", [224]) +def test_vit(device, model_name, batch_size, image_size, image_channels, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + config = ttnn_optimized_sharded_vit.update_model_config(config, batch_size) + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0] + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt").pixel_values + torch_pixel_values = torch_pixel_values.repeat(batch_size, 1, 1, 1) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values).logits + + # cls_token & position embeddings expand to batch_size + # TODO: pass batch_size to preprocess_model_parameters + model_state_dict = model.state_dict() + torch_cls_token = model_state_dict["vit.embeddings.cls_token"] + torch_position_embeddings = model_state_dict["vit.embeddings.position_embeddings"] + if batch_size > 1: + torch_cls_token = torch.nn.Parameter(torch_cls_token.expand(batch_size, -1, -1)) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings.expand(batch_size, -1, -1)) + else: + torch_cls_token = torch.nn.Parameter(torch_cls_token) + torch_position_embeddings = torch.nn.Parameter(torch_position_embeddings) + cls_token = ttnn.from_torch(torch_cls_token, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + # torch_cls_token_padded = torch.nn.functional.pad(torch_cls_token, (0, 0, 0, 196, 0, 0)) + # torch_cls_position_embeddings = torch.add(torch_cls_token_padded, torch_position_embeddings) + # cls_position_embeddings = ttnn.from_torch( + # torch_cls_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + # ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_sharded_vit.custom_preprocessor, + ) + + torch_pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + torch_pixel_values = torch.nn.functional.pad(torch_pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + batch_size, img_h, img_w, img_c = torch_pixel_values.shape # permuted input NHWC + patch_size = 16 + torch_pixel_values = torch_pixel_values.reshape(batch_size, img_h, img_w // patch_size, 4 * patch_size) + N, H, W, C = torch_pixel_values.shape + shard_grid = ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(7, 0), + ), + } + ) + n_cores = 8 + shard_spec = ttnn.experimental.tensor.ShardSpec( + shard_grid, [N * H * W // n_cores, C], ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, False + ) + + pixel_values = torch2tt_tensor( + torch_pixel_values, + device, + ttnn.experimental.tensor.Layout.ROW_MAJOR, + tt_memory_config=ttnn.experimental.tensor.MemoryConfig( + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.BufferType.L1, + shard_spec, + ), + tt_dtype=ttnn.experimental.tensor.DataType.BFLOAT16, + ) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + output = ttnn_optimized_sharded_vit.vit( + config, + pixel_values, + head_masks, + cls_token, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output[0, 0, :1000], 0.8146) diff --git a/tests/ttnn/integration_tests/vit/test_ttnn_optimized_vit_highres.py b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_vit_highres.py new file mode 100644 index 00000000000..c3153c10c9a --- /dev/null +++ b/tests/ttnn/integration_tests/vit/test_ttnn_optimized_vit_highres.py @@ -0,0 +1,465 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch +import math +import transformers +from datasets import load_dataset +from transformers import AutoImageProcessor + +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters + +from models.experimental.functional_vit.tt import ttnn_optimized_vit_highres +from models.utility_functions import torch_random, skip_for_wormhole_b0 + +from tests.ttnn.utils_for_testing import assert_with_pcc + + +def interpolate_pos_encoding( + position_embeddings: torch.Tensor, patch_size, num_patches, height: int, width: int +) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = position_embeddings.shape[-1] + h0 = height // patch_size + w0 = width // patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = torch.nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("image_size_h", [1024]) +@pytest.mark.parametrize("image_size_w", [1024]) +@pytest.mark.parametrize("image_channels", [3]) +def test_vit_patch_embeddings(device, model_name, image_size_h, image_size_w, image_channels): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + torch_pixel_values = torch_random((1, image_channels, image_size_h, image_size_w), -1, 1, dtype=torch.float32) + torch_output, *_ = model.vit.embeddings.patch_embeddings(torch_pixel_values, interpolate_pos_encoding=True) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_vit_highres.custom_preprocessor, + ) + + # pixel_values = ttnn.from_torch(torch_pixel_values, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + print(pixel_values.shape) + # pixel_values = pixel_values.reshape(batch_size, image_size, image_size // patch_size, 4 * patch_size) # run it on device + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + output = ttnn_optimized_vit_highres.vit_patch_embeddings( + config, + pixel_values, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output[0], 0.9999) + + +# padding +# 1024 / 16 = 64 +# 64*64 + 32 = 4128 (from cls_token concat) +# 4352 = (4128 + 224) +# 4352 / 8 = 136 + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("image_size_h", [1024]) +@pytest.mark.parametrize("image_size_w", [1024]) +def test_vit_embeddings(device, model_name, image_size_h, image_size_w): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0].resize((image_size_h, image_size_w)) + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt", do_resize=False, do_center_crop=False).pixel_values + torch_output, *_ = model.vit.embeddings(torch_pixel_values, interpolate_pos_encoding=True) + + # High resolution patch_parameters interpolation + model_state_dict = model.state_dict() + init_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size_h // patch_size) * (image_size_w // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size_h, image_size_w) + ) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_vit_highres.custom_preprocessor, + ) + + # pixel_values = ttnn.from_torch(torch_pixel_values, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + patch_size = 16 + pixel_values = torch.permute(torch_pixel_values, (0, 2, 3, 1)) + pixel_values = torch.nn.functional.pad(pixel_values, (0, 1, 0, 0, 0, 0, 0, 0)) + print(pixel_values.shape) + # pixel_values = pixel_values.reshape(batch_size, image_size, image_size // patch_size, 4 * patch_size) # run it on device + pixel_values = ttnn.from_torch(pixel_values, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + + output = ttnn_optimized_vit_highres.vit_embeddings( + config, + pixel_values, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + print(output.shape) + + # pad_size = (0, 0, 0, 255) # 224 + 31 + # torch_output = torch.nn.functional.pad(torch_output, pad_size, "constant", 0) + + assert_with_pcc(torch_output, output[0], 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [640]) +def test_vit_attention_experimental(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + + torch_hidden_states = torch_random((1, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(1, 1, 1, sequence_size, dtype=torch.float32) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_vit_highres.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + attention_mask = ttnn.from_torch(torch_attention_mask, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_optimized_vit_highres.vit_attention_experimental( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [640]) +def test_vit_attention(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTAttention(config).eval() + + torch_hidden_states = torch_random((1, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(1, 1, 1, sequence_size, dtype=torch.float32) + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_vit_highres.custom_preprocessor, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + attention_mask = ttnn.from_torch(torch_attention_mask, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_optimized_vit_highres.vit_attention( + config, + hidden_states, + attention_mask=attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [4032]) +def test_vit_intermediate(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTIntermediate(config).eval() + + torch_hidden_states = torch_random((1, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_output = model(torch_hidden_states) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model.to(torch.bfloat16), + device=device, + ) + + hidden_states = ttnn.from_torch(torch_hidden_states, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_optimized_vit_highres.vit_intermediate( + hidden_states, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [4032]) +def test_vit_output(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.models.vit.modeling_vit.ViTOutput(config).eval() + + torch_intermediate = torch_random((1, sequence_size, config.intermediate_size), -1, 1, dtype=torch.float32) + torch_residual = torch_random((1, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_output = model(torch_intermediate, torch_residual) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + ) + + intermediate = ttnn.from_torch(torch_intermediate, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + residual = ttnn.from_torch(torch_residual, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) + + output = ttnn_optimized_vit_highres.vit_output( + config, + intermediate, + residual, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output.to(torch_output.dtype), 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [4032]) +def test_vit_layer(device, model_name, batch_size, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").vit.encoder.layer[0] + # print(model) + + torch_hidden_states = torch_random((batch_size, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(batch_size, 1, 1, sequence_size, dtype=torch.float32) + + torch_output, *_ = model(torch_hidden_states, torch_attention_mask) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_vit_highres.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + attention_mask = ttnn.from_torch( + torch_attention_mask, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + output = ttnn_optimized_vit_highres.vit_layer( + config, + hidden_states, + attention_mask, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("sequence_size", [640]) # +def test_vit_encoder(device, model_name, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained( + "google/vit-base-patch16-224", config=config + ).vit.encoder + + torch_hidden_states = torch_random((1, sequence_size, config.hidden_size), -1, 1, dtype=torch.float32) + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + torch_output = model(torch_hidden_states, torch_attention_mask).last_hidden_state + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + custom_preprocessor=ttnn_optimized_vit_highres.custom_preprocessor, + device=device, + ) + + hidden_states = ttnn.from_torch( + torch_hidden_states, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(1, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + output = ttnn_optimized_vit_highres.vit_encoder( + config, + hidden_states, + head_masks, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output, 0.9999) + + +# padding +# 1024 / 16 = 64 +# 64*64 + 32 = 4128 (from cls_token concat) +# 4352 = (4128 + 224) +# 4352 / 8 = 136 + + +@pytest.mark.skip(reason="#7527: Test and PCC threshold needs review") +@skip_for_wormhole_b0() +@pytest.mark.parametrize("model_name", ["google/vit-base-patch16-224"]) +@pytest.mark.parametrize("image_size_h", [1024]) +@pytest.mark.parametrize("image_size_w", [1024]) +@pytest.mark.parametrize("sequence_size", [4032]) +def test_vit(device, model_name, image_size_h, image_size_w, sequence_size): + torch.manual_seed(0) + + config = transformers.ViTConfig.from_pretrained(model_name) + config.num_hidden_layers = 12 + model = transformers.ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", config=config) + + dataset = load_dataset("huggingface/cats-image") + image = dataset["test"]["image"][0].resize((image_size_h, image_size_w)) + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + torch_pixel_values = image_processor(image, return_tensors="pt", do_resize=False, do_center_crop=False).pixel_values + torch_attention_mask = torch.ones(config.num_hidden_layers, sequence_size, dtype=torch.float32) + torch_output, *_ = model(torch_pixel_values, interpolate_pos_encoding=True).logits + + # High resolution patch_parameters interpolation + model_state_dict = model.state_dict() + init_position_embeddings = torch.nn.Parameter(model_state_dict["vit.embeddings.position_embeddings"]) + patch_size = 16 + tot_patch_count = (image_size_h // patch_size) * (image_size_w // patch_size) + torch_position_embeddings = torch.nn.Parameter( + interpolate_pos_encoding(init_position_embeddings, patch_size, tot_patch_count, image_size, image_size) + ) + position_embeddings = ttnn.from_torch( + torch_position_embeddings, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device + ) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, + device=device, + custom_preprocessor=ttnn_optimized_vit_highres.custom_preprocessor, + ) + + torch_pixel_values = torch_pixel_values.to(torch.bfloat16) + pixel_values = ttnn.from_torch(torch_pixel_values, layout=ttnn.TILE_LAYOUT, device=device) + + if torch_attention_mask is not None: + head_masks = [ + ttnn.from_torch( + torch_attention_mask[index].reshape(1, 1, 1, sequence_size).expand(batch_size, -1, -1, -1), + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + for index in range(config.num_hidden_layers) + ] + else: + head_masks = [None for _ in range(config.num_hidden_layers)] + + output = ttnn_optimized_vit_highres.vit( + config, + pixel_values, + head_masks, + position_embeddings, + parameters=parameters, + ) + output = ttnn.to_torch(output) + + assert_with_pcc(torch_output, output[0][0], 0.9999) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 5fba74dd97f..1a15b2bfd09 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -660,11 +660,15 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo // TT_ASSERT(shard_shape[1] * tensor_impl::element_size_bytes_wrapper(data_type) % ADDRESS_ALIGNMENT == 0); } + auto width = shape[-1]; + auto other_dims = 1; + for (int i = 0; i < shape.rank() - 1; i++) { + other_dims *= shape[i]; + } + auto element_size = tensor_impl::element_size_bytes_wrapper(data_type); auto page_shape = tensor_impl::get_sharded_page_shape(layout, data_type, shard_spec.shape); - std::array tensor2d_size = {shape[0]*shape[1] * shape[2]/page_shape[0], - shape[3]/page_shape[1] - }; + std::array tensor2d_size = {other_dims/page_shape[0], width/page_shape[1]}; ShardSpecBuffer shard_spec_buffer(shard_spec, page_shape, tensor2d_size); uint32_t packed_size_in_bytes; diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp index 1d622c804fd..d3ee017221e 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp @@ -111,7 +111,8 @@ void EltwiseBinary::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); TT_FATAL( - input_tensor_a.get_legacy_shape().without_padding() == input_tensor_b.get_legacy_shape().without_padding(), + (input_tensor_a.get_legacy_shape() == input_tensor_b.get_legacy_shape()) or + (input_tensor_a.get_legacy_shape().without_padding() == input_tensor_b.get_legacy_shape().without_padding()), "Input shapes must be the same!"); TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE and input_tensor_b.storage_type() == StorageType::DEVICE, "Operands to eltwise binary need to be on device!"); TT_FATAL(input_tensor_a.buffer() != nullptr and input_tensor_b.buffer() != nullptr, "Operands to eltwise binary need to be allocated in buffers on device!"); diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp index e01fd305eb3..7a810cce35e 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp @@ -132,7 +132,9 @@ struct make_eltwise_binary { } } TT_FATAL( - in_a.get_legacy_shape() == in_b.get_legacy_shape(), "Input shapes must be the same!"); + (in_a.get_legacy_shape() == in_b.get_legacy_shape()) or + (in_a.get_legacy_shape().without_padding() == in_b.get_legacy_shape().without_padding()), + "Input shapes must be the same!"); return operation::run_with_autoformat( EltwiseBinary{ binary_op_type, @@ -199,7 +201,8 @@ inline Tensor add( } } TT_FATAL( - in_a.get_legacy_shape().without_padding() == in_b.get_legacy_shape().without_padding(), + (input_tensor_a.get_legacy_shape() == input_tensor_b.get_legacy_shape()) or + (input_tensor_a.get_legacy_shape().without_padding() == input_tensor_b.get_legacy_shape().without_padding()), "Input shapes must be the same!"); auto output = operation::run( EltwiseBinary{