diff --git a/notebooks/CORA-geometric.ipynb b/notebooks/CORA-geometric.ipynb new file mode 100644 index 000000000..cfe3699c1 --- /dev/null +++ b/notebooks/CORA-geometric.ipynb @@ -0,0 +1,649 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "93e10022", + "metadata": {}, + "source": [ + "# torch geometric + skorch @ CORA dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f8b30bc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fr 22. Jul 18:52:31 CEST 2022\r\n" + ] + } + ], + "source": [ + "!date" + ] + }, + { + "cell_type": "markdown", + "id": "4e144398", + "metadata": {}, + "source": [ + "This is an example for how to use skorch with [torch geometric](https://pytorch-geometric.readthedocs.io/). The code is based on the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) but modified to have a proper train/valid/test split. This example is showcasing a quite small data set that does not need to employ batching to be trained efficiently. How to do batching with skorch + torch geometric will not be handled here since it is non-trivial and quite dataset specific - if you need this and are stuck, feel free to open [an issue](https://github.com/skorch-dev/skorch/issues) so that we can support you the best we can.\n", + "\n", + "Dependencies of this notebook besides skorch base installation:" + ] + }, + { + "cell_type": "raw", + "id": "3748b2a0", + "metadata": {}, + "source": [ + "!pip install torch_geometric==2.0.4" + ] + }, + { + "cell_type": "markdown", + "id": "7ce31902", + "metadata": {}, + "source": [ + "It is recommended to install the dependencies [as documented by pytorch geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)." + ] + }, + { + "cell_type": "markdown", + "id": "ff707ef1", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "efa2b5a5", + "metadata": {}, + "outputs": [], + "source": [ + "import skorch\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "934fa438", + "metadata": {}, + "source": [ + "### Data Loading" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f1505f8f", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.datasets import Planetoid\n", + "\n", + "dataset = Planetoid(root='/tmp/Cora', name='Cora')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4979ba6f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708]),\n", + " 7)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.data, dataset.num_classes" + ] + }, + { + "cell_type": "markdown", + "id": "9cd7e104", + "metadata": {}, + "source": [ + "In order to use pytorch geometric / the cora dataset with skorch\n", + "we need to address the following things:\n", + " \n", + "1. graph convolutions cannot handle missing nodes (=> splitting node attributes but keeping edge_index intact will lead to errors)\n", + "2. cora dataset has different attributes for the different split masks (i.e. `train_mask`, `val_mask`, `test_mask`)\n", + "3. skorch expects to have (X, y) pairs for classification tasks\n", + "\n", + "To deal with (1) we will split the data into three datasets, creating three sub-graphs in the process; these complete sub-graphs can then be convolved over without errors. \n", + "We use the masks mentioned in (2) to identify the nodes and edges of the subgraphs.\n", + "\n", + "(3) will be handled by specifying our own `XYDataset` which will just have length 1 and return the dataset and the respective y values. We will therefore basically simulate a `batch_size=1` scenario." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "35aa1770", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.data import Data\n", + "\n", + "# simulating batch_size=1 by returning the whole dataset and the\n", + "# y-values. this way, the data loader can iterate over the 'batches'\n", + "# and produce X/y values for us.\n", + "class XYDataset(torch.utils.data.Dataset):\n", + " def __init__(self, data: Data, y: torch.tensor):\n", + " self.data = data\n", + " self.y = y\n", + " \n", + " def __len__(self):\n", + " return 1\n", + " \n", + " def __getitem__(self, i):\n", + " return self.data, self.y" + ] + }, + { + "cell_type": "markdown", + "id": "7cd78cd1", + "metadata": {}, + "source": [ + "### Data Splitting" + ] + }, + { + "cell_type": "markdown", + "id": "468d8d6f", + "metadata": {}, + "source": [ + "Split the graph into train, validation and test sub-graphs.\n", + "This ensures that there will be no leakage between steps when we apply graph\n", + "convolution operators on the graph since each split has its own sub-graph.\n", + "\n", + "We use `relabel_nodes=True` to make the node indices in the edge tensor \n", + "zero-based for each sub-graph. If we would not do this the node subsets\n", + "(now zero-based after applying the mask) would not match the indices in the\n", + "edge tensor." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f3d908e4", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.utils import subgraph\n", + "\n", + "data = dataset[0]\n", + "\n", + "edge_index_train, _ = subgraph(\n", + " subset=data.train_mask, \n", + " edge_index=data.edge_index, \n", + " relabel_nodes=True\n", + ")\n", + "ds_train = XYDataset(\n", + " Data(x=data.x[data.train_mask], edge_index=edge_index_train),\n", + " data.y[data.train_mask],\n", + ")\n", + "\n", + "edge_index_valid, _ = subgraph(\n", + " subset=data.val_mask, \n", + " edge_index=data.edge_index, \n", + " relabel_nodes=True\n", + ")\n", + "ds_valid = XYDataset(\n", + " Data(x=data.x[data.val_mask], edge_index=edge_index_valid),\n", + " data.y[data.val_mask],\n", + ")\n", + "\n", + "edge_index_test, _ = subgraph(\n", + " subset=data.test_mask, \n", + " edge_index=data.edge_index, \n", + " relabel_nodes=True\n", + ")\n", + "ds_test = XYDataset(\n", + " Data(x=data.x[data.test_mask], edge_index=edge_index_test),\n", + " data.y[data.test_mask],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bf59f7a9", + "metadata": {}, + "source": [ + "### Data Feeding" + ] + }, + { + "cell_type": "markdown", + "id": "4d23039d", + "metadata": {}, + "source": [ + "Our \"batch\" consists of the whole dataset so if we unpack the\n", + "batch into `(X, y)` we will have `X = Data(...)` and `y = [y_true]`.\n", + "The `DataLoader` does not modify `X` but `y` gets a new batch dimension.\n", + "This will lead to a shape mismatch as `y.shape` would then be `(1, #num_samples)`. Therefore, we need our own loader that strips the first dimension to \n", + "match the predicted `y` and the labelled `y` in length.\n", + "\n", + "Note: It is possible to avoid this by stripping this dimension by overriding `get_loss` in the `NeuralNet` class. For brevity we won't do this in this example. It is possible to use [one of the many `DataLoader` classes](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html) provided by torch geometric using the approach outlined below (just base the `RawDataloader` on one of the other classes) - chances are, though, that if you are doing this you need to deal with batching anyway which is a topic that is not handled here since it is not trivial." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "83594638", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader\n", + "\n", + "class RawLoader(DataLoader):\n", + " def __iter__(self):\n", + " it = super().__iter__()\n", + " for X, y in it:\n", + " yield X, y[0]" + ] + }, + { + "cell_type": "markdown", + "id": "d51e7140", + "metadata": {}, + "source": [ + "### Modelling" + ] + }, + { + "cell_type": "markdown", + "id": "67e6ff9e", + "metadata": {}, + "source": [ + "This is the CORA example module as seen in the [torch geometric introduction](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "10de0020", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = GCNConv(dataset.num_node_features, 16)\n", + " self.conv2 = GCNConv(16, dataset.num_classes)\n", + "\n", + " def forward(self, data): \n", + " x, edge_index = data.x, data.edge_index\n", + "\n", + " x = self.conv1(x, edge_index)\n", + " x = F.relu(x)\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.conv2(x, edge_index)\n", + "\n", + " return F.softmax(x, dim=1)" + ] + }, + { + "cell_type": "markdown", + "id": "f58e93d8", + "metadata": {}, + "source": [ + "### Fitting" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b3f5ef3c", + "metadata": {}, + "outputs": [], + "source": [ + "from skorch.helper import predefined_split\n", + "\n", + "torch.manual_seed(42)\n", + "\n", + "net = skorch.NeuralNetClassifier(\n", + " module=GCN,\n", + " lr=0.1,\n", + " optimizer__weight_decay=5e-4,\n", + " max_epochs=200,\n", + " train_split=skorch.helper.predefined_split(ds_valid),\n", + " batch_size=1,\n", + " iterator_train=RawLoader,\n", + " iterator_valid=RawLoader,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c2cf8bbd", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " epoch train_loss valid_acc valid_loss dur\n", + "------- ------------ ----------- ------------ ------\n", + " 1 \u001b[36m1.9724\u001b[0m \u001b[32m0.1680\u001b[0m \u001b[35m1.9398\u001b[0m 0.0074\n", + " 2 \u001b[36m1.9625\u001b[0m \u001b[32m0.1740\u001b[0m \u001b[35m1.9376\u001b[0m 0.0052\n", + " 3 \u001b[36m1.9327\u001b[0m 0.1720 \u001b[35m1.9342\u001b[0m 0.0043\n", + " 4 \u001b[36m1.9321\u001b[0m \u001b[32m0.1760\u001b[0m \u001b[35m1.9324\u001b[0m 0.0045\n", + " 5 \u001b[36m1.9142\u001b[0m \u001b[32m0.1800\u001b[0m \u001b[35m1.9307\u001b[0m 0.0037\n", + " 6 \u001b[36m1.8923\u001b[0m 0.1800 \u001b[35m1.9290\u001b[0m 0.0036\n", + " 7 \u001b[36m1.8848\u001b[0m \u001b[32m0.1880\u001b[0m \u001b[35m1.9269\u001b[0m 0.0042\n", + " 8 1.8936 \u001b[32m0.1920\u001b[0m \u001b[35m1.9247\u001b[0m 0.0042\n", + " 9 \u001b[36m1.8783\u001b[0m \u001b[32m0.1960\u001b[0m \u001b[35m1.9219\u001b[0m 0.0044\n", + " 10 \u001b[36m1.8737\u001b[0m \u001b[32m0.2040\u001b[0m \u001b[35m1.9192\u001b[0m 0.0044\n", + " 11 \u001b[36m1.8542\u001b[0m \u001b[32m0.2060\u001b[0m \u001b[35m1.9176\u001b[0m 0.0042\n", + " 12 \u001b[36m1.8489\u001b[0m \u001b[32m0.2100\u001b[0m \u001b[35m1.9156\u001b[0m 0.0042\n", + " 13 \u001b[36m1.8314\u001b[0m \u001b[32m0.2120\u001b[0m \u001b[35m1.9121\u001b[0m 0.0074\n", + " 14 1.8334 \u001b[32m0.2220\u001b[0m \u001b[35m1.9100\u001b[0m 0.0055\n", + " 15 \u001b[36m1.8041\u001b[0m \u001b[32m0.2240\u001b[0m \u001b[35m1.9085\u001b[0m 0.0081\n", + " 16 1.8089 0.2220 \u001b[35m1.9065\u001b[0m 0.0063\n", + " 17 1.8082 0.2200 \u001b[35m1.9043\u001b[0m 0.0081\n", + " 18 \u001b[36m1.7759\u001b[0m 0.2220 \u001b[35m1.9023\u001b[0m 0.0094\n", + " 19 \u001b[36m1.7745\u001b[0m 0.2220 \u001b[35m1.8992\u001b[0m 0.0049\n", + " 20 \u001b[36m1.7630\u001b[0m \u001b[32m0.2280\u001b[0m \u001b[35m1.8970\u001b[0m 0.0069\n", + " 21 \u001b[36m1.7411\u001b[0m \u001b[32m0.2340\u001b[0m \u001b[35m1.8946\u001b[0m 0.0068\n", + " 22 1.7732 \u001b[32m0.2360\u001b[0m \u001b[35m1.8921\u001b[0m 0.0054\n", + " 23 \u001b[36m1.7407\u001b[0m \u001b[32m0.2420\u001b[0m \u001b[35m1.8893\u001b[0m 0.0061\n", + " 24 \u001b[36m1.7259\u001b[0m 0.2420 \u001b[35m1.8857\u001b[0m 0.0046\n", + " 25 \u001b[36m1.6920\u001b[0m \u001b[32m0.2520\u001b[0m \u001b[35m1.8836\u001b[0m 0.0095\n", + " 26 1.7033 \u001b[32m0.2540\u001b[0m \u001b[35m1.8805\u001b[0m 0.0070\n", + " 27 1.7080 \u001b[32m0.2580\u001b[0m \u001b[35m1.8767\u001b[0m 0.0071\n", + " 28 1.6924 \u001b[32m0.2620\u001b[0m \u001b[35m1.8741\u001b[0m 0.0064\n", + " 29 \u001b[36m1.6882\u001b[0m 0.2620 \u001b[35m1.8703\u001b[0m 0.0049\n", + " 30 \u001b[36m1.6850\u001b[0m \u001b[32m0.2660\u001b[0m \u001b[35m1.8679\u001b[0m 0.0061\n", + " 31 \u001b[36m1.6438\u001b[0m 0.2580 \u001b[35m1.8650\u001b[0m 0.0077\n", + " 32 \u001b[36m1.6345\u001b[0m \u001b[32m0.2680\u001b[0m \u001b[35m1.8618\u001b[0m 0.0088\n", + " 33 1.6816 0.2660 \u001b[35m1.8579\u001b[0m 0.0089\n", + " 34 \u001b[36m1.6169\u001b[0m 0.2620 \u001b[35m1.8559\u001b[0m 0.0065\n", + " 35 1.6373 \u001b[32m0.2720\u001b[0m \u001b[35m1.8522\u001b[0m 0.0047\n", + " 36 \u001b[36m1.6107\u001b[0m 0.2700 \u001b[35m1.8491\u001b[0m 0.0100\n", + " 37 \u001b[36m1.6035\u001b[0m \u001b[32m0.2800\u001b[0m \u001b[35m1.8449\u001b[0m 0.0069\n", + " 38 1.6060 \u001b[32m0.2840\u001b[0m \u001b[35m1.8421\u001b[0m 0.0062\n", + " 39 \u001b[36m1.5604\u001b[0m \u001b[32m0.2960\u001b[0m \u001b[35m1.8389\u001b[0m 0.0075\n", + " 40 1.5724 \u001b[32m0.3060\u001b[0m \u001b[35m1.8354\u001b[0m 0.0093\n", + " 41 \u001b[36m1.5371\u001b[0m \u001b[32m0.3160\u001b[0m \u001b[35m1.8319\u001b[0m 0.0066\n", + " 42 \u001b[36m1.5246\u001b[0m \u001b[32m0.3240\u001b[0m \u001b[35m1.8281\u001b[0m 0.0044\n", + " 43 1.5524 0.3200 \u001b[35m1.8241\u001b[0m 0.0078\n", + " 44 1.5282 \u001b[32m0.3300\u001b[0m \u001b[35m1.8211\u001b[0m 0.0068\n", + " 45 1.5356 \u001b[32m0.3380\u001b[0m \u001b[35m1.8169\u001b[0m 0.0057\n", + " 46 \u001b[36m1.5079\u001b[0m \u001b[32m0.3440\u001b[0m \u001b[35m1.8137\u001b[0m 0.0089\n", + " 47 1.5192 \u001b[32m0.3500\u001b[0m \u001b[35m1.8090\u001b[0m 0.0072\n", + " 48 \u001b[36m1.4991\u001b[0m \u001b[32m0.3540\u001b[0m \u001b[35m1.8063\u001b[0m 0.0045\n", + " 49 \u001b[36m1.4949\u001b[0m 0.3460 \u001b[35m1.8036\u001b[0m 0.0085\n", + " 50 \u001b[36m1.4892\u001b[0m \u001b[32m0.3640\u001b[0m \u001b[35m1.8000\u001b[0m 0.0051\n", + " 51 1.5165 \u001b[32m0.3760\u001b[0m \u001b[35m1.7968\u001b[0m 0.0090\n", + " 52 \u001b[36m1.4367\u001b[0m 0.3740 \u001b[35m1.7931\u001b[0m 0.0079\n", + " 53 1.4473 0.3700 \u001b[35m1.7894\u001b[0m 0.0045\n", + " 54 1.4387 \u001b[32m0.3840\u001b[0m \u001b[35m1.7855\u001b[0m 0.0052\n", + " 55 \u001b[36m1.4261\u001b[0m 0.3840 \u001b[35m1.7825\u001b[0m 0.0088\n", + " 56 1.4355 \u001b[32m0.4040\u001b[0m \u001b[35m1.7768\u001b[0m 0.0075\n", + " 57 1.4270 0.3900 \u001b[35m1.7749\u001b[0m 0.0098\n", + " 58 \u001b[36m1.4029\u001b[0m 0.4000 \u001b[35m1.7714\u001b[0m 0.0087\n", + " 59 \u001b[36m1.3793\u001b[0m 0.4040 \u001b[35m1.7679\u001b[0m 0.0080\n", + " 60 \u001b[36m1.3493\u001b[0m 0.4020 \u001b[35m1.7629\u001b[0m 0.0051\n", + " 61 1.3624 \u001b[32m0.4160\u001b[0m \u001b[35m1.7597\u001b[0m 0.0083\n", + " 62 1.3970 \u001b[32m0.4180\u001b[0m \u001b[35m1.7562\u001b[0m 0.0082\n", + " 63 1.3552 \u001b[32m0.4220\u001b[0m \u001b[35m1.7516\u001b[0m 0.0057\n", + " 64 1.3745 \u001b[32m0.4240\u001b[0m \u001b[35m1.7480\u001b[0m 0.0054\n", + " 65 1.4002 \u001b[32m0.4260\u001b[0m \u001b[35m1.7448\u001b[0m 0.0086\n", + " 66 \u001b[36m1.2924\u001b[0m \u001b[32m0.4280\u001b[0m \u001b[35m1.7405\u001b[0m 0.0071\n", + " 67 1.2954 \u001b[32m0.4300\u001b[0m \u001b[35m1.7375\u001b[0m 0.0070\n", + " 68 \u001b[36m1.2785\u001b[0m \u001b[32m0.4320\u001b[0m \u001b[35m1.7319\u001b[0m 0.0070\n", + " 69 1.3192 0.4300 \u001b[35m1.7290\u001b[0m 0.0088\n", + " 70 1.3049 \u001b[32m0.4360\u001b[0m \u001b[35m1.7246\u001b[0m 0.0057\n", + " 71 \u001b[36m1.2504\u001b[0m \u001b[32m0.4420\u001b[0m \u001b[35m1.7198\u001b[0m 0.0085\n", + " 72 1.2841 0.4340 \u001b[35m1.7165\u001b[0m 0.0077\n", + " 73 \u001b[36m1.2304\u001b[0m \u001b[32m0.4460\u001b[0m \u001b[35m1.7120\u001b[0m 0.0067\n", + " 74 1.2414 \u001b[32m0.4540\u001b[0m \u001b[35m1.7070\u001b[0m 0.0085\n", + " 75 \u001b[36m1.1753\u001b[0m 0.4520 \u001b[35m1.7020\u001b[0m 0.0076\n", + " 76 1.2608 \u001b[32m0.4580\u001b[0m \u001b[35m1.6981\u001b[0m 0.0076\n", + " 77 1.2053 0.4580 \u001b[35m1.6935\u001b[0m 0.0084\n", + " 78 1.2640 \u001b[32m0.4600\u001b[0m \u001b[35m1.6910\u001b[0m 0.0056\n", + " 79 1.2251 \u001b[32m0.4700\u001b[0m \u001b[35m1.6845\u001b[0m 0.0069\n", + " 80 1.2221 \u001b[32m0.4780\u001b[0m \u001b[35m1.6801\u001b[0m 0.0044\n", + " 81 \u001b[36m1.1499\u001b[0m 0.4760 \u001b[35m1.6761\u001b[0m 0.0073\n", + " 82 1.1761 \u001b[32m0.4820\u001b[0m \u001b[35m1.6727\u001b[0m 0.0062\n", + " 83 \u001b[36m1.1286\u001b[0m \u001b[32m0.4880\u001b[0m \u001b[35m1.6673\u001b[0m 0.0082\n", + " 84 1.1338 \u001b[32m0.4920\u001b[0m \u001b[35m1.6634\u001b[0m 0.0053\n", + " 85 \u001b[36m1.1273\u001b[0m \u001b[32m0.4940\u001b[0m \u001b[35m1.6593\u001b[0m 0.0089\n", + " 86 1.1289 0.4900 \u001b[35m1.6548\u001b[0m 0.0078\n", + " 87 1.1618 \u001b[32m0.4960\u001b[0m \u001b[35m1.6512\u001b[0m 0.0064\n", + " 88 1.1306 \u001b[32m0.4980\u001b[0m \u001b[35m1.6474\u001b[0m 0.0054\n", + " 89 1.1436 \u001b[32m0.5000\u001b[0m \u001b[35m1.6438\u001b[0m 0.0091\n", + " 90 \u001b[36m1.0675\u001b[0m \u001b[32m0.5020\u001b[0m \u001b[35m1.6397\u001b[0m 0.0083\n", + " 91 1.0798 0.5000 \u001b[35m1.6360\u001b[0m 0.0072\n", + " 92 1.1148 0.4980 \u001b[35m1.6330\u001b[0m 0.0085\n", + " 93 1.0830 \u001b[32m0.5040\u001b[0m \u001b[35m1.6276\u001b[0m 0.0064\n", + " 94 1.1569 0.5020 \u001b[35m1.6246\u001b[0m 0.0051\n", + " 95 \u001b[36m1.0338\u001b[0m 0.5020 \u001b[35m1.6197\u001b[0m 0.0079\n", + " 96 1.0800 \u001b[32m0.5100\u001b[0m \u001b[35m1.6139\u001b[0m 0.0063\n", + " 97 1.0869 0.5080 \u001b[35m1.6121\u001b[0m 0.0064\n", + " 98 1.1144 0.5100 \u001b[35m1.6074\u001b[0m 0.0053\n", + " 99 \u001b[36m1.0271\u001b[0m 0.5060 \u001b[35m1.6045\u001b[0m 0.0073\n", + " 100 1.0465 0.5100 \u001b[35m1.6020\u001b[0m 0.0062\n", + " 101 1.0348 \u001b[32m0.5200\u001b[0m \u001b[35m1.5963\u001b[0m 0.0047\n", + " 102 \u001b[36m1.0045\u001b[0m \u001b[32m0.5220\u001b[0m \u001b[35m1.5936\u001b[0m 0.0053\n", + " 103 1.0307 \u001b[32m0.5260\u001b[0m \u001b[35m1.5895\u001b[0m 0.0048\n", + " 104 \u001b[36m0.9970\u001b[0m \u001b[32m0.5300\u001b[0m \u001b[35m1.5839\u001b[0m 0.0047\n", + " 105 \u001b[36m0.9644\u001b[0m 0.5300 \u001b[35m1.5814\u001b[0m 0.0041\n", + " 106 0.9879 \u001b[32m0.5320\u001b[0m \u001b[35m1.5770\u001b[0m 0.0050\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 107 0.9986 0.5320 \u001b[35m1.5732\u001b[0m 0.0047\n", + " 108 \u001b[36m0.9234\u001b[0m 0.5320 \u001b[35m1.5692\u001b[0m 0.0070\n", + " 109 0.9704 0.5300 \u001b[35m1.5642\u001b[0m 0.0065\n", + " 110 1.0256 0.5300 \u001b[35m1.5621\u001b[0m 0.0087\n", + " 111 0.9590 0.5240 \u001b[35m1.5585\u001b[0m 0.0074\n", + " 112 1.0168 0.5300 \u001b[35m1.5565\u001b[0m 0.0057\n", + " 113 0.9994 0.5320 \u001b[35m1.5534\u001b[0m 0.0073\n", + " 114 0.9635 0.5320 \u001b[35m1.5492\u001b[0m 0.0082\n", + " 115 0.9872 \u001b[32m0.5340\u001b[0m \u001b[35m1.5452\u001b[0m 0.0046\n", + " 116 0.9749 0.5340 \u001b[35m1.5411\u001b[0m 0.0092\n", + " 117 0.9667 0.5340 \u001b[35m1.5392\u001b[0m 0.0048\n", + " 118 \u001b[36m0.8757\u001b[0m 0.5300 \u001b[35m1.5351\u001b[0m 0.0088\n", + " 119 0.9306 0.5340 \u001b[35m1.5340\u001b[0m 0.0072\n", + " 120 \u001b[36m0.8284\u001b[0m \u001b[32m0.5380\u001b[0m \u001b[35m1.5300\u001b[0m 0.0085\n", + " 121 0.8389 \u001b[32m0.5400\u001b[0m \u001b[35m1.5254\u001b[0m 0.0086\n", + " 122 0.9347 \u001b[32m0.5440\u001b[0m \u001b[35m1.5226\u001b[0m 0.0082\n", + " 123 0.8502 0.5340 \u001b[35m1.5207\u001b[0m 0.0064\n", + " 124 0.8519 \u001b[32m0.5480\u001b[0m \u001b[35m1.5163\u001b[0m 0.0085\n", + " 125 0.8536 0.5460 \u001b[35m1.5127\u001b[0m 0.0073\n", + " 126 0.8926 0.5480 \u001b[35m1.5082\u001b[0m 0.0065\n", + " 127 0.8605 0.5480 \u001b[35m1.5050\u001b[0m 0.0088\n", + " 128 0.8853 \u001b[32m0.5540\u001b[0m \u001b[35m1.5037\u001b[0m 0.0067\n", + " 129 0.8483 0.5540 \u001b[35m1.4992\u001b[0m 0.0092\n", + " 130 0.8745 0.5540 \u001b[35m1.4954\u001b[0m 0.0048\n", + " 131 \u001b[36m0.7866\u001b[0m 0.5500 \u001b[35m1.4940\u001b[0m 0.0096\n", + " 132 0.8322 0.5520 \u001b[35m1.4901\u001b[0m 0.0108\n", + " 133 0.8019 0.5520 \u001b[35m1.4864\u001b[0m 0.0062\n", + " 134 0.8829 0.5540 \u001b[35m1.4846\u001b[0m 0.0097\n", + " 135 0.8545 \u001b[32m0.5560\u001b[0m \u001b[35m1.4829\u001b[0m 0.0078\n", + " 136 0.9028 0.5560 \u001b[35m1.4802\u001b[0m 0.0071\n", + " 137 0.8797 0.5540 \u001b[35m1.4794\u001b[0m 0.0077\n", + " 138 0.7967 0.5560 \u001b[35m1.4744\u001b[0m 0.0070\n", + " 139 \u001b[36m0.7614\u001b[0m 0.5560 \u001b[35m1.4723\u001b[0m 0.0084\n", + " 140 0.8399 \u001b[32m0.5580\u001b[0m \u001b[35m1.4696\u001b[0m 0.0063\n", + " 141 0.8502 0.5580 \u001b[35m1.4671\u001b[0m 0.0077\n", + " 142 \u001b[36m0.7301\u001b[0m 0.5580 \u001b[35m1.4648\u001b[0m 0.0088\n", + " 143 0.7543 \u001b[32m0.5640\u001b[0m \u001b[35m1.4617\u001b[0m 0.0061\n", + " 144 \u001b[36m0.7023\u001b[0m 0.5620 \u001b[35m1.4580\u001b[0m 0.0093\n", + " 145 0.7329 0.5640 \u001b[35m1.4561\u001b[0m 0.0078\n", + " 146 0.7820 0.5640 \u001b[35m1.4526\u001b[0m 0.0066\n", + " 147 0.8137 0.5640 \u001b[35m1.4521\u001b[0m 0.0065\n", + " 148 0.7950 0.5600 \u001b[35m1.4489\u001b[0m 0.0072\n", + " 149 0.7702 0.5600 \u001b[35m1.4468\u001b[0m 0.0089\n", + " 150 0.7851 0.5580 1.4469 0.0067\n", + " 151 0.7881 0.5600 \u001b[35m1.4456\u001b[0m 0.0090\n", + " 152 0.7375 0.5600 \u001b[35m1.4405\u001b[0m 0.0093\n", + " 153 0.7888 0.5580 \u001b[35m1.4401\u001b[0m 0.0069\n", + " 154 0.8128 0.5580 \u001b[35m1.4376\u001b[0m 0.0078\n", + " 155 \u001b[36m0.6960\u001b[0m 0.5600 \u001b[35m1.4345\u001b[0m 0.0081\n", + " 156 0.7073 0.5600 \u001b[35m1.4328\u001b[0m 0.0078\n", + " 157 0.7129 0.5620 \u001b[35m1.4301\u001b[0m 0.0065\n", + " 158 0.7282 0.5620 \u001b[35m1.4283\u001b[0m 0.0076\n", + " 159 0.7855 0.5580 \u001b[35m1.4263\u001b[0m 0.0090\n", + " 160 0.7444 0.5620 \u001b[35m1.4237\u001b[0m 0.0065\n", + " 161 0.7081 0.5620 \u001b[35m1.4204\u001b[0m 0.0085\n", + " 162 \u001b[36m0.6947\u001b[0m 0.5620 \u001b[35m1.4188\u001b[0m 0.0091\n", + " 163 0.7121 0.5620 \u001b[35m1.4152\u001b[0m 0.0074\n", + " 164 0.7374 0.5620 \u001b[35m1.4139\u001b[0m 0.0078\n", + " 165 \u001b[36m0.6866\u001b[0m 0.5600 \u001b[35m1.4113\u001b[0m 0.0100\n", + " 166 0.7360 0.5640 \u001b[35m1.4102\u001b[0m 0.0061\n", + " 167 \u001b[36m0.6596\u001b[0m 0.5600 \u001b[35m1.4073\u001b[0m 0.0082\n", + " 168 \u001b[36m0.6245\u001b[0m 0.5600 \u001b[35m1.4048\u001b[0m 0.0079\n", + " 169 0.6546 0.5600 \u001b[35m1.4032\u001b[0m 0.0077\n", + " 170 0.6534 0.5640 \u001b[35m1.4001\u001b[0m 0.0087\n", + " 171 0.7578 0.5640 \u001b[35m1.3992\u001b[0m 0.0074\n", + " 172 \u001b[36m0.5842\u001b[0m 0.5640 \u001b[35m1.3976\u001b[0m 0.0080\n", + " 173 0.5937 0.5600 \u001b[35m1.3950\u001b[0m 0.0067\n", + " 174 0.6404 0.5620 \u001b[35m1.3937\u001b[0m 0.0071\n", + " 175 0.6674 \u001b[32m0.5700\u001b[0m \u001b[35m1.3930\u001b[0m 0.0104\n", + " 176 0.6230 0.5660 \u001b[35m1.3901\u001b[0m 0.0051\n", + " 177 0.7345 0.5660 \u001b[35m1.3891\u001b[0m 0.0102\n", + " 178 0.6696 0.5700 \u001b[35m1.3872\u001b[0m 0.0078\n", + " 179 \u001b[36m0.5714\u001b[0m \u001b[32m0.5720\u001b[0m \u001b[35m1.3838\u001b[0m 0.0064\n", + " 180 0.5873 0.5720 \u001b[35m1.3810\u001b[0m 0.0079\n", + " 181 0.6660 \u001b[32m0.5780\u001b[0m \u001b[35m1.3787\u001b[0m 0.0073\n", + " 182 0.6348 0.5760 \u001b[35m1.3771\u001b[0m 0.0091\n", + " 183 0.6988 0.5780 \u001b[35m1.3746\u001b[0m 0.0058\n", + " 184 0.5842 0.5760 \u001b[35m1.3719\u001b[0m 0.0102\n", + " 185 0.5969 0.5760 \u001b[35m1.3698\u001b[0m 0.0076\n", + " 186 0.6382 0.5780 \u001b[35m1.3682\u001b[0m 0.0072\n", + " 187 0.5822 0.5780 1.3693 0.0086\n", + " 188 0.6263 0.5760 \u001b[35m1.3677\u001b[0m 0.0094\n", + " 189 0.6212 0.5740 \u001b[35m1.3666\u001b[0m 0.0063\n", + " 190 0.6059 0.5740 \u001b[35m1.3641\u001b[0m 0.0087\n", + " 191 0.5797 0.5760 \u001b[35m1.3610\u001b[0m 0.0062\n", + " 192 0.6375 0.5780 \u001b[35m1.3598\u001b[0m 0.0083\n", + " 193 0.6777 0.5740 1.3610 0.0059\n", + " 194 0.6331 0.5700 1.3609 0.0074\n", + " 195 0.6305 0.5720 1.3601 0.0081\n", + " 196 \u001b[36m0.5502\u001b[0m 0.5700 \u001b[35m1.3568\u001b[0m 0.0063\n", + " 197 0.7014 0.5720 \u001b[35m1.3562\u001b[0m 0.0109\n", + " 198 0.5605 0.5720 \u001b[35m1.3551\u001b[0m 0.0105\n", + " 199 0.5541 0.5760 1.3557 0.0072\n", + " 200 \u001b[36m0.5496\u001b[0m 0.5720 \u001b[35m1.3506\u001b[0m 0.0088\n" + ] + }, + { + "data": { + "text/plain": [ + "[initialized](\n", + " module_=GCN(\n", + " (conv1): GCNConv(1433, 16)\n", + " (conv2): GCNConv(16, 7)\n", + " ),\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.fit(ds_train, None)" + ] + }, + { + "cell_type": "markdown", + "id": "fed65b00", + "metadata": {}, + "source": [ + "### Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ae562c9e", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ef5f2409", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.682" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy_score(ds_test.y, net.predict(ds_test))" + ] + }, + { + "cell_type": "markdown", + "id": "2fb7c99b", + "metadata": {}, + "source": [ + "In conclusion this example showed you how to use a basic data graph dataset using pytorch geometric in conjunction with skorch. The final test score is lower than the ~80% accuracy in the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) which can be explained by the reduced leakage between train and validation sets due to our splitting the data into subgraphs beforehand.\n", + "\n", + "The model is now incorporated into the sklearn world (as you could already see, you can simply use sklearn metrics to evaluate the model). Thus, tools like grid and random search are available to you and it is easily possible to include a graph neural net as a feature transformer in your next ML pipeline!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/README.md b/notebooks/README.md index 040b3966b..6361d40d4 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -6,4 +6,5 @@ * [MNIST using torchvision](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/MNIST-torchvision.ipynb) * [Transfer Learning](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb) * [Gaussian Processes](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb) -* [Huggingface fine-tuning](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Huggingface_Finetuning.ipynb) +* [PyTorch Geometric on the CORA dataset](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/CORA-geometric.ipynb) +* [Hugging Face fine-tuning](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb) diff --git a/skorch/utils.py b/skorch/utils.py index 410edb23b..5fa570926 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -40,6 +40,12 @@ except ImportError: gpytorch = None +try: + import torch_geometric + TORCH_GEOMETRIC_INSTALLED = True +except ImportError: + TORCH_GEOMETRIC_INSTALLED = False + class Ansi(Enum): BLUE = '\033[94m' @@ -59,6 +65,11 @@ def is_dataset(x): return isinstance(x, torch.utils.data.Dataset) +def is_geometric_data_type(x): + from torch_geometric.data import Data + return isinstance(x, Data) + + # pylint: disable=not-callable def to_tensor(X, device, accept_sparse=False): """Turn input data to torch tensor. @@ -93,6 +104,8 @@ def to_tensor(X, device, accept_sparse=False): if is_torch_data_type(X): return to_device(X, device) + if TORCH_GEOMETRIC_INSTALLED and is_geometric_data_type(X): + return to_device(X, device) if hasattr(X, 'convert_to_tensors'): # huggingface transformers BatchEncoding return X.convert_to_tensors('pt')