Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip device placement for past key values in decoder models #23919

Merged
merged 1 commit into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
main_input_name = "input_ids"
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
_keep_in_fp32_modules = None

# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
Expand Down Expand Up @@ -2887,7 +2888,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs)

if output_loading_info:
if loading_info is None:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
std = self.config.init_std
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
std = self.config.init_std
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
r"language_model.lm_head.weight",
]
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keep_in_fp32_modules = ["wo"]

def _init_weights(self, module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
base_model_prefix = "bridgetower"
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
if isinstance(module, BridgeTowerVisionModel):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox_japanese"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gptsan_japanese"
supports_gradient_checkpointing = False
_no_split_modules = ["GPTSanJapaneseBlock"]
_skip_keys_device_placement = "past_key_values"

@property
def dummy_inputs(self):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]

def _init_weights(self, module):
Expand Down