From 9ef3b0499263344d466ad3af29f4d68d622b6400 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 6 Aug 2024 15:06:47 +0100 Subject: [PATCH] Generate: fix end to end compilation (#32465) --- src/transformers/cache_utils.py | 27 +++++++++++++++------------ src/transformers/generation/utils.py | 17 +++++++++-------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 141964c1add6e3..1ddc3516ba2167 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1024,19 +1024,22 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: # Note: There will be significant perf decrease if switching to use 5D tensors instead. cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) for idx in range(config.num_hidden_layers): - # Note: `torch.export()`` requires mutations to be registered as buffers. - self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) - self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) - key_cache = getattr(self, f"key_cache_{idx}") - value_cache = getattr(self, f"value_cache_{idx}") - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case - # it is not needed anyway) + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + # Notes: + # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case + # it is not needed anyway) + # 2. `torch.export()` requires mutations to be registered as buffers. if not is_torchdynamo_compiling(): - torch._dynamo.mark_static_address(key_cache) - torch._dynamo.mark_static_address(value_cache) - self.key_cache.append(key_cache) - self.value_cache.append(value_cache) + self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) + self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) + new_layer_key_cache = getattr(self, f"key_cache_{idx}") + new_layer_value_cache = getattr(self, f"value_cache_{idx}") + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) def update( self, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ccaa1d80e3f8e1..3e3c7803a61189 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1429,7 +1429,9 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = cache_position return model_kwargs - def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: + def _get_cache( + self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs + ) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a new `generate` call requires a larger cache or uses a different batch size. @@ -1477,7 +1479,7 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l "config": self.config, "max_batch_size": max_batch_size, "max_cache_len": max_cache_len, - "device": self.device, + "device": device, "dtype": cache_dtype, } self._cache = cache_cls(**cache_kwargs) @@ -1813,12 +1815,11 @@ def generate( "issue: https://github.com/huggingface/transformers/issues/28981" ) model_kwargs[cache_name] = self._get_cache( - generation_config.cache_implementation, - getattr(generation_config, "num_beams", 1) - * getattr(generation_config, "num_return_sequences", 1) - * batch_size, - generation_config.max_length, - model_kwargs, + cache_implementation=generation_config.cache_implementation, + max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, + max_cache_len=generation_config.max_length, + device=device, + model_kwargs=model_kwargs, ) elif generation_config.cache_implementation == "quantized": if not self._supports_quantized_cache: