-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing recompiles in KV-cache + compile #1663
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1663
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3110bf9 with merge base 3fddc56 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device) | ||
self.size += seq_len | ||
|
||
assert (self.cache_pos[0] + seq_len) <= self.k_cache.shape[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dynamo doesn't like control flow so the error is not as informative : (
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont 100% understand why, but in previous PRs i was instructed to use ValueError instead of assert. I dont know if this is worth an exception. Probably because assertion msg is not a good as value error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message without the assertion is even worse (raising like we did before was ideal)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The message was worse or ValueError is worse? If its ok to use ValueError, can we keep it and use a nicer msg instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, there can't be any branching logic here because dynamo will complain about control flow (this is because we're checking the value of a tensor) so the plain assertion error is the best we can do.
We can do something experimental like torch._check but I'll investigate that another time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a quick pass and didnt think it through if this solution is more performant than the previous one (i worry about all of these pre allocated tensors and their memory consumption). With that being said, would something like below have fixed it?
torch._dynamo.mark_dynamic(self.size)
|
||
@property | ||
def size(self) -> int: | ||
return self.cache_pos[0].item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a rule of thumb, its never good to call item. It is not performant and has issues with torch export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q, how should I be doing this?
cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device) | ||
self.size += seq_len | ||
|
||
assert (self.cache_pos[0] + seq_len) <= self.k_cache.shape[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont 100% understand why, but in previous PRs i was instructed to use ValueError instead of assert. I dont know if this is worth an exception. Probably because assertion msg is not a good as value error
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very slick. I think the PR summary is really clear, might be good to add a bit more detail in the code itself (and/or docstring), especially around the incrementing of the cache_pos buffer. Personally I had to think about it for a minute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @SalmanMohammadi , thank you so much for this PR! I think we are a bit afraid of changing things after already testing them. Do you mind holding it for a bit and have a bit testing later this week or next week, just for peace of mind?
Summary
|
Branch | Time for inference(s) | tokens/sec | Bandwidth achieved (GB/\s) | Max memory allocated (GB) |
---|---|---|---|---|
main |
0.88 | 9.04 | 146.92 | 16.28 |
fix_kv_compile |
0.88 | 9.09 | 147.69 | 16.29 |
eleuther_eval
Branch | Time for completion(s) | it/sec | Max memory allocated (GB) |
---|---|---|---|
main |
368.69 | 3.45 | 21.31 |
fix_kv_compiile |
368.70 | 3.45 | 21.32 |
Outputs for generation and metrics for evaluation were identical. The suite of tests in tests/torchtune/generation
also guard against generation outputs changing using KV-cacheing. Please let me know if there's more testing you'd like to see!
Raw logs
On main
generate_v2
INFO:torchtune.utils._logging:Running InferenceRecipe with resolved config:
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files:
- model-00001-of-00004.safetensors
- model-00002-of-00004.safetensors
- model-00003-of-00004.safetensors
- model-00004-of-00004.safetensors
model_type: LLAMA3
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
recipe_checkpoint: null
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 200
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
seed: 1234
temperature: 0.6
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: null
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
top_k: 300
INFO:torchtune.utils._logging:Model was initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:
The capital of France is Paris!
INFO:torchtune.utils._logging:Time for inference: 0.88 sec total, 9.04 tokens/sec
INFO:torchtune.utils._logging:Bandwidth achieved: 146.92 GB/s
INFO:torchtune.utils._logging:Max memory allocated: 16.29 GB
eleuther_eval
INFO:torchtune.utils._logging:Running EleutherEvalRecipe with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files:
- model-00001-of-00004.safetensors
- model-00002-of-00004.safetensors
- model-00003-of-00004.safetensors
- model-00004-of-00004.safetensors
model_type: LLAMA3
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
recipe_checkpoint: null
device: cuda
dtype: bf16
enable_kv_cache: true
limit: 100
max_seq_length: 4096
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
quantizer: null
seed: 1234
tasks:
- truthfulqa_gen
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: null
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
INFO:lm-eval:`group` and `group_alias` keys in TaskConfigs are deprecated and will be removed in v0.4.5 of lm_eval. The new `tag` field will be used to allow for a shortcut to a group of tasks one does not wish to aggregate metrics across. `group`s which aggregate across subtasks must be only defined in a separate group config file, which will be the official way to create groups that support cross-task aggregation as in `mmlu`. Please see the v0.4.4 patch notes and our documentation: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md#advanced-group-configs for more information.
INFO:torchtune.utils._logging:Running evaluation on the following tasks: ['truthfulqa_gen']
INFO:lm-eval:Building contexts for truthfulqa_gen on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 4057.80it/s]
INFO:lm-eval:Running generate_until requests
Running generate_until requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:45<00:00, 3.45s/it]
INFO:torchtune.utils._logging:Eval completed in 368.69 seconds.
INFO:torchtune.utils._logging:Max memory allocated: 21.31 GB
INFO:torchtune.utils._logging:
Tasks | Version | Filter | n-shot | Metric | Value | Stderr | ||
---|---|---|---|---|---|---|---|---|
truthfulqa_gen | 3 | none | 0 | bleu_acc | ↑ | 0.5700 | ± | 0.0498 |
none | 0 | bleu_diff | ↑ | 14.1211 | ± | 3.2247 | ||
none | 0 | bleu_max | ↑ | 36.3770 | ± | 2.5794 | ||
none | 0 | rouge1_acc | ↑ | 0.5700 | ± | 0.0498 | ||
none | 0 | rouge1_diff | ↑ | 20.3622 | ± | 4.4249 | ||
none | 0 | rouge1_max | ↑ | 60.7925 | ± | 2.9996 | ||
none | 0 | rouge2_acc | ↑ | 0.5300 | ± | 0.0502 | ||
none | 0 | rouge2_diff | ↑ | 20.8200 | ± | 4.8412 | ||
none | 0 | rouge2_max | ↑ | 49.8049 | ± | 3.5523 | ||
none | 0 | rougeL_acc | ↑ | 0.5600 | ± | 0.0499 | ||
none | 0 | rougeL_diff | ↑ | 19.7338 | ± | 4.4891 | ||
none | 0 | rougeL_max | ↑ | 59.3272 | ± | 3.1075 |
On fix_kv_compile
generate_v2
INFO:torchtune.utils._logging:Running InferenceRecipe with resolved config:
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files:
- model-00001-of-00004.safetensors
- model-00002-of-00004.safetensors
- model-00003-of-00004.safetensors
- model-00004-of-00004.safetensors
model_type: LLAMA3
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
recipe_checkpoint: null
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 200
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
seed: 1234
temperature: 0.6
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: null
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
top_k: 300
INFO:torchtune.utils._logging:Model was initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:
The capital of France is Paris!
INFO:torchtune.utils._logging:Time for inference: 0.88 sec total, 9.09 tokens/sec
INFO:torchtune.utils._logging:Bandwidth achieved: 147.69 GB/s
INFO:torchtune.utils._logging:Max memory allocated: 16.29 GB
eleuther_eval
INFO:torchtune.utils._logging:Running EleutherEvalRecipe with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files:
- model-00001-of-00004.safetensors
- model-00002-of-00004.safetensors
- model-00003-of-00004.safetensors
- model-00004-of-00004.safetensors
model_type: LLAMA3
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
recipe_checkpoint: null
device: cuda
dtype: bf16
enable_kv_cache: true
limit: 100
max_seq_length: 4096
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
quantizer: null
seed: 1234
tasks:
- truthfulqa_gen
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: null
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
INFO:lm-eval:`group` and `group_alias` keys in TaskConfigs are deprecated and will be removed in v0.4.5 of lm_eval. The new `tag` field will be used to allow for a shortcut to a group of tasks one does not wish to aggregate metrics across. `group`s which aggregate across subtasks must be only defined in a separate group config file, which will be the official way to create groups that support cross-task aggregation as in `mmlu`. Please see the v0.4.4 patch notes and our documentation: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md#advanced-group-configs for more information.
INFO:torchtune.utils._logging:Running evaluation on the following tasks: ['truthfulqa_gen']
INFO:lm-eval:Building contexts for truthfulqa_gen on rank 0...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3991.08it/s]
INFO:lm-eval:Running generate_until requests
Running generate_until requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [05:45<00:00, 3.45s/it]
INFO:torchtune.utils._logging:Eval completed in 368.70 seconds.
INFO:torchtune.utils._logging:Max memory allocated: 21.31 GB
INFO:torchtune.utils._logging:
Tasks | Version | Filter | n-shot | Metric | Value | Stderr | ||
---|---|---|---|---|---|---|---|---|
truthfulqa_gen | 3 | none | 0 | bleu_acc | ↑ | 0.5700 | ± | 0.0498 |
none | 0 | bleu_diff | ↑ | 14.1211 | ± | 3.2247 | ||
none | 0 | bleu_max | ↑ | 36.3770 | ± | 2.5794 | ||
none | 0 | rouge1_acc | ↑ | 0.5700 | ± | 0.0498 | ||
none | 0 | rouge1_diff | ↑ | 20.3622 | ± | 4.4249 | ||
none | 0 | rouge1_max | ↑ | 60.7925 | ± | 2.9996 | ||
none | 0 | rouge2_acc | ↑ | 0.5300 | ± | 0.0502 | ||
none | 0 | rouge2_diff | ↑ | 20.8200 | ± | 4.8412 | ||
none | 0 | rouge2_max | ↑ | 49.8049 | ± | 3.5523 | ||
none | 0 | rougeL_acc | ↑ | 0.5600 | ± | 0.0499 | ||
none | 0 | rougeL_diff | ↑ | 19.7338 | ± | 4.4891 | ||
none | 0 | rougeL_max | ↑ | 59.3272 | ± | 3.1075 |
Context
What is the purpose of this PR? Is it to
Unbeknownst to me, when #1449 landed it broke compatibility with compile. This PR fixes it.
First, let's look at the problem. With
torch._logging.set_logs(recompiles=True)
:What's going on here? In
KVCache.update
:we're using
self.size
to track the current position of the cache and use it to retrieve subsequent positions, and we're triggering a recompile every time we successively run the graph and incrementself.size
.How can we get around this? No dynamism! Let's prefill the whole length of
cache_pos
we need, and then just index into them correctly. To get around having to keep track of an integer variable of our current cache position, we can just increment our positions every time we update:It seems kind of weird to do it this way (imagine if I used
torch.roll
??), but it ends up being very compile friendly. We're initializingcache_pos
for the longest possible pre-fill, after which we'll only ever needcache_pos[0]
to indicate the current position in the cache for next-token prediction.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example