-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make LinkX
model jittable
#6712
Conversation
Codecov Report
@@ Coverage Diff @@
## master #6712 +/- ##
==========================================
+ Coverage 88.97% 89.05% +0.08%
==========================================
Files 423 423
Lines 22965 22978 +13
==========================================
+ Hits 20433 20464 +31
+ Misses 2532 2514 -18
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cool, thank you!
If we plan to do this for other models, I think it would be a nice idea to out-source the functionality of creating a jittable parent module.
This PR add the TorchScript support for `RECT_L` model. The fail reason and our solution for original code is very similar with PR [#6721](#6712), except that this model using the `torch.jit.export` on `embed` and `get_semantic_labels` methods. And another fail reason is `@torch.no_grad`, which bring some error msg I cann't understand. Adding TorchScript support will bring a lot of extra code and reduce the code readability. I will consider how to do better in another PR. --------- Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
This PR make the LinkX model to be jittable and get better performance.
The original LinkX model cann't be JIT due to many reasons:
edge_index
of LinkX.forward isAdj
. So we need to add@torch.jit._overload_method
for this method.edge_norm
andedge_mlp
in LinkX model isn't always initialized. So we modify the initialize logic.@torch.jit._overload_method
for forward, we need a wrapper of model. Overwise, the TorchScript can't find the forward function.Union
, which will caused type error in TorchScript when we use the return value of MLP.I think there might be a better way to make it jittable, but I haven't found yet.