Skip to content

Commit

Permalink
Generate: fix end to end compilation (#32465)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and nbroad1881 committed Aug 7, 2024
1 parent 5c9420d commit 9ef3b04
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
27 changes: 15 additions & 12 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,19 +1024,22 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
key_cache = getattr(self, f"key_cache_{idx}")
value_cache = getattr(self, f"value_cache_{idx}")
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(value_cache)
self.key_cache.append(key_cache)
self.value_cache.append(value_cache)
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update(
self,
Expand Down
17 changes: 9 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,9 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = cache_position
return model_kwargs

def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache or uses a different batch size.
Expand Down Expand Up @@ -1477,7 +1479,7 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
"config": self.config,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": self.device,
"device": device,
"dtype": cache_dtype,
}
self._cache = cache_cls(**cache_kwargs)
Expand Down Expand Up @@ -1813,12 +1815,11 @@ def generate(
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1)
* getattr(generation_config, "num_return_sequences", 1)
* batch_size,
generation_config.max_length,
model_kwargs,
cache_implementation=generation_config.cache_implementation,
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
Expand Down

0 comments on commit 9ef3b04

Please sign in to comment.