Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failed to load whisper decoder engine with paged kv cache #1930

Closed
3 of 4 tasks
MahmoudAshraf97 opened this issue Jul 10, 2024 · 7 comments
Closed
3 of 4 tasks

failed to load whisper decoder engine with paged kv cache #1930

MahmoudAshraf97 opened this issue Jul 10, 2024 · 7 comments
Labels
bug Something isn't working functionality issue

Comments

@MahmoudAshraf97
Copy link
Contributor

MahmoudAshraf97 commented Jul 10, 2024

System Info

  • CPU architecture: x86_64
  • CPU/Host memory size: 32GB DDR4
  • GPU properties
    • GPU name: RTX 3070 Ti
    • GPU memory size: 8GB
  • Libraries
    • TensorRT-LLM version: 0.12.0.dev2024070900
  • OS: Ubuntu 22.04 on WSL

Who can help?

@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Build using the official example instructions and switch remove_input_padding and paged_kv_cache to enable

INFERENCE_PRECISION=float16
WEIGHT_ONLY_PRECISION=int8
MAX_BEAM_WIDTH=4
MAX_BATCH_SIZE=8
checkpoint_dir=distil_whisper_medium_en_weights_${WEIGHT_ONLY_PRECISION}
output_dir=distil_whisper_medium_en_${WEIGHT_ONLY_PRECISION}
trtllm-build  --checkpoint_dir ${checkpoint_dir}/decoder \
              --output_dir ${output_dir}/decoder \
              --paged_kv_cache enable \
              --moe_plugin disable \
              --enable_xqa disable \
              --use_custom_all_reduce disable \
              --max_beam_width ${MAX_BEAM_WIDTH} \
              --max_batch_size ${MAX_BATCH_SIZE} \
              --max_seq_len 114 \
              --max_input_len 14 \
              --max_encoder_input_len 1500 \
              --gemm_plugin ${INFERENCE_PRECISION} \
              --bert_attention_plugin ${INFERENCE_PRECISION} \
              --gpt_attention_plugin ${INFERENCE_PRECISION} \
              --remove_input_padding enable

then load the model using the class in run.py

Expected behavior

The model should load fine

actual behavior

[07/10/2024-20:02:34] [TRT-LLM] [E] The following expected tensors are not found: {'past_key_value_1', 'cross_past_key_value_1', 'past_key_value_0', 'present_key_value_0', 'present_key_value_1', 'cross_past_key_value_0', 'cross_present_key_value_0', 'cross_present_key_value_1'}
[07/10/2024-20:02:34] [TRT-LLM] [E] Those tensors in engine are not expected: {'host_kv_cache_block_offsets', 'kv_cache_block_offsets', 'host_kv_cache_pool_pointers', 'host_cross_kv_cache_pool_pointers', 'host_cross_kv_cache_block_offsets', 'cross_kv_cache_block_offsets'}
[07/10/2024-20:02:34] [TRT-LLM] [E] Expected tensor names: ['input_ids', 'logits', 'last_token_ids', 'position_ids', 'cache_indirection', 'past_key_value_0', 'present_key_value_0', 'past_key_value_1', 'present_key_value_1', 'cross_present_key_value_0', 'cross_past_key_value_0', 'cross_present_key_value_1', 'cross_past_key_value_1', 'sequence_length', 'context_lengths', 'host_request_types', 'host_past_key_value_lengths', 'host_sink_token_length', 'host_max_attention_window_sizes', 'host_context_lengths', 'encoder_output', 'encoder_input_lengths', 'encoder_max_input_length', 'cross_kv_cache_gen']
[07/10/2024-20:02:34] [TRT-LLM] [E] Found tensor names: ['input_ids', 'position_ids', 'encoder_input_lengths', 'encoder_max_input_length', 'encoder_output', 'host_past_key_value_lengths', 'host_context_lengths', 'sequence_length', 'context_lengths', 'host_request_types', 'last_token_ids', 'cache_indirection', 'host_max_attention_window_sizes', 'host_sink_token_length', 'kv_cache_block_offsets', 'host_kv_cache_block_offsets', 'host_kv_cache_pool_pointers', 'cross_kv_cache_block_offsets', 'host_cross_kv_cache_block_offsets', 'host_cross_kv_cache_pool_pointers', 'cross_kv_cache_gen', 'logits']

{
	"name": "RuntimeError",
	"message": "Tensor names in engine are not the same as expected, to use this GenerationSession, you need to use PretrainedModel.prepare_inputs to create TRT Network inputs.",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 15
     12 accuracy_check = False  # Change to True for CI test accuracy check
     14 tensorrt_llm.logger.set_level(log_level)
---> 15 model = WhisperTRTLLM(engine_dir, debug, assets_dir)
     16 normalizer = EnglishTextNormalizer()

Cell In[2], line 172, in WhisperTRTLLM.__init__(self, engine_dir, debug_mode, assets_dir)
    169 engine_dir = Path(engine_dir)
    171 self.encoder = WhisperEncoding(engine_dir)
--> 172 self.decoder = WhisperDecoding(engine_dir, runtime_mapping, debug_mode=False)
    173 is_multilingual = self.decoder.decoder_config[\"vocab_size\"] >= 51865
    174 if is_multilingual:

Cell In[2], line 57, in WhisperDecoding.__init__(self, engine_dir, runtime_mapping, debug_mode)
     54 def __init__(self, engine_dir, runtime_mapping, debug_mode=False):
     56     self.decoder_config = self.get_config(engine_dir)
---> 57     self.decoder_generation_session = self.get_session(
     58         engine_dir, runtime_mapping, debug_mode
     59     )

Cell In[2], line 93, in WhisperDecoding.get_session(self, engine_dir, runtime_mapping, debug_mode)
     73     decoder_engine_buffer = f.read()
     75 decoder_model_config = ModelConfig(
     76     max_batch_size=self.decoder_config[\"max_batch_size\"],
     77     max_beam_width=self.decoder_config[\"max_beam_width\"],
   (...)
     91     has_token_type_embedding=False,
     92 )
---> 93 decoder_generation_session = tensorrt_llm.runtime.GenerationSession(
     94     decoder_model_config,
     95     decoder_engine_buffer,
     96     runtime_mapping,
     97     debug_mode=debug_mode,
     98 )
    100 return decoder_generation_session

File ~/.local/lib/python3.10/site-packages/tensorrt_llm/runtime/generation.py:863, in GenerationSession.__init__(self, model_config, engine_buffer, mapping, debug_mode, debug_tensors_to_save, cuda_graph_mode, stream)
    861     logger.error(f\"Expected tensor names: {expected_tensor_names}\")
    862     logger.error(f\"Found tensor names: {found_tensor_names}\")
--> 863     raise RuntimeError(
    864         \"Tensor names in engine are not the same as expected, to use this GenerationSession, \"
    865         \"you need to use PretrainedModel.prepare_inputs to create TRT Network inputs.\"
    866     )
    867 if self.debug_mode:
    868     self.debug_tensors = list(
    869         set(found_tensor_names) - set(expected_tensor_names))

RuntimeError: Tensor names in engine are not the same as expected, to use this GenerationSession, you need to use PretrainedModel.prepare_inputs to create TRT Network inputs."
}

additional notes

I build with kv cache enabled to use in-flight batching, it's not in a usable state for now but this is for another issue
check #1909

@MahmoudAshraf97 MahmoudAshraf97 added the bug Something isn't working label Jul 10, 2024
@yuekaizhang
Copy link

@MahmoudAshraf97 I am investigating the issue now and would update here later.

@MahmoudAshraf97
Copy link
Contributor Author

Reproduced with 0.12.0.dev2024071600

@yuekaizhang
Copy link

Reproduced with 0.12.0.dev2024071600

@MahmoudAshraf97 Yeah, the fixed codes have not been merged into main yet. Let me tell you here once it got merged.

@qingquansong
Copy link
Contributor

Faced same issue with 0.13.0.dev2024090300 and also have other two issues:

  1. master branch example convert_checkpoint.py code when doing fp8 quantization (either use_fp8 or use use_fpt_rowwise) has issue as well for the weight loader. I need to use the quantize.py function

  2. --use_paged_context_fmha enable --use_fp8_context_fmha enable and --gemm_swiglu_plugin fp8 all cannot be used when doing build, I feel it should not be related to the modelopt version, but just in case, could you recommend the right modelopt version for 0.13.0.dev2024090300? I'm using Version: 0.15.1 with:

ENV CUDA_VERSION 12.6
ENV CUDNN_VERSION cuda-12-9.3.0.75-1
ENV NCCL_VERSION 2.22.3-1+cuda12.6

@qingquansong
Copy link
Contributor

qingquansong commented Sep 11, 2024

For my case, it seems like the issue happens at here:

    @property
    def paged_kv_cache(self):
        return self._model_config.kv_cache_type == KVCacheType.PAGED

the self._model_config.kv_cache_type is a string PAGED (or others) but the KVCacheType.PAGED is <KVCacheType.PAGED: 1> a binding object. It will cause the judge to have issues so causing some tensor name to be added incorrectly (either missed for paged case or added extra) here:
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/runtime/generation.py#L809-L823

The right way should be something like the model_runner config builder here:
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/runtime/model_runner.py#L89

v_cache_type = KVCacheType(builder_config.get('kv_cache_type'))
and here it will be: KVCacheType(self._model_config.kv_cache_type) if self._model_config.kv_cache_type is string

However, the root cause seems should be here:

https://github.com/NVIDIA/TensorRT-LLM/blob/main/benchmarks/python/gpt_benchmark.py#L83 this somehow set the string PAGED so maybe we should do: kv_cache_type = KVCacheType (self.kv_cache_type) here ?

I put a small pr change here: #2219 let me know if it's good to check in. Have some other parts to fix in benchmarking as well (such as the quantization etc.) I'll try to see if could can fix them locally first.

@gonggqing
Copy link

@qingquansong

Many thanks! This works for me. I faced similar error when calling benchmarks/python/benchmark.py in this command
python benchmarks/python/benchmark.py -m dec --engine_dir /data/trtllm_build/qwen2-7b-instruct-trtllm-build/ --batch_size "1;8" --input_output_len "256,64;512,64" when I trying to benchmark qwen2 performance on RTX 4090.

In my case, the error code report this

root@7cb13ca13820:/code/tensorrt_llm# python benchmarks/python/benchmark.py -m dec --engine_dir /data/trtllm_build/qwen2-7b-instruct-trtllm-build/ --batch_size "1;8" --input_output_len "256,64;512,64"
[TensorRT-LLM] TensorRT-LLM version: 0.14.0.dev2024091700
[09/20/2024-03:59:59] [TRT-LLM] [E] The following expected tensors are not found: {'present_key_value_15', 'present_key_value_10', 'past_key_value_4', 'past_key_value_5', 'present_key_value_2', 'present_key_value_8', 'past_key_value_18', 'past_key_value_19', 'past_key_value_14', 'past_key_value_10', 'past_key_value_3', 'past_key_value_11', 'present_key_value_12', 'past_key_value_24', 'past_key_value_26', 'past_key_value_13', 'present_key_value_13', 'past_key_value_17', 'present_key_value_21', 'present_key_value_14', 'past_key_value_25', 'present_key_value_3', 'present_key_value_19', 'present_key_value_17', 'past_key_value_9', 'present_key_value_6', 'present_key_value_4', 'present_key_value_7', 'past_key_value_2', 'present_key_value_0', 'present_key_value_25', 'past_key_value_27', 'present_key_value_16', 'present_key_value_23', 'past_key_value_15', 'present_key_value_5', 'past_key_value_8', 'present_key_value_24', 'past_key_value_20', 'present_key_value_18', 'past_key_value_21', 'present_key_value_9', 'present_key_value_1', 'past_key_value_6', 'past_key_value_1', 'past_key_value_22', 'present_key_value_20', 'present_key_value_26', 'present_key_value_11', 'present_key_value_22', 'past_key_value_23', 'present_key_value_27', 'past_key_value_12', 'past_key_value_7', 'past_key_value_0', 'past_key_value_16'}
[09/20/2024-03:59:59] [TRT-LLM] [E] Those tensors in engine are not expected: {'kv_cache_block_offsets', 'host_kv_cache_pool_pointers', 'host_kv_cache_block_offsets'}
[09/20/2024-03:59:59] [TRT-LLM] [E] Expected tensor names: ['input_ids', 'logits', 'last_token_ids', 'position_ids', 'cache_indirection', 'past_key_value_0', 'present_key_value_0', 'past_key_value_1', 'present_key_value_1', 'past_key_value_2', 'present_key_value_2', 'past_key_value_3', 'present_key_value_3', 'past_key_value_4', 'present_key_value_4', 'past_key_value_5', 'present_key_value_5', 'past_key_value_6', 'present_key_value_6', 'past_key_value_7', 'present_key_value_7', 'past_key_value_8', 'present_key_value_8', 'past_key_value_9', 'present_key_value_9', 'past_key_value_10', 'present_key_value_10', 'past_key_value_11', 'present_key_value_11', 'past_key_value_12', 'present_key_value_12', 'past_key_value_13', 'present_key_value_13', 'past_key_value_14', 'present_key_value_14', 'past_key_value_15', 'present_key_value_15', 'past_key_value_16', 'present_key_value_16', 'past_key_value_17', 'present_key_value_17', 'past_key_value_18', 'present_key_value_18', 'past_key_value_19', 'present_key_value_19', 'past_key_value_20', 'present_key_value_20', 'past_key_value_21', 'present_key_value_21', 'past_key_value_22', 'present_key_value_22', 'past_key_value_23', 'present_key_value_23', 'past_key_value_24', 'present_key_value_24', 'past_key_value_25', 'present_key_value_25', 'past_key_value_26', 'present_key_value_26', 'past_key_value_27', 'present_key_value_27', 'sequence_length', 'host_past_key_value_lengths', 'context_lengths', 'host_request_types', 'host_sink_token_length', 'host_runtime_perf_knobs', 'host_max_attention_window_sizes', 'host_context_lengths']
[09/20/2024-03:59:59] [TRT-LLM] [E] Found tensor names: ['input_ids', 'position_ids', 'last_token_ids', 'kv_cache_block_offsets', 'host_kv_cache_block_offsets', 'host_kv_cache_pool_pointers', 'sequence_length', 'host_request_types', 'host_past_key_value_lengths', 'context_lengths', 'host_runtime_perf_knobs', 'host_context_lengths', 'host_max_attention_window_sizes', 'host_sink_token_length', 'cache_indirection', 'logits']
Traceback (most recent call last):
  File "/code/tensorrt_llm/benchmarks/python/benchmark.py", line 354, in <module>
    main(args)
  File "/code/tensorrt_llm/benchmarks/python/benchmark.py", line 209, in main
    benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options,
  File "/code/tensorrt_llm/benchmarks/python/gpt_benchmark.py", line 110, in __init__
    self.decoder = session_cls(model_config,
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/generation.py", line 930, in __init__
    raise RuntimeError(
RuntimeError: Tensor names in engine are not the same as expected, to use this GenerationSession, you need to use PretrainedModel.prepare_inputs to create TRT Network inputs.

And with this minor change locally (@qingquansong 's pr), it solved the problem:

diff --git a/benchmarks/python/gpt_benchmark.py b/benchmarks/python/gpt_benchmark.py
index 04ba2ab0..ce06c9f9 100644
--- a/benchmarks/python/gpt_benchmark.py
+++ b/benchmarks/python/gpt_benchmark.py
@@ -80,7 +80,7 @@ class GPTBenchmark(BaseBenchmark):
 
         kv_cache_type = KVCacheType.CONTINUOUS
         if hasattr(self, 'kv_cache_type'):
-            kv_cache_type = self.kv_cache_type
+            kv_cache_type = KVCacheType(self.kv_cache_type)
         else:
             if hasattr(self, 'paged_kv_cache'):
                 kv_cache_type = KVCacheType.PAGED if self.paged_kv_cache == True else KVCacheType.CONTINUOUS

@MahmoudAshraf97
Copy link
Contributor Author

not reproducable in 0.14.0.dev2024100800, closing as solved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working functionality issue
Projects
None yet
Development

No branches or pull requests

5 participants