Skip to content

Commit

Permalink
Added the Freebase FB15k_237 dataset (#3204)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rusty1s and pre-commit-ci[bot] authored Dec 29, 2022
1 parent aa42868 commit 5feb18a
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added the Freebase `FB15k_237` dataset ([#3204](https://github.com/pyg-team/pytorch_geometric/pull/3204))
- Added `Data.update()` and `HeteroData.update()` functionality ([#6313](https://github.com/pyg-team/pytorch_geometric/pull/6313))
- Added `PGExplainer` ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6204))
- Added the `AirfRANS` dataset ([#6287](https://github.com/pyg-team/pytorch_geometric/pull/6287))
Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,16 @@ def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the graph."""
return self._store.num_edge_features

@property
def num_node_types(self) -> int:
r"""Returns the number of node types in the graph."""
return int(self.node_type.max()) + 1 if 'node_type' in self else 1

@property
def num_edge_types(self) -> int:
r"""Returns the number of edge types in the graph."""
return int(self.edge_type.max()) + 1 if 'edge_type' in self else 1

def __iter__(self) -> Iterable:
r"""Iterates over all attributes in the data, yielding their attribute
names and values."""
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .suite_sparse import SuiteSparseMatrixCollection
from .aminer import AMiner
from .word_net import WordNet18, WordNet18RR
from .freebase import FB15k_237
from .wikics import WikiCS
from .webkb import WebKB
from .wikipedia_network import WikipediaNetwork
Expand Down Expand Up @@ -135,6 +136,7 @@
'AMiner',
'WordNet18',
'WordNet18RR',
'FB15k_237',
'WikiCS',
'WebKB',
'WikipediaNetwork',
Expand Down
89 changes: 89 additions & 0 deletions torch_geometric/datasets/freebase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Callable, List, Optional

import torch

from torch_geometric.data import Data, InMemoryDataset, download_url


class FB15k_237(InMemoryDataset):
r"""The FB15K237 dataset from the `"Translating Embeddings for Modeling
Multi-Relational Data"
<https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling
-multi-relational-data>`_ paper,
containing 14,541 entities, 237 relations and 310,116 fact triples.
.. note::
The original :class:`FB15k` dataset suffers from major test leakage
through inverse relations, where a large number of test triples could
be obtained by inverting triples in the training set.
In order to create a dataset without this characteristic, the
:class:`~torch_geometric.datasets.FB15k_237` describes a subset of
:class:`FB15k` where inverse relations are removed.
Args:
root (string): Root directory where the dataset should be saved.
split (string): If :obj:`"train"`, loads the training dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
"""
url = ('https://raw.githubusercontent.com/villmow/'
'datasets_knowledge_embedding/master/FB15k-237')

def __init__(self, root: str, split: str = "train",
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
super().__init__(root, transform, pre_transform)

if split not in {'train', 'val', 'test'}:
raise ValueError(f"Invalid 'split' argument (got {split})")

path = self.processed_paths[['train', 'val', 'test'].index(split)]
self.data, self.slices = torch.load(path)

@property
def raw_file_names(self) -> List[str]:
return ['train.txt', 'valid.txt', 'test.txt']

@property
def processed_file_names(self) -> List[str]:
return ['train_data.pt', 'val_data.pt', 'test_data.pt']

def download(self):
for filename in self.raw_file_names:
download_url(f'{self.url}/{filename}', self.raw_dir)

def process(self):
data_list, node_dict, rel_dict = [], {}, {}
for path in self.raw_paths:
with open(path, 'r') as f:
data = [x.split('\t') for x in f.read().split('\n')[:-1]]

edge_index = torch.empty((2, len(data)), dtype=torch.long)
edge_type = torch.empty(len(data), dtype=torch.long)
for i, (src, rel, dst) in enumerate(data):
if src not in node_dict:
node_dict[src] = len(node_dict)
if dst not in node_dict:
node_dict[dst] = len(node_dict)
if rel not in rel_dict:
rel_dict[rel] = len(rel_dict)

edge_index[0, i] = node_dict[src]
edge_index[1, i] = node_dict[dst]
edge_type[i] = rel_dict[rel]

data = Data(edge_index=edge_index, edge_type=edge_type)
data_list.append(data)

for data, path in zip(data_list, self.processed_paths):
data.num_nodes = len(node_dict)
torch.save(self.collate([data]), path)

0 comments on commit 5feb18a

Please sign in to comment.