Skip to content

Commit

Permalink
fixed a bug in the normalized laplacian implementation and a bug in b…
Browse files Browse the repository at this point in the history
…atch for sparse graph creation
  • Loading branch information
AmitaiYacobi committed Nov 25, 2023
1 parent 53a5c8b commit da19231
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/spectralnet/_losses/_spectralnet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def forward(
m = Y.size(0)
if is_normalized:
D = torch.sum(W, dim=1)
Y = Y / D[:, None]
Y = Y / torch.sqrt(D)[:, None]

Dy = torch.cdist(Y, Y)
loss = torch.sum(W * Dy.pow(2)) / (2 * m)
Expand Down
2 changes: 1 addition & 1 deletion src/spectralnet/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def make_batch_for_sparse_grapsh(batch_x: torch.Tensor) -> torch.Tensor:
x = x.detach().cpu().numpy()
nn_indices = u.get_nns_by_vector(x, n_neighbors)
nn_tensors = [u.get_item_vector(i) for i in nn_indices[1:]]
nn_tensors = torch.tensor(nn_tensors)
nn_tensors = torch.tensor(nn_tensors, device=batch_x.device)
new_batch_x = torch.cat((new_batch_x, nn_tensors))

return new_batch_x
Expand Down

0 comments on commit da19231

Please sign in to comment.