diff --git a/README.md b/README.md index 90e4e88..cd31af2 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ python -m pip install -e . ### GPU support Upgrade `jax` to the gpu version ``` -pip install --upgrade "jax[cuda]==0.4.1" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` ## Validation @@ -57,13 +57,13 @@ Times are remeasured on Quadro RTX 4000, __model only__ on batches of 100 graphs QM9 (alpha) - .075* + .066* 82.53 - .098 + .082 105.98** -* rerun +* rerun on same conditions ** padded (naive) @@ -76,7 +76,6 @@ git clone https://github.com/gerkone/segnn-jax They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough). ``` -pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu python -m pip install -r experiments/requirements.txt ``` @@ -96,17 +95,17 @@ python3 -u generate_dataset.py --simulation=gravity --n-balls=100 ### Usage #### N-body (charged) ``` -python main.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 +python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 ``` #### N-body (gravity) ``` -python main.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100 +python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100 ``` #### QM9 ``` -python main.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling +python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling ``` (configurations used in validation) diff --git a/experiments/__init__.py b/experiments/__init__.py index fff0437..be203b3 100644 --- a/experiments/__init__.py +++ b/experiments/__init__.py @@ -4,6 +4,10 @@ from .nbody.utils import setup_nbody_data from .qm9.utils import setup_qm9_data +from .train import train + +__all__ = ["setup_data", "train"] + __setup_conf = { "qm9": setup_qm9_data, @@ -12,7 +16,7 @@ } -def setup_datasets(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]: +def setup_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]: assert args.dataset in [ "qm9", "charged", diff --git a/experiments/nbody/utils.py b/experiments/nbody/utils.py index c4cab65..e00d1ca 100644 --- a/experiments/nbody/utils.py +++ b/experiments/nbody/utils.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import jax.tree_util as tree +import jraph import numpy as np import torch from jraph import GraphsTuple, segment_mean @@ -111,7 +112,10 @@ def NbodyGraphTransform( ] ).T - def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]: + def _to_steerable_graph( + data: List, training: bool = True + ) -> Tuple[SteerableGraphsTuple, jnp.ndarray]: + _ = training loc, vel, _, q, targets = data cur_batch = int(loc.shape[0] / n_nodes) @@ -138,6 +142,7 @@ def _to_steerable_graph(data: List) -> Tuple[SteerableGraphsTuple, jnp.ndarray]: ) ) st_graph = transform(st_graph, loc, vel, q) + # relative shift as target if relative_target: targets = targets - loc @@ -157,7 +162,9 @@ def numpy_collate(batch): return jnp.array(batch) -def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]: +def setup_nbody_data( + args, +) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]: if args.dataset == "charged": dataset_train = ChargedDataset( partition="train", @@ -234,4 +241,4 @@ def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable collate_fn=numpy_collate, ) - return loader_train, loader_val, loader_test, graph_transform + return loader_train, loader_val, loader_test, graph_transform, None diff --git a/experiments/qm9/utils.py b/experiments/qm9/utils.py index b90f364..1007cd0 100644 --- a/experiments/qm9/utils.py +++ b/experiments/qm9/utils.py @@ -12,25 +12,26 @@ def QM9GraphTransform( - node_features_irreps: e3nn.Irreps, - edge_features_irreps: e3nn.Irreps, - lmax_attributes: int, + args, max_batch_nodes: int, max_batch_edges: int, + train_trn: Callable, ) -> Callable: """ Build a function that converts torch DataBatch into SteerableGraphsTuple. Mostly a quick fix out of lazyness. Rewriting QM9 in jax is not trivial. """ - attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes) + attribute_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes) - def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]: + def _to_steerable_graph( + data: Data, training: bool = True + ) -> Tuple[SteerableGraphsTuple, jnp.array]: ptr = jnp.array(data.ptr) senders = jnp.array(data.edge_index[0]) receivers = jnp.array(data.edge_index[1]) graph = jraph.GraphsTuple( - nodes=e3nn.IrrepsArray(node_features_irreps, jnp.array(data.x)), + nodes=e3nn.IrrepsArray(args.node_irreps, jnp.array(data.x)), edges=None, senders=senders, receivers=receivers, @@ -54,7 +55,7 @@ def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]: node_attributes.array = node_attributes.array.at[:, 0].set(1.0) additional_message_features = e3nn.IrrepsArray( - edge_features_irreps, + args.additional_message_irreps, jnp.pad(jnp.array(data.additional_message_features), edge_attr_pad), ) edge_attributes = e3nn.IrrepsArray( @@ -69,13 +70,24 @@ def _to_steerable_graph(data: Data) -> Tuple[SteerableGraphsTuple, jnp.array]: ) # pad targets - target = jnp.append(jnp.array(data.y), 0) + target = jnp.array(data.y) + if args.task == "node": + target = jnp.pad(target, [(0, max_batch_nodes - target.shape[0] - 1)]) + if args.task == "graph": + target = jnp.append(target, 0) + + # normalize targets + if training and train_trn is not None: + target = train_trn(target) + return st_graph, target return _to_steerable_graph -def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]: +def setup_qm9_data( + args, +) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]: dataset_train = QM9( "datasets", args.target, @@ -115,6 +127,10 @@ def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]: ) ) + target_mean, target_mad = dataset_train.calc_stats() + + remove_offsets = lambda t: (t - target_mean) / target_mad + # not great and very slow due to huge padding loader_train = DataLoader( dataset_train, @@ -136,10 +152,12 @@ def setup_qm9_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]: ) to_graphs_tuple = QM9GraphTransform( - args.node_irreps, - args.additional_message_irreps, - args.lmax_attributes, + args, max_batch_nodes=max_batch_nodes, max_batch_edges=max_batch_edges, + train_trn=remove_offsets, ) - return loader_train, loader_val, loader_test, to_graphs_tuple + + add_offsets = lambda p: p * target_mad + target_mean + + return loader_train, loader_val, loader_test, to_graphs_tuple, add_offsets diff --git a/experiments/requirements.txt b/experiments/requirements.txt index e75b483..56d23f4 100644 --- a/experiments/requirements.txt +++ b/experiments/requirements.txt @@ -1,8 +1,11 @@ ---find-links https://data.pyg.org/whl/torch-1.12.1+cpu.html +--extra-index-url https://download.pytorch.org/whl/cpu + +--find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html + e3nn==0.5.0 matplotlib>=3.6.2 rdkit==2022.9.2 -torch==1.12.1 +torch==1.13.1 torch-cluster==1.6.0 torch-geometric==2.1.0 torch-scatter==2.1.0 diff --git a/experiments/train.py b/experiments/train.py new file mode 100644 index 0000000..f021131 --- /dev/null +++ b/experiments/train.py @@ -0,0 +1,166 @@ +import time +from functools import partial +from typing import Callable, Tuple + +import haiku as hk +import jax +from jax import jit +import jax.numpy as jnp +import jraph +import optax + +from segnn_jax import SteerableGraphsTuple + + +@partial(jit, static_argnames=["model_fn", "criterion", "task", "do_mask", "eval_trn"]) +def loss_fn( + params: hk.Params, + state: hk.State, + st_graph: SteerableGraphsTuple, + target: jnp.ndarray, + model_fn: Callable, + criterion: Callable, + task: str = "node", + do_mask: bool = True, + eval_trn: Callable = None, +) -> Tuple[float, hk.State]: + pred, state = model_fn(params, state, st_graph) + if eval_trn is not None: + pred = eval_trn(pred) + if task == "node": + mask = jraph.get_node_padding_mask(st_graph.graph) + if task == "graph": + mask = jraph.get_graph_padding_mask(st_graph.graph) + # broadcase mask for vector targets + if len(pred.shape) == 2: + mask = mask[:, jnp.newaxis] + if do_mask: + target = target * mask + pred = pred * mask + assert target.shape == pred.shape + return jnp.sum(criterion(pred, target)) / jnp.count_nonzero(mask), state + + +@partial(jit, static_argnames=["loss_fn", "opt_update"]) +def update( + params: hk.Params, + state: hk.State, + graph: SteerableGraphsTuple, + target: jnp.ndarray, + opt_state: optax.OptState, + loss_fn: Callable, + opt_update: Callable, +) -> Tuple[float, hk.Params, hk.State, optax.OptState]: + (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)( + params, state, graph, target + ) + updates, opt_state = opt_update(grads, opt_state, params) + return loss, optax.apply_updates(params, updates), state, opt_state + + +def evaluate( + loader, + params: hk.Params, + state: hk.State, + loss_fn: Callable, + graph_transform: Callable, +) -> Tuple[float, float]: + eval_loss = 0.0 + eval_times = 0.0 + for data in loader: + graph, target = graph_transform(data, training=False) + eval_start = time.perf_counter_ns() + loss, _ = jax.lax.stop_gradient(loss_fn(params, state, graph, target)) + eval_loss += jax.block_until_ready(loss) + eval_times += (time.perf_counter_ns() - eval_start) / 1e6 + return eval_times / len(loader), eval_loss / len(loader) + + +def train( + key, + segnn, + loader_train, + loader_val, + loader_test, + loss_fn, + eval_loss_fn, + graph_transform, + args, +): + init_graph, _ = graph_transform(next(iter(loader_train))) + params, segnn_state = segnn.init(key, init_graph) + + print( + f"Starting {args.epochs} epochs " + f"with {hk.data_structures.tree_size(params)} parameters.\n" + "Jitting..." + ) + + total_steps = args.epochs * len(loader_train) + + # set up learning rate and optimizer + learning_rate = args.lr + if args.lr_scheduling: + learning_rate = optax.piecewise_constant_schedule( + learning_rate, + boundaries_and_scales={ + int(total_steps * 0.7): 0.1, + int(total_steps * 0.9): 0.1, + }, + ) + opt_init, opt_update = optax.adamw( + learning_rate=learning_rate, weight_decay=args.weight_decay + ) + + model_fn = segnn.apply + + loss_fn = partial(loss_fn, model_fn=model_fn) + eval_loss_fn = partial(eval_loss_fn, model_fn=model_fn) + update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update) + eval_fn = partial(evaluate, loss_fn=eval_loss_fn, graph_transform=graph_transform) + + opt_state = opt_init(params) + avg_time = [] + best_val = 1e10 + + for e in range(args.epochs): + train_loss = 0.0 + train_start = time.perf_counter_ns() + for data in loader_train: + graph, target = graph_transform(data) + loss, params, segnn_state, opt_state = update_fn( + params=params, + state=segnn_state, + graph=graph, + target=target, + opt_state=opt_state, + ) + train_loss += loss + train_time = (time.perf_counter_ns() - train_start) / 1e6 + train_loss /= len(loader_train) + print( + f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms", + end="", + ) + if e % args.val_freq == 0: + eval_time, val_loss = eval_fn(loader_val, params, segnn_state) + avg_time.append(eval_time) + tag = "" + if val_loss < best_val: + best_val = val_loss + tag = " (best)" + _, test_loss_ckp = eval_fn(loader_test, params, segnn_state) + print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="") + + print() + + test_loss = 0 + _, test_loss = eval_fn(loader_test, params, segnn_state) + # ignore compilation time + avg_time = avg_time[2:] + avg_time = sum(avg_time) / len(avg_time) + print( + "Training done.\n" + f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n" + f"Average (model) eval time {avg_time:.2f}ms" + ) diff --git a/main.py b/main.py deleted file mode 100644 index fe6c591..0000000 --- a/main.py +++ /dev/null @@ -1,410 +0,0 @@ -import argparse -import time -from functools import partial -from typing import Callable, Iterable, Tuple - -import e3nn_jax as e3nn -import haiku as hk -import jax -import jax.numpy as jnp -import optax -import wandb - -from experiments import setup_datasets -from segnn_jax import SEGNN, SteerableGraphsTuple, weight_balanced_irreps - -key = jax.random.PRNGKey(1337) - - -def predict( - model_fn: hk.TransformedWithState, - params: hk.Params, - state: hk.State, - graph: SteerableGraphsTuple, - mean_shift: float = 0, - mad_shift: float = 1, -) -> Tuple[jnp.ndarray, hk.State]: - pred, state = model_fn(params, state, graph) - return pred * mad_shift + mean_shift, state - - -@partial(jax.jit, static_argnames=["model_fn", "mean_shift", "mad_shift", "mask_last"]) -def mae( - params: hk.Params, - state: hk.State, - graph: SteerableGraphsTuple, - target: jnp.ndarray, - model_fn: Callable, - mean_shift: float = 0, - mad_shift: float = 1, - mask_last: bool = False, -) -> Tuple[float, hk.State]: - pred, state = predict(model_fn, params, state, graph, mean_shift, mad_shift) - assert target.shape == pred.shape - # similar to get_graph_padding_mask - if mask_last: - return (jnp.abs(pred[:-1] - target[:-1])).mean(), state - else: - return (jnp.abs(pred - target)).mean(), state - - -@partial(jax.jit, static_argnames=["model_fn", "mean_shift", "mad_shift", "mask_last"]) -def mse( - params: hk.Params, - state: hk.State, - graph: SteerableGraphsTuple, - target: jnp.ndarray, - model_fn: Callable, - mean_shift: float = 0, - mad_shift: float = 1, - mask_last: bool = False, -) -> Tuple[float, hk.State]: - pred, state = predict(model_fn, params, state, graph, mean_shift, mad_shift) - assert target.shape == pred.shape - if mask_last: - return (jnp.power(pred[:-1] - target[:-1], 2)).mean(), state - else: - return (jnp.power(pred - target, 2)).mean(), state - - -@partial(jax.jit, static_argnames=["loss_fn", "opt_update"]) -def update( - params: hk.Params, - state: hk.State, - graph: SteerableGraphsTuple, - target: jnp.ndarray, - opt_state: optax.OptState, - loss_fn: Callable, - opt_update: Callable, -) -> Tuple[float, hk.Params, hk.State, optax.OptState]: - (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)( - params, state, graph, target - ) - updates, opt_state = opt_update(grads, opt_state, params) - return loss, optax.apply_updates(params, updates), state, opt_state - - -def evaluate( - loader: Iterable, - params: hk.Params, - state: hk.State, - loss_fn: Callable, - graph_transform: Callable, -) -> Tuple[float, float]: - eval_loss = 0.0 - eval_times = 0.0 - for data in loader: - graph, target = graph_transform(data) - eval_start = time.perf_counter_ns() - loss, _ = jax.lax.stop_gradient(loss_fn(params, state, graph, target)) - eval_loss += jax.block_until_ready(loss) - eval_times += (time.perf_counter_ns() - eval_start) / 1e6 - - return eval_times / len(loader), eval_loss / len(loader) - - -def train( - segnn: hk.Transformed, loader_train, loader_val, loader_test, graph_transform, args -): - init_graph, _ = graph_transform(next(iter(loader_train))) - params, segnn_state = segnn.init(key, init_graph) - - print( - f"Starting {args.epochs} epochs on {args.dataset} " - f"with {hk.data_structures.tree_size(params)} parameters.\n" - "Jitting..." - ) - - total_steps = args.epochs * len(loader_train) - - # set up learning rate and optimizer - if args.lr_scheduling: - learning_rate = optax.piecewise_constant_schedule( - args.lr, - boundaries_and_scales={ - int(total_steps * 0.8): 0.1, - int(total_steps * 0.9): 0.1, - }, - ) - else: - learning_rate = args.lr - - opt_init, opt_update = optax.adamw( - learning_rate=learning_rate, weight_decay=args.weight_decay - ) - - if args.dataset == "qm9": - # qm9 - target_mean, target_mad = loader_train.dataset.calc_stats() - # ignore padded target - loss_fn = partial(mae, model_fn=segnn.apply, mask_last=True) - eval_loss_fn = partial( - mae, - model_fn=segnn.apply, - mask_last=True, - mean_shift=target_mean, - mad_shift=target_mad, - ) - else: - # nbody - target_mean, target_mad = 0, 1 - loss_fn = partial(mse, model_fn=segnn.apply) - eval_loss_fn = partial(mse, model_fn=segnn.apply) - - update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update) - eval_fn = partial(evaluate, loss_fn=eval_loss_fn, graph_transform=graph_transform) - - opt_state = opt_init(params) - avg_time = [] - best_val = 1e10 - - for e in range(args.epochs): - train_loss = 0.0 - train_start = time.perf_counter_ns() - for data in loader_train: - graph, target = graph_transform(data) - # normalize targets - loss, params, segnn_state, opt_state = update_fn( - params=params, - state=segnn_state, - graph=graph, - target=(target - target_mean) / target_mad, - opt_state=opt_state, - ) - train_loss += loss - train_time = (time.perf_counter_ns() - train_start) / 1e6 - train_loss /= len(loader_train) - wandb_logs = {"train_loss": float(train_loss), "update_time": float(train_time)} - print( - f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms", - end="", - ) - if e % args.val_freq == 0: - eval_time, val_loss = eval_fn(loader_val, params, segnn_state) - avg_time.append(eval_time) - tag = "" - if val_loss < best_val: - best_val = val_loss - _, test_loss_ckp = eval_fn(loader_test, params, segnn_state) - wandb_logs.update({"test_loss": float(test_loss_ckp)}) - tag = " (BEST)" - wandb_logs.update( - {"val_loss": float(val_loss), "eval_time": float(eval_time)} - ) - print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="") - - print() - if args.wandb: - wandb.log(wandb_logs) - - test_loss = 0 - _, test_loss = eval_fn(loader_test, params, segnn_state) - # ignore compilation time - avg_time = avg_time[2:] - avg_time = sum(avg_time) / len(avg_time) - if args.wandb: - wandb.log({"test_loss": float(test_loss), "avg_eval_time": float(avg_time)}) - print( - "Training done.\n" - f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n" - f"Average (model) eval time {avg_time:.2f}ms" - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Run parameters - parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") - parser.add_argument( - "--batch-size", - type=int, - default=128, - help="Batch size (number of graphs).", - ) - parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate") - parser.add_argument( - "--lr-scheduling", - action="store_true", - help="Use learning rate scheduling", - ) - parser.add_argument( - "--weight-decay", type=float, default=1e-12, help="Weight decay" - ) - parser.add_argument( - "--dataset", - type=str, - choices=["qm9", "charged", "gravity"], - help="Dataset name", - ) - parser.add_argument( - "--max-samples", - type=int, - default=3000, - help="Maximum number of samples in nbody dataset", - ) - parser.add_argument( - "--val-freq", - type=int, - default=10, - help="Evaluation frequency (number of epochs)", - ) - - # nbody parameters - parser.add_argument( - "--target", - type=str, - default="pos", - help="Target. e.g. pos, force (gravity), alpha (qm9)", - ) - parser.add_argument( - "--neighbours", - type=int, - default=20, - help="Number of connected nearest neighbours", - ) - parser.add_argument( - "--n-bodies", - type=int, - default=5, - help="Number of bodies in the dataset", - ) - parser.add_argument( - "--dataset-name", - type=str, - default="small", - choices=["small", "default", "small_out_dist"], - help="Name of nbody data partition: default (200 steps), small (1000 steps)", - ) - - # qm9 parameters - parser.add_argument( - "--radius", - type=float, - default=2.0, - help="Radius (Angstrom) between which atoms to add links.", - ) - parser.add_argument( - "--feature-type", - type=str, - default="one_hot", - choices=["one_hot", "cormorant", "gilmer"], - help="Type of input feature", - ) - - # Model parameters - parser.add_argument( - "--units", type=int, default=64, help="Number of values in the hidden layers" - ) - parser.add_argument( - "--lmax-hidden", - type=int, - default=1, - help="Max degree of hidden representations.", - ) - parser.add_argument( - "--lmax-attributes", - type=int, - default=1, - help="Max degree of geometric attribute embedding", - ) - parser.add_argument( - "--layers", type=int, default=7, help="Number of message passing layers" - ) - parser.add_argument( - "--blocks", type=int, default=2, help="Number of layers in steerable MLPs." - ) - parser.add_argument( - "--norm", - type=str, - default="none", - choices=["instance", "batch", "none"], - help="Normalisation type", - ) - parser.add_argument( - "--double-precision", - action="store_true", - help="Use double precision in model", - ) - - # wandb parameters - parser.add_argument( - "--wandb", - action="store_true", - help="Activate weights and biases logging", - ) - parser.add_argument( - "--wandb-project", - type=str, - default="segnn", - help="Weights and biases project", - ) - parser.add_argument( - "--wandb-entity", - type=str, - default="", - help="Weights and biases entity", - ) - - args = parser.parse_args() - - # if specified set jax in double precision - jax.config.update("jax_enable_x64", args.double_precision) - - # connect to wandb - if args.wandb: - wandb_name = "_".join( - [ - args.wandb_project, - args.dataset, - args.target, - str(int(time.time())), - ] - ) - wandb.init( - project=args.wandb_project, - name=wandb_name, - config=args, - entity=args.wandb_entity, - ) - - # feature representations - if args.dataset == "qm9": - task = "graph" - if args.feature_type == "one_hot": - args.node_irreps = e3nn.Irreps("5x0e") - elif args.feature_type == "cormorant": - args.node_irreps = e3nn.Irreps("15x0e") - elif args.feature_type == "gilmer": - args.node_irreps = e3nn.Irreps("11x0e") - args.output_irreps = e3nn.Irreps("1x0e") - args.additional_message_irreps = e3nn.Irreps("1x0e") - elif args.dataset in ["charged", "gravity"]: - task = "node" - args.node_irreps = e3nn.Irreps("2x1o + 1x0e") - args.output_irreps = e3nn.Irreps("1x1o") - args.additional_message_irreps = e3nn.Irreps("2x0e") - - # Create hidden irreps - hidden_irreps = weight_balanced_irreps( - scalar_units=args.units, - # attribute irreps - irreps_right=e3nn.Irreps.spherical_harmonics(args.lmax_attributes), - use_sh=True, - lmax=args.lmax_hidden, - ) - - # build model - segnn = lambda x: SEGNN( - hidden_irreps=hidden_irreps, - output_irreps=args.output_irreps, - num_layers=args.layers, - task=task, - pool="avg", - blocks_per_layer=args.blocks, - norm=args.norm, - )(x) - segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) - - dataset_train, dataset_val, dataset_test, graph_transform = setup_datasets(args) - - train(segnn, dataset_train, dataset_val, dataset_test, graph_transform, args) diff --git a/requirements.txt b/requirements.txt index 2aa0260..2fe59a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html dm-haiku==0.0.9 -e3nn-jax==0.17.0 -jax[cuda]==0.4.1 +e3nn-jax==0.17.4 +jax[cuda]==0.4.8 jraph==0.0.6.dev0 numpy>=1.23.4 optax==0.1.3 diff --git a/segnn_jax/__init__.py b/segnn_jax/__init__.py index 987e13a..64f5182 100644 --- a/segnn_jax/__init__.py +++ b/segnn_jax/__init__.py @@ -14,4 +14,4 @@ "SteerableGraphsTuple", ] -__version__ = "0.5" +__version__ = "0.6" diff --git a/segnn_jax/blocks.py b/segnn_jax/blocks.py index 8ce752a..58e74c1 100644 --- a/segnn_jax/blocks.py +++ b/segnn_jax/blocks.py @@ -31,21 +31,15 @@ class O3TensorProduct(hk.Module): """ O(3) equivariant linear parametrized tensor product layer. - Attributes: - left_irreps: Left input representation - right_irreps: Right input representation - output_irreps: Output representation - get_parameter: Haiku parameter getter and init function - tensor_product: Tensor product function - biases: Bias wrapper function + Functionally the same as O3TensorProductLegacy, but around 5-10% faster. + FullyConnectedTensorProduct seems faster than tensor_product + linear: + https://github.com/e3nn/e3nn-jax/releases/tag/0.14.0 """ def __init__( self, output_irreps: e3nn.Irreps, *, - left_irreps: e3nn.Irreps, - right_irreps: Optional[e3nn.Irreps] = None, biases: bool = True, name: Optional[str] = None, init_fn: Optional[InitFn] = None, @@ -56,8 +50,6 @@ def __init__( Args: output_irreps: Output representation - left_irreps: Left input representation - right_irreps: Right input representation (optional, defaults to 1x0e) biases: If set ot true will add biases name: Name of the linear layer params init_fn: Weight initialization function. Default is uniform. @@ -67,41 +59,34 @@ def __init__( """ super().__init__(name) - if not right_irreps: - # NOTE: this is equivalent to a linear recombination of the left vectors - right_irreps = e3nn.Irreps("1x0e") - if not isinstance(output_irreps, e3nn.Irreps): output_irreps = e3nn.Irreps(output_irreps) - if not isinstance(left_irreps, e3nn.Irreps): - left_irreps = e3nn.Irreps(left_irreps) - if not isinstance(right_irreps, e3nn.Irreps): - right_irreps = e3nn.Irreps(right_irreps) - self.output_irreps = output_irreps - self.right_irreps = right_irreps - self.left_irreps = left_irreps - + # tp weight init if not init_fn: init_fn = uniform_init - self.get_parameter = init_fn if not gradient_normalization: gradient_normalization = config("gradient_normalization") if not path_normalization: path_normalization = config("path_normalization") + self._gradient_normalization = gradient_normalization + self._path_normalization = path_normalization + + self.biases = biases and "0e" in self.output_irreps - # NOTE FunctionalFullyConnectedTensorProduct appears to be faster than combining - # tensor_product+linear: https://github.com/e3nn/e3nn-jax/releases/tag/0.14.0 - # Implementation adapted from e3nn.haiku.FullyConnectedTensorProduct + def _build_tensor_product( + self, left_irreps: e3nn.Irreps, right_irreps: e3nn.Irreps + ) -> Callable: + """Build the tensor product function.""" tp = e3nn.FunctionalFullyConnectedTensorProduct( left_irreps, right_irreps, - output_irreps, - gradient_normalization=gradient_normalization, - path_normalization=path_normalization, + self.output_irreps, + gradient_normalization=self._gradient_normalization, + path_normalization=self._path_normalization, ) ws = [ self.get_parameter( @@ -118,35 +103,32 @@ def __init__( ] def tensor_product(x, y, **kwargs): - return tp.left_right(ws, x, y, **kwargs)._convert(output_irreps) + return tp.left_right(ws, x, y, **kwargs)._convert(self.output_irreps) - self.tensor_product = naive_broadcast_decorator(tensor_product) - self.biases = None + return naive_broadcast_decorator(tensor_product) - if biases and "0e" in self.output_irreps: - # add biases - b = [ - self.get_parameter( - f"b[{i_out}] {tp.irreps_out[i_out]}", - path_shape=(mul_ir.dim,), - weight_std=1 / jnp.sqrt(mul_ir.dim), - ) - for i_out, mul_ir in enumerate(output_irreps) - if mul_ir.ir.is_scalar() - ] - b = e3nn.IrrepsArray( - f"{self.output_irreps.count('0e')}x0e", jnp.concatenate(b) + def _build_biases(self) -> Callable: + """Build the add bias function.""" + b = [ + self.get_parameter( + f"b[{i_out}] {self.output_irreps}", + path_shape=(mul_ir.dim,), + weight_std=1 / jnp.sqrt(mul_ir.dim), + ) + for i_out, mul_ir in enumerate(self.output_irreps) + if mul_ir.ir.is_scalar() + ] + b = e3nn.IrrepsArray(f"{self.output_irreps.count('0e')}x0e", jnp.concatenate(b)) + + # TODO: could be improved + def _wrapper(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray: + scalars = x.filter("0e") + other = x.filter(drop="0e") + return e3nn.concatenate( + [scalars + b.broadcast_to(scalars.shape), other], axis=1 ) - # TODO: could be improved - def _wrapper(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray: - scalars = x.filter("0e") - other = x.filter(drop="0e") - return e3nn.concatenate( - [scalars + b.broadcast_to(scalars.shape), other], axis=1 - ) - - self.biases = _wrapper + return _wrapper def __call__( self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs @@ -162,7 +144,7 @@ def __call__( """ if not y: - y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1))) + y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0: warnings.warn( @@ -171,18 +153,13 @@ def __call__( "redistributing them into scalars or choose higher orders." ) - assert ( - x.irreps == self.left_irreps - ), f"Left irreps do not match. Got {x.irreps}, expected {self.left_irreps}" - assert ( - y.irreps == self.right_irreps - ), f"Right irreps do not match. Got {y.irreps}, expected {self.right_irreps}" - - output = self.tensor_product(x, y, **kwargs) + tp = self._build_tensor_product(x.irreps, y.irreps) + output = tp(x, y, **kwargs) if self.biases: # add biases - return self.biases(output) + bias_fn = self._build_biases() + return bias_fn(output) return output @@ -190,8 +167,6 @@ def __call__( def O3TensorProductLegacy( output_irreps: e3nn.Irreps, *, - left_irreps: e3nn.Irreps, - right_irreps: Optional[e3nn.Irreps] = None, biases: bool = True, name: Optional[str] = None, init_fn: Optional[InitFn] = None, @@ -215,15 +190,8 @@ def O3TensorProductLegacy( A function that returns the output to the weighted tensor product. """ - if not right_irreps: - right_irreps = e3nn.Irreps("1x0e") - if not isinstance(output_irreps, e3nn.Irreps): output_irreps = e3nn.Irreps(output_irreps) - if not isinstance(left_irreps, e3nn.Irreps): - left_irreps = e3nn.Irreps(left_irreps) - if not isinstance(right_irreps, e3nn.Irreps): - right_irreps = e3nn.Irreps(right_irreps) if not init_fn: init_fn = uniform_init @@ -251,7 +219,7 @@ def _tensor_product( """ if not y: - y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1))) + y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) if x.irreps.lmax == 0 and y.irreps.lmax == 0 and output_irreps.lmax > 0: warnings.warn( @@ -260,13 +228,6 @@ def _tensor_product( "redistributing them into scalars or choose higher orders." ) - assert ( - x.irreps == left_irreps - ), f"Left irreps do not match. Got {x.irreps}, expected {left_irreps}" - assert ( - y.irreps == right_irreps - ), f"Right irreps do not match. Got {y.irreps}, expected {right_irreps}" - tp = e3nn.tensor_product(x, y) return linear(tp) @@ -280,8 +241,6 @@ def _tensor_product( def O3TensorProductGate( output_irreps: e3nn.Irreps, *, - left_irreps: e3nn.Irreps, - right_irreps: Optional[e3nn.Irreps] = None, biases: bool = True, scalar_activation: Optional[Callable] = None, gate_activation: Optional[Callable] = None, @@ -312,8 +271,6 @@ def O3TensorProductGate( ) tensor_product = O3Layer( (gate_irreps + output_irreps).regroup(), - left_irreps=left_irreps, - right_irreps=right_irreps, biases=biases, name=name, init_fn=init_fn, diff --git a/segnn_jax/segnn.py b/segnn_jax/segnn.py index 4f12ffb..a2567a4 100644 --- a/segnn_jax/segnn.py +++ b/segnn_jax/segnn.py @@ -26,12 +26,9 @@ def O3Embedding(embed_irreps: e3nn.Irreps, embed_edges: bool = True) -> Callable def _embedding( st_graph: SteerableGraphsTuple, ) -> SteerableGraphsTuple: - # TODO update graph = st_graph.graph nodes = O3Layer( embed_irreps, - left_irreps=graph.nodes.irreps, - right_irreps=st_graph.node_attributes.irreps, name="embedding_nodes", )(graph.nodes, st_graph.node_attributes) st_graph = st_graph._replace(graph=graph._replace(nodes=nodes)) @@ -40,8 +37,6 @@ def _embedding( if embed_edges: additional_message_features = O3Layer( embed_irreps, - left_irreps=graph.nodes.irreps, - right_irreps=st_graph.node_attributes.irreps, name="embedding_msg_features", )( st_graph.additional_message_features, @@ -82,30 +77,21 @@ def _decoder(st_graph: SteerableGraphsTuple): nodes = st_graph.graph.nodes # pre pool block for i in range(blocks): - nodes = O3TensorProductGate( - latent_irreps, - left_irreps=nodes.irreps, - right_irreps=st_graph.node_attributes.irreps, - name=f"prepool_{i}", - )(nodes, st_graph.node_attributes) + nodes = O3TensorProductGate(latent_irreps, name=f"prepool_{i}")( + nodes, st_graph.node_attributes + ) if task == "node": - nodes = O3Layer( - output_irreps, - left_irreps=nodes.irreps, - right_irreps=st_graph.node_attributes.irreps, - name="output", - )(nodes, st_graph.node_attributes) + nodes = O3Layer(output_irreps, name="output")( + nodes, st_graph.node_attributes + ) if task == "graph": # pool over graph pooled_irreps = (latent_irreps.num_irreps * output_irreps).regroup() - nodes = O3Layer( - pooled_irreps, - left_irreps=nodes.irreps, - right_irreps=st_graph.node_attributes.irreps, - name=f"prepool_{blocks}", - )(nodes, st_graph.node_attributes) + nodes = O3Layer(pooled_irreps, name=f"prepool_{blocks}")( + nodes, st_graph.node_attributes + ) # pooling layer if pool == "avg": @@ -117,12 +103,8 @@ def _decoder(st_graph: SteerableGraphsTuple): # post pool mlp (not steerable) for i in range(blocks): - nodes = O3TensorProductGate( - pooled_irreps, left_irreps=nodes.irreps, name=f"postpool_{i}" - )(nodes) - nodes = O3Layer(output_irreps, left_irreps=nodes.irreps, name="output")( - nodes - ) + nodes = O3TensorProductGate(pooled_irreps, name=f"postpool_{i}")(nodes) + nodes = O3Layer(output_irreps, name="output")(nodes) return nodes @@ -179,12 +161,9 @@ def _message( msg = e3nn.concatenate([msg, additional_message_features], axis=-1) # message mlp (phi_m in the paper) steered by edge attributeibutes for i in range(self._blocks): - msg = O3TensorProductGate( - self._output_irreps, - left_irreps=msg.irreps, - right_irreps=getattr(edge_attribute, "irreps", None), - name=f"tp_{i}", - )(msg, edge_attribute) + msg = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")( + msg, edge_attribute + ) # NOTE: original implementation only applied batch norm to messages if self._norm == "batch": msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg) @@ -204,19 +183,13 @@ def _update( x = e3nn.concatenate([nodes, msg], axis=-1) # update mlp (phi_f in the paper) steered by node attributeibutes for i in range(self._blocks - 1): - x = O3TensorProductGate( - self._output_irreps, - left_irreps=x.irreps, - right_irreps=getattr(node_attribute, "irreps", None), - name=f"tp_{i}", - )(x, node_attribute) + x = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")( + x, node_attribute + ) # last update layer without activation - update = O3Layer( - self._output_irreps, - left_irreps=x.irreps, - right_irreps=getattr(node_attribute, "irreps", None), - name=f"tp_{self._blocks - 1}", - )(x, node_attribute) + update = O3Layer(self._output_irreps, name=f"tp_{self._blocks - 1}")( + x, node_attribute + ) # residual connection nodes += update # message norm diff --git a/setup.cfg b/setup.cfg index 53a3feb..9bd45b6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,9 +15,9 @@ packages = segnn_jax python_requires = >=3.8 install_requires = dm_haiku==0.0.9 - e3nn_jax==0.17.0 - jax==0.4.1 - jaxlib==0.4.1 + e3nn_jax==0.17.4 + jax==0.4.8 + jaxlib==0.4.8 jraph==0.0.6.dev0 numpy>=1.23.4 optax==0.1.3 diff --git a/tests/test_blocks.py b/tests/test_blocks.py index 8af60ad..e7a8336 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -8,9 +8,7 @@ @pytest.mark.parametrize("biases", [False, True]) def test_linear(key, biases): - f = lambda x1, x2: O3TensorProduct( - "1x1o", left_irreps="1x1o", right_irreps="1x1o", biases=biases - )(x1, x2) + f = lambda x1, x2: O3TensorProduct("1x1o", biases=biases)(x1, x2) f = hk.without_apply_rng(hk.transform(f)) v = e3nn.normal("1x1o", key, (5,)) @@ -31,9 +29,7 @@ def test_gated(key, biases): segnn_jax.blocks.O3Layer = segnn_jax.blocks.O3TensorProduct - f = lambda x1, x2: O3TensorProductGate( - "1x1o", left_irreps="1x1o", right_irreps="1x1o", biases=biases - )(x1, x2) + f = lambda x1, x2: O3TensorProductGate("1x1o", biases=biases)(x1, x2) f = hk.without_apply_rng(hk.transform(f)) v = e3nn.normal("1x1o", key, (5,)) @@ -50,9 +46,7 @@ def test_gated(key, biases): @pytest.mark.parametrize("biases", [False, True]) def test_linear_legacy(key, biases): - f = lambda x1, x2: O3TensorProductLegacy( - "1x1o", left_irreps="1x1o", right_irreps="1x1o", biases=biases - )(x1, x2) + f = lambda x1, x2: O3TensorProductLegacy("1x1o", biases=biases)(x1, x2) f = hk.without_apply_rng(hk.transform(f)) v = e3nn.normal("1x1o", key, (5,)) diff --git a/validate.py b/validate.py new file mode 100644 index 0000000..4eaa14e --- /dev/null +++ b/validate.py @@ -0,0 +1,238 @@ +import argparse +import time +from functools import partial + +import e3nn_jax as e3nn +import haiku as hk +import jax +import jax.numpy as jnp +import wandb + +from experiments import setup_data, train +from segnn_jax import SEGNN, weight_balanced_irreps + +key = jax.random.PRNGKey(1337) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Run parameters + parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") + parser.add_argument( + "--batch-size", + type=int, + default=128, + help="Batch size (number of graphs).", + ) + parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate") + parser.add_argument( + "--lr-scheduling", + action="store_true", + help="Use learning rate scheduling", + ) + parser.add_argument( + "--weight-decay", type=float, default=1e-12, help="Weight decay" + ) + parser.add_argument( + "--dataset", + type=str, + choices=["qm9", "charged", "gravity"], + help="Dataset name", + ) + parser.add_argument( + "--max-samples", + type=int, + default=3000, + help="Maximum number of samples in nbody dataset", + ) + parser.add_argument( + "--val-freq", + type=int, + default=10, + help="Evaluation frequency (number of epochs)", + ) + + # nbody parameters + parser.add_argument( + "--target", + type=str, + default="pos", + help="Target. e.g. pos, force (gravity), alpha (qm9)", + ) + parser.add_argument( + "--neighbours", + type=int, + default=20, + help="Number of connected nearest neighbours", + ) + parser.add_argument( + "--n-bodies", + type=int, + default=5, + help="Number of bodies in the dataset", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="small", + choices=["small", "default", "small_out_dist"], + help="Name of nbody data partition: default (200 steps), small (1000 steps)", + ) + + # qm9 parameters + parser.add_argument( + "--radius", + type=float, + default=2.0, + help="Radius (Angstrom) between which atoms to add links.", + ) + parser.add_argument( + "--feature-type", + type=str, + default="one_hot", + choices=["one_hot", "cormorant", "gilmer"], + help="Type of input feature", + ) + + # Model parameters + parser.add_argument( + "--units", type=int, default=64, help="Number of values in the hidden layers" + ) + parser.add_argument( + "--lmax-hidden", + type=int, + default=1, + help="Max degree of hidden representations.", + ) + parser.add_argument( + "--lmax-attributes", + type=int, + default=1, + help="Max degree of geometric attribute embedding", + ) + parser.add_argument( + "--layers", type=int, default=7, help="Number of message passing layers" + ) + parser.add_argument( + "--blocks", type=int, default=2, help="Number of layers in steerable MLPs." + ) + parser.add_argument( + "--norm", + type=str, + default="none", + choices=["instance", "batch", "none"], + help="Normalisation type", + ) + parser.add_argument( + "--double-precision", + action="store_true", + help="Use double precision in model", + ) + + # wandb parameters + parser.add_argument( + "--wandb", + action="store_true", + help="Activate weights and biases logging", + ) + parser.add_argument( + "--wandb-project", + type=str, + default="segnn", + help="Weights and biases project", + ) + parser.add_argument( + "--wandb-entity", + type=str, + default="", + help="Weights and biases entity", + ) + + args = parser.parse_args() + + # if specified set jax in double precision + jax.config.update("jax_enable_x64", args.double_precision) + + # connect to wandb + if args.wandb: + wandb_name = "_".join( + [ + args.wandb_project, + args.dataset, + args.target, + str(int(time.time())), + ] + ) + wandb.init( + project=args.wandb_project, + name=wandb_name, + config=args, + entity=args.wandb_entity, + ) + + # feature representations + if args.dataset == "qm9": + args.task = "graph" + if args.feature_type == "one_hot": + args.node_irreps = e3nn.Irreps("5x0e") + elif args.feature_type == "cormorant": + args.node_irreps = e3nn.Irreps("15x0e") + elif args.feature_type == "gilmer": + args.node_irreps = e3nn.Irreps("11x0e") + args.output_irreps = e3nn.Irreps("1x0e") + args.additional_message_irreps = e3nn.Irreps("1x0e") + elif args.dataset in ["charged", "gravity"]: + args.task = "node" + args.node_irreps = e3nn.Irreps("2x1o + 1x0e") + args.output_irreps = e3nn.Irreps("1x1o") + args.additional_message_irreps = e3nn.Irreps("2x0e") + + # Create hidden irreps + hidden_irreps = weight_balanced_irreps( + scalar_units=args.units, + # attribute irreps + irreps_right=e3nn.Irreps.spherical_harmonics(args.lmax_attributes), + use_sh=True, + lmax=args.lmax_hidden, + ) + + # build model + segnn = lambda x: SEGNN( + hidden_irreps=hidden_irreps, + output_irreps=args.output_irreps, + num_layers=args.layers, + task=args.task, + pool="avg", + blocks_per_layer=args.blocks, + norm=args.norm, + )(x) + segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) + + loader_train, loader_val, loader_test, graph_transform, eval_trn = setup_data(args) + + if args.dataset == "qm9": + from experiments.train import loss_fn + + _mae = lambda p, t: jnp.abs(p - t) + + train_loss = partial(loss_fn, criterion=_mae, task=args.task) + eval_loss = partial(loss_fn, criterion=_mae, eval_trn=eval_trn, task=args.task) + if args.dataset in ["charged", "gravity"]: + from experiments.train import loss_fn + + _mse = lambda p, t: jnp.power(p - t, 2) + + train_loss = partial(loss_fn, criterion=_mse, do_mask=False) + eval_loss = partial(loss_fn, criterion=_mse, do_mask=False) + + train( + key, + segnn, + loader_train, + loader_val, + loader_test, + train_loss, + eval_loss, + graph_transform, + args, + )