From 4438921f054650865e92fdcae5b42262471e1707 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 31 May 2023 15:32:21 -0400 Subject: [PATCH] Skip device placement for past key values in decoder models (#23919) --- src/transformers/modeling_utils.py | 6 +++++- src/transformers/models/bart/modeling_bart.py | 1 + .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 1 + src/transformers/models/blip_2/modeling_blip_2.py | 1 + src/transformers/models/bloom/modeling_bloom.py | 1 + src/transformers/models/bridgetower/modeling_bridgetower.py | 1 + src/transformers/models/codegen/modeling_codegen.py | 1 + src/transformers/models/gpt2/modeling_gpt2.py | 1 + src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 1 + src/transformers/models/gpt_neo/modeling_gpt_neo.py | 1 + src/transformers/models/gpt_neox/modeling_gpt_neox.py | 1 + .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 1 + src/transformers/models/gptj/modeling_gptj.py | 1 + .../models/gptsan_japanese/modeling_gptsan_japanese.py | 1 + src/transformers/models/llama/modeling_llama.py | 1 + 15 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aa982873664e04..1c4ce11b52366e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix main_input_name = "input_ids" _auto_class = None _no_split_modules = None + _skip_keys_device_placement = None _keep_in_fp32_modules = None # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing @@ -2887,7 +2888,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Dispatch model with hooks on all devices if necessary if device_map is not None: - dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index) + kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index} + if "skip_keys" in inspect.signature(dispatch_model).parameters: + kwargs["skip_keys"] = model._skip_keys_device_placement + dispatch_model(model, **kwargs) if output_loading_info: if loading_info is None: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c1da4eb288e838..59d99f9aa1339e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index e4c64e12b55461..61972c53f01730 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 381a3fe0ed7ada..2cb77f44a28bda 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel): r"language_model.lm_head.weight", ] _no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"] + _skip_keys_device_placement = "past_key_values" _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 5c0d570cbe9c21..a3d2242f8a6acb 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["BloomBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 29bd5b581d9241..eb95f12c854e31 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel): base_model_prefix = "bridgetower" supports_gradient_checkpointing = False _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): if isinstance(module, BridgeTowerVisionModel): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 82a4f95dde1f69..aff556d2f14462 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["CodeGenBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 5fc5451141aabf..00d92f0bb23c2b 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index f6ec24d7773c73..0bbe7648237902 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["GPTBigCodeBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 02f7d5534bde60..f98a73373b4a2d 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 488093594637b2..5ae2807608a934 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): base_model_prefix = "gpt_neox" supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoXLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index d18554480bd7a8..aeab9434e51576 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): base_model_prefix = "gpt_neox_japanese" supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoXJapaneseLayer"] + _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 82cb280caba9c4..91b32d1a8242bf 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel): is_parallelizable = True supports_gradient_checkpointing = True _no_split_modules = ["GPTJBlock"] + _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 4343340a7f7392..a0b68543b3da95 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel): base_model_prefix = "gptsan_japanese" supports_gradient_checkpointing = False _no_split_modules = ["GPTSanJapaneseBlock"] + _skip_keys_device_placement = "past_key_values" @property def dummy_inputs(self): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 80cfdfa5f06645..346da82d86f5b3 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] def _init_weights(self, module):