This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Graph embedding and graph backbones #592
Merged
ethanwharris
merged 205 commits into
Lightning-Universe:master
from
PabloAMC:task_a_thon
Nov 9, 2021
Merged
Changes from 192 commits
Commits
Show all changes
205 commits
Select commit
Hold shift + click to select a range
ea4e7b6
Initial structure of GraphClassification model.py
1255f8f
Improvement of model.py. Still need to debug etc
365863f
BasicDataset Implemented
02b0f6a
Create __init__.py
f28e949
Implemented dataset and DataModule as for image processing
ad76827
Pipeline taken from images.
ea6ee9d
Initial structure of GraphClassification model.py
8b93a4a
Improvement of model.py. Still need to debug etc
6b4d7e3
BasicDataset Implemented
48dcf2d
Implemented dataset and DataModule as for image processing
49dfe4d
Pipeline taken from images.
151f7d9
Choice of model implemented (you can pass a model to GraphClassifier)
6236d95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 28e315f
Initial readaptation of the structure
08ace7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 35a79cb
Minimal structure of how to structure data.py files
93dd638
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 82576a5
Minor corrections
089cb07
update
tchaton 920fc68
i
tchaton c970b5f
update
tchaton fe41405
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 868b8d7
Added auto_dataset.num_features
debf5da
Deleted manually included num_features so that it is extracted from G…
1e6b2b0
Test for GraphClassification implemented
faa8709
Documentation for GraphClassification included
072a35b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] be015df
Creation of from_pygdatasequence method in DataModule and GraphSequen…
1fb160b
Update graph_classification.py
bb3b941
Update datatype_graph.txt
3583d5c
Tests and docs for the from_pygdatasequence method
193c2bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f59a2fc
Graph requirements
71a15bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 370ea24
Update CHANGELOG.md
e9d4e93
Update requirements with pytorch geometric libraries
a2b208e
Simplified, version with only the DataSource
7c3eaf4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 809b615
Minor tweaks
bf4a9b6
Merge branch 'master' of https://github.com/PabloAMC/lightning-flash
3089d94
Update the flash_example to reflect the new template
338c3ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0f39778
Delete IMDB-BINARY_A.txt
b051f36
Delete IMDB-BINARY_graph_indicator.txt
4e23aff
Delete IMDB-BINARY_graph_labels.txt
1ad2ce3
Class method from_pygdatasequence from flash/core/data/data_module.py
4e9178c
Merge branch 'master' of https://github.com/PabloAMC/lightning-flash
2b1bbd9
Creating backbones
94d4e60
Merge branch 'master' into master
ethanwharris c2cac21
Changing GRAPH_BACKBONES to GRAPH_CLASSIFICATION_BACKBONES
73663bd
Update docs
ethanwharris 287132f
Merge branch 'master' into master
ethanwharris 2631bd4
fix imports.py
ethanwharris b4d3b41
remove unused imports
ethanwharris b19b5b8
clean init.py
ethanwharris 7a4a914
updates
ethanwharris c831751
Minor tweaks in the docs and change from Graph_backbones to graph_cla…
13aa012
Updates
ethanwharris e9cedb0
Updates
ethanwharris fe95a77
Updates
ethanwharris 54f6a88
Graph embedding task implemented, modulo corrections
667140a
Error corrections
db0b599
Merge branch 'master' into task_a_thon
bbdad91
Updates
ethanwharris d5deb38
Update docs
ethanwharris f634e9f
Update docs
ethanwharris 428f313
Update docs
ethanwharris 435cc95
fix tests
ethanwharris b54e543
fix tests
ethanwharris 4453818
Add API reference
ethanwharris b4877b1
Try fix
ethanwharris aceef22
Merge branch 'master' into task_a_thon
c6e3b84
Included Networkx as requirement for graph library
8fe2813
Try fix
ethanwharris 113b6d0
Try fix
ethanwharris bba0395
Merge branch 'master' into task_a_thon
d6dec3d
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])…
d8d26ab
Update flash/core/data/auto_dataset.py
ethanwharris 7b2734f
Update docstring
ethanwharris a09254f
Correction of minor errors
62ff1b3
Updating docs
97b5e07
Merge branch 'master' into task_a_thon
8d9db4c
Update graph_embedding docs
2021a03
Minor tweaks
b064be3
Merge remote-tracking branch 'upstream/master' into task_a_thon
31354fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b1fc948
Correct documentation
a1e5bb7
Creating a head suited for PyG backbones
bc06e43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 335d30c
Update the head of the embedding model
56598f7
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
7de0a50
Update graph_classification example
628ef9b
Update backbones to match how they will work in Pytorch Geometric
bed29f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e49c290
Update CHANGELOG.md
95abca4
Merge remote-tracking branch 'upstream/master' into task_a_thon
0903585
Update graph requirements
cf45529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a872b75
Update test_model.py
88812b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 360250b
Update data_module.py
41db022
Update data_module.py
077b34b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5d3e1b4
Update test_model.py
70d0570
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
9a04c2f
_TORCH_GEOMETRIC_AVAILABLE to _GRAPH_AVAILABLE
7d63be0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ac4616a
Adding num_features to uses of backbones
53395e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f902c5f
Update graph_classification.py
c9c3c0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 093a266
Update test_model.py
c1b5cab
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
f2962b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 89a8e2d
Backbone kwargs default changed from None to {}
56e9c2e
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
9012d61
error correction in graph/embedding/test_model.py: num_classes -> emb…
76c6d55
Pretrained option not implemented for backbones error
8be004b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2bdf6a6
Update model.py
9ecff7e
Small error correction
cc1ee34
Merge branch 'master' into task_a_thon
ethanwharris 63577b9
Updates
ethanwharris 62802a0
models adapted to torch geometric basic models
7590175
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
92d6b3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3428fd6
Merge remote-tracking branch 'upstream/master' into task_a_thon
5b72219
Merge branch 'master' into task_a_thon
c467e60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ff5a592
Updating minor corrections
de9c45a
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
c878c11
Eliminating needless stuff
667f5d0
Changed descriptions
5a7060d
Update CHANGELOG.md
ce168a7
Update flash/graph/embedding/__init__.py
d9d9298
Update flash_examples/graph_classification.py
6816d76
Update docs/source/reference/graph_embedding.rst
caad522
Updates based on ethan comments
fe92e11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 472e00d
Importing partial
4b1923f
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
1c76bd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2de9f26
Update backbones.py
de9577a
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
4b8f4b5
Update backbones.py
11bd028
Update model.py
c40ad61
Eliminating pretrained option
3453e2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 85dc93e
Update backbones.py
e4de58d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6d631e6
Update backbones.py
e74f17b
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
b886e96
Update backbones.py
38196c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 90cdfa8
Update backbones.py
ee7cb69
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
14cd39c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4199aab
Update backbones.py
9a2d046
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
4858083
Merge branch 'master' into task_a_thon
80e4e3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 65e499a
Update backbones.py
54f54d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 18d78d7
Update backbones.py
a69e307
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
ac9bb8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6658640
Update backbones.py
11b07f9
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
89e7040
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 18c69b6
Update backbones.py
e980c25
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
8f9217d
Update backbones.py
0e9968a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a440d47
Update backbones.py
e6e84df
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
58b9348
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fdbc4db
Minor correction in model.py
c708d04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 994dd30
self.backbone(x.x, x.edge_index)
57571f8
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
8b9d733
Update model.py
a10749c
Update model.py
21e8078
update
tchaton 3aa0e95
Merge branch 'master' into task_a_thon
6bab7b1
Merge branch 'master' into task_a_thon
ethanwharris d3148be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 80bb1e8
Updates
ethanwharris 86ff38b
Merge branch 'task_a_thon' of https://github.com/PabloAMC/lightning-f…
ethanwharris b0429ca
Update CHANGELOG.md
ethanwharris dc27f8a
Docs fixes
ethanwharris 00b7159
Updates
ethanwharris da8fefb
Docs
ethanwharris 2738e75
Tests
ethanwharris cdb8217
Pre-commit
ethanwharris 3a67f66
Merge branch 'master' into task_a_thon
ethanwharris 47dae8a
Update requirements
ethanwharris 9ca6084
Formatting
ethanwharris acf92e2
Fix reqs.
ethanwharris bff6228
Merge branch 'master' into task_a_thon
ethanwharris 8b490b0
Fixes
ethanwharris 54d5de8
Multiple pooling functions
ethanwharris c59c766
Fixes
ethanwharris 8ee6262
Fixes
ethanwharris dfa14e6
Try fix
ethanwharris f357f62
Speed up CI
ethanwharris File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
.. _graph_embedder: | ||
|
||
############## | ||
Graph Embedder | ||
############## | ||
|
||
******** | ||
The Task | ||
******** | ||
This task consists of creating an embedding of a graph. That is, a vector of features which can be used for a downstream task. | ||
The :class:`~flash.graph.classification.model.GraphEmbedder` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric <https://github.com/rusty1s/pytorch_geometric>`_. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at generating embeddings of graphs from the KKI data set from `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_. | ||
|
||
We start by creating the `TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>`. | ||
Next, we load a trained :class:`~flash.graph.classification.model.GraphEmbedder` (from a previously trained :class:`~flash.graph.classification.model.GraphClassifier`). | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/graph_embedder.py | ||
:language: python | ||
:lines: 14 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401 | ||
from flash.graph.embedding import GraphEmbedder # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from functools import partial | ||
|
||
from flash.core.registry import FlashRegistry | ||
from flash.core.utilities.imports import _GRAPH_AVAILABLE | ||
from flash.core.utilities.providers import _PYTORCH_GEOMETRIC | ||
|
||
if _GRAPH_AVAILABLE: | ||
from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE | ||
|
||
MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we could integrate GraphGym directly there ? |
||
else: | ||
MODELS = {} | ||
|
||
GRAPH_BACKBONES = FlashRegistry("backbones") | ||
|
||
|
||
def _load_graph_backbone( | ||
model_name: str, | ||
in_channels: int, | ||
hidden_channels: int = 512, | ||
num_layers: int = 4, | ||
): | ||
model = MODELS[model_name] | ||
return model(in_channels, hidden_channels, num_layers) | ||
|
||
|
||
for model_name in MODELS.keys(): | ||
GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from flash.graph.classification.data import GraphClassificationData # noqa: E402 | ||
from flash.graph.embedding.model import GraphEmbedder # noqa: E402 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we automate this and get all their models ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not possible to do cleanly at the moment as the models package also contains many other things that wouldn't work in the GraphClassifier