Skip to content

Commit

Permalink
benchmark backward
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <goshagks@yandex.ru>
  • Loading branch information
WoodieDudy authored and kramarenko.gs committed Jul 18, 2024
1 parent 230913b commit fafcd0a
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 252 deletions.
9 changes: 7 additions & 2 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MultiHeadAttention(nn.Module):
def __init__(self, n_head, n_feat, dropout_rate, max_cache_len=0):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadAttention, self).__init__()
self.dropout_rate = dropout_rate
self.cache_drop_size = None
assert n_feat % n_head == 0
# We assume d_v always equals d_k
Expand Down Expand Up @@ -132,14 +133,18 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
cache (torch.Tensor) : (batch, time_cache_next, size)
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)
n_batch = value.size(0)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

# temporary until we solve this more gracefully

with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
scale = 1 / self.s_d_k

if mask is not None:
mask = mask.unsqueeze(1).logical_not()

out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout_rate, scale=scale)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down
107 changes: 0 additions & 107 deletions old_multi_head_attention.py

This file was deleted.

58 changes: 33 additions & 25 deletions sdpa_testing/sdpa_mha_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from old_multi_head_attention import MultiHeadAttention
from nemo.utils import avoid_float16_autocast_context


torch.manual_seed(123)

device = "cpu"
device = "cuda"
batch_size = 32
seq_len = 1024
d_model = 512
Expand All @@ -18,7 +19,7 @@
value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False)
mask = torch.triu(mask, diagonal=1).bool()
# mask = None
mask = None

attention_sdpa = SDPAMultiHeadAttention(n_head, d_model, 0.0).to(device)
attention_original = MultiHeadAttention(n_head, d_model, 0.0).to(device)
Expand All @@ -27,50 +28,57 @@
# attention_sdpa = torch.compile(attention_sdpa)
# attention_original = torch.compile(attention_original)


def measure_time(attention, query, key, value, mask):
with torch.no_grad():
timer = benchmark.Timer(
stmt='attention(query, key, value, mask)',
stmt='attention(query, key, value, mask);torch.cuda.synchronize()',
setup='torch.cuda.synchronize()',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask}
)
torch.cuda.synchronize()

with torch.no_grad(), torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient = False, enable_flash = True):
with torch.no_grad():
results = timer.blocked_autorange(min_run_time=10)
forward_time = results.mean
output = attention(query, key, value, mask)
return forward_time, output

def measure_backward_time(attention, query, key, value, mask):

def measure_fwd_bwd_time(attention, query, key, value, mask):
timer = benchmark.Timer(
stmt='loss.backward()',
setup='''
torch.cuda.empty_cache()
output = attention(query, key, value, mask)
loss = output.sum()
torch.cuda.synchronize()
''',
stmt='loss=attention(query, key, value, mask).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()',
setup='torch.cuda.synchronize()',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask}
)
torch.cuda.synchronize()
results = timer.blocked_autorange(min_run_time=10)
backward_time = results.mean
return backward_time
fwd_bwd_time = results.mean
return fwd_bwd_time


time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask)
time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask)

# time_original, output_original = measure_time(attention_original, query, key, value, mask)
# time_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask)
print(f"Original implementation time: {time_fwd_original:.6f} seconds")
print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")

# print(f"Original implementation time: {time_original:.6f} seconds")
# print(f"SDPA implementation time: {time_sdpa:.6f} seconds")
# print(f"SDPA boost {(time_original - time_sdpa) / time_original * 100:.3f}%")
time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask)
time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask)
time_bwd_original = time_fwd_bwd_original - time_fwd_original
time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa

time_backward_original = measure_backward_time(attention_original, query, key, value, mask)
time_backward_sdpa = measure_backward_time(attention_sdpa, query, key, value, mask)
print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")

print(f"Original implementation backward time: {time_backward_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_backward_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_backward_original - time_backward_sdpa) / time_backward_original * 100:.3f}%")
print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")

# print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")
# Original implementation time: 0.017988 seconds
# SDPA implementation time: 0.012850 seconds
# SDPA boost 28.562%
# Original implementation backward time: 0.034470 seconds
# SDPA implementation backward time: 0.038126 seconds
# SDPA backward boost -10.608%
# Outputs are the same
72 changes: 37 additions & 35 deletions sdpa_testing/sdpa_relpos_mha_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,65 +18,67 @@
value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False)
mask = torch.triu(mask, diagonal=1).bool()
# mask = None
mask = None
pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)

attention_sdpa = SDPARelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device)
attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None).to(device)
for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()):
original_param.data.copy_(sdpa_param.data)
attention_sdpa = torch.compile(attention_sdpa)
attention_original = torch.compile(attention_original)

# attention_sdpa = torch.compile(attention_sdpa)
# attention_original = torch.compile(attention_original)


def measure_time(attention, query, key, value, mask, pos_emb):
with torch.no_grad():
timer = benchmark.Timer(
stmt='attention(query, key, value, mask, pos_emb)',
stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()',
setup='torch.cuda.synchronize()',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
)
torch.cuda.synchronize()
results = timer.blocked_autorange(min_run_time=10)
forward_time = results.mean
output = attention(query, key, value, mask, pos_emb)

with torch.no_grad():
results = timer.blocked_autorange(min_run_time=10)
forward_time = results.mean
output = attention(query, key, value, mask, pos_emb)
return forward_time, output

def measure_backward_time(attention, query, key, value, mask, pos_emb):

def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb):
timer = benchmark.Timer(
stmt='loss.backward()',
setup='''
output = attention(query, key, value, mask, pos_emb)
loss = output.sum()
torch.cuda.synchronize()
''',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()',
setup='torch.cuda.synchronize()',
globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
)
torch.cuda.synchronize()
results = timer.blocked_autorange(min_run_time=10)
backward_time = results.mean
return backward_time

fwd_bwd_time = results.mean
return fwd_bwd_time

time_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb)

time_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb)
time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb)

print(f"Original implementation time: {time_original:.6f} seconds")
print(f"SDPA implementation time: {time_sdpa:.6f} seconds")
print(f"SDPA boost {(time_original - time_sdpa) / time_original * 100:.3f}%")
print(f"Original implementation time: {time_fwd_original:.6f} seconds")
print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")

time_backward_original = measure_backward_time(attention_original, query, key, value, mask, pos_emb)
time_backward_sdpa = measure_backward_time(attention_sdpa, query, key, value, mask, pos_emb)
time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb)
time_bwd_original = time_fwd_bwd_original - time_fwd_original
time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa

print(f"Original implementation backward time: {time_backward_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_backward_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_backward_original - time_backward_sdpa) / time_backward_original * 100:.3f}%")
print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")

print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")
# Original implementation time: 0.042381 seconds
# SDPA implementation time: 0.028353 seconds
# SDPA boost 33.099%
# Original implementation backward time: 0.080170 seconds
# SDPA implementation backward time: 0.083670 seconds
# SDPA backward boost -4.365%

# Original implementation time: 0.031546 seconds
# SDPA implementation time: 0.025657 seconds
# SDPA boost 18.67%
# Original implementation backward time: 0.069962 seconds
# SDPA implementation backward time: 0.078320 seconds
# SDPA backward boost -11.95%
# Outputs are the same
Loading

0 comments on commit fafcd0a

Please sign in to comment.