-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
S-Prompts for ViT and Text Transformers #388
Changes from 100 commits
7bdcab0
962fe19
4be7d52
1082d74
c77bfe7
bd6194f
d3d17fd
6222576
5c1b225
482a337
5b4564b
500f093
b97510c
4949f68
1dd0784
8b40951
f6d0042
ef04cc8
7689574
2dc0745
d5da9a3
8911f2f
595f5d9
9b019bf
055b626
f295c2d
e1dd41d
0a2239d
3142c50
e0d8a12
b685d98
cc33efc
2baccff
06411ea
18ab312
f0e9fbe
a47005e
d1eafdf
c5d316b
e4bf81f
6a83a7c
4cafcff
1b8b28c
7491921
dda2a12
b1bfd8c
e57d5d6
c71db9a
796202e
b731148
87dd8bb
5dadf60
1a0ee30
71c8f2d
d1a67c1
07605d1
c667119
4a52a07
9b938f6
8ac6f49
bff8202
03d046b
b2c1013
c698187
14ea3bd
29082cb
2727e3d
4878ade
2667aa0
858c058
4766c50
cf2f7f2
39ce029
028ec7f
b6ff6e9
37393f0
8e78dca
cfd47c7
f05c383
03fc742
bb3df13
aca2873
47345f0
93bd6b0
e123788
34e9f49
39dbab7
cc3f622
088945b
b336e41
7da478c
c64fe4a
4ccde7b
0f963e2
57f0465
0495640
161da77
41cbe1c
16c185d
d462447
446b8b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -367,6 +367,8 @@ def __init__( | |
def prepare_data(self) -> None: | ||
"""Download DomainNet dataset for given domain.""" | ||
file_name = f"{self.data_id}.zip" | ||
# update dataset name: | ||
self._dataset_name = self.data_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this here? |
||
url = "http://csr.bu.edu/ftp/visda/2019/multi-source/" | ||
if self.data_id in ["clipart", "painting"]: | ||
url = os.path.join(url, "groundtruth") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import logging | ||
import math | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from renate.models.layers.shared_linear import SharedMultipleLinear | ||
from renate.models.prediction_strategies import PredictionStrategy | ||
from renate.models.task_identification_strategies import TaskPrototypes | ||
|
||
from .l2p import PromptedTransformer | ||
from .base import RenateBenchmarkingModule | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class PromptPool(nn.Module): | ||
"""Implements a pool of prompts to be used in for S-Prompts. | ||
|
||
Args: | ||
prompt_size: Equivalent to number of input tokens used per update . Defaults to 10. | ||
embedding_size: Hidden size of the transformer used.. Defaults to 768. | ||
current_update_id: Current update it. Used to init number of prompts. Defaults to 0. | ||
""" | ||
|
||
def __init__( | ||
self, prompt_size: int = 10, embedding_size: int = 768, current_update_id: int = 0 | ||
) -> None: | ||
super().__init__() | ||
self._M = prompt_size | ||
self._embedding_size = embedding_size | ||
self._curr_task = current_update_id | ||
|
||
self._pool = nn.ParameterDict() | ||
for id in range(self._curr_task): | ||
# This always needs to be intialized as the torch's state dict only restores for an | ||
# existing Parameter. | ||
self._pool[f"{id}"] = nn.Parameter( | ||
torch.nn.init.kaiming_uniform_( | ||
torch.empty((self._M, self._embedding_size)), a=math.sqrt(5) | ||
) | ||
) | ||
|
||
self._pool.requires_grad_(True) | ||
|
||
def forward(self, id: int) -> torch.nn.Parameter: | ||
return self._pool[f"{id}"] | ||
|
||
def get_params(self, id: int) -> List[torch.nn.Parameter]: | ||
return [self._pool[f"{id}"]] | ||
|
||
def increment_task(self) -> None: | ||
self._pool[f"{len(self._pool)}"] = nn.Parameter( | ||
torch.empty((self._M, self._embedding_size)).uniform_(-1, 1) | ||
) | ||
self._pool.requires_grad_(True) | ||
|
||
|
||
class SPromptTransformer(RenateBenchmarkingModule): | ||
"""Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al ."S-prompts | ||
learning with pre-trained transformers: An occam’s razor for domain incremental learning." | ||
Advances in Neural Information Processing Systems 35 (2022): 5682-5695. | ||
|
||
Args: | ||
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub | ||
to use. | ||
image_size: Image size. Used if `pretrained_model_name_or_path` is not set . | ||
patch_size: Patch size to be extracted. Used if `pretrained_model_name_or_path` is not set . | ||
num_layers: Num of transformer layers. Used only if `pretrained_model_name_or_path` is not | ||
set . | ||
num_heads: Num heads in MHSA. Used only if `pretrained_model_name_or_path` is not set . | ||
hidden_dim: Hidden dimension of transformers. Used only if `pretrained_model_name_or_path` | ||
is not set . | ||
mlp_dim: _description_. Used only if `pretrained_model_name_or_path` is not set . | ||
dropout: _description_. Used only if `pretrained_model_name_or_path` is not set . | ||
attention_dropout: _description_. Used only if `pretrained_model_name_or_path` is not set . | ||
num_outputs: Number of output classes of the output. Defaults to 10. | ||
prediction_strategy: Continual learning strategies may alter the prediction at train or test | ||
time. Defaults to None. | ||
add_icarl_class_means: If ``True``, additional parameters used only by the | ||
``ICaRLModelUpdater`` are added. Only required when using that updater. | ||
prompt_size: Equivalent to number of input tokens used per update . Defaults to 10. | ||
task_id: Internal variable used to increment update id. Shouldn't be set by user. | ||
Defaults to 0. | ||
clusters_per_task: Number clusters in k-means used for task identification. Defaults to 5. | ||
per_task_classifier: Flag to share or use a common classifier head for all tasks. | ||
Defaults to False. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
pretrained_model_name_or_path="google/vit-base-patch16-224", | ||
image_size: int = 32, | ||
patch_size: int = 4, | ||
num_layers: int = 12, | ||
num_heads: int = 12, | ||
hidden_dim: int = 768, | ||
mlp_dim: int = 3072, | ||
dropout: float = 0.1, | ||
attention_dropout: float = 0.1, | ||
num_outputs: int = 10, | ||
prediction_strategy: Optional[PredictionStrategy] = None, | ||
add_icarl_class_means: bool = True, | ||
prompt_size: int = 10, | ||
task_id: int = 0, | ||
clusters_per_task: int = 5, | ||
per_task_classifier: bool = False, | ||
): | ||
transformer = PromptedTransformer( | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
image_size=image_size, | ||
patch_size=patch_size, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_dim=hidden_dim, | ||
mlp_dim=mlp_dim, | ||
dropout=dropout, | ||
attention_dropout=attention_dropout, | ||
num_outputs=num_outputs, | ||
prediction_strategy=prediction_strategy, | ||
add_icarl_class_means=add_icarl_class_means, | ||
) | ||
super().__init__( | ||
embedding_size=transformer.transformer._embedding_size, | ||
num_outputs=num_outputs, | ||
constructor_arguments=dict( | ||
**transformer.transformer._constructor_arguments, | ||
prompt_size=prompt_size, | ||
task_id=task_id, | ||
clusters_per_task=clusters_per_task, | ||
per_task_classifier=per_task_classifier, | ||
), | ||
prediction_strategy=prediction_strategy, | ||
add_icarl_class_means=add_icarl_class_means, | ||
) | ||
self._M = prompt_size | ||
self._task_id = task_id | ||
self._per_task_classifier = per_task_classifier | ||
logger.warning(f"Task id is {self._task_id}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this line? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
prompt_pool = PromptPool( | ||
prompt_size=self._M, | ||
embedding_size=self._embedding_size, | ||
current_update_id=self._task_id, | ||
) | ||
|
||
self._backbone = nn.ModuleDict({"transformer": transformer, "prompt_pool": prompt_pool}) | ||
self._task_id_method = TaskPrototypes( | ||
task_id=task_id, | ||
clusters_per_task=clusters_per_task, | ||
embedding_size=self._embedding_size, | ||
) | ||
self._backbone["transformer"].requires_grad_(False) | ||
self._backbone["prompt_pool"].requires_grad_(True) | ||
|
||
# self._backbone["transformer"].transformer._backbone.enable_gradient_checkpointing() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleanup? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
# with this, we make task_params as identities, and use only this. | ||
self._backbone["classifier"] = SharedMultipleLinear( | ||
self._embedding_size, | ||
self._num_outputs, | ||
share_parameters=not self._per_task_classifier, | ||
num_updates=self._task_id + 1, | ||
) | ||
|
||
self._tasks_params = nn.ModuleDict( | ||
{k: nn.Identity() for k, _ in self._tasks_params.items()} | ||
) | ||
|
||
self._backbone.forward = self.forward_for_monkey_patching | ||
self.task_ids = None | ||
|
||
def increment_task(self) -> None: | ||
# This cannot be a part of add_task_params as the super.__init__ function calls | ||
# add_task_params, and thus we would be trying parameters to the non-existent | ||
# self.s_prompts | ||
self._backbone["prompt_pool"].increment_task() | ||
|
||
def forward_for_monkey_patching( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prepend There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
self, x: Union[torch.Tensor, Dict[str, Any]], task_id: str = None | ||
) -> torch.Tensor: | ||
prompt = None | ||
task_ids = None | ||
if not self.training: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move below for readability:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
task_ids = self._task_id_method.infer_task(self._backbone["transformer"](x)) | ||
if self.training: | ||
prompt = self._backbone["prompt_pool"](self._task_id) | ||
elif task_ids is not None: | ||
prompt = torch.stack([self._backbone["prompt_pool"](i) for i in task_ids]) | ||
self.task_ids = task_ids.detach().cpu().numpy() | ||
|
||
features = self._backbone["transformer"](x, prompt) | ||
|
||
# additional logic for separate classifiers | ||
# a. This forward returns logits directly, and the RenateBenchmarkingModule's _task_params | ||
# now are identities. Thus, the overall operation is still the network forward pass. | ||
# b. Additional handling of params is not needed as backbone's params will return all the | ||
# necessary elements. | ||
|
||
if self.training: | ||
logits = self._backbone["classifier"][f"{self._task_id}"](features) | ||
elif task_ids is not None: | ||
logits = torch.cat( | ||
[ | ||
self._backbone["classifier"][f"{t}"](feat.unsqueeze(0)) | ||
for t, feat in zip(task_ids, features) | ||
] | ||
) | ||
else: | ||
logits = self._backbone["classifier"]["0"](features) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we remove the hard-coding of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is the update/task_id variable converted into a string. 0 just implies that we are in the first |
||
|
||
return logits | ||
|
||
def update_task_identifier(self, features: torch.Tensor, labels: torch.Tensor) -> None: | ||
self._task_id_method.update_task_prototypes(features, labels) | ||
|
||
def set_extra_state(self, state: Any, decode=True): | ||
super().set_extra_state(state, decode) | ||
# once this is set (after loading. increase that by one.) | ||
self._constructor_arguments["task_id"] = self._task_id + 1 | ||
|
||
def features(self, x: torch.Tensor) -> torch.Tensor: | ||
return self._backbone["transformer"](x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is about supported algorithms, it should list S-Prompts and give guidance how to use it with SPeft.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.