From 95ea37cd39b5e55bf349a3a8c8d1b31d48bc58f3 Mon Sep 17 00:00:00 2001 From: WoodieDudy Date: Tue, 2 Jul 2024 10:18:50 +0000 Subject: [PATCH] add use_sdpa arg Signed-off-by: WoodieDudy --- .../parts/submodules/multi_head_attention.py | 111 +++++---- sdpa_testing/old_multi_head_attention.py | 210 ------------------ sdpa_testing/sdpa_mha_benchmark.py | 84 ------- sdpa_testing/sdpa_relpos_mha_benchmark.py | 84 ------- 4 files changed, 65 insertions(+), 424 deletions(-) delete mode 100644 sdpa_testing/old_multi_head_attention.py delete mode 100644 sdpa_testing/sdpa_mha_benchmark.py delete mode 100644 sdpa_testing/sdpa_relpos_mha_benchmark.py diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index e4a58b635a20..ab5b7b9d7ea4 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -57,11 +57,12 @@ class MultiHeadAttention(nn.Module): dropout_rate (float): dropout rate """ - def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0): + def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0, use_sdpa=False): """Construct an MultiHeadedAttention object.""" super(MultiHeadAttention, self).__init__() - self.dropout_rate = dropout_rate + self.use_sdpa = use_sdpa self.cache_drop_size = None + self.dropout_rate = dropout_rate assert n_feat % n_head == 0 # We assume d_v always equals d_k self.d_k = n_feat // n_head @@ -96,28 +97,28 @@ def forward_qkv(self, query, key, value): return q, k, v - # def forward_attention(self, value, scores, mask): - # """Compute attention context vector. - # Args: - # value (torch.Tensor): (batch, time2, size) - # scores(torch.Tensor): (batch, time1, time2) - # mask(torch.Tensor): (batch, time1, time2) - # returns: - # value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores - # """ - # n_batch = value.size(0) - # if mask is not None: - # mask = mask.unsqueeze(1) # (batch, 1, time1, time2) - # scores = scores.masked_fill(mask, -10000.0) - # attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) - # else: - # attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - # p_attn = self.dropout(attn) - # x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - # x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) - - # return self.linear_out(x) # (batch, time1, d_model) + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + Args: + value (torch.Tensor): (batch, time2, size) + scores(torch.Tensor): (batch, time1, time2) + mask(torch.Tensor): (batch, time1, time2) + returns: + value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -10000.0) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) def forward(self, query, key, value, mask, pos_emb=None, cache=None): """Compute 'Scaled Dot Product Attention'. @@ -133,21 +134,28 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None): cache (torch.Tensor) : (batch, time_cache_next, size) """ key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) - n_batch = value.size(0) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) - + + # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value) - scale = 1 / self.s_d_k - - if mask is not None: - mask = mask.unsqueeze(1).logical_not() - - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale) - out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) - out = self.linear_out(out) # (batch, time1, d_model) + + if self.use_sdpa: + scale = 1 / self.s_d_k + n_batch = value.size(0) + + if mask is not None: + mask = mask.unsqueeze(1) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale) + out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + out = self.linear_out(out) # (batch, time1, d_model) + else: + scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k + out = self.forward_attention(v, scores, mask) + if cache is None: return out else: @@ -170,10 +178,9 @@ class RelPositionMultiHeadAttention(MultiHeadAttention): dropout_rate (float): dropout rate """ - def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0): + def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0, use_sdpa=False): """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len) - self.dropout_rate = dropout_rate + super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len, use_sdpa=use_sdpa) # linear transformation for positional encoding self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) # these two learnable biases are used in matrix c and matrix d @@ -236,21 +243,33 @@ def forward(self, query, key, value, mask, pos_emb, cache=None): # (batch, head, time1, d_k) q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + # compute matrix b and matrix d # (batch, head, time1, time2) matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) matrix_bd = self.rel_shift(matrix_bd) - scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1)) - # drops extra elements in the matrix_bd to match the matrix_ac's size - matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor - if mask is not None: - mask = mask.unsqueeze(1) - matrix_bd.masked_fill_(mask, float("-inf")) + if self.use_sdpa: + scale_factor = 1 / math.sqrt(q_with_bias_u.size(-1)) + matrix_bd = matrix_bd[:, :, :, : k.size(-2)] * scale_factor - out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=self.dropout_rate, scale=scale_factor) - out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) - out = self.linear_out(out) # (batch, time1, d_model) + if mask is not None: + mask = mask.unsqueeze(1) + matrix_bd.masked_fill_(mask.logical_not(), float("-inf")) + + out = torch.nn.functional.scaled_dot_product_attention(q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=self.dropout_rate, scale=scale_factor) + out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) + out = self.linear_out(out) # (batch, time1, d_model) + else: + # drops extra elements in the matrix_bd to match the matrix_ac's size + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] + scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2) + out = self.forward_attention(v, scores, mask) if cache is None: return out diff --git a/sdpa_testing/old_multi_head_attention.py b/sdpa_testing/old_multi_head_attention.py deleted file mode 100644 index 41ba181c8ee3..000000000000 --- a/sdpa_testing/old_multi_head_attention.py +++ /dev/null @@ -1,210 +0,0 @@ -import math -from functools import lru_cache -from typing import List, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from nemo.utils import avoid_float16_autocast_context - - -class MultiHeadAttention(nn.Module): - """Multi-Head Attention layer of Transformer. - Args: - n_head (int): number of heads - n_feat (int): size of the features - dropout_rate (float): dropout rate - """ - - def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadAttention, self).__init__() - self.cache_drop_size = None - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.s_d_k = math.sqrt(self.d_k) - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - self._max_cache_len = max_cache_len - - def forward_qkv(self, query, key, value): - """Transforms query, key and value. - Args: - query (torch.Tensor): (batch, time1, size) - key (torch.Tensor): (batch, time2, size) - value (torch.Tensor): (batch, time2, size) - returns: - q (torch.Tensor): (batch, head, time1, size) - k (torch.Tensor): (batch, head, time2, size) - v (torch.Tensor): (batch, head, time2, size) - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - return q, k, v - - def forward_attention(self, value, scores, mask): - """Compute attention context vector. - Args: - value (torch.Tensor): (batch, time2, size) - scores(torch.Tensor): (batch, time1, time2) - mask(torch.Tensor): (batch, time1, time2) - returns: - value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1) # (batch, 1, time1, time2) - scores = scores.masked_fill(mask, -10000.0) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) - else: - attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward(self, query, key, value, mask, pos_emb=None, cache=None): - """Compute 'Scaled Dot Product Attention'. - Args: - query (torch.Tensor): (batch, time1, size) - key (torch.Tensor): (batch, time2, size) - value(torch.Tensor): (batch, time2, size) - mask (torch.Tensor): (batch, time1, time2) - cache (torch.Tensor) : (batch, time_cache, size) - - returns: - output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention - cache (torch.Tensor) : (batch, time_cache_next, size) - """ - key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) - - if torch.is_autocast_enabled(): - query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) - - # temporary until we solve this more gracefully - with avoid_float16_autocast_context(): - q, k, v = self.forward_qkv(query, key, value) - scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k - out = self.forward_attention(v, scores, mask) - if cache is None: - return out - else: - return out, cache - - def update_cache(self, key, value, query, cache): - if cache is not None: - key = value = torch.cat([cache, key], dim=1) - q_keep_size = query.shape[1] - self.cache_drop_size - cache = torch.cat([cache[:, q_keep_size:, :], query[:, :q_keep_size, :]], dim=1) - return key, value, query, cache - - -class RelPositionMultiHeadAttention(MultiHeadAttention): - """Multi-Head Attention layer of Transformer-XL with support of relative positional encoding. - Paper: https://arxiv.org/abs/1901.02860 - Args: - n_head (int): number of heads - n_feat (int): size of the features - dropout_rate (float): dropout rate - """ - - def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len) - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) - # these two learnable biases are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - if pos_bias_u is None or pos_bias_v is None: - self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) - # nn.init.normal_(self.pos_bias_u, 0.0, 0.02) - # nn.init.normal_(self.pos_bias_v, 0.0, 0.02) - nn.init.zeros_(self.pos_bias_u) - nn.init.zeros_(self.pos_bias_v) - else: - self.pos_bias_u = pos_bias_u - self.pos_bias_v = pos_bias_v - - def rel_shift(self, x): - """Compute relative positional encoding. - Args: - x (torch.Tensor): (batch, nheads, time, 2*time-1) - """ - b, h, qlen, pos_len = x.size() # (b, h, t1, t2) - # need to add a column of zeros on the left side of last dimension to perform the relative shifting - x = torch.nn.functional.pad(x, pad=(1, 0)) # (b, h, t1, t2+1) - x = x.view(b, h, -1, qlen) # (b, h, t2+1, t1) - # need to drop the first row - x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2) - return x - - def forward(self, query, key, value, mask, pos_emb, cache=None): - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - Args: - query (torch.Tensor): (batch, time1, size) - key (torch.Tensor): (batch, time2, size) - value(torch.Tensor): (batch, time2, size) - mask (torch.Tensor): (batch, time1, time2) - pos_emb (torch.Tensor) : (batch, time1, size) - cache (torch.Tensor) : (batch, time_cache, size) - - Returns: - output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention - cache (torch.Tensor) : (batch, time_cache_next, size) - """ - key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) - - if torch.is_autocast_enabled(): - query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) - - # temporary until we solve this more gracefully - with avoid_float16_autocast_context(): - q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) - - n_batch_pos = pos_emb.size(0) - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, time1, d_k) - - # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) - # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) - - # compute matrix b and matrix d - # (batch, head, time1, time2) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) - matrix_bd = self.rel_shift(matrix_bd) - # drops extra elements in the matrix_bd to match the matrix_ac's size - matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] - - scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2) - - out = self.forward_attention(v, scores, mask) - - if cache is None: - return out - else: - return out, cache diff --git a/sdpa_testing/sdpa_mha_benchmark.py b/sdpa_testing/sdpa_mha_benchmark.py deleted file mode 100644 index 71febbf3c5d2..000000000000 --- a/sdpa_testing/sdpa_mha_benchmark.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -import torch.nn as nn -import torch.utils.benchmark as benchmark -from nemo.collections.asr.parts.submodules.multi_head_attention import MultiHeadAttention as SDPAMultiHeadAttention -from old_multi_head_attention import MultiHeadAttention -from nemo.utils import avoid_float16_autocast_context - - -torch.manual_seed(123) - -device = "cuda" -batch_size = 32 -seq_len = 1024 -d_model = 512 -n_head = 8 - -query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False) -mask = torch.triu(mask, diagonal=1).bool() -mask = None - -attention_sdpa = SDPAMultiHeadAttention(n_head, d_model, 0.0).to(device) -attention_original = MultiHeadAttention(n_head, d_model, 0.0).to(device) -for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()): - original_param.data.copy_(sdpa_param.data) -# attention_sdpa = torch.compile(attention_sdpa) -# attention_original = torch.compile(attention_original) - - -def measure_time(attention, query, key, value, mask): - with torch.no_grad(): - timer = benchmark.Timer( - stmt='attention(query, key, value, mask);torch.cuda.synchronize()', - setup='torch.cuda.synchronize()', - globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask} - ) - torch.cuda.synchronize() - - with torch.no_grad(): - results = timer.blocked_autorange(min_run_time=10) - forward_time = results.mean - output = attention(query, key, value, mask) - return forward_time, output - - -def measure_fwd_bwd_time(attention, query, key, value, mask): - timer = benchmark.Timer( - stmt='loss=attention(query, key, value, mask).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()', - setup='torch.cuda.synchronize()', - globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask} - ) - torch.cuda.synchronize() - results = timer.blocked_autorange(min_run_time=10) - fwd_bwd_time = results.mean - return fwd_bwd_time - - -time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask) -time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask) - -print(f"Original implementation time: {time_fwd_original:.6f} seconds") -print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds") -print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%") - -time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask) -time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask) -time_bwd_original = time_fwd_bwd_original - time_fwd_original -time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa - -print(f"Original implementation backward time: {time_bwd_original:.6f} seconds") -print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds") -print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%") - -print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}") - -# Original implementation time: 0.017988 seconds -# SDPA implementation time: 0.012850 seconds -# SDPA boost 28.562% -# Original implementation backward time: 0.034470 seconds -# SDPA implementation backward time: 0.038126 seconds -# SDPA backward boost -10.608% -# Outputs are the same \ No newline at end of file diff --git a/sdpa_testing/sdpa_relpos_mha_benchmark.py b/sdpa_testing/sdpa_relpos_mha_benchmark.py deleted file mode 100644 index f732fc7e72ed..000000000000 --- a/sdpa_testing/sdpa_relpos_mha_benchmark.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -import torch.nn as nn -import torch.utils.benchmark as benchmark -from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention as SDPARelPositionMultiHeadAttention -from old_multi_head_attention import RelPositionMultiHeadAttention -from nemo.utils import avoid_float16_autocast_context - -torch.manual_seed(123) - -device = "cuda" -batch_size = 32 -seq_len = 1024 -d_model = 512 -n_head = 8 - -query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False) -mask = torch.triu(mask, diagonal=1).bool() -mask = None -pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) - -attention_sdpa = SDPARelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) -attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) -for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()): - original_param.data.copy_(sdpa_param.data) - -# attention_sdpa = torch.compile(attention_sdpa) -# attention_original = torch.compile(attention_original) - - -def measure_time(attention, query, key, value, mask, pos_emb): - with torch.no_grad(): - timer = benchmark.Timer( - stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()', - setup='torch.cuda.synchronize()', - globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} - ) - - with torch.no_grad(): - results = timer.blocked_autorange(min_run_time=10) - forward_time = results.mean - output = attention(query, key, value, mask, pos_emb) - return forward_time, output - - -def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb): - timer = benchmark.Timer( - stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()', - setup='torch.cuda.synchronize()', - globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} - ) - torch.cuda.synchronize() - results = timer.blocked_autorange(min_run_time=10) - fwd_bwd_time = results.mean - return fwd_bwd_time - - -time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb) -time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb) - -print(f"Original implementation time: {time_fwd_original:.6f} seconds") -print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds") -print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%") - -time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb) -time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb) -time_bwd_original = time_fwd_bwd_original - time_fwd_original -time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa - -print(f"Original implementation backward time: {time_bwd_original:.6f} seconds") -print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds") -print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%") - -print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}") - -# Original implementation time: 0.031546 seconds -# SDPA implementation time: 0.025657 seconds -# SDPA boost 18.67% -# Original implementation backward time: 0.069962 seconds -# SDPA implementation backward time: 0.078320 seconds -# SDPA backward boost -11.95% -# Outputs are the same \ No newline at end of file