Skip to content

Commit

Permalink
Add mean reciprocal rank (MRR) computation in KGEModel (#8298)
Browse files Browse the repository at this point in the history
Greetings, thank you for confirming this request.
I tried to add mean reciprocal rank (MRR) in the
torch_geometric/nn/kge/base.py and examples/kge_fb15k_237.py.
(Please see the related discussion in
[here](#8256))
If there are any problems, please forgive me and freely edit the code.
Thanks!

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Nov 1, 2023
1 parent 4fa9466 commit fd7ff50
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Support MRR computation in `KGEModel.test()` ([#8298](https://github.com/pyg-team/pytorch_geometric/pull/8298))
- Added an example for model parallelism (`examples/multi_gpu/model_parallel.py`) ([#8309](https://github.com/pyg-team/pytorch_geometric/pull/8309))
- Added a tutorial for multi-node multi-GPU training with pure PyTorch ([#8071](https://github.com/pyg-team/pytorch_geometric/pull/8071))
- Added a multinode-multigpu example on `ogbn-papers100M` ([#8070](https://github.com/pyg-team/pytorch_geometric/pull/8070))
Expand Down
9 changes: 5 additions & 4 deletions examples/kge_fb15k_237.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def test(data):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if epoch % 25 == 0:
rank, hits = test(val_data)
rank, mrr, hits = test(val_data)
print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
f'Val Hits@10: {hits:.4f}')
f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')

rank, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}')
rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '
f'Test Hits@10: {hits_at_10:.4f}')
7 changes: 4 additions & 3 deletions test/nn/kge/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_complex():
loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, hits_at_10 = model.test(h, r, t, batch_size=5, log=False)
assert mean_rank <= 10
assert hits_at_10 == 1.0
mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
assert 0 <= mean_rank <= 10
assert 0 < mrr <= 1
assert hits == 1.0
7 changes: 4 additions & 3 deletions test/nn/kge/test_distmult.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_distmult():
loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, hits_at_10 = model.test(h, r, t, batch_size=5, log=False)
assert mean_rank <= 10
assert hits_at_10 == 1.0
mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
assert 0 <= mean_rank <= 10
assert 0 < mrr <= 1
assert hits == 1.0
7 changes: 4 additions & 3 deletions test/nn/kge/test_rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_rotate():
loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, hits_at_10 = model.test(h, r, t, batch_size=5, log=False)
assert mean_rank <= 10
assert hits_at_10 == 1.0
mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
assert 0 <= mean_rank <= 10
assert 0 < mrr <= 1
assert hits == 1.0
7 changes: 4 additions & 3 deletions test/nn/kge/test_transe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_transe():
loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, hits_at_10 = model.test(h, r, t, batch_size=5, log=False)
assert mean_rank <= 10
assert hits_at_10 == 1.0
mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
assert 0 <= mean_rank <= 10
assert 0 < mrr <= 1
assert hits == 1.0
12 changes: 7 additions & 5 deletions torch_geometric/nn/kge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def test(
batch_size: int,
k: int = 10,
log: bool = True,
) -> Tuple[float, float]:
r"""Evaluates the model quality by computing Mean Rank and
Hits @ :math:`k` across all possible tail entities.
) -> Tuple[float, float, float]:
r"""Evaluates the model quality by computing Mean Rank, MRR and
Hits@:math:`k` across all possible tail entities.
Args:
head_index (torch.Tensor): The head indices.
Expand All @@ -115,7 +115,7 @@ def test(
arange = range(head_index.numel())
arange = tqdm(arange) if log else arange

mean_ranks, hits_at_k = [], []
mean_ranks, reciprocal_ranks, hits_at_k = [], [], []
for i in arange:
h, r, t = head_index[i], rel_type[i], tail_index[i]

Expand All @@ -126,12 +126,14 @@ def test(
rank = int((torch.cat(scores).argsort(
descending=True) == t).nonzero().view(-1))
mean_ranks.append(rank)
reciprocal_ranks.append(1 / (rank + 1))
hits_at_k.append(rank < k)

mean_rank = float(torch.tensor(mean_ranks, dtype=torch.float).mean())
mrr = float(torch.tensor(reciprocal_ranks, dtype=torch.float).mean())
hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k)

return mean_rank, hits_at_k
return mean_rank, mrr, hits_at_k

@torch.no_grad()
def random_sample(
Expand Down

0 comments on commit fd7ff50

Please sign in to comment.