Skip to content

Commit

Permalink
Merge local branch
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed May 8, 2023
2 parents 7fb05da + 56e7460 commit 738cb1d
Show file tree
Hide file tree
Showing 14 changed files with 534 additions and 585 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ python -m pip install -e .
### GPU support
Upgrade `jax` to the gpu version
```
pip install --upgrade "jax[cuda]==0.4.1" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

## Validation
Expand Down Expand Up @@ -57,13 +57,13 @@ Times are remeasured on Quadro RTX 4000, __model only__ on batches of 100 graphs
</tr>
<tr>
<td> <code>QM9 (alpha)</code> </td>
<td>.075*</td>
<td>.066*</td>
<td>82.53</td>
<td>.098</td>
<td>.082</td>
<td>105.98**</td>
</tr>
</table>
* rerun
* rerun on same conditions

** padded (naive)

Expand All @@ -76,7 +76,6 @@ git clone https://github.com/gerkone/segnn-jax

They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough).
```
pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r experiments/requirements.txt
```

Expand All @@ -96,17 +95,17 @@ python3 -u generate_dataset.py --simulation=gravity --n-balls=100
### Usage
#### N-body (charged)
```
python main.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
```

#### N-body (gravity)
```
python main.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
```

#### QM9
```
python main.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling
python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling
```

(configurations used in validation)
Expand Down
6 changes: 5 additions & 1 deletion experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from .nbody.utils import setup_nbody_data
from .qm9.utils import setup_qm9_data
from .train import train

__all__ = ["setup_data", "train"]


__setup_conf = {
"qm9": setup_qm9_data,
Expand All @@ -12,7 +16,7 @@
}


def setup_datasets(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
def setup_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
assert args.dataset in [
"qm9",
"charged",
Expand Down
13 changes: 10 additions & 3 deletions experiments/nbody/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph
import numpy as np
import torch
from jraph import GraphsTuple, segment_mean
Expand Down Expand Up @@ -111,7 +112,10 @@ def NbodyGraphTransform(
]
).T

def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
def _to_steerable_graph(
data: List, training: bool = True
) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
_ = training
loc, vel, _, q, targets = data

cur_batch = int(loc.shape[0] / n_nodes)
Expand All @@ -138,6 +142,7 @@ def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
)
)
st_graph = transform(st_graph, loc, vel, q)

# relative shift as target
if relative_target:
targets = targets - loc
Expand All @@ -157,7 +162,9 @@ def numpy_collate(batch):
return jnp.array(batch)


def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
def setup_nbody_data(
args,
) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
if args.dataset == "charged":
dataset_train = ChargedDataset(
partition="train",
Expand Down Expand Up @@ -234,4 +241,4 @@ def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable
collate_fn=numpy_collate,
)

return loader_train, loader_val, loader_test, graph_transform
return loader_train, loader_val, loader_test, graph_transform, None
44 changes: 31 additions & 13 deletions experiments/qm9/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,26 @@


def QM9GraphTransform(
node_features_irreps: e3nn.Irreps,
edge_features_irreps: e3nn.Irreps,
lmax_attributes: int,
args,
max_batch_nodes: int,
max_batch_edges: int,
train_trn: Callable,
) -> Callable:
"""
Build a function that converts torch DataBatch into SteerableGraphsTuple.
Mostly a quick fix out of lazyness. Rewriting QM9 in jax is not trivial.
"""
attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes)
attribute_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes)

def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
def _to_steerable_graph(
data: Data, training: bool = True
) -> Tuple[SteerableGraphsTuple, jnp.array]:
ptr = jnp.array(data.ptr)
senders = jnp.array(data.edge_index[0])
receivers = jnp.array(data.edge_index[1])
graph = jraph.GraphsTuple(
nodes=e3nn.IrrepsArray(node_features_irreps, jnp.array(data.x)),
nodes=e3nn.IrrepsArray(args.node_irreps, jnp.array(data.x)),
edges=None,
senders=senders,
receivers=receivers,
Expand All @@ -54,7 +55,7 @@ def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
node_attributes.array = node_attributes.array.at[:, 0].set(1.0)

additional_message_features = e3nn.IrrepsArray(
edge_features_irreps,
args.additional_message_irreps,
jnp.pad(jnp.array(data.additional_message_features), edge_attr_pad),
)
edge_attributes = e3nn.IrrepsArray(
Expand All @@ -69,13 +70,24 @@ def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
)

# pad targets
target = jnp.append(jnp.array(data.y), 0)
target = jnp.array(data.y)
if args.task == "node":
target = jnp.pad(target, [(0, max_batch_nodes - target.shape[0] - 1)])
if args.task == "graph":
target = jnp.append(target, 0)

# normalize targets
if training and train_trn is not None:
target = train_trn(target)

return st_graph, target

return _to_steerable_graph


def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
def setup_qm9_data(
args,
) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
dataset_train = QM9(
"datasets",
args.target,
Expand Down Expand Up @@ -115,6 +127,10 @@ def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
)
)

target_mean, target_mad = dataset_train.calc_stats()

remove_offsets = lambda t: (t - target_mean) / target_mad

# not great and very slow due to huge padding
loader_train = DataLoader(
dataset_train,
Expand All @@ -136,10 +152,12 @@ def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
)

to_graphs_tuple = QM9GraphTransform(
args.node_irreps,
args.additional_message_irreps,
args.lmax_attributes,
args,
max_batch_nodes=max_batch_nodes,
max_batch_edges=max_batch_edges,
train_trn=remove_offsets,
)
return loader_train, loader_val, loader_test, to_graphs_tuple

add_offsets = lambda p: p * target_mad + target_mean

return loader_train, loader_val, loader_test, to_graphs_tuple, add_offsets
7 changes: 5 additions & 2 deletions experiments/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
--find-links https://data.pyg.org/whl/torch-1.12.1+cpu.html
--extra-index-url https://download.pytorch.org/whl/cpu

--find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html

e3nn==0.5.0
matplotlib>=3.6.2
rdkit==2022.9.2
torch==1.12.1
torch==1.13.1
torch-cluster==1.6.0
torch-geometric==2.1.0
torch-scatter==2.1.0
Expand Down
166 changes: 166 additions & 0 deletions experiments/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import time
from functools import partial
from typing import Callable, Tuple

import haiku as hk
import jax
from jax import jit
import jax.numpy as jnp
import jraph
import optax

from segnn_jax import SteerableGraphsTuple


@partial(jit, static_argnames=["model_fn", "criterion", "task", "do_mask", "eval_trn"])
def loss_fn(
params: hk.Params,
state: hk.State,
st_graph: SteerableGraphsTuple,
target: jnp.ndarray,
model_fn: Callable,
criterion: Callable,
task: str = "node",
do_mask: bool = True,
eval_trn: Callable = None,
) -> Tuple[float, hk.State]:
pred, state = model_fn(params, state, st_graph)
if eval_trn is not None:
pred = eval_trn(pred)
if task == "node":
mask = jraph.get_node_padding_mask(st_graph.graph)
if task == "graph":
mask = jraph.get_graph_padding_mask(st_graph.graph)
# broadcase mask for vector targets
if len(pred.shape) == 2:
mask = mask[:, jnp.newaxis]
if do_mask:
target = target * mask
pred = pred * mask
assert target.shape == pred.shape
return jnp.sum(criterion(pred, target)) / jnp.count_nonzero(mask), state


@partial(jit, static_argnames=["loss_fn", "opt_update"])
def update(
params: hk.Params,
state: hk.State,
graph: SteerableGraphsTuple,
target: jnp.ndarray,
opt_state: optax.OptState,
loss_fn: Callable,
opt_update: Callable,
) -> Tuple[float, hk.Params, hk.State, optax.OptState]:
(loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
params, state, graph, target
)
updates, opt_state = opt_update(grads, opt_state, params)
return loss, optax.apply_updates(params, updates), state, opt_state


def evaluate(
loader,
params: hk.Params,
state: hk.State,
loss_fn: Callable,
graph_transform: Callable,
) -> Tuple[float, float]:
eval_loss = 0.0
eval_times = 0.0
for data in loader:
graph, target = graph_transform(data, training=False)
eval_start = time.perf_counter_ns()
loss, _ = jax.lax.stop_gradient(loss_fn(params, state, graph, target))
eval_loss += jax.block_until_ready(loss)
eval_times += (time.perf_counter_ns() - eval_start) / 1e6
return eval_times / len(loader), eval_loss / len(loader)


def train(
key,
segnn,
loader_train,
loader_val,
loader_test,
loss_fn,
eval_loss_fn,
graph_transform,
args,
):
init_graph, _ = graph_transform(next(iter(loader_train)))
params, segnn_state = segnn.init(key, init_graph)

print(
f"Starting {args.epochs} epochs "
f"with {hk.data_structures.tree_size(params)} parameters.\n"
"Jitting..."
)

total_steps = args.epochs * len(loader_train)

# set up learning rate and optimizer
learning_rate = args.lr
if args.lr_scheduling:
learning_rate = optax.piecewise_constant_schedule(
learning_rate,
boundaries_and_scales={
int(total_steps * 0.7): 0.1,
int(total_steps * 0.9): 0.1,
},
)
opt_init, opt_update = optax.adamw(
learning_rate=learning_rate, weight_decay=args.weight_decay
)

model_fn = segnn.apply

loss_fn = partial(loss_fn, model_fn=model_fn)
eval_loss_fn = partial(eval_loss_fn, model_fn=model_fn)
update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update)
eval_fn = partial(evaluate, loss_fn=eval_loss_fn, graph_transform=graph_transform)

opt_state = opt_init(params)
avg_time = []
best_val = 1e10

for e in range(args.epochs):
train_loss = 0.0
train_start = time.perf_counter_ns()
for data in loader_train:
graph, target = graph_transform(data)
loss, params, segnn_state, opt_state = update_fn(
params=params,
state=segnn_state,
graph=graph,
target=target,
opt_state=opt_state,
)
train_loss += loss
train_time = (time.perf_counter_ns() - train_start) / 1e6
train_loss /= len(loader_train)
print(
f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms",
end="",
)
if e % args.val_freq == 0:
eval_time, val_loss = eval_fn(loader_val, params, segnn_state)
avg_time.append(eval_time)
tag = ""
if val_loss < best_val:
best_val = val_loss
tag = " (best)"
_, test_loss_ckp = eval_fn(loader_test, params, segnn_state)
print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="")

print()

test_loss = 0
_, test_loss = eval_fn(loader_test, params, segnn_state)
# ignore compilation time
avg_time = avg_time[2:]
avg_time = sum(avg_time) / len(avg_time)
print(
"Training done.\n"
f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n"
f"Average (model) eval time {avg_time:.2f}ms"
)
Loading

0 comments on commit 738cb1d

Please sign in to comment.