-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Abstracting prompting transformer for use in L2P and S-Prompt #420
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,6 +108,123 @@ def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTen | |
return selected_prompts, loss_value | ||
|
||
|
||
class PromptedTransformer(nn.Module): | ||
"""This generic module is the basic prompted transformer. It takes in a model string and creates | ||
the appropriate transformer. If not prompted, it returns features, and if prompted, it returns | ||
the full feature using those prompts and the input image/text. | ||
|
||
Args: | ||
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub | ||
to use. | ||
num_outputs: Size of the output. | ||
prediction_strategy: Continual learning strategies may alter the prediction at train or test | ||
time. | ||
add_icarl_class_means: If ``True``, additional parameters used only by the | ||
``ICaRLModelUpdater`` are added. Only required when using that updater. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
pretrained_model_name_or_path="google/vit-base-patch16-224", | ||
image_size: int = 32, | ||
patch_size: int = 4, | ||
num_layers: int = 12, | ||
num_heads: int = 12, | ||
hidden_dim: int = 768, | ||
mlp_dim: int = 3072, | ||
dropout: float = 0.1, | ||
attention_dropout: float = 0.1, | ||
num_outputs: int = 10, | ||
prediction_strategy: Optional[PredictionStrategy] = None, | ||
add_icarl_class_means: bool = True, | ||
) -> None: | ||
super().__init__() | ||
if "vit" in pretrained_model_name_or_path: | ||
self.transformer = VisionTransformer( | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
image_size=image_size, | ||
patch_size=patch_size, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_dim=hidden_dim, | ||
mlp_dim=mlp_dim, | ||
dropout=dropout, | ||
attention_dropout=attention_dropout, | ||
num_outputs=num_outputs, | ||
prediction_strategy=prediction_strategy, | ||
add_icarl_class_means=add_icarl_class_means, | ||
) | ||
self._is_text_transformer = False | ||
else: | ||
self.transformer = HuggingFaceSequenceClassificationTransformer( | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
num_outputs=num_outputs, | ||
prediction_strategy=prediction_strategy, | ||
add_icarl_class_means=add_icarl_class_means, | ||
) | ||
for named_param, value in self.transformer.named_parameters(): | ||
if value.shape[0] == self.transformer._backbone.config.vocab_size: | ||
self.word_embeddings = self.transformer.get_submodule( | ||
named_param.replace(".weight", "") | ||
) | ||
break | ||
|
||
self._is_text_transformer = True | ||
|
||
self.transformer._tasks_params.clear() | ||
self.transformer.eval() | ||
for p in self.transformer.parameters(): | ||
p.requires_grad_(False) | ||
|
||
def forward( | ||
self, x: torch.Tensor, prompt: Optional[torch.Tensor] = None, cls_feat: bool = True | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
x: Input torch tensor. | ||
prompt: Prompt tensor. Defaults to None. | ||
cls_feat: Whether to extract [CLS] token or to return full feature tensor. | ||
Ignored for text transformer. Defaults to True. | ||
""" | ||
if prompt is None: | ||
return ( | ||
self.transformer.get_features(x) | ||
if self._is_text_transformer | ||
else self.transformer.get_features(x, cls_feat=cls_feat) | ||
) | ||
# text transformers dont support cls_feat. | ||
else: | ||
if self._is_text_transformer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can flatten the if else(if else) part to if, elif, else There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in d8ef1c8 |
||
# The implicit assumption here is that x for text transformers is the input_ids. | ||
# This simplified forward pass has 4 steps: | ||
# 1. Get prompts | ||
# 2. Get embeddings from inputs. | ||
# 3. Concat prompt and inputs | ||
# 4. Forward prop inputs_embeds to get the features. | ||
inputs_embeds = self.word_embeddings(x["input_ids"]) | ||
if prompt.size(0) != inputs_embeds.size(0): | ||
prompt = prompt.unsqueeze(0).expand( | ||
inputs_embeds.size(0), -1, -1 | ||
) # Expand one prompt to batch size | ||
inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1) | ||
return self.transformer.get_features({"inputs_embeds": inputs_embeds}) | ||
else: | ||
patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x) | ||
if prompt.size(0) != x.size(0): | ||
prompt = prompt.unsqueeze(0).expand( | ||
x.size(0), -1, -1 | ||
) # Expand one prompt to batch size# Expand one prompt to batch size | ||
input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1) | ||
|
||
encoded_features = self.transformer.get_submodule("_backbone.encoder")( | ||
input_concat_prompt, return_dict=False | ||
)[0] | ||
encoded_features = self.transformer.get_submodule("_backbone.layernorm")( | ||
encoded_features | ||
) | ||
return encoded_features[:, 0, :] if cls_feat else encoded_features | ||
|
||
|
||
class LearningToPromptTransformer(RenateBenchmarkingModule): | ||
""" | ||
Implements the vision transformer with prompt pool described in | ||
|
@@ -166,34 +283,22 @@ def __init__( | |
prompt_embedding_features: str = "cls", | ||
patch_pooler: str = "prompt_mean", | ||
) -> None: | ||
if "vit" in pretrained_model_name_or_path: | ||
transformer = VisionTransformer( | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
image_size=image_size, | ||
patch_size=patch_size, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_dim=hidden_dim, | ||
mlp_dim=mlp_dim, | ||
dropout=dropout, | ||
attention_dropout=attention_dropout, | ||
prediction_strategy=prediction_strategy, | ||
add_icarl_class_means=add_icarl_class_means, | ||
num_outputs=num_outputs, | ||
) | ||
self._is_text_transformer = False | ||
else: | ||
transformer = HuggingFaceSequenceClassificationTransformer( | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
prediction_strategy=prediction_strategy, | ||
add_icarl_class_means=add_icarl_class_means, | ||
num_outputs=num_outputs, | ||
) | ||
|
||
self._is_text_transformer = True | ||
transformer._tasks_params.clear() | ||
transformer = PromptedTransformer( | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
image_size=image_size, | ||
patch_size=patch_size, | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_dim=hidden_dim, | ||
mlp_dim=mlp_dim, | ||
dropout=dropout, | ||
attention_dropout=attention_dropout, | ||
num_outputs=num_outputs, | ||
add_icarl_class_means=add_icarl_class_means, | ||
prediction_strategy=prediction_strategy, | ||
) | ||
prompter = PromptPool( | ||
embedding_dim=transformer._embedding_size, | ||
embedding_dim=transformer.transformer._embedding_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the following, we access a lot of protected attributes of the transformer. do we want to keep it that way or rather make them public? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 3896135 |
||
pool_size=pool_size, | ||
pool_selection_size=pool_selection_size, | ||
prompt_size=prompt_size, | ||
|
@@ -204,10 +309,10 @@ def __init__( | |
) | ||
|
||
super().__init__( | ||
embedding_size=transformer._embedding_size, | ||
embedding_size=transformer.transformer._embedding_size, | ||
num_outputs=num_outputs, | ||
constructor_arguments=dict( | ||
**transformer._constructor_arguments, | ||
**transformer.transformer._constructor_arguments, | ||
pool_size=pool_size, | ||
pool_selection_size=pool_selection_size, | ||
prompt_size=prompt_size, | ||
|
@@ -221,6 +326,7 @@ def __init__( | |
) | ||
|
||
self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter}) | ||
self._is_text_transformer = transformer._is_text_transformer | ||
self.prompt_embedding_features = prompt_embedding_features | ||
self.patch_pooler = patch_pooler | ||
self.similarity_score: Optional[torch.Tensor] = None | ||
|
@@ -236,56 +342,32 @@ def __init__( | |
"prompt_mean", | ||
], f"Invalid method to extract prompt embedding features. Got {patch_pooler}" | ||
|
||
for n, p in self._backbone["transformer"].named_parameters(): | ||
p.requires_grad = False | ||
self._backbone["transformer"].eval() | ||
for p in self._backbone["prompter"].parameters(): | ||
p.requires_grad = True | ||
|
||
if self._is_text_transformer: | ||
## This is to find the Embedding layer. | ||
for named_param, value in self._backbone["transformer"].named_parameters(): | ||
if value.shape[0] == self._backbone["transformer"]._backbone.config.vocab_size: | ||
self.word_embeddings = self._backbone["transformer"].get_submodule( | ||
named_param.replace(".weight", "") | ||
) | ||
break | ||
# The backbone's forward is monkey-patched to allow the parent class' forward to work | ||
# without any manual management. | ||
self._backbone.forward = self.forward_for_monkey_patching | ||
|
||
def forward_for_monkey_patching( | ||
self, x: torch.Tensor, task_id: str = defaults.TASK_ID | ||
) -> torch.Tensor: | ||
with torch.no_grad(): | ||
prompt_pool_input = self._backbone["transformer"](x, cls_feat=False) | ||
if not self._is_text_transformer: | ||
# The vision transformer code is manual strapping in. | ||
with torch.no_grad(): | ||
prompt_pool_input = self._backbone["transformer"].get_features(x, cls_feat=False) | ||
if self.prompt_embedding_features == "cls": | ||
# retrieve cls token features. This is used in L2P paper. | ||
prompt_pool_input = prompt_pool_input[:, 0, :] | ||
elif self.prompt_embedding_features == "mean": | ||
# compute mean patch features. | ||
prompt_pool_input = prompt_pool_input[:, 1:, :].mean(1) | ||
# Compute the prompts to be stacked | ||
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) | ||
# compute patch embeddings | ||
patch_embeddings = self._backbone["transformer"].get_submodule("_backbone.embeddings")( | ||
x | ||
) | ||
# concatenate both. | ||
input_concat_prompt = torch.cat([patch_embeddings, prompts], dim=1) | ||
## rest of processing. this code is part of the ViTModel class in HF Transformers. | ||
encoded_features = self._backbone["transformer"].get_submodule("_backbone.encoder")( | ||
input_concat_prompt, return_dict=False | ||
)[0] | ||
encoded_features = self._backbone["transformer"].get_submodule("_backbone.layernorm")( | ||
encoded_features | ||
) | ||
|
||
## Save similarity | ||
self.similarity_score = prompt_similarity | ||
|
||
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) | ||
self.similarity_score = prompt_similarity | ||
encoded_features = self._backbone["transformer"](x, prompts, cls_feat=False) | ||
if self._is_text_transformer: | ||
return encoded_features | ||
else: | ||
if self.patch_pooler == "cls": | ||
seq_cls_token = encoded_features[:, 0, :] | ||
elif self.patch_pooler == "mean": | ||
|
@@ -294,19 +376,3 @@ def forward_for_monkey_patching( | |
num_prompts = prompts.size(1) | ||
seq_cls_token = encoded_features[:, -num_prompts:, :].mean(1) | ||
return seq_cls_token | ||
|
||
else: | ||
## The implicit assumption here is that x for text transformers is the input_ids. | ||
# This simplified forward pass has 4 steps: | ||
# 1. Get prompts | ||
# 2. Get embeddings from inputs. | ||
# 3. Concat prompt and inputs | ||
# 4. Forward prop inputs_embeds to get the features. The forward of the RenateBM applies | ||
# the classifier and gets logits. | ||
with torch.no_grad(): | ||
prompt_pool_input = self._backbone["transformer"].get_features(x) | ||
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) # 1 | ||
self.similarity_score = prompt_similarity | ||
inputs_embeds = self.word_embeddings(x["input_ids"]) # 2 | ||
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) # 3 | ||
return self._backbone["transformer"].get_features({"inputs_embeds": inputs_embeds}) # 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this description still accurate? I don't see the input being returned. Maybe clarify difference between features and full feature. Without context, it might not even be clear that this is the model output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 3896135.