Skip to content

Commit

Permalink
[Bugfix] Update loading datasets (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
cenyk1230 authored Nov 13, 2021
1 parent c819937 commit bd47b29
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions cogdl/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def build_dataset_from_path(data_path, dataset=None):
module = importlib.import_module(path)
class_name = SUPPORTED_DATASETS[dataset].split(".")[-1]
dataset = getattr(module, class_name)(data_path=data_path)
return dataset

if dataset is None:
try:
return torch.load(data_path)
except Exception as e:
print(e)
exit(0)
raise ValueError("You are expected to specify `dataset` and `data_path`")


Expand Down
2 changes: 1 addition & 1 deletion cogdl/utils/spmm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def mh_spmm(graph, attention, h, csrmhspmm=None, fast_spmm=None):
h = h.permute(1, 0, 2).contiguous()
for i in range(nhead):
edge_weight = attention[:, i]
graph.edge_weight = edge_weight
graph.edge_weight = edge_weight.contiguous()
hidden = h[i]
assert not torch.isnan(hidden).any()
h_prime.append(spmm(graph, hidden, fast_spmm=fast_spmm))
Expand Down
6 changes: 3 additions & 3 deletions examples/VRGCN/VRGCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(self, x, sample_ids_adjs, full_ids_adjs) -> Tensor:
x = x - self.histories[i].pull(cur_id).detach()
h = self.histories[i].pull(full_id)

x = self.spmm(sample_adj, x)[: target_id.shape[0]] + self.spmm(full_adj, h)[: target_id.shape[0]].detach()
x = spmm(sample_adj, x)[: target_id.shape[0]] + spmm(full_adj, h)[: target_id.shape[0]].detach()
x = self.lins[i](x)

if i != self.num_layers - 1:
Expand All @@ -131,7 +131,7 @@ def inference(self, x, adj):
adj = adj.to(self._device)
xs = [x]
for i in range(self.num_layers):
x = self.spmm(adj, x)
x = spmm(adj, x)
x = self.lins[i](x)
if i != self.num_layers - 1:
x = self.norms[i](x)
Expand All @@ -148,7 +148,7 @@ def inference_batch(self, x, test_loader):
tmp_x = []
for target_id, full_id, full_adj in test_loader:
full_adj = full_adj.to(device)
agg_x = self.spmm(full_adj, x[full_id].to(device))[: target_id.shape[0]]
agg_x = spmm(full_adj, x[full_id].to(device))[: target_id.shape[0]]
agg_x = self.lins[i](agg_x)

if i != self.num_layers - 1:
Expand Down

0 comments on commit bd47b29

Please sign in to comment.