Skip to content

Commit

Permalink
comparation script & dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
kramarenko.gs committed May 26, 2024
1 parent b4a1c74 commit c3e2cab
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len)
self.dropout_rate = dropout_rate
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable biases are used in matrix c and matrix d
Expand Down Expand Up @@ -242,7 +243,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
else:
attn_mask = matrix_bd

out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=attn_mask)
out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=attn_mask, dropout_p=self.dropout_rate)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = torch.nan_to_num(out, nan=0.0)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down
53 changes: 53 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
import time
from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention as NewRelPositionMultiHeadAttention
from old_multi_head_attention import RelPositionMultiHeadAttention
from nemo.utils import avoid_float16_autocast_context


torch.manual_seed(123)


device = "cuda"
batch_size = 32
seq_len = 1024
d_model = 512
n_head = 8

query = torch.rand(batch_size, seq_len, d_model).to(device)
key = torch.rand(batch_size, seq_len, d_model).to(device)
value = torch.rand(batch_size, seq_len, d_model).to(device)
mask = torch.ones(batch_size, seq_len, seq_len)
mask = torch.triu(mask).bool().to(device)
# mask = None
pos_emb = torch.rand(batch_size, seq_len, d_model).to(device)

torch.manual_seed(123)
attention_sdpa = NewRelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device)
torch.manual_seed(123)
attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device)

output_sdpa = attention_sdpa(query, key, value, mask, pos_emb)
output_original = attention_original(query, key, value, mask, pos_emb)

def measure_time(attention, query, key, value, mask, pos_emb):
timer = benchmark.Timer(
stmt='attention(query, key, value, mask, pos_emb)',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
)
results = timer.blocked_autorange(min_run_time=10)
return results.mean, results

time_original, _ = measure_time(attention_original, query, key, value, mask, pos_emb)
time_sdpa, _ = measure_time(attention_sdpa, query, key, value, mask, pos_emb)

print(f"Original implementation time: {time_original:.6f} seconds")
print(f"SDPA implementation time: {time_sdpa:.6f} seconds")
print(f"SDPA boost {(time_original - time_sdpa) / time_original * 100:.3f}%")
print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")
# Original implementation time: 0.042316 seconds
# SDPA implementation time: 0.030923 seconds
# SDPA boost 26.924%
# Outputs are the same

0 comments on commit c3e2cab

Please sign in to comment.