diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a8cff91ed..17028e270b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - minimal requirements issue ([#1147](https://github.com/catalyst-team/catalyst/issues/1147)) - nested dicts in `loaders_params`/`samplers_params` overriding fixed ([#1150](https://github.com/catalyst-team/catalyst/pull/1150)) +- fixed hitrate calculation issue ([#1155]) (https://github.com/catalyst-team/catalyst/issues/1155) ## [21.03.1] - 2021-03-28 diff --git a/catalyst/metrics/functional/_hitrate.py b/catalyst/metrics/functional/_hitrate.py index 04553cf8bd..7356a74fb6 100644 --- a/catalyst/metrics/functional/_hitrate.py +++ b/catalyst/metrics/functional/_hitrate.py @@ -5,9 +5,12 @@ from catalyst.metrics.functional._misc import process_recsys_components -def hitrate(outputs: torch.Tensor, targets: torch.Tensor, topk: List[int]) -> List[torch.Tensor]: +def hitrate( + outputs: torch.Tensor, targets: torch.Tensor, topk: List[int], zero_division: int = 0 +) -> List[torch.Tensor]: """ - Calculate the hit rate score given model outputs and targets. + Calculate the hit rate (aka recall) score given + model outputs and targets. Hit-rate is a metric for evaluating ranking systems. Generate top-N recommendations and if one of the recommendation is actually what user has rated, you consider that a hit. @@ -30,6 +33,9 @@ def hitrate(outputs: torch.Tensor, targets: torch.Tensor, topk: List[int]) -> Li ground truth, labels topk (List[int]): Parameter fro evaluation on top-k items + zero_division (int): + value, returns in the case of the divison by zero + should be one of 0 or 1 Returns: hitrate_at_k (List[torch.Tensor]): the hitrate score @@ -39,7 +45,8 @@ def hitrate(outputs: torch.Tensor, targets: torch.Tensor, topk: List[int]) -> Li targets_sort_by_outputs = process_recsys_components(outputs, targets) for k in topk: k = min(outputs.size(1), k) - hits_score = torch.sum(targets_sort_by_outputs[:, :k], dim=1) / k + hits_score = torch.sum(targets_sort_by_outputs[:, :k], dim=1) / targets.sum(dim=1) + hits_score = hits_score.nan_to_num(zero_division) results.append(torch.mean(hits_score)) return results diff --git a/catalyst/metrics/functional/tests/test_hitrate.py b/catalyst/metrics/functional/tests/test_hitrate.py index ab1799ef13..57a7b6a378 100644 --- a/catalyst/metrics/functional/tests/test_hitrate.py +++ b/catalyst/metrics/functional/tests/test_hitrate.py @@ -9,13 +9,16 @@ def test_hitrate(): """ Tests for catalyst.metrics.hitrate metric. """ - y_pred = [0.5, 0.2] - y_true = [1.0, 0.0] - k = [1, 2] + y_pred = [0.5, 0.2, 0.1] + y_true = [1.0, 0.0, 1.0] + k = [1, 2, 3] - hitrate_at1, hitrate_at2 = hitrate(torch.Tensor([y_pred]), torch.Tensor([y_true]), k) - assert hitrate_at1 == 1.0 + hitrate_at1, hitrate_at2, hitrate_at3 = hitrate( + torch.Tensor([y_pred]), torch.Tensor([y_true]), k + ) + assert hitrate_at1 == 0.5 assert hitrate_at2 == 0.5 + assert hitrate_at3 == 1.0 # check 1 simple case y_pred = [0.5, 0.2] @@ -24,3 +27,20 @@ def test_hitrate(): hitrate_at2 = hitrate(torch.Tensor([y_pred]), torch.Tensor([y_true]), k)[0] assert hitrate_at2 == 0.0 + + # check batch case + y_pred1 = [4.0, 2.0, 3.0, 1.0] + y_pred2 = [1.0, 2.0, 3.0, 4.0] + y_true1 = [0, 0, 1.0, 1.0] + y_true2 = [0, 0, 0.0, 0.0] + k = [1, 2, 3, 4] + + y_pred_torch = torch.Tensor([y_pred1, y_pred2]) + y_true_torch = torch.Tensor([y_true1, y_true2]) + + hitrate_at1, hitrate_at2, hitrate_at3, hitrate_at4 = hitrate(y_pred_torch, y_true_torch, k) + + assert hitrate_at1 == 0.0 + assert hitrate_at2 == 0.25 + assert hitrate_at3 == 0.25 + assert hitrate_at4 == 0.5