Skip to content

Commit

Permalink
Jittable fixes - knn_interpolate and mlp (#5025)
Browse files Browse the repository at this point in the history
* typehints and explicit unpack for `torch.jit.script`

* fix for jit script error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
GericoVi and pre-commit-ci[bot] authored Jul 21, 2022
1 parent 405ae34 commit b757161
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,14 @@ def reset_parameters(self):

def forward(self, x: Tensor, return_emb: NoneType = None) -> Tensor:
""""""
for lin, norm, dropout in zip(self.lins, self.norms, self.dropout):
for i, (lin, norm) in enumerate(zip(self.lins, self.norms)):
x = lin(x)
if self.act is not None and self.act_first:
x = self.act(x)
x = norm(x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = F.dropout(x, p=dropout, training=self.training)
x = F.dropout(x, p=self.dropout[i], training=self.training)
emb = x

if self.plain_last:
Expand Down
8 changes: 5 additions & 3 deletions torch_geometric/nn/unpool/knn_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from torch_scatter import scatter_add

from torch_geometric.nn import knn
from torch_geometric.typing import OptTensor


def knn_interpolate(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3,
num_workers=1):
def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor,
batch_x: OptTensor = None, batch_y: OptTensor = None,
k: int = 3, num_workers: int = 1):
r"""The k-NN interpolation from the `"PointNet++: Deep Hierarchical
Feature Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper.
Expand Down Expand Up @@ -44,7 +46,7 @@ def knn_interpolate(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3,
with torch.no_grad():
assign_index = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y,
num_workers=num_workers)
y_idx, x_idx = assign_index
y_idx, x_idx = assign_index[0], assign_index[1]
diff = pos_x[x_idx] - pos_y[y_idx]
squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
weights = 1.0 / torch.clamp(squared_distance, min=1e-16)
Expand Down

0 comments on commit b757161

Please sign in to comment.