Skip to content

Commit

Permalink
Merge pull request #363 from donglihe-hub/predict
Browse files Browse the repository at this point in the history
Don't Calculate Loss in Prediction
  • Loading branch information
Eleven1Liu authored Mar 13, 2024
2 parents d9307a0 + b0b08f4 commit 2c41ee1
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions libmultilabel/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,18 @@ def predict_step(self, batch, batch_idx):
Returns:
dict: Top k label indexes and the prediction scores.
"""
_, pred_logits = self.shared_step(batch)
pred_logits = self(batch)
pred_scores = pred_logits.detach().cpu().numpy()
k = self.save_k_predictions
top_k_idx = argsort_top_k(pred_scores, k, axis=1)
top_k_scores = np.take_along_axis(pred_scores, top_k_idx, axis=1)

return {"top_k_pred": top_k_idx, "top_k_pred_scores": top_k_scores}

def forward(self, batch):
"""compute predicted logits"""
return self.network(batch)["logits"]

def print(self, *args, **kwargs):
"""Prints only from process 0 and not in silent mode. Use this in any
distributed mode to log only once."""
Expand Down Expand Up @@ -224,8 +228,7 @@ def shared_step(self, batch):
pred_logits (torch.Tensor): The predict logits (batch_size, num_classes).
"""
target_labels = batch["label"]
outputs = self.network(batch)
pred_logits = outputs["logits"]
pred_logits = self(batch)
loss = self.loss_function(pred_logits, target_labels.float())

return loss, pred_logits

0 comments on commit 2c41ee1

Please sign in to comment.