Skip to content

DenseGNN: universal and scalable deeper graph neural networks for high-performance property prediction in crystals and molecules

Notifications You must be signed in to change notification settings

dhw059/DenseGNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

48 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GitHub release (latest by date) Documentation Status PyPI version PyPI - Downloads kgcnn_unit_tests GitHub GitHub issues Maintenance

DenseGNN: universal and scalable deeper graph neural networks for high-performance property prediction in crystals and molecules

Modern deep learning-driven generative models have made it possible to design millions of hypothetical materials. However, to sift through these candidate materials and identify promising new materials, we need fast and accurate models for predicting material properties. Graph neural networks (GNNs) have emerged as a current research hotspot due to their ability to directly operate on the graph representations of molecules and materials, enabling comprehensively capturing key information and exhibiting outstanding performance in predicting material properties. Nevertheless, GNNs still face several key problems in practical applications: firstly, existing nested graph network strategies, while able to incorporate critical structural information such as bond angles, significantly increase the number of trainable parameters in the model, leading to a substantial rise in training costs; secondly, extending GNN models to broader fields such as molecules, crystalline materials, and catalysis, as well as adapting to small datasets, remains a challenge; finally, the scalability of GNN models are limited by the over-smoothing problem. To address these problems, we propose the DenseGNN model, which combines dense connectivity network (DCN), hierarchical node-edge-graph residual networks (HRN), and Local structure Order Parameters Embedding (LOPE) strategies, aiming to create a universal, scalable and efficient GNN model. We have achieved state-of-the-art (SOAT) performance on multiple datasets including JARVIS-DFT, Materials Project, QM9, Lipop, FreeSolv, ESOL, and OC22, demonstrating the generality and scalability of our approach. By fusing DCN and LOPE strategies into GNN models in the fields of computer, crystal materials, and molecules, we have significantly enhanced the performance of models such as GIN, Schnet, and Hamnet on material datasets like Matbench. The LOPE strategy optimizes the embedding representation of atoms, enabling our model to train efficiently at a minimal level of edge connections, significantly reducing computational costs, shortening the time required to train large GNNs, while maintaining accuracy. Our technique not only supports the construction of deeper GNNs, avoiding performance degradation problems seen in other models, but is also applicable to a wide range of applications requiring large deep learning models. Furthermore, our study demonstrates that by utilizing structural embeddings from pre-trained models, our model not only outperforms other GNNs in crystal structure distinguishment, but also approaches the standard X-ray diffraction (XRD) method.

Model comparison and analysis

Crystal structure distinguishment improvement

Requirements

Standard python package requirements are placed in the setup.py and are installed automatically (kgcnn >=2.2). Packages which must be installed manually for full functionality:

Installation

pip install kgcnn==3.0.2

Implementation details

Representation

The most frequent usage for graph convolutions is either node or graph classification. As for their size, either a single large graph, e.g. citation network or small (batched) graphs like molecules have to be considered. Graphs can be represented by an index list of connections plus feature information. Typical quantities in tensor format to describe a graph are listed below.

  • nodes: Node-list of shape (batch, [N], F) where N is the number of nodes and F is the node feature dimension.
  • edges: Edge-list of shape (batch, [M], F) where M is the number of edges and F is the edge feature dimension.
  • indices: Connection-list of shape (batch, [M], 2) where M is the number of edges. The indices denote a connection of incoming or receiving node i and outgoing or sending node j as (i, j).
  • state: Graph state information of shape (batch, F) where F denotes the feature dimension.

A major issue for graphs is their flexible size and shape, when using mini-batches. Here, for a graph implementation in the spirit of keras, the batch dimension should be kept also in between layers. This is realized by using RaggedTensors.

Input

Graph tensors for edge-indices or attributes for multiple graphs is passed to the model in form of ragged tensors of shape (batch, None, Dim) where Dim denotes a fixed feature or index dimension. Such a ragged tensor has ragged_rank=1 with one ragged dimension indicated by None and is build from a value plus partition tensor. For example, the graph structure is represented by an index-list of shape (batch, None, 2) with index of incoming or receiving node i and outgoing or sending node j as (i, j). Note, an additional edge with (j, i) is required for undirected graphs. A ragged constant can be easily created and passed to a model:

import tensorflow as tf
import numpy as np
idx = [[[0, 1], [1, 0]], [[0, 1], [1, 2], [2, 0]], [[0, 0]]]  # batch_size=3
# Get ragged tensor of shape (3, None, 2)
print(tf.ragged.constant(idx, ragged_rank=1, inner_shape=(2, )).shape)
print(tf.RaggedTensor.from_row_lengths(np.concatenate(idx), [len(i) for i in idx]).shape) 

Model

Models can be set up in a functional way. Example message passing from fundamental operations:

import tensorflow as tf
from kgcnn.layers.gather import GatherNodes
from kgcnn.layers.modules import Dense, LazyConcatenate  # ragged support
from kgcnn.layers.aggr import AggregateLocalMessages
from kgcnn.layers.pooling import PoolingNodes

ks = tf.keras

n = ks.layers.Input(shape=(None, 3), name='node_input', dtype="float32", ragged=True)
ei = ks.layers.Input(shape=(None, 2), name='edge_index_input', dtype="int64", ragged=True)

n_in_out = GatherNodes()([n, ei])
node_messages = Dense(10, activation='relu')(n_in_out)
node_updates = AggregateLocalMessages(is_sorted=False)([n, node_messages, ei])
n_node_updates = LazyConcatenate(axis=-1)([n, node_updates])
n_embedding = Dense(1)(n_node_updates)
g_embedding = PoolingNodes()(n_embedding)

message_passing = ks.models.Model(inputs=[n, ei], outputs=g_embedding)

or via sub-classing of the message passing base layer. Where only message_function and update_nodes must be implemented:

from kgcnn.layers.message import MessagePassingBase
from kgcnn.layers.modules import Dense, LazyConcatenate


class MyMessageNN(MessagePassingBase):

    def __init__(self, units, **kwargs):
        super(MyMessageNN, self).__init__(**kwargs)
        self.dense = Dense(units)
        self.add = LazyConcatenate(axis=-1)

    def message_function(self, inputs, **kwargs):
        n_in, n_out, edges = inputs
        return self.dense(n_out)

    def update_nodes(self, inputs, **kwargs):
        nodes, nodes_update = inputs
        return self.add([nodes, nodes_update])

Literature

The following models, proposed in literature, have a module in literature. The module usually exposes a make_model function to create a tf.keras.models.Model, which features ragged tensor in- or output. The models can but must not be build completely from kgcnn.layers and can for example include original implementations (with proper licencing).

... and many more (click to expand).

Data

How to construct ragged tensors is shown above. Moreover, some data handling classes are given in kgcnn.data.

Graph dictionary

Graphs are represented by a dictionary GraphDict of (numpy) arrays which behaves like a python dict. There are graph pre- and postprocessors in kgcnn.graph which take specific properties by name and apply a processing function or transformation.

from kgcnn.data.base import GraphDict
# Single graph.
graph = GraphDict({"edge_indices": [[1, 0], [0, 1]], "node_label": [[0], [1]]})
graph.set("graph_labels", [0])  # use set(), get() to assign (tensor) properties.
graph.set("edge_attributes", [[1.0], [2.0]])
graph.to_networkx()
# Modify with e.g. preprocessor.
from kgcnn.graph.preprocessor import SortEdgeIndices
SortEdgeIndices(edge_indices="edge_indices", edge_attributes="^edge_(?!indices$).*", in_place=True)(graph)

List of graph dictionaries

A MemoryGraphList should behave identical to a python list but contain only GraphDict items.

from kgcnn.data.base import MemoryGraphList
# List of graph dicts.
graph_list = MemoryGraphList([{"edge_indices": [[0, 1], [1, 0]]}, {"edge_indices": [[0, 0]]}, {}])
graph_list.clean(["edge_indices"])  # Remove graphs without property
graph_list.get("edge_indices")  # opposite is set()
# Easily cast to (ragged) tf-tensor; makes copy.
tensor = graph_list.tensor([{"name": "edge_indices", "ragged": True}])  # config of keras `Input` layer
# Or directly modify list.
for i, x in enumerate(graph_list):
    x.set("graph_number", [i])
print(len(graph_list), graph_list[:2])  # Also supports indexing lists.

Datasets

The MemoryGraphDataset inherits from MemoryGraphList but must be initialized with file information on disk that points to a data_directory for the dataset. The data_directory can have a subdirectory for files and/or single file such as a CSV file:

├── data_directory
    ├── file_directory
    │   ├── *.*
    │   └── ... 
    ├── file_name
    └── dataset_name.kgcnn.pickle

A base dataset class is created with path and name information:

from kgcnn.data.base import MemoryGraphDataset
dataset = MemoryGraphDataset(data_directory="ExampleDir/", 
                             dataset_name="Example",
                             file_name=None, file_directory=None)
dataset.save()  # opposite is load(). 

The subclasses QMDataset, MoleculeNetDataset, CrystalDataset, VisualGraphDataset and GraphTUDataset further have functions required for the specific dataset type to convert and process files such as '.txt', '.sdf', '.xyz' etc. Most subclasses implement prepare_data() and read_in_memory() with dataset dependent arguments. An example for MoleculeNetDataset is shown below. For more details find tutorials in notebooks.

from kgcnn.data.moleculenet import MoleculeNetDataset
# File directory and files must exist. 
# Here 'ExampleDir' and 'ExampleDir/data.csv' with columns "smiles" and "label".
dataset = MoleculeNetDataset(dataset_name="Example",
                             data_directory="ExampleDir/",
                             file_name="data.csv")
dataset.prepare_data(overwrite=True, smiles_column_name="smiles", add_hydrogen=True,
                     make_conformers=True, optimize_conformer=True, num_workers=None)
dataset.read_in_memory(label_column_name="label", add_hydrogen=False, 
                       has_conformers=True)

In data.datasets there are graph learning benchmark datasets as subclasses which are being downloaded from e.g. popular graph archives like TUDatasets, MatBench or MoleculeNet. The subclasses GraphTUDataset2020, MatBenchDataset2020 and MoleculeNetDataset2018 download and read the available datasets by name. There are also specific dataset subclasses for each dataset to handle additional processing or downloading from individual sources:

from kgcnn.data.datasets.MUTAGDataset import MUTAGDataset
dataset = MUTAGDataset()  # inherits from GraphTUDataset2020

Downloaded datasets are stored in ~/.kgcnn/datasets on your computer. Please remove them manually, if no longer required.

Training

A set of example training can be found in training. Training scripts are configurable with a hyperparameter config file and command line arguments regarding model and dataset.

You can find a table of common benchmark datasets in results.

Citing and Acknowledgement

If you want to cite this repo, please refer to our paper:

@article{under review, Hongwei Du, Hong Wang*, School of Materials Science and Engineering, Shanghai Jiao Tong University, Shanghai 200240, China.
2 Zhangjiang Institute for Advanced Study, Shanghai Jiao Tong University, Shanghai 201203, China.
3 Materials Genome Initiative Center, Shanghai Jiao Tong University, Shanghai 200240, China.
}

@article{REISER2021100095,
title = {Graph neural networks in TensorFlow-Keras with RaggedTensor representation (kgcnn)},
journal = {Software Impacts},
pages = {100095},
year = {2021},
issn = {2665-9638},
doi = {https://doi.org/10.1016/j.simpa.2021.100095},
url = {https://www.sciencedirect.com/science/article/pii/S266596382100035X},
author = {Patrick Reiser and Andre Eberhard and Pascal Friederich}
}

References

About

DenseGNN: universal and scalable deeper graph neural networks for high-performance property prediction in crystals and molecules

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published