Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LinkX model jittable #6712

Merged
merged 6 commits into from
Feb 15, 2023
Merged

Make LinkX model jittable #6712

merged 6 commits into from
Feb 15, 2023

Conversation

ftxj
Copy link
Contributor

@ftxj ftxj commented Feb 15, 2023

This PR make the LinkX model to be jittable and get better performance.

The original LinkX model cann't be JIT due to many reasons:

  1. The type of second parameter edge_index of LinkX.forward is Adj. So we need to add @torch.jit._overload_method for this method.
  2. The edge_norm and edge_mlp in LinkX model isn't always initialized. So we modify the initialize logic.
  3. After we add @torch.jit._overload_method for forward, we need a wrapper of model. Overwise, the TorchScript can't find the forward function.
  4. The return type of MLP is an Union, which will caused type error in TorchScript when we use the return value of MLP.

I think there might be a better way to make it jittable, but I haven't found yet.

@codecov
Copy link

codecov bot commented Feb 15, 2023

Codecov Report

Merging #6712 (6b4c2db) into master (48a3686) will increase coverage by 0.08%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #6712      +/-   ##
==========================================
+ Coverage   88.97%   89.05%   +0.08%     
==========================================
  Files         423      423              
  Lines       22965    22978      +13     
==========================================
+ Hits        20433    20464      +31     
+ Misses       2532     2514      -18     
Impacted Files Coverage Δ
torch_geometric/nn/models/linkx.py 100.00% <100.00%> (ø)
torch_geometric/nn/models/mlp.py 95.18% <100.00%> (ø)
torch_geometric/data/storage.py 82.76% <0.00%> (+0.08%) ⬆️
torch_geometric/nn/aggr/multi.py 98.90% <0.00%> (+1.09%) ⬆️
torch_geometric/nn/conv/gen_conv.py 94.94% <0.00%> (+17.17%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@rusty1s rusty1s changed the title Make LinkX model Jittable Make LinkX model jittable Feb 15, 2023
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool, thank you!

If we plan to do this for other models, I think it would be a nice idea to out-source the functionality of creating a jittable parent module.

@rusty1s rusty1s merged commit 34668c7 into pyg-team:master Feb 15, 2023
rusty1s added a commit that referenced this pull request Feb 17, 2023
This PR add the TorchScript support for `RECT_L` model.

The fail reason and our solution for original code is very similar with
PR [#6721](#6712),
except that this model using the `torch.jit.export` on `embed` and
`get_semantic_labels` methods. And another fail reason is
`@torch.no_grad`, which bring some error msg I cann't understand.

Adding TorchScript support will bring a lot of extra code and reduce the
code readability. I will consider how to do better in another PR.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants