From c54887c0e30dab52f8e2a38e55c5e9f21e347779 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Mon, 4 Dec 2023 18:08:58 +0100 Subject: [PATCH] S-Prompts for ViT and Text Transformers (#388) --- doc/benchmarking/renate_benchmarks.rst | 7 + doc/getting_started/supported_algorithms.rst | 5 +- .../benchmark/datasets/vision_datasets.py | 2 + src/renate/benchmark/experiment_config.py | 16 ++ src/renate/benchmark/models/__init__.py | 5 +- src/renate/benchmark/models/spromptmodel.py | 222 ++++++++++++++++++ src/renate/cli/parsing_functions.py | 8 + src/renate/defaults.py | 3 + src/renate/models/layers/shared_linear.py | 50 ++++ .../models/task_identification_strategies.py | 97 ++++++++ src/renate/updaters/experimental/speft.py | 198 ++++++++++++++++ test/renate/benchmark/models/test_sprompt.py | 20 ++ test/renate/models/test_shared_linear.py | 17 ++ .../models/test_task_identification_strat.py | 27 +++ 14 files changed, 675 insertions(+), 2 deletions(-) create mode 100644 src/renate/benchmark/models/spromptmodel.py create mode 100644 src/renate/models/layers/shared_linear.py create mode 100644 src/renate/models/task_identification_strategies.py create mode 100644 src/renate/updaters/experimental/speft.py create mode 100644 test/renate/benchmark/models/test_sprompt.py create mode 100644 test/renate/models/test_shared_linear.py create mode 100644 test/renate/models/test_task_identification_strat.py diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index 3bb602e6..21579398 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -105,6 +105,13 @@ The full list of models and model names including a short description is provide * ``pool_selection_size``: Number of prompts to select for each input from the pool. * ``prompt_size``: Number of input tokens each prompt is equivalent to. * ``prompt_key_dim``: Dimenensionality of the features used for prompt matching. + * - `~renate.benchmark.models.spromptmodel.SPromptTransformer` + - `S-Prompt Transformer `_. + - * ``pretrained_model_name_or_path``: Hugging Face `transformer ID `__. + * ``num_outputs``: The number of classes. + * ``prompt_size``: Number of input tokens each prompt is equivalent to. + * ``clusters_per_task``: Number of clusters for K-Means in task identification. + * ``per_task_classifier``: Flag to share or use individual classifier per task. .. _benchmarking-renate-benchmarks-datasets: diff --git a/doc/getting_started/supported_algorithms.rst b/doc/getting_started/supported_algorithms.rst index 02c0ce98..9c89c7a9 100644 --- a/doc/getting_started/supported_algorithms.rst +++ b/doc/getting_started/supported_algorithms.rst @@ -41,7 +41,10 @@ using Renate (e.g., using :py:func:`~renate.training.training.run_training_job`; - A class that implements a Learning to Prompt method for ViTs. The methods trains only the input prompts that are sampled from a prompt pool in an input dependent fashion. * - ``"LearningToPromptReplay"`` - :py:class:`LearningToPromptLearner ` - - A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER" + - A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER". + * - ``"S-Prompts"`` + - :py:class:`SPeft ` + - A class that (currently) implements S-Prompts method for memory-free continual learning when used with `SPromptTransformer` model. The method trains a set of input prompts in an update dependent fashion. * - ``"Avalanche-ER"`` - :py:class:`AvalancheReplayLearner ` - A wrapper which gives access to Experience Replay as implemented in the Avalanche library. This method is the equivalent to our Offline-ER. diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index 7b575c9b..c57031f9 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -367,6 +367,8 @@ def __init__( def prepare_data(self) -> None: """Download DomainNet dataset for given domain.""" file_name = f"{self.data_id}.zip" + # update dataset name: + self._dataset_name = self.data_id url = "http://csr.bu.edu/ftp/visda/2019/multi-source/" if self.data_id in ["clipart", "painting"]: url = os.path.join(url, "groundtruth") diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 9c712f56..00549b49 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -52,6 +52,8 @@ from renate.models import RenateModule from renate.models.prediction_strategies import ICaRLClassificationStrategy +from renate.benchmark.models.spromptmodel import SPromptTransformer + models = { "MultiLayerPerceptron": MultiLayerPerceptron, "ResNet18CIFAR": ResNet18CIFAR, @@ -68,6 +70,7 @@ "VisionTransformerH14": VisionTransformerH14, "HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer, "LearningToPromptTransformer": LearningToPromptTransformer, + "SPromptTransformer": SPromptTransformer, } @@ -81,6 +84,9 @@ def model_fn( hidden_size: Optional[Tuple[int]] = None, dataset_name: Optional[str] = None, pretrained_model_name_or_path: Optional[str] = None, + prompt_size: int = 10, + clusters_per_task: int = 5, + per_task_classifier: bool = True, ) -> RenateModule: """Returns a model instance.""" if model_name not in models: @@ -110,6 +116,16 @@ def model_fn( f"LearningToPromptTransformer, but model name specified is {model_name}." ) model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path + elif (updater is not None) and ("SPeft" in updater): + if not model_name.startswith("SPrompt"): + raise ValueError( + "SPrompt model updater is designed to work only with " + f"SPromptTransformer, but model name specified is {model_name}." + ) + model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path + model_kwargs["prompt_size"] = prompt_size + model_kwargs["clusters_per_task"] = clusters_per_task + model_kwargs["per_task_classifier"] = per_task_classifier if model_state_url is None: model = model_class(**model_kwargs) else: diff --git a/src/renate/benchmark/models/__init__.py b/src/renate/benchmark/models/__init__.py index edfdf6bb..f343432d 100644 --- a/src/renate/benchmark/models/__init__.py +++ b/src/renate/benchmark/models/__init__.py @@ -9,7 +9,8 @@ ResNet50, ResNet50CIFAR, ) -from renate.benchmark.models.l2p import LearningToPromptTransformer +from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptedTransformer +from renate.benchmark.models.spromptmodel import SPromptTransformer from renate.benchmark.models.vision_transformer import ( VisionTransformerB16, VisionTransformerB32, @@ -28,6 +29,8 @@ "ResNet50", "ResNet50CIFAR", "LearningToPromptTransformer", + "PromptedTransformer", + "SPromptTransformer", "VisionTransformerB16", "VisionTransformerB32", "VisionTransformerCIFAR", diff --git a/src/renate/benchmark/models/spromptmodel.py b/src/renate/benchmark/models/spromptmodel.py new file mode 100644 index 00000000..dac81b81 --- /dev/null +++ b/src/renate/benchmark/models/spromptmodel.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import logging +import math +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from renate.models.layers.shared_linear import SharedMultipleLinear +from renate.models.prediction_strategies import PredictionStrategy +from renate.models.task_identification_strategies import TaskPrototypes + +from .l2p import PromptedTransformer +from .base import RenateBenchmarkingModule + +logger = logging.getLogger(__name__) + + +class PromptPool(nn.Module): + """Implements a pool of prompts to be used in for S-Prompts. + + Args: + prompt_size: Equivalent to number of input tokens used per update . Defaults to 10. + embedding_size: Hidden size of the transformer used.. Defaults to 768. + current_update_id: Current update it. Used to init number of prompts. Defaults to 0. + """ + + def __init__( + self, prompt_size: int = 10, embedding_size: int = 768, current_update_id: int = 0 + ) -> None: + super().__init__() + self._M = prompt_size + self._embedding_size = embedding_size + self._curr_task = current_update_id + + self._pool = nn.ParameterDict() + for id in range(self._curr_task): + # This always needs to be intialized as the torch's state dict only restores for an + # existing Parameter. + self._pool[f"{id}"] = nn.Parameter( + torch.nn.init.kaiming_uniform_( + torch.empty((self._M, self._embedding_size)), a=math.sqrt(5) + ) + ) + + self._pool.requires_grad_(True) + + def forward(self, id: int) -> torch.nn.Parameter: + return self._pool[f"{id}"] + + def get_params(self, id: int) -> List[torch.nn.Parameter]: + return [self._pool[f"{id}"]] + + def increment_task(self) -> None: + self._pool[f"{len(self._pool)}"] = nn.Parameter( + torch.empty((self._M, self._embedding_size)).uniform_(-1, 1) + ) + self._pool.requires_grad_(True) + + +class SPromptTransformer(RenateBenchmarkingModule): + """Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al ."S-prompts + learning with pre-trained transformers: An occam’s razor for domain incremental learning." + Advances in Neural Information Processing Systems 35 (2022): 5682-5695. + + Args: + pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub + to use. + image_size: Image size. Used if `pretrained_model_name_or_path` is not set . + patch_size: Patch size to be extracted. Used if `pretrained_model_name_or_path` is not set . + num_layers: Num of transformer layers. Used only if `pretrained_model_name_or_path` is not + set . + num_heads: Num heads in MHSA. Used only if `pretrained_model_name_or_path` is not set . + hidden_dim: Hidden dimension of transformers. Used only if `pretrained_model_name_or_path` + is not set . + mlp_dim: _description_. Used only if `pretrained_model_name_or_path` is not set . + dropout: _description_. Used only if `pretrained_model_name_or_path` is not set . + attention_dropout: _description_. Used only if `pretrained_model_name_or_path` is not set . + num_outputs: Number of output classes of the output. Defaults to 10. + prediction_strategy: Continual learning strategies may alter the prediction at train or test + time. Defaults to None. + add_icarl_class_means: If ``True``, additional parameters used only by the + ``ICaRLModelUpdater`` are added. Only required when using that updater. + prompt_size: Equivalent to number of input tokens used per update . Defaults to 10. + task_id: Internal variable used to increment update id. Shouldn't be set by user. + Defaults to 0. + clusters_per_task: Number clusters in k-means used for task identification. Defaults to 5. + per_task_classifier: Flag to share or use a common classifier head for all tasks. + Defaults to False. + """ + + def __init__( + self, + pretrained_model_name_or_path="google/vit-base-patch16-224", + image_size: int = 32, + patch_size: int = 4, + num_layers: int = 12, + num_heads: int = 12, + hidden_dim: int = 768, + mlp_dim: int = 3072, + dropout: float = 0.1, + attention_dropout: float = 0.1, + num_outputs: int = 10, + prediction_strategy: Optional[PredictionStrategy] = None, + add_icarl_class_means: bool = True, + prompt_size: int = 10, + task_id: int = 0, + clusters_per_task: int = 5, + per_task_classifier: bool = False, + ): + transformer = PromptedTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_outputs=num_outputs, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + ) + super().__init__( + embedding_size=transformer.transformer._embedding_size, + num_outputs=num_outputs, + constructor_arguments=dict( + **transformer.transformer._constructor_arguments, + prompt_size=prompt_size, + task_id=task_id, + clusters_per_task=clusters_per_task, + per_task_classifier=per_task_classifier, + ), + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + ) + self._M = prompt_size + self._task_id = task_id + self._per_task_classifier = per_task_classifier + + prompt_pool = PromptPool( + prompt_size=self._M, + embedding_size=self._embedding_size, + current_update_id=self._task_id, + ) + + self._backbone = nn.ModuleDict({"transformer": transformer, "prompt_pool": prompt_pool}) + self._task_id_method = TaskPrototypes( + task_id=task_id, + clusters_per_task=clusters_per_task, + embedding_size=self._embedding_size, + ) + self._backbone["transformer"].requires_grad_(False) + self._backbone["prompt_pool"].requires_grad_(True) + + self._backbone["classifier"] = SharedMultipleLinear( + self._embedding_size, + self._num_outputs, + share_parameters=not self._per_task_classifier, + num_updates=self._task_id + 1, + ) + + self._tasks_params = nn.ModuleDict( + {k: nn.Identity() for k, _ in self._tasks_params.items()} + ) + + self._backbone.forward = self._forward_for_monkey_patching + self.task_ids = None + + def increment_task(self) -> None: + # This cannot be a part of add_task_params as the super.__init__ function calls + # add_task_params, and thus we would be trying parameters to the non-existent + # self.s_prompts + self._backbone["prompt_pool"].increment_task() + + def _forward_for_monkey_patching( + self, x: Union[torch.Tensor, Dict[str, Any]], task_id: str = None + ) -> torch.Tensor: + prompt = None + task_ids = None + if self.training: + prompt = self._backbone["prompt_pool"](self._task_id) + else: + task_ids = self._task_id_method.infer_task(self._backbone["transformer"](x)) + if task_ids is not None: + prompt = torch.stack([self._backbone["prompt_pool"](i) for i in task_ids]) + self.task_ids = task_ids.detach().cpu().numpy() + + features = self._backbone["transformer"](x, prompt) + + # additional logic for separate classifiers + # a. This forward returns logits directly, and the RenateBenchmarkingModule's _task_params + # now are identities. Thus, the overall operation is still the network forward pass. + # b. Additional handling of params is not needed as backbone's params will return all the + # necessary elements. + + if self.training: + logits = self._backbone["classifier"][f"{self._task_id}"](features) + elif task_ids is not None: + logits = torch.cat( + [ + self._backbone["classifier"][f"{t}"](feat.unsqueeze(0)) + for t, feat in zip(task_ids, features) + ] + ) + else: + logits = self._backbone["classifier"]["0"](features) + + return logits + + def update_task_identifier(self, features: torch.Tensor, labels: torch.Tensor) -> None: + self._task_id_method.update_task_prototypes(features, labels) + + def set_extra_state(self, state: Any, decode=True): + super().set_extra_state(state, decode) + # once this is set (after loading. increase that by one.) + self._constructor_arguments["task_id"] = self._task_id + 1 + + def features(self, x: torch.Tensor) -> torch.Tensor: + return self._backbone["transformer"](x) diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 0d94173e..f14a5c10 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -28,6 +28,7 @@ ) from renate.updaters.experimental.offline_er import OfflineExperienceReplayModelUpdater from renate.updaters.experimental.repeated_distill import RepeatedDistillationModelUpdater +from renate.updaters.experimental.speft import SPeftModelUpdater from renate.updaters.model_updater import ModelUpdater REQUIRED_ARGS_GROUP = "Required Arguments" @@ -70,6 +71,8 @@ def get_updater_and_learner_kwargs( elif args.updater == "LearningToPromptReplay": learner_args = learner_args + ["prompt_sim_loss_weight", "memory_size", "memory_batch_size"] updater_class = LearningToPromptReplayModelUpdater + elif args.updater == "SPeft": + updater_class = SPeftModelUpdater elif args.updater == "DER": learner_args = base_er_args + ["alpha", "beta"] updater_class = DarkExperienceReplayModelUpdater @@ -488,6 +491,10 @@ def _add_l2preplay_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: _add_offline_er_arguments(arguments) +def _add_speft_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: + pass + + def _add_gdumb_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: """A helper function that adds GDumb arguments.""" _add_replay_learner_arguments(arguments) @@ -973,6 +980,7 @@ def get_scheduler_kwargs( "ER": _add_experience_replay_arguments, "LearningToPrompt": _add_l2p_arguments, "LearningToPromptReplay": _add_l2preplay_arguments, + "SPeft": _add_speft_arguments, "DER": _add_dark_experience_replay_arguments, "POD-ER": _add_pod_experience_replay_arguments, "CLS-ER": _add_cls_experience_replay_arguments, diff --git a/src/renate/defaults.py b/src/renate/defaults.py index e0854022..a92f7b34 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -112,6 +112,9 @@ # L2p PROMPT_SIM_LOSS_WEIGHT = 0.5 +# S-prompt +CLUSTERS_PER_TASK = 5 + def scheduler(config_space: Dict[str, Any], mode: str, metric: str): return FIFOScheduler( diff --git a/src/renate/models/layers/shared_linear.py b/src/renate/models/layers/shared_linear.py new file mode 100644 index 00000000..bef9d8ed --- /dev/null +++ b/src/renate/models/layers/shared_linear.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import torch.nn as nn + + +class SharedMultipleLinear(nn.ModuleDict): + """This implements a linear classification layer for multiple tasks (updates). + This linear layer can be shared across all tasks or can have a separate layer per task. + This follows the `_task_params` in the `RenateBenchmarkingModule` that is a `nn.ModuleDict` + that holds a classifier per task (as in TIL). + + Args: + in_features: size of each input sample + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + share_parameters: Flag whether to share parameters or use individual linears per task. + The interface remains identical, and the underlying linear layer is shared (or not). + num_updates: Number of updates that have happened/is happening. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + share_parameters: bool = True, + num_updates: int = 0, + ) -> None: + self._share_parameters = share_parameters + self.in_features = in_features + self.out_features = out_features + self.bias = bias + + super().__init__() + for _ in range(num_updates): + self.increment_task() + + def increment_task(self) -> None: + currlen = len(self) + if self._share_parameters: + self[f"{currlen}"] = ( + self[list(self.keys())[0]] + if currlen > 0 + else nn.Linear(in_features=self.in_features, out_features=self.out_features) + ) + else: + self[f"{currlen}"] = nn.Linear( + in_features=self.in_features, out_features=self.out_features + ) diff --git a/src/renate/models/task_identification_strategies.py b/src/renate/models/task_identification_strategies.py new file mode 100644 index 00000000..8cf0dbee --- /dev/null +++ b/src/renate/models/task_identification_strategies.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Union + +import numpy as np +import numpy.typing as npt +import torch +import torch.nn as nn +from sklearn.cluster import KMeans + + +class TaskEstimator(nn.Module, ABC): + """An ABC that all task estimator methods inherit. + + They implement two methods `update_task_prototypes` and `infer_task`. + """ + + @abstractmethod + def update_task_prototypes(self): + return + + @abstractmethod + def infer_task(self): + return + + +class TaskPrototypes(TaskEstimator): + """Task identification method proposed in S-Prompts. + + Args: + task_id: The current update id of the method. Required to deserialize. + clusters_per_task: Number of clusters to use in K-means. + embedding_size: Embedding size of the transformer features. + """ + + def __init__(self, task_id, clusters_per_task, embedding_size) -> None: + super().__init__() + self.register_buffer( + "_training_feat_centroids", + torch.empty(task_id * clusters_per_task, embedding_size), + ) + self.register_buffer( + "_training_feat_task_ids", + torch.full( + (self._training_feat_centroids.size(0),), fill_value=task_id, dtype=torch.long + ), + ) + self._clusters_per_task = clusters_per_task + self._task_id = task_id + self._embedding_size = embedding_size + + @torch.no_grad() + def update_task_prototypes( + self, + features: Union[torch.Tensor, npt.ArrayLike], + labels: Union[torch.Tensor, npt.ArrayLike], + ) -> None: + # At training. + if isinstance(features, torch.Tensor): + features = features.cpu().numpy() + + # l2 normalize features: + features = features / np.power(np.einsum("ij, ij -> i", features, features), 0.5)[:, None] + + centroids = torch.from_numpy( + KMeans(n_clusters=self._clusters_per_task, random_state=0) + .fit(features) + .cluster_centers_ + ).to(self._training_feat_centroids.device) + + self._training_feat_centroids = torch.cat( + [ + self._training_feat_centroids, + centroids, + ] + ) + self._training_feat_task_ids = torch.cat( + [ + self._training_feat_task_ids, + torch.full( + (centroids.size(0),), + fill_value=self._task_id, + dtype=torch.int8, + device=self._training_feat_task_ids.device, + ), + ] + ) + + def infer_task(self, features: torch.Tensor) -> torch.Tensor: + # At inference. + if self._training_feat_centroids.numel() > 0: + features = torch.nn.functional.normalize(features) + nearest_p_inds = torch.cdist(features, self._training_feat_centroids, p=2).argmin(1) + return self._training_feat_task_ids[nearest_p_inds] + else: + return None diff --git a/src/renate/updaters/experimental/speft.py b/src/renate/updaters/experimental/speft.py new file mode 100644 index 00000000..e6c0a095 --- /dev/null +++ b/src/renate/updaters/experimental/speft.py @@ -0,0 +1,198 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn as nn +import torchmetrics +from pytorch_lightning.loggers.logger import Logger +from torch.nn import Parameter +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import Dataset + +from renate import defaults +from renate.benchmark.models.spromptmodel import SPromptTransformer +from renate.models import RenateModule +from renate.updaters.learner import Learner +from renate.updaters.model_updater import SingleTrainingLoopUpdater + + +class SPeftLearner(Learner): + """Learner to implement S-Prompts from + ```Wang, Yabin, et.al . + "S-prompts learning with pre-trained transformers: An occam’s razor for domain incremental learning." # noqa: E501 + Advances in Neural Information Processing Systems 35 (2022): 5682-5695.``` + + + Args: + model: The SPromptTransformer model to be trained. + loss_fn: Loss function to be trained with. + optimizer: Partial optimizer used to create an optimizer by passing the model parameters. + learning_rate_scheduler: Partial object of learning rate scheduler that will be created by + passing the optimizer. + learning_rate_scheduler_interval: When to update the learning rate scheduler. + Options: `epoch` and `step`. + batch_size: Training batch size. + train_transform: The transformation applied during training. + train_target_transform: The target transformation applied during testing. + test_transform: The transformation at test time. + test_target_transform: The target transformation at test time. + logged_metrics: Metrics logged additional to the default ones. + seed: See :func:`renate.models.utils.get_generator`. + mask_unused_classes: Masking logits corresponding to unused classes. Useful only for class + incremental problems. Defaults to defaults.MASK_UNUSED_CLASSES. + """ + + def __init__( + self, + model: RenateModule, + loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], + learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 + batch_size: int = defaults.BATCH_SIZE, + train_transform: Optional[Callable] = None, + train_target_transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + test_target_transform: Optional[Callable] = None, + logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, + seed: int = defaults.SEED, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, + ) -> None: + if not isinstance(model, SPromptTransformer): + raise ValueError( + "SPrompt Learner can only be used with a SPromptTransformer model." + f"But got {type(model)}" + ) + super().__init__( + model, + loss_fn, + optimizer, + learning_rate_scheduler, + learning_rate_scheduler_interval, + batch_size, + train_transform, + train_target_transform, + test_transform, + test_target_transform, + logged_metrics, + seed, + mask_unused_classes, + ) + + def on_model_update_start( + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, + ) -> None: + """A custom on_model_update_start hook for S-Peft methods. + + Here, we iterate oer the train data set and extract features. These features used to compute + the task prototypes by the `update_task_identifier` call. Having this function in the model + update start instead of end results in val metrics being reflective of test accuracy. + """ + super().on_model_update_start( + train_dataset, val_dataset, train_dataset_collate_fn, val_dataset_collate_fn, task_id + ) + ## k-means + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + self._model.to(device) + features, labels = [], [] + with torch.inference_mode(): + for x, y in self.train_dataloader(): + features.append(self._model.features(x.to(device)).cpu()) + labels.append(y) + features = torch.cat(features) + labels = torch.cat(labels) + self._model.update_task_identifier(features=features, labels=labels) + + def setup(self, stage: str) -> None: + # We dont support distributed + assert ( + self.trainer.world_size == 1 + ), "SPrompt learner does not support Multi-GPU training yet." + if stage == "fit": + # This needs to run before configure optimizers is called. The only hook is setup("fit") + self._model.increment_task() + + def optimizer_zero_grad( + self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int + ) -> None: + """Explicitly setting grads to None instead of zero.""" + optimizer.zero_grad(set_to_none=True) + + +class SPeftModelUpdater(SingleTrainingLoopUpdater): + def __init__( + self, + model: RenateModule, + loss_fn: torch.nn.Module, + optimizer: Callable[[List[nn.Parameter]], Optimizer], + batch_size: int = defaults.BATCH_SIZE, + seed: int = defaults.SEED, + learner_kwargs: Optional[Dict[str, Any]] = None, + input_state_folder: Optional[str] = None, + output_state_folder: Optional[str] = None, + max_epochs: int = defaults.MAX_EPOCHS, + learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 + train_transform: Optional[Callable] = None, + train_target_transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + test_target_transform: Optional[Callable] = None, + buffer_transform: Optional[Callable] = None, + buffer_target_transform: Optional[Callable] = None, + metric: Optional[str] = None, + mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", + logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, + early_stopping_enabled: bool = False, + logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), + accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, + devices: Optional[int] = None, + strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, + deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, + ): + learner_kwargs = { + "batch_size": batch_size, + "seed": seed, + "loss_fn": loss_fn, + } + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + learner_class=SPeftLearner, + learner_kwargs=learner_kwargs, + input_state_folder=input_state_folder, + output_state_folder=output_state_folder, + max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, + train_transform=train_transform, + train_target_transform=train_target_transform, + test_transform=test_transform, + test_target_transform=test_target_transform, + buffer_transform=buffer_transform, + buffer_target_transform=buffer_target_transform, + metric=metric, + mode=mode, + logged_metrics=logged_metrics, + early_stopping_enabled=early_stopping_enabled, + logger=logger, + accelerator=accelerator, + devices=devices, + strategy=strategy, + precision=precision, + deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, + ) diff --git a/test/renate/benchmark/models/test_sprompt.py b/test/renate/benchmark/models/test_sprompt.py new file mode 100644 index 00000000..03eefbf2 --- /dev/null +++ b/test/renate/benchmark/models/test_sprompt.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from renate.benchmark.models.spromptmodel import PromptPool + + +def test_prompt_pool(): + prompt_size = 7 + embedding_size = 12 + curr_update_id = 3 + pool = PromptPool( + prompt_size=prompt_size, embedding_size=embedding_size, current_update_id=curr_update_id + ) + + for i in range(curr_update_id): + assert pool(i).shape == (prompt_size, embedding_size) + assert pool.get_params(i)[0].shape == (prompt_size, embedding_size) + + pool.increment_task() + assert len(pool._pool) == curr_update_id + 1 diff --git a/test/renate/models/test_shared_linear.py b/test/renate/models/test_shared_linear.py new file mode 100644 index 00000000..4960d2d6 --- /dev/null +++ b/test/renate/models/test_shared_linear.py @@ -0,0 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from renate.models.layers.shared_linear import SharedMultipleLinear + + +@pytest.mark.parametrize("share_parameters", [True, False]) +def test_shared_multiple_classifier(share_parameters): + model = SharedMultipleLinear(5, 3, share_parameters=share_parameters, num_updates=10) + num_elems = sum(x.numel() for x in model.parameters()) + assert num_elems == [10 * 5 * 3 + 10 * 3, 5 * 3 + 3][share_parameters] + + model.increment_task() + num_elems = sum(x.numel() for x in model.parameters()) + assert num_elems == [(10 + 1) * 5 * 3 + (10 + 1) * 3, 5 * 3 + 3][share_parameters] diff --git a/test/renate/models/test_task_identification_strat.py b/test/renate/models/test_task_identification_strat.py new file mode 100644 index 00000000..a575d0c0 --- /dev/null +++ b/test/renate/models/test_task_identification_strat.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import torch +from sklearn.cluster import KMeans + +from renate.models.task_identification_strategies import TaskPrototypes + + +def test_task_prototypes(): + data = torch.nn.functional.normalize(torch.rand(10, 3)) + labels = torch.arange(start=0, end=data.size(0)) + task_proto = TaskPrototypes(0, 0, data.size(1)) + # lets attach + task_proto._training_feat_centroids = data + task_proto._training_feat_task_ids = labels + + test_data = torch.nn.functional.normalize(torch.rand(5, 3)) + predictions = task_proto.infer_task(test_data) + + kmeans = KMeans(n_clusters=data.size(0)) + kmeans.cluster_centers_ = data.numpy() + kmeans.labels_ = labels.numpy() + kmeans._n_threads = 1 + + gnd_truth = kmeans.predict(test_data.numpy()) + + assert (predictions.numpy() == gnd_truth).all()