From c3e2cab6d9e51d2a84324e5e2b55ed6498bffb70 Mon Sep 17 00:00:00 2001 From: "kramarenko.gs" Date: Sun, 26 May 2024 19:03:07 +0000 Subject: [PATCH] comparation script & dropout --- .../parts/submodules/multi_head_attention.py | 3 +- test.py | 53 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 test.py diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 5da99cdc27f3..c5e64659e6a5 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -166,6 +166,7 @@ class RelPositionMultiHeadAttention(MultiHeadAttention): def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len) + self.dropout_rate = dropout_rate # linear transformation for positional encoding self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) # these two learnable biases are used in matrix c and matrix d @@ -242,7 +243,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None): else: attn_mask = matrix_bd - out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=attn_mask) + out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=attn_mask, dropout_p=self.dropout_rate) out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) out = torch.nan_to_num(out, nan=0.0) out = self.linear_out(out) # (batch, time1, d_model) diff --git a/test.py b/test.py new file mode 100644 index 000000000000..63e5132bc4ed --- /dev/null +++ b/test.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.utils.benchmark as benchmark +import time +from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention as NewRelPositionMultiHeadAttention +from old_multi_head_attention import RelPositionMultiHeadAttention +from nemo.utils import avoid_float16_autocast_context + + +torch.manual_seed(123) + + +device = "cuda" +batch_size = 32 +seq_len = 1024 +d_model = 512 +n_head = 8 + +query = torch.rand(batch_size, seq_len, d_model).to(device) +key = torch.rand(batch_size, seq_len, d_model).to(device) +value = torch.rand(batch_size, seq_len, d_model).to(device) +mask = torch.ones(batch_size, seq_len, seq_len) +mask = torch.triu(mask).bool().to(device) +# mask = None +pos_emb = torch.rand(batch_size, seq_len, d_model).to(device) + +torch.manual_seed(123) +attention_sdpa = NewRelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) +torch.manual_seed(123) +attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) + +output_sdpa = attention_sdpa(query, key, value, mask, pos_emb) +output_original = attention_original(query, key, value, mask, pos_emb) + +def measure_time(attention, query, key, value, mask, pos_emb): + timer = benchmark.Timer( + stmt='attention(query, key, value, mask, pos_emb)', + globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} + ) + results = timer.blocked_autorange(min_run_time=10) + return results.mean, results + +time_original, _ = measure_time(attention_original, query, key, value, mask, pos_emb) +time_sdpa, _ = measure_time(attention_sdpa, query, key, value, mask, pos_emb) + +print(f"Original implementation time: {time_original:.6f} seconds") +print(f"SDPA implementation time: {time_sdpa:.6f} seconds") +print(f"SDPA boost {(time_original - time_sdpa) / time_original * 100:.3f}%") +print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}") +# Original implementation time: 0.042316 seconds +# SDPA implementation time: 0.030923 seconds +# SDPA boost 26.924% +# Outputs are the same \ No newline at end of file