Skip to content

Commit

Permalink
Remove warning in segmatmul_heuristic (#7379)
Browse files Browse the repository at this point in the history
otherwise I get some warning

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored May 17, 2023
1 parent 7c4aef8 commit def5301
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
3 changes: 1 addition & 2 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def radius(
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
Args:
x (torch.Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
Expand Down
7 changes: 6 additions & 1 deletion torch_geometric/utils/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def segmatmul_heuristic(inputs: Tensor, type_ptr, weight: Tensor):
in_feat = inputs.size(1)
out_feat = weight.size(-1)
# this heuristic was learned with learn_sklearn_heuristic on an A100
x = torch.tensor([num_types, max_num_nodes_per_types, in_feat, out_feat])
x = torch.tensor([
int(num_types),
int(max_num_nodes_per_types),
int(in_feat),
int(out_feat)
])
scale_mean = torch.tensor(
[125.11603189, 12133.21523472, 163.81222321, 32.43755536])
scale_scale = torch.tensor(
Expand Down

0 comments on commit def5301

Please sign in to comment.