Skip to content

Commit

Permalink
[Documentation] Data Splitting (#8366)
Browse files Browse the repository at this point in the history
Part of #7892
Fixed #10004

---------

Co-authored-by: Rishi Puri <puririshi98@berkeley.edu>
  • Loading branch information
xnuohz and puririshi98 authored Feb 10, 2025
1 parent bb6601c commit 0d142bb
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added Data Splitting Tutorial ([#8366](https://github.com/pyg-team/pytorch_geometric/pull/8366))
- Added `Diversity` metric for link prediction ([#10009](https://github.com/pyg-team/pytorch_geometric/pull/10009))
- Added `Coverage` metric for link prediction ([#10006](https://github.com/pyg-team/pytorch_geometric/pull/10006))
- Added Graph Transformer Tutorial ([#8144](https://github.com/pyg-team/pytorch_geometric/pull/8144))
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 11 additions & 18 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,21 @@
typehints_defaults = 'comma'

nbsphinx_thumbnails = {
'tutorial/create_gnn':
'_static/thumbnails/create_gnn.png',
'tutorial/heterogeneous':
'_static/thumbnails/heterogeneous.png',
'tutorial/create_dataset':
'_static/thumbnails/create_dataset.png',
'tutorial/load_csv':
'_static/thumbnails/load_csv.png',
'tutorial/neighbor_loader':
'_static/thumbnails/neighbor_loader.png',
'tutorial/point_cloud':
'_static/thumbnails/point_cloud.png',
'tutorial/explain':
'_static/thumbnails/explain.png',
'tutorial/create_gnn': '_static/thumbnails/create_gnn.png',
'tutorial/heterogeneous': '_static/thumbnails/heterogeneous.png',
'tutorial/create_dataset': '_static/thumbnails/create_dataset.png',
'tutorial/load_csv': '_static/thumbnails/load_csv.png',
'tutorial/dataset_splitting': '_static/thumbnails/dataset_splitting.png',
'tutorial/neighbor_loader': '_static/thumbnails/neighbor_loader.png',
'tutorial/point_cloud': '_static/thumbnails/point_cloud.png',
'tutorial/explain': '_static/thumbnails/explain.png',
'tutorial/shallow_node_embeddings':
'_static/thumbnails/shallow_node_embeddings.png',
'tutorial/distributed_pyg':
'_static/thumbnails/distributed_pyg.png',
'tutorial/multi_gpu_vanilla':
'_static/thumbnails/multi_gpu_vanilla.png',
'tutorial/distributed_pyg': '_static/thumbnails/distributed_pyg.png',
'tutorial/multi_gpu_vanilla': '_static/thumbnails/multi_gpu_vanilla.png',
'tutorial/multi_node_multi_gpu_vanilla':
'_static/thumbnails/multi_gpu_vanilla.png',
'tutorial/graph_transformer': '_static/thumbnails/graph_transformer.png',
}


Expand Down
1 change: 1 addition & 0 deletions docs/source/tutorial/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Working with Graph Datasets

create_dataset
load_csv
dataset_splitting
151 changes: 151 additions & 0 deletions docs/source/tutorial/dataset_splitting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
Dataset Splitting
=================

Dataset splitting is a critical step in graph machine learning, where we divide our dataset into subsets for training, validation, and testing.
It ensures that our models are evaluated properly, preventing overfitting, and enabling generalization.
In this tutorial, we will explore the basics of dataset splitting, focusing on three fundamental tasks: node prediction, link prediction, and graph prediction.
We will introduce commonly used techniques, including :class:`~torch_geometric.transforms.RandomNodeSplit` and :class:`~torch_geometric.transforms.RandomLinkSplit` transformations.
Additionally, we will also cover how to create custom dataset splits beyond random ones.

Node Prediction
---------------

.. note::

In this section, we'll learn how to use :class:`~torch_geometric.transforms.RandomNodeSplit` of :pyg:`PyG` to randomly divide nodes into training, validation, and test sets.
A fully working example on dataset :class:`~torch_geometric.datasets.Planetoid` is available in `examples/cora.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cora.py>`_.

The :class:`~torch_geometric.transforms.RandomNodeSplit` is initialized to split nodes for both a :pyg:`PyG` :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` object.

* :obj:`split` defines the dataset's split type.
* :obj:`num_splits` defines the number of splits to add.
* :obj:`num_train_per_class` defines the number of training nodes per class.
* :obj:`num_val` defines the number of validation nodes after data splitting.
* :obj:`num_test` defines the number of test nodes after data splitting.
* :obj:`key` defines the name of the ground-truth labels.

.. code-block:: python
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RandomNodeSplit
x = torch.randn(8, 32) # Node features of shape [num_nodes, num_features]
y = torch.randint(0, 4, (8, )) # Node labels of shape [num_nodes]
edge_index = torch.tensor([
[2, 3, 3, 4, 5, 6, 7],
[0, 0, 1, 1, 2, 3, 4]],
)
# 0 1
# / \/ \
# 2 3 4
# | | |
# 5 6 7
data = Data(x=x, y=y, edge_index=edge_index)
node_transform = RandomNodeSplit(num_val=2, num_test=3)
node_splits = node_transform(data)
Here, we initialize a :class:`~torch_geometric.transforms.RandomNodeSplit` transformation to split the graph data by nodes.
After the transformation, :obj:`train_mask`, :obj:`valid_mask` and :obj:`test_mask` will be attached to the graph data.

.. code-block:: python
node_splits.train_mask
>>> tensor([ True, False, False, False, True, True, False, False])
node_splits.val_mask
>>> tensor([False, False, False, False, False, False, True, True])
node_splits.test_mask
>>> tensor([False, True, True, True, False, False, False, False])
In this example, there are 8 nodes, we want to sample 2 nodes for validation, 3 nodes for testing, and the rest for training.
Finally, we got node :obj:`0, 4, 5` as training set, node :obj:`6, 7` as validation set, and node :obj:`1, 2, 3` as test set.

Link Prediction
---------------

.. note::

In this section, we'll learn how to use :class:`~torch_geometric.transforms.RandomLinkSplit` of :pyg:`PyG` to randomly divide edges into training, validation, and test sets.
A fully working example on dataset :class:`~torch_geometric.datasets.Planetoid` is available in `examples/link_pred.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py>`_.

The :class:`~torch_geometric.transforms.RandomLinkSplit` is initialized to split edges for both a :pyg:`PyG` :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` object.

* :obj:`num_val` defines the number of validation edges after data splitting.
* :obj:`num_test` defines the number of test edges after data splitting.
* :obj:`is_undirected` defines whether the graph is assumed as undirected.

.. code-block:: python
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
x = torch.randn(8, 32) # Node features of shape [num_nodes, num_features]
y = torch.randint(0, 4, (8, )) # Node labels of shape [num_nodes]
edge_index = torch.tensor([
[2, 3, 3, 4, 5, 6, 7],
[0, 0, 1, 1, 2, 3, 4]],
)
edge_y = torch.tensor([0, 0, 0, 0, 1, 1, 1])
# 0 1
# / \/ \
# 2 3 4
# | | |
# 5 6 7
data = Data(x=x, y=y, edge_index=edge_index, edge_y=edge_y)
edge_transform = RandomLinkSplit(num_val=0.2, num_test=0.2, key='edge_y',
is_undirected=False, add_negative_train_samples=False)
train_data, val_data, test_data = edge_transform(data)
Similar to node splitting, we initialize a :class:`~torch_geometric.transforms.RandomLinkSplit` transformation to split the graph data by edges.
Below, we can see the splitting results.

.. code-block:: python
train_data
>>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[5], edge_y_index=[2, 5])
val_data
>>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[2], edge_y_index=[2, 2])
test_data
>>> Data(x=[8, 32], edge_index=[2, 6], y=[8], edge_y=[2], edge_y_index=[2, 2])
:obj:`train_data.edge_index` and :obj:`val_data.edge_index` refers to the edges that are used for message passing.
As such, during training and validation, we are allowed to propagate information based on the training edges.
While during testing, we can propagate information based on the union of training and validation edges.
For evaluation and testing, :obj:`val_data.edge_label_index` and :obj:`test_data.edge_label_index` hold a batch of positive and negative samples that should be used to evaluate and test our model on.

Graph Prediction
----------------

.. note::

In this section, we'll learn how to randomly divide graphs into training, validation, and test sets.
A fully working example on dataset :class:`~torch_geometric.datasets.PPI` is available in `examples/ppi.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ppi.py>`_.

In graph prediction task, each graph is an independent sample.
Usually we need to divide a graph dataset according to a certain ratio.
:pyg:`PyG` has provided some datasets that already contain corresponding indexes for training, validation and test, such as :class:`~torch_geometric.datasets.PPI`.

.. code-block:: python
from torch_geometric.datasets import PPI
path = './data/PPI'
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')
In addition, we can also use :obj:`scikit-learn` or :obj:`numpy` to randomly divide :pyg:`PyG` dataset.

Creating Custom Splits
----------------------

If random splitting doesn't suit our specific use case, then we can create custom node splits.
This requirement generally occurs in real business scenarios.
For example, there are large-scale heterogeneous graphs in e-commerce scenarios, and nodes can be used to represent users, products, merchants, etc.
We may divide new and old users to evaluate the performance of the model on new users.
Therefore, we'll not post specific examples here for reference.

0 comments on commit 0d142bb

Please sign in to comment.