diff --git a/README.md b/README.md
index 90e4e88..cd31af2 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@ python -m pip install -e .
### GPU support
Upgrade `jax` to the gpu version
```
-pip install --upgrade "jax[cuda]==0.4.1" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+pip install --upgrade "jax[cuda]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
## Validation
@@ -57,13 +57,13 @@ Times are remeasured on Quadro RTX 4000, __model only__ on batches of 100 graphs
QM9 (alpha) |
- .075* |
+ .066* |
82.53 |
- .098 |
+ .082 |
105.98** |
-* rerun
+* rerun on same conditions
** padded (naive)
@@ -76,7 +76,6 @@ git clone https://github.com/gerkone/segnn-jax
They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough).
```
-pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r experiments/requirements.txt
```
@@ -96,17 +95,17 @@ python3 -u generate_dataset.py --simulation=gravity --n-balls=100
### Usage
#### N-body (charged)
```
-python main.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
+python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
```
#### N-body (gravity)
```
-python main.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
+python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
```
#### QM9
```
-python main.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling
+python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling
```
(configurations used in validation)
diff --git a/experiments/__init__.py b/experiments/__init__.py
index fff0437..be203b3 100644
--- a/experiments/__init__.py
+++ b/experiments/__init__.py
@@ -4,6 +4,10 @@
from .nbody.utils import setup_nbody_data
from .qm9.utils import setup_qm9_data
+from .train import train
+
+__all__ = ["setup_data", "train"]
+
__setup_conf = {
"qm9": setup_qm9_data,
@@ -12,7 +16,7 @@
}
-def setup_datasets(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
+def setup_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
assert args.dataset in [
"qm9",
"charged",
diff --git a/experiments/nbody/utils.py b/experiments/nbody/utils.py
index c4cab65..e00d1ca 100644
--- a/experiments/nbody/utils.py
+++ b/experiments/nbody/utils.py
@@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import jax.tree_util as tree
+import jraph
import numpy as np
import torch
from jraph import GraphsTuple, segment_mean
@@ -111,7 +112,10 @@ def NbodyGraphTransform(
]
).T
- def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
+ def _to_steerable_graph(
+ data: List, training: bool = True
+ ) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
+ _ = training
loc, vel, _, q, targets = data
cur_batch = int(loc.shape[0] / n_nodes)
@@ -138,6 +142,7 @@ def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
)
)
st_graph = transform(st_graph, loc, vel, q)
+
# relative shift as target
if relative_target:
targets = targets - loc
@@ -157,7 +162,9 @@ def numpy_collate(batch):
return jnp.array(batch)
-def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
+def setup_nbody_data(
+ args,
+) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
if args.dataset == "charged":
dataset_train = ChargedDataset(
partition="train",
@@ -234,4 +241,4 @@ def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable
collate_fn=numpy_collate,
)
- return loader_train, loader_val, loader_test, graph_transform
+ return loader_train, loader_val, loader_test, graph_transform, None
diff --git a/experiments/qm9/utils.py b/experiments/qm9/utils.py
index b90f364..1007cd0 100644
--- a/experiments/qm9/utils.py
+++ b/experiments/qm9/utils.py
@@ -12,25 +12,26 @@
def QM9GraphTransform(
- node_features_irreps: e3nn.Irreps,
- edge_features_irreps: e3nn.Irreps,
- lmax_attributes: int,
+ args,
max_batch_nodes: int,
max_batch_edges: int,
+ train_trn: Callable,
) -> Callable:
"""
Build a function that converts torch DataBatch into SteerableGraphsTuple.
Mostly a quick fix out of lazyness. Rewriting QM9 in jax is not trivial.
"""
- attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes)
+ attribute_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes)
- def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
+ def _to_steerable_graph(
+ data: Data, training: bool = True
+ ) -> Tuple[SteerableGraphsTuple, jnp.array]:
ptr = jnp.array(data.ptr)
senders = jnp.array(data.edge_index[0])
receivers = jnp.array(data.edge_index[1])
graph = jraph.GraphsTuple(
- nodes=e3nn.IrrepsArray(node_features_irreps, jnp.array(data.x)),
+ nodes=e3nn.IrrepsArray(args.node_irreps, jnp.array(data.x)),
edges=None,
senders=senders,
receivers=receivers,
@@ -54,7 +55,7 @@ def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
node_attributes.array = node_attributes.array.at[:, 0].set(1.0)
additional_message_features = e3nn.IrrepsArray(
- edge_features_irreps,
+ args.additional_message_irreps,
jnp.pad(jnp.array(data.additional_message_features), edge_attr_pad),
)
edge_attributes = e3nn.IrrepsArray(
@@ -69,13 +70,24 @@ def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]:
)
# pad targets
- target = jnp.append(jnp.array(data.y), 0)
+ target = jnp.array(data.y)
+ if args.task == "node":
+ target = jnp.pad(target, [(0, max_batch_nodes - target.shape[0] - 1)])
+ if args.task == "graph":
+ target = jnp.append(target, 0)
+
+ # normalize targets
+ if training and train_trn is not None:
+ target = train_trn(target)
+
return st_graph, target
return _to_steerable_graph
-def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
+def setup_qm9_data(
+ args,
+) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
dataset_train = QM9(
"datasets",
args.target,
@@ -115,6 +127,10 @@ def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
)
)
+ target_mean, target_mad = dataset_train.calc_stats()
+
+ remove_offsets = lambda t: (t - target_mean) / target_mad
+
# not great and very slow due to huge padding
loader_train = DataLoader(
dataset_train,
@@ -136,10 +152,12 @@ def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]:
)
to_graphs_tuple = QM9GraphTransform(
- args.node_irreps,
- args.additional_message_irreps,
- args.lmax_attributes,
+ args,
max_batch_nodes=max_batch_nodes,
max_batch_edges=max_batch_edges,
+ train_trn=remove_offsets,
)
- return loader_train, loader_val, loader_test, to_graphs_tuple
+
+ add_offsets = lambda p: p * target_mad + target_mean
+
+ return loader_train, loader_val, loader_test, to_graphs_tuple, add_offsets
diff --git a/experiments/requirements.txt b/experiments/requirements.txt
index e75b483..56d23f4 100644
--- a/experiments/requirements.txt
+++ b/experiments/requirements.txt
@@ -1,8 +1,11 @@
---find-links https://data.pyg.org/whl/torch-1.12.1+cpu.html
+--extra-index-url https://download.pytorch.org/whl/cpu
+
+--find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html
+
e3nn==0.5.0
matplotlib>=3.6.2
rdkit==2022.9.2
-torch==1.12.1
+torch==1.13.1
torch-cluster==1.6.0
torch-geometric==2.1.0
torch-scatter==2.1.0
diff --git a/experiments/train.py b/experiments/train.py
new file mode 100644
index 0000000..f021131
--- /dev/null
+++ b/experiments/train.py
@@ -0,0 +1,166 @@
+import time
+from functools import partial
+from typing import Callable, Tuple
+
+import haiku as hk
+import jax
+from jax import jit
+import jax.numpy as jnp
+import jraph
+import optax
+
+from segnn_jax import SteerableGraphsTuple
+
+
+@partial(jit, static_argnames=["model_fn", "criterion", "task", "do_mask", "eval_trn"])
+def loss_fn(
+ params: hk.Params,
+ state: hk.State,
+ st_graph: SteerableGraphsTuple,
+ target: jnp.ndarray,
+ model_fn: Callable,
+ criterion: Callable,
+ task: str = "node",
+ do_mask: bool = True,
+ eval_trn: Callable = None,
+) -> Tuple[float, hk.State]:
+ pred, state = model_fn(params, state, st_graph)
+ if eval_trn is not None:
+ pred = eval_trn(pred)
+ if task == "node":
+ mask = jraph.get_node_padding_mask(st_graph.graph)
+ if task == "graph":
+ mask = jraph.get_graph_padding_mask(st_graph.graph)
+ # broadcase mask for vector targets
+ if len(pred.shape) == 2:
+ mask = mask[:, jnp.newaxis]
+ if do_mask:
+ target = target * mask
+ pred = pred * mask
+ assert target.shape == pred.shape
+ return jnp.sum(criterion(pred, target)) / jnp.count_nonzero(mask), state
+
+
+@partial(jit, static_argnames=["loss_fn", "opt_update"])
+def update(
+ params: hk.Params,
+ state: hk.State,
+ graph: SteerableGraphsTuple,
+ target: jnp.ndarray,
+ opt_state: optax.OptState,
+ loss_fn: Callable,
+ opt_update: Callable,
+) -> Tuple[float, hk.Params, hk.State, optax.OptState]:
+ (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
+ params, state, graph, target
+ )
+ updates, opt_state = opt_update(grads, opt_state, params)
+ return loss, optax.apply_updates(params, updates), state, opt_state
+
+
+def evaluate(
+ loader,
+ params: hk.Params,
+ state: hk.State,
+ loss_fn: Callable,
+ graph_transform: Callable,
+) -> Tuple[float, float]:
+ eval_loss = 0.0
+ eval_times = 0.0
+ for data in loader:
+ graph, target = graph_transform(data, training=False)
+ eval_start = time.perf_counter_ns()
+ loss, _ = jax.lax.stop_gradient(loss_fn(params, state, graph, target))
+ eval_loss += jax.block_until_ready(loss)
+ eval_times += (time.perf_counter_ns() - eval_start) / 1e6
+ return eval_times / len(loader), eval_loss / len(loader)
+
+
+def train(
+ key,
+ segnn,
+ loader_train,
+ loader_val,
+ loader_test,
+ loss_fn,
+ eval_loss_fn,
+ graph_transform,
+ args,
+):
+ init_graph, _ = graph_transform(next(iter(loader_train)))
+ params, segnn_state = segnn.init(key, init_graph)
+
+ print(
+ f"Starting {args.epochs} epochs "
+ f"with {hk.data_structures.tree_size(params)} parameters.\n"
+ "Jitting..."
+ )
+
+ total_steps = args.epochs * len(loader_train)
+
+ # set up learning rate and optimizer
+ learning_rate = args.lr
+ if args.lr_scheduling:
+ learning_rate = optax.piecewise_constant_schedule(
+ learning_rate,
+ boundaries_and_scales={
+ int(total_steps * 0.7): 0.1,
+ int(total_steps * 0.9): 0.1,
+ },
+ )
+ opt_init, opt_update = optax.adamw(
+ learning_rate=learning_rate, weight_decay=args.weight_decay
+ )
+
+ model_fn = segnn.apply
+
+ loss_fn = partial(loss_fn, model_fn=model_fn)
+ eval_loss_fn = partial(eval_loss_fn, model_fn=model_fn)
+ update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update)
+ eval_fn = partial(evaluate, loss_fn=eval_loss_fn, graph_transform=graph_transform)
+
+ opt_state = opt_init(params)
+ avg_time = []
+ best_val = 1e10
+
+ for e in range(args.epochs):
+ train_loss = 0.0
+ train_start = time.perf_counter_ns()
+ for data in loader_train:
+ graph, target = graph_transform(data)
+ loss, params, segnn_state, opt_state = update_fn(
+ params=params,
+ state=segnn_state,
+ graph=graph,
+ target=target,
+ opt_state=opt_state,
+ )
+ train_loss += loss
+ train_time = (time.perf_counter_ns() - train_start) / 1e6
+ train_loss /= len(loader_train)
+ print(
+ f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms",
+ end="",
+ )
+ if e % args.val_freq == 0:
+ eval_time, val_loss = eval_fn(loader_val, params, segnn_state)
+ avg_time.append(eval_time)
+ tag = ""
+ if val_loss < best_val:
+ best_val = val_loss
+ tag = " (best)"
+ _, test_loss_ckp = eval_fn(loader_test, params, segnn_state)
+ print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="")
+
+ print()
+
+ test_loss = 0
+ _, test_loss = eval_fn(loader_test, params, segnn_state)
+ # ignore compilation time
+ avg_time = avg_time[2:]
+ avg_time = sum(avg_time) / len(avg_time)
+ print(
+ "Training done.\n"
+ f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n"
+ f"Average (model) eval time {avg_time:.2f}ms"
+ )
diff --git a/main.py b/main.py
deleted file mode 100644
index fe6c591..0000000
--- a/main.py
+++ /dev/null
@@ -1,410 +0,0 @@
-import argparse
-import time
-from functools import partial
-from typing import Callable, Iterable, Tuple
-
-import e3nn_jax as e3nn
-import haiku as hk
-import jax
-import jax.numpy as jnp
-import optax
-import wandb
-
-from experiments import setup_datasets
-from segnn_jax import SEGNN, SteerableGraphsTuple, weight_balanced_irreps
-
-key = jax.random.PRNGKey(1337)
-
-
-def predict(
- model_fn: hk.TransformedWithState,
- params: hk.Params,
- state: hk.State,
- graph: SteerableGraphsTuple,
- mean_shift: float = 0,
- mad_shift: float = 1,
-) -> Tuple[jnp.ndarray, hk.State]:
- pred, state = model_fn(params, state, graph)
- return pred * mad_shift + mean_shift, state
-
-
-@partial(jax.jit, static_argnames=["model_fn", "mean_shift", "mad_shift", "mask_last"])
-def mae(
- params: hk.Params,
- state: hk.State,
- graph: SteerableGraphsTuple,
- target: jnp.ndarray,
- model_fn: Callable,
- mean_shift: float = 0,
- mad_shift: float = 1,
- mask_last: bool = False,
-) -> Tuple[float, hk.State]:
- pred, state = predict(model_fn, params, state, graph, mean_shift, mad_shift)
- assert target.shape == pred.shape
- # similar to get_graph_padding_mask
- if mask_last:
- return (jnp.abs(pred[:-1] - target[:-1])).mean(), state
- else:
- return (jnp.abs(pred - target)).mean(), state
-
-
-@partial(jax.jit, static_argnames=["model_fn", "mean_shift", "mad_shift", "mask_last"])
-def mse(
- params: hk.Params,
- state: hk.State,
- graph: SteerableGraphsTuple,
- target: jnp.ndarray,
- model_fn: Callable,
- mean_shift: float = 0,
- mad_shift: float = 1,
- mask_last: bool = False,
-) -> Tuple[float, hk.State]:
- pred, state = predict(model_fn, params, state, graph, mean_shift, mad_shift)
- assert target.shape == pred.shape
- if mask_last:
- return (jnp.power(pred[:-1] - target[:-1], 2)).mean(), state
- else:
- return (jnp.power(pred - target, 2)).mean(), state
-
-
-@partial(jax.jit, static_argnames=["loss_fn", "opt_update"])
-def update(
- params: hk.Params,
- state: hk.State,
- graph: SteerableGraphsTuple,
- target: jnp.ndarray,
- opt_state: optax.OptState,
- loss_fn: Callable,
- opt_update: Callable,
-) -> Tuple[float, hk.Params, hk.State, optax.OptState]:
- (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
- params, state, graph, target
- )
- updates, opt_state = opt_update(grads, opt_state, params)
- return loss, optax.apply_updates(params, updates), state, opt_state
-
-
-def evaluate(
- loader: Iterable,
- params: hk.Params,
- state: hk.State,
- loss_fn: Callable,
- graph_transform: Callable,
-) -> Tuple[float, float]:
- eval_loss = 0.0
- eval_times = 0.0
- for data in loader:
- graph, target = graph_transform(data)
- eval_start = time.perf_counter_ns()
- loss, _ = jax.lax.stop_gradient(loss_fn(params, state, graph, target))
- eval_loss += jax.block_until_ready(loss)
- eval_times += (time.perf_counter_ns() - eval_start) / 1e6
-
- return eval_times / len(loader), eval_loss / len(loader)
-
-
-def train(
- segnn: hk.Transformed, loader_train, loader_val, loader_test, graph_transform, args
-):
- init_graph, _ = graph_transform(next(iter(loader_train)))
- params, segnn_state = segnn.init(key, init_graph)
-
- print(
- f"Starting {args.epochs} epochs on {args.dataset} "
- f"with {hk.data_structures.tree_size(params)} parameters.\n"
- "Jitting..."
- )
-
- total_steps = args.epochs * len(loader_train)
-
- # set up learning rate and optimizer
- if args.lr_scheduling:
- learning_rate = optax.piecewise_constant_schedule(
- args.lr,
- boundaries_and_scales={
- int(total_steps * 0.8): 0.1,
- int(total_steps * 0.9): 0.1,
- },
- )
- else:
- learning_rate = args.lr
-
- opt_init, opt_update = optax.adamw(
- learning_rate=learning_rate, weight_decay=args.weight_decay
- )
-
- if args.dataset == "qm9":
- # qm9
- target_mean, target_mad = loader_train.dataset.calc_stats()
- # ignore padded target
- loss_fn = partial(mae, model_fn=segnn.apply, mask_last=True)
- eval_loss_fn = partial(
- mae,
- model_fn=segnn.apply,
- mask_last=True,
- mean_shift=target_mean,
- mad_shift=target_mad,
- )
- else:
- # nbody
- target_mean, target_mad = 0, 1
- loss_fn = partial(mse, model_fn=segnn.apply)
- eval_loss_fn = partial(mse, model_fn=segnn.apply)
-
- update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update)
- eval_fn = partial(evaluate, loss_fn=eval_loss_fn, graph_transform=graph_transform)
-
- opt_state = opt_init(params)
- avg_time = []
- best_val = 1e10
-
- for e in range(args.epochs):
- train_loss = 0.0
- train_start = time.perf_counter_ns()
- for data in loader_train:
- graph, target = graph_transform(data)
- # normalize targets
- loss, params, segnn_state, opt_state = update_fn(
- params=params,
- state=segnn_state,
- graph=graph,
- target=(target - target_mean) / target_mad,
- opt_state=opt_state,
- )
- train_loss += loss
- train_time = (time.perf_counter_ns() - train_start) / 1e6
- train_loss /= len(loader_train)
- wandb_logs = {"train_loss": float(train_loss), "update_time": float(train_time)}
- print(
- f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms",
- end="",
- )
- if e % args.val_freq == 0:
- eval_time, val_loss = eval_fn(loader_val, params, segnn_state)
- avg_time.append(eval_time)
- tag = ""
- if val_loss < best_val:
- best_val = val_loss
- _, test_loss_ckp = eval_fn(loader_test, params, segnn_state)
- wandb_logs.update({"test_loss": float(test_loss_ckp)})
- tag = " (BEST)"
- wandb_logs.update(
- {"val_loss": float(val_loss), "eval_time": float(eval_time)}
- )
- print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="")
-
- print()
- if args.wandb:
- wandb.log(wandb_logs)
-
- test_loss = 0
- _, test_loss = eval_fn(loader_test, params, segnn_state)
- # ignore compilation time
- avg_time = avg_time[2:]
- avg_time = sum(avg_time) / len(avg_time)
- if args.wandb:
- wandb.log({"test_loss": float(test_loss), "avg_eval_time": float(avg_time)})
- print(
- "Training done.\n"
- f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n"
- f"Average (model) eval time {avg_time:.2f}ms"
- )
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- # Run parameters
- parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
- parser.add_argument(
- "--batch-size",
- type=int,
- default=128,
- help="Batch size (number of graphs).",
- )
- parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
- parser.add_argument(
- "--lr-scheduling",
- action="store_true",
- help="Use learning rate scheduling",
- )
- parser.add_argument(
- "--weight-decay", type=float, default=1e-12, help="Weight decay"
- )
- parser.add_argument(
- "--dataset",
- type=str,
- choices=["qm9", "charged", "gravity"],
- help="Dataset name",
- )
- parser.add_argument(
- "--max-samples",
- type=int,
- default=3000,
- help="Maximum number of samples in nbody dataset",
- )
- parser.add_argument(
- "--val-freq",
- type=int,
- default=10,
- help="Evaluation frequency (number of epochs)",
- )
-
- # nbody parameters
- parser.add_argument(
- "--target",
- type=str,
- default="pos",
- help="Target. e.g. pos, force (gravity), alpha (qm9)",
- )
- parser.add_argument(
- "--neighbours",
- type=int,
- default=20,
- help="Number of connected nearest neighbours",
- )
- parser.add_argument(
- "--n-bodies",
- type=int,
- default=5,
- help="Number of bodies in the dataset",
- )
- parser.add_argument(
- "--dataset-name",
- type=str,
- default="small",
- choices=["small", "default", "small_out_dist"],
- help="Name of nbody data partition: default (200 steps), small (1000 steps)",
- )
-
- # qm9 parameters
- parser.add_argument(
- "--radius",
- type=float,
- default=2.0,
- help="Radius (Angstrom) between which atoms to add links.",
- )
- parser.add_argument(
- "--feature-type",
- type=str,
- default="one_hot",
- choices=["one_hot", "cormorant", "gilmer"],
- help="Type of input feature",
- )
-
- # Model parameters
- parser.add_argument(
- "--units", type=int, default=64, help="Number of values in the hidden layers"
- )
- parser.add_argument(
- "--lmax-hidden",
- type=int,
- default=1,
- help="Max degree of hidden representations.",
- )
- parser.add_argument(
- "--lmax-attributes",
- type=int,
- default=1,
- help="Max degree of geometric attribute embedding",
- )
- parser.add_argument(
- "--layers", type=int, default=7, help="Number of message passing layers"
- )
- parser.add_argument(
- "--blocks", type=int, default=2, help="Number of layers in steerable MLPs."
- )
- parser.add_argument(
- "--norm",
- type=str,
- default="none",
- choices=["instance", "batch", "none"],
- help="Normalisation type",
- )
- parser.add_argument(
- "--double-precision",
- action="store_true",
- help="Use double precision in model",
- )
-
- # wandb parameters
- parser.add_argument(
- "--wandb",
- action="store_true",
- help="Activate weights and biases logging",
- )
- parser.add_argument(
- "--wandb-project",
- type=str,
- default="segnn",
- help="Weights and biases project",
- )
- parser.add_argument(
- "--wandb-entity",
- type=str,
- default="",
- help="Weights and biases entity",
- )
-
- args = parser.parse_args()
-
- # if specified set jax in double precision
- jax.config.update("jax_enable_x64", args.double_precision)
-
- # connect to wandb
- if args.wandb:
- wandb_name = "_".join(
- [
- args.wandb_project,
- args.dataset,
- args.target,
- str(int(time.time())),
- ]
- )
- wandb.init(
- project=args.wandb_project,
- name=wandb_name,
- config=args,
- entity=args.wandb_entity,
- )
-
- # feature representations
- if args.dataset == "qm9":
- task = "graph"
- if args.feature_type == "one_hot":
- args.node_irreps = e3nn.Irreps("5x0e")
- elif args.feature_type == "cormorant":
- args.node_irreps = e3nn.Irreps("15x0e")
- elif args.feature_type == "gilmer":
- args.node_irreps = e3nn.Irreps("11x0e")
- args.output_irreps = e3nn.Irreps("1x0e")
- args.additional_message_irreps = e3nn.Irreps("1x0e")
- elif args.dataset in ["charged", "gravity"]:
- task = "node"
- args.node_irreps = e3nn.Irreps("2x1o + 1x0e")
- args.output_irreps = e3nn.Irreps("1x1o")
- args.additional_message_irreps = e3nn.Irreps("2x0e")
-
- # Create hidden irreps
- hidden_irreps = weight_balanced_irreps(
- scalar_units=args.units,
- # attribute irreps
- irreps_right=e3nn.Irreps.spherical_harmonics(args.lmax_attributes),
- use_sh=True,
- lmax=args.lmax_hidden,
- )
-
- # build model
- segnn = lambda x: SEGNN(
- hidden_irreps=hidden_irreps,
- output_irreps=args.output_irreps,
- num_layers=args.layers,
- task=task,
- pool="avg",
- blocks_per_layer=args.blocks,
- norm=args.norm,
- )(x)
- segnn = hk.without_apply_rng(hk.transform_with_state(segnn))
-
- dataset_train, dataset_val, dataset_test, graph_transform = setup_datasets(args)
-
- train(segnn, dataset_train, dataset_val, dataset_test, graph_transform, args)
diff --git a/requirements.txt b/requirements.txt
index 2aa0260..2fe59a5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,7 @@
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
dm-haiku==0.0.9
-e3nn-jax==0.17.0
-jax[cuda]==0.4.1
+e3nn-jax==0.17.4
+jax[cuda]==0.4.8
jraph==0.0.6.dev0
numpy>=1.23.4
optax==0.1.3
diff --git a/segnn_jax/__init__.py b/segnn_jax/__init__.py
index 987e13a..64f5182 100644
--- a/segnn_jax/__init__.py
+++ b/segnn_jax/__init__.py
@@ -14,4 +14,4 @@
"SteerableGraphsTuple",
]
-__version__ = "0.5"
+__version__ = "0.6"
diff --git a/segnn_jax/blocks.py b/segnn_jax/blocks.py
index 8ce752a..58e74c1 100644
--- a/segnn_jax/blocks.py
+++ b/segnn_jax/blocks.py
@@ -31,21 +31,15 @@ class O3TensorProduct(hk.Module):
"""
O(3) equivariant linear parametrized tensor product layer.
- Attributes:
- left_irreps: Left input representation
- right_irreps: Right input representation
- output_irreps: Output representation
- get_parameter: Haiku parameter getter and init function
- tensor_product: Tensor product function
- biases: Bias wrapper function
+ Functionally the same as O3TensorProductLegacy, but around 5-10% faster.
+ FullyConnectedTensorProduct seems faster than tensor_product + linear:
+ https://github.com/e3nn/e3nn-jax/releases/tag/0.14.0
"""
def __init__(
self,
output_irreps: e3nn.Irreps,
*,
- left_irreps: e3nn.Irreps,
- right_irreps: Optional[e3nn.Irreps] = None,
biases: bool = True,
name: Optional[str] = None,
init_fn: Optional[InitFn] = None,
@@ -56,8 +50,6 @@ def __init__(
Args:
output_irreps: Output representation
- left_irreps: Left input representation
- right_irreps: Right input representation (optional, defaults to 1x0e)
biases: If set ot true will add biases
name: Name of the linear layer params
init_fn: Weight initialization function. Default is uniform.
@@ -67,41 +59,34 @@ def __init__(
"""
super().__init__(name)
- if not right_irreps:
- # NOTE: this is equivalent to a linear recombination of the left vectors
- right_irreps = e3nn.Irreps("1x0e")
-
if not isinstance(output_irreps, e3nn.Irreps):
output_irreps = e3nn.Irreps(output_irreps)
- if not isinstance(left_irreps, e3nn.Irreps):
- left_irreps = e3nn.Irreps(left_irreps)
- if not isinstance(right_irreps, e3nn.Irreps):
- right_irreps = e3nn.Irreps(right_irreps)
-
self.output_irreps = output_irreps
- self.right_irreps = right_irreps
- self.left_irreps = left_irreps
-
+ # tp weight init
if not init_fn:
init_fn = uniform_init
-
self.get_parameter = init_fn
if not gradient_normalization:
gradient_normalization = config("gradient_normalization")
if not path_normalization:
path_normalization = config("path_normalization")
+ self._gradient_normalization = gradient_normalization
+ self._path_normalization = path_normalization
+
+ self.biases = biases and "0e" in self.output_irreps
- # NOTE FunctionalFullyConnectedTensorProduct appears to be faster than combining
- # tensor_product+linear: https://github.com/e3nn/e3nn-jax/releases/tag/0.14.0
- # Implementation adapted from e3nn.haiku.FullyConnectedTensorProduct
+ def _build_tensor_product(
+ self, left_irreps: e3nn.Irreps, right_irreps: e3nn.Irreps
+ ) -> Callable:
+ """Build the tensor product function."""
tp = e3nn.FunctionalFullyConnectedTensorProduct(
left_irreps,
right_irreps,
- output_irreps,
- gradient_normalization=gradient_normalization,
- path_normalization=path_normalization,
+ self.output_irreps,
+ gradient_normalization=self._gradient_normalization,
+ path_normalization=self._path_normalization,
)
ws = [
self.get_parameter(
@@ -118,35 +103,32 @@ def __init__(
]
def tensor_product(x, y, **kwargs):
- return tp.left_right(ws, x, y, **kwargs)._convert(output_irreps)
+ return tp.left_right(ws, x, y, **kwargs)._convert(self.output_irreps)
- self.tensor_product = naive_broadcast_decorator(tensor_product)
- self.biases = None
+ return naive_broadcast_decorator(tensor_product)
- if biases and "0e" in self.output_irreps:
- # add biases
- b = [
- self.get_parameter(
- f"b[{i_out}] {tp.irreps_out[i_out]}",
- path_shape=(mul_ir.dim,),
- weight_std=1 / jnp.sqrt(mul_ir.dim),
- )
- for i_out, mul_ir in enumerate(output_irreps)
- if mul_ir.ir.is_scalar()
- ]
- b = e3nn.IrrepsArray(
- f"{self.output_irreps.count('0e')}x0e", jnp.concatenate(b)
+ def _build_biases(self) -> Callable:
+ """Build the add bias function."""
+ b = [
+ self.get_parameter(
+ f"b[{i_out}] {self.output_irreps}",
+ path_shape=(mul_ir.dim,),
+ weight_std=1 / jnp.sqrt(mul_ir.dim),
+ )
+ for i_out, mul_ir in enumerate(self.output_irreps)
+ if mul_ir.ir.is_scalar()
+ ]
+ b = e3nn.IrrepsArray(f"{self.output_irreps.count('0e')}x0e", jnp.concatenate(b))
+
+ # TODO: could be improved
+ def _wrapper(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
+ scalars = x.filter("0e")
+ other = x.filter(drop="0e")
+ return e3nn.concatenate(
+ [scalars + b.broadcast_to(scalars.shape), other], axis=1
)
- # TODO: could be improved
- def _wrapper(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
- scalars = x.filter("0e")
- other = x.filter(drop="0e")
- return e3nn.concatenate(
- [scalars + b.broadcast_to(scalars.shape), other], axis=1
- )
-
- self.biases = _wrapper
+ return _wrapper
def __call__(
self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs
@@ -162,7 +144,7 @@ def __call__(
"""
if not y:
- y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1)))
+ y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype))
if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0:
warnings.warn(
@@ -171,18 +153,13 @@ def __call__(
"redistributing them into scalars or choose higher orders."
)
- assert (
- x.irreps == self.left_irreps
- ), f"Left irreps do not match. Got {x.irreps}, expected {self.left_irreps}"
- assert (
- y.irreps == self.right_irreps
- ), f"Right irreps do not match. Got {y.irreps}, expected {self.right_irreps}"
-
- output = self.tensor_product(x, y, **kwargs)
+ tp = self._build_tensor_product(x.irreps, y.irreps)
+ output = tp(x, y, **kwargs)
if self.biases:
# add biases
- return self.biases(output)
+ bias_fn = self._build_biases()
+ return bias_fn(output)
return output
@@ -190,8 +167,6 @@ def __call__(
def O3TensorProductLegacy(
output_irreps: e3nn.Irreps,
*,
- left_irreps: e3nn.Irreps,
- right_irreps: Optional[e3nn.Irreps] = None,
biases: bool = True,
name: Optional[str] = None,
init_fn: Optional[InitFn] = None,
@@ -215,15 +190,8 @@ def O3TensorProductLegacy(
A function that returns the output to the weighted tensor product.
"""
- if not right_irreps:
- right_irreps = e3nn.Irreps("1x0e")
-
if not isinstance(output_irreps, e3nn.Irreps):
output_irreps = e3nn.Irreps(output_irreps)
- if not isinstance(left_irreps, e3nn.Irreps):
- left_irreps = e3nn.Irreps(left_irreps)
- if not isinstance(right_irreps, e3nn.Irreps):
- right_irreps = e3nn.Irreps(right_irreps)
if not init_fn:
init_fn = uniform_init
@@ -251,7 +219,7 @@ def _tensor_product(
"""
if not y:
- y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1)))
+ y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype))
if x.irreps.lmax == 0 and y.irreps.lmax == 0 and output_irreps.lmax > 0:
warnings.warn(
@@ -260,13 +228,6 @@ def _tensor_product(
"redistributing them into scalars or choose higher orders."
)
- assert (
- x.irreps == left_irreps
- ), f"Left irreps do not match. Got {x.irreps}, expected {left_irreps}"
- assert (
- y.irreps == right_irreps
- ), f"Right irreps do not match. Got {y.irreps}, expected {right_irreps}"
-
tp = e3nn.tensor_product(x, y)
return linear(tp)
@@ -280,8 +241,6 @@ def _tensor_product(
def O3TensorProductGate(
output_irreps: e3nn.Irreps,
*,
- left_irreps: e3nn.Irreps,
- right_irreps: Optional[e3nn.Irreps] = None,
biases: bool = True,
scalar_activation: Optional[Callable] = None,
gate_activation: Optional[Callable] = None,
@@ -312,8 +271,6 @@ def O3TensorProductGate(
)
tensor_product = O3Layer(
(gate_irreps + output_irreps).regroup(),
- left_irreps=left_irreps,
- right_irreps=right_irreps,
biases=biases,
name=name,
init_fn=init_fn,
diff --git a/segnn_jax/segnn.py b/segnn_jax/segnn.py
index 4f12ffb..a2567a4 100644
--- a/segnn_jax/segnn.py
+++ b/segnn_jax/segnn.py
@@ -26,12 +26,9 @@ def O3Embedding(embed_irreps: e3nn.Irreps, embed_edges: bool = True) -> Callable
def _embedding(
st_graph: SteerableGraphsTuple,
) -> SteerableGraphsTuple:
- # TODO update
graph = st_graph.graph
nodes = O3Layer(
embed_irreps,
- left_irreps=graph.nodes.irreps,
- right_irreps=st_graph.node_attributes.irreps,
name="embedding_nodes",
)(graph.nodes, st_graph.node_attributes)
st_graph = st_graph._replace(graph=graph._replace(nodes=nodes))
@@ -40,8 +37,6 @@ def _embedding(
if embed_edges:
additional_message_features = O3Layer(
embed_irreps,
- left_irreps=graph.nodes.irreps,
- right_irreps=st_graph.node_attributes.irreps,
name="embedding_msg_features",
)(
st_graph.additional_message_features,
@@ -82,30 +77,21 @@ def _decoder(st_graph: SteerableGraphsTuple):
nodes = st_graph.graph.nodes
# pre pool block
for i in range(blocks):
- nodes = O3TensorProductGate(
- latent_irreps,
- left_irreps=nodes.irreps,
- right_irreps=st_graph.node_attributes.irreps,
- name=f"prepool_{i}",
- )(nodes, st_graph.node_attributes)
+ nodes = O3TensorProductGate(latent_irreps, name=f"prepool_{i}")(
+ nodes, st_graph.node_attributes
+ )
if task == "node":
- nodes = O3Layer(
- output_irreps,
- left_irreps=nodes.irreps,
- right_irreps=st_graph.node_attributes.irreps,
- name="output",
- )(nodes, st_graph.node_attributes)
+ nodes = O3Layer(output_irreps, name="output")(
+ nodes, st_graph.node_attributes
+ )
if task == "graph":
# pool over graph
pooled_irreps = (latent_irreps.num_irreps * output_irreps).regroup()
- nodes = O3Layer(
- pooled_irreps,
- left_irreps=nodes.irreps,
- right_irreps=st_graph.node_attributes.irreps,
- name=f"prepool_{blocks}",
- )(nodes, st_graph.node_attributes)
+ nodes = O3Layer(pooled_irreps, name=f"prepool_{blocks}")(
+ nodes, st_graph.node_attributes
+ )
# pooling layer
if pool == "avg":
@@ -117,12 +103,8 @@ def _decoder(st_graph: SteerableGraphsTuple):
# post pool mlp (not steerable)
for i in range(blocks):
- nodes = O3TensorProductGate(
- pooled_irreps, left_irreps=nodes.irreps, name=f"postpool_{i}"
- )(nodes)
- nodes = O3Layer(output_irreps, left_irreps=nodes.irreps, name="output")(
- nodes
- )
+ nodes = O3TensorProductGate(pooled_irreps, name=f"postpool_{i}")(nodes)
+ nodes = O3Layer(output_irreps, name="output")(nodes)
return nodes
@@ -179,12 +161,9 @@ def _message(
msg = e3nn.concatenate([msg, additional_message_features], axis=-1)
# message mlp (phi_m in the paper) steered by edge attributeibutes
for i in range(self._blocks):
- msg = O3TensorProductGate(
- self._output_irreps,
- left_irreps=msg.irreps,
- right_irreps=getattr(edge_attribute, "irreps", None),
- name=f"tp_{i}",
- )(msg, edge_attribute)
+ msg = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")(
+ msg, edge_attribute
+ )
# NOTE: original implementation only applied batch norm to messages
if self._norm == "batch":
msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg)
@@ -204,19 +183,13 @@ def _update(
x = e3nn.concatenate([nodes, msg], axis=-1)
# update mlp (phi_f in the paper) steered by node attributeibutes
for i in range(self._blocks - 1):
- x = O3TensorProductGate(
- self._output_irreps,
- left_irreps=x.irreps,
- right_irreps=getattr(node_attribute, "irreps", None),
- name=f"tp_{i}",
- )(x, node_attribute)
+ x = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")(
+ x, node_attribute
+ )
# last update layer without activation
- update = O3Layer(
- self._output_irreps,
- left_irreps=x.irreps,
- right_irreps=getattr(node_attribute, "irreps", None),
- name=f"tp_{self._blocks - 1}",
- )(x, node_attribute)
+ update = O3Layer(self._output_irreps, name=f"tp_{self._blocks - 1}")(
+ x, node_attribute
+ )
# residual connection
nodes += update
# message norm
diff --git a/setup.cfg b/setup.cfg
index 53a3feb..9bd45b6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -15,9 +15,9 @@ packages = segnn_jax
python_requires = >=3.8
install_requires =
dm_haiku==0.0.9
- e3nn_jax==0.17.0
- jax==0.4.1
- jaxlib==0.4.1
+ e3nn_jax==0.17.4
+ jax==0.4.8
+ jaxlib==0.4.8
jraph==0.0.6.dev0
numpy>=1.23.4
optax==0.1.3
diff --git a/tests/test_blocks.py b/tests/test_blocks.py
index 8af60ad..e7a8336 100644
--- a/tests/test_blocks.py
+++ b/tests/test_blocks.py
@@ -8,9 +8,7 @@
@pytest.mark.parametrize("biases", [False, True])
def test_linear(key, biases):
- f = lambda x1, x2: O3TensorProduct(
- "1x1o", left_irreps="1x1o", right_irreps="1x1o", biases=biases
- )(x1, x2)
+ f = lambda x1, x2: O3TensorProduct("1x1o", biases=biases)(x1, x2)
f = hk.without_apply_rng(hk.transform(f))
v = e3nn.normal("1x1o", key, (5,))
@@ -31,9 +29,7 @@ def test_gated(key, biases):
segnn_jax.blocks.O3Layer = segnn_jax.blocks.O3TensorProduct
- f = lambda x1, x2: O3TensorProductGate(
- "1x1o", left_irreps="1x1o", right_irreps="1x1o", biases=biases
- )(x1, x2)
+ f = lambda x1, x2: O3TensorProductGate("1x1o", biases=biases)(x1, x2)
f = hk.without_apply_rng(hk.transform(f))
v = e3nn.normal("1x1o", key, (5,))
@@ -50,9 +46,7 @@ def test_gated(key, biases):
@pytest.mark.parametrize("biases", [False, True])
def test_linear_legacy(key, biases):
- f = lambda x1, x2: O3TensorProductLegacy(
- "1x1o", left_irreps="1x1o", right_irreps="1x1o", biases=biases
- )(x1, x2)
+ f = lambda x1, x2: O3TensorProductLegacy("1x1o", biases=biases)(x1, x2)
f = hk.without_apply_rng(hk.transform(f))
v = e3nn.normal("1x1o", key, (5,))
diff --git a/validate.py b/validate.py
new file mode 100644
index 0000000..4eaa14e
--- /dev/null
+++ b/validate.py
@@ -0,0 +1,238 @@
+import argparse
+import time
+from functools import partial
+
+import e3nn_jax as e3nn
+import haiku as hk
+import jax
+import jax.numpy as jnp
+import wandb
+
+from experiments import setup_data, train
+from segnn_jax import SEGNN, weight_balanced_irreps
+
+key = jax.random.PRNGKey(1337)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Run parameters
+ parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=128,
+ help="Batch size (number of graphs).",
+ )
+ parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
+ parser.add_argument(
+ "--lr-scheduling",
+ action="store_true",
+ help="Use learning rate scheduling",
+ )
+ parser.add_argument(
+ "--weight-decay", type=float, default=1e-12, help="Weight decay"
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ choices=["qm9", "charged", "gravity"],
+ help="Dataset name",
+ )
+ parser.add_argument(
+ "--max-samples",
+ type=int,
+ default=3000,
+ help="Maximum number of samples in nbody dataset",
+ )
+ parser.add_argument(
+ "--val-freq",
+ type=int,
+ default=10,
+ help="Evaluation frequency (number of epochs)",
+ )
+
+ # nbody parameters
+ parser.add_argument(
+ "--target",
+ type=str,
+ default="pos",
+ help="Target. e.g. pos, force (gravity), alpha (qm9)",
+ )
+ parser.add_argument(
+ "--neighbours",
+ type=int,
+ default=20,
+ help="Number of connected nearest neighbours",
+ )
+ parser.add_argument(
+ "--n-bodies",
+ type=int,
+ default=5,
+ help="Number of bodies in the dataset",
+ )
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ default="small",
+ choices=["small", "default", "small_out_dist"],
+ help="Name of nbody data partition: default (200 steps), small (1000 steps)",
+ )
+
+ # qm9 parameters
+ parser.add_argument(
+ "--radius",
+ type=float,
+ default=2.0,
+ help="Radius (Angstrom) between which atoms to add links.",
+ )
+ parser.add_argument(
+ "--feature-type",
+ type=str,
+ default="one_hot",
+ choices=["one_hot", "cormorant", "gilmer"],
+ help="Type of input feature",
+ )
+
+ # Model parameters
+ parser.add_argument(
+ "--units", type=int, default=64, help="Number of values in the hidden layers"
+ )
+ parser.add_argument(
+ "--lmax-hidden",
+ type=int,
+ default=1,
+ help="Max degree of hidden representations.",
+ )
+ parser.add_argument(
+ "--lmax-attributes",
+ type=int,
+ default=1,
+ help="Max degree of geometric attribute embedding",
+ )
+ parser.add_argument(
+ "--layers", type=int, default=7, help="Number of message passing layers"
+ )
+ parser.add_argument(
+ "--blocks", type=int, default=2, help="Number of layers in steerable MLPs."
+ )
+ parser.add_argument(
+ "--norm",
+ type=str,
+ default="none",
+ choices=["instance", "batch", "none"],
+ help="Normalisation type",
+ )
+ parser.add_argument(
+ "--double-precision",
+ action="store_true",
+ help="Use double precision in model",
+ )
+
+ # wandb parameters
+ parser.add_argument(
+ "--wandb",
+ action="store_true",
+ help="Activate weights and biases logging",
+ )
+ parser.add_argument(
+ "--wandb-project",
+ type=str,
+ default="segnn",
+ help="Weights and biases project",
+ )
+ parser.add_argument(
+ "--wandb-entity",
+ type=str,
+ default="",
+ help="Weights and biases entity",
+ )
+
+ args = parser.parse_args()
+
+ # if specified set jax in double precision
+ jax.config.update("jax_enable_x64", args.double_precision)
+
+ # connect to wandb
+ if args.wandb:
+ wandb_name = "_".join(
+ [
+ args.wandb_project,
+ args.dataset,
+ args.target,
+ str(int(time.time())),
+ ]
+ )
+ wandb.init(
+ project=args.wandb_project,
+ name=wandb_name,
+ config=args,
+ entity=args.wandb_entity,
+ )
+
+ # feature representations
+ if args.dataset == "qm9":
+ args.task = "graph"
+ if args.feature_type == "one_hot":
+ args.node_irreps = e3nn.Irreps("5x0e")
+ elif args.feature_type == "cormorant":
+ args.node_irreps = e3nn.Irreps("15x0e")
+ elif args.feature_type == "gilmer":
+ args.node_irreps = e3nn.Irreps("11x0e")
+ args.output_irreps = e3nn.Irreps("1x0e")
+ args.additional_message_irreps = e3nn.Irreps("1x0e")
+ elif args.dataset in ["charged", "gravity"]:
+ args.task = "node"
+ args.node_irreps = e3nn.Irreps("2x1o + 1x0e")
+ args.output_irreps = e3nn.Irreps("1x1o")
+ args.additional_message_irreps = e3nn.Irreps("2x0e")
+
+ # Create hidden irreps
+ hidden_irreps = weight_balanced_irreps(
+ scalar_units=args.units,
+ # attribute irreps
+ irreps_right=e3nn.Irreps.spherical_harmonics(args.lmax_attributes),
+ use_sh=True,
+ lmax=args.lmax_hidden,
+ )
+
+ # build model
+ segnn = lambda x: SEGNN(
+ hidden_irreps=hidden_irreps,
+ output_irreps=args.output_irreps,
+ num_layers=args.layers,
+ task=args.task,
+ pool="avg",
+ blocks_per_layer=args.blocks,
+ norm=args.norm,
+ )(x)
+ segnn = hk.without_apply_rng(hk.transform_with_state(segnn))
+
+ loader_train, loader_val, loader_test, graph_transform, eval_trn = setup_data(args)
+
+ if args.dataset == "qm9":
+ from experiments.train import loss_fn
+
+ _mae = lambda p, t: jnp.abs(p - t)
+
+ train_loss = partial(loss_fn, criterion=_mae, task=args.task)
+ eval_loss = partial(loss_fn, criterion=_mae, eval_trn=eval_trn, task=args.task)
+ if args.dataset in ["charged", "gravity"]:
+ from experiments.train import loss_fn
+
+ _mse = lambda p, t: jnp.power(p - t, 2)
+
+ train_loss = partial(loss_fn, criterion=_mse, do_mask=False)
+ eval_loss = partial(loss_fn, criterion=_mse, do_mask=False)
+
+ train(
+ key,
+ segnn,
+ loader_train,
+ loader_val,
+ loader_test,
+ train_loss,
+ eval_loss,
+ graph_transform,
+ args,
+ )