diff --git a/goli/data/dataset.py b/goli/data/dataset.py index dd562076a..e2fef061b 100644 --- a/goli/data/dataset.py +++ b/goli/data/dataset.py @@ -6,10 +6,9 @@ from loguru import logger from copy import deepcopy - import torch from torch.utils.data.dataloader import Dataset -from torch_geometric.data import Data +from torch_geometric.data import Data, Batch from goli.data.smiles_transform import smiles_to_unique_mol_ids from goli.features import GraphDict @@ -207,22 +206,22 @@ def num_graphs_total(self): @property def num_nodes_total(self): """Total number of nodes for all graphs""" - return sum([data.num_nodes for data in self.features]) + return sum(get_num_nodes_per_graph(self.features)) @property def max_num_nodes_per_graph(self): """Maximum number of nodes per graph""" - return max([data.num_nodes for data in self.features]) + return max(get_num_nodes_per_graph(self.features)) @property def std_num_nodes_per_graph(self): """Standard deviation of number of nodes per graph""" - return np.std([data.num_nodes for data in self.features]) + return np.std(get_num_nodes_per_graph(self.features)) @property def min_num_nodes_per_graph(self): """Minimum number of nodes per graph""" - return min([data.num_nodes for data in self.features]) + return min(get_num_nodes_per_graph(self.features)) @property def mean_num_nodes_per_graph(self): @@ -232,22 +231,22 @@ def mean_num_nodes_per_graph(self): @property def num_edges_total(self): """Total number of edges for all graphs""" - return sum([data.num_edges for data in self.features]) + return sum(get_num_edges_per_graph(self.features)) @property def max_num_edges_per_graph(self): """Maximum number of edges per graph""" - return max([data.num_edges for data in self.features]) + return max(get_num_edges_per_graph(self.features)) @property def min_num_edges_per_graph(self): """Minimum number of edges per graph""" - return min([data.num_edges for data in self.features]) + return min(get_num_edges_per_graph(self.features)) @property def std_num_edges_per_graph(self): """Standard deviation of number of nodes per graph""" - return np.std([data.num_edges for data in self.features]) + return np.std(get_num_edges_per_graph(self.features)) @property def mean_num_edges_per_graph(self): @@ -407,6 +406,11 @@ def __repr__(self) -> str: ) return out_str + # Faster to compute the statistics if we unbatch first. + features = self.features + if isinstance(self.features, Batch): + self.features = self.features.to_data_list() + out_str = ( f"-------------------\n{self.__class__.__name__}\n" + f"\tabout = {self.about}\n" @@ -423,6 +427,10 @@ def __repr__(self) -> str: + f"\tmean_num_edges_per_graph = {self.mean_num_edges_per_graph}\n" + f"-------------------\n" ) + + # Restore the original features. + self.features = features + return out_str @@ -525,3 +533,23 @@ def __getitem__(self, idx): datum["features"] = self.features[idx] return datum + + +def get_num_nodes_per_graph(graphs): + r""" + number of nodes per graph + """ + if isinstance(graphs, Batch): + graphs = graphs.to_data_list() + counts = [graph.num_nodes for graph in graphs] + return counts + + +def get_num_edges_per_graph(graphs): + r""" + number of edges per graph + """ + if isinstance(graphs, Batch): + graphs = graphs.to_data_list() + counts = [graph.num_edges for graph in graphs] + return counts diff --git a/goli/expts/pyg_batching_sparse.ipynb b/goli/expts/pyg_batching_sparse.ipynb new file mode 100644 index 000000000..c6a52a7a7 --- /dev/null +++ b/goli/expts/pyg_batching_sparse.ipynb @@ -0,0 +1,95 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch version: 1.13.0+cpu\n", + "pyg version: 2.3.0.dev20230306\n", + "batch.x = tensor(indices=tensor([[0, 1, 1, 2, 3, 4],\n", + " [1, 0, 1, 1, 0, 1]]),\n", + " values=tensor([1, 2, 3, 4, 5, 6]),\n", + " size=(5, 2), nnz=6, layout=torch.sparse_coo)\n", + "Data(x=[2, 2])\n", + "Data(x=[2, 2])\n", + "[Data(x=[2, 2]), Data(x=[3, 2])]\n", + "[tensor(indices=tensor([[0, 1, 1],\n", + " [1, 0, 1]]),\n", + " values=tensor([1, 2, 3]),\n", + " size=(2, 2), nnz=3, layout=torch.sparse_coo), tensor(indices=tensor([[0, 1, 2],\n", + " [1, 0, 1]]),\n", + " values=tensor([4, 5, 6]),\n", + " size=(3, 2), nnz=3, layout=torch.sparse_coo)]\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch_geometric\n", + "from torch_geometric.data import Data, Batch\n", + "\n", + "data1 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 1], [1, 0, 1]]), torch.tensor([1, 2, 3]), (2, 2)))\n", + "data2 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 2], [1, 0, 1]]), torch.tensor([4, 5, 6]), (3, 2)))\n", + "batch = Batch.from_data_list([data1, data2])\n", + "\n", + "print(\"torch version: \", torch.__version__)\n", + "print(\"pyg version: \", torch_geometric.__version__)\n", + "print(\"batch.x = \", batch.x) # WORKS\n", + "print(batch.get_example(0)) # FAILS\n", + "print(batch[0]) # FAILS\n", + "print(batch.to_data_list())\n", + "print([b.x for b in batch.to_data_list()])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(indices=tensor([[0, 1, 1, 2],\n", + " [1, 0, 1, 0]]),\n", + " values=tensor([1, 2, 3, 5]),\n", + " size=(3, 2), nnz=4, layout=torch.sparse_coo)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch.x.index_select(dim=0, index=torch.tensor([0, 1, 3]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "goli_ipu", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/goli/features/featurizer.py b/goli/features/featurizer.py index 4fef83437..9791a4da0 100644 --- a/goli/features/featurizer.py +++ b/goli/features/featurizer.py @@ -820,16 +820,26 @@ def make_pyg_graph(self, **kwargs) -> Data: num_nodes = self.adj.shape[0] # Get the node and edge data - data_dict = {key: val for key, val in self.items()} - data_dict.pop("adj") - data_dict.pop("dtype") - data_dict.pop("mask_nan") - - # Convert the data to torch - for key, val in data_dict.items(): - if isinstance(val, np.ndarray): + data_dict = {} + + # Convert the data and sparse data to torch + for key, val in self.items(): + if key in ["adj", "dtype", "mask_nan"]: # Skip the parameters + continue + elif isinstance(val, np.ndarray): + # Convert the data to the specified dtype in torch format val = val.astype(self.dtype) data_dict[key] = torch.as_tensor(val) + elif issparse(val): + data_dict[key] = val + # TODO: Convert sparse data to torch once the bug is fixed in PyG + # See https://github.com/pyg-team/pytorch_geometric/pull/7037 + # indices = torch.from_numpy(np.vstack((val.row, val.col)).astype(np.int64)) + # data_dict[key] = torch.sparse_coo_tensor(indices=indices, values=val.data, size=val.shape) + elif isinstance(val, torch.Tensor): + data_dict[key] = val + else: + pass # Skip the other parameters # Create the PyG graph object `Data` data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes, **data_dict)