Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the Freebase FB15k_237 dataset #3204

Merged
merged 14 commits into from
Dec 29, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added the Freebase `FB15k_237` dataset ([#3204](https://github.com/pyg-team/pytorch_geometric/pull/3204))
- Added `Data.update()` and `HeteroData.update()` functionality ([#6313](https://github.com/pyg-team/pytorch_geometric/pull/6313))
- Added `PGExplainer` ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6204))
- Added the `AirfRANS` dataset ([#6287](https://github.com/pyg-team/pytorch_geometric/pull/6287))
Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,16 @@ def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the graph."""
return self._store.num_edge_features

@property
def num_node_types(self) -> int:
r"""Returns the number of node types in the graph."""
return int(self.node_type.max()) + 1 if 'node_type' in self else 1

@property
def num_edge_types(self) -> int:
r"""Returns the number of edge types in the graph."""
return int(self.edge_type.max()) + 1 if 'edge_type' in self else 1

def __iter__(self) -> Iterable:
r"""Iterates over all attributes in the data, yielding their attribute
names and values."""
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .suite_sparse import SuiteSparseMatrixCollection
from .aminer import AMiner
from .word_net import WordNet18, WordNet18RR
from .freebase import FB15k_237
from .wikics import WikiCS
from .webkb import WebKB
from .wikipedia_network import WikipediaNetwork
Expand Down Expand Up @@ -135,6 +136,7 @@
'AMiner',
'WordNet18',
'WordNet18RR',
'FB15k_237',
'WikiCS',
'WebKB',
'WikipediaNetwork',
Expand Down
89 changes: 89 additions & 0 deletions torch_geometric/datasets/freebase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Callable, List, Optional

import torch

from torch_geometric.data import Data, InMemoryDataset, download_url


class FB15k_237(InMemoryDataset):
r"""The FB15K237 dataset from the `"Translating Embeddings for Modeling
Multi-Relational Data"
<https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling
-multi-relational-data>`_ paper,
containing 14,541 entities, 237 relations and 310,116 fact triples.

.. note::

The original :class:`FB15k` dataset suffers from major test leakage
through inverse relations, where a large number of test triples could
be obtained by inverting triples in the training set.
In order to create a dataset without this characteristic, the
:class:`~torch_geometric.datasets.FB15k_237` describes a subset of
:class:`FB15k` where inverse relations are removed.

Args:
root (string): Root directory where the dataset should be saved.
split (string): If :obj:`"train"`, loads the training dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
"""
url = ('https://raw.githubusercontent.com/villmow/'
'datasets_knowledge_embedding/master/FB15k-237')

def __init__(self, root: str, split: str = "train",
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)

if split not in {'train', 'val', 'test'}:
raise ValueError(f"Invalid 'split' argument (got {split})")

path = self.processed_paths[['train', 'val', 'test'].index(split)]
self.data, self.slices = torch.load(path)

@property
def raw_file_names(self) -> List[str]:
return ['train.txt', 'valid.txt', 'test.txt']

@property
def processed_file_names(self) -> List[str]:
return ['train_data.pt', 'val_data.pt', 'test_data.pt']

def download(self):
for filename in self.raw_file_names:
download_url(f'{self.url}/{filename}', self.raw_dir)

def process(self):
data_list, node_dict, rel_dict = [], {}, {}
for path in self.raw_paths:
with open(path, 'r') as f:
data = [x.split('\t') for x in f.read().split('\n')[:-1]]

edge_index = torch.empty((2, len(data)), dtype=torch.long)
edge_type = torch.empty(len(data), dtype=torch.long)
for i, (src, rel, dst) in enumerate(data):
if src not in node_dict:
node_dict[src] = len(node_dict)
if dst not in node_dict:
node_dict[dst] = len(node_dict)
if rel not in rel_dict:
rel_dict[rel] = len(rel_dict)

edge_index[0, i] = node_dict[src]
edge_index[1, i] = node_dict[dst]
edge_type[i] = rel_dict[rel]

data = Data(edge_index=edge_index, edge_type=edge_type)
data_list.append(data)

for data, path in zip(data_list, self.processed_paths):
data.num_nodes = len(node_dict)
torch.save(self.collate([data]), path)