From 8d292d6ef218d5dc5f40878cde2052448e08aabb Mon Sep 17 00:00:00 2001 From: WoodieDudy Date: Wed, 21 Aug 2024 21:23:55 +0500 Subject: [PATCH] sdpa work Signed-off-by: WoodieDudy --- .../fastconformer/fast-conformer_ctc_bpe.yaml | 1 + .../fast-conformer_transducer_bpe.yaml | 2 +- .../asr/modules/conformer_encoder.py | 2 +- .../multi_head_attention_adapter_module.py | 8 ++- .../parts/submodules/multi_head_attention.py | 34 ++++++++++-- .../adapters/test_asr_adapter_modules.py | 53 +++++++++++++++++++ 6 files changed, 93 insertions(+), 7 deletions(-) diff --git a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml index 9b51edf614b8..3320a6c2ae77 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml @@ -136,6 +136,7 @@ model: xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 + use_pytorch_sdpa: true # use torch sdpa instead of manual attention # Convolution module's params conv_kernel_size: 9 diff --git a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml index c6fa75ddce1c..a6d8e0e8599d 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml @@ -145,7 +145,7 @@ model: xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 - use_pytorch_sdpa: true + use_pytorch_sdpa: true # use torch sdpa instead of manual attention # Convolution module's params conv_kernel_size: 9 diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index ea887918b143..7934e2afce9c 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -1044,7 +1044,7 @@ def change_attention_model( att_context_size=att_context_size, pos_bias_u=None, pos_bias_v=None, - # use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa=self.use_pytorch_sdpa, ) elif self_attention_model == 'abs_pos': new_attn = MultiHeadAttention( diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 2617ed6f575b..2bfdaacf9dad 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -134,8 +134,9 @@ def __init__( dropout_rate: float, proj_dim: Optional[int] = None, adapter_strategy: MHAResidualAddAdapterStrategy = None, + use_pytorch_sdpa: bool = True, ): - super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0) + super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa) self.pre_norm = nn.LayerNorm(n_feat) @@ -200,6 +201,7 @@ class MultiHeadAttentionAdapterConfig: dropout_rate: float = 0.0 proj_dim: Optional[int] = None adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig()) + use_pytorch_sdpa: bool = True _target_: str = "{0}.{1}".format(MultiHeadAttentionAdapter.__module__, MultiHeadAttentionAdapter.__name__) @@ -225,9 +227,10 @@ def __init__( dropout_rate: float, proj_dim: Optional[int] = None, adapter_strategy: MHAResidualAddAdapterStrategyConfig = None, + use_pytorch_sdpa: bool = False, ): super().__init__( - n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0 + n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa ) self.pre_norm = nn.LayerNorm(n_feat) @@ -305,6 +308,7 @@ class RelPositionMultiHeadAttentionAdapterConfig: dropout_rate: float = 0.0 proj_dim: Optional[int] = None adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig()) + use_pytorch_sdpa: bool = True _target_: str = "{0}.{1}".format( RelPositionMultiHeadAttentionAdapter.__module__, RelPositionMultiHeadAttentionAdapter.__name__ ) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index e873eca78468..f359ea07c1c1 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -145,15 +145,26 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None): q, k, v = self.forward_qkv(query, key, value) if self.use_pytorch_sdpa: - scale = 1 / self.s_d_k n_batch = value.size(0) if mask is not None: mask = mask.unsqueeze(1) + # add extra col for mask to handle problem with nan after solfmax + rows_all_false = torch.all(mask, dim=-1) + modified_tensor = torch.where(mask, torch.tensor(-10000.0), torch.tensor(0.0)) + new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0)) + mask = torch.cat([modified_tensor, new_column.unsqueeze(-1)], dim=-1).to(mask.device) dropout_rate = self.dropout_rate if self.training else 0 + + # add extra col for key and value to handle problem with nan after solfmax + extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device) + k = torch.cat([k, extra_column], dim=-2) + + extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device) + v = torch.cat([v, extra_column], dim=-2) out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=dropout_rate, scale=scale + q, k, v, attn_mask=mask, dropout_p=dropout_rate ) out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) out = self.linear_out(out) # (batch, time1, d_model) @@ -275,10 +286,21 @@ def forward(self, query, key, value, mask, pos_emb, cache=None): if mask is not None: mask = mask.unsqueeze(1) matrix_bd.masked_fill_(mask, -10000.0) + # add extra col for mask (matrix_bd) to handle problem with nan after solfmax + rows_all_false = torch.all(mask, dim=-1) + new_column = torch.where(rows_all_false, torch.tensor(10000.0), torch.tensor(-10000.0)) + new_column = new_column.repeat(1, self.h, 1) + matrix_bd = torch.cat([matrix_bd, new_column.unsqueeze(-1)], dim=-1).to(matrix_bd.device) dropout_rate = self.dropout_rate if self.training else 0 + # add extra col for key and value to handle problem with nan after solfmax + extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device) + k = torch.cat([k, extra_column], dim=-2) + + extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device) + v = torch.cat([v, extra_column], dim=-2) out = torch.nn.functional.scaled_dot_product_attention( - q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate, scale=scale_factor + q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate ) out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) out = self.linear_out(out) # (batch, time1, d_model) @@ -328,6 +350,7 @@ def __init__( global_tokens_spacing=1, global_attn_separate=False, use_bias=True, + use_pytorch_sdpa=False, ): """Construct an RelPositionMultiHeadAttentionLongformer object.""" super().__init__( @@ -338,7 +361,12 @@ def __init__( pos_bias_v=pos_bias_v, max_cache_len=max_cache_len, use_bias=use_bias, + use_pytorch_sdpa=use_pytorch_sdpa, ) + + if use_pytorch_sdpa: + raise NotImplementedError("Not implemented for Longformer yet") + self.att_context_size = att_context_size self.global_tokens = global_tokens self.global_tokens_spacing = global_tokens_spacing diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index ffaf1e640f3e..f15892707201 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -178,6 +178,59 @@ def test_relmha_adapter_init(self, n_head, proj_dim): assert out.sum().abs() <= 1e-8 assert out.shape == x.shape + @pytest.mark.unit + def test_relmha_adapter_with_torch_sdpa(self): + torch.random.manual_seed(0) + x = torch.randn(2, 32, 50) + lengths = torch.randint(1, x.size(1), size=(x.size(0),)) + lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1) + + adapter_torch_sdpa = adapter_modules.RelPositionMultiHeadAttentionAdapter( + n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=True + ) + adapter = adapter_modules.RelPositionMultiHeadAttentionAdapter( + n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=False + ) + # to dont reset linear_out parameters to zero + adapter.linear_out = torch.nn.Linear(adapter.linear_out.in_features, adapter.linear_out.out_features) + for original_param, sdpa_param in zip(adapter.parameters(), adapter_torch_sdpa.parameters()): + sdpa_param.data.copy_(original_param.data) + relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50) + + pad_mask, att_mask = get_mask(lengths) + relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32) + + with torch.no_grad(): + _, pos_emb = relpos_enc(x) + out = adapter(x, x, x, att_mask, pos_emb) + out_sdpa = adapter_torch_sdpa(x, x, x, att_mask, pos_emb) + assert torch.allclose(out_sdpa, out, atol=1e-5) + + @pytest.mark.unit + def test_mha_adapter_with_torch_sdpa(self): + torch.random.manual_seed(0) + x = torch.randn(2, 32, 50) + lengths = torch.randint(1, x.size(1), size=(x.size(0),)) + lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1) + + adapter_torch_sdpa = adapter_modules.MultiHeadAttentionAdapter( + n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=True + ) + adapter = adapter_modules.MultiHeadAttentionAdapter( + n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=False + ) + # to dont reset linear_out parameters to zero + adapter.linear_out = torch.nn.Linear(adapter.linear_out.in_features, adapter.linear_out.out_features) + + for original_param, sdpa_param in zip(adapter.parameters(), adapter_torch_sdpa.parameters()): + sdpa_param.data.copy_(original_param.data) + + pad_mask, att_mask = get_mask(lengths) + with torch.no_grad(): + out = adapter(x, x, x, att_mask) + out_sdpa = adapter_torch_sdpa(x, x, x, att_mask) + assert torch.allclose(out_sdpa, out, atol=1e-5) + @pytest.mark.unit def test_abspos_encoding_init(self): torch.random.manual_seed(0)