Skip to content

Commit

Permalink
refactor rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin committed Oct 31, 2024
1 parent 8e94ce9 commit 4db2f73
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
14 changes: 8 additions & 6 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(self, data):
conv,
use_reentrant=False,
inv_node_feat=inv_node_feat,
pos=pos,
equiv_node_feat=equiv_node_feat,
**conv_args
)
inv_node_feat = self.activation_function(feat_layer(inv_node_feat))
Expand All @@ -349,15 +349,17 @@ def forward(self, data):
outputs_var.append(output_head[:, head_dim:] ** 2)
else:
if self.node_NN_type == "conv":
inv_node_feat = x
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
c, pos = conv(
inv_node_feat=x,
inv_node_feat, equiv_node_feat = conv(
inv_node_feat=inv_node_feat,
equiv_node_feat=equiv_node_feat,
**conv_args
)
c = batch_norm(c)
x = self.activation_function(c)
x_node = x
inv_node_feat = batch_norm(inv_node_feat)
inv_node_feat = self.activation_function(inv_node_feat)
x_node = inv_node_feat
x = inv_node_feat
else:
x_node = headloc(x=x, batch=data.batch)
outputs.append(x_node[:, :head_dim])
Expand Down
9 changes: 6 additions & 3 deletions hydragnn/models/PAINNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
assert (
hidden_dim > 1
), "PainnNet requires more than one hidden dimension between input_dim and output_dim."
print("hidden_dim", input_dim)
self_inter = PainnMessage(
node_size=input_dim, edge_size=self.num_radial, cutoff=self.radius
)
cross_inter = PainnUpdate(node_size=input_dim, last_layer=last_layer)
"""
The following linear layers are to get the correct sizing of embeddings. This is
necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly
The following linear layers are to get the correct sizing of embeddings. This is
necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly
because node_scalar and node-vector are updated through a sum.
"""
node_embed_out = nn.Sequential(
Expand Down Expand Up @@ -137,7 +138,6 @@ def _embedding(self, data):

# Instantiate tensor to hold equivariant traits
v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device)
data.v = v

conv_args = {
"edge_index": data.edge_index.t().to(torch.long),
Expand All @@ -153,6 +153,7 @@ class PainnMessage(nn.Module):

def __init__(self, node_size: int, edge_size: int, cutoff: float):
super().__init__()
print(node_size)

self.node_size = node_size
self.edge_size = edge_size
Expand Down Expand Up @@ -183,6 +184,8 @@ def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):
dim=1,
)

print(node_vector[edge[:, 1]].shape)
print(gate_state_vector.shape)
# num_pairs * 3 * node_size, num_pairs * node_size
message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1)
edge_vector = gate_edge_vector.unsqueeze(1) * (
Expand Down
4 changes: 2 additions & 2 deletions hydragnn/models/PNAEqStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
)
update = PainnUpdate(node_size=input_dim, last_layer=last_layer)
"""
The following linear layers are to get the correct sizing of embeddings. This is
necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly
The following linear layers are to get the correct sizing of embeddings. This is
necessary to use the hidden_dim, output_dim of HYDRAGNN's stacked conv layers correctly
because node_scalar and node-vector are updated through an additive skip connection.
"""
# Embed down to output size
Expand Down

0 comments on commit 4db2f73

Please sign in to comment.