Skip to content

Commit

Permalink
style(eltociear): fix typo in contrastive_loss.py
Browse files Browse the repository at this point in the history
postive -> positive
  • Loading branch information
eltociear authored Jan 28, 2023
1 parent a6e88cb commit 1439148
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ding/torch_utils/loss/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
x_n = x.view(-1, self._encode_shape)
y_n = y.view(-1, self._encode_shape)

# Use inner product to obtain postive samples.
# Use inner product to obtain positive samples.
# [N, x_heads, encode_dim] * [N, encode_dim, y_heads] -> [N, x_heads, y_heads]
u_pos = torch.matmul(x, y.permute(0, 2, 1)).unsqueeze(2)
# Use outer product to obtain all sample permutations.
Expand All @@ -92,7 +92,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
u_neg = (n_mask * u_all) - (10. * (1 - n_mask))
u_neg = u_neg.view(N, N * x_heads, y_heads).unsqueeze(dim=1).expand(-1, x_heads, -1, -1)

# Concatenate postive and negative samples and apply log softmax.
# Concatenate positive and negative samples and apply log softmax.
pred_lgt = torch.cat([u_pos, u_neg], dim=2)
pred_log = F.log_softmax(pred_lgt * self._temperature, dim=2)

Expand Down

0 comments on commit 1439148

Please sign in to comment.