-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
S-Prompts for ViT and Text Transformers #388
Conversation
@@ -41,7 +41,10 @@ using Renate (e.g., using :py:func:`~renate.training.training.run_training_job`; | |||
- A class that implements a Learning to Prompt method for ViTs. The methods trains only the input prompts that are sampled from a prompt pool in an input dependent fashion. | |||
* - ``"LearningToPromptReplay"`` | |||
- :py:class:`LearningToPromptLearner <renate.updaters.experimental.l2p.LearningToPromptReplayLearner>` | |||
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER" | |||
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is about supported algorithms, it should list S-Prompts and give guidance how to use it with SPeft.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -367,6 +367,8 @@ def __init__( | |||
def prepare_data(self) -> None: | |||
"""Download DomainNet dataset for given domain.""" | |||
file_name = f"{self.data_id}.zip" | |||
# update dataset name: | |||
self._dataset_name = self.data_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this here? prepare_data
is called only once. therefore, one could set this already as part of the constructor. should this replace L365?
self._M = prompt_size | ||
self._task_id = task_id | ||
self._per_task_classifier = per_task_classifier | ||
logger.warning(f"Task id is {self._task_id}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self._backbone["transformer"].requires_grad_(False) | ||
self._backbone["prompt_pool"].requires_grad_(True) | ||
|
||
# self._backbone["transformer"].transformer._backbone.enable_gradient_checkpointing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# self.s_prompts | ||
self._backbone["prompt_pool"].increment_task() | ||
|
||
def forward_for_monkey_patching( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prepend _
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
] | ||
) | ||
else: | ||
logits = self._backbone["classifier"]["0"](features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove the hard-coding of "0"
somehow? what is defining that name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the update/task_id variable converted into a string. 0 just implies that we are in the first
update. Removing it will just be a cosmetic first_task_name = "0"
, which doesn't seem to serve any
purpose.
Args: | ||
in_features: size of each input sample | ||
out_features: size of each output sample | ||
bias: If set to ``False``, the layer will not learn an additive bias. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing documentation of args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
super().__init__(all_layers) | ||
|
||
def increment_task(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a nice way to reuse this function in the constructor to populate all_layers
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int | ||
) -> None: | ||
"""Explicitly setting grads to None instead of zero.""" | ||
optimizer.zero_grad(set_to_none=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.