Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S-Prompts for ViT and Text Transformers #388

Merged
merged 101 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 100 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
7bdcab0
running l2p
prabhuteja12 Jun 22, 2023
962fe19
Merge remote-tracking branch 'origin/dev' into pt_l2p
prabhuteja12 Jun 22, 2023
4be7d52
reorganized prompt pool code
prabhuteja12 Jun 25, 2023
1082d74
minor changes to vit
prabhuteja12 Jun 26, 2023
c77bfe7
Merge remote-tracking branch 'origin/dev' into pt_l2p
prabhuteja12 Jun 27, 2023
bd6194f
l2p working version
prabhuteja12 Jun 30, 2023
d3d17fd
handling extra state
prabhuteja12 Jul 3, 2023
6222576
changing constructors args
prabhuteja12 Jul 3, 2023
5c1b225
fixes to extra state
prabhuteja12 Jul 11, 2023
482a337
prompt reimplementation
prabhuteja12 Jul 20, 2023
5b4564b
dev merge
prabhuteja12 Aug 3, 2023
500f093
first commit allowing for a class mask
prabhuteja12 Aug 8, 2023
b97510c
avalanche masking unsued classes
prabhuteja12 Aug 8, 2023
4949f68
Merge remote-tracking branch 'origin/pt_class_mask_for_CIL' into pt_l2p
prabhuteja12 Aug 8, 2023
1dd0784
Removing debug statements
prabhuteja12 Aug 8, 2023
8b40951
Merge remote-tracking branch 'origin/pt_class_mask_for_CIL' into pt_l2p
prabhuteja12 Aug 8, 2023
f6d0042
device placement fix
prabhuteja12 Aug 8, 2023
ef04cc8
addressing comments
prabhuteja12 Aug 9, 2023
7689574
Merge remote-tracking branch 'origin/pt_class_mask_for_CIL' into pt_l2p
prabhuteja12 Aug 9, 2023
2dc0745
removing legacy tensor constructors
prabhuteja12 Aug 9, 2023
d5da9a3
offline ER class mask
prabhuteja12 Aug 9, 2023
8911f2f
Merge remote-tracking branch 'origin/pt_class_mask_for_CIL' into pt_l2p
prabhuteja12 Aug 9, 2023
595f5d9
working l2p code
prabhuteja12 Aug 10, 2023
9b019bf
flake8
prabhuteja12 Aug 10, 2023
055b626
small changes to tests
prabhuteja12 Aug 11, 2023
f295c2d
code comments in l2p
prabhuteja12 Aug 11, 2023
e1dd41d
ViT outputs pooled or full feats flag
prabhuteja12 Aug 11, 2023
0a2239d
unified mask, abstractions
prabhuteja12 Aug 14, 2023
3142c50
typing fixes
prabhuteja12 Aug 14, 2023
e0d8a12
Merge remote-tracking branch 'origin/pt_class_mask_for_CIL' into pt_l2p
prabhuteja12 Aug 14, 2023
b685d98
avalanche bug fix
prabhuteja12 Aug 14, 2023
cc33efc
Merge remote-tracking branch 'origin/pt_class_mask_for_CIL' into pt_l2p
prabhuteja12 Aug 14, 2023
2baccff
linting + small comments
prabhuteja12 Aug 14, 2023
06411ea
documentation addition
prabhuteja12 Aug 14, 2023
18ab312
Merge remote-tracking branch 'origin/dev' into pt_class_mask_for_CIL
prabhuteja12 Aug 14, 2023
f0e9fbe
Merge remote-tracking branch 'origin/dev' into pt_l2p
prabhuteja12 Aug 15, 2023
a47005e
bug fix for tests
prabhuteja12 Aug 15, 2023
d1eafdf
Merge branch 'pt_class_mask_for_CIL' of https://github.com/awslabs/Re…
prabhuteja12 Aug 15, 2023
c5d316b
change default Vit
prabhuteja12 Aug 15, 2023
e4bf81f
addressing comments
prabhuteja12 Aug 15, 2023
6a83a7c
Merge remote-tracking branch 'origin/pt_class_mask_for_CIL' into pt_l2p
prabhuteja12 Aug 15, 2023
4cafcff
cleanup
prabhuteja12 Aug 15, 2023
1b8b28c
error checking in experiment config
prabhuteja12 Aug 15, 2023
7491921
Merge remote-tracking branch 'origin/dev' into pt_l2p
prabhuteja12 Aug 15, 2023
dda2a12
bug fix
prabhuteja12 Aug 15, 2023
b1bfd8c
transforms for cifar l2p
prabhuteja12 Aug 15, 2023
e57d5d6
flake8
prabhuteja12 Aug 15, 2023
c71db9a
bug fixes
prabhuteja12 Aug 15, 2023
796202e
disabling pooler for cifar vit
prabhuteja12 Aug 15, 2023
b731148
ViT parameter count adjustment
prabhuteja12 Aug 15, 2023
87dd8bb
infer embedding dim from model config
prabhuteja12 Aug 16, 2023
5dadf60
Merge remote-tracking branch 'origin/dev' into pt_l2p
prabhuteja12 Aug 16, 2023
1a0ee30
renaming input model type argument
prabhuteja12 Aug 17, 2023
71c8f2d
l2p for text transformers with a unified interface
prabhuteja12 Aug 17, 2023
d1a67c1
test for l2p learners
prabhuteja12 Aug 18, 2023
07605d1
removing comment
prabhuteja12 Aug 18, 2023
c667119
transforms for ViT
prabhuteja12 Aug 18, 2023
4a52a07
Merge remote-tracking branch 'origin/dev' into pt_l2p
prabhuteja12 Aug 18, 2023
9b938f6
reconciling memory batch size
prabhuteja12 Aug 18, 2023
8ac6f49
adding batch memory frac in l2p
prabhuteja12 Aug 20, 2023
bff8202
basic implementation of sprompt
prabhuteja12 Aug 21, 2023
03d046b
fixing augmentations
prabhuteja12 Aug 21, 2023
b2c1013
Merge branch 'pt_l2p' into pt_s_prompts
prabhuteja12 Aug 21, 2023
c698187
fixed sizes of centroid buffer
prabhuteja12 Aug 22, 2023
14ea3bd
Merge remote-tracking branch 'origin/dev' into pt_s_prompts
prabhuteja12 Aug 22, 2023
29082cb
Merge remote-tracking branch 'origin/dev' into pt_s_prompts
prabhuteja12 Aug 25, 2023
2727e3d
functioning sprompts
prabhuteja12 Aug 28, 2023
4878ade
normalization of features
prabhuteja12 Sep 7, 2023
2667aa0
abstracting prompting transformer
prabhuteja12 Sep 10, 2023
858c058
Merge branch 'pt_prompt_transformer' into pt_s_prompts
prabhuteja12 Sep 10, 2023
4766c50
abstractions for conlora
prabhuteja12 Sep 13, 2023
cf2f7f2
minor changes
prabhuteja12 Sep 19, 2023
39ce029
separated linear and task id-er
prabhuteja12 Sep 20, 2023
028ec7f
parsing argument changes
prabhuteja12 Sep 20, 2023
b6ff6e9
numpy concat to torch
prabhuteja12 Sep 20, 2023
37393f0
Merge remote-tracking branch 'origin/dev' into pt_s_prompts
prabhuteja12 Sep 22, 2023
8e78dca
adding tests and marking slow ones
prabhuteja12 Sep 22, 2023
cfd47c7
removed import
prabhuteja12 Sep 22, 2023
f05c383
adding requires_grad to new prompt
prabhuteja12 Sep 22, 2023
03fc742
multiple classifiers
prabhuteja12 Sep 24, 2023
bb3df13
bug fix in shared linear
prabhuteja12 Sep 25, 2023
aca2873
removing nested tensors
prabhuteja12 Sep 25, 2023
47345f0
flake8
prabhuteja12 Sep 25, 2023
93bd6b0
fixing missing prompt
prabhuteja12 Sep 25, 2023
e123788
changing init method for prompts to kaiming uniform
prabhuteja12 Sep 26, 2023
34e9f49
Merge remote-tracking branch 'origin/dev' into pt_s_prompts
prabhuteja12 Oct 4, 2023
39dbab7
renaming modules
prabhuteja12 Oct 4, 2023
cc3f622
renaming in parsing_functions
prabhuteja12 Oct 4, 2023
088945b
bug fix
prabhuteja12 Oct 4, 2023
b336e41
minor simplifications to model
prabhuteja12 Nov 29, 2023
7da478c
Merge remote-tracking branch 'origin/dev' into pt_s_prompts
prabhuteja12 Nov 29, 2023
c64fe4a
bug fix in parsing function and optimizer zero grad
prabhuteja12 Dec 1, 2023
4ccde7b
doc strings
prabhuteja12 Dec 1, 2023
0f963e2
doc strings
prabhuteja12 Dec 1, 2023
57f0465
Documentation update
prabhuteja12 Dec 1, 2023
0495640
modifed docs and docstrings
prabhuteja12 Dec 1, 2023
161da77
changing default value of per task classifier
prabhuteja12 Dec 1, 2023
41cbe1c
fixing updater name
prabhuteja12 Dec 4, 2023
16c185d
model cleanup
prabhuteja12 Dec 4, 2023
d462447
flake8
prabhuteja12 Dec 4, 2023
446b8b2
Addressing comments
prabhuteja12 Dec 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/benchmarking/renate_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ The full list of models and model names including a short description is provide
* ``pool_selection_size``: Number of prompts to select for each input from the pool.
* ``prompt_size``: Number of input tokens each prompt is equivalent to.
* ``prompt_key_dim``: Dimenensionality of the features used for prompt matching.
* - `~renate.benchmark.models.spromptmodel.SPromptTransformer`
- `S-Prompt Transformer <https://arxiv.org/abs/2207.12819>`_.
- * ``pretrained_model_name_or_path``: Hugging Face `transformer ID <https://huggingface.co/models>`__.
* ``num_outputs``: The number of classes.
* ``prompt_size``: Number of input tokens each prompt is equivalent to.
* ``clusters_per_task``: Number of clusters for K-Means in task identification.
* ``per_task_classifier``: Flag to share or use individual classifier per task.

.. _benchmarking-renate-benchmarks-datasets:

Expand Down
5 changes: 4 additions & 1 deletion doc/getting_started/supported_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ using Renate (e.g., using :py:func:`~renate.training.training.run_training_job`;
- A class that implements a Learning to Prompt method for ViTs. The methods trains only the input prompts that are sampled from a prompt pool in an input dependent fashion.
* - ``"LearningToPromptReplay"``
- :py:class:`LearningToPromptLearner <renate.updaters.experimental.l2p.LearningToPromptReplayLearner>`
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER"
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is about supported algorithms, it should list S-Prompts and give guidance how to use it with SPeft.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

* - ``"SPeft"``
- :py:class `SPeft <renate.updaters.experimental.speft.SPeftLearner>`
- A class that (currently) implements S-Prompts method for memory-free continual learning.
* - ``"Avalanche-ER"``
- :py:class:`AvalancheReplayLearner <renate.updaters.avalanche.learner.AvalancheReplayLearner>`
- A wrapper which gives access to Experience Replay as implemented in the Avalanche library. This method is the equivalent to our Offline-ER.
Expand Down
2 changes: 2 additions & 0 deletions src/renate/benchmark/datasets/vision_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def __init__(
def prepare_data(self) -> None:
"""Download DomainNet dataset for given domain."""
file_name = f"{self.data_id}.zip"
# update dataset name:
self._dataset_name = self.data_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this here? prepare_data is called only once. therefore, one could set this already as part of the constructor. should this replace L365?

url = "http://csr.bu.edu/ftp/visda/2019/multi-source/"
if self.data_id in ["clipart", "painting"]:
url = os.path.join(url, "groundtruth")
Expand Down
16 changes: 16 additions & 0 deletions src/renate/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
from renate.models import RenateModule
from renate.models.prediction_strategies import ICaRLClassificationStrategy

from renate.benchmark.models.spromptmodel import SPromptTransformer

models = {
"MultiLayerPerceptron": MultiLayerPerceptron,
"ResNet18CIFAR": ResNet18CIFAR,
Expand All @@ -68,6 +70,7 @@
"VisionTransformerH14": VisionTransformerH14,
"HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer,
"LearningToPromptTransformer": LearningToPromptTransformer,
"SPromptTransformer": SPromptTransformer,
}


Expand All @@ -81,6 +84,9 @@ def model_fn(
hidden_size: Optional[Tuple[int]] = None,
dataset_name: Optional[str] = None,
pretrained_model_name_or_path: Optional[str] = None,
prompt_size: int = 10,
clusters_per_task: int = 5,
per_task_classifier: bool = True,
) -> RenateModule:
"""Returns a model instance."""
if model_name not in models:
Expand Down Expand Up @@ -110,6 +116,16 @@ def model_fn(
f"LearningToPromptTransformer, but model name specified is {model_name}."
)
model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path
elif (updater is not None) and ("SPeft" in updater):
if not model_name.startswith("SPrompt"):
raise ValueError(
"SPrompt model updater is designed to work only with "
f"SPromptTransformer, but model name specified is {model_name}."
)
model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path
model_kwargs["prompt_size"] = prompt_size
model_kwargs["clusters_per_task"] = clusters_per_task
model_kwargs["per_task_classifier"] = per_task_classifier
if model_state_url is None:
model = model_class(**model_kwargs)
else:
Expand Down
5 changes: 4 additions & 1 deletion src/renate/benchmark/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
ResNet50,
ResNet50CIFAR,
)
from renate.benchmark.models.l2p import LearningToPromptTransformer
from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptedTransformer
from renate.benchmark.models.spromptmodel import SPromptTransformer
from renate.benchmark.models.vision_transformer import (
VisionTransformerB16,
VisionTransformerB32,
Expand All @@ -28,6 +29,8 @@
"ResNet50",
"ResNet50CIFAR",
"LearningToPromptTransformer",
"PromptedTransformer",
"SPromptTransformer",
"VisionTransformerB16",
"VisionTransformerB32",
"VisionTransformerCIFAR",
Expand Down
225 changes: 225 additions & 0 deletions src/renate/benchmark/models/spromptmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn

from renate.models.layers.shared_linear import SharedMultipleLinear
from renate.models.prediction_strategies import PredictionStrategy
from renate.models.task_identification_strategies import TaskPrototypes

from .l2p import PromptedTransformer
from .base import RenateBenchmarkingModule

logger = logging.getLogger(__name__)


class PromptPool(nn.Module):
"""Implements a pool of prompts to be used in for S-Prompts.

Args:
prompt_size: Equivalent to number of input tokens used per update . Defaults to 10.
embedding_size: Hidden size of the transformer used.. Defaults to 768.
current_update_id: Current update it. Used to init number of prompts. Defaults to 0.
"""

def __init__(
self, prompt_size: int = 10, embedding_size: int = 768, current_update_id: int = 0
) -> None:
super().__init__()
self._M = prompt_size
self._embedding_size = embedding_size
self._curr_task = current_update_id

self._pool = nn.ParameterDict()
for id in range(self._curr_task):
# This always needs to be intialized as the torch's state dict only restores for an
# existing Parameter.
self._pool[f"{id}"] = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty((self._M, self._embedding_size)), a=math.sqrt(5)
)
)

self._pool.requires_grad_(True)

def forward(self, id: int) -> torch.nn.Parameter:
return self._pool[f"{id}"]

def get_params(self, id: int) -> List[torch.nn.Parameter]:
return [self._pool[f"{id}"]]

def increment_task(self) -> None:
self._pool[f"{len(self._pool)}"] = nn.Parameter(
torch.empty((self._M, self._embedding_size)).uniform_(-1, 1)
)
self._pool.requires_grad_(True)


class SPromptTransformer(RenateBenchmarkingModule):
"""Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al ."S-prompts
learning with pre-trained transformers: An occam’s razor for domain incremental learning."
Advances in Neural Information Processing Systems 35 (2022): 5682-5695.

Args:
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub
to use.
image_size: Image size. Used if `pretrained_model_name_or_path` is not set .
patch_size: Patch size to be extracted. Used if `pretrained_model_name_or_path` is not set .
num_layers: Num of transformer layers. Used only if `pretrained_model_name_or_path` is not
set .
num_heads: Num heads in MHSA. Used only if `pretrained_model_name_or_path` is not set .
hidden_dim: Hidden dimension of transformers. Used only if `pretrained_model_name_or_path`
is not set .
mlp_dim: _description_. Used only if `pretrained_model_name_or_path` is not set .
dropout: _description_. Used only if `pretrained_model_name_or_path` is not set .
attention_dropout: _description_. Used only if `pretrained_model_name_or_path` is not set .
num_outputs: Number of output classes of the output. Defaults to 10.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time. Defaults to None.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
prompt_size: Equivalent to number of input tokens used per update . Defaults to 10.
task_id: Internal variable used to increment update id. Shouldn't be set by user.
Defaults to 0.
clusters_per_task: Number clusters in k-means used for task identification. Defaults to 5.
per_task_classifier: Flag to share or use a common classifier head for all tasks.
Defaults to False.
"""

def __init__(
self,
pretrained_model_name_or_path="google/vit-base-patch16-224",
image_size: int = 32,
patch_size: int = 4,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.1,
attention_dropout: float = 0.1,
num_outputs: int = 10,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
prompt_size: int = 10,
task_id: int = 0,
clusters_per_task: int = 5,
per_task_classifier: bool = False,
):
transformer = PromptedTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
super().__init__(
embedding_size=transformer.transformer._embedding_size,
num_outputs=num_outputs,
constructor_arguments=dict(
**transformer.transformer._constructor_arguments,
prompt_size=prompt_size,
task_id=task_id,
clusters_per_task=clusters_per_task,
per_task_classifier=per_task_classifier,
),
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
self._M = prompt_size
self._task_id = task_id
self._per_task_classifier = per_task_classifier
logger.warning(f"Task id is {self._task_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


prompt_pool = PromptPool(
prompt_size=self._M,
embedding_size=self._embedding_size,
current_update_id=self._task_id,
)

self._backbone = nn.ModuleDict({"transformer": transformer, "prompt_pool": prompt_pool})
self._task_id_method = TaskPrototypes(
task_id=task_id,
clusters_per_task=clusters_per_task,
embedding_size=self._embedding_size,
)
self._backbone["transformer"].requires_grad_(False)
self._backbone["prompt_pool"].requires_grad_(True)

# self._backbone["transformer"].transformer._backbone.enable_gradient_checkpointing()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# with this, we make task_params as identities, and use only this.
self._backbone["classifier"] = SharedMultipleLinear(
self._embedding_size,
self._num_outputs,
share_parameters=not self._per_task_classifier,
num_updates=self._task_id + 1,
)

self._tasks_params = nn.ModuleDict(
{k: nn.Identity() for k, _ in self._tasks_params.items()}
)

self._backbone.forward = self.forward_for_monkey_patching
self.task_ids = None

def increment_task(self) -> None:
# This cannot be a part of add_task_params as the super.__init__ function calls
# add_task_params, and thus we would be trying parameters to the non-existent
# self.s_prompts
self._backbone["prompt_pool"].increment_task()

def forward_for_monkey_patching(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepend _

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self, x: Union[torch.Tensor, Dict[str, Any]], task_id: str = None
) -> torch.Tensor:
prompt = None
task_ids = None
if not self.training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move below for readability:

if self.training:
  ....
else:
  task_ids = ...
  if task_ids is not None:
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

task_ids = self._task_id_method.infer_task(self._backbone["transformer"](x))
if self.training:
prompt = self._backbone["prompt_pool"](self._task_id)
elif task_ids is not None:
prompt = torch.stack([self._backbone["prompt_pool"](i) for i in task_ids])
self.task_ids = task_ids.detach().cpu().numpy()

features = self._backbone["transformer"](x, prompt)

# additional logic for separate classifiers
# a. This forward returns logits directly, and the RenateBenchmarkingModule's _task_params
# now are identities. Thus, the overall operation is still the network forward pass.
# b. Additional handling of params is not needed as backbone's params will return all the
# necessary elements.

if self.training:
logits = self._backbone["classifier"][f"{self._task_id}"](features)
elif task_ids is not None:
logits = torch.cat(
[
self._backbone["classifier"][f"{t}"](feat.unsqueeze(0))
for t, feat in zip(task_ids, features)
]
)
else:
logits = self._backbone["classifier"]["0"](features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the hard-coding of "0" somehow? what is defining that name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the update/task_id variable converted into a string. 0 just implies that we are in the first
update. Removing it will just be a cosmetic first_task_name = "0", which doesn't seem to serve any
purpose.


return logits

def update_task_identifier(self, features: torch.Tensor, labels: torch.Tensor) -> None:
self._task_id_method.update_task_prototypes(features, labels)

def set_extra_state(self, state: Any, decode=True):
super().set_extra_state(state, decode)
# once this is set (after loading. increase that by one.)
self._constructor_arguments["task_id"] = self._task_id + 1

def features(self, x: torch.Tensor) -> torch.Tensor:
return self._backbone["transformer"](x)
8 changes: 8 additions & 0 deletions src/renate/cli/parsing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from renate.updaters.experimental.offline_er import OfflineExperienceReplayModelUpdater
from renate.updaters.experimental.repeated_distill import RepeatedDistillationModelUpdater
from renate.updaters.experimental.speft import SPeftModelUpdater
from renate.updaters.model_updater import ModelUpdater

REQUIRED_ARGS_GROUP = "Required Arguments"
Expand Down Expand Up @@ -70,6 +71,8 @@ def get_updater_and_learner_kwargs(
elif args.updater == "LearningToPromptReplay":
learner_args = learner_args + ["prompt_sim_loss_weight", "memory_size", "memory_batch_size"]
updater_class = LearningToPromptReplayModelUpdater
elif args.updater == "SPeft":
updater_class = SPeftModelUpdater
elif args.updater == "DER":
learner_args = base_er_args + ["alpha", "beta"]
updater_class = DarkExperienceReplayModelUpdater
Expand Down Expand Up @@ -488,6 +491,10 @@ def _add_l2preplay_arguments(arguments: Dict[str, Dict[str, Any]]) -> None:
_add_offline_er_arguments(arguments)


def _add_speft_arguments(arguments: Dict[str, Dict[str, Any]]) -> None:
pass


def _add_gdumb_arguments(arguments: Dict[str, Dict[str, Any]]) -> None:
"""A helper function that adds GDumb arguments."""
_add_replay_learner_arguments(arguments)
Expand Down Expand Up @@ -973,6 +980,7 @@ def get_scheduler_kwargs(
"ER": _add_experience_replay_arguments,
"LearningToPrompt": _add_l2p_arguments,
"LearningToPromptReplay": _add_l2preplay_arguments,
"SPeft": _add_speft_arguments,
"DER": _add_dark_experience_replay_arguments,
"POD-ER": _add_pod_experience_replay_arguments,
"CLS-ER": _add_cls_experience_replay_arguments,
Expand Down
3 changes: 3 additions & 0 deletions src/renate/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@
# L2p
PROMPT_SIM_LOSS_WEIGHT = 0.5

# S-prompt
CLUSTERS_PER_TASK = 5


def scheduler(config_space: Dict[str, Any], mode: str, metric: str):
return FIFOScheduler(
Expand Down
Loading