Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch sparse instead of scipy sparse #275

Merged
merged 6 commits into from
Apr 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions goli/data/dataset.py
Original file line number Diff line number Diff line change
@@ -6,10 +6,9 @@
from loguru import logger
from copy import deepcopy


import torch
from torch.utils.data.dataloader import Dataset
from torch_geometric.data import Data
from torch_geometric.data import Data, Batch

from goli.data.smiles_transform import smiles_to_unique_mol_ids
from goli.features import GraphDict
@@ -207,22 +206,22 @@ def num_graphs_total(self):
@property
def num_nodes_total(self):
"""Total number of nodes for all graphs"""
return sum([data.num_nodes for data in self.features])
return sum(get_num_nodes_per_graph(self.features))

@property
def max_num_nodes_per_graph(self):
"""Maximum number of nodes per graph"""
return max([data.num_nodes for data in self.features])
return max(get_num_nodes_per_graph(self.features))

@property
def std_num_nodes_per_graph(self):
"""Standard deviation of number of nodes per graph"""
return np.std([data.num_nodes for data in self.features])
return np.std(get_num_nodes_per_graph(self.features))

@property
def min_num_nodes_per_graph(self):
"""Minimum number of nodes per graph"""
return min([data.num_nodes for data in self.features])
return min(get_num_nodes_per_graph(self.features))

@property
def mean_num_nodes_per_graph(self):
@@ -232,22 +231,22 @@ def mean_num_nodes_per_graph(self):
@property
def num_edges_total(self):
"""Total number of edges for all graphs"""
return sum([data.num_edges for data in self.features])
return sum(get_num_edges_per_graph(self.features))

@property
def max_num_edges_per_graph(self):
"""Maximum number of edges per graph"""
return max([data.num_edges for data in self.features])
return max(get_num_edges_per_graph(self.features))

@property
def min_num_edges_per_graph(self):
"""Minimum number of edges per graph"""
return min([data.num_edges for data in self.features])
return min(get_num_edges_per_graph(self.features))

@property
def std_num_edges_per_graph(self):
"""Standard deviation of number of nodes per graph"""
return np.std([data.num_edges for data in self.features])
return np.std(get_num_edges_per_graph(self.features))

@property
def mean_num_edges_per_graph(self):
@@ -407,6 +406,11 @@ def __repr__(self) -> str:
)
return out_str

# Faster to compute the statistics if we unbatch first.
features = self.features
if isinstance(self.features, Batch):
self.features = self.features.to_data_list()

out_str = (
f"-------------------\n{self.__class__.__name__}\n"
+ f"\tabout = {self.about}\n"
@@ -423,6 +427,10 @@ def __repr__(self) -> str:
+ f"\tmean_num_edges_per_graph = {self.mean_num_edges_per_graph}\n"
+ f"-------------------\n"
)

# Restore the original features.
self.features = features

return out_str


@@ -525,3 +533,23 @@ def __getitem__(self, idx):
datum["features"] = self.features[idx]

return datum


def get_num_nodes_per_graph(graphs):
r"""
number of nodes per graph
"""
if isinstance(graphs, Batch):
graphs = graphs.to_data_list()
counts = [graph.num_nodes for graph in graphs]
return counts


def get_num_edges_per_graph(graphs):
r"""
number of edges per graph
"""
if isinstance(graphs, Batch):
graphs = graphs.to_data_list()
counts = [graph.num_edges for graph in graphs]
return counts
95 changes: 95 additions & 0 deletions goli/expts/pyg_batching_sparse.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 1.13.0+cpu\n",
"pyg version: 2.3.0.dev20230306\n",
"batch.x = tensor(indices=tensor([[0, 1, 1, 2, 3, 4],\n",
" [1, 0, 1, 1, 0, 1]]),\n",
" values=tensor([1, 2, 3, 4, 5, 6]),\n",
" size=(5, 2), nnz=6, layout=torch.sparse_coo)\n",
"Data(x=[2, 2])\n",
"Data(x=[2, 2])\n",
"[Data(x=[2, 2]), Data(x=[3, 2])]\n",
"[tensor(indices=tensor([[0, 1, 1],\n",
" [1, 0, 1]]),\n",
" values=tensor([1, 2, 3]),\n",
" size=(2, 2), nnz=3, layout=torch.sparse_coo), tensor(indices=tensor([[0, 1, 2],\n",
" [1, 0, 1]]),\n",
" values=tensor([4, 5, 6]),\n",
" size=(3, 2), nnz=3, layout=torch.sparse_coo)]\n"
]
}
],
"source": [
"import torch\n",
"import torch_geometric\n",
"from torch_geometric.data import Data, Batch\n",
"\n",
"data1 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 1], [1, 0, 1]]), torch.tensor([1, 2, 3]), (2, 2)))\n",
"data2 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 2], [1, 0, 1]]), torch.tensor([4, 5, 6]), (3, 2)))\n",
"batch = Batch.from_data_list([data1, data2])\n",
"\n",
"print(\"torch version: \", torch.__version__)\n",
"print(\"pyg version: \", torch_geometric.__version__)\n",
"print(\"batch.x = \", batch.x) # WORKS\n",
"print(batch.get_example(0)) # FAILS\n",
"print(batch[0]) # FAILS\n",
"print(batch.to_data_list())\n",
"print([b.x for b in batch.to_data_list()])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(indices=tensor([[0, 1, 1, 2],\n",
" [1, 0, 1, 0]]),\n",
" values=tensor([1, 2, 3, 5]),\n",
" size=(3, 2), nnz=4, layout=torch.sparse_coo)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch.x.index_select(dim=0, index=torch.tensor([0, 1, 3]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "goli_ipu",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
26 changes: 18 additions & 8 deletions goli/features/featurizer.py
Original file line number Diff line number Diff line change
@@ -820,16 +820,26 @@ def make_pyg_graph(self, **kwargs) -> Data:
num_nodes = self.adj.shape[0]

# Get the node and edge data
data_dict = {key: val for key, val in self.items()}
data_dict.pop("adj")
data_dict.pop("dtype")
data_dict.pop("mask_nan")

# Convert the data to torch
for key, val in data_dict.items():
if isinstance(val, np.ndarray):
data_dict = {}

# Convert the data and sparse data to torch
for key, val in self.items():
if key in ["adj", "dtype", "mask_nan"]: # Skip the parameters
continue
elif isinstance(val, np.ndarray):
# Convert the data to the specified dtype in torch format
val = val.astype(self.dtype)
data_dict[key] = torch.as_tensor(val)
elif issparse(val):
data_dict[key] = val
# TODO: Convert sparse data to torch once the bug is fixed in PyG
# See https://github.com/pyg-team/pytorch_geometric/pull/7037
# indices = torch.from_numpy(np.vstack((val.row, val.col)).astype(np.int64))
# data_dict[key] = torch.sparse_coo_tensor(indices=indices, values=val.data, size=val.shape)
elif isinstance(val, torch.Tensor):
data_dict[key] = val
else:
pass # Skip the other parameters

# Create the PyG graph object `Data`
data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes, **data_dict)