Skip to content

Commit

Permalink
ESCN export II (#848)
Browse files Browse the repository at this point in the history
* temp fix to train mptraj

* add compile option

* move jd to init

* fix dynamic export

* add value testing for export

* update test

* lint

* update comment

* update forward code

* reraise error

* wrap escn

* revert packages

* format

Former-commit-id: 362482e00920bea5af5bf7fcb9c035bd62966aa6
  • Loading branch information
rayg1234 authored Sep 17, 2024
1 parent 60f56bd commit 7707b3e
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 199 deletions.
153 changes: 110 additions & 43 deletions src/fairchem/core/models/escn/escn_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@

import contextlib
import logging
import os
import typing

import torch
import torch.nn as nn

if typing.TYPE_CHECKING:
from torch_geometric.data.batch import Batch

from fairchem.core.common.registry import registry
from fairchem.core.models.base import GraphModelMixin
from fairchem.core.models.escn.so3_exportable import (
CoefficientMapping,
SO3_Grid,
Expand All @@ -32,15 +38,15 @@


@registry.register_model("escn_export")
class eSCN(nn.Module):
class eSCN(nn.Module, GraphModelMixin):
"""Equivariant Spherical Channel Network
Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs
Args:
regress_forces (bool): Compute forces
cutoff (float): Maximum distance between nieghboring atoms in Angstroms
max_num_elements (int): Maximum atomic number
max_neighbors(int): Max neighbors to take per node, when using the graph generation
cutoff (float): Maximum distance between nieghboring atoms in Angstroms
max_num_elements (int): Maximum atomic number
num_layers (int): Number of layers in the GNN
lmax (int): maximum degree of the spherical harmonics (1 to 10)
mmax (int): maximum order of the spherical harmonics (0 to lmax)
Expand All @@ -51,13 +57,15 @@ class eSCN(nn.Module):
distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances
basis_width_scalar (float): Width of distance basis function
distance_resolution (float): Distance between distance basis functions in Angstroms
compile (bool): use torch.compile on the forward
export (bool): use the exportable version of the module
"""

def __init__(
self,
regress_forces: bool = True,
max_neighbors: int = 300,
cutoff: float = 8.0,
max_num_elements: int = 90,
max_num_elements: int = 100,
num_layers: int = 8,
lmax: int = 4,
mmax: int = 2,
Expand All @@ -69,6 +77,8 @@ def __init__(
basis_width_scalar: float = 1.0,
distance_resolution: float = 0.02,
resolution: int | None = None,
compile: bool = False,
export: bool = False,
) -> None:
super().__init__()

Expand All @@ -78,7 +88,7 @@ def __init__(
logging.error("You need to install the e3nn library to use the SCN model")
raise ImportError

self.regress_forces = regress_forces
self.max_neighbors = max_neighbors
self.cutoff = cutoff
self.max_num_elements = max_num_elements
self.hidden_channels = hidden_channels
Expand All @@ -91,6 +101,8 @@ def __init__(
self.mmax = mmax
self.basis_width_scalar = basis_width_scalar
self.distance_function = distance_function
self.compile = compile
self.export = export

# non-linear activation function used throughout the network
self.act = nn.SiLU()
Expand Down Expand Up @@ -169,10 +181,9 @@ def __init__(
self.energy_block = EnergyBlock(
self.sphere_channels, self.num_sphere_samples, self.act
)
if self.regress_forces:
self.force_block = ForceBlock(
self.sphere_channels, self.num_sphere_samples, self.act
)
self.force_block = ForceBlock(
self.sphere_channels, self.num_sphere_samples, self.act
)

# Create a roughly evenly distributed point sampling of the sphere for the output blocks
self.sphere_points = nn.Parameter(
Expand All @@ -189,29 +200,96 @@ def __init__(
requires_grad=False,
)

def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
pos: torch.Tensor = data["pos"]
batch_idx: torch.Tensor = data["batch"]
natoms: torch.Tensor = data["natoms"]
atomic_numbers: torch.Tensor = data["atomic_numbers"]
edge_index: torch.Tensor = data["edge_index"]
edge_distance: torch.Tensor = data["distances"]
edge_distance_vec: torch.Tensor = data["edge_distance_vec"]

atomic_numbers = atomic_numbers.long()
# TODO: this requires upgrade to torch2.4 with export non-strict mode to enable
# assert (
# atomic_numbers.max().item() < self.max_num_elements
# ), "Atomic number exceeds that given in model config"
self.sph_feature_size = int((self.lmax + 1) ** 2)
# Pre-load Jd tensors for wigner matrices
# Borrowed from e3nn @ 0.4.0:
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10
# _Jd is a list of tensors of shape (2l+1, 2l+1)
# TODO: we should probably just bake this into the file as strings to avoid
# carrying this extra file around
Jd_list = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
for l in range(self.lmax + 1):
self.register_buffer(f"Jd_{l}", Jd_list[l])

if self.compile:
logging.info("Using the compiled escn forward function...")
self.forward = torch.compile(
options={"triton.cudagraphs": True}, fullgraph=True, dynamic=True
)(self.forward)

# torch.export only works with nn.module with an unaltered forward function,
# furthermore AOT Inductor currently requires a flat list of inputs
# this we need keep the module.forward function as the fully exportable region
# When not using export, ie for training, we swap out the forward with a version
# that wraps it with the graph generator
#
# TODO: this is really ugly and confusing to read, find a better way to deal
# with partially exportable model
if not self.export:
self._forward = self.forward
self.forward = self.forward_trainable

def forward_trainable(self, data: Batch) -> dict[str, torch.Tensor]:
# standard forward call that generates the graph on-the-fly with generate_graph
# this part of the code is not compile/export friendly so we keep it separated and wrap the exportaable forward
graph = self.generate_graph(
data,
max_neighbors=self.max_neighbors,
otf_graph=True,
use_pbc=True,
use_pbc_single=True,
)
energy, forces = self._forward(
data.pos,
data.batch,
data.natoms,
data.atomic_numbers.long(),
graph.edge_index,
graph.edge_distance,
graph.edge_distance_vec,
)
return {"energy": energy, "forces": forces}

# a fully compilable/exportable forward function
# takes a full graph with edges as input
def forward(
self,
pos: torch.Tensor,
batch_idx: torch.Tensor,
natoms: torch.Tensor,
atomic_numbers: torch.Tensor,
edge_index: torch.Tensor,
edge_distance: torch.Tensor,
edge_distance_vec: torch.Tensor,
) -> list[torch.Tensor]:
"""
N: num atoms
N: batch size
E: num edges
pos: [N, 3] atom positions
batch_idx: [N] batch index of each atom
natoms: [B] number of atoms in each batch
atomic_numbers: [N] atomic number per atom
edge_index: [2, E] edges between source and target atoms
edge_distance: [E] cartesian distance for each edge
edge_distance_vec: [E, 3] direction vector of edges (includes pbc)
"""
if not self.export and not self.compile:
assert atomic_numbers.max().item() < self.max_num_elements
num_atoms = len(atomic_numbers)

###############################################################
# Initialize data structures
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(edge_index, edge_distance_vec)
wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax).detach()
edge_rot_mat = self._init_edge_rot_mat(edge_distance_vec)
Jd_buffers = [
getattr(self, f"Jd_{l}").type(edge_rot_mat.dtype)
for l in range(self.lmax + 1)
]
wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax, Jd_buffers).detach()

###############################################################
# Initialize node embeddings
Expand All @@ -220,7 +298,7 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
# Init per node representations using an atomic number based embedding
x_message = torch.zeros(
num_atoms,
int((self.lmax + 1) ** 2),
self.sph_feature_size,
self.sphere_channels,
device=pos.device,
dtype=pos.dtype,
Expand Down Expand Up @@ -266,31 +344,20 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
# Scale energy to help balance numerical precision w.r.t. forces
energy = energy * 0.001

outputs = {"energy": energy}
###############################################################
# Force estimation
###############################################################
if self.regress_forces:
forces = self.force_block(x_pt, self.sphere_points)
outputs["forces"] = forces
forces = self.force_block(x_pt, self.sphere_points)

return outputs
return energy, forces

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, edge_index, edge_distance_vec):
def _init_edge_rot_mat(self, edge_distance_vec):
edge_vec_0 = edge_distance_vec
edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1))

# Make sure the atoms are far enough apart
# TODO: this requires upgrade to torch2.4 with export non-strict mode to enable
# if torch.min(edge_vec_0_distance) < 0.0001:
# logging.error(
# f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}"
# )
# (minval, minidx) = torch.min(edge_vec_0_distance, 0)
# logging.error(
# f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}"
# )
# assert torch.min(edge_vec_0_distance) < 0.0001

norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))

Expand Down
55 changes: 28 additions & 27 deletions src/fairchem/core/models/escn/so3_exportable.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import math
import os

import torch

Expand All @@ -11,51 +10,53 @@
except ImportError:
pass

# Borrowed from e3nn @ 0.4.0:
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10
# _Jd is a list of tensors of shape (2l+1, 2l+1)
__Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))


@torch.compiler.assume_constant_result
def get_jd() -> torch.Tensor:
return __Jd


# Borrowed from e3nn @ 0.4.0:
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37
#
# In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower:
# https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92
def wigner_D(
lv: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor
lv: int,
alpha: torch.Tensor,
beta: torch.Tensor,
gamma: torch.Tensor,
_Jd: list[torch.Tensor],
) -> torch.Tensor:
_Jd = get_jd()
assert (
lv < len(_Jd)
), f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more"

alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma)
J = _Jd[lv].to(dtype=alpha.dtype, device=alpha.device)
J = _Jd[lv]
Xa = _z_rot_mat(alpha, lv)
Xb = _z_rot_mat(beta, lv)
Xc = _z_rot_mat(gamma, lv)
return Xa @ J @ Xb @ J @ Xc


def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor:
shape, device, dtype = angle.shape, angle.device, angle.dtype
M = angle.new_zeros((*shape, 2 * lv + 1, 2 * lv + 1))
inds = torch.arange(0, 2 * lv + 1, 1, device=device)
reversed_inds = torch.arange(2 * lv, -1, -1, device=device)
frequencies = torch.arange(lv, -lv - 1, -1, dtype=dtype, device=device)
M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None])
M[..., inds, inds] = torch.cos(frequencies * angle[..., None])
M = angle.new_zeros((*angle.shape, 2 * lv + 1, 2 * lv + 1))

# The following code needs to replaced for a for loop because
# torch.export barfs on outer product like operations
# ie: torch.outer(frequences, angle) (same as frequencies * angle[..., None])
# will place a non-sense Guard on the dimensions of angle when attempting to export setting
# angle (edge dimensions) as dynamic. This may be fixed in torch2.4.

# inds = torch.arange(0, 2 * lv + 1, 1, device=device)
# reversed_inds = torch.arange(2 * lv, -1, -1, device=device)
# frequencies = torch.arange(lv, -lv - 1, -1, dtype=dtype, device=device)
# M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None])
# M[..., inds, inds] = torch.cos(frequencies * angle[..., None])

inds = list(range(0, 2 * lv + 1, 1))
reversed_inds = list(range(2 * lv, -1, -1))
frequencies = list(range(lv, -lv - 1, -1))
for i in range(len(frequencies)):
M[..., inds[i], reversed_inds[i]] = torch.sin(frequencies[i] * angle)
M[..., inds[i], inds[i]] = torch.cos(frequencies[i] * angle)
return M


def rotation_to_wigner(
edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int
edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int, Jd: list[torch.Tensor]
) -> torch.Tensor:
x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0])
alpha, beta = o3.xyz_to_angles(x)
Expand All @@ -69,7 +70,7 @@ def rotation_to_wigner(
wigner = torch.zeros(len(alpha), size, size, device=edge_rot_mat.device)
start = 0
for lmax in range(start_lmax, end_lmax + 1):
block = wigner_D(lmax, alpha, beta, gamma)
block = wigner_D(lmax, alpha, beta, gamma, Jd)
end = start + block.size()[1]
wigner[:, start:end, start:end] = block
start = end
Expand Down
Loading

0 comments on commit 7707b3e

Please sign in to comment.