diff --git a/torch_geometric/nn/models/mlp.py b/torch_geometric/nn/models/mlp.py index d6ecab84bbb7..360995b012ea 100644 --- a/torch_geometric/nn/models/mlp.py +++ b/torch_geometric/nn/models/mlp.py @@ -179,14 +179,14 @@ def reset_parameters(self): def forward(self, x: Tensor, return_emb: NoneType = None) -> Tensor: """""" - for lin, norm, dropout in zip(self.lins, self.norms, self.dropout): + for i, (lin, norm) in enumerate(zip(self.lins, self.norms)): x = lin(x) if self.act is not None and self.act_first: x = self.act(x) x = norm(x) if self.act is not None and not self.act_first: x = self.act(x) - x = F.dropout(x, p=dropout, training=self.training) + x = F.dropout(x, p=self.dropout[i], training=self.training) emb = x if self.plain_last: diff --git a/torch_geometric/nn/unpool/knn_interpolate.py b/torch_geometric/nn/unpool/knn_interpolate.py index 93401049dd19..7d2d0e07191d 100644 --- a/torch_geometric/nn/unpool/knn_interpolate.py +++ b/torch_geometric/nn/unpool/knn_interpolate.py @@ -2,10 +2,12 @@ from torch_scatter import scatter_add from torch_geometric.nn import knn +from torch_geometric.typing import OptTensor -def knn_interpolate(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3, - num_workers=1): +def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, + batch_x: OptTensor = None, batch_y: OptTensor = None, + k: int = 3, num_workers: int = 1): r"""The k-NN interpolation from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" `_ paper. @@ -44,7 +46,7 @@ def knn_interpolate(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3, with torch.no_grad(): assign_index = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y, num_workers=num_workers) - y_idx, x_idx = assign_index + y_idx, x_idx = assign_index[0], assign_index[1] diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16)