This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Graph embedding and graph backbones (#592)
Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
- Loading branch information
1 parent
8c9903d
commit 00ef240
Showing
22 changed files
with
360 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
.. _graph_embedder: | ||
|
||
############## | ||
Graph Embedder | ||
############## | ||
|
||
******** | ||
The Task | ||
******** | ||
This task consists of creating an embedding of a graph. That is, a vector of features which can be used for a downstream task. | ||
The :class:`~flash.graph.classification.model.GraphEmbedder` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric <https://github.com/rusty1s/pytorch_geometric>`_. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at generating embeddings of graphs from the KKI data set from `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_. | ||
|
||
We start by creating the `TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>`. | ||
Next, we load a trained :class:`~flash.graph.classification.model.GraphEmbedder` (from a previously trained :class:`~flash.graph.classification.model.GraphClassifier`). | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/graph_embedder.py | ||
:language: python | ||
:lines: 14 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401 | ||
from flash.graph.embedding import GraphEmbedder # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from functools import partial | ||
|
||
from flash.core.registry import FlashRegistry | ||
from flash.core.utilities.imports import _GRAPH_AVAILABLE | ||
from flash.core.utilities.providers import _PYTORCH_GEOMETRIC | ||
|
||
if _GRAPH_AVAILABLE: | ||
from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE | ||
|
||
MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN} | ||
else: | ||
MODELS = {} | ||
|
||
GRAPH_BACKBONES = FlashRegistry("backbones") | ||
|
||
|
||
def _load_graph_backbone( | ||
model_name: str, | ||
in_channels: int, | ||
hidden_channels: int = 512, | ||
num_layers: int = 4, | ||
): | ||
model = MODELS[model_name] | ||
return model(in_channels, hidden_channels, num_layers) | ||
|
||
|
||
for model_name in MODELS.keys(): | ||
GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from flash.graph.classification.data import GraphClassificationData # noqa: F401 | ||
from flash.graph.embedding.model import GraphEmbedder # noqa: F401 |
Oops, something went wrong.