diff --git a/ding/torch_utils/loss/contrastive_loss.py b/ding/torch_utils/loss/contrastive_loss.py index ed17eebcb1..0871ebdd85 100644 --- a/ding/torch_utils/loss/contrastive_loss.py +++ b/ding/torch_utils/loss/contrastive_loss.py @@ -79,7 +79,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): x_n = x.view(-1, self._encode_shape) y_n = y.view(-1, self._encode_shape) - # Use inner product to obtain postive samples. + # Use inner product to obtain positive samples. # [N, x_heads, encode_dim] * [N, encode_dim, y_heads] -> [N, x_heads, y_heads] u_pos = torch.matmul(x, y.permute(0, 2, 1)).unsqueeze(2) # Use outer product to obtain all sample permutations. @@ -92,7 +92,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): u_neg = (n_mask * u_all) - (10. * (1 - n_mask)) u_neg = u_neg.view(N, N * x_heads, y_heads).unsqueeze(dim=1).expand(-1, x_heads, -1, -1) - # Concatenate postive and negative samples and apply log softmax. + # Concatenate positive and negative samples and apply log softmax. pred_lgt = torch.cat([u_pos, u_neg], dim=2) pred_log = F.log_softmax(pred_lgt * self._temperature, dim=2)