Skip to content

Commit

Permalink
Merge pull request #1 from gerkone/fixes-0.5
Browse files Browse the repository at this point in the history
New tensor product, fixes
  • Loading branch information
gerkone authored Mar 31, 2023
2 parents d22df64 + 5b42f7c commit 7fb05da
Show file tree
Hide file tree
Showing 20 changed files with 838 additions and 313 deletions.
27 changes: 26 additions & 1 deletion .github/workflows/build_branch.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,33 @@
name: Publish on PyPi
name: Build

on: push

jobs:
tests:
name: Tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- name: Set up docker image
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Install dependencies
run: >-
python -m
pip install
-r requirements.txt
--user
- name: Install pytest
run: >-
python -m
pip install
pytest
--user
- name: Run tests
run: >-
python -m pytest tests/
build-publish:
name: Build and publish
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# experiments
*.npy
wandb/
nohup.out
*.out
datasets/

# dev
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://github.com/pycqa/isort
rev: 5.8.0
rev: 5.12.0
hooks:
- id: isort
args: [ --profile, black ]
Expand Down
2 changes: 2 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ disable=C0114, # module docstring
W0108, # unnecessary lambda
W0621, # redifining from outer scope
C0103, # bad function names
C0415, # import outside toplevel
W0212, # protected access
abstract-method,
apply-builtin,
arguments-differ,
Expand Down
44 changes: 25 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Steerable E(3) GNN in jax
Reimplementation of [SEGNN](https://arxiv.org/abs/2110.02905) in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling.

## Why jax?
**40-50% faster** inference and training compared to the [original torch implementation](https://github.com/RobDHess/Steerable-E3-GNN). Also JAX-MD.

## Installation
```
python -m pip install segnn-jax
Expand All @@ -19,9 +22,12 @@ pip install --upgrade "jax[cuda]==0.4.1" -f https://storage.googleapis.com/jax-r

## Validation
N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.
The implementation is validated on all three of them, getting close results and considerably faster runtimes.

### Results
Charged is on 5 bodies, gravity on 100 bodies. QM9 has graphs of variable sizes, so in jax samples are padded to the maximum size. Loss is MSE for Charged and Gravity and MAE for QM9.

Times are remeasured on Quadro RTX 4000, __model only__ on batches of 100 graphs, in (global) single precision.

<table>
<tr>
<td></td>
Expand All @@ -30,36 +36,36 @@ The implementation is validated on all three of them, getting close results and
</tr>
<tr>
<td></td>
<td>MSE</td>
<td>Inference [ms]*</td>
<td>MSE</td>
<td>Loss</td>
<td>Inference [ms]</td>
<td>Loss</td>
<td>Inference [ms]</td>
</tr>
<tr>
<td> <code>charged (position)</code> </td>
<td>.0043</td>
<td>40.76</td>
<td>.0047</td>
<td><b>28.67</td>
<td>21.22</td>
<td>.0045</td>
<td>4.47</td>
</tr>
<tr>
<td><code>gravity (position)</code> </td>
<td>.265</td>
<td>392.20</td>
<td>.28</td>
<td><b>240.34</td>
<td>60.55</td>
<td>.264</td>
<td>41.72</td>
</tr>
<tr>
<td> <code>QM9 (alpha)</code> </td>
<td>.06</td>
<td>159.17</td>
<td></td>
<td>109.58**</td>
<td>.075*</td>
<td>82.53</td>
<td>.098</td>
<td>105.98**</td>
</tr>
</table>
* remeasured (Quadro RTX 4000), batch of 100 graphs, single precision
* rerun

** padded
** padded (naive)

### Validation install

Expand Down Expand Up @@ -90,12 +96,12 @@ python3 -u generate_dataset.py --simulation=gravity --n-balls=100
### Usage
#### N-body (charged)
```
python main.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-4 --weight-decay=1e-8
python main.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
```

#### N-body (gravity)
```
python main.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=1e-4 --weight-decay=1e-8 --neighbours=5 --n-bodies=100
python main.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
```

#### QM9
Expand All @@ -108,4 +114,4 @@ python main.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax

## Acknowledgments
- [e3nn_jax](https://github.com/e3nn/e3nn-jax) made this reimplementation possible.
- [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for supporting developement.
- [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for support.
5 changes: 2 additions & 3 deletions experiments/nbody/data/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Generate charged and gravity datasets.
charged: python3 generate_dataset.py --simulation=charged --num-train=10000 --seed=43
gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --seed=43 --n-balls=100
charged: python3 generate_dataset.py --simulation=charged --num-train=10000
gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --n-balls=100
"""
import argparse
import time
Expand Down Expand Up @@ -95,7 +95,6 @@ def generate_dataset(num_sims, length, sample_freq):


if __name__ == "__main__":

print("Generating {} training simulations".format(args.num_train))
loc_train, vel_train, edges_train, charges_train = generate_dataset(
args.num_train, args.length, args.sample_freq
Expand Down
2 changes: 0 additions & 2 deletions experiments/nbody/data/synthetic_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def _l2(self, A, B):
return dist

def _energy(self, loc, vel, edges):

# disables division by zero warning, since I fix it with fill_diagonal
with np.errstate(divide="ignore"):

K = 0.5 * (vel**2).sum()
U = 0
for i in range(loc.shape[1]):
Expand Down
14 changes: 7 additions & 7 deletions experiments/nbody/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, List, Optional, Tuple

import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import numpy as np
Expand All @@ -24,13 +25,13 @@ def O3Transform(
"""
attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes)

@jax.jit
def _o3_transform(
st_graph: SteerableGraphsTuple,
loc: jnp.ndarray,
vel: jnp.ndarray,
charges: jnp.ndarray,
) -> SteerableGraphsTuple:

graph = st_graph.graph
prod_charges = charges[graph.senders] * charges[graph.receivers]
rel_pos = loc[graph.senders] - loc[graph.receivers]
Expand Down Expand Up @@ -91,7 +92,7 @@ def NbodyGraphTransform(
dataset_name: str,
n_nodes: int,
batch_size: int,
neighbours: Optional[int] = 0,
neighbours: Optional[int] = 6,
relative_target: bool = False,
) -> Callable:
"""
Expand All @@ -100,7 +101,7 @@ def NbodyGraphTransform(

if dataset_name == "charged":
# charged system is a connected graph
full_edge_indices = np.array(
full_edge_indices = jnp.array(
[
(i + n_nodes * b, j + n_nodes * b)
for b in range(batch_size)
Expand All @@ -111,7 +112,6 @@ def NbodyGraphTransform(
).T

def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:

loc, vel, _, q, targets = data

cur_batch = int(loc.shape[0] / n_nodes)
Expand All @@ -124,7 +124,7 @@ def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
batch = batch.repeat_interleave(n_nodes).long()
edge_indices = knn_graph(torch.from_numpy(np.array(loc)), neighbours, batch)
# switched by default
senders, receivers = jnp.array(edge_indices[1]), jnp.array(edge_indices[0])
senders, receivers = jnp.array(edge_indices[0]), jnp.array(edge_indices[1])

st_graph = SteerableGraphsTuple(
graph=GraphsTuple(
Expand Down Expand Up @@ -215,8 +215,8 @@ def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable
loader_train = DataLoader(
dataset_train,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
shuffle=True,
drop_last=True,
collate_fn=numpy_collate,
)
loader_val = DataLoader(
Expand Down
40 changes: 25 additions & 15 deletions experiments/qm9/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ def QM9GraphTransform(
attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes)

def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
ptr = jnp.array(data.ptr)
senders = jnp.array(data.edge_index[0])
receivers = jnp.array(data.edge_index[1])
graph = jraph.GraphsTuple(
nodes=e3nn.IrrepsArray(node_features_irreps, jnp.array(data.x)),
edges=None,
senders=jnp.array(data.edge_index[0]),
receivers=jnp.array(data.edge_index[1]),
n_node=jnp.diff(jnp.array(data.ptr)),
# n_edge is not used anywhere by segnn, but is neded for padding
n_edge=jnp.array([jnp.array(data.edge_index[1]).shape[0]]),
senders=senders,
receivers=receivers,
n_node=jnp.diff(ptr),
n_edge=jnp.diff(jnp.sum(senders[:, jnp.newaxis] < ptr, axis=0)),
globals=None,
)
# pad for jax static shapes
Expand All @@ -45,19 +47,27 @@ def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
n_edge=max_batch_edges + 1,
n_graph=graph.n_node.shape[0] + 1,
)

node_attributes = e3nn.IrrepsArray(
attribute_irreps, jnp.pad(jnp.array(data.node_attr), node_attr_pad)
)
node_attributes.array = node_attributes.array.at[:, 0].set(1.0)

additional_message_features = e3nn.IrrepsArray(
edge_features_irreps,
jnp.pad(jnp.array(data.additional_message_features), edge_attr_pad),
)
edge_attributes = e3nn.IrrepsArray(
attribute_irreps, jnp.pad(jnp.array(data.edge_attr), edge_attr_pad)
)

st_graph = SteerableGraphsTuple(
graph=graph,
node_attributes=e3nn.IrrepsArray(
attribute_irreps, jnp.pad(jnp.array(data.node_attr), node_attr_pad)
),
edge_attributes=e3nn.IrrepsArray(
attribute_irreps, jnp.pad(jnp.array(data.edge_attr), edge_attr_pad)
),
additional_message_features=e3nn.IrrepsArray(
edge_features_irreps,
jnp.pad(jnp.array(data.additional_message_features), edge_attr_pad),
),
node_attributes=node_attributes,
edge_attributes=edge_attributes,
additional_message_features=additional_message_features,
)

# pad targets
target = jnp.append(jnp.array(data.y), 0)
return st_graph, target
Expand Down
Loading

0 comments on commit 7fb05da

Please sign in to comment.