Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs: Add ModelHub documentation #6591

Merged
merged 26 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `cat` aggregation type to the `HeteroConv` class so that features can be concatenated during grouping ([#6634](https://github.com/pyg-team/pytorch_geometric/pull/6634))
- Added `torch.compile` support and benchmark study ([#6610](https://github.com/pyg-team/pytorch_geometric/pull/6610))
- Added the `AntiSymmetricConv` layer ([#6577](https://github.com/pyg-team/pytorch_geometric/pull/6577))
- Added a mixin for Huggingface model hub integration ([#5930](https://github.com/pyg-team/pytorch_geometric/pull/5930))
- Added a mixin for Huggingface model hub integration ([#5930](https://github.com/pyg-team/pytorch_geometric/pull/5930), [#6591](https://github.com/pyg-team/pytorch_geometric/pull/6591))
- Added support for accelerated GNN layers in `nn.conv.cugraph` via `cugraph-ops` ([#6278](https://github.com/pyg-team/pytorch_geometric/pull/6278), [#6388](https://github.com/pyg-team/pytorch_geometric/pull/6388), [#6412](https://github.com/pyg-team/pytorch_geometric/pull/6412))
- Added accelerated `index_sort` function from `pyg-lib` for faster sorting ([#6554](https://github.com/pyg-team/pytorch_geometric/pull/6554))
- Fix incorrect device in `EquilibriumAggregration` ([#6560](https://github.com/pyg-team/pytorch_geometric/pull/6560))
Expand Down
10 changes: 7 additions & 3 deletions docs/source/modules/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,15 @@ DataParallel Layers

.. automodule:: torch_geometric.nn.data_parallel
:members:
:undoc-members:
:exclude-members: training

Model Hub
---------

.. automodule:: torch_geometric.nn.model_hub
:members:

Model Summary
-------------

.. autofunction:: torch_geometric.nn.summary.summary
.. automodule:: torch_geometric.nn.summary
:members:
2 changes: 2 additions & 0 deletions torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .to_hetero_with_bases_transformer import to_hetero_with_bases
from .to_fixed_size_transformer import to_fixed_size
from .encoding import PositionalEncoding
from .model_hub import PyGModelHubMixin
from .summary import summary

from .aggr import * # noqa
Expand All @@ -26,5 +27,6 @@
'to_hetero_with_bases',
'to_fixed_size',
'PositionalEncoding',
'PyGModelHubMixin',
'summary',
]
190 changes: 83 additions & 107 deletions torch_geometric/nn/model_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,55 @@


class PyGModelHubMixin(ModelHubMixin):
r"""
Mixin for saving and loading models to
`Huggingface Model Hub <https://huggingface.co/docs/hub/index>`.

Sample code for saving a :obj:`Node2Vec` model to the model hub:
r"""A mixin for saving and loading models to the
`Huggingface Model Hub <https://huggingface.co/docs/hub/index>`_.

.. code-block:: python

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vec
from torch_geometric.nn.model_hub import PyGModelHubMixin

# Define your class with the mixin:
class N2V(Node2Vec, PyGModelHubMixin):
def __init__(self,model_name, dataset_name, model_kwargs):
Node2Vec.__init__(self,**model_kwargs)
PyGModelHubMixin.__init__(self, model_name,
dataset_name, model_kwargs)

# instantiate your model:
n2v = N2V(model_name='node2vec',
dataset_name='Cora', model_kwargs=dict(
edge_index=data.edge_index, embedding_dim=128,
walk_length=20, context_size=10, walks_per_node=10,
num_negative_samples=1, p=1, q=1, sparse=True))

# train model
...

# push to Huggingface:
repo_id = ... # your repo id
n2v.save_pretrained(local_file_path, push_to_hub=True,
repo_id=repo_id)

# Load the model for inference:
# The required arguments are the repo id/local folder, and any model
# initialisation arguments that are not native python types (e.g
# Node2Vec requires the edge_index argument which is a tensor--
# this is not saved in model hub)

model = N2V.from_pretrained(repo_id,
model_name='node2vec', dataset_name='Cora',
edge_index=data.edge_index)

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vec, PyGModelHubMixin

# Define your class with the mixin:
class N2V(Node2Vec, PyGModelHubMixin):
def __init__(self,model_name, dataset_name, model_kwargs):
Node2Vec.__init__(self,**model_kwargs)
PyGModelHubMixin.__init__(self, model_name,
dataset_name, model_kwargs)

# Instantiate your model:
n2v = N2V(model_name='node2vec',
dataset_name='Cora', model_kwargs=dict(
edge_index=data.edge_index, embedding_dim=128,
walk_length=20, context_size=10, walks_per_node=10,
num_negative_samples=1, p=1, q=1, sparse=True))

# Train the model:
...

# Push to the HuggingFace hub:
repo_id = ... # your repo id
n2v.save_pretrained(
local_file_path,
push_to_hub=True,
repo_id=repo_id,
)

..note::
At the moment the model card is fairly basic. Override the
:obj:`construct_model_card` method if you want a more detailed
model card
# Load the model for inference:
# The required arguments are the repo id/local folder, and any model
# initialisation arguments that are not native python types (e.g
# Node2Vec requires the edge_index argument which is not stored in the
# model hub).
model = N2V.from_pretrained(
repo_id,
model_name='node2vec',
dataset_name='Cora',
edge_index=data.edge_index,
)

Args:
model_name (str): Name of the model shown on the model card
on hugging face hub.
model_name (str): Name of the model.
dataset_name (str): Name of the dataset the model was trained against.
model_kwargs (Dict): Arguments to initialise the Pyg model.
model_kwargs (Dict[str, Any]): The arguments to initialise the model.
"""
def __init__(self, model_name: str, dataset_name: str, model_kwargs: Dict):
ModelHubMixin.__init__(self)
Expand Down Expand Up @@ -100,32 +94,25 @@ def construct_model_card(self, model_name: str, dataset_name: str) -> Any:
return card

def _save_pretrained(self, save_directory: Union[Path, str]):
r"""
Args:
save_directory (Path or str): local filepath to
save model state dict.
"""
path = os.path.join(save_directory, MODEL_WEIGHTS_NAME)
model_to_save = self.module if hasattr(self, "module") else self
torch.save(model_to_save.state_dict(), path)

def save_pretrained(self, save_directory: Union[str, Path],
push_to_hub: bool = False,
repo_id: Optional[str] = None, **kwargs):
r"""
Save a trained model to a local directory or to huggingface model hub.
r"""Save a trained model to a local directory or to the HuggingFace
model hub.

Args:
save_directory (str, Path): The directory where weights are saved,
to a file called :obj:`"model.pth"`.
push_to_hub(bool): If :obj:`True`, push the model to the
model hub. (default: :obj:`False`)
save_directory (str): The directory where weights are saved.
push_to_hub (bool, optional): If :obj:`True`, push the model to the
HuggingFace model hub. (default: :obj:`False`)
repo_id (str, optional): The repository name in the hub.
If not provided will default to the name of
:obj:`save_directory` in your namespace.
(default: :obj:`None`)
**kwargs: Additional keyword arguments passed along to
:obj:`huggingface_hub.ModelHubMixin.save_pretrained`.
:obj:`save_directory` in your namespace. (default: :obj:`None`)
**kwargs: Additional keyword arguments passed to
:meth:`huggingface_hub.ModelHubMixin.save_pretrained`.
"""
config = self.model_config
# due to way huggingface hub handles the loading/saving of models,
Expand Down Expand Up @@ -158,7 +145,6 @@ def _from_pretrained(
strict=False,
**model_kwargs,
):
r"""Load trained model."""
map_location = torch.device(map_location)

if os.path.isdir(model_id):
Expand Down Expand Up @@ -200,57 +186,47 @@ def from_pretrained(
cache_dir: Optional[str] = None,
local_files_only: bool = False,
**model_kwargs,
):
r"""
Download and instantiate a model from the Hugging Face Hub.
) -> Any:
r"""Downloads and instantiates a model from the HuggingFace hub.

Args:
pretrained_model_name_or_path (str, Path):
Can be either:
- A string, the `model id` of a pretrained model
hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level,
like `bert-base-uncased`, or namespaced under a
user or organization name, like
`dbmdz/bert-base-german-cased`.
- You can add `revision` by appending `@` at the end
of model_id simply like this:
`dbmdz/bert-base-german-cased@main` Revision is
the specific model version to use. It can be a
branch name, a tag name, or a commit id, since we
use a git-based system for storing models and
other artifacts on huggingface.co, so `revision`
can be any identifier allowed by git.
- A path to a `directory` containing model weights
saved using
[`~transformers.PreTrainedModel.save_pretrained`],
e.g., `./my_model_directory/`.
- `None` if you are both providing the configuration
and state dictionary (resp. with keyword arguments
:obj:`config` and :obj:`state_dict`).
force_download (bool): Whether to force the (re-)download of the
model weights and configuration files, overriding the cached
versions if they exist. (default: :obj:`False`)
resume_download (bool): Whether to delete incompletely received
files. Will attempt to resume the download if such a
file exists.(default: :obj:`False`)
pretrained_model_name_or_path (str): Can be either:

- The :obj:`model_id` of a pretrained model hosted inside the
HuggingFace hub.

- You can add a :obj:`revision` by appending :obj:`@` at the
end of :obj:`model_id` to load a specific model version.

- A path to a directory containing the saved model weights.

- :obj:`None` if you are both providing the configuration
:obj:`config` and state dictionary :obj:`state_dict`.

force_download (bool, optional): Whether to force the
(re-)download of the model weights and configuration files,
overriding the cached versions if they exist.
(default: :obj:`False`)
resume_download (bool, optional): Whether to delete incompletely
received files. Will attempt to resume the download if such a
file exists. (default: :obj:`False`)
proxies (Dict[str, str], optional): A dictionary of proxy servers
to use by protocol or endpoint,
e.g.,`{'http': 'foo.bar:3128', 'http://host': 'foo.bar:4012'}`.
to use by protocol or endpoint, *e.g.*,
:obj:`{'http': 'foo.bar:3128', 'http://host': 'foo.bar:4012'}`.
The proxies are used on each request. (default: :obj:`None`)
token (str, bool, optional): The token to use as HTTP bearer
authorization for remote files. If `True`, will use the token
generated when running `transformers-cli login` (stored in
`~/.huggingface`). It is **required** if you
token (str or bool, optional): The token to use as HTTP bearer
authorization for remote files. If set to :obj:`True`, will use
the token generated when running :obj:`transformers-cli login`
(stored in :obj:`~/.huggingface`). It is **required** if you
want to use a private model. (default: :obj:`None`)
cache_dir (str, Path, optional): Path to a directory in which a
cache_dir (str, optional): The path to a directory in which a
downloaded model configuration should be cached if the
standard cache should not be used. (default: :obj:`None`)
local_files_only(bool): Whether to only look at local files
(i.e., do not try to download the model).
local_files_only (bool, optional): Whether to only look at local
files, *i.e.* do not try to download the model.
(default: :obj:`False`)
**model_kwargs: Keyword arguments passed along to
model during initialization. (default: :obj:`None`)
**model_kwargs: Additional keyword arguments passed to the
model during initialization.
"""
return super().from_pretrained(
pretrained_model_name_or_path,
Expand Down