Skip to content

Commit

Permalink
Merge branch 'master' into yanbing/benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Aug 17, 2022
2 parents 1789e6c + 503b181 commit 1cf7d7e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))
- Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652))
- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647))
- Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))
- Added `HeteroData.subgraph()`, `HeteroData.node_type_subgraph()` and `HeteroData.edge_type_subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))
- Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626))
- Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644))
- Added PyTorch Lightning support in GraphGym ([#4511](https://github.com/pyg-team/pytorch_geometric/pull/4511), [#4516](https://github.com/pyg-team/pytorch_geometric/pull/4516) [#4531](https://github.com/pyg-team/pytorch_geometric/pull/4531), [#4689](https://github.com/pyg-team/pytorch_geometric/pull/4689), [#4843](https://github.com/pyg-team/pytorch_geometric/pull/4843))
Expand Down
2 changes: 1 addition & 1 deletion examples/hetero/metapath2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def train(epoch, log_steps=100, eval_steps=2000):
def test(train_ratio=0.1):
model.eval()

z = model('author', batch=data['author'].y_index)
z = model('author', batch=data['author'].y_index.to(device))
y = data['author'].y

perm = torch.randperm(z.size(0))
Expand Down
10 changes: 10 additions & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@ def test_hetero_data_subgraph():
assert len(out['author', 'paper']) == 1
assert out['author', 'paper'].edge_index is not None

out = data.node_type_subgraph(['paper', 'author'])
assert out.node_types == ['paper', 'author']
assert out.edge_types == [('paper', 'to', 'paper'),
('author', 'to', 'paper'),
('paper', 'to', 'author')]

out = data.edge_type_subgraph([('paper', 'author')])
assert out.node_types == ['paper', 'author']
assert out.edge_types == [('paper', 'to', 'author')]


def test_copy_hetero_data():
data = HeteroData()
Expand Down
33 changes: 33 additions & 0 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,39 @@ def subgraph(self, subset_dict: Dict[NodeType, Tensor]) -> 'HeteroData':

return data

def node_type_subgraph(self, node_types: List[NodeType]) -> 'HeteroData':
r"""Returns the subgraph induced by the given :obj:`node_types`, *i.e.*
the returned :class:`HeteroData` object only contains the node types
which are included in :obj:`node_types`, and only contains the edge
types where both end points are included in :obj:`node_types`."""
data = copy.copy(self)
for edge_type in self.edge_types:
src, _, dst = edge_type
if src not in node_types or dst not in node_types:
del data[edge_type]
for node_type in self.node_types:
if node_type not in node_types:
del data[node_type]
return data

def edge_type_subgraph(self, edge_types: List[EdgeType]) -> 'HeteroData':
r"""Returns the subgraph induced by the given :obj:`edge_types`, *i.e.*
the returned :class:`HeteroData` object only contains the edge types
which are included in :obj:`edge_types`, and only contains the node
types of the end points which are included in :obj:`node_types`."""
edge_types = [self._to_canonical(e) for e in edge_types]

data = copy.copy(self)
for edge_type in self.edge_types:
if edge_type not in edge_types:
del data[edge_type]
node_types = set(e[0] for e in edge_types)
node_types |= set(e[-1] for e in edge_types)
for node_type in self.node_types:
if node_type not in node_types:
del data[node_type]
return data

def to_homogeneous(self, node_attrs: Optional[List[str]] = None,
edge_attrs: Optional[List[str]] = None,
add_node_type: bool = True,
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/aggr/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def reset_parameters(self):

def init_output(self, index: Optional[Tensor] = None) -> Tensor:
index_size = 1 if index is None else int(index.max().item() + 1)
return torch.zeros(index_size, self.output_dim,
requires_grad=True).float()
return torch.zeros(index_size, self.output_dim, requires_grad=True,
device=self.lamb.device).float()

def reg(self, y: Tensor) -> Tensor:
return self.softplus(self.lamb) * y.square().sum(dim=-1).mean()
Expand Down

0 comments on commit 1cf7d7e

Please sign in to comment.