-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Light refactor to enhance development #302
Conversation
hydragnn/models/EGCLStack.py
Outdated
super().__init__(input_args, conv_args, *args, **kwargs) | ||
|
||
assert ( | ||
self.input_args == "inv_node_feat, equiv_node_feat, edge_index, edge_attr" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does EGCL always require edge_attributes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does in the current/previous implementation. There is a way around this, but it may take some unrolling to update it.
hydragnn/models/PAINNStack.py
Outdated
], | ||
) | ||
|
||
def forward(self, data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the largest advantage of the approach you're taking here, right? By generalizing (pos/v --> equiv_node_feature), you've taken out the need to rewrite the whole forward function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The aim is to support the development equivariant architectures by providing a specific variable for handling equivariant information. This is done by improving the variable names to inv_node_feat
and equiv_node_feat
instead of vaguely named x, pos
, respectively. One advantage would be that for many architectures, the forward function is handled by the Base class.
Are the model stacks like PAINN, DimeNet, etc... finished with conversion to your proposed method here? If so, then I can take inspiration and adjust MACE tomorrow :) Also, I think a version diff could be making black give you extraneous changes to the files in HYDRA, maybe check that |
Yes. I will change the extra formatting back. I ran into some other issues with PAINN and PNAEq as observed in the most recent check. I ran out of time to reason through it yesterday. I will check it out this evening. I have temporarily converted MACE to this format but I believe there is a better way to handle the variables. I would definitely appreciate help cleaning up MACE. Let me ping you when the checks pass for PAINN and PNAEq. |
I'm initializing a PR with the following initiatives:
(1) light refactoring as discussed this morning (args -> create.py)
(2) equivariance clean up (inv_feat + geom_feat tracking inside of the MPNN)
I will complete this PR and add full documentation if there is interest in these initiatives.