Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba committed May 4, 2023
1 parent 85e4451 commit 18dcf20
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 32 deletions.
2 changes: 1 addition & 1 deletion doc/benchmarking/custom_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ with your data module as follows.
.. code-block:: python
def data_module_fn(
data_path: Union[Path, str], chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]]
data_path: str, chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]]
):
data_module = CustomDataModule(data_path=data_path, seed=seed)
return ClassIncrementalScenario(
Expand Down
8 changes: 8 additions & 0 deletions doc/benchmarking/renate_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ The full list of models and model names including a short description is provide
* - `~renate.benchmark.models.vision_transformer.VisionTransformerH14`
- Huge `Vision Transformer <https://arxiv.org/abs/2010.11929>`_ architecture for images of size 224x224 with patch size 14.
- * ``num_outputs``: Output dimensionality, for classification the number of classes.
* - `~renate.benchmark.models.transformer.HuggingFaceSequenceClassificationTransformer`
- Wrapper around Hugging Face transformers.
- * ``pretrained_model_name``: Hugging Face `transformer ID <https://huggingface.co/models>`__.
* ``num_outputs``: The number of classes.


.. _benchmarking-renate-benchmarks-datasets:
Expand Down Expand Up @@ -133,6 +137,10 @@ The following table contains the list of supported datasets.
- Image Classification
- 60k train, 10k test, 10 classes, image shape 28x28x1
- Li Deng: The MNIST Database of Handwritten Digit Images for Machine Learning Research. IEEE Signal Processing Magazine. 2012.
* - hfd-{dataset_name}
- multiple
- Any `Hugging Face dataset <https://huggingface.co/datasets>`__ can be used. Just prepend the prefix ``hfd-``, e.g., ``hfd-rotten_tomatoes``. Select input and target columns via ``config_space``, e.g., add ``"input_column": "text", "target_column": "label"`` for the `rotten_tomatoes <https://huggingface.co/datasets/rotten_tomatoes>`__ example.
- Please refer to `the official documentation <https://huggingface.co/datasets>`__.

.. _benchmarking-renate-benchmarks-scenarios:

Expand Down
18 changes: 9 additions & 9 deletions doc/getting_started/how_to_renate_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Its signature is

.. code-block:: python
def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
A :py:class:`~renate.models.renate_module.RenateModule` is a :code:`torch.nn.Module` with some
additional functionality relevant to continual learning.
Expand All @@ -29,7 +29,7 @@ method, which automatically handles model hyperparameters.

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Example
:lines: 13-39
:lines: 14-39

If you are using a torch model with **no or fixed hyperparameters**, you can use
:py:class:`~renate.models.renate_module.RenateWrapper`.
Expand All @@ -40,7 +40,7 @@ method, but simply reinstantiate your model and call :code:`load_state_dict`.
.. code-block:: python
:caption: Example
def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(28 * 28, 10) # Instantiate your torch model.
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
Expand All @@ -58,7 +58,7 @@ Its signature is

.. code-block:: python
def data_module_fn(data_path: Union[Path, str], seed: int = defaults.SEED) -> RenateDataModule:
def data_module_fn(data_path: str, seed: int = defaults.SEED) -> RenateDataModule:
:py:class:`~renate.data.data_module.RenateDataModule` provides a structured interface to
download, set up, and access train/val/test datasets.
Expand All @@ -67,7 +67,7 @@ such as data subsampling or splitting.

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Example
:lines: 42-72
:lines: 43-69

Transforms
==========
Expand Down Expand Up @@ -143,11 +143,11 @@ Let us assume we already have a config file in which we implemented a simple lin

.. code-block:: python
def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(28 * 28, 10)
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model
Expand All @@ -156,11 +156,11 @@ The natural change would be to change it to something like

.. code-block:: python
def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(num_inputs, num_outputs)
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model
Expand Down
10 changes: 5 additions & 5 deletions doc/getting_started/output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ In the following, we refer with :code:`model_fn` to the function defined by the
Output Saved Locally
~~~~~~~~~~~~~~~~~~~~

If :code:`next_state_url` is a path to a local folder, loading the updated model can be done as follows:
If :code:`output_state_url` is a path to a local folder, loading the updated model can be done as follows:

.. code-block:: python
from renate.defaults import current_state_folder, model_file
from renate.defaults import input_state_folder, model_file
my_model = model_fn(model_file(current_state_folder(next_state_url)))
my_model = model_fn(model_file(input_state_folder(output_state_url)))
Output Saved on S3
~~~~~~~~~~~~~~~~~~
Expand All @@ -46,9 +46,9 @@ If the Renate output was saved on S3, the model checkpoint :code:`model.ckpt` ca

.. code-block:: python
from renate.defaults import current_state_folder, model_file
from renate.defaults import input_state_folder, model_file
print(model_file(current_state_folder(next_state_url)))
print(model_file(input_state_folder(output_state_url)))
and then loaded via

Expand Down
7 changes: 3 additions & 4 deletions examples/simple_classifier_cifar10/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Optional

import torch
from torchvision import transforms
Expand All @@ -13,12 +12,12 @@
from renate.models import RenateModule


def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
"""Returns a model instance."""
if model_state_url is None:
model = ResNet18CIFAR()
else:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model = ResNet18CIFAR.from_state_dict(state_dict)
return model

Expand Down
7 changes: 3 additions & 4 deletions examples/train_mlp_locally/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Optional

import torch
from torchvision.transforms import transforms
Expand Down Expand Up @@ -34,14 +33,14 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) ->
return class_incremental_scenario


def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
"""Returns a model instance."""
if model_state_url is None:
model = MultiLayerPerceptron(
num_inputs=784, num_outputs=10, num_hidden_layers=2, hidden_size=128
)
else:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model = MultiLayerPerceptron.from_state_dict(state_dict)
return model

Expand Down
18 changes: 9 additions & 9 deletions src/renate/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,41 +120,41 @@ def current_timestamp() -> str:
return str(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))


def data_folder(working_directory: str):
def data_folder(working_directory: str) -> str:
return os.path.join(working_directory, "data")


def input_state_folder(working_directory: str):
def input_state_folder(working_directory: str) -> str:
return os.path.join(working_directory, "input_state")


def output_state_folder(working_directory: str):
def output_state_folder(working_directory: str) -> str:
return os.path.join(working_directory, "output_state")


def logs_folder(working_directory: str):
def logs_folder(working_directory: str) -> str:
return os.path.join(working_directory, "logs")


def model_file(state_folder: str):
def model_file(state_folder: str) -> str:
return os.path.join(state_folder, "model.ckpt")


LEARNER_CHECKPOINT_NAME = "learner.ckpt"
AVALANCHE_CHECKPOINT_NAME = "avalanche.ckpt"


def learner_state_file(state_folder: str):
def learner_state_file(state_folder: str) -> str:
return os.path.join(state_folder, LEARNER_CHECKPOINT_NAME)


def avalanche_state_file(state_folder: str):
def avalanche_state_file(state_folder: str) -> str:
return os.path.join(state_folder, AVALANCHE_CHECKPOINT_NAME)


def metric_summary_file(logs_folder: str, special_str: str = ""):
def metric_summary_file(logs_folder: str, special_str: str = "") -> str:
return os.path.join(logs_folder, f"metrics_summary{special_str}.csv")


def hpo_file(state_folder: str):
def hpo_file(state_folder: str) -> str:
return os.path.join(state_folder, "hpo.csv")

0 comments on commit 18dcf20

Please sign in to comment.