Skip to content

Commit

Permalink
add todo
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jun 5, 2022
1 parent 6fd8f43 commit d2b4b8c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
6 changes: 5 additions & 1 deletion torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __repr__(self) -> str:
# Assertions ##############################################################

def assert_index_present(self, index: Optional[Tensor]):
# TODO Currently, not all aggregators support `ptr`. This assert helps
# to ensure that we require `index` to be passed to the computation:
if index is None:
raise NotImplementedError(f"'{self.__class__.__name__}' requires "
f"'index' to be specified")
Expand Down Expand Up @@ -107,9 +109,11 @@ def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None,
dim_size: Optional[int] = None,
dim: int = -2) -> Tuple[Tensor, Tensor]:

self.assert_index_present(index) # TODO
# TODO Currently, `to_dense_batch` can only operate on `index`:
self.assert_index_present(index)
self.assert_sorted_index(index)
self.assert_two_dimensional_input(x, dim)

return to_dense_batch(x, index, batch_size=dim_size)


Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/aggr/set2set.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

self.assert_index_present(index) # TODO
# TODO Currently, `to_dense_batch` can only operate on `index`:
self.assert_index_present(index)
self.assert_two_dimensional_input(x, dim)

h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))),
Expand Down

0 comments on commit d2b4b8c

Please sign in to comment.