Skip to content

Commit

Permalink
sdpa work
Browse files Browse the repository at this point in the history
Signed-off-by: WoodieDudy <goshagks@gmail.com>
  • Loading branch information
WoodieDudy committed Aug 26, 2024
1 parent 6a32a7a commit fad3414
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ model:
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
use_pytorch_sdpa: true # use torch sdpa instead of manual attention

# Convolution module's params
conv_kernel_size: 9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ model:
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
use_pytorch_sdpa: true
use_pytorch_sdpa: true # use torch sdpa instead of manual attention

# Convolution module's params
conv_kernel_size: 9
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ def change_attention_model(
att_context_size=att_context_size,
pos_bias_u=None,
pos_bias_v=None,
# use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa=self.use_pytorch_sdpa,
)
elif self_attention_model == 'abs_pos':
new_attn = MultiHeadAttention(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def __init__(
dropout_rate: float,
proj_dim: Optional[int] = None,
adapter_strategy: MHAResidualAddAdapterStrategy = None,
use_pytorch_sdpa: bool = True,
):
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0)
super().__init__(n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa)

self.pre_norm = nn.LayerNorm(n_feat)

Expand Down Expand Up @@ -200,6 +201,7 @@ class MultiHeadAttentionAdapterConfig:
dropout_rate: float = 0.0
proj_dim: Optional[int] = None
adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig())
use_pytorch_sdpa: bool = True
_target_: str = "{0}.{1}".format(MultiHeadAttentionAdapter.__module__, MultiHeadAttentionAdapter.__name__)


Expand All @@ -225,9 +227,10 @@ def __init__(
dropout_rate: float,
proj_dim: Optional[int] = None,
adapter_strategy: MHAResidualAddAdapterStrategyConfig = None,
use_pytorch_sdpa: bool = False,
):
super().__init__(
n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0
n_head=n_head, n_feat=n_feat, dropout_rate=dropout_rate, pos_bias_u=None, pos_bias_v=None, max_cache_len=0, use_pytorch_sdpa=use_pytorch_sdpa
)

self.pre_norm = nn.LayerNorm(n_feat)
Expand Down Expand Up @@ -305,6 +308,7 @@ class RelPositionMultiHeadAttentionAdapterConfig:
dropout_rate: float = 0.0
proj_dim: Optional[int] = None
adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig())
use_pytorch_sdpa: bool = True
_target_: str = "{0}.{1}".format(
RelPositionMultiHeadAttentionAdapter.__module__, RelPositionMultiHeadAttentionAdapter.__name__
)
Expand Down
44 changes: 32 additions & 12 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
'PositionalEncoding',
]

inf_val = 10000.0


class MultiHeadAttention(nn.Module):
"""Multi-Head Attention layer of Transformer.
Expand Down Expand Up @@ -111,7 +113,7 @@ def forward_attention(self, value, scores, mask):
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, -10000.0)
scores = scores.masked_fill(mask, -inf_val)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
Expand Down Expand Up @@ -145,16 +147,20 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
q, k, v = self.forward_qkv(query, key, value)

if self.use_pytorch_sdpa:
scale = 1 / self.s_d_k
n_batch = value.size(0)

if mask is not None:
mask = mask.unsqueeze(1)
mask = ~mask.unsqueeze(1)

dropout_rate = self.dropout_rate if self.training else 0
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=dropout_rate, scale=scale
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout_rate)

# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(~mask, dim=-1)
all_masked_rows.unsqueeze_(-1)
out = out.masked_fill(all_masked_rows, 0.0)

out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
Expand Down Expand Up @@ -274,12 +280,20 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):

if mask is not None:
mask = mask.unsqueeze(1)
matrix_bd.masked_fill_(mask, -10000.0)
matrix_bd.masked_fill_(mask, -inf_val)

dropout_rate = self.dropout_rate if self.training else 0
out = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate, scale=scale_factor
q_with_bias_u, k, v, attn_mask=matrix_bd, dropout_p=dropout_rate
)

# this IF block can be deleted when https://github.com/pytorch/pytorch/pull/131863 is in the stable version
if mask is not None:
all_masked_rows = torch.all(mask, dim=-1)
all_masked_rows.unsqueeze_(-1)
all_masked_rows = all_masked_rows.expand(-1, out.size(1), -1, out.size(-1))
out = out.masked_fill(all_masked_rows, 0.0)

out = out.transpose(1, 2).reshape(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
out = self.linear_out(out) # (batch, time1, d_model)
else:
Expand Down Expand Up @@ -328,6 +342,7 @@ def __init__(
global_tokens_spacing=1,
global_attn_separate=False,
use_bias=True,
use_pytorch_sdpa=False,
):
"""Construct an RelPositionMultiHeadAttentionLongformer object."""
super().__init__(
Expand All @@ -338,7 +353,12 @@ def __init__(
pos_bias_v=pos_bias_v,
max_cache_len=max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=use_pytorch_sdpa,
)

if use_pytorch_sdpa:
raise NotImplementedError("Not implemented for Longformer yet")

self.att_context_size = att_context_size
self.global_tokens = global_tokens
self.global_tokens_spacing = global_tokens_spacing
Expand Down Expand Up @@ -410,14 +430,14 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):
# (batch, head, time, 2w + 1)

# mask invalid positions
scores[:, :, :, :start_pos] = -10000.0
scores[:, :, :, end_pos + 1 :] = -10000.0
scores[:, :, :, :start_pos] = -inf_val
scores[:, :, :, end_pos + 1 :] = -inf_val

# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (bsz x seq_len) to (bsz x num_heads x seqlen x hidden_size)
mask = mask.unsqueeze(dim=1).unsqueeze(dim=-1)
# cast to float/half then replace 1's with -inf
float_mask = mask.type_as(scores).masked_fill(mask, -10000.0)
float_mask = mask.type_as(scores).masked_fill(mask, -inf_val)
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding
d_mask = self.sliding_chunks_matmul_qk(ones, float_mask, w, padding_value=0.0)
Expand Down Expand Up @@ -950,7 +970,7 @@ def create_pe(self, positions, dtype):
pe = torch.zeros(pos_length, self.d_model, device=positions.device)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32, device=positions.device)
* -(math.log(10000.0) / self.d_model)
* -(math.log(inf_val) / self.d_model)
)
pe[:, 0::2] = torch.sin(positions * div_term)
pe[:, 1::2] = torch.cos(positions * div_term)
Expand Down
53 changes: 53 additions & 0 deletions tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,59 @@ def test_relmha_adapter_init(self, n_head, proj_dim):
assert out.sum().abs() <= 1e-8
assert out.shape == x.shape

@pytest.mark.unit
def test_relmha_adapter_with_torch_sdpa(self):
torch.random.manual_seed(0)
x = torch.randn(2, 32, 50)
lengths = torch.randint(1, x.size(1), size=(x.size(0),))
lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1)

adapter_torch_sdpa = adapter_modules.RelPositionMultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=True
)
adapter = adapter_modules.RelPositionMultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=False
)
# to dont reset linear_out parameters to zero
adapter.linear_out = torch.nn.Linear(adapter.linear_out.in_features, adapter.linear_out.out_features)
for original_param, sdpa_param in zip(adapter.parameters(), adapter_torch_sdpa.parameters()):
sdpa_param.data.copy_(original_param.data)
relpos_enc = adapter_modules.RelPositionalEncodingAdapter(d_model=50)

pad_mask, att_mask = get_mask(lengths)
relpos_enc.extend_pe(lengths.max(), device='cpu', dtype=torch.float32)

with torch.no_grad():
_, pos_emb = relpos_enc(x)
out = adapter(x, x, x, att_mask, pos_emb)
out_sdpa = adapter_torch_sdpa(x, x, x, att_mask, pos_emb)
assert torch.allclose(out_sdpa, out, atol=1e-5)

@pytest.mark.unit
def test_mha_adapter_with_torch_sdpa(self):
torch.random.manual_seed(0)
x = torch.randn(2, 32, 50)
lengths = torch.randint(1, x.size(1), size=(x.size(0),))
lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1)

adapter_torch_sdpa = adapter_modules.MultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=True
)
adapter = adapter_modules.MultiHeadAttentionAdapter(
n_head=2, n_feat=50, dropout_rate=0.0, proj_dim=-1, use_pytorch_sdpa=False
)
# to dont reset linear_out parameters to zero
adapter.linear_out = torch.nn.Linear(adapter.linear_out.in_features, adapter.linear_out.out_features)

for original_param, sdpa_param in zip(adapter.parameters(), adapter_torch_sdpa.parameters()):
sdpa_param.data.copy_(original_param.data)

pad_mask, att_mask = get_mask(lengths)
with torch.no_grad():
out = adapter(x, x, x, att_mask)
out_sdpa = adapter_torch_sdpa(x, x, x, att_mask)
assert torch.allclose(out_sdpa, out, atol=1e-5)

@pytest.mark.unit
def test_abspos_encoding_init(self):
torch.random.manual_seed(0)
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/asr/test_conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_stochastic_depth_model_creation(self):
def test_stochastic_depth_forward(self):
"""Testing that forward works and we get randomness during training, but not during eval."""
random_input = torch.rand((1, 2, 2))
random_length = torch.tensor([2, 2], dtype=torch.int64)
random_length = torch.tensor([2], dtype=torch.int64)

model = ConformerEncoder(
feat_in=2,
Expand Down

0 comments on commit fad3414

Please sign in to comment.