From 73aa1267cd35c84c19bb2eb53c1c551a68d4fff0 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 21 Oct 2024 00:16:57 +0100 Subject: [PATCH] Toggling KV-caches (#1763) --- docs/source/api_ref_modules.rst | 3 + recipes/eleuther_eval.py | 108 +++++---- tests/recipes/test_eleuther_eval.py | 9 - .../modules/model_fusion/test_fusion_layer.py | 20 +- .../model_fusion/test_fusion_models.py | 6 +- tests/torchtune/modules/test_attention.py | 3 + tests/torchtune/modules/test_common_utils.py | 193 ++++++++++++++++ torchtune/models/clip/_position_embeddings.py | 2 +- torchtune/models/gemma/transformer.py | 4 +- torchtune/modules/__init__.py | 10 +- torchtune/modules/attention.py | 10 +- torchtune/modules/common_utils.py | 207 +++++++++++++++++- torchtune/modules/model_fusion/_fusion.py | 42 +++- torchtune/modules/transformer.py | 61 ++++-- 14 files changed, 569 insertions(+), 109 deletions(-) create mode 100644 tests/torchtune/modules/test_common_utils.py diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 70e3870c99..cc9a493147 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -103,6 +103,9 @@ These are utilities that are common to and can be used by all modules. :nosignatures: common_utils.reparametrize_as_dtype_state_dict_post_hook + common_utils.local_kv_cache + common_utils.disable_kv_cache + common_utils.delete_kv_caches Vision Transforms diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 65e0c4eba8..590e4f902a 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -13,7 +13,7 @@ import torch -from lm_eval.evaluator import evaluate, get_task_list +from lm_eval.evaluator import evaluate from lm_eval.models.hf_vlms import HFMultimodalLM from lm_eval.models.huggingface import HFLM from lm_eval.tasks import get_task_dict, TaskManager @@ -29,6 +29,7 @@ ) from torchtune.generation import generate, sample from torchtune.modules import TransformerDecoder +from torchtune.modules.common_utils import local_kv_cache from torchtune.modules.model_fusion import DeepFusionModel from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform @@ -224,18 +225,11 @@ def _model_multimodal_generate( "multimodal generation." ) - # 2. Setup KV cache and masks for bsz 1 + encoder_max_seq_len = ( + self.model_transform.image_seq_len * self._max_images_per_sample + ) + # Setup masks for bsz 1 with self.device: - if self.model.caches_are_enabled(): - self.model.reset_caches() - else: - self.model.setup_caches( - batch_size=1, - dtype=self._dtype, - encoder_max_seq_len=self.model_transform.image_seq_len - * self._max_images_per_sample, - decoder_max_seq_len=self.max_length, - ) causal_mask = torch.tril( torch.ones( size=(self.max_length, self.max_length), @@ -247,28 +241,37 @@ def _model_multimodal_generate( batch["input_pos"] = input_pos[None, :seq_len] batch["mask"] = causal_mask[None, :seq_len] - # 3. Prefill step - generated_tokens = [] - logits = self.model(prompt, **batch)[:, -1] - token = sample(logits, temperature=0.0, top_k=None) - generated_tokens.append(token.item()) - - cache_mask = batch["encoder_mask"][:, -1:] - - # 4. Continue generating - for _ in range(max_length): - if token.item() in self.model_transform.stop_tokens: - break - logits = self.model( - token, - mask=causal_mask[None, seq_len, None, :], - encoder_input=None, - encoder_mask=cache_mask, - input_pos=input_pos[None, seq_len], - )[:, -1] + # 2. Setup KV cache + with local_kv_cache( + self.model, + batch_size=self.batch_size, + device=self.device, + dtype=self._dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=self.max_length, + ): + # 3. Prefill step + generated_tokens = [] + logits = self.model(prompt, **batch)[:, -1] token = sample(logits, temperature=0.0, top_k=None) generated_tokens.append(token.item()) - seq_len += 1 + + cache_mask = batch["encoder_mask"][:, -1:] + + # 4. Continue generating + for _ in range(max_length): + if token.item() in self.model_transform.stop_tokens: + break + logits = self.model( + token, + mask=causal_mask[None, seq_len, None, :], + encoder_input=None, + encoder_mask=cache_mask, + input_pos=input_pos[None, seq_len], + )[:, -1] + token = sample(logits, temperature=0.0, top_k=None) + generated_tokens.append(token.item()) + seq_len += 1 # 5. Return generated tokens return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0) @@ -388,18 +391,6 @@ def _model_generate( "Any decoding strategy other than greedy is not supported." ) - # Setup KV caches OR reset them if they're already set up - if self.enable_kv_cache: - if self.model.caches_are_enabled(): - self.model.reset_caches() - else: - with self.device: - self.model.setup_caches( - batch_size=self.batch_size, - dtype=self._dtype, - decoder_max_seq_len=self.max_length, - ) - # if we've recieved fewer than self._batch_size samples in the current # batch we need to pad the batch out. here we're padding the end of the # current batch to the correct length. this is because when we use static @@ -409,15 +400,21 @@ def _model_generate( (0, 0, 0, self._batch_size - bsz), value=self._tokenizer.eos_id, # pad with one of the tokenizer's stop tokens so generation can stop early ) - - toks, _ = generate( + with local_kv_cache( self.model, - maybe_padded_context, - max_generated_tokens=self.max_gen_toks, - temperature=temperature, - top_k=None, - stop_tokens=self._tokenizer.stop_tokens, - ) + batch_size=self.batch_size, + device=self.device, + dtype=self._dtype, + decoder_max_seq_len=self.max_length, + ): + toks, _ = generate( + self.model, + maybe_padded_context, + max_generated_tokens=self.max_gen_toks, + temperature=temperature, + top_k=None, + stop_tokens=self._tokenizer.stop_tokens, + ) return toks[:bsz] @@ -536,13 +533,6 @@ def evaluate(self) -> None: # Initialize tasks for the harness task_manager = TaskManager(include_path=self.include_path) task_dict = get_task_dict(self.tasks, task_manager) - task_types = set([t.task.OUTPUT_TYPE for t in get_task_list(task_dict)]) - if len(task_types) > 1 and "generate_until" in task_types: - raise RuntimeError( - "Evaluating on multiple task types where any one task involves " - "generation is currently not supported. See the issue below for more info: " - "https://github.com/pytorch/torchtune/issues/1621" - ) # Run evaluation t0 = time.time() diff --git a/tests/recipes/test_eleuther_eval.py b/tests/recipes/test_eleuther_eval.py index 29f8e9f123..1c3a7bb65f 100644 --- a/tests/recipes/test_eleuther_eval.py +++ b/tests/recipes/test_eleuther_eval.py @@ -194,12 +194,3 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir match="QAT quantizers should only be used during quantization aware training", ): runpy.run_path(TUNE_PATH, run_name="__main__") - - @pytest.mark.integration_test - def test_eval_recipe_errors_with_generate_until_and_mc_tasks( - self, caplog, capsys, monkeypatch, tmpdir - ): - # We can't currently specify both generate_until and mc_tasks in the same run - # b/c the KV cache won't be reset and the result will be different. This test - # catches that error - pass diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layer.py index 94ca29085e..a2fc0715eb 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layer.py @@ -25,10 +25,13 @@ def __init__(self, dim): self.cache_enabled = False self.encoder_max_seq_len = None - def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): + def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): self.cache_enabled = True self.encoder_max_seq_len = encoder_max_seq_len + def caches_are_enabled(self): + return self.cache_enabled + def reset_cache(self): self.cache_enabled = False @@ -43,10 +46,13 @@ def __init__(self, dim): self.cache_enabled = False self.decoder_max_seq_len = None - def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): + def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len): self.cache_enabled = True self.decoder_max_seq_len = decoder_max_seq_len + def caches_are_enabled(self): + return self.cache_enabled + def reset_cache(self): self.cache_enabled = False @@ -131,22 +137,20 @@ def test_fusion_params(self, fused_layer): "fusion_layer.linear.bias", } - def test_setup_cache(self, fused_layer): + def test_setup_caches(self, fused_layer): """ Test that the cache methods works as expected. """ - fused_layer.setup_cache( + fused_layer.setup_caches( 2, torch.float32, encoder_max_seq_len=10, decoder_max_seq_len=10 ) - assert fused_layer.cache_enabled - fused_layer.reset_cache() - assert not fused_layer.cache_enabled + assert fused_layer.caches_are_enabled() def test_setup_cache_different_cache_seq_len(self, fused_layer): """ Test that the cache methods works as expected. """ - fused_layer.setup_cache( + fused_layer.setup_caches( 2, torch.float32, encoder_max_seq_len=5, decoder_max_seq_len=10 ) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 01cac982c3..322616276e 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -32,7 +32,7 @@ def __init__(self, dim, vocab_size): def setup_caches(self, batch_size, dtype, *args, **kwargs): self.cache_enabled = True - def caches_are_enabled(self): + def caches_are_setup(self): return self.cache_enabled def reset_caches(self): @@ -144,9 +144,9 @@ def test_setup_cache(self, fused_model): Test that the cache methods works as expected. """ fused_model.setup_caches(2, torch.float32) - assert fused_model.caches_are_enabled() + assert fused_model.caches_are_setup() fused_model.reset_caches() - assert not fused_model.caches_are_enabled() + assert not fused_model.caches_are_setup() def test_set_trainable_params(self, fused_model, encoder, decoder): """ diff --git a/tests/torchtune/modules/test_attention.py b/tests/torchtune/modules/test_attention.py index 4fdef88bd7..872f6684de 100644 --- a/tests/torchtune/modules/test_attention.py +++ b/tests/torchtune/modules/test_attention.py @@ -141,6 +141,7 @@ def gqa_kv_cache( kv_cache=kv_cache, max_seq_len=max_seq_len, ) + attn.cache_enabled = True fixed_init_model(attn) attn.eval() return attn @@ -195,6 +196,7 @@ def mha_kv_cache( kv_cache=kv_cache, max_seq_len=max_seq_len, ) + attn.cache_enabled = True fixed_init_model(attn) attn.eval() return attn @@ -249,6 +251,7 @@ def mqa_kv_cache( kv_cache=kv_cache, max_seq_len=max_seq_len, ) + attn.cache_enabled = True fixed_init_model(attn) attn.eval() return attn diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py new file mode 100644 index 0000000000..41dc472f00 --- /dev/null +++ b/tests/torchtune/modules/test_common_utils.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from tests.test_utils import fixed_init_model +from torchtune.models.llama3_2._component_builders import llama3_2 +from torchtune.models.llama3_2_vision._component_builders import ( + llama3_2_vision_decoder, + llama3_2_vision_encoder, +) +from torchtune.modules import delete_kv_caches, disable_kv_cache, local_kv_cache +from torchtune.modules.model_fusion import DeepFusionModel + + +@pytest.fixture +def llama_vision_model(): + vision_encoder = llama3_2_vision_encoder( + clip_embed_dim=32, + clip_num_layers=4, + num_heads=4, + tile_size=49, + patch_size=9, + max_num_tiles=4, + in_channels=3, + clip_hidden_states=[0, 1], + num_layers_projection=2, + decoder_embed_dim=128, + ).eval() + vision_decoder = llama3_2_vision_decoder( + vocab_size=256, + num_layers=4, + fusion_interval=2, + num_special_tokens=2, + num_heads=8, + num_kv_heads=4, + embed_dim=128, + max_seq_len=4096, + encoder_max_seq_len=4096, + ).eval() + fixed_init_model(vision_encoder, min_val=-1, max_val=1) + fixed_init_model(vision_decoder, min_val=-1, max_val=1) + model = DeepFusionModel( + encoder=vision_encoder, + decoder=vision_decoder, + encoder_trainable=False, + decoder_trainable=False, + fusion_trainable=False, + ) + return model + + +@pytest.fixture +def llama_decoder_model(): + model = llama3_2( + vocab_size=256, + num_layers=2, + num_heads=8, + num_kv_heads=4, + embed_dim=256, + max_seq_len=4096, + ) + fixed_init_model(model, min_val=-1, max_val=1) + model.eval() + return model + + +@pytest.fixture +def device(): + return torch.device("cpu") + + +@pytest.fixture +def inputs(): + return torch.randint(low=0, high=256, size=(4, 2048)) + + +@pytest.fixture +def causal_mask(): + return torch.tril(torch.ones((2048, 4096))).unsqueeze(0).repeat(4, 1, 1) + + +@pytest.fixture +def input_pos(): + return torch.arange(0, 2048).unsqueeze(0).repeat(4, 1) + + +class TestLocalKVCache: + @pytest.mark.parametrize("model", ["llama_decoder_model", "llama_vision_model"]) + def test_local_kv_cache( + self, device, inputs, causal_mask, input_pos, model, request + ): + model = request.getfixturevalue(model) + outs = model(inputs) + + with local_kv_cache(model, batch_size=4, device=device, dtype=torch.float32): + outs_cached = model(inputs, mask=causal_mask, input_pos=input_pos) + assert model.caches_are_setup() + assert model.caches_are_enabled() + + for module in model.modules(): + if hasattr(module, "kv_cache"): + assert module.kv_cache is None + + assert not model.caches_are_setup() + assert not model.caches_are_enabled() + + torch.testing.assert_close( + outs_cached.mean(), outs.mean(), atol=1e-4, rtol=1e-6 + ) + + @pytest.mark.parametrize("model", ["llama_decoder_model", "llama_vision_model"]) + def test_local_kv_cache_raises_error_caches_setup(self, device, model, request): + + model = request.getfixturevalue(model) + model.setup_caches(batch_size=4, dtype=torch.float32) + with pytest.raises(ValueError, match="Model caches must be not setup"): + with local_kv_cache( + model, batch_size=4, device=device, dtype=torch.float32 + ): + pass + + +class TestDeleteKVCaches: + @pytest.mark.parametrize("model", ["llama_decoder_model", "llama_vision_model"]) + def test_delete_kv_cache(self, model, request): + model = request.getfixturevalue(model) + model.setup_caches(batch_size=4, dtype=torch.float32) + + delete_kv_caches(model) + + assert not model.caches_are_setup() + assert not model.caches_are_enabled() + + for module in model.modules(): + if hasattr(module, "kv_cache"): + assert module.kv_cache is None + assert not module.cache_enabled + + @pytest.mark.parametrize("model", ["llama_decoder_model", "llama_vision_model"]) + def test_delete_kv_cache_raises_error_without_caches_setup(self, model, request): + model = request.getfixturevalue(model) + with pytest.raises(ValueError, match="You have tried to delete model caches"): + delete_kv_caches(model) + + +class TestDisableKVCaches: + @pytest.mark.parametrize("model", ["llama_decoder_model", "llama_vision_model"]) + def test_disable_kv_cache(self, inputs, causal_mask, input_pos, model, request): + + # firstly, setup kv-caches and update the cache state + model = request.getfixturevalue(model) + model.setup_caches(batch_size=4, dtype=torch.float32) + model(inputs, mask=causal_mask, input_pos=input_pos) + + # let's grab this initial cache state for later + expected_kv_cache_states = [] + for module in model.modules(): + if hasattr(module, "kv_cache") and callable(module.kv_cache): + expected_kv_cache_states.append(module.kv_cache.k_cache.clone()) + + with disable_kv_cache(model): + assert model.caches_are_setup() + assert not model.caches_are_enabled() + + # these model forward passes should *not* be updating the cache + model(inputs) + model(inputs) + + # grab the cache states after exiting the context manager + kv_cache_states = [] + for module in model.modules(): + if hasattr(module, "kv_cache") and callable(module.kv_cache): + assert module.cache_enabled + kv_cache_states.append(module.kv_cache.k_cache.clone()) + + # should be the same! + for expected, output in zip(expected_kv_cache_states, kv_cache_states): + assert torch.equal(expected, output) + + assert model.caches_are_setup() + assert model.caches_are_enabled() + + @pytest.mark.parametrize("model", ["llama_decoder_model", "llama_vision_model"]) + def test_disable_kv_cache_raises_error_caches_not_setup(self, model, request): + model = request.getfixturevalue(model) + with pytest.raises(ValueError, match="Model caches must be setup"): + with disable_kv_cache(model): + pass diff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py index 8bc7797757..cd1ea5947c 100644 --- a/torchtune/models/clip/_position_embeddings.py +++ b/torchtune/models/clip/_position_embeddings.py @@ -570,7 +570,7 @@ def _load_state_dict_hook( if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb: raise ValueError( "Expected embedding shape to be (..., num_tokens, tgt_emb) to match" - f" but found shapes {self.embedding.shape} and {state_dict[prefix+'embedding'].shape}" + f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}" ) if inpt_max_num_tiles_x != inpt_max_num_tiles_y: diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index e4cb212e8c..e6310e198e 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -70,7 +70,7 @@ def __init__( self.norm_embeddings = norm_embeddings self.num_output_chunks = 0 - def caches_are_enabled(self) -> bool: + def caches_are_setup(self) -> bool: """Check if the key value caches are setup.""" return self.layers[0].cache_enabled @@ -104,7 +104,7 @@ def setup_caches( if decoder_max_seq_len is not None: self.decoder_max_seq_len = decoder_max_seq_len for layer in self.layers: - layer.setup_cache( + layer.setup_caches( batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len, diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 29540d2b15..32af70f8e5 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -6,7 +6,12 @@ from .attention import MultiHeadAttention # noqa from .attention_utils import create_block_causal_mask, packed_block_causal_mask -from .common_utils import reparametrize_as_dtype_state_dict_post_hook +from .common_utils import ( + delete_kv_caches, + disable_kv_cache, + local_kv_cache, + reparametrize_as_dtype_state_dict_post_hook, +) from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa from .layer_norm import Fp32LayerNorm # noqa @@ -42,5 +47,8 @@ "reparametrize_as_dtype_state_dict_post_hook", "create_block_causal_mask", "packed_block_causal_mask", + "local_kv_cache", + "delete_kv_caches", + "disable_kv_cache", "get_cosine_schedule_with_warmup", ] diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index ec04a752e4..879f0679cf 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -70,7 +70,7 @@ class MultiHeadAttention(nn.Module): This is needed to compute the RoPE Cache. Default: 4096. is_causal (bool): sets the default mask to causal when no mask is provided attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. - This argument is ignored if self.training is False. Default value is 0.0. + Default value is 0.0. Raises: ValueError: If ``num_heads % num_kv_heads != 0`` @@ -139,6 +139,11 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int ) -> None: @@ -163,6 +168,7 @@ def setup_cache( head_dim=self.head_dim, dtype=dtype, ) + self.cache_enabled = True def reset_cache(self): """Reset the key value caches.""" @@ -290,7 +296,7 @@ def forward( k = self.k_norm(k) # Update key-value cache - if self.kv_cache is not None: + if self.kv_cache is not None and self.cache_enabled: k, v = self.kv_cache.update(k, v) output = self._attention_call( diff --git a/torchtune/modules/common_utils.py b/torchtune/modules/common_utils.py index 60cc3222c8..ead3c7ad1e 100644 --- a/torchtune/modules/common_utils.py +++ b/torchtune/modules/common_utils.py @@ -4,11 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import mmap import sys from collections import OrderedDict from functools import partial -from typing import Any, Dict, Tuple +from typing import Any, Dict, Generator, Optional, Tuple +from warnings import warn import torch @@ -163,3 +165,206 @@ def _register_reparametrize_state_dict_hooks( module._register_state_dict_hook( partial(hook, dtype=dtype, offload_to_cpu=offload_to_cpu) ) + + +@contextlib.contextmanager +def disable_kv_cache(model: nn.Module) -> Generator[None, None, None]: + """ + This context manager temporarily disables KV-cacheing on a given model, which must already + already have KV-caches setup. All forward passes using the model within this context manager + will not use KV-caches. + + KV-caches will be disabled when entering the context manager, and will be enabled upon exit, + without being modified. + + This is useful in cases where we might wish to alternate between model calls which use KV-cacheing, + and model calls which do not use KV-cacheing, without the additional overhead of deleting and setting caches up + every time. + + Example: + >>> from torchtune.models.llama3_2 import llama3_2_1b + >>> from torchtune.modules import disable_kv_cache + >>> import torch + >>> model = llama3_2_1b() + >>> # setup caches + >>> model.setup_caches(batch_size=1, + >>> dtype=torch.float32, + >>> decoder_max_seq_len=1024) + >>> print(model.caches_are_setup()) + True + >>> print(model.caches_are_enabled()) + True + >>> print(model.layers[0].attn.kv_cache) + KVCache() + >>> # now temporarily disable caches + >>> with disable_kv_cache(model): + >>> print(model.caches_are_setup()) + >>> True + >>> print(model.caches_are_enabled()) + >>> False + >>> print(model.layers[0].attn.kv_cache) + >>> # KVCache() + >>> # caches are now re-enabled, and their state is untouched + >>> print(model.caches_are_setup()) + True + >>> print(model.caches_are_enabled()) + True + >>> print(model.layers[0].attn.kv_cache) + >>> KVCache() + + Args: + model (nn.Module): model to disable KV-cacheing for. + + Yields: + None: Returns control to the caller with KV-caches disabled on the given model. + + Raises: + ValueError: If the model does not have caches setup. + """ + if not model.caches_are_setup(): + raise ValueError( + "Model caches must be setup before calling disable_kv_cache! " + "Please use model.setup_caches() to setup model caches." + ) + if not model.caches_are_enabled(): + warn( + "You are using disable_kv_cache with a model that does not " + "have caches enabled. This is a no-op and the expected behaviour " + "may not occur." + ) + for module in model.modules(): + if hasattr(module, "kv_cache") and callable(module.kv_cache): + module.cache_enabled = False + try: + yield + finally: + for module in model.modules(): + if hasattr(module, "kv_cache") and callable(module.kv_cache): + module.cache_enabled = True + + +@contextlib.contextmanager +def local_kv_cache( + model: nn.Module, + *, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, +) -> Generator[None, None, None]: + """ + This context manager temporarily enables KV-cacheing on a given model, which does not + already have KV-caches setup. All forward passes using the model within this context manager + will use KV-caches. + + KV-caches will be set-up with the given ``batch_size``, ``dtype``, and ``max_seq_len`` when + entering the context manager, and will be deleted on exit. + + Example: + >>> from torchtune.models.llama3_2 import llama3_2_1b + >>> from torchtune.modules import local_kv_cache + >>> import torch + >>> model = llama3_2_1b() + >>> print(model.caches_are_setup()) + False + >>> print(model.caches_are_enabled()) + False + >>> print(model.layers[0].attn.kv_cache) + None + >>> # entering cacheing mode + >>> with local_kv_cache(model, + >>> batch_size=1, + >>> device=torch.device("cpu"), + >>> dtype=torch.float32, + >>> decoder_max_seq_len=1024): + >>> print(model.caches_are_setup()) + True + >>> print(model.caches_are_enabled()) + True + >>> print(model.layers[0].attn.kv_cache) + KVCache() + >>> # exited cacheing mode + >>> print(model.caches_are_setup()) + False + >>> print(model.caches_are_enabled()) + False + >>> print(model.layers[0].attn.kv_cache) + None + + Args: + model (nn.Module): model to enable KV-cacheing for. + batch_size (int): batch size for the caches. + device (torch.device): device to setup caches on. this should be the same device + the model is on. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. + + Yields: + None: Returns control to the caller with KV-caches setup and enabled on the given model. + + Raises: + ValueError: If the model already has caches setup. + """ + if model.caches_are_setup(): + raise ValueError( + "Model caches must be not setup prior to entering this context manager! " + "Please use delete_kv_caches(model) to delete model caches." + ) + # ensure caches are setup on the same device as the model + with device: + model.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) + try: + yield + finally: + delete_kv_caches(model) + + +def delete_kv_caches(model: nn.Module): + """ + Deletes KV caches from all attention layers in a model, + and also ensures ``cache_enabled`` is set to False. + + Example: + >>> from torchtune.models.llama3_2 import llama3_2_1b + >>> from torchtune.modules import delete_kv_caches + >>> import torch + >>> model = llama3_2_1b() + >>> model.setup_caches(batch_size=1, + >>> dtype=torch.float32, + >>> decoder_max_seq_len=1024) + >>> print(model.caches_are_setup()) + >>> True + >>> print(model.caches_are_enabled()) + >>> True + >>> print(model.layers[0].attn.kv_cache) + >>> KVCache() + >>> delete_kv_caches(model) + >>> print(model.caches_are_setup()) + >>> False + >>> print(model.caches_are_enabled()) + >>> False + >>> print(model.layers[0].attn.kv_cache) + >>> None + Args: + model (nn.Module): model to enable KV-cacheing for. + + Raises: + ValueError: if ``delete_kv_caches`` is called on a model which does not have + caches setup. + """ + if not model.caches_are_setup(): + raise ValueError( + "You have tried to delete model caches, but `model.caches_are_setup()` " + "is False!" + ) + for module in model.modules(): + if hasattr(module, "kv_cache") and callable(module.kv_cache): + module.cache_enabled = False + module.kv_cache = None diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index ea1f01c383..40ede4feec 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -91,7 +91,7 @@ def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): state_dict[new_key] = state_dict[key] del state_dict[key] - def setup_cache( + def setup_caches( self, batch_size: int, dtype: torch.dtype, @@ -107,24 +107,33 @@ def setup_cache( encoder_max_seq_len (int): maximum cache sequence length for cross-attention layer. decoder_max_seq_len (int): maximum cache sequence length for self-attention layer. """ - self.layer.setup_cache( + self.layer.setup_caches( batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len, decoder_max_seq_len=decoder_max_seq_len, ) - self.fusion_layer.setup_cache( + self.fusion_layer.setup_caches( batch_size, dtype, encoder_max_seq_len=encoder_max_seq_len, decoder_max_seq_len=decoder_max_seq_len, ) - @property - def cache_enabled(self) -> bool: - """Check if the key value caches are setup.""" - return self.layer.cache_enabled + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup on ``self.layer``. + See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. + """ + return self.layer.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches on ``self.layer`` are enabled. + See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. + """ + return self.layer.caches_are_enabled() def reset_cache(self): """Reset both layers' key value caches.""" @@ -384,12 +393,27 @@ def setup_caches( decoder_max_seq_len=decoder_max_seq_len, ) + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + def caches_are_enabled(self) -> bool: - """Check if the key value caches are setup.""" + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False. + """ return self.decoder.caches_are_enabled() def reset_caches(self): - """Reset the key value caches.""" + """ + Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, + without deleting or reallocating cache tensors. + """ self.decoder.reset_caches() def forward( diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index ded4d96672..910cb8273b 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -45,7 +45,7 @@ def __init__( self.sa_scale = sa_scale or nn.Identity() self.mlp_scale = mlp_scale or nn.Identity() - def setup_cache( + def setup_caches( self, batch_size: int, dtype: torch.dtype, @@ -63,11 +63,20 @@ def setup_cache( """ self.attn.setup_cache(batch_size, dtype, max_seq_len=decoder_max_seq_len) - @property - def cache_enabled(self) -> bool: - """Check if the key value caches are setup.""" + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup on ``self.attn``. + See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. + """ return self.attn.kv_cache is not None + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches on ``self.attn`` are enabled. + See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. + """ + return self.attn.cache_enabled + def reset_cache(self): """Reset the key value caches.""" self.attn.reset_cache() @@ -165,7 +174,7 @@ def __init__( self.ca_scale = ca_scale or nn.Identity() self.mlp_scale = mlp_scale or nn.Identity() - def setup_cache( + def setup_caches( self, batch_size: int, dtype: torch.dtype, @@ -183,11 +192,20 @@ def setup_cache( """ self.attn.setup_cache(batch_size, dtype, encoder_max_seq_len) - @property - def cache_enabled(self) -> bool: - """Check if the key value caches are setup.""" + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup on ``self.attn``. + See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. + """ return self.attn.kv_cache is not None + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches on ``self.attn`` are enabled. + See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. + """ + return self.attn.cache_enabled + def reset_cache(self): """Reset the key value caches.""" self.attn.reset_cache() @@ -253,7 +271,7 @@ def forward( """ # During decoding, it's possible encoder_input is None because the embeds # are already stored in the kv cache. - empty_cache = not self.cache_enabled or self.attn.kv_cache.size == 0 + empty_cache = not self.caches_are_enabled() or self.attn.kv_cache.size == 0 # Skip cross attention when no secondary input as it's primary purpose # is to attend between x and encoder_input. if encoder_input is None and empty_cache: @@ -423,19 +441,34 @@ def setup_caches( self.decoder_max_cache_seq_len = self.max_seq_len for layer in self.layers: - layer.setup_cache( + layer.setup_caches( batch_size, dtype, encoder_max_seq_len=self.encoder_max_cache_seq_len, decoder_max_seq_len=self.decoder_max_cache_seq_len, ) + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.layers[0].caches_are_setup() + def caches_are_enabled(self) -> bool: - """Check if the key value caches are setup. This is useful to efficient inference.""" - return self.layers[0].cache_enabled + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using ``torchtune.modules.disable_kv_cache``, upon which ``caches_are_enabled`` would return False. + """ + return self.layers[0].caches_are_enabled() def reset_caches(self): - """Reset the key value caches.""" + """ + Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, + without deleting or reallocating cache tensors. + """ if not self.caches_are_enabled(): raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." @@ -759,7 +792,7 @@ def setup_caches( self.decoder_max_cache_seq_len = self.decoder_max_cache_seq_len for layer in self.layers: - layer.setup_cache( + layer.setup_caches( batch_size, dtype, self.encoder_max_cache_seq_len,