From 8db8a0d4d23e696c54cc96494b54a83f5ac55d69 Mon Sep 17 00:00:00 2001 From: Marian Tietz Date: Tue, 21 Jun 2022 15:05:46 +0200 Subject: [PATCH 1/3] torch geometric support + example notebook Adding support for torch geometric has so far been straight-forward since its data types also implement the `.to()` interface and use tensors in the background, they can basically be passed forward to the module which will deal with the data type accordingly. There are probably other places that might still fail, I can imagine that different data loaders (there are quite a [few][1]) might be problematic in terms of splitting X/y in batches, but this is a first step forward and we can address these issues once they arise. [1]: https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html --- notebooks/CORA-geometric.ipynb | 628 +++++++++++++++++++++++++++++++++ skorch/utils.py | 13 + 2 files changed, 641 insertions(+) create mode 100644 notebooks/CORA-geometric.ipynb diff --git a/notebooks/CORA-geometric.ipynb b/notebooks/CORA-geometric.ipynb new file mode 100644 index 000000000..7dde251bb --- /dev/null +++ b/notebooks/CORA-geometric.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "93e10022", + "metadata": {}, + "source": [ + "# torch geometric + skorch @ CORA dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f8b30bc0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Di 21. Jun 15:04:50 CEST 2022\r\n" + ] + } + ], + "source": [ + "!date" + ] + }, + { + "cell_type": "markdown", + "id": "4e144398", + "metadata": {}, + "source": [ + "This is an example for how to use skorch with [torch geometric](https://pytorch-geometric.readthedocs.io/). The code is based on the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) but modified to have a proper train/valid/test split. This example is showcasing a quite small data set that does not need to employ batching to be trained efficiently. How to do batching with skorch + torch geometric will not be handled here.\n", + "\n", + "Dependencies of this notebook besides skorch base installation:" + ] + }, + { + "cell_type": "raw", + "id": "ac7d2260", + "metadata": {}, + "source": [ + "!pip install torch_geometric==2.0.4" + ] + }, + { + "cell_type": "markdown", + "id": "ff707ef1", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "efa2b5a5", + "metadata": {}, + "outputs": [], + "source": [ + "import skorch\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "934fa438", + "metadata": {}, + "source": [ + "### Data Loading" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f1505f8f", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.datasets import Planetoid\n", + "\n", + "dataset = Planetoid(root='/tmp/Cora', name='Cora')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4979ba6f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708]),\n", + " 7)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.data, dataset.num_classes" + ] + }, + { + "cell_type": "markdown", + "id": "9cd7e104", + "metadata": {}, + "source": [ + "In order to use pytorch geometric / the cora dataset with skorch\n", + "we need to address the following things:\n", + " \n", + "1. graph convolutions cannot handle missing nodes (=> splitting node attributes but keeping edge_index intact will lead to errors)\n", + "2. cora dataset has different attributes for the different split masks (i.e. `train_mask`, `val_mask`, `test_mask`)\n", + "3. skorch expects to have (X, y) pairs for classification tasks\n", + "\n", + "To deal with (1) we will split the data into three datasets, creating three sub-graphs in the process; these complete sub-graphs can then be convolved over without errors. \n", + "We use the masks mentioned in (2) to identify the nodes and edges of the subgraphs.\n", + "\n", + "(3) will be handled by specifying our own `XYDataset` which will just have length 1 and return the dataset and the respective y values. We will therefore basically simulate a `batch_size=1` scenario." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "35aa1770", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.data import Data\n", + "\n", + "# simulating batch_size=1 by returning the whole dataset and the\n", + "# y-values. this way, the data loader can iterate over the 'batches'\n", + "# and produce X/y values for us.\n", + "class XYDataset(torch.utils.data.Dataset):\n", + " def __init__(self, data: Data, y: torch.tensor):\n", + " self.data = data\n", + " self.y = y\n", + " \n", + " def __len__(self):\n", + " return 1\n", + " \n", + " def __getitem__(self, i):\n", + " return self.data, self.y" + ] + }, + { + "cell_type": "markdown", + "id": "7cd78cd1", + "metadata": {}, + "source": [ + "### Data Splitting" + ] + }, + { + "cell_type": "markdown", + "id": "468d8d6f", + "metadata": {}, + "source": [ + "Split the graph into train, validation and test sub-graphs.\n", + "This ensures that there will be no leakage between steps when we apply graph\n", + "convolution operators on the graph.\n", + "\n", + "We use `relabel_nodes=True` to re-index the edges when the mask starts in the\n", + "middle of the node tensor. Without this flag a mask like `[0, 1, 0]` will assume\n", + "that `nodes[1]` exists but our mask forbids this: `nodes = all_nodes[[0, 1, 0]] == []`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f3d908e4", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.utils import subgraph\n", + "\n", + "data = dataset[0]\n", + "\n", + "edge_index_train, _ = subgraph(\n", + " subset=data.train_mask, \n", + " edge_index=data.edge_index, \n", + " relabel_nodes=True\n", + ")\n", + "ds_train = XYDataset(\n", + " Data(x=data.x[data.train_mask], edge_index=edge_index_train),\n", + " data.y[data.train_mask],\n", + ")\n", + "\n", + "edge_index_valid, _ = subgraph(\n", + " subset=data.val_mask, \n", + " edge_index=data.edge_index, \n", + " relabel_nodes=True\n", + ")\n", + "ds_valid = XYDataset(\n", + " Data(x=data.x[data.val_mask], edge_index=edge_index_valid),\n", + " data.y[data.val_mask],\n", + ")\n", + "\n", + "edge_index_test, _ = subgraph(\n", + " subset=data.test_mask, \n", + " edge_index=data.edge_index, \n", + " relabel_nodes=True\n", + ")\n", + "ds_test = XYDataset(\n", + " Data(x=data.x[data.test_mask], edge_index=edge_index_test),\n", + " data.y[data.test_mask],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bf59f7a9", + "metadata": {}, + "source": [ + "### Data Feeding" + ] + }, + { + "cell_type": "markdown", + "id": "4d23039d", + "metadata": {}, + "source": [ + "Our \"batch\" consists of the whole dataset so if we unpack the\n", + "batch into `(X, y)` we will have `X = Data(...)` and `y = [y_true]`.\n", + "The `DataLoader` does not modify `X` but `y` gets a new batch dimension.\n", + "This will lead to a shape mismatch as `y.shape` would then be `(1, #num_samples)`. Therefore, we need our own loader that strips the first dimension to \n", + "match the predicted `y` and the labelled `y` in length." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "83594638", + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader\n", + "\n", + "class RawLoader(DataLoader):\n", + " def __iter__(self):\n", + " it = super().__iter__()\n", + " for X, y in it:\n", + " yield X, y[0]" + ] + }, + { + "cell_type": "markdown", + "id": "d51e7140", + "metadata": {}, + "source": [ + "### Modelling" + ] + }, + { + "cell_type": "markdown", + "id": "67e6ff9e", + "metadata": {}, + "source": [ + "This is the CORA example module as seen in the [torch geometric introduction](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "10de0020", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = GCNConv(dataset.num_node_features, 16)\n", + " self.conv2 = GCNConv(16, dataset.num_classes)\n", + "\n", + " def forward(self, data): \n", + " x, edge_index = data.x, data.edge_index\n", + "\n", + " x = self.conv1(x, edge_index)\n", + " x = F.relu(x)\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.conv2(x, edge_index)\n", + "\n", + " return F.softmax(x, dim=1)" + ] + }, + { + "cell_type": "markdown", + "id": "f58e93d8", + "metadata": {}, + "source": [ + "### Fitting" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b3f5ef3c", + "metadata": {}, + "outputs": [], + "source": [ + "from skorch.helper import predefined_split\n", + "\n", + "torch.manual_seed(42)\n", + "\n", + "net = skorch.NeuralNetClassifier(\n", + " module=GCN,\n", + " lr=0.1,\n", + " optimizer__weight_decay=5e-4,\n", + " max_epochs=200,\n", + " train_split=skorch.helper.predefined_split(ds_valid),\n", + " batch_size=1,\n", + " iterator_train=RawLoader,\n", + " iterator_valid=RawLoader,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c2cf8bbd", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " epoch train_loss valid_acc valid_loss dur\n", + "------- ------------ ----------- ------------ ------\n", + " 1 \u001b[36m1.9659\u001b[0m \u001b[32m0.1620\u001b[0m \u001b[35m1.9380\u001b[0m 0.1744\n", + " 2 \u001b[36m1.9286\u001b[0m \u001b[32m0.1640\u001b[0m \u001b[35m1.9361\u001b[0m 0.0062\n", + " 3 1.9395 0.1640 \u001b[35m1.9333\u001b[0m 0.0062\n", + " 4 \u001b[36m1.9249\u001b[0m \u001b[32m0.1780\u001b[0m \u001b[35m1.9311\u001b[0m 0.0067\n", + " 5 \u001b[36m1.9008\u001b[0m \u001b[32m0.1880\u001b[0m \u001b[35m1.9292\u001b[0m 0.0066\n", + " 6 \u001b[36m1.8900\u001b[0m 0.1860 \u001b[35m1.9276\u001b[0m 0.0059\n", + " 7 \u001b[36m1.8748\u001b[0m \u001b[32m0.1900\u001b[0m \u001b[35m1.9261\u001b[0m 0.0098\n", + " 8 \u001b[36m1.8742\u001b[0m \u001b[32m0.1940\u001b[0m \u001b[35m1.9244\u001b[0m 0.0063\n", + " 9 \u001b[36m1.8715\u001b[0m 0.1900 \u001b[35m1.9227\u001b[0m 0.0093\n", + " 10 \u001b[36m1.8554\u001b[0m \u001b[32m0.2000\u001b[0m \u001b[35m1.9201\u001b[0m 0.0063\n", + " 11 \u001b[36m1.8497\u001b[0m \u001b[32m0.2100\u001b[0m \u001b[35m1.9173\u001b[0m 0.0085\n", + " 12 \u001b[36m1.8421\u001b[0m \u001b[32m0.2140\u001b[0m \u001b[35m1.9146\u001b[0m 0.0068\n", + " 13 \u001b[36m1.8155\u001b[0m 0.2140 \u001b[35m1.9122\u001b[0m 0.0074\n", + " 14 1.8206 \u001b[32m0.2200\u001b[0m \u001b[35m1.9101\u001b[0m 0.0065\n", + " 15 1.8275 0.2140 \u001b[35m1.9072\u001b[0m 0.0074\n", + " 16 1.8165 \u001b[32m0.2220\u001b[0m \u001b[35m1.9050\u001b[0m 0.0074\n", + " 17 \u001b[36m1.7948\u001b[0m 0.2220 \u001b[35m1.9018\u001b[0m 0.0073\n", + " 18 \u001b[36m1.7784\u001b[0m \u001b[32m0.2280\u001b[0m \u001b[35m1.8987\u001b[0m 0.0068\n", + " 19 \u001b[36m1.7709\u001b[0m \u001b[32m0.2300\u001b[0m \u001b[35m1.8966\u001b[0m 0.0066\n", + " 20 1.7712 \u001b[32m0.2340\u001b[0m \u001b[35m1.8941\u001b[0m 0.0071\n", + " 21 \u001b[36m1.7370\u001b[0m 0.2340 \u001b[35m1.8912\u001b[0m 0.0074\n", + " 22 1.7756 0.2320 \u001b[35m1.8886\u001b[0m 0.0069\n", + " 23 1.7419 \u001b[32m0.2380\u001b[0m \u001b[35m1.8864\u001b[0m 0.0064\n", + " 24 \u001b[36m1.7358\u001b[0m \u001b[32m0.2420\u001b[0m \u001b[35m1.8833\u001b[0m 0.0071\n", + " 25 1.7479 \u001b[32m0.2500\u001b[0m \u001b[35m1.8806\u001b[0m 0.0076\n", + " 26 \u001b[36m1.7270\u001b[0m \u001b[32m0.2520\u001b[0m \u001b[35m1.8784\u001b[0m 0.0084\n", + " 27 \u001b[36m1.6930\u001b[0m \u001b[32m0.2620\u001b[0m \u001b[35m1.8746\u001b[0m 0.0072\n", + " 28 1.6946 \u001b[32m0.2660\u001b[0m \u001b[35m1.8703\u001b[0m 0.0085\n", + " 29 \u001b[36m1.6879\u001b[0m 0.2620 \u001b[35m1.8682\u001b[0m 0.0072\n", + " 30 1.6928 0.2640 \u001b[35m1.8656\u001b[0m 0.0076\n", + " 31 1.6902 0.2640 \u001b[35m1.8617\u001b[0m 0.0074\n", + " 32 \u001b[36m1.6612\u001b[0m 0.2540 \u001b[35m1.8590\u001b[0m 0.0069\n", + " 33 \u001b[36m1.6379\u001b[0m 0.2640 \u001b[35m1.8563\u001b[0m 0.0068\n", + " 34 \u001b[36m1.6360\u001b[0m \u001b[32m0.2740\u001b[0m \u001b[35m1.8524\u001b[0m 0.0074\n", + " 35 \u001b[36m1.6216\u001b[0m \u001b[32m0.2800\u001b[0m \u001b[35m1.8489\u001b[0m 0.0066\n", + " 36 1.6563 \u001b[32m0.2840\u001b[0m \u001b[35m1.8459\u001b[0m 0.0065\n", + " 37 1.6321 \u001b[32m0.2960\u001b[0m \u001b[35m1.8430\u001b[0m 0.0067\n", + " 38 \u001b[36m1.6134\u001b[0m \u001b[32m0.2980\u001b[0m \u001b[35m1.8394\u001b[0m 0.0079\n", + " 39 \u001b[36m1.6074\u001b[0m 0.2980 \u001b[35m1.8377\u001b[0m 0.0080\n", + " 40 \u001b[36m1.5596\u001b[0m \u001b[32m0.3020\u001b[0m \u001b[35m1.8335\u001b[0m 0.0072\n", + " 41 1.5783 \u001b[32m0.3040\u001b[0m \u001b[35m1.8301\u001b[0m 0.0066\n", + " 42 \u001b[36m1.5340\u001b[0m \u001b[32m0.3060\u001b[0m \u001b[35m1.8276\u001b[0m 0.0063\n", + " 43 1.5659 \u001b[32m0.3140\u001b[0m \u001b[35m1.8241\u001b[0m 0.0069\n", + " 44 \u001b[36m1.5010\u001b[0m \u001b[32m0.3280\u001b[0m \u001b[35m1.8195\u001b[0m 0.0065\n", + " 45 1.5131 \u001b[32m0.3300\u001b[0m \u001b[35m1.8166\u001b[0m 0.0062\n", + " 46 \u001b[36m1.4980\u001b[0m \u001b[32m0.3340\u001b[0m \u001b[35m1.8127\u001b[0m 0.0062\n", + " 47 1.5246 \u001b[32m0.3400\u001b[0m \u001b[35m1.8094\u001b[0m 0.0060\n", + " 48 1.5106 \u001b[32m0.3420\u001b[0m \u001b[35m1.8066\u001b[0m 0.0063\n", + " 49 1.5101 \u001b[32m0.3500\u001b[0m \u001b[35m1.8032\u001b[0m 0.0066\n", + " 50 \u001b[36m1.4888\u001b[0m \u001b[32m0.3520\u001b[0m \u001b[35m1.7998\u001b[0m 0.0065\n", + " 51 \u001b[36m1.4591\u001b[0m \u001b[32m0.3640\u001b[0m \u001b[35m1.7969\u001b[0m 0.0069\n", + " 52 1.4993 \u001b[32m0.3660\u001b[0m \u001b[35m1.7946\u001b[0m 0.0072\n", + " 53 1.4803 \u001b[32m0.3760\u001b[0m \u001b[35m1.7909\u001b[0m 0.0065\n", + " 54 \u001b[36m1.4442\u001b[0m \u001b[32m0.3840\u001b[0m \u001b[35m1.7858\u001b[0m 0.0064\n", + " 55 1.4480 \u001b[32m0.3980\u001b[0m \u001b[35m1.7836\u001b[0m 0.0066\n", + " 56 \u001b[36m1.4204\u001b[0m \u001b[32m0.4000\u001b[0m \u001b[35m1.7784\u001b[0m 0.0066\n", + " 57 \u001b[36m1.3968\u001b[0m 0.3960 \u001b[35m1.7761\u001b[0m 0.0066\n", + " 58 \u001b[36m1.3669\u001b[0m \u001b[32m0.4100\u001b[0m \u001b[35m1.7721\u001b[0m 0.0060\n", + " 59 1.4043 \u001b[32m0.4140\u001b[0m \u001b[35m1.7684\u001b[0m 0.0065\n", + " 60 1.3997 \u001b[32m0.4180\u001b[0m \u001b[35m1.7645\u001b[0m 0.0062\n", + " 61 \u001b[36m1.3579\u001b[0m 0.4120 \u001b[35m1.7605\u001b[0m 0.0066\n", + " 62 \u001b[36m1.3489\u001b[0m \u001b[32m0.4240\u001b[0m \u001b[35m1.7577\u001b[0m 0.0074\n", + " 63 1.3591 0.4220 \u001b[35m1.7532\u001b[0m 0.0071\n", + " 64 1.4029 0.4240 \u001b[35m1.7498\u001b[0m 0.0062\n", + " 65 1.3672 \u001b[32m0.4260\u001b[0m \u001b[35m1.7439\u001b[0m 0.0061\n", + " 66 \u001b[36m1.3081\u001b[0m \u001b[32m0.4280\u001b[0m \u001b[35m1.7403\u001b[0m 0.0061\n", + " 67 1.3654 0.4260 \u001b[35m1.7355\u001b[0m 0.0062\n", + " 68 1.3166 \u001b[32m0.4340\u001b[0m \u001b[35m1.7326\u001b[0m 0.0061\n", + " 69 1.3307 \u001b[32m0.4380\u001b[0m \u001b[35m1.7291\u001b[0m 0.0061\n", + " 70 \u001b[36m1.2918\u001b[0m \u001b[32m0.4440\u001b[0m \u001b[35m1.7261\u001b[0m 0.0061\n", + " 71 \u001b[36m1.2555\u001b[0m \u001b[32m0.4460\u001b[0m \u001b[35m1.7215\u001b[0m 0.0063\n", + " 72 1.2723 0.4440 \u001b[35m1.7175\u001b[0m 0.0065\n", + " 73 1.3018 \u001b[32m0.4480\u001b[0m \u001b[35m1.7133\u001b[0m 0.0062\n", + " 74 1.2618 \u001b[32m0.4520\u001b[0m \u001b[35m1.7088\u001b[0m 0.0070\n", + " 75 1.2653 0.4460 \u001b[35m1.7054\u001b[0m 0.0075\n", + " 76 \u001b[36m1.2415\u001b[0m 0.4480 \u001b[35m1.7005\u001b[0m 0.0070\n", + " 77 1.2647 \u001b[32m0.4620\u001b[0m \u001b[35m1.6967\u001b[0m 0.0062\n", + " 78 \u001b[36m1.2124\u001b[0m \u001b[32m0.4660\u001b[0m \u001b[35m1.6913\u001b[0m 0.0063\n", + " 79 1.2216 \u001b[32m0.4720\u001b[0m \u001b[35m1.6870\u001b[0m 0.0062\n", + " 80 1.2192 \u001b[32m0.4740\u001b[0m \u001b[35m1.6834\u001b[0m 0.0062\n", + " 81 \u001b[36m1.1580\u001b[0m 0.4740 \u001b[35m1.6801\u001b[0m 0.0062\n", + " 82 \u001b[36m1.1166\u001b[0m \u001b[32m0.4760\u001b[0m \u001b[35m1.6758\u001b[0m 0.0065\n", + " 83 1.1823 \u001b[32m0.4780\u001b[0m \u001b[35m1.6715\u001b[0m 0.0072\n", + " 84 1.1905 \u001b[32m0.4900\u001b[0m \u001b[35m1.6669\u001b[0m 0.0068\n", + " 85 \u001b[36m1.1068\u001b[0m 0.4900 \u001b[35m1.6633\u001b[0m 0.0170\n", + " 86 \u001b[36m1.0972\u001b[0m \u001b[32m0.4940\u001b[0m \u001b[35m1.6573\u001b[0m 0.0110\n", + " 87 1.1711 \u001b[32m0.4980\u001b[0m \u001b[35m1.6554\u001b[0m 0.0089\n", + " 88 1.1307 0.4980 \u001b[35m1.6504\u001b[0m 0.0072\n", + " 89 \u001b[36m1.0547\u001b[0m \u001b[32m0.5020\u001b[0m \u001b[35m1.6459\u001b[0m 0.0068\n", + " 90 1.1255 0.5000 \u001b[35m1.6409\u001b[0m 0.0071\n", + " 91 1.0594 \u001b[32m0.5040\u001b[0m \u001b[35m1.6372\u001b[0m 0.0066\n", + " 92 1.1300 \u001b[32m0.5100\u001b[0m \u001b[35m1.6317\u001b[0m 0.0061\n", + " 93 1.0988 \u001b[32m0.5120\u001b[0m \u001b[35m1.6290\u001b[0m 0.0064\n", + " 94 \u001b[36m1.0492\u001b[0m 0.5100 \u001b[35m1.6251\u001b[0m 0.0063\n", + " 95 1.0698 \u001b[32m0.5160\u001b[0m \u001b[35m1.6196\u001b[0m 0.0077\n", + " 96 \u001b[36m1.0099\u001b[0m \u001b[32m0.5220\u001b[0m \u001b[35m1.6140\u001b[0m 0.0087\n", + " 97 1.0835 0.5180 \u001b[35m1.6117\u001b[0m 0.0081\n", + " 98 1.0613 0.5200 \u001b[35m1.6081\u001b[0m 0.0067\n", + " 99 1.0391 \u001b[32m0.5240\u001b[0m \u001b[35m1.6039\u001b[0m 0.0072\n", + " 100 1.1177 0.5200 \u001b[35m1.6006\u001b[0m 0.0075\n", + " 101 1.0404 0.5220 \u001b[35m1.5972\u001b[0m 0.0077\n", + " 102 1.0133 \u001b[32m0.5260\u001b[0m \u001b[35m1.5917\u001b[0m 0.0066\n", + " 103 \u001b[36m0.9971\u001b[0m \u001b[32m0.5280\u001b[0m \u001b[35m1.5888\u001b[0m 0.0073\n", + " 104 \u001b[36m0.9611\u001b[0m \u001b[32m0.5320\u001b[0m \u001b[35m1.5846\u001b[0m 0.0070\n", + " 105 \u001b[36m0.9147\u001b[0m 0.5320 \u001b[35m1.5816\u001b[0m 0.0073\n", + " 106 0.9746 0.5300 \u001b[35m1.5774\u001b[0m 0.0068\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 107 0.9672 0.5320 \u001b[35m1.5743\u001b[0m 0.0071\n", + " 108 0.9656 \u001b[32m0.5360\u001b[0m \u001b[35m1.5719\u001b[0m 0.0073\n", + " 109 1.0130 0.5340 \u001b[35m1.5680\u001b[0m 0.0072\n", + " 110 0.9821 \u001b[32m0.5380\u001b[0m \u001b[35m1.5630\u001b[0m 0.0070\n", + " 111 0.9151 0.5360 \u001b[35m1.5598\u001b[0m 0.0077\n", + " 112 \u001b[36m0.8390\u001b[0m 0.5360 \u001b[35m1.5549\u001b[0m 0.0093\n", + " 113 0.9395 \u001b[32m0.5420\u001b[0m \u001b[35m1.5505\u001b[0m 0.0074\n", + " 114 0.9234 \u001b[32m0.5440\u001b[0m \u001b[35m1.5465\u001b[0m 0.0069\n", + " 115 0.9149 0.5420 \u001b[35m1.5425\u001b[0m 0.0072\n", + " 116 0.9710 0.5400 \u001b[35m1.5399\u001b[0m 0.0071\n", + " 117 0.9081 0.5400 \u001b[35m1.5393\u001b[0m 0.0086\n", + " 118 0.9285 0.5380 \u001b[35m1.5347\u001b[0m 0.0063\n", + " 119 0.8962 0.5440 \u001b[35m1.5320\u001b[0m 0.0080\n", + " 120 0.9140 0.5420 \u001b[35m1.5285\u001b[0m 0.0066\n", + " 121 0.8729 \u001b[32m0.5460\u001b[0m \u001b[35m1.5242\u001b[0m 0.0075\n", + " 122 0.8770 0.5400 \u001b[35m1.5219\u001b[0m 0.0077\n", + " 123 0.8922 \u001b[32m0.5520\u001b[0m \u001b[35m1.5180\u001b[0m 0.0073\n", + " 124 0.8534 0.5460 \u001b[35m1.5161\u001b[0m 0.0080\n", + " 125 0.8476 0.5520 \u001b[35m1.5092\u001b[0m 0.0077\n", + " 126 0.9047 0.5520 \u001b[35m1.5058\u001b[0m 0.0076\n", + " 127 \u001b[36m0.8089\u001b[0m 0.5480 \u001b[35m1.5046\u001b[0m 0.0066\n", + " 128 0.8734 0.5500 \u001b[35m1.5015\u001b[0m 0.0074\n", + " 129 0.8506 \u001b[32m0.5540\u001b[0m \u001b[35m1.4975\u001b[0m 0.0069\n", + " 130 0.8230 0.5540 \u001b[35m1.4946\u001b[0m 0.0075\n", + " 131 \u001b[36m0.7985\u001b[0m 0.5540 \u001b[35m1.4900\u001b[0m 0.0093\n", + " 132 0.8636 \u001b[32m0.5580\u001b[0m \u001b[35m1.4891\u001b[0m 0.0073\n", + " 133 0.8238 0.5580 \u001b[35m1.4844\u001b[0m 0.0066\n", + " 134 \u001b[36m0.7810\u001b[0m \u001b[32m0.5600\u001b[0m \u001b[35m1.4820\u001b[0m 0.0061\n", + " 135 0.8755 0.5600 \u001b[35m1.4792\u001b[0m 0.0066\n", + " 136 0.8143 0.5580 \u001b[35m1.4761\u001b[0m 0.0063\n", + " 137 0.8134 0.5600 \u001b[35m1.4725\u001b[0m 0.0067\n", + " 138 0.8288 \u001b[32m0.5620\u001b[0m \u001b[35m1.4714\u001b[0m 0.0066\n", + " 139 0.8095 0.5600 \u001b[35m1.4695\u001b[0m 0.0065\n", + " 140 \u001b[36m0.7792\u001b[0m 0.5600 \u001b[35m1.4666\u001b[0m 0.0076\n", + " 141 \u001b[36m0.7747\u001b[0m \u001b[32m0.5660\u001b[0m \u001b[35m1.4646\u001b[0m 0.0071\n", + " 142 \u001b[36m0.7375\u001b[0m 0.5660 \u001b[35m1.4606\u001b[0m 0.0071\n", + " 143 0.8205 \u001b[32m0.5720\u001b[0m \u001b[35m1.4577\u001b[0m 0.0068\n", + " 144 0.7945 0.5700 \u001b[35m1.4561\u001b[0m 0.0069\n", + " 145 0.7934 0.5680 \u001b[35m1.4551\u001b[0m 0.0071\n", + " 146 \u001b[36m0.7044\u001b[0m 0.5660 \u001b[35m1.4501\u001b[0m 0.0074\n", + " 147 0.7447 0.5680 \u001b[35m1.4488\u001b[0m 0.0072\n", + " 148 0.8395 0.5660 \u001b[35m1.4486\u001b[0m 0.0070\n", + " 149 0.8077 0.5660 \u001b[35m1.4456\u001b[0m 0.0073\n", + " 150 0.7555 0.5660 \u001b[35m1.4424\u001b[0m 0.0071\n", + " 151 0.7340 0.5680 \u001b[35m1.4395\u001b[0m 0.0073\n", + " 152 0.7571 0.5680 \u001b[35m1.4362\u001b[0m 0.0078\n", + " 153 \u001b[36m0.6802\u001b[0m 0.5660 \u001b[35m1.4343\u001b[0m 0.0073\n", + " 154 0.7739 0.5660 \u001b[35m1.4316\u001b[0m 0.0072\n", + " 155 0.7414 0.5660 \u001b[35m1.4288\u001b[0m 0.0071\n", + " 156 0.7475 0.5660 \u001b[35m1.4274\u001b[0m 0.0072\n", + " 157 0.7376 0.5680 \u001b[35m1.4242\u001b[0m 0.0073\n", + " 158 0.7045 0.5660 \u001b[35m1.4220\u001b[0m 0.0073\n", + " 159 0.8076 0.5680 \u001b[35m1.4196\u001b[0m 0.0071\n", + " 160 0.7094 0.5720 \u001b[35m1.4156\u001b[0m 0.0073\n", + " 161 0.7091 0.5720 \u001b[35m1.4152\u001b[0m 0.0067\n", + " 162 0.7714 0.5680 1.4154 0.0077\n", + " 163 0.6879 \u001b[32m0.5760\u001b[0m \u001b[35m1.4114\u001b[0m 0.0067\n", + " 164 0.7088 0.5740 \u001b[35m1.4095\u001b[0m 0.0073\n", + " 165 0.7036 0.5760 \u001b[35m1.4072\u001b[0m 0.0070\n", + " 166 \u001b[36m0.6331\u001b[0m 0.5740 \u001b[35m1.4056\u001b[0m 0.0064\n", + " 167 0.6961 0.5740 \u001b[35m1.4026\u001b[0m 0.0068\n", + " 168 0.6787 0.5740 \u001b[35m1.4020\u001b[0m 0.0069\n", + " 169 0.6620 0.5760 \u001b[35m1.3986\u001b[0m 0.0072\n", + " 170 0.6767 0.5760 \u001b[35m1.3980\u001b[0m 0.0072\n", + " 171 0.6663 0.5720 \u001b[35m1.3974\u001b[0m 0.0076\n", + " 172 \u001b[36m0.5900\u001b[0m 0.5720 \u001b[35m1.3955\u001b[0m 0.0081\n", + " 173 0.6731 0.5740 \u001b[35m1.3931\u001b[0m 0.0074\n", + " 174 0.7121 0.5720 \u001b[35m1.3913\u001b[0m 0.0067\n", + " 175 0.6490 0.5740 \u001b[35m1.3882\u001b[0m 0.0062\n", + " 176 0.6253 0.5740 \u001b[35m1.3849\u001b[0m 0.0060\n", + " 177 0.6121 0.5740 \u001b[35m1.3835\u001b[0m 0.0062\n", + " 178 0.6331 0.5740 \u001b[35m1.3813\u001b[0m 0.0059\n", + " 179 0.6628 0.5720 \u001b[35m1.3810\u001b[0m 0.0070\n", + " 180 0.6165 0.5740 \u001b[35m1.3790\u001b[0m 0.0072\n", + " 181 0.5927 0.5720 \u001b[35m1.3769\u001b[0m 0.0074\n", + " 182 0.5952 0.5720 \u001b[35m1.3745\u001b[0m 0.0080\n", + " 183 \u001b[36m0.5650\u001b[0m 0.5740 \u001b[35m1.3740\u001b[0m 0.0074\n", + " 184 0.7050 0.5700 \u001b[35m1.3716\u001b[0m 0.0066\n", + " 185 0.6149 0.5720 \u001b[35m1.3708\u001b[0m 0.0060\n", + " 186 0.5872 \u001b[32m0.5780\u001b[0m \u001b[35m1.3678\u001b[0m 0.0059\n", + " 187 0.6876 0.5700 \u001b[35m1.3653\u001b[0m 0.0061\n", + " 188 0.6151 0.5740 \u001b[35m1.3649\u001b[0m 0.0059\n", + " 189 0.6035 0.5740 \u001b[35m1.3641\u001b[0m 0.0060\n", + " 190 0.6302 0.5720 \u001b[35m1.3637\u001b[0m 0.0071\n", + " 191 0.6218 \u001b[32m0.5800\u001b[0m \u001b[35m1.3597\u001b[0m 0.0068\n", + " 192 0.5813 0.5720 \u001b[35m1.3595\u001b[0m 0.0078\n", + " 193 0.6559 0.5760 \u001b[35m1.3581\u001b[0m 0.0067\n", + " 194 0.5938 0.5760 \u001b[35m1.3561\u001b[0m 0.0062\n", + " 195 0.6327 0.5780 \u001b[35m1.3544\u001b[0m 0.0069\n", + " 196 0.6192 0.5760 \u001b[35m1.3531\u001b[0m 0.0062\n", + " 197 0.6347 0.5760 \u001b[35m1.3515\u001b[0m 0.0065\n", + " 198 0.6515 0.5760 \u001b[35m1.3496\u001b[0m 0.0063\n", + " 199 0.6065 0.5780 \u001b[35m1.3485\u001b[0m 0.0064\n", + " 200 0.6256 0.5780 \u001b[35m1.3467\u001b[0m 0.0067\n" + ] + }, + { + "data": { + "text/plain": [ + "[initialized](\n", + " module_=GCN(\n", + " (conv1): GCNConv(1433, 16)\n", + " (conv2): GCNConv(16, 7)\n", + " ),\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.fit(ds_train, None)" + ] + }, + { + "cell_type": "markdown", + "id": "fed65b00", + "metadata": {}, + "source": [ + "### Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ae562c9e", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import accuracy_score" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ef5f2409", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.683" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy_score(ds_test.y, net.predict(ds_test))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/skorch/utils.py b/skorch/utils.py index a2c1d4af9..688e5a638 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -37,6 +37,12 @@ except ImportError: gpytorch = None +try: + import torch_geometric + TORCH_GEOMETRIC_INSTALLED = True +except ImportError: + TORCH_GEOMETRIC_INSTALLED = False + class Ansi(Enum): BLUE = '\033[94m' @@ -56,6 +62,11 @@ def is_dataset(x): return isinstance(x, torch.utils.data.Dataset) +def is_geometric_data_type(x): + from torch_geometric.data import Data + return isinstance(x, Data) + + # pylint: disable=not-callable def to_tensor(X, device, accept_sparse=False): """Turn input data to torch tensor. @@ -90,6 +101,8 @@ def to_tensor(X, device, accept_sparse=False): if is_torch_data_type(X): return to_device(X, device) + if TORCH_GEOMETRIC_INSTALLED and is_geometric_data_type(X): + return to_device(X, device) if hasattr(X, 'convert_to_tensors'): # huggingface transformers BatchEncoding return X.convert_to_tensors('pt') From 377040750e3d23ccb11d3a4aa6ae575faf1bd5a4 Mon Sep 17 00:00:00 2001 From: Marian Tietz Date: Fri, 22 Jul 2022 18:52:58 +0200 Subject: [PATCH 2/3] Address review comments --- notebooks/CORA-geometric.ipynb | 443 +++++++++++++++++---------------- 1 file changed, 232 insertions(+), 211 deletions(-) diff --git a/notebooks/CORA-geometric.ipynb b/notebooks/CORA-geometric.ipynb index 7dde251bb..cfe3699c1 100644 --- a/notebooks/CORA-geometric.ipynb +++ b/notebooks/CORA-geometric.ipynb @@ -18,7 +18,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Di 21. Jun 15:04:50 CEST 2022\r\n" + "Fr 22. Jul 18:52:31 CEST 2022\r\n" ] } ], @@ -31,19 +31,27 @@ "id": "4e144398", "metadata": {}, "source": [ - "This is an example for how to use skorch with [torch geometric](https://pytorch-geometric.readthedocs.io/). The code is based on the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) but modified to have a proper train/valid/test split. This example is showcasing a quite small data set that does not need to employ batching to be trained efficiently. How to do batching with skorch + torch geometric will not be handled here.\n", + "This is an example for how to use skorch with [torch geometric](https://pytorch-geometric.readthedocs.io/). The code is based on the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) but modified to have a proper train/valid/test split. This example is showcasing a quite small data set that does not need to employ batching to be trained efficiently. How to do batching with skorch + torch geometric will not be handled here since it is non-trivial and quite dataset specific - if you need this and are stuck, feel free to open [an issue](https://github.com/skorch-dev/skorch/issues) so that we can support you the best we can.\n", "\n", "Dependencies of this notebook besides skorch base installation:" ] }, { "cell_type": "raw", - "id": "ac7d2260", + "id": "3748b2a0", "metadata": {}, "source": [ "!pip install torch_geometric==2.0.4" ] }, + { + "cell_type": "markdown", + "id": "7ce31902", + "metadata": {}, + "source": [ + "It is recommended to install the dependencies [as documented by pytorch geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)." + ] + }, { "cell_type": "markdown", "id": "ff707ef1", @@ -162,11 +170,12 @@ "source": [ "Split the graph into train, validation and test sub-graphs.\n", "This ensures that there will be no leakage between steps when we apply graph\n", - "convolution operators on the graph.\n", + "convolution operators on the graph since each split has its own sub-graph.\n", "\n", - "We use `relabel_nodes=True` to re-index the edges when the mask starts in the\n", - "middle of the node tensor. Without this flag a mask like `[0, 1, 0]` will assume\n", - "that `nodes[1]` exists but our mask forbids this: `nodes = all_nodes[[0, 1, 0]] == []`." + "We use `relabel_nodes=True` to make the node indices in the edge tensor \n", + "zero-based for each sub-graph. If we would not do this the node subsets\n", + "(now zero-based after applying the mask) would not match the indices in the\n", + "edge tensor." ] }, { @@ -228,7 +237,9 @@ "batch into `(X, y)` we will have `X = Data(...)` and `y = [y_true]`.\n", "The `DataLoader` does not modify `X` but `y` gets a new batch dimension.\n", "This will lead to a shape mismatch as `y.shape` would then be `(1, #num_samples)`. Therefore, we need our own loader that strips the first dimension to \n", - "match the predicted `y` and the labelled `y` in length." + "match the predicted `y` and the labelled `y` in length.\n", + "\n", + "Note: It is possible to avoid this by stripping this dimension by overriding `get_loss` in the `NeuralNet` class. For brevity we won't do this in this example. It is possible to use [one of the many `DataLoader` classes](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html) provided by torch geometric using the approach outlined below (just base the `RawDataloader` on one of the other classes) - chances are, though, that if you are doing this you need to deal with batching anyway which is a topic that is not handled here since it is not trivial." ] }, { @@ -336,212 +347,212 @@ "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", - " 1 \u001b[36m1.9659\u001b[0m \u001b[32m0.1620\u001b[0m \u001b[35m1.9380\u001b[0m 0.1744\n", - " 2 \u001b[36m1.9286\u001b[0m \u001b[32m0.1640\u001b[0m \u001b[35m1.9361\u001b[0m 0.0062\n", - " 3 1.9395 0.1640 \u001b[35m1.9333\u001b[0m 0.0062\n", - " 4 \u001b[36m1.9249\u001b[0m \u001b[32m0.1780\u001b[0m \u001b[35m1.9311\u001b[0m 0.0067\n", - " 5 \u001b[36m1.9008\u001b[0m \u001b[32m0.1880\u001b[0m \u001b[35m1.9292\u001b[0m 0.0066\n", - " 6 \u001b[36m1.8900\u001b[0m 0.1860 \u001b[35m1.9276\u001b[0m 0.0059\n", - " 7 \u001b[36m1.8748\u001b[0m \u001b[32m0.1900\u001b[0m \u001b[35m1.9261\u001b[0m 0.0098\n", - " 8 \u001b[36m1.8742\u001b[0m \u001b[32m0.1940\u001b[0m \u001b[35m1.9244\u001b[0m 0.0063\n", - " 9 \u001b[36m1.8715\u001b[0m 0.1900 \u001b[35m1.9227\u001b[0m 0.0093\n", - " 10 \u001b[36m1.8554\u001b[0m \u001b[32m0.2000\u001b[0m \u001b[35m1.9201\u001b[0m 0.0063\n", - " 11 \u001b[36m1.8497\u001b[0m \u001b[32m0.2100\u001b[0m \u001b[35m1.9173\u001b[0m 0.0085\n", - " 12 \u001b[36m1.8421\u001b[0m \u001b[32m0.2140\u001b[0m \u001b[35m1.9146\u001b[0m 0.0068\n", - " 13 \u001b[36m1.8155\u001b[0m 0.2140 \u001b[35m1.9122\u001b[0m 0.0074\n", - " 14 1.8206 \u001b[32m0.2200\u001b[0m \u001b[35m1.9101\u001b[0m 0.0065\n", - " 15 1.8275 0.2140 \u001b[35m1.9072\u001b[0m 0.0074\n", - " 16 1.8165 \u001b[32m0.2220\u001b[0m \u001b[35m1.9050\u001b[0m 0.0074\n", - " 17 \u001b[36m1.7948\u001b[0m 0.2220 \u001b[35m1.9018\u001b[0m 0.0073\n", - " 18 \u001b[36m1.7784\u001b[0m \u001b[32m0.2280\u001b[0m \u001b[35m1.8987\u001b[0m 0.0068\n", - " 19 \u001b[36m1.7709\u001b[0m \u001b[32m0.2300\u001b[0m \u001b[35m1.8966\u001b[0m 0.0066\n", - " 20 1.7712 \u001b[32m0.2340\u001b[0m \u001b[35m1.8941\u001b[0m 0.0071\n", - " 21 \u001b[36m1.7370\u001b[0m 0.2340 \u001b[35m1.8912\u001b[0m 0.0074\n", - " 22 1.7756 0.2320 \u001b[35m1.8886\u001b[0m 0.0069\n", - " 23 1.7419 \u001b[32m0.2380\u001b[0m \u001b[35m1.8864\u001b[0m 0.0064\n", - " 24 \u001b[36m1.7358\u001b[0m \u001b[32m0.2420\u001b[0m \u001b[35m1.8833\u001b[0m 0.0071\n", - " 25 1.7479 \u001b[32m0.2500\u001b[0m \u001b[35m1.8806\u001b[0m 0.0076\n", - " 26 \u001b[36m1.7270\u001b[0m \u001b[32m0.2520\u001b[0m \u001b[35m1.8784\u001b[0m 0.0084\n", - " 27 \u001b[36m1.6930\u001b[0m \u001b[32m0.2620\u001b[0m \u001b[35m1.8746\u001b[0m 0.0072\n", - " 28 1.6946 \u001b[32m0.2660\u001b[0m \u001b[35m1.8703\u001b[0m 0.0085\n", - " 29 \u001b[36m1.6879\u001b[0m 0.2620 \u001b[35m1.8682\u001b[0m 0.0072\n", - " 30 1.6928 0.2640 \u001b[35m1.8656\u001b[0m 0.0076\n", - " 31 1.6902 0.2640 \u001b[35m1.8617\u001b[0m 0.0074\n", - " 32 \u001b[36m1.6612\u001b[0m 0.2540 \u001b[35m1.8590\u001b[0m 0.0069\n", - " 33 \u001b[36m1.6379\u001b[0m 0.2640 \u001b[35m1.8563\u001b[0m 0.0068\n", - " 34 \u001b[36m1.6360\u001b[0m \u001b[32m0.2740\u001b[0m \u001b[35m1.8524\u001b[0m 0.0074\n", - " 35 \u001b[36m1.6216\u001b[0m \u001b[32m0.2800\u001b[0m \u001b[35m1.8489\u001b[0m 0.0066\n", - " 36 1.6563 \u001b[32m0.2840\u001b[0m \u001b[35m1.8459\u001b[0m 0.0065\n", - " 37 1.6321 \u001b[32m0.2960\u001b[0m \u001b[35m1.8430\u001b[0m 0.0067\n", - " 38 \u001b[36m1.6134\u001b[0m \u001b[32m0.2980\u001b[0m \u001b[35m1.8394\u001b[0m 0.0079\n", - " 39 \u001b[36m1.6074\u001b[0m 0.2980 \u001b[35m1.8377\u001b[0m 0.0080\n", - " 40 \u001b[36m1.5596\u001b[0m \u001b[32m0.3020\u001b[0m \u001b[35m1.8335\u001b[0m 0.0072\n", - " 41 1.5783 \u001b[32m0.3040\u001b[0m \u001b[35m1.8301\u001b[0m 0.0066\n", - " 42 \u001b[36m1.5340\u001b[0m \u001b[32m0.3060\u001b[0m \u001b[35m1.8276\u001b[0m 0.0063\n", - " 43 1.5659 \u001b[32m0.3140\u001b[0m \u001b[35m1.8241\u001b[0m 0.0069\n", - " 44 \u001b[36m1.5010\u001b[0m \u001b[32m0.3280\u001b[0m \u001b[35m1.8195\u001b[0m 0.0065\n", - " 45 1.5131 \u001b[32m0.3300\u001b[0m \u001b[35m1.8166\u001b[0m 0.0062\n", - " 46 \u001b[36m1.4980\u001b[0m \u001b[32m0.3340\u001b[0m \u001b[35m1.8127\u001b[0m 0.0062\n", - " 47 1.5246 \u001b[32m0.3400\u001b[0m \u001b[35m1.8094\u001b[0m 0.0060\n", - " 48 1.5106 \u001b[32m0.3420\u001b[0m \u001b[35m1.8066\u001b[0m 0.0063\n", - " 49 1.5101 \u001b[32m0.3500\u001b[0m \u001b[35m1.8032\u001b[0m 0.0066\n", - " 50 \u001b[36m1.4888\u001b[0m \u001b[32m0.3520\u001b[0m \u001b[35m1.7998\u001b[0m 0.0065\n", - " 51 \u001b[36m1.4591\u001b[0m \u001b[32m0.3640\u001b[0m \u001b[35m1.7969\u001b[0m 0.0069\n", - " 52 1.4993 \u001b[32m0.3660\u001b[0m \u001b[35m1.7946\u001b[0m 0.0072\n", - " 53 1.4803 \u001b[32m0.3760\u001b[0m \u001b[35m1.7909\u001b[0m 0.0065\n", - " 54 \u001b[36m1.4442\u001b[0m \u001b[32m0.3840\u001b[0m \u001b[35m1.7858\u001b[0m 0.0064\n", - " 55 1.4480 \u001b[32m0.3980\u001b[0m \u001b[35m1.7836\u001b[0m 0.0066\n", - " 56 \u001b[36m1.4204\u001b[0m \u001b[32m0.4000\u001b[0m \u001b[35m1.7784\u001b[0m 0.0066\n", - " 57 \u001b[36m1.3968\u001b[0m 0.3960 \u001b[35m1.7761\u001b[0m 0.0066\n", - " 58 \u001b[36m1.3669\u001b[0m \u001b[32m0.4100\u001b[0m \u001b[35m1.7721\u001b[0m 0.0060\n", - " 59 1.4043 \u001b[32m0.4140\u001b[0m \u001b[35m1.7684\u001b[0m 0.0065\n", - " 60 1.3997 \u001b[32m0.4180\u001b[0m \u001b[35m1.7645\u001b[0m 0.0062\n", - " 61 \u001b[36m1.3579\u001b[0m 0.4120 \u001b[35m1.7605\u001b[0m 0.0066\n", - " 62 \u001b[36m1.3489\u001b[0m \u001b[32m0.4240\u001b[0m \u001b[35m1.7577\u001b[0m 0.0074\n", - " 63 1.3591 0.4220 \u001b[35m1.7532\u001b[0m 0.0071\n", - " 64 1.4029 0.4240 \u001b[35m1.7498\u001b[0m 0.0062\n", - " 65 1.3672 \u001b[32m0.4260\u001b[0m \u001b[35m1.7439\u001b[0m 0.0061\n", - " 66 \u001b[36m1.3081\u001b[0m \u001b[32m0.4280\u001b[0m \u001b[35m1.7403\u001b[0m 0.0061\n", - " 67 1.3654 0.4260 \u001b[35m1.7355\u001b[0m 0.0062\n", - " 68 1.3166 \u001b[32m0.4340\u001b[0m \u001b[35m1.7326\u001b[0m 0.0061\n", - " 69 1.3307 \u001b[32m0.4380\u001b[0m \u001b[35m1.7291\u001b[0m 0.0061\n", - " 70 \u001b[36m1.2918\u001b[0m \u001b[32m0.4440\u001b[0m \u001b[35m1.7261\u001b[0m 0.0061\n", - " 71 \u001b[36m1.2555\u001b[0m \u001b[32m0.4460\u001b[0m \u001b[35m1.7215\u001b[0m 0.0063\n", - " 72 1.2723 0.4440 \u001b[35m1.7175\u001b[0m 0.0065\n", - " 73 1.3018 \u001b[32m0.4480\u001b[0m \u001b[35m1.7133\u001b[0m 0.0062\n", - " 74 1.2618 \u001b[32m0.4520\u001b[0m \u001b[35m1.7088\u001b[0m 0.0070\n", - " 75 1.2653 0.4460 \u001b[35m1.7054\u001b[0m 0.0075\n", - " 76 \u001b[36m1.2415\u001b[0m 0.4480 \u001b[35m1.7005\u001b[0m 0.0070\n", - " 77 1.2647 \u001b[32m0.4620\u001b[0m \u001b[35m1.6967\u001b[0m 0.0062\n", - " 78 \u001b[36m1.2124\u001b[0m \u001b[32m0.4660\u001b[0m \u001b[35m1.6913\u001b[0m 0.0063\n", - " 79 1.2216 \u001b[32m0.4720\u001b[0m \u001b[35m1.6870\u001b[0m 0.0062\n", - " 80 1.2192 \u001b[32m0.4740\u001b[0m \u001b[35m1.6834\u001b[0m 0.0062\n", - " 81 \u001b[36m1.1580\u001b[0m 0.4740 \u001b[35m1.6801\u001b[0m 0.0062\n", - " 82 \u001b[36m1.1166\u001b[0m \u001b[32m0.4760\u001b[0m \u001b[35m1.6758\u001b[0m 0.0065\n", - " 83 1.1823 \u001b[32m0.4780\u001b[0m \u001b[35m1.6715\u001b[0m 0.0072\n", - " 84 1.1905 \u001b[32m0.4900\u001b[0m \u001b[35m1.6669\u001b[0m 0.0068\n", - " 85 \u001b[36m1.1068\u001b[0m 0.4900 \u001b[35m1.6633\u001b[0m 0.0170\n", - " 86 \u001b[36m1.0972\u001b[0m \u001b[32m0.4940\u001b[0m \u001b[35m1.6573\u001b[0m 0.0110\n", - " 87 1.1711 \u001b[32m0.4980\u001b[0m \u001b[35m1.6554\u001b[0m 0.0089\n", - " 88 1.1307 0.4980 \u001b[35m1.6504\u001b[0m 0.0072\n", - " 89 \u001b[36m1.0547\u001b[0m \u001b[32m0.5020\u001b[0m \u001b[35m1.6459\u001b[0m 0.0068\n", - " 90 1.1255 0.5000 \u001b[35m1.6409\u001b[0m 0.0071\n", - " 91 1.0594 \u001b[32m0.5040\u001b[0m \u001b[35m1.6372\u001b[0m 0.0066\n", - " 92 1.1300 \u001b[32m0.5100\u001b[0m \u001b[35m1.6317\u001b[0m 0.0061\n", - " 93 1.0988 \u001b[32m0.5120\u001b[0m \u001b[35m1.6290\u001b[0m 0.0064\n", - " 94 \u001b[36m1.0492\u001b[0m 0.5100 \u001b[35m1.6251\u001b[0m 0.0063\n", - " 95 1.0698 \u001b[32m0.5160\u001b[0m \u001b[35m1.6196\u001b[0m 0.0077\n", - " 96 \u001b[36m1.0099\u001b[0m \u001b[32m0.5220\u001b[0m \u001b[35m1.6140\u001b[0m 0.0087\n", - " 97 1.0835 0.5180 \u001b[35m1.6117\u001b[0m 0.0081\n", - " 98 1.0613 0.5200 \u001b[35m1.6081\u001b[0m 0.0067\n", - " 99 1.0391 \u001b[32m0.5240\u001b[0m \u001b[35m1.6039\u001b[0m 0.0072\n", - " 100 1.1177 0.5200 \u001b[35m1.6006\u001b[0m 0.0075\n", - " 101 1.0404 0.5220 \u001b[35m1.5972\u001b[0m 0.0077\n", - " 102 1.0133 \u001b[32m0.5260\u001b[0m \u001b[35m1.5917\u001b[0m 0.0066\n", - " 103 \u001b[36m0.9971\u001b[0m \u001b[32m0.5280\u001b[0m \u001b[35m1.5888\u001b[0m 0.0073\n", - " 104 \u001b[36m0.9611\u001b[0m \u001b[32m0.5320\u001b[0m \u001b[35m1.5846\u001b[0m 0.0070\n", - " 105 \u001b[36m0.9147\u001b[0m 0.5320 \u001b[35m1.5816\u001b[0m 0.0073\n", - " 106 0.9746 0.5300 \u001b[35m1.5774\u001b[0m 0.0068\n" + " 1 \u001b[36m1.9724\u001b[0m \u001b[32m0.1680\u001b[0m \u001b[35m1.9398\u001b[0m 0.0074\n", + " 2 \u001b[36m1.9625\u001b[0m \u001b[32m0.1740\u001b[0m \u001b[35m1.9376\u001b[0m 0.0052\n", + " 3 \u001b[36m1.9327\u001b[0m 0.1720 \u001b[35m1.9342\u001b[0m 0.0043\n", + " 4 \u001b[36m1.9321\u001b[0m \u001b[32m0.1760\u001b[0m \u001b[35m1.9324\u001b[0m 0.0045\n", + " 5 \u001b[36m1.9142\u001b[0m \u001b[32m0.1800\u001b[0m \u001b[35m1.9307\u001b[0m 0.0037\n", + " 6 \u001b[36m1.8923\u001b[0m 0.1800 \u001b[35m1.9290\u001b[0m 0.0036\n", + " 7 \u001b[36m1.8848\u001b[0m \u001b[32m0.1880\u001b[0m \u001b[35m1.9269\u001b[0m 0.0042\n", + " 8 1.8936 \u001b[32m0.1920\u001b[0m \u001b[35m1.9247\u001b[0m 0.0042\n", + " 9 \u001b[36m1.8783\u001b[0m \u001b[32m0.1960\u001b[0m \u001b[35m1.9219\u001b[0m 0.0044\n", + " 10 \u001b[36m1.8737\u001b[0m \u001b[32m0.2040\u001b[0m \u001b[35m1.9192\u001b[0m 0.0044\n", + " 11 \u001b[36m1.8542\u001b[0m \u001b[32m0.2060\u001b[0m \u001b[35m1.9176\u001b[0m 0.0042\n", + " 12 \u001b[36m1.8489\u001b[0m \u001b[32m0.2100\u001b[0m \u001b[35m1.9156\u001b[0m 0.0042\n", + " 13 \u001b[36m1.8314\u001b[0m \u001b[32m0.2120\u001b[0m \u001b[35m1.9121\u001b[0m 0.0074\n", + " 14 1.8334 \u001b[32m0.2220\u001b[0m \u001b[35m1.9100\u001b[0m 0.0055\n", + " 15 \u001b[36m1.8041\u001b[0m \u001b[32m0.2240\u001b[0m \u001b[35m1.9085\u001b[0m 0.0081\n", + " 16 1.8089 0.2220 \u001b[35m1.9065\u001b[0m 0.0063\n", + " 17 1.8082 0.2200 \u001b[35m1.9043\u001b[0m 0.0081\n", + " 18 \u001b[36m1.7759\u001b[0m 0.2220 \u001b[35m1.9023\u001b[0m 0.0094\n", + " 19 \u001b[36m1.7745\u001b[0m 0.2220 \u001b[35m1.8992\u001b[0m 0.0049\n", + " 20 \u001b[36m1.7630\u001b[0m \u001b[32m0.2280\u001b[0m \u001b[35m1.8970\u001b[0m 0.0069\n", + " 21 \u001b[36m1.7411\u001b[0m \u001b[32m0.2340\u001b[0m \u001b[35m1.8946\u001b[0m 0.0068\n", + " 22 1.7732 \u001b[32m0.2360\u001b[0m \u001b[35m1.8921\u001b[0m 0.0054\n", + " 23 \u001b[36m1.7407\u001b[0m \u001b[32m0.2420\u001b[0m \u001b[35m1.8893\u001b[0m 0.0061\n", + " 24 \u001b[36m1.7259\u001b[0m 0.2420 \u001b[35m1.8857\u001b[0m 0.0046\n", + " 25 \u001b[36m1.6920\u001b[0m \u001b[32m0.2520\u001b[0m \u001b[35m1.8836\u001b[0m 0.0095\n", + " 26 1.7033 \u001b[32m0.2540\u001b[0m \u001b[35m1.8805\u001b[0m 0.0070\n", + " 27 1.7080 \u001b[32m0.2580\u001b[0m \u001b[35m1.8767\u001b[0m 0.0071\n", + " 28 1.6924 \u001b[32m0.2620\u001b[0m \u001b[35m1.8741\u001b[0m 0.0064\n", + " 29 \u001b[36m1.6882\u001b[0m 0.2620 \u001b[35m1.8703\u001b[0m 0.0049\n", + " 30 \u001b[36m1.6850\u001b[0m \u001b[32m0.2660\u001b[0m \u001b[35m1.8679\u001b[0m 0.0061\n", + " 31 \u001b[36m1.6438\u001b[0m 0.2580 \u001b[35m1.8650\u001b[0m 0.0077\n", + " 32 \u001b[36m1.6345\u001b[0m \u001b[32m0.2680\u001b[0m \u001b[35m1.8618\u001b[0m 0.0088\n", + " 33 1.6816 0.2660 \u001b[35m1.8579\u001b[0m 0.0089\n", + " 34 \u001b[36m1.6169\u001b[0m 0.2620 \u001b[35m1.8559\u001b[0m 0.0065\n", + " 35 1.6373 \u001b[32m0.2720\u001b[0m \u001b[35m1.8522\u001b[0m 0.0047\n", + " 36 \u001b[36m1.6107\u001b[0m 0.2700 \u001b[35m1.8491\u001b[0m 0.0100\n", + " 37 \u001b[36m1.6035\u001b[0m \u001b[32m0.2800\u001b[0m \u001b[35m1.8449\u001b[0m 0.0069\n", + " 38 1.6060 \u001b[32m0.2840\u001b[0m \u001b[35m1.8421\u001b[0m 0.0062\n", + " 39 \u001b[36m1.5604\u001b[0m \u001b[32m0.2960\u001b[0m \u001b[35m1.8389\u001b[0m 0.0075\n", + " 40 1.5724 \u001b[32m0.3060\u001b[0m \u001b[35m1.8354\u001b[0m 0.0093\n", + " 41 \u001b[36m1.5371\u001b[0m \u001b[32m0.3160\u001b[0m \u001b[35m1.8319\u001b[0m 0.0066\n", + " 42 \u001b[36m1.5246\u001b[0m \u001b[32m0.3240\u001b[0m \u001b[35m1.8281\u001b[0m 0.0044\n", + " 43 1.5524 0.3200 \u001b[35m1.8241\u001b[0m 0.0078\n", + " 44 1.5282 \u001b[32m0.3300\u001b[0m \u001b[35m1.8211\u001b[0m 0.0068\n", + " 45 1.5356 \u001b[32m0.3380\u001b[0m \u001b[35m1.8169\u001b[0m 0.0057\n", + " 46 \u001b[36m1.5079\u001b[0m \u001b[32m0.3440\u001b[0m \u001b[35m1.8137\u001b[0m 0.0089\n", + " 47 1.5192 \u001b[32m0.3500\u001b[0m \u001b[35m1.8090\u001b[0m 0.0072\n", + " 48 \u001b[36m1.4991\u001b[0m \u001b[32m0.3540\u001b[0m \u001b[35m1.8063\u001b[0m 0.0045\n", + " 49 \u001b[36m1.4949\u001b[0m 0.3460 \u001b[35m1.8036\u001b[0m 0.0085\n", + " 50 \u001b[36m1.4892\u001b[0m \u001b[32m0.3640\u001b[0m \u001b[35m1.8000\u001b[0m 0.0051\n", + " 51 1.5165 \u001b[32m0.3760\u001b[0m \u001b[35m1.7968\u001b[0m 0.0090\n", + " 52 \u001b[36m1.4367\u001b[0m 0.3740 \u001b[35m1.7931\u001b[0m 0.0079\n", + " 53 1.4473 0.3700 \u001b[35m1.7894\u001b[0m 0.0045\n", + " 54 1.4387 \u001b[32m0.3840\u001b[0m \u001b[35m1.7855\u001b[0m 0.0052\n", + " 55 \u001b[36m1.4261\u001b[0m 0.3840 \u001b[35m1.7825\u001b[0m 0.0088\n", + " 56 1.4355 \u001b[32m0.4040\u001b[0m \u001b[35m1.7768\u001b[0m 0.0075\n", + " 57 1.4270 0.3900 \u001b[35m1.7749\u001b[0m 0.0098\n", + " 58 \u001b[36m1.4029\u001b[0m 0.4000 \u001b[35m1.7714\u001b[0m 0.0087\n", + " 59 \u001b[36m1.3793\u001b[0m 0.4040 \u001b[35m1.7679\u001b[0m 0.0080\n", + " 60 \u001b[36m1.3493\u001b[0m 0.4020 \u001b[35m1.7629\u001b[0m 0.0051\n", + " 61 1.3624 \u001b[32m0.4160\u001b[0m \u001b[35m1.7597\u001b[0m 0.0083\n", + " 62 1.3970 \u001b[32m0.4180\u001b[0m \u001b[35m1.7562\u001b[0m 0.0082\n", + " 63 1.3552 \u001b[32m0.4220\u001b[0m \u001b[35m1.7516\u001b[0m 0.0057\n", + " 64 1.3745 \u001b[32m0.4240\u001b[0m \u001b[35m1.7480\u001b[0m 0.0054\n", + " 65 1.4002 \u001b[32m0.4260\u001b[0m \u001b[35m1.7448\u001b[0m 0.0086\n", + " 66 \u001b[36m1.2924\u001b[0m \u001b[32m0.4280\u001b[0m \u001b[35m1.7405\u001b[0m 0.0071\n", + " 67 1.2954 \u001b[32m0.4300\u001b[0m \u001b[35m1.7375\u001b[0m 0.0070\n", + " 68 \u001b[36m1.2785\u001b[0m \u001b[32m0.4320\u001b[0m \u001b[35m1.7319\u001b[0m 0.0070\n", + " 69 1.3192 0.4300 \u001b[35m1.7290\u001b[0m 0.0088\n", + " 70 1.3049 \u001b[32m0.4360\u001b[0m \u001b[35m1.7246\u001b[0m 0.0057\n", + " 71 \u001b[36m1.2504\u001b[0m \u001b[32m0.4420\u001b[0m \u001b[35m1.7198\u001b[0m 0.0085\n", + " 72 1.2841 0.4340 \u001b[35m1.7165\u001b[0m 0.0077\n", + " 73 \u001b[36m1.2304\u001b[0m \u001b[32m0.4460\u001b[0m \u001b[35m1.7120\u001b[0m 0.0067\n", + " 74 1.2414 \u001b[32m0.4540\u001b[0m \u001b[35m1.7070\u001b[0m 0.0085\n", + " 75 \u001b[36m1.1753\u001b[0m 0.4520 \u001b[35m1.7020\u001b[0m 0.0076\n", + " 76 1.2608 \u001b[32m0.4580\u001b[0m \u001b[35m1.6981\u001b[0m 0.0076\n", + " 77 1.2053 0.4580 \u001b[35m1.6935\u001b[0m 0.0084\n", + " 78 1.2640 \u001b[32m0.4600\u001b[0m \u001b[35m1.6910\u001b[0m 0.0056\n", + " 79 1.2251 \u001b[32m0.4700\u001b[0m \u001b[35m1.6845\u001b[0m 0.0069\n", + " 80 1.2221 \u001b[32m0.4780\u001b[0m \u001b[35m1.6801\u001b[0m 0.0044\n", + " 81 \u001b[36m1.1499\u001b[0m 0.4760 \u001b[35m1.6761\u001b[0m 0.0073\n", + " 82 1.1761 \u001b[32m0.4820\u001b[0m \u001b[35m1.6727\u001b[0m 0.0062\n", + " 83 \u001b[36m1.1286\u001b[0m \u001b[32m0.4880\u001b[0m \u001b[35m1.6673\u001b[0m 0.0082\n", + " 84 1.1338 \u001b[32m0.4920\u001b[0m \u001b[35m1.6634\u001b[0m 0.0053\n", + " 85 \u001b[36m1.1273\u001b[0m \u001b[32m0.4940\u001b[0m \u001b[35m1.6593\u001b[0m 0.0089\n", + " 86 1.1289 0.4900 \u001b[35m1.6548\u001b[0m 0.0078\n", + " 87 1.1618 \u001b[32m0.4960\u001b[0m \u001b[35m1.6512\u001b[0m 0.0064\n", + " 88 1.1306 \u001b[32m0.4980\u001b[0m \u001b[35m1.6474\u001b[0m 0.0054\n", + " 89 1.1436 \u001b[32m0.5000\u001b[0m \u001b[35m1.6438\u001b[0m 0.0091\n", + " 90 \u001b[36m1.0675\u001b[0m \u001b[32m0.5020\u001b[0m \u001b[35m1.6397\u001b[0m 0.0083\n", + " 91 1.0798 0.5000 \u001b[35m1.6360\u001b[0m 0.0072\n", + " 92 1.1148 0.4980 \u001b[35m1.6330\u001b[0m 0.0085\n", + " 93 1.0830 \u001b[32m0.5040\u001b[0m \u001b[35m1.6276\u001b[0m 0.0064\n", + " 94 1.1569 0.5020 \u001b[35m1.6246\u001b[0m 0.0051\n", + " 95 \u001b[36m1.0338\u001b[0m 0.5020 \u001b[35m1.6197\u001b[0m 0.0079\n", + " 96 1.0800 \u001b[32m0.5100\u001b[0m \u001b[35m1.6139\u001b[0m 0.0063\n", + " 97 1.0869 0.5080 \u001b[35m1.6121\u001b[0m 0.0064\n", + " 98 1.1144 0.5100 \u001b[35m1.6074\u001b[0m 0.0053\n", + " 99 \u001b[36m1.0271\u001b[0m 0.5060 \u001b[35m1.6045\u001b[0m 0.0073\n", + " 100 1.0465 0.5100 \u001b[35m1.6020\u001b[0m 0.0062\n", + " 101 1.0348 \u001b[32m0.5200\u001b[0m \u001b[35m1.5963\u001b[0m 0.0047\n", + " 102 \u001b[36m1.0045\u001b[0m \u001b[32m0.5220\u001b[0m \u001b[35m1.5936\u001b[0m 0.0053\n", + " 103 1.0307 \u001b[32m0.5260\u001b[0m \u001b[35m1.5895\u001b[0m 0.0048\n", + " 104 \u001b[36m0.9970\u001b[0m \u001b[32m0.5300\u001b[0m \u001b[35m1.5839\u001b[0m 0.0047\n", + " 105 \u001b[36m0.9644\u001b[0m 0.5300 \u001b[35m1.5814\u001b[0m 0.0041\n", + " 106 0.9879 \u001b[32m0.5320\u001b[0m \u001b[35m1.5770\u001b[0m 0.0050\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - " 107 0.9672 0.5320 \u001b[35m1.5743\u001b[0m 0.0071\n", - " 108 0.9656 \u001b[32m0.5360\u001b[0m \u001b[35m1.5719\u001b[0m 0.0073\n", - " 109 1.0130 0.5340 \u001b[35m1.5680\u001b[0m 0.0072\n", - " 110 0.9821 \u001b[32m0.5380\u001b[0m \u001b[35m1.5630\u001b[0m 0.0070\n", - " 111 0.9151 0.5360 \u001b[35m1.5598\u001b[0m 0.0077\n", - " 112 \u001b[36m0.8390\u001b[0m 0.5360 \u001b[35m1.5549\u001b[0m 0.0093\n", - " 113 0.9395 \u001b[32m0.5420\u001b[0m \u001b[35m1.5505\u001b[0m 0.0074\n", - " 114 0.9234 \u001b[32m0.5440\u001b[0m \u001b[35m1.5465\u001b[0m 0.0069\n", - " 115 0.9149 0.5420 \u001b[35m1.5425\u001b[0m 0.0072\n", - " 116 0.9710 0.5400 \u001b[35m1.5399\u001b[0m 0.0071\n", - " 117 0.9081 0.5400 \u001b[35m1.5393\u001b[0m 0.0086\n", - " 118 0.9285 0.5380 \u001b[35m1.5347\u001b[0m 0.0063\n", - " 119 0.8962 0.5440 \u001b[35m1.5320\u001b[0m 0.0080\n", - " 120 0.9140 0.5420 \u001b[35m1.5285\u001b[0m 0.0066\n", - " 121 0.8729 \u001b[32m0.5460\u001b[0m \u001b[35m1.5242\u001b[0m 0.0075\n", - " 122 0.8770 0.5400 \u001b[35m1.5219\u001b[0m 0.0077\n", - " 123 0.8922 \u001b[32m0.5520\u001b[0m \u001b[35m1.5180\u001b[0m 0.0073\n", - " 124 0.8534 0.5460 \u001b[35m1.5161\u001b[0m 0.0080\n", - " 125 0.8476 0.5520 \u001b[35m1.5092\u001b[0m 0.0077\n", - " 126 0.9047 0.5520 \u001b[35m1.5058\u001b[0m 0.0076\n", - " 127 \u001b[36m0.8089\u001b[0m 0.5480 \u001b[35m1.5046\u001b[0m 0.0066\n", - " 128 0.8734 0.5500 \u001b[35m1.5015\u001b[0m 0.0074\n", - " 129 0.8506 \u001b[32m0.5540\u001b[0m \u001b[35m1.4975\u001b[0m 0.0069\n", - " 130 0.8230 0.5540 \u001b[35m1.4946\u001b[0m 0.0075\n", - " 131 \u001b[36m0.7985\u001b[0m 0.5540 \u001b[35m1.4900\u001b[0m 0.0093\n", - " 132 0.8636 \u001b[32m0.5580\u001b[0m \u001b[35m1.4891\u001b[0m 0.0073\n", - " 133 0.8238 0.5580 \u001b[35m1.4844\u001b[0m 0.0066\n", - " 134 \u001b[36m0.7810\u001b[0m \u001b[32m0.5600\u001b[0m \u001b[35m1.4820\u001b[0m 0.0061\n", - " 135 0.8755 0.5600 \u001b[35m1.4792\u001b[0m 0.0066\n", - " 136 0.8143 0.5580 \u001b[35m1.4761\u001b[0m 0.0063\n", - " 137 0.8134 0.5600 \u001b[35m1.4725\u001b[0m 0.0067\n", - " 138 0.8288 \u001b[32m0.5620\u001b[0m \u001b[35m1.4714\u001b[0m 0.0066\n", - " 139 0.8095 0.5600 \u001b[35m1.4695\u001b[0m 0.0065\n", - " 140 \u001b[36m0.7792\u001b[0m 0.5600 \u001b[35m1.4666\u001b[0m 0.0076\n", - " 141 \u001b[36m0.7747\u001b[0m \u001b[32m0.5660\u001b[0m \u001b[35m1.4646\u001b[0m 0.0071\n", - " 142 \u001b[36m0.7375\u001b[0m 0.5660 \u001b[35m1.4606\u001b[0m 0.0071\n", - " 143 0.8205 \u001b[32m0.5720\u001b[0m \u001b[35m1.4577\u001b[0m 0.0068\n", - " 144 0.7945 0.5700 \u001b[35m1.4561\u001b[0m 0.0069\n", - " 145 0.7934 0.5680 \u001b[35m1.4551\u001b[0m 0.0071\n", - " 146 \u001b[36m0.7044\u001b[0m 0.5660 \u001b[35m1.4501\u001b[0m 0.0074\n", - " 147 0.7447 0.5680 \u001b[35m1.4488\u001b[0m 0.0072\n", - " 148 0.8395 0.5660 \u001b[35m1.4486\u001b[0m 0.0070\n", - " 149 0.8077 0.5660 \u001b[35m1.4456\u001b[0m 0.0073\n", - " 150 0.7555 0.5660 \u001b[35m1.4424\u001b[0m 0.0071\n", - " 151 0.7340 0.5680 \u001b[35m1.4395\u001b[0m 0.0073\n", - " 152 0.7571 0.5680 \u001b[35m1.4362\u001b[0m 0.0078\n", - " 153 \u001b[36m0.6802\u001b[0m 0.5660 \u001b[35m1.4343\u001b[0m 0.0073\n", - " 154 0.7739 0.5660 \u001b[35m1.4316\u001b[0m 0.0072\n", - " 155 0.7414 0.5660 \u001b[35m1.4288\u001b[0m 0.0071\n", - " 156 0.7475 0.5660 \u001b[35m1.4274\u001b[0m 0.0072\n", - " 157 0.7376 0.5680 \u001b[35m1.4242\u001b[0m 0.0073\n", - " 158 0.7045 0.5660 \u001b[35m1.4220\u001b[0m 0.0073\n", - " 159 0.8076 0.5680 \u001b[35m1.4196\u001b[0m 0.0071\n", - " 160 0.7094 0.5720 \u001b[35m1.4156\u001b[0m 0.0073\n", - " 161 0.7091 0.5720 \u001b[35m1.4152\u001b[0m 0.0067\n", - " 162 0.7714 0.5680 1.4154 0.0077\n", - " 163 0.6879 \u001b[32m0.5760\u001b[0m \u001b[35m1.4114\u001b[0m 0.0067\n", - " 164 0.7088 0.5740 \u001b[35m1.4095\u001b[0m 0.0073\n", - " 165 0.7036 0.5760 \u001b[35m1.4072\u001b[0m 0.0070\n", - " 166 \u001b[36m0.6331\u001b[0m 0.5740 \u001b[35m1.4056\u001b[0m 0.0064\n", - " 167 0.6961 0.5740 \u001b[35m1.4026\u001b[0m 0.0068\n", - " 168 0.6787 0.5740 \u001b[35m1.4020\u001b[0m 0.0069\n", - " 169 0.6620 0.5760 \u001b[35m1.3986\u001b[0m 0.0072\n", - " 170 0.6767 0.5760 \u001b[35m1.3980\u001b[0m 0.0072\n", - " 171 0.6663 0.5720 \u001b[35m1.3974\u001b[0m 0.0076\n", - " 172 \u001b[36m0.5900\u001b[0m 0.5720 \u001b[35m1.3955\u001b[0m 0.0081\n", - " 173 0.6731 0.5740 \u001b[35m1.3931\u001b[0m 0.0074\n", - " 174 0.7121 0.5720 \u001b[35m1.3913\u001b[0m 0.0067\n", - " 175 0.6490 0.5740 \u001b[35m1.3882\u001b[0m 0.0062\n", - " 176 0.6253 0.5740 \u001b[35m1.3849\u001b[0m 0.0060\n", - " 177 0.6121 0.5740 \u001b[35m1.3835\u001b[0m 0.0062\n", - " 178 0.6331 0.5740 \u001b[35m1.3813\u001b[0m 0.0059\n", - " 179 0.6628 0.5720 \u001b[35m1.3810\u001b[0m 0.0070\n", - " 180 0.6165 0.5740 \u001b[35m1.3790\u001b[0m 0.0072\n", - " 181 0.5927 0.5720 \u001b[35m1.3769\u001b[0m 0.0074\n", - " 182 0.5952 0.5720 \u001b[35m1.3745\u001b[0m 0.0080\n", - " 183 \u001b[36m0.5650\u001b[0m 0.5740 \u001b[35m1.3740\u001b[0m 0.0074\n", - " 184 0.7050 0.5700 \u001b[35m1.3716\u001b[0m 0.0066\n", - " 185 0.6149 0.5720 \u001b[35m1.3708\u001b[0m 0.0060\n", - " 186 0.5872 \u001b[32m0.5780\u001b[0m \u001b[35m1.3678\u001b[0m 0.0059\n", - " 187 0.6876 0.5700 \u001b[35m1.3653\u001b[0m 0.0061\n", - " 188 0.6151 0.5740 \u001b[35m1.3649\u001b[0m 0.0059\n", - " 189 0.6035 0.5740 \u001b[35m1.3641\u001b[0m 0.0060\n", - " 190 0.6302 0.5720 \u001b[35m1.3637\u001b[0m 0.0071\n", - " 191 0.6218 \u001b[32m0.5800\u001b[0m \u001b[35m1.3597\u001b[0m 0.0068\n", - " 192 0.5813 0.5720 \u001b[35m1.3595\u001b[0m 0.0078\n", - " 193 0.6559 0.5760 \u001b[35m1.3581\u001b[0m 0.0067\n", - " 194 0.5938 0.5760 \u001b[35m1.3561\u001b[0m 0.0062\n", - " 195 0.6327 0.5780 \u001b[35m1.3544\u001b[0m 0.0069\n", - " 196 0.6192 0.5760 \u001b[35m1.3531\u001b[0m 0.0062\n", - " 197 0.6347 0.5760 \u001b[35m1.3515\u001b[0m 0.0065\n", - " 198 0.6515 0.5760 \u001b[35m1.3496\u001b[0m 0.0063\n", - " 199 0.6065 0.5780 \u001b[35m1.3485\u001b[0m 0.0064\n", - " 200 0.6256 0.5780 \u001b[35m1.3467\u001b[0m 0.0067\n" + " 107 0.9986 0.5320 \u001b[35m1.5732\u001b[0m 0.0047\n", + " 108 \u001b[36m0.9234\u001b[0m 0.5320 \u001b[35m1.5692\u001b[0m 0.0070\n", + " 109 0.9704 0.5300 \u001b[35m1.5642\u001b[0m 0.0065\n", + " 110 1.0256 0.5300 \u001b[35m1.5621\u001b[0m 0.0087\n", + " 111 0.9590 0.5240 \u001b[35m1.5585\u001b[0m 0.0074\n", + " 112 1.0168 0.5300 \u001b[35m1.5565\u001b[0m 0.0057\n", + " 113 0.9994 0.5320 \u001b[35m1.5534\u001b[0m 0.0073\n", + " 114 0.9635 0.5320 \u001b[35m1.5492\u001b[0m 0.0082\n", + " 115 0.9872 \u001b[32m0.5340\u001b[0m \u001b[35m1.5452\u001b[0m 0.0046\n", + " 116 0.9749 0.5340 \u001b[35m1.5411\u001b[0m 0.0092\n", + " 117 0.9667 0.5340 \u001b[35m1.5392\u001b[0m 0.0048\n", + " 118 \u001b[36m0.8757\u001b[0m 0.5300 \u001b[35m1.5351\u001b[0m 0.0088\n", + " 119 0.9306 0.5340 \u001b[35m1.5340\u001b[0m 0.0072\n", + " 120 \u001b[36m0.8284\u001b[0m \u001b[32m0.5380\u001b[0m \u001b[35m1.5300\u001b[0m 0.0085\n", + " 121 0.8389 \u001b[32m0.5400\u001b[0m \u001b[35m1.5254\u001b[0m 0.0086\n", + " 122 0.9347 \u001b[32m0.5440\u001b[0m \u001b[35m1.5226\u001b[0m 0.0082\n", + " 123 0.8502 0.5340 \u001b[35m1.5207\u001b[0m 0.0064\n", + " 124 0.8519 \u001b[32m0.5480\u001b[0m \u001b[35m1.5163\u001b[0m 0.0085\n", + " 125 0.8536 0.5460 \u001b[35m1.5127\u001b[0m 0.0073\n", + " 126 0.8926 0.5480 \u001b[35m1.5082\u001b[0m 0.0065\n", + " 127 0.8605 0.5480 \u001b[35m1.5050\u001b[0m 0.0088\n", + " 128 0.8853 \u001b[32m0.5540\u001b[0m \u001b[35m1.5037\u001b[0m 0.0067\n", + " 129 0.8483 0.5540 \u001b[35m1.4992\u001b[0m 0.0092\n", + " 130 0.8745 0.5540 \u001b[35m1.4954\u001b[0m 0.0048\n", + " 131 \u001b[36m0.7866\u001b[0m 0.5500 \u001b[35m1.4940\u001b[0m 0.0096\n", + " 132 0.8322 0.5520 \u001b[35m1.4901\u001b[0m 0.0108\n", + " 133 0.8019 0.5520 \u001b[35m1.4864\u001b[0m 0.0062\n", + " 134 0.8829 0.5540 \u001b[35m1.4846\u001b[0m 0.0097\n", + " 135 0.8545 \u001b[32m0.5560\u001b[0m \u001b[35m1.4829\u001b[0m 0.0078\n", + " 136 0.9028 0.5560 \u001b[35m1.4802\u001b[0m 0.0071\n", + " 137 0.8797 0.5540 \u001b[35m1.4794\u001b[0m 0.0077\n", + " 138 0.7967 0.5560 \u001b[35m1.4744\u001b[0m 0.0070\n", + " 139 \u001b[36m0.7614\u001b[0m 0.5560 \u001b[35m1.4723\u001b[0m 0.0084\n", + " 140 0.8399 \u001b[32m0.5580\u001b[0m \u001b[35m1.4696\u001b[0m 0.0063\n", + " 141 0.8502 0.5580 \u001b[35m1.4671\u001b[0m 0.0077\n", + " 142 \u001b[36m0.7301\u001b[0m 0.5580 \u001b[35m1.4648\u001b[0m 0.0088\n", + " 143 0.7543 \u001b[32m0.5640\u001b[0m \u001b[35m1.4617\u001b[0m 0.0061\n", + " 144 \u001b[36m0.7023\u001b[0m 0.5620 \u001b[35m1.4580\u001b[0m 0.0093\n", + " 145 0.7329 0.5640 \u001b[35m1.4561\u001b[0m 0.0078\n", + " 146 0.7820 0.5640 \u001b[35m1.4526\u001b[0m 0.0066\n", + " 147 0.8137 0.5640 \u001b[35m1.4521\u001b[0m 0.0065\n", + " 148 0.7950 0.5600 \u001b[35m1.4489\u001b[0m 0.0072\n", + " 149 0.7702 0.5600 \u001b[35m1.4468\u001b[0m 0.0089\n", + " 150 0.7851 0.5580 1.4469 0.0067\n", + " 151 0.7881 0.5600 \u001b[35m1.4456\u001b[0m 0.0090\n", + " 152 0.7375 0.5600 \u001b[35m1.4405\u001b[0m 0.0093\n", + " 153 0.7888 0.5580 \u001b[35m1.4401\u001b[0m 0.0069\n", + " 154 0.8128 0.5580 \u001b[35m1.4376\u001b[0m 0.0078\n", + " 155 \u001b[36m0.6960\u001b[0m 0.5600 \u001b[35m1.4345\u001b[0m 0.0081\n", + " 156 0.7073 0.5600 \u001b[35m1.4328\u001b[0m 0.0078\n", + " 157 0.7129 0.5620 \u001b[35m1.4301\u001b[0m 0.0065\n", + " 158 0.7282 0.5620 \u001b[35m1.4283\u001b[0m 0.0076\n", + " 159 0.7855 0.5580 \u001b[35m1.4263\u001b[0m 0.0090\n", + " 160 0.7444 0.5620 \u001b[35m1.4237\u001b[0m 0.0065\n", + " 161 0.7081 0.5620 \u001b[35m1.4204\u001b[0m 0.0085\n", + " 162 \u001b[36m0.6947\u001b[0m 0.5620 \u001b[35m1.4188\u001b[0m 0.0091\n", + " 163 0.7121 0.5620 \u001b[35m1.4152\u001b[0m 0.0074\n", + " 164 0.7374 0.5620 \u001b[35m1.4139\u001b[0m 0.0078\n", + " 165 \u001b[36m0.6866\u001b[0m 0.5600 \u001b[35m1.4113\u001b[0m 0.0100\n", + " 166 0.7360 0.5640 \u001b[35m1.4102\u001b[0m 0.0061\n", + " 167 \u001b[36m0.6596\u001b[0m 0.5600 \u001b[35m1.4073\u001b[0m 0.0082\n", + " 168 \u001b[36m0.6245\u001b[0m 0.5600 \u001b[35m1.4048\u001b[0m 0.0079\n", + " 169 0.6546 0.5600 \u001b[35m1.4032\u001b[0m 0.0077\n", + " 170 0.6534 0.5640 \u001b[35m1.4001\u001b[0m 0.0087\n", + " 171 0.7578 0.5640 \u001b[35m1.3992\u001b[0m 0.0074\n", + " 172 \u001b[36m0.5842\u001b[0m 0.5640 \u001b[35m1.3976\u001b[0m 0.0080\n", + " 173 0.5937 0.5600 \u001b[35m1.3950\u001b[0m 0.0067\n", + " 174 0.6404 0.5620 \u001b[35m1.3937\u001b[0m 0.0071\n", + " 175 0.6674 \u001b[32m0.5700\u001b[0m \u001b[35m1.3930\u001b[0m 0.0104\n", + " 176 0.6230 0.5660 \u001b[35m1.3901\u001b[0m 0.0051\n", + " 177 0.7345 0.5660 \u001b[35m1.3891\u001b[0m 0.0102\n", + " 178 0.6696 0.5700 \u001b[35m1.3872\u001b[0m 0.0078\n", + " 179 \u001b[36m0.5714\u001b[0m \u001b[32m0.5720\u001b[0m \u001b[35m1.3838\u001b[0m 0.0064\n", + " 180 0.5873 0.5720 \u001b[35m1.3810\u001b[0m 0.0079\n", + " 181 0.6660 \u001b[32m0.5780\u001b[0m \u001b[35m1.3787\u001b[0m 0.0073\n", + " 182 0.6348 0.5760 \u001b[35m1.3771\u001b[0m 0.0091\n", + " 183 0.6988 0.5780 \u001b[35m1.3746\u001b[0m 0.0058\n", + " 184 0.5842 0.5760 \u001b[35m1.3719\u001b[0m 0.0102\n", + " 185 0.5969 0.5760 \u001b[35m1.3698\u001b[0m 0.0076\n", + " 186 0.6382 0.5780 \u001b[35m1.3682\u001b[0m 0.0072\n", + " 187 0.5822 0.5780 1.3693 0.0086\n", + " 188 0.6263 0.5760 \u001b[35m1.3677\u001b[0m 0.0094\n", + " 189 0.6212 0.5740 \u001b[35m1.3666\u001b[0m 0.0063\n", + " 190 0.6059 0.5740 \u001b[35m1.3641\u001b[0m 0.0087\n", + " 191 0.5797 0.5760 \u001b[35m1.3610\u001b[0m 0.0062\n", + " 192 0.6375 0.5780 \u001b[35m1.3598\u001b[0m 0.0083\n", + " 193 0.6777 0.5740 1.3610 0.0059\n", + " 194 0.6331 0.5700 1.3609 0.0074\n", + " 195 0.6305 0.5720 1.3601 0.0081\n", + " 196 \u001b[36m0.5502\u001b[0m 0.5700 \u001b[35m1.3568\u001b[0m 0.0063\n", + " 197 0.7014 0.5720 \u001b[35m1.3562\u001b[0m 0.0109\n", + " 198 0.5605 0.5720 \u001b[35m1.3551\u001b[0m 0.0105\n", + " 199 0.5541 0.5760 1.3557 0.0072\n", + " 200 \u001b[36m0.5496\u001b[0m 0.5720 \u001b[35m1.3506\u001b[0m 0.0088\n" ] }, { @@ -584,17 +595,17 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "ef5f2409", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.683" + "0.682" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -602,6 +613,16 @@ "source": [ "accuracy_score(ds_test.y, net.predict(ds_test))" ] + }, + { + "cell_type": "markdown", + "id": "2fb7c99b", + "metadata": {}, + "source": [ + "In conclusion this example showed you how to use a basic data graph dataset using pytorch geometric in conjunction with skorch. The final test score is lower than the ~80% accuracy in the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) which can be explained by the reduced leakage between train and validation sets due to our splitting the data into subgraphs beforehand.\n", + "\n", + "The model is now incorporated into the sklearn world (as you could already see, you can simply use sklearn metrics to evaluate the model). Thus, tools like grid and random search are available to you and it is easily possible to include a graph neural net as a feature transformer in your next ML pipeline!" + ] } ], "metadata": { From cb1454a62552a5e6e31a284b28d9b6f14a366cec Mon Sep 17 00:00:00 2001 From: Marian Tietz Date: Mon, 25 Jul 2022 13:14:04 +0200 Subject: [PATCH 3/3] Add link to notebooks README --- notebooks/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/notebooks/README.md b/notebooks/README.md index 8176245ce..64826496e 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -6,3 +6,4 @@ * [MNIST using torchvision](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/MNIST-torchvision.ipynb) * [Transfer Learning](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb) * [Gaussian Processes](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb) +* [PyTorch Geometric on the CORA dataset](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/CORA-geometric.ipynb)