Skip to content

Commit

Permalink
PCQM4Mv2 dataset: Reference implementation for OnDiskDataset (#8102)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Sep 30, 2023
1 parent 1e12d41 commit 14e160d
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102)
- Added `module_headers` property to `nn.Sequential` models ([#8093](https://github.com/pyg-team/pytorch_geometric/pull/8093)
- Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088), [#8092](https://github.com/pyg-team/pytorch_geometric/pull/8092))
- Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938)
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def insert(self, index: int, data: Any):
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.execute(query, (index, *self._serialize(data)))
self._connection.commit()

def _multi_insert(
self,
Expand All @@ -331,6 +332,7 @@ def _multi_insert(
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.executemany(query, data_list)
self._connection.commit()

def get(self, index: int) -> Any:
query = (f'SELECT {self._joined_col_names} FROM {self.name} '
Expand Down
1 change: 0 additions & 1 deletion torch_geometric/data/on_disk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
self,
root: str,
transform: Optional[Callable] = None,
*,
pre_filter: Optional[Callable] = None,
backend: str = 'sqlite',
schema: Schema = object,
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .zinc import ZINC
from .aqsol import AQSOL
from .molecule_net import MoleculeNet
from .pcqm4m import PCQM4Mv2
from .entities import Entities
from .rel_link_pred_dataset import RelLinkPredDataset
from .ged_dataset import GEDDataset
Expand Down Expand Up @@ -125,6 +126,7 @@
'ZINC',
'AQSOL',
'MoleculeNet',
'PCQM4Mv2',
'Entities',
'RelLinkPredDataset',
'GEDDataset',
Expand Down
107 changes: 107 additions & 0 deletions torch_geometric/datasets/pcqm4m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import os.path as osp
from typing import Any, Callable, Dict, List, Optional

import torch
from tqdm import tqdm

from torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip
from torch_geometric.utils import from_smiles


class PCQM4Mv2(OnDiskDataset):
r"""The PCQM4Mv2 dataset from the `"OGB-LSC: A Large-Scale Challenge for
Machine Learning on Graphs" <https://arxiv.org/abs/2103.09430>`_ paper.
:class:`PCQM4Mv2` is a quantum chemistry dataset originally curated under
the PubChemQC project. The task is to predict the DFT-calculated HOMO-LUMO
energy gap of molecules given their 2D molecular graphs.
.. note::
This dataset uses the :class:`OnDiskDataset` base class to load data
dynamically from disk.
Args:
root (str): Root directory where the dataset should be saved.
split (str, optional): If :obj:`"train"`, loads the training dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset.
If :obj:`"holdout"`, loads the holdout dataset.
(default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
backend (str): The :class:`Database` backend to use.
(default: :obj:`"sqlite"`)
"""
url = ('https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/'
'pcqm4m-v2.zip')

split_mapping = {
'train': 'train',
'val': 'valid',
'test': 'test-dev',
'holdout': 'test-challenge',
}

def __init__(
self,
root: str,
split: str = 'train',
transform: Optional[Callable] = None,
backend: str = 'sqlite',
):
assert split in ['train', 'val', 'test', 'holdout']

schema = {
'x': dict(dtype=torch.int64, size=(-1, 9)),
'edge_index': dict(dtype=torch.int64, size=(2, -1)),
'edge_attr': dict(dtype=torch.int64, size=(-1, 3)),
'smiles': str,
'y': float,
}

super().__init__(root, transform, backend=backend, schema=schema)

split_idx = torch.load(self.raw_paths[1])
self._indices = split_idx[self.split_mapping[split]].tolist()

@property
def raw_file_names(self) -> List[str]:
return [
osp.join('pcqm4m-v2', 'raw', 'data.csv.gz'),
osp.join('pcqm4m-v2', 'split_dict.pt'),
]

def download(self):
path = download_url(self.url_2d, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)

def process(self):
import pandas as pd

df = pd.read_csv(self.raw_paths[0])

data_list: List[Data] = []
iterator = enumerate(zip(df['smiles'], df['homolumogap']))
for i, (smiles, y) in tqdm(iterator, total=len(df)):
data = from_smiles(smiles)
data.y = y

data_list.append(data)
if i + 1 == len(df) or (i + 1) % 1000 == 0: # Write batch-wise:
self.extend(data_list)
data_list = []

def serialize(self, data: Data) -> Dict[str, Any]:
return dict(
x=data.x,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
y=data.x,
smiles=data.smiles,
)

def deserialize(self, data: Dict[str, Any]) -> Data:
return Data.from_dict(data)

0 comments on commit 14e160d

Please sign in to comment.