Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
  • Loading branch information
WoodieDudy committed Sep 26, 2024
1 parent 41acec1 commit e2aab5b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
dropout_rate=dropout_rate,
max_cache_len=0,
use_pytorch_sdpa=use_pytorch_sdpa,
use_pytorch_sdpa_backends=use_pytorch_sdpa_backends
use_pytorch_sdpa_backends=use_pytorch_sdpa_backends,
)

self.pre_norm = nn.LayerNorm(n_feat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):

dropout_rate = self.dropout_rate if self.training else 0
with sdpa_kernel(self.use_pytorch_sdpa_backends):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate
)

# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
Expand Down

0 comments on commit e2aab5b

Please sign in to comment.