-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR make the LinkX model to be jittable and get better performance. The original LinkX model cann't be JIT due to many reasons: 1. The type of second parameter `edge_index` of LinkX.forward is `Adj`. So we need to add `@torch.jit._overload_method` for this method. 2. The `edge_norm` and `edge_mlp` in LinkX model isn't always initialized. So we modify the initialize logic. 3. After we add `@torch.jit._overload_method` for forward, we need a wrapper of model. Overwise, the TorchScript can't find the forward function. 4. The return type of MLP is an `Union`, which will caused type error in TorchScript when we use the return value of MLP. I think there might be a better way to make it jittable, but I haven't found yet. --------- Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
Showing
4 changed files
with
93 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters