Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S-Prompts for ViT and Text Transformers #388

Merged
merged 101 commits into from
Dec 4, 2023
Merged

S-Prompts for ViT and Text Transformers #388

merged 101 commits into from
Dec 4, 2023

Conversation

prabhuteja12
Copy link
Contributor

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@prabhuteja12 prabhuteja12 marked this pull request as ready for review December 1, 2023 10:33
@@ -41,7 +41,10 @@ using Renate (e.g., using :py:func:`~renate.training.training.run_training_job`;
- A class that implements a Learning to Prompt method for ViTs. The methods trains only the input prompts that are sampled from a prompt pool in an input dependent fashion.
* - ``"LearningToPromptReplay"``
- :py:class:`LearningToPromptLearner <renate.updaters.experimental.l2p.LearningToPromptReplayLearner>`
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER"
- A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is about supported algorithms, it should list S-Prompts and give guidance how to use it with SPeft.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -367,6 +367,8 @@ def __init__(
def prepare_data(self) -> None:
"""Download DomainNet dataset for given domain."""
file_name = f"{self.data_id}.zip"
# update dataset name:
self._dataset_name = self.data_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this here? prepare_data is called only once. therefore, one could set this already as part of the constructor. should this replace L365?

self._M = prompt_size
self._task_id = task_id
self._per_task_classifier = per_task_classifier
logger.warning(f"Task id is {self._task_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self._backbone["transformer"].requires_grad_(False)
self._backbone["prompt_pool"].requires_grad_(True)

# self._backbone["transformer"].transformer._backbone.enable_gradient_checkpointing()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# self.s_prompts
self._backbone["prompt_pool"].increment_task()

def forward_for_monkey_patching(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepend _

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

]
)
else:
logits = self._backbone["classifier"]["0"](features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the hard-coding of "0" somehow? what is defining that name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the update/task_id variable converted into a string. 0 just implies that we are in the first
update. Removing it will just be a cosmetic first_task_name = "0", which doesn't seem to serve any
purpose.

Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing documentation of args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
super().__init__(all_layers)

def increment_task(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a nice way to reuse this function in the constructor to populate all_layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

src/renate/updaters/experimental/speft.py Show resolved Hide resolved
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
) -> None:
"""Explicitly setting grads to None instead of zero."""
optimizer.zero_grad(set_to_none=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

@prabhuteja12 prabhuteja12 merged commit c54887c into dev Dec 4, 2023
18 checks passed
@prabhuteja12 prabhuteja12 deleted the pt_s_prompts branch December 4, 2023 17:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants