Skip to content

Commit

Permalink
S-Prompts for ViT and Text Transformers (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhuteja12 authored Dec 4, 2023
1 parent 17733be commit c54887c
Show file tree
Hide file tree
Showing 14 changed files with 675 additions and 2 deletions.
7 changes: 7 additions & 0 deletions doc/benchmarking/renate_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ The full list of models and model names including a short description is provide
* ``pool_selection_size``: Number of prompts to select for each input from the pool.
* ``prompt_size``: Number of input tokens each prompt is equivalent to.
* ``prompt_key_dim``: Dimenensionality of the features used for prompt matching.
* - `~renate.benchmark.models.spromptmodel.SPromptTransformer`
- `S-Prompt Transformer <https://arxiv.org/abs/2207.12819>`_.
- * ``pretrained_model_name_or_path``: Hugging Face `transformer ID <https://huggingface.co/models>`__.
* ``num_outputs``: The number of classes.
* ``prompt_size``: Number of input tokens each prompt is equivalent to.
* ``clusters_per_task``: Number of clusters for K-Means in task identification.
* ``per_task_classifier``: Flag to share or use individual classifier per task.

.. _benchmarking-renate-benchmarks-datasets:

Expand Down
5 changes: 4 additions & 1 deletion doc/getting_started/supported_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ using Renate (e.g., using :py:func:`~renate.training.training.run_training_job`;
- A class that implements a Learning to Prompt method for ViTs. The methods trains only the input prompts that are sampled from a prompt pool in an input dependent fashion.
* - ``"LearningToPromptReplay"``
- :py:class:`LearningToPromptLearner <renate.updaters.experimental.l2p.LearningToPromptReplayLearner>`
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER"
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER".
* - ``"S-Prompts"``
- :py:class:`SPeft <renate.updaters.experimental.speft.SPeftLearner>`
- A class that (currently) implements S-Prompts method for memory-free continual learning when used with `SPromptTransformer` model. The method trains a set of input prompts in an update dependent fashion.
* - ``"Avalanche-ER"``
- :py:class:`AvalancheReplayLearner <renate.updaters.avalanche.learner.AvalancheReplayLearner>`
- A wrapper which gives access to Experience Replay as implemented in the Avalanche library. This method is the equivalent to our Offline-ER.
Expand Down
2 changes: 2 additions & 0 deletions src/renate/benchmark/datasets/vision_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def __init__(
def prepare_data(self) -> None:
"""Download DomainNet dataset for given domain."""
file_name = f"{self.data_id}.zip"
# update dataset name:
self._dataset_name = self.data_id
url = "http://csr.bu.edu/ftp/visda/2019/multi-source/"
if self.data_id in ["clipart", "painting"]:
url = os.path.join(url, "groundtruth")
Expand Down
16 changes: 16 additions & 0 deletions src/renate/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
from renate.models import RenateModule
from renate.models.prediction_strategies import ICaRLClassificationStrategy

from renate.benchmark.models.spromptmodel import SPromptTransformer

models = {
"MultiLayerPerceptron": MultiLayerPerceptron,
"ResNet18CIFAR": ResNet18CIFAR,
Expand All @@ -68,6 +70,7 @@
"VisionTransformerH14": VisionTransformerH14,
"HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer,
"LearningToPromptTransformer": LearningToPromptTransformer,
"SPromptTransformer": SPromptTransformer,
}


Expand All @@ -81,6 +84,9 @@ def model_fn(
hidden_size: Optional[Tuple[int]] = None,
dataset_name: Optional[str] = None,
pretrained_model_name_or_path: Optional[str] = None,
prompt_size: int = 10,
clusters_per_task: int = 5,
per_task_classifier: bool = True,
) -> RenateModule:
"""Returns a model instance."""
if model_name not in models:
Expand Down Expand Up @@ -110,6 +116,16 @@ def model_fn(
f"LearningToPromptTransformer, but model name specified is {model_name}."
)
model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path
elif (updater is not None) and ("SPeft" in updater):
if not model_name.startswith("SPrompt"):
raise ValueError(
"SPrompt model updater is designed to work only with "
f"SPromptTransformer, but model name specified is {model_name}."
)
model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path
model_kwargs["prompt_size"] = prompt_size
model_kwargs["clusters_per_task"] = clusters_per_task
model_kwargs["per_task_classifier"] = per_task_classifier
if model_state_url is None:
model = model_class(**model_kwargs)
else:
Expand Down
5 changes: 4 additions & 1 deletion src/renate/benchmark/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
ResNet50,
ResNet50CIFAR,
)
from renate.benchmark.models.l2p import LearningToPromptTransformer
from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptedTransformer
from renate.benchmark.models.spromptmodel import SPromptTransformer
from renate.benchmark.models.vision_transformer import (
VisionTransformerB16,
VisionTransformerB32,
Expand All @@ -28,6 +29,8 @@
"ResNet50",
"ResNet50CIFAR",
"LearningToPromptTransformer",
"PromptedTransformer",
"SPromptTransformer",
"VisionTransformerB16",
"VisionTransformerB32",
"VisionTransformerCIFAR",
Expand Down
222 changes: 222 additions & 0 deletions src/renate/benchmark/models/spromptmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn

from renate.models.layers.shared_linear import SharedMultipleLinear
from renate.models.prediction_strategies import PredictionStrategy
from renate.models.task_identification_strategies import TaskPrototypes

from .l2p import PromptedTransformer
from .base import RenateBenchmarkingModule

logger = logging.getLogger(__name__)


class PromptPool(nn.Module):
"""Implements a pool of prompts to be used in for S-Prompts.
Args:
prompt_size: Equivalent to number of input tokens used per update . Defaults to 10.
embedding_size: Hidden size of the transformer used.. Defaults to 768.
current_update_id: Current update it. Used to init number of prompts. Defaults to 0.
"""

def __init__(
self, prompt_size: int = 10, embedding_size: int = 768, current_update_id: int = 0
) -> None:
super().__init__()
self._M = prompt_size
self._embedding_size = embedding_size
self._curr_task = current_update_id

self._pool = nn.ParameterDict()
for id in range(self._curr_task):
# This always needs to be intialized as the torch's state dict only restores for an
# existing Parameter.
self._pool[f"{id}"] = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty((self._M, self._embedding_size)), a=math.sqrt(5)
)
)

self._pool.requires_grad_(True)

def forward(self, id: int) -> torch.nn.Parameter:
return self._pool[f"{id}"]

def get_params(self, id: int) -> List[torch.nn.Parameter]:
return [self._pool[f"{id}"]]

def increment_task(self) -> None:
self._pool[f"{len(self._pool)}"] = nn.Parameter(
torch.empty((self._M, self._embedding_size)).uniform_(-1, 1)
)
self._pool.requires_grad_(True)


class SPromptTransformer(RenateBenchmarkingModule):
"""Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al ."S-prompts
learning with pre-trained transformers: An occam’s razor for domain incremental learning."
Advances in Neural Information Processing Systems 35 (2022): 5682-5695.
Args:
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub
to use.
image_size: Image size. Used if `pretrained_model_name_or_path` is not set .
patch_size: Patch size to be extracted. Used if `pretrained_model_name_or_path` is not set .
num_layers: Num of transformer layers. Used only if `pretrained_model_name_or_path` is not
set .
num_heads: Num heads in MHSA. Used only if `pretrained_model_name_or_path` is not set .
hidden_dim: Hidden dimension of transformers. Used only if `pretrained_model_name_or_path`
is not set .
mlp_dim: _description_. Used only if `pretrained_model_name_or_path` is not set .
dropout: _description_. Used only if `pretrained_model_name_or_path` is not set .
attention_dropout: _description_. Used only if `pretrained_model_name_or_path` is not set .
num_outputs: Number of output classes of the output. Defaults to 10.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time. Defaults to None.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
prompt_size: Equivalent to number of input tokens used per update . Defaults to 10.
task_id: Internal variable used to increment update id. Shouldn't be set by user.
Defaults to 0.
clusters_per_task: Number clusters in k-means used for task identification. Defaults to 5.
per_task_classifier: Flag to share or use a common classifier head for all tasks.
Defaults to False.
"""

def __init__(
self,
pretrained_model_name_or_path="google/vit-base-patch16-224",
image_size: int = 32,
patch_size: int = 4,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.1,
attention_dropout: float = 0.1,
num_outputs: int = 10,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
prompt_size: int = 10,
task_id: int = 0,
clusters_per_task: int = 5,
per_task_classifier: bool = False,
):
transformer = PromptedTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
super().__init__(
embedding_size=transformer.transformer._embedding_size,
num_outputs=num_outputs,
constructor_arguments=dict(
**transformer.transformer._constructor_arguments,
prompt_size=prompt_size,
task_id=task_id,
clusters_per_task=clusters_per_task,
per_task_classifier=per_task_classifier,
),
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
self._M = prompt_size
self._task_id = task_id
self._per_task_classifier = per_task_classifier

prompt_pool = PromptPool(
prompt_size=self._M,
embedding_size=self._embedding_size,
current_update_id=self._task_id,
)

self._backbone = nn.ModuleDict({"transformer": transformer, "prompt_pool": prompt_pool})
self._task_id_method = TaskPrototypes(
task_id=task_id,
clusters_per_task=clusters_per_task,
embedding_size=self._embedding_size,
)
self._backbone["transformer"].requires_grad_(False)
self._backbone["prompt_pool"].requires_grad_(True)

self._backbone["classifier"] = SharedMultipleLinear(
self._embedding_size,
self._num_outputs,
share_parameters=not self._per_task_classifier,
num_updates=self._task_id + 1,
)

self._tasks_params = nn.ModuleDict(
{k: nn.Identity() for k, _ in self._tasks_params.items()}
)

self._backbone.forward = self._forward_for_monkey_patching
self.task_ids = None

def increment_task(self) -> None:
# This cannot be a part of add_task_params as the super.__init__ function calls
# add_task_params, and thus we would be trying parameters to the non-existent
# self.s_prompts
self._backbone["prompt_pool"].increment_task()

def _forward_for_monkey_patching(
self, x: Union[torch.Tensor, Dict[str, Any]], task_id: str = None
) -> torch.Tensor:
prompt = None
task_ids = None
if self.training:
prompt = self._backbone["prompt_pool"](self._task_id)
else:
task_ids = self._task_id_method.infer_task(self._backbone["transformer"](x))
if task_ids is not None:
prompt = torch.stack([self._backbone["prompt_pool"](i) for i in task_ids])
self.task_ids = task_ids.detach().cpu().numpy()

features = self._backbone["transformer"](x, prompt)

# additional logic for separate classifiers
# a. This forward returns logits directly, and the RenateBenchmarkingModule's _task_params
# now are identities. Thus, the overall operation is still the network forward pass.
# b. Additional handling of params is not needed as backbone's params will return all the
# necessary elements.

if self.training:
logits = self._backbone["classifier"][f"{self._task_id}"](features)
elif task_ids is not None:
logits = torch.cat(
[
self._backbone["classifier"][f"{t}"](feat.unsqueeze(0))
for t, feat in zip(task_ids, features)
]
)
else:
logits = self._backbone["classifier"]["0"](features)

return logits

def update_task_identifier(self, features: torch.Tensor, labels: torch.Tensor) -> None:
self._task_id_method.update_task_prototypes(features, labels)

def set_extra_state(self, state: Any, decode=True):
super().set_extra_state(state, decode)
# once this is set (after loading. increase that by one.)
self._constructor_arguments["task_id"] = self._task_id + 1

def features(self, x: torch.Tensor) -> torch.Tensor:
return self._backbone["transformer"](x)
8 changes: 8 additions & 0 deletions src/renate/cli/parsing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from renate.updaters.experimental.offline_er import OfflineExperienceReplayModelUpdater
from renate.updaters.experimental.repeated_distill import RepeatedDistillationModelUpdater
from renate.updaters.experimental.speft import SPeftModelUpdater
from renate.updaters.model_updater import ModelUpdater

REQUIRED_ARGS_GROUP = "Required Arguments"
Expand Down Expand Up @@ -70,6 +71,8 @@ def get_updater_and_learner_kwargs(
elif args.updater == "LearningToPromptReplay":
learner_args = learner_args + ["prompt_sim_loss_weight", "memory_size", "memory_batch_size"]
updater_class = LearningToPromptReplayModelUpdater
elif args.updater == "SPeft":
updater_class = SPeftModelUpdater
elif args.updater == "DER":
learner_args = base_er_args + ["alpha", "beta"]
updater_class = DarkExperienceReplayModelUpdater
Expand Down Expand Up @@ -488,6 +491,10 @@ def _add_l2preplay_arguments(arguments: Dict[str, Dict[str, Any]]) -> None:
_add_offline_er_arguments(arguments)


def _add_speft_arguments(arguments: Dict[str, Dict[str, Any]]) -> None:
pass


def _add_gdumb_arguments(arguments: Dict[str, Dict[str, Any]]) -> None:
"""A helper function that adds GDumb arguments."""
_add_replay_learner_arguments(arguments)
Expand Down Expand Up @@ -973,6 +980,7 @@ def get_scheduler_kwargs(
"ER": _add_experience_replay_arguments,
"LearningToPrompt": _add_l2p_arguments,
"LearningToPromptReplay": _add_l2preplay_arguments,
"SPeft": _add_speft_arguments,
"DER": _add_dark_experience_replay_arguments,
"POD-ER": _add_pod_experience_replay_arguments,
"CLS-ER": _add_cls_experience_replay_arguments,
Expand Down
3 changes: 3 additions & 0 deletions src/renate/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@
# L2p
PROMPT_SIM_LOSS_WEIGHT = 0.5

# S-prompt
CLUSTERS_PER_TASK = 5


def scheduler(config_space: Dict[str, Any], mode: str, metric: str):
return FIFOScheduler(
Expand Down
Loading

0 comments on commit c54887c

Please sign in to comment.