diff --git a/src/fairseq2/models/wav2vec2/model.py b/src/fairseq2/models/wav2vec2/model.py index 9f457c4c9..e9f9bced7 100644 --- a/src/fairseq2/models/wav2vec2/model.py +++ b/src/fairseq2/models/wav2vec2/model.py @@ -252,6 +252,16 @@ def _sample_distractors(self, targets: Tensor) -> Tensor: return distractors + def cosine_similarity( + self, x1: torch.Tensor, x2: torch.Tensor, dim=1, eps=1e-8 + ) -> torch.Tensor: + # Normalize along the specified dimension + x1_norm = x1 / (x1.norm(dim=dim, dtype=x1.dtype).clamp(min=eps).unsqueeze(dim)) + x2_norm = x2 / (x2.norm(dim=dim, dtype=x2.dtype).clamp(min=eps).unsqueeze(dim)) + + # Compute dot product along the specified dimension + return torch.sum(x1_norm * x2_norm, dim=dim, dtype=x1.dtype) + def _compute_logits( self, seqs: Tensor, targets: Tensor, distractors: Tensor ) -> Tensor: @@ -264,7 +274,8 @@ def _compute_logits( # Perform in fp32. # (N, S, L + 1, M) -> (N, S, L + 1) - logits = torch.cosine_similarity(seqs.float(), candidates.float(), dim=-1) + # logits = torch.cosine_similarity(seqs.float(), candidates.float(), dim=-1) + logits = self.cosine_similarity(seqs, candidates, dim=-1) if self.logit_temp != 1.0: logits = logits / self.logit_temp diff --git a/src/fairseq2/models/wav2vec2/vector_quantizer.py b/src/fairseq2/models/wav2vec2/vector_quantizer.py index bb1570050..f6ddbc1fe 100644 --- a/src/fairseq2/models/wav2vec2/vector_quantizer.py +++ b/src/fairseq2/models/wav2vec2/vector_quantizer.py @@ -166,22 +166,26 @@ def forward(self, x: Tensor) -> "GumbelVectorQuantizerOutput": .scatter_(-1, k.view(-1, 1), 1.0) .view(bsz * tsz, self.num_codebooks, -1) ) - hard_probs = torch.mean(hard_x.float(), dim=0) + hard_probs = torch.mean(hard_x, dim=0) - code_perplexity = torch.exp( - -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) - ).sum() + @torch.compile(fullgraph=True) + def compute_code_perplexity(probs: torch.Tensor) -> torch.Tensor: + return torch.exp(-torch.sum(probs * torch.log(probs + 1e-7), dim=-1)).sum() + + code_perplexity = compute_code_perplexity(probs=hard_probs) + + @torch.compile(fullgraph=True) + def compute_softmax(x: torch.Tensor) -> torch.Tensor: + return torch.softmax( + x.view(bsz * tsz, self.num_codebooks, -1), dim=-1 + ).mean(dim=0) - avg_probs = torch.softmax( - x.view(bsz * tsz, self.num_codebooks, -1).float(), dim=-1 - ).mean(dim=0) + avg_probs = compute_softmax(x=x) - prob_perplexity = torch.exp( - -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) - ).sum() + prob_perplexity = compute_code_perplexity(probs=avg_probs) if self.training: - x = gumbel_softmax(x.float(), tau=current_temp, hard=True).type_as(x) + x = gumbel_softmax(x, tau=current_temp, hard=True).type_as(x) else: x = hard_x diff --git a/src/fairseq2/nn/transformer/attention.py b/src/fairseq2/nn/transformer/attention.py index 744e39d87..aca0ca658 100644 --- a/src/fairseq2/nn/transformer/attention.py +++ b/src/fairseq2/nn/transformer/attention.py @@ -162,15 +162,14 @@ def forward( else: mask = None - with _with_memory_efficient_kernel(self._enable_memory_efficient): - attn = scaled_dot_product_attention( - seqs, - keys, - values, - attn_mask=mask, - dropout_p=dropout_p, - is_causal=is_causal, - ) + attn = scaled_dot_product_attention( + seqs, + keys, + values, + attn_mask=mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) return attn, None