Skip to content

Commit

Permalink
DimeNet Edge Features Functionality (#291)
Browse files Browse the repository at this point in the history
* Add edge features to DIMEStack

* take out unnecessary impoty

* Update comment
  • Loading branch information
RylieWeaver authored Oct 8, 2024
1 parent 669cbd7 commit 076e2d6
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
93 changes: 73 additions & 20 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch_geometric.nn import Linear, Sequential
from torch_geometric.nn.models.dimenet import (
BesselBasisLayer,
EmbeddingBlock,
InteractionPPBlock,
OutputPPBlock,
SphericalBasisLayer,
Expand All @@ -45,6 +44,7 @@ def __init__(
num_before_skip,
num_radial,
num_spherical,
edge_dim,
radius,
*args,
max_neighbours: Optional[int] = None,
Expand All @@ -57,6 +57,7 @@ def __init__(
self.num_spherical = num_spherical
self.num_before_skip = num_before_skip
self.num_after_skip = num_after_skip
self.edge_dim = edge_dim
self.radius = radius

super().__init__(*args, **kwargs)
Expand All @@ -83,7 +84,10 @@ def get_conv(self, input_dim, output_dim):
), "DimeNet requires more than one hidden dimension between input_dim and output_dim."
lin = Linear(input_dim, hidden_dim)
emb = HydraEmbeddingBlock(
num_radial=self.num_radial, hidden_channels=hidden_dim, act=SiLU()
num_radial=self.num_radial,
hidden_channels=hidden_dim,
act=SiLU(),
edge_dim=self.edge_dim,
)
inter = InteractionPPBlock(
hidden_channels=hidden_dim,
Expand All @@ -104,16 +108,29 @@ def get_conv(self, input_dim, output_dim):
act=SiLU(),
output_initializer="glorot_orthogonal",
)
return Sequential(
"x, pos, rbf, sbf, i, j, idx_kj, idx_ji",
[
(lin, "x -> x"),
(emb, "x, rbf, i, j -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
],
)

if self.use_edge_attr:
return Sequential(
"x, pos, rbf, edge_attr, sbf, i, j, idx_kj, idx_ji",
[
(lin, "x -> x"),
(emb, "x, rbf, i, j, edge_attr -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
],
)
else:
return Sequential(
"x, pos, rbf, sbf, i, j, idx_kj, idx_ji",
[
(lin, "x -> x"),
(emb, "x, rbf, i, j -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
],
)

def _conv_args(self, data):
assert (
Expand Down Expand Up @@ -143,6 +160,12 @@ def _conv_args(self, data):
"idx_ji": idx_ji,
}

if self.use_edge_attr:
assert (
data.edge_attr is not None
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})

return conv_args


Expand Down Expand Up @@ -182,20 +205,50 @@ def triplets(
return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji


class HydraEmbeddingBlock(EmbeddingBlock):
def __init__(self, num_radial: int, hidden_channels: int, act: Callable):
super().__init__(
num_radial=num_radial, hidden_channels=hidden_channels, act=act
)
del self.emb # Atomic embeddings are handled by Hydra.
class HydraEmbeddingBlock(torch.nn.Module):
def __init__(
self,
num_radial: int,
hidden_channels: int,
act: Callable,
edge_dim: Optional[int] = None,
):
super().__init__()
self.act = act

# self.emb = Embedding(95, hidden_channels) # Atomic embeddings are handled by HYDRA
self.lin_rbf = Linear(num_radial, hidden_channels)
if edge_dim is not None: # Optional edge features
self.edge_lin = Linear(edge_dim, hidden_channels)
self.lin = Linear(4 * hidden_channels, hidden_channels)
else:
self.lin = Linear(3 * hidden_channels, hidden_channels)

self.reset_parameters()

def reset_parameters(self):
# self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
self.lin_rbf.reset_parameters()
self.lin.reset_parameters()
if hasattr(self, "edge_lin"):
self.edge_lin.reset_parameters()

def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
rbf: Tensor,
i: Tensor,
j: Tensor,
edge_attr: Optional[Tensor] = None,
) -> Tensor:
# x = self.emb(x)
rbf = self.act(self.lin_rbf(rbf))
return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))

# Include edge features if they are provided
if edge_attr is not None and hasattr(self, "edge_lin"):
edge_attr = self.act(self.edge_lin(edge_attr))
out = torch.cat([x[i], x[j], rbf, edge_attr], dim=-1)
else:
out = torch.cat([x[i], x[j], rbf], dim=-1)

return self.act(self.lin(out))
1 change: 1 addition & 0 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def create_model(
num_before_skip,
num_radial,
num_spherical,
edge_dim,
radius,
input_dim,
hidden_dim,
Expand Down
6 changes: 3 additions & 3 deletions hydragnn/utils/input_config_parsing/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ def update_config_equivariance(config):

def update_config_edge_dim(config):
config["edge_dim"] = None
edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN"]
edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN", "DimeNet"]
if "edge_features" in config and config["edge_features"]:
assert (
config["model_type"] in edge_models
), "Edge features can only be used with EGNN, SchNet, PNA, PNAPlus, and CGCNN."
), "Edge features can only be used with DimeNet EGNN, SchNet, PNA, PNAPlus, and CGCNN."
config["edge_dim"] = len(config["edge_features"])
elif config["model_type"] == "CGCNN":
# CG always needs an integer edge_dim
# PNA would fail with integer edge_dim without edge_attr
# PNA, PNAPlus, and DimeNet would fail with integer edge_dim without edge_attr
config["edge_dim"] = 0
return config

Expand Down

0 comments on commit 076e2d6

Please sign in to comment.