Skip to content

Commit

Permalink
Fix to_cugraph and from_cugraph tests in test_convert (#7908)
Browse files Browse the repository at this point in the history
Fix to match the behavior expected with cugraph >= 23.08

Signed-off-by: Serge Panev <spanev@nvidia.com>
  • Loading branch information
Kh4L authored Aug 21, 2023
1 parent deff5a4 commit 1ffce71
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions test/utils/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,13 @@ def test_to_cugraph(edge_weight, directed, relabel_nodes):
edge_list = graph.view_edge_list()
assert edge_list is not None

edge_list = edge_list.sort_values(by=['src', 'dst'])
edge_list = edge_list.sort_values(by=[0, 1])

cu_edge_index = edge_list[['src', 'dst']].to_pandas().values
cu_edge_index = edge_list[[0, 1]].to_pandas().values
cu_edge_index = torch.from_numpy(cu_edge_index).t()
cu_edge_weight = None
if edge_weight is not None:
cu_edge_weight = edge_list['weights'].to_pandas().values
cu_edge_weight = edge_list['2'].to_pandas().values
cu_edge_weight = torch.from_numpy(cu_edge_weight)

cu_edge_index, cu_edge_weight = sort_edge_index(cu_edge_index,
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,13 @@ def from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:
"""
df = g.view_edge_list()

src = from_dlpack(df['src'].to_dlpack()).long()
dst = from_dlpack(df['dst'].to_dlpack()).long()
src = from_dlpack(df[0].to_dlpack()).long()
dst = from_dlpack(df[1].to_dlpack()).long()
edge_index = torch.stack([src, dst], dim=0)

edge_weight = None
if 'weights' in df:
edge_weight = from_dlpack(df['weights'].to_dlpack())
if '2' in df:
edge_weight = from_dlpack(df['2'].to_dlpack())

return edge_index, edge_weight

Expand Down

0 comments on commit 1ffce71

Please sign in to comment.