diff --git a/doc/benchmarking/custom_benchmarks.rst b/doc/benchmarking/custom_benchmarks.rst index 6076333c..65aa2dbc 100644 --- a/doc/benchmarking/custom_benchmarks.rst +++ b/doc/benchmarking/custom_benchmarks.rst @@ -40,7 +40,7 @@ with your data module as follows. .. code-block:: python def data_module_fn( - data_path: Union[Path, str], chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]] + data_path: str, chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]] ): data_module = CustomDataModule(data_path=data_path, seed=seed) return ClassIncrementalScenario( diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index 036e9cdf..3f128736 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -93,6 +93,10 @@ The full list of models and model names including a short description is provide * - `~renate.benchmark.models.vision_transformer.VisionTransformerH14` - Huge `Vision Transformer `_ architecture for images of size 224x224 with patch size 14. - * ``num_outputs``: Output dimensionality, for classification the number of classes. + * - `~renate.benchmark.models.transformer.HuggingFaceSequenceClassificationTransformer` + - Wrapper around Hugging Face transformers. + - * ``pretrained_model_name``: Hugging Face `transformer ID `__. + * ``num_outputs``: The number of classes. .. _benchmarking-renate-benchmarks-datasets: @@ -133,6 +137,10 @@ The following table contains the list of supported datasets. - Image Classification - 60k train, 10k test, 10 classes, image shape 28x28x1 - Li Deng: The MNIST Database of Handwritten Digit Images for Machine Learning Research. IEEE Signal Processing Magazine. 2012. + * - hfd-{dataset_name} + - multiple + - Any `Hugging Face dataset `__ can be used. Just prepend the prefix ``hfd-``, e.g., ``hfd-rotten_tomatoes``. Select input and target columns via ``config_space``, e.g., add ``"input_column": "text", "target_column": "label"`` for the `rotten_tomatoes `__ example. + - Please refer to `the official documentation `__. .. _benchmarking-renate-benchmarks-scenarios: diff --git a/doc/getting_started/how_to_renate_config.rst b/doc/getting_started/how_to_renate_config.rst index e246d5fa..3320cb12 100644 --- a/doc/getting_started/how_to_renate_config.rst +++ b/doc/getting_started/how_to_renate_config.rst @@ -16,7 +16,7 @@ Its signature is .. code-block:: python - def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(model_state_url: Optional[str] = None) -> RenateModule: A :py:class:`~renate.models.renate_module.RenateModule` is a :code:`torch.nn.Module` with some additional functionality relevant to continual learning. @@ -29,7 +29,7 @@ method, which automatically handles model hyperparameters. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 13-39 + :lines: 14-39 If you are using a torch model with **no or fixed hyperparameters**, you can use :py:class:`~renate.models.renate_module.RenateWrapper`. @@ -40,7 +40,7 @@ method, but simply reinstantiate your model and call :code:`load_state_dict`. .. code-block:: python :caption: Example - def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(model_state_url: Optional[str] = None) -> RenateModule: my_torch_model = torch.nn.Linear(28 * 28, 10) # Instantiate your torch model. model = RenateWrapper(my_torch_model) if model_state_url is not None: @@ -58,7 +58,7 @@ Its signature is .. code-block:: python - def data_module_fn(data_path: Union[Path, str], seed: int = defaults.SEED) -> RenateDataModule: + def data_module_fn(data_path: str, seed: int = defaults.SEED) -> RenateDataModule: :py:class:`~renate.data.data_module.RenateDataModule` provides a structured interface to download, set up, and access train/val/test datasets. @@ -67,7 +67,7 @@ such as data subsampling or splitting. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 42-72 + :lines: 43-69 Transforms ========== @@ -143,11 +143,11 @@ Let us assume we already have a config file in which we implemented a simple lin .. code-block:: python - def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(model_state_url: Optional[str] = None) -> RenateModule: my_torch_model = torch.nn.Linear(28 * 28, 10) model = RenateWrapper(my_torch_model) if model_state_url is not None: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model.load_state_dict(state_dict) return model @@ -156,11 +156,11 @@ The natural change would be to change it to something like .. code-block:: python - def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[str] = None) -> RenateModule: my_torch_model = torch.nn.Linear(num_inputs, num_outputs) model = RenateWrapper(my_torch_model) if model_state_url is not None: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model.load_state_dict(state_dict) return model diff --git a/doc/getting_started/output.rst b/doc/getting_started/output.rst index 1c722bd3..08d36061 100644 --- a/doc/getting_started/output.rst +++ b/doc/getting_started/output.rst @@ -31,13 +31,13 @@ In the following, we refer with :code:`model_fn` to the function defined by the Output Saved Locally ~~~~~~~~~~~~~~~~~~~~ -If :code:`next_state_url` is a path to a local folder, loading the updated model can be done as follows: +If :code:`output_state_url` is a path to a local folder, loading the updated model can be done as follows: .. code-block:: python - from renate.defaults import current_state_folder, model_file + from renate.defaults import input_state_folder, model_file - my_model = model_fn(model_file(current_state_folder(next_state_url))) + my_model = model_fn(model_file(input_state_folder(output_state_url))) Output Saved on S3 ~~~~~~~~~~~~~~~~~~ @@ -46,9 +46,9 @@ If the Renate output was saved on S3, the model checkpoint :code:`model.ckpt` ca .. code-block:: python - from renate.defaults import current_state_folder, model_file + from renate.defaults import input_state_folder, model_file - print(model_file(current_state_folder(next_state_url))) + print(model_file(input_state_folder(output_state_url))) and then loaded via diff --git a/examples/simple_classifier_cifar10/renate_config.py b/examples/simple_classifier_cifar10/renate_config.py index 6f70102b..563d46e7 100644 --- a/examples/simple_classifier_cifar10/renate_config.py +++ b/examples/simple_classifier_cifar10/renate_config.py @@ -1,7 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Optional import torch from torchvision import transforms @@ -13,12 +12,12 @@ from renate.models import RenateModule -def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: +def model_fn(model_state_url: Optional[str] = None) -> RenateModule: """Returns a model instance.""" if model_state_url is None: model = ResNet18CIFAR() else: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model = ResNet18CIFAR.from_state_dict(state_dict) return model diff --git a/examples/train_mlp_locally/renate_config.py b/examples/train_mlp_locally/renate_config.py index 86cfdc94..779f4ef3 100644 --- a/examples/train_mlp_locally/renate_config.py +++ b/examples/train_mlp_locally/renate_config.py @@ -1,7 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Optional import torch from torchvision.transforms import transforms @@ -34,14 +33,14 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> return class_incremental_scenario -def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: +def model_fn(model_state_url: Optional[str] = None) -> RenateModule: """Returns a model instance.""" if model_state_url is None: model = MultiLayerPerceptron( num_inputs=784, num_outputs=10, num_hidden_layers=2, hidden_size=128 ) else: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model = MultiLayerPerceptron.from_state_dict(state_dict) return model diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 20fd8a84..3131142d 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -120,23 +120,23 @@ def current_timestamp() -> str: return str(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")) -def data_folder(working_directory: str): +def data_folder(working_directory: str) -> str: return os.path.join(working_directory, "data") -def input_state_folder(working_directory: str): +def input_state_folder(working_directory: str) -> str: return os.path.join(working_directory, "input_state") -def output_state_folder(working_directory: str): +def output_state_folder(working_directory: str) -> str: return os.path.join(working_directory, "output_state") -def logs_folder(working_directory: str): +def logs_folder(working_directory: str) -> str: return os.path.join(working_directory, "logs") -def model_file(state_folder: str): +def model_file(state_folder: str) -> str: return os.path.join(state_folder, "model.ckpt") @@ -144,17 +144,17 @@ def model_file(state_folder: str): AVALANCHE_CHECKPOINT_NAME = "avalanche.ckpt" -def learner_state_file(state_folder: str): +def learner_state_file(state_folder: str) -> str: return os.path.join(state_folder, LEARNER_CHECKPOINT_NAME) -def avalanche_state_file(state_folder: str): +def avalanche_state_file(state_folder: str) -> str: return os.path.join(state_folder, AVALANCHE_CHECKPOINT_NAME) -def metric_summary_file(logs_folder: str, special_str: str = ""): +def metric_summary_file(logs_folder: str, special_str: str = "") -> str: return os.path.join(logs_folder, f"metrics_summary{special_str}.csv") -def hpo_file(state_folder: str): +def hpo_file(state_folder: str) -> str: return os.path.join(state_folder, "hpo.csv")