Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
  • Loading branch information
WoodieDudy committed Aug 21, 2024
1 parent 2be301e commit b5930a7
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def __init__(
adapter_strategy: MHAResidualAddAdapterStrategy = None,
use_pytorch_sdpa: bool = True,
):
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa)
super().__init__(
n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa
)

self.pre_norm = nn.LayerNorm(n_feat)

Expand Down Expand Up @@ -230,7 +232,13 @@ def __init__(
use_pytorch_sdpa: bool = False,
):
super().__init__(
n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa
n_head=n_head,
n_feat=n_feat,
dropout_rate=dropout_rate,
pos_bias_u=None,
pos_bias_v=None,
max_cache_len=0,
use_pytorch_sdpa=use_pytorch_sdpa,
)

self.pre_norm = nn.LayerNorm(n_feat)
Expand Down
10 changes: 4 additions & 6 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,11 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):

# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate
)
v = torch.cat([v, extra_column], dim=-2)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate)
out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
Expand Down Expand Up @@ -303,7 +301,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
dropout_rate = self.dropout_rate if self.training else 0
# add extra col for key and value to handle problem with nan after solfmax
extra_column = torch.zeros(k.shape[:-2] + (1, k.shape[-1])).to(k.device)
k = torch.cat([k, extra_column], dim=-2)
k = torch.cat([k, extra_column], dim=-2)

extra_column = torch.zeros(v.shape[:-2] + (1, v.shape[-1])).to(v.device)
v = torch.cat([v, extra_column], dim=-2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_relmha_adapter_with_torch_sdpa(self):
_, pos_emb = relpos_enc(x)
out = adapter(x, x, x, att_mask, pos_emb)
out_sdpa = adapter_torch_sdpa(x, x, x, att_mask, pos_emb)
assert torch.allclose(out_sdpa, out, atol=1e-5)
assert torch.allclose(out_sdpa, out, atol=1e-5)

@pytest.mark.unit
def test_mha_adapter_with_torch_sdpa(self):
Expand All @@ -229,7 +229,7 @@ def test_mha_adapter_with_torch_sdpa(self):
with torch.no_grad():
out = adapter(x, x, x, att_mask)
out_sdpa = adapter_torch_sdpa(x, x, x, att_mask)
assert torch.allclose(out_sdpa, out, atol=1e-5)
assert torch.allclose(out_sdpa, out, atol=1e-5)

@pytest.mark.unit
def test_abspos_encoding_init(self):
Expand Down
6 changes: 5 additions & 1 deletion tests/collections/asr/test_conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def test_stochastic_depth_model_creation(self):
for start_layer in [-1, 0, 5]:
with pytest.raises(ValueError, match="stochastic_depth_start_layer has to be in"):
ConformerEncoder(
feat_in=10, n_layers=n_layers, d_model=4, feat_out=8, stochastic_depth_start_layer=start_layer,
feat_in=10,
n_layers=n_layers,
d_model=4,
feat_out=8,
stochastic_depth_start_layer=start_layer,
)

@pytest.mark.pleasefixme
Expand Down

0 comments on commit b5930a7

Please sign in to comment.