Skip to content

Commit

Permalink
errant logic fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zachfox committed Oct 30, 2024
1 parent 3837de6 commit 4b45f08
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,69 +307,70 @@ def forward(self, data):
x = data.x
pos = data.pos

# print("data.x IN: ", x)
# print("data.pos IN", pos)

### encoder part ####
conv_args = self._conv_args(data)
for conv, feat_layer in zip(self.graph_convs, self.feature_layers):
c, pos = conv(x=x, pos=pos, **conv_args)
if not self.conv_checkpointing:
c, pos = conv(x=x, pos=pos, **conv_args)
else:
c, pos = checkpoint(
conv, use_reentrant=False, x=x, pos=pos, **conv_args
)
x = self.activation_function(feat_layer(c))

#### multi-head decoder part####
# shared dense layers for graph level output
if data.batch is None:
x_graph = x.mean(dim=0, keepdim=True)
else:
x_graph = global_mean_pool(x, data.batch.to(x.device))
outputs = []
outputs_var = []
for head_dim, headloc, type_head in zip(
self.head_dims, self.heads_NN, self.head_type
):
if type_head == "graph":
x_graph_head = self.graph_shared(x_graph)
outputs.append(headloc(x_graph_head))
output_head = headloc(x_graph_head)
outputs.append(output_head[:, :head_dim])
outputs_var.append(output_head[:, head_dim:] ** 2)
elif type_head == "node":
if self.node_NN_type == "conv":
for conv, batch_norm in zip(headloc[0::2], headloc[1::2]):
x_node = self.activation_function(
batch_norm(conv(x=x, edge_index=data.edge_index))
)
c, pos = conv(x=x, pos=pos, **conv_args)
c = batch_norm(c)
x = self.activation_function(c)
outputs.append(x_node[:, :head_dim])
outputs_var.append(x_node[:, head_dim:] ** 2)
else:
x_node = headloc(x=x, batch=data.batch)

# print("NODE OUT: ", x_node)
elif type_head == "pos":
# print("POS OUT: ", pos)
if self.equivariance:
x_node = (
pos - data.pos
) # following 3.2 The Dynamics in "Equivariant Diffusion for Molecule Generation in 3D" (Hoogeboom et al 2022)
# calculate the center of gravity for each subgraph
)
sg_num_nodes = [
d.num_nodes for d in data.to_data_list()
] # TODO - inefficient
]
com_ten = []
# std_ten = []
place = 0
for sgnn in sg_num_nodes:
sg_x_node = x_node[place : place + sgnn]
com_ten.append(
sg_x_node.mean(dim=0, keepdim=True).tile(sgnn, 1)
)
# std_ten.append(sg_x_node.std() * torch.ones_like(sg_x_node))
place += sgnn
com_ten = torch.cat(com_ten, dim=0)
# std_ten = torch.cat(std_ten, dim=0)
x_node = x_node - com_ten # subtract centers of mass
# x_node = x_node / std_ten # normalize output like GroupNorm
x_node = x_node - com_ten
else:
x_node = pos
else:
raise NotImplementedError(
"Head type {} not recognized".format(type_head)
)
outputs.append(x_node)
if self.var_output:
return outputs, outputs_var
return outputs

def loss(self, pred, value, head_index):
Expand Down

0 comments on commit 4b45f08

Please sign in to comment.