Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix T5 adapter layer input #479

Merged
merged 2 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions src/transformers/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,42 +512,61 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten
hidden_states = torch.cat(children_hidden, 0)
return hidden_states

def adapter_layer_forward(self, hidden_states, input_tensor, layer_norm):
"""
Called for each forward pass through adapters.
def adapter_layer_forward(self, hidden_states, residual_input, layer_norm):
"""Forward pass through the adapter layer.
NOTE: This method should only be called if the calling module directly inherits from AdapterLayer. Otherwise,
call the regular forward() method.

Args:
hidden_states (torch.Tensor): Input hidden states to the adapter layer.
residual_input (torch.Tensor): Residual input to the adapter layer.
layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer.

Returns:
torch.Tensor: Output hidden states of the adapter layer.
"""
adapter_setup = self.get_active_setup(self.adapters)
if adapter_setup is not None:
input_hidden_states = hidden_states

if isinstance(adapter_setup, Stack):
hidden_states, _, input_tensor = self.adapter_stack(
adapter_setup, hidden_states, input_tensor, layer_norm
hidden_states, _, residual_input = self.adapter_stack(
adapter_setup, hidden_states, residual_input, layer_norm
)
elif isinstance(adapter_setup, Fuse):
hidden_states = self.adapter_fusion(adapter_setup, hidden_states, input_tensor, layer_norm)
hidden_states = self.adapter_fusion(adapter_setup, hidden_states, residual_input, layer_norm)
elif isinstance(adapter_setup, Split):
hidden_states = self.adapter_split(adapter_setup, hidden_states, input_tensor, layer_norm)
hidden_states = self.adapter_split(adapter_setup, hidden_states, residual_input, layer_norm)
elif isinstance(adapter_setup, Parallel):
# notice that we are overriding input tensor here to keep the same dim as hidden_states for the residual
# in case we were blowing up the batch for parallel processing of multiple adapters for the same input
hidden_states, input_tensor = self.adapter_parallel(
adapter_setup, hidden_states, input_tensor, layer_norm
hidden_states, residual_input = self.adapter_parallel(
adapter_setup, hidden_states, residual_input, layer_norm
)
elif isinstance(adapter_setup, BatchSplit):
hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, input_tensor, layer_norm)
hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, residual_input, layer_norm)
else:
raise ValueError(f"Invalid adapter setup {adapter_setup}")

last_adapter = self.adapters[adapter_setup.last()]
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, input_tensor, layer_norm)
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm)

elif layer_norm:
hidden_states = layer_norm(hidden_states + input_tensor)
hidden_states = layer_norm(hidden_states + residual_input)
else:
hidden_states = hidden_states + input_tensor
hidden_states = hidden_states + residual_input

return hidden_states

def forward(self, hidden_states, input_tensor, layer_norm):
return self.adapter_layer_forward(hidden_states, input_tensor, layer_norm)
def forward(self, hidden_states, residual_input, layer_norm):
"""Forward pass through the adapter layer.

Args:
hidden_states (torch.Tensor): Input hidden states to the adapter layer.
residual_input (torch.Tensor): Residual input to the adapter layer.
layer_norm (torch.nn.Module): Transformer layer normalization module to be used by the adapter layer.

Returns:
torch.Tensor: Output hidden states of the adapter layer.
"""
return self.adapter_layer_forward(hidden_states, residual_input, layer_norm)
6 changes: 3 additions & 3 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def __init__(self, config: T5Config):
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = self.adapter_layer_forward(hidden_states, self.dropout(forwarded_states), None)
hidden_states = self.adapter_layer_forward(self.dropout(forwarded_states), hidden_states, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make it super obvious which one the residual connection is by using hidden_states=..., residual_connection=... in the method call. This could help people adding new models when they have a look at the currently implemented models.

return hidden_states


Expand Down Expand Up @@ -609,7 +609,7 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = self.adapter_layer_forward(hidden_states, self.dropout(attention_output[0]), None)
hidden_states = self.adapter_layer_forward(self.dropout(attention_output[0]), hidden_states, None)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs

Expand Down Expand Up @@ -647,7 +647,7 @@ def forward(
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = self.adapter_layer_forward(hidden_states, self.dropout(attention_output[0]), None)
layer_output = self.adapter_layer_forward(self.dropout(attention_output[0]), hidden_states, None)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs

Expand Down