Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KAN readout options for MACE with possible better accuracy #655

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
3 changes: 2 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Union

import dill
import numpy as np
import torch
from ase.calculators.calculator import Calculator, all_changes
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(

# Load models from files
self.models = [
torch.load(f=model_path, map_location=device)
torch.load(f=model_path, map_location=device, pickle_module=dill)
for model_path in model_paths
]

Expand Down
2 changes: 2 additions & 0 deletions mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

import dill
import torch
from e3nn.util import jit

Expand Down Expand Up @@ -64,6 +65,7 @@ def main():
model = torch.load(
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
pickle_module=dill,
)
if args.dtype == "float64":
model = model.double().to("cpu")
Expand Down
7 changes: 4 additions & 3 deletions mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import ase.data
import ase.io
import dill
import numpy as np
import torch

Expand Down Expand Up @@ -58,7 +59,7 @@ def parse_args() -> argparse.Namespace:
help="Model head used for evaluation",
type=str,
required=False,
default=None
default=None,
)
return parser.parse_args()

Expand All @@ -73,7 +74,7 @@ def run(args: argparse.Namespace) -> None:
device = torch_tools.init_device(args.device)

# Load model
model = torch.load(f=args.model, map_location=args.device)
model = torch.load(f=args.model, map_location=args.device, pickle_module=dill)
model = model.to(
args.device
) # shouldn't be necessary but seems to help with CUDA problems
Expand All @@ -94,7 +95,7 @@ def run(args: argparse.Namespace) -> None:
heads = model.heads
except AttributeError:
heads = None

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
Expand Down
9 changes: 5 additions & 4 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pathlib import Path
from typing import List, Optional

import dill
import torch.distributed
import torch.nn.functional
from e3nn.util import jit
Expand Down Expand Up @@ -142,7 +143,7 @@ def run(args: argparse.Namespace) -> None:
model_foundation = calc.models[0]
else:
model_foundation = torch.load(
args.foundation_model, map_location=args.device
args.foundation_model, map_location=args.device, pickle_module=dill
)
logging.info(
f"Using foundation model {args.foundation_model} as initial checkpoint."
Expand Down Expand Up @@ -731,7 +732,7 @@ def run(args: argparse.Namespace) -> None:
logging.info(f"Saving model to {model_path}")
if args.save_cpu:
model = model.to("cpu")
torch.save(model, model_path)
torch.save(model, model_path, pickle_module=dill)
extra_files = {
"commit.txt": commit.encode("utf-8") if commit is not None else b"",
"config.yaml": json.dumps(
Expand All @@ -740,7 +741,7 @@ def run(args: argparse.Namespace) -> None:
}
if swa_eval:
torch.save(
model, Path(args.model_dir) / (args.name + "_stagetwo.model")
model, Path(args.model_dir) / (args.name + "_stagetwo.model"), pickle_module=dill
)
try:
path_complied = Path(args.model_dir) / (
Expand All @@ -756,7 +757,7 @@ def run(args: argparse.Namespace) -> None:
except Exception as e: # pylint: disable=W0703
pass
else:
torch.save(model, Path(args.model_dir) / (args.name + ".model"))
torch.save(model, Path(args.model_dir) / (args.name + ".model"), pickle_module=dill)
try:
path_complied = Path(args.model_dir) / (
args.name + "_compiled.model"
Expand Down
4 changes: 4 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
KANNonLinearReadoutBlock,
KANReadoutBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
Expand Down Expand Up @@ -77,6 +79,8 @@
"ZBLBasis",
"LinearNodeEmbeddingBlock",
"LinearReadoutBlock",
"KANReadoutBlock",
"KANNonLinearReadoutBlock",
"EquivariantProductBasisBlock",
"ScaleShiftBlock",
"LinearDipoleReadoutBlock",
Expand Down
91 changes: 91 additions & 0 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from e3nn.util.jit import compile_mode

from mace.tools.compile import simplify_if_compile
from mace.tools.MultKAN_jit import MultKAN
from mace.tools.scatter import scatter_sum

from .irreps_tools import (
Expand Down Expand Up @@ -59,6 +60,96 @@ def forward(
return self.linear(x) # [n_nodes, 1]


@compile_mode("trace")
class KANReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
irrep_out: o3.Irreps = o3.Irreps("0e"),
):
super().__init__()
self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=MLP_irreps)
self.linear_2 = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out)
self.irreps_in = o3.Irreps(irreps_in)
self.hidden_irreps = MLP_irreps
assert MLP_irreps.dim >= 8, "MLP_irreps at least 8!"
dim = [MLP_irreps.dim, MLP_irreps.dim // 2, MLP_irreps.dim // 4, irrep_out.dim]
self.kan = MultKAN(
width=dim,
grid=3,
k=3,
mult_arity=2,
symbolic_enabled=False,
auto_save=False,
save_act=False,
)
# self.kan.speed(compile=True)

def forward(
self,
x: torch.Tensor,
heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x1 = self.linear(x)
return self.kan(x1) + self.linear_2(x) # [n_nodes, irrep_out.dim]

def _make_tracing_inputs(self, n: int):
return [
{"forward": (torch.randn(6, self.irreps_in.dim), None)}
for _ in range(n)
]

def __repr__(self):
return f"{self.__class__.__name__}(dim=[{self.kan.width}])"


@compile_mode("trace")
class KANNonLinearReadoutBlock(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
MLP_irreps: o3.Irreps,
irrep_out: o3.Irreps = o3.Irreps("0e"),
num_heads: int = 1,
):
super().__init__()
self.irreps_in = o3.Irreps(irreps_in)
self.hidden_irreps = MLP_irreps
self.num_heads = num_heads
self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
# self.linear_2 = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out)
assert MLP_irreps.dim >= 8, "MLP_irreps at least 8!"
dim = [MLP_irreps.dim, MLP_irreps.dim // 2, MLP_irreps.dim // 4, irrep_out.dim]
self.kan = MultKAN(
width=dim,
grid=3,
k=3,
mult_arity=2,
symbolic_enabled=False,
auto_save=False,
save_act=False,
)

def forward(
self, x: torch.Tensor, heads: Optional[torch.Tensor] = None
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
if hasattr(self, "num_heads"):
if self.num_heads > 1 and heads is not None:
x = mask_head(x, heads, self.num_heads)
x1 = self.linear_1(x)
return self.kan(x1) # + self.linear_2(x) # [n_nodes, irrep_out.dim]

def _make_tracing_inputs(self, n: int):
return [
{"forward": (torch.randn(6, self.irreps_in.dim), None)}
for _ in range(n)
]

def __repr__(self):
return f"{self.__class__.__name__}(dim=[{self.kan.width}])"


@simplify_if_compile
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
Expand Down
57 changes: 43 additions & 14 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
AtomicEnergiesBlock,
EquivariantProductBasisBlock,
InteractionBlock,
KANNonLinearReadoutBlock,
KANReadoutBlock,
LinearDipoleReadoutBlock,
LinearNodeEmbeddingBlock,
LinearReadoutBlock,
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
radial_MLP: Optional[List[int]] = None,
radial_type: Optional[str] = "bessel",
heads: Optional[List[str]] = None,
KAN_readout: bool = False,
):
super().__init__()
self.register_buffer(
Expand Down Expand Up @@ -135,9 +138,18 @@ def __init__(
self.products = torch.nn.ModuleList([prod])

self.readouts = torch.nn.ModuleList()
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
self.KAN_readout = KAN_readout

if KAN_readout:
self.readouts.append(
KANReadoutBlock(
hidden_irreps, MLP_irreps, o3.Irreps(f"{len(heads)}x0e")
)
)
else:
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)

for i in range(num_interactions - 1):
if i == num_interactions - 2:
Expand Down Expand Up @@ -166,19 +178,36 @@ def __init__(
)
self.products.append(prod)
if i == num_interactions - 2:
self.readouts.append(
NonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
gate,
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
if KAN_readout:
self.readouts.append(
KANNonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
)
)
else:
self.readouts.append(
NonLinearReadoutBlock(
hidden_irreps_out,
(len(heads) * MLP_irreps).simplify(),
gate,
o3.Irreps(f"{len(heads)}x0e"),
len(heads),
)
)
)
else:
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)
if KAN_readout:
self.readouts.append(
KANReadoutBlock(
hidden_irreps, MLP_irreps, o3.Irreps(f"{len(heads)}x0e")
)
)
else:
self.readouts.append(
LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e"))
)

def forward(
self,
Expand Down
Loading
Loading