diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 155e2cc5c0b2..e4a58b635a20 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -60,6 +60,7 @@ class MultiHeadAttention(nn.Module): def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0): """Construct an MultiHeadedAttention object.""" super(MultiHeadAttention, self).__init__() + self.dropout_rate = dropout_rate self.cache_drop_size = None assert n_feat % n_head == 0 # We assume d_v always equals d_k @@ -132,14 +133,18 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None): cache (torch.Tensor) : (batch, time_cache_next, size) """ key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) + n_batch = value.size(0) if torch.is_autocast_enabled(): query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) - - # temporary until we solve this more gracefully + with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value) scale = 1 / self.s_d_k + + if mask is not None: + mask = mask.unsqueeze(1).logical_not() + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale) out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) out = self.linear_out(out) # (batch, time1, d_model) diff --git a/old_multi_head_attention.py b/old_multi_head_attention.py deleted file mode 100644 index 99941fa67ced..000000000000 --- a/old_multi_head_attention.py +++ /dev/null @@ -1,107 +0,0 @@ -import math -from functools import lru_cache -from typing import List, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from nemo.utils import avoid_float16_autocast_context -from nemo.collections.asr.parts.submodules.multi_head_attention import MultiHeadAttention - - - -class RelPositionMultiHeadAttention(MultiHeadAttention): - """Multi-Head Attention layer of Transformer-XL with support of relative positional encoding. - Paper: https://arxiv.org/abs/1901.02860 - Args: - n_head (int): number of heads - n_feat (int): size of the features - dropout_rate (float): dropout rate - """ - - def __init__(self, n_head, n_feat, dropout_rate, pos_bias_u, pos_bias_v, max_cache_len=0): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=max_cache_len) - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) - # these two learnable biases are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - if pos_bias_u is None or pos_bias_v is None: - self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) - # nn.init.normal_(self.pos_bias_u, 0.0, 0.02) - # nn.init.normal_(self.pos_bias_v, 0.0, 0.02) - nn.init.zeros_(self.pos_bias_u) - nn.init.zeros_(self.pos_bias_v) - else: - self.pos_bias_u = pos_bias_u - self.pos_bias_v = pos_bias_v - - def rel_shift(self, x): - """Compute relative positional encoding. - Args: - x (torch.Tensor): (batch, nheads, time, 2*time-1) - """ - b, h, qlen, pos_len = x.size() # (b, h, t1, t2) - # need to add a column of zeros on the left side of last dimension to perform the relative shifting - x = torch.nn.functional.pad(x, pad=(1, 0)) # (b, h, t1, t2+1) - x = x.view(b, h, -1, qlen) # (b, h, t2+1, t1) - # need to drop the first row - x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2) - return x - - def forward(self, query, key, value, mask, pos_emb, cache=None): - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - Args: - query (torch.Tensor): (batch, time1, size) - key (torch.Tensor): (batch, time2, size) - value(torch.Tensor): (batch, time2, size) - mask (torch.Tensor): (batch, time1, time2) - pos_emb (torch.Tensor) : (batch, time1, size) - cache (torch.Tensor) : (batch, time_cache, size) - - Returns: - output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention - cache (torch.Tensor) : (batch, time_cache_next, size) - """ - key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache) - - if torch.is_autocast_enabled(): - query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32) - - # temporary until we solve this more gracefully - with avoid_float16_autocast_context(): - q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) - - n_batch_pos = pos_emb.size(0) - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, time1, d_k) - - # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) - # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) - - # compute matrix b and matrix d - # (batch, head, time1, time2) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) - matrix_bd = self.rel_shift(matrix_bd) - # drops extra elements in the matrix_bd to match the matrix_ac's size - matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] - - scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2) - - out = self.forward_attention(v, scores, mask) - - if cache is None: - return out - else: - return out, cache diff --git a/sdpa_testing/sdpa_mha_benchmark.py b/sdpa_testing/sdpa_mha_benchmark.py index 04d035216d18..71febbf3c5d2 100644 --- a/sdpa_testing/sdpa_mha_benchmark.py +++ b/sdpa_testing/sdpa_mha_benchmark.py @@ -5,9 +5,10 @@ from old_multi_head_attention import MultiHeadAttention from nemo.utils import avoid_float16_autocast_context + torch.manual_seed(123) -device = "cpu" +device = "cuda" batch_size = 32 seq_len = 1024 d_model = 512 @@ -18,7 +19,7 @@ value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False) mask = torch.triu(mask, diagonal=1).bool() -# mask = None +mask = None attention_sdpa = SDPAMultiHeadAttention(n_head, d_model, 0.0).to(device) attention_original = MultiHeadAttention(n_head, d_model, 0.0).to(device) @@ -27,50 +28,57 @@ # attention_sdpa = torch.compile(attention_sdpa) # attention_original = torch.compile(attention_original) + def measure_time(attention, query, key, value, mask): with torch.no_grad(): timer = benchmark.Timer( - stmt='attention(query, key, value, mask)', + stmt='attention(query, key, value, mask);torch.cuda.synchronize()', setup='torch.cuda.synchronize()', globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask} ) torch.cuda.synchronize() - with torch.no_grad(), torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient = False, enable_flash = True): + with torch.no_grad(): results = timer.blocked_autorange(min_run_time=10) forward_time = results.mean output = attention(query, key, value, mask) return forward_time, output -def measure_backward_time(attention, query, key, value, mask): + +def measure_fwd_bwd_time(attention, query, key, value, mask): timer = benchmark.Timer( - stmt='loss.backward()', - setup=''' -torch.cuda.empty_cache() -output = attention(query, key, value, mask) -loss = output.sum() -torch.cuda.synchronize() -''', + stmt='loss=attention(query, key, value, mask).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()', + setup='torch.cuda.synchronize()', globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask} ) torch.cuda.synchronize() results = timer.blocked_autorange(min_run_time=10) - backward_time = results.mean - return backward_time + fwd_bwd_time = results.mean + return fwd_bwd_time + +time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask) +time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask) -# time_original, output_original = measure_time(attention_original, query, key, value, mask) -# time_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask) +print(f"Original implementation time: {time_fwd_original:.6f} seconds") +print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds") +print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%") -# print(f"Original implementation time: {time_original:.6f} seconds") -# print(f"SDPA implementation time: {time_sdpa:.6f} seconds") -# print(f"SDPA boost {(time_original - time_sdpa) / time_original * 100:.3f}%") +time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask) +time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask) +time_bwd_original = time_fwd_bwd_original - time_fwd_original +time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa -time_backward_original = measure_backward_time(attention_original, query, key, value, mask) -time_backward_sdpa = measure_backward_time(attention_sdpa, query, key, value, mask) +print(f"Original implementation backward time: {time_bwd_original:.6f} seconds") +print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds") +print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%") -print(f"Original implementation backward time: {time_backward_original:.6f} seconds") -print(f"SDPA implementation backward time: {time_backward_sdpa:.6f} seconds") -print(f"SDPA backward boost {(time_backward_original - time_backward_sdpa) / time_backward_original * 100:.3f}%") +print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}") -# print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}") +# Original implementation time: 0.017988 seconds +# SDPA implementation time: 0.012850 seconds +# SDPA boost 28.562% +# Original implementation backward time: 0.034470 seconds +# SDPA implementation backward time: 0.038126 seconds +# SDPA backward boost -10.608% +# Outputs are the same \ No newline at end of file diff --git a/sdpa_testing/sdpa_relpos_mha_benchmark.py b/sdpa_testing/sdpa_relpos_mha_benchmark.py index 9e83f7033201..f732fc7e72ed 100644 --- a/sdpa_testing/sdpa_relpos_mha_benchmark.py +++ b/sdpa_testing/sdpa_relpos_mha_benchmark.py @@ -18,65 +18,67 @@ value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False) mask = torch.triu(mask, diagonal=1).bool() -# mask = None +mask = None pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) attention_sdpa = SDPARelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()): original_param.data.copy_(sdpa_param.data) -attention_sdpa = torch.compile(attention_sdpa) -attention_original = torch.compile(attention_original) + +# attention_sdpa = torch.compile(attention_sdpa) +# attention_original = torch.compile(attention_original) + def measure_time(attention, query, key, value, mask, pos_emb): with torch.no_grad(): timer = benchmark.Timer( - stmt='attention(query, key, value, mask, pos_emb)', + stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()', setup='torch.cuda.synchronize()', globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} ) - torch.cuda.synchronize() - results = timer.blocked_autorange(min_run_time=10) - forward_time = results.mean - output = attention(query, key, value, mask, pos_emb) + + with torch.no_grad(): + results = timer.blocked_autorange(min_run_time=10) + forward_time = results.mean + output = attention(query, key, value, mask, pos_emb) return forward_time, output -def measure_backward_time(attention, query, key, value, mask, pos_emb): + +def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb): timer = benchmark.Timer( - stmt='loss.backward()', - setup=''' -output = attention(query, key, value, mask, pos_emb) -loss = output.sum() -torch.cuda.synchronize() -''', - globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} + stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()', + setup='torch.cuda.synchronize()', + globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} ) torch.cuda.synchronize() results = timer.blocked_autorange(min_run_time=10) - backward_time = results.mean - return backward_time - + fwd_bwd_time = results.mean + return fwd_bwd_time -time_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb) -time_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb) +time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb) +time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb) -print(f"Original implementation time: {time_original:.6f} seconds") -print(f"SDPA implementation time: {time_sdpa:.6f} seconds") -print(f"SDPA boost {(time_original - time_sdpa) / time_original * 100:.3f}%") +print(f"Original implementation time: {time_fwd_original:.6f} seconds") +print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds") +print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%") -time_backward_original = measure_backward_time(attention_original, query, key, value, mask, pos_emb) -time_backward_sdpa = measure_backward_time(attention_sdpa, query, key, value, mask, pos_emb) +time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb) +time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb) +time_bwd_original = time_fwd_bwd_original - time_fwd_original +time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa -print(f"Original implementation backward time: {time_backward_original:.6f} seconds") -print(f"SDPA implementation backward time: {time_backward_sdpa:.6f} seconds") -print(f"SDPA backward boost {(time_backward_original - time_backward_sdpa) / time_backward_original * 100:.3f}%") +print(f"Original implementation backward time: {time_bwd_original:.6f} seconds") +print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds") +print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%") print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}") -# Original implementation time: 0.042381 seconds -# SDPA implementation time: 0.028353 seconds -# SDPA boost 33.099% -# Original implementation backward time: 0.080170 seconds -# SDPA implementation backward time: 0.083670 seconds -# SDPA backward boost -4.365% + +# Original implementation time: 0.031546 seconds +# SDPA implementation time: 0.025657 seconds +# SDPA boost 18.67% +# Original implementation backward time: 0.069962 seconds +# SDPA implementation backward time: 0.078320 seconds +# SDPA backward boost -11.95% # Outputs are the same \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index f7ed76478019..000000000000 --- a/test.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.nn as nn -import torch.utils.benchmark as benchmark -from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention as NewRelPositionMultiHeadAttention -from old_multi_head_attention import RelPositionMultiHeadAttention -from nemo.utils import avoid_float16_autocast_context - -torch.manual_seed(123) - -device = "cuda" -batch_size = 32 -seq_len = 1024 -d_model = 512 -n_head = 8 - -query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) -mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False) -mask = torch.triu(mask, diagonal=1).bool() -# mask = None -pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True) - -# torch.manual_seed(123) -attention_sdpa = NewRelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) -# torch.manual_seed(123) -attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device) - -def measure_time(attention, query, key, value, mask, pos_emb): - with torch.no_grad(): - timer = benchmark.Timer( - stmt='attention(query, key, value, mask, pos_emb)', - setup='torch.cuda.synchronize()', - globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} - ) - torch.cuda.synchronize() - results = timer.blocked_autorange(min_run_time=10) - forward_time = results.mean - output = attention(query, key, value, mask, pos_emb) - return forward_time, output - -def measure_backward_time(attention, query, key, value, mask, pos_emb): - timer = benchmark.Timer( - stmt='loss.backward()', - setup=''' -query.grad = None -key.grad = None -value.grad = None -pos_emb.grad = None -output = attention(query, key, value, mask, pos_emb) -loss = output.sum() -torch.cuda.synchronize() -''', - globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb} - ) - torch.cuda.synchronize() - results = timer.blocked_autorange(min_run_time=10) - backward_time = results.mean - return backward_time - - -time_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb) -time_backward_original = measure_backward_time(attention_original, query, key, value, mask, pos_emb) - -time_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb) -time_backward_sdpa = measure_backward_time(attention_sdpa, query, key, value, mask, pos_emb) - -print(f"Original implementation time: {time_original:.6f} seconds") -print(f"SDPA implementation time: {time_sdpa:.6f} seconds") -print(f"SDPA boost {(time_original - time_sdpa) / time_original * 100:.3f}%") - -print(f"Original implementation backward time: {time_backward_original:.6f} seconds") -print(f"SDPA implementation backward time: {time_backward_sdpa:.6f} seconds") -print(f"SDPA backward boost {(time_backward_original - time_backward_sdpa) / time_backward_original * 100:.3f}%") - -print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}") -# Original implementation time: 0.042344 seconds -# SDPA implementation time: 0.028285 seconds -# SDPA boost 33.202% -# Original implementation backward time: 0.079229 seconds -# SDPA implementation backward time: 0.082796 seconds -# SDPA backward boost -4.502% -# Outputs are the same \ No newline at end of file