Skip to content

Commit

Permalink
sdpa work
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <goshagks@gmail.com>
  • Loading branch information
WoodieDudy committed Aug 21, 2024
1 parent d282a59 commit 8d292d6
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ model:
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
use_pytorch_sdpa: true # use torch sdpa instead of manual attention

# Convolution module's params
conv_kernel_size: 9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ model:
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
use_pytorch_sdpa: true
use_pytorch_sdpa: true # use torch sdpa instead of manual attention

# Convolution module's params
conv_kernel_size: 9
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ def change_attention_model(
att_context_size=att_context_size,
pos_bias_u=None,
pos_bias_v=None,
# use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'abs_pos':
new_attn = MultiHeadAttention(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def __init__(
dropout_rate: float,
proj_dim: Optional[int] = None,
adapter_strategy: MHAResidualAddAdapterStrategy = None,
use_pytorch_sdpa: bool = True,
):
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0)
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa)

self.pre_norm = nn.LayerNorm(n_feat)

Expand Down Expand Up @@ -200,6 +201,7 @@ class MultiHeadAttentionAdapterConfig:
dropout_rate: float = 0.0
proj_dim: Optional[int] = None
adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig())
use_pytorch_sdpa: bool = True
_target_: str = "{0}.{1}".format(MultiHeadAttentionAdapter.__module__, MultiHeadAttentionAdapter.__name__)


Expand All @@ -225,9 +227,10 @@ def __init__(
dropout_rate: float,
proj_dim: Optional[int] = None,
adapter_strategy: MHAResidualAddAdapterStrategyConfig = None,
use_pytorch_sdpa: bool = False,
):
super().__init__(
n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0
n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa
)

self.pre_norm = nn.LayerNorm(n_feat)
Expand Down Expand Up @@ -305,6 +308,7 @@ class RelPositionMultiHeadAttentionAdapterConfig:
dropout_rate: float = 0.0
proj_dim: Optional[int] = None
adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig())
use_pytorch_sdpa: bool = True
_target_: str = "{0}.{1}".format(
RelPositionMultiHeadAttentionAdapter.__module__, RelPositionMultiHeadAttentionAdapter.__name__
)
Expand Down
34 changes: 31 additions & 3 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,26 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
q, k, v = self.forward_qkv(query, key, value)

if self.use_pytorch_sdpa:
scale = 1 / self.s_d_k
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)
# add extra col for mask to handle problem with nan after solfmax
rows_all_false = torch.all(mask, dim=-1)
modified_tensor = torch.where(mask, torch.tensor(-10000.0), torch.tensor(0.0))
new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0))
mask = torch.cat([modified_tensor, new_column.unsqueeze(-1)], dim=-1).to(mask.device)

dropout_rate = self.dropout_rate if self.training else 0

# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate, scale=scale
q, k, v, attn_mask=mask, dropout_p=dropout_rate
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down Expand Up @@ -275,10 +286,21 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask, -10000.0)
# add extra col for mask (matrix_bd) to handle problem with nan after solfmax
rows_all_false = torch.all(mask, dim=-1)
new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0))
new_column = new_column.repeat(1, self.h, 1)
matrix_bd = torch.cat([matrix_bd, new_column.unsqueeze(-1)], dim=-1).to(matrix_bd.device)

dropout_rate = self.dropout_rate if self.training else 0
# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate, scale=scale_factor
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
Expand Down Expand Up @@ -328,6 +350,7 @@ def __init__(
global_tokens_spacing=1,
global_attn_separate=False,
use_bias=True,
use_pytorch_sdpa=False,
):
"""Construct an RelPositionMultiHeadAttentionLongformer object."""
super().__init__(
Expand All @@ -338,7 +361,12 @@ def __init__(
pos_bias_v=pos_bias_v,
max_cache_len=max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=use_pytorch_sdpa,
)

if use_pytorch_sdpa:
raise NotImplementedError("Not implemented for Longformer yet")

self.att_context_size = att_context_size
self.global_tokens = global_tokens
self.global_tokens_spacing = global_tokens_spacing
Expand Down
53 changes: 53 additions & 0 deletions tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,59 @@ def test_relmha_adapter_init(self, n_head, proj_dim):
assert out.sum().abs() <= 1e-8
assert out.shape == x.shape

@pytest.mark.unit
def test_relmha_adapter_with_torch_sdpa(self):
torch.random.manual_seed(0)
x = torch.randn(2, 32, 50)
lengths = torch.randint(1, x.size(1), size=(x.size(0),))
lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1)

adapter_torch_sdpa = adapter_modules.RelPositionMultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=True
)
adapter = adapter_modules.RelPositionMultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=False
)
# to dont reset linear_out parameters to zero
adapter.linear_out = torch.nn.Linear(adapter.linear_out.in_features, adapter.linear_out.out_features)
for original_param, sdpa_param in zip(adapter.parameters(), adapter_torch_sdpa.parameters()):
sdpa_param.data.copy_(original_param.data)
relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50)

pad_mask, att_mask = get_mask(lengths)
relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32)

with torch.no_grad():
_, pos_emb = relpos_enc(x)
out = adapter(x, x, x, att_mask, pos_emb)
out_sdpa = adapter_torch_sdpa(x, x, x, att_mask, pos_emb)
assert torch.allclose(out_sdpa, out, atol=1e-5)

@pytest.mark.unit
def test_mha_adapter_with_torch_sdpa(self):
torch.random.manual_seed(0)
x = torch.randn(2, 32, 50)
lengths = torch.randint(1, x.size(1), size=(x.size(0),))
lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1)

adapter_torch_sdpa = adapter_modules.MultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=True
)
adapter = adapter_modules.MultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=False
)
# to dont reset linear_out parameters to zero
adapter.linear_out = torch.nn.Linear(adapter.linear_out.in_features, adapter.linear_out.out_features)

for original_param, sdpa_param in zip(adapter.parameters(), adapter_torch_sdpa.parameters()):
sdpa_param.data.copy_(original_param.data)

pad_mask, att_mask = get_mask(lengths)
with torch.no_grad():
out = adapter(x, x, x, att_mask)
out_sdpa = adapter_torch_sdpa(x, x, x, att_mask)
assert torch.allclose(out_sdpa, out, atol=1e-5)

@pytest.mark.unit
def test_abspos_encoding_init(self):
torch.random.manual_seed(0)
Expand Down

0 comments on commit 8d292d6

Please sign in to comment.