Skip to content

Commit

Permalink
Merge pull request #1225 from Sherry-XLL/master
Browse files Browse the repository at this point in the history
FIX: fix UserWarning in get_norm_adj_mat and accelerate csr2tensor
  • Loading branch information
2017pxy authored Apr 4, 2022
2 parents 4f76169 + 390c305 commit 86b20cd
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/lightgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_norm_adj_mat(self):
L = sp.coo_matrix(L)
row = L.row
col = L.col
i = torch.LongTensor([row, col])
i = torch.LongTensor(np.array([row, col]))
data = torch.FloatTensor(L.data)
SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
return SparseL
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/ncl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_norm_adj_mat(self):
L = sp.coo_matrix(L)
row = L.row
col = L.col
i = torch.LongTensor([row, col])
i = torch.LongTensor(np.array([row, col]))
data = torch.FloatTensor(L.data)
SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
return SparseL
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/ngcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_norm_adj_mat(self):
L = sp.coo_matrix(L)
row = L.row
col = L.col
i = torch.LongTensor([row, col])
i = torch.LongTensor(np.array([row, col]))
data = torch.FloatTensor(L.data)
SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
return SparseL
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/sgl.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def csr2tensor(self, matrix: sp.csr_matrix):
"""
matrix = matrix.tocoo()
x = torch.sparse.FloatTensor(
torch.LongTensor([matrix.row.tolist(), matrix.col.tolist()]),
torch.LongTensor(np.array([matrix.row, matrix.col])),
torch.FloatTensor(matrix.data.astype(np.float32)), matrix.shape
).to(self.device)
return x
Expand Down

0 comments on commit 86b20cd

Please sign in to comment.