Skip to content

Commit

Permalink
fix some bugs in link_prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
anpolol committed Dec 19, 2023
1 parent cb0a70c commit b7317da
Show file tree
Hide file tree
Showing 6 changed files with 2,343 additions and 147 deletions.
3 changes: 2 additions & 1 deletion stable_gnn/model_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def forward(self, x: Tensor, edge_index: Adj, batch: Tensor) -> Tuple[Tensor, Te
:return: (Tensor, Tensor): Predicted probabilities of labels and predicted degrees of graphs
"""
# 1. Obtain node embeddings
x = x.type(torch.FloatTensor).to(self.device)
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < self.num_layers - 1:
Expand Down Expand Up @@ -111,7 +112,7 @@ def loss_sup(pred: Tensor, label: Tensor) -> Tensor:
:param label: (Tensor): Genuine labels
:return: (Tensor): Loss
"""
return F.nll_loss(pred, label)
return F.nll_loss(pred, label)#.type(torch.LongTensor).to(device))

@staticmethod
def convert_dataset(
Expand Down
2 changes: 1 addition & 1 deletion stable_gnn/model_link_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def train_cl(self, train_edges: List[List[int]], neg_samples_train: List[List[in
:param neg_samples_train: (List): List of negative samples to train
:return: (BaseEstimator): Classifier which support fit predict notation
"""
if self.number_of_trials:
if self.number_of_trials > 0:
self.embeddings = EmbeddingFactory().build_embeddings(
loss_name=self.loss_name,
conv=self.emb_conv_name,
Expand Down
Loading

0 comments on commit b7317da

Please sign in to comment.