Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: High Latency in Stateful Inference for Obtaining State (get_state()) in OpenVINO #28474

Open
3 tasks done
piDack opened this issue Jan 16, 2025 · 5 comments
Open
3 tasks done
Assignees
Labels
performance Performance related topics support_request

Comments

@piDack
Copy link

piDack commented Jan 16, 2025

OpenVINO Version

2024.4.6

Operating System

Windows System

Device used for inference

iGPU

OpenVINO installation

PyPi

Programming Language

C++

Hardware Architecture

x86 (64 bits)

Model used

llama

Model quantization

No

Target Platform

No response

Performance issue description

We have identified a significant issue with the latency of stateful inference when obtaining the state in OpenVINO. The delay is excessively high, which impacts the overall performance of our application. The following sample code illustrates the problem:

auto start = std::chrono::high_resolution_clock::now();
ov::Tensor old_tensor = state.get_state();
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
auto shape = old_tensor.get_shape();
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double, std::milli> duration = end - start;
std::cout << "time " << duration.count() << " ms" << std::endl;

In our tests, we have observed that the latency for obtaining the state is consistently around 40ms, which is unacceptable for our real-time application requirements.

Are there any good suggestions for optimization? Modifying the state is crucial for optimizing LLM models (such as Medusa and other optimization methods)

Step-by-step reproduction

No response

Issue submission checklist

  • I'm reporting a performance issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.
@piDack piDack added performance Performance related topics support_request labels Jan 16, 2025
@dnkurek
Copy link
Contributor

dnkurek commented Jan 16, 2025

Hi, may we ask for an isolated reproducer so we may debug it?

@dnkurek
Copy link
Contributor

dnkurek commented Jan 16, 2025

Also, I have checked the code and looks like get_state for GPU plugin is returning a copy of it's internal memory buffer

ov::SoPtr<ov::ITensor> VariableState::get_state() const {
if (m_memory == nullptr) {
const auto& pshape = m_layout.get_partial_shape();
const auto& shape = get_tensor_shape(pshape);
return m_context->create_host_tensor(get_user_specified_type(), shape);
}
auto tensor = m_context->create_host_tensor(get_user_specified_type(), m_memory->get_layout().get_shape());
convert_and_copy(m_memory, tensor._ptr.get(), m_context->get_engine().get_service_stream());
return tensor;
}

It may explain why it's slow, but it's hard to me to know without any broader context

@dnkurek
Copy link
Contributor

dnkurek commented Jan 16, 2025

Instead, if m_memory or the layout is all you need then you can use this

cldnn::memory::ptr VariableState::get_memory() const {
return m_memory;
}
const cldnn::layout& VariableState::get_layout() const {
return m_layout;
}

@piDack
Copy link
Author

piDack commented Jan 17, 2025

Hi, may we ask for an isolated reproducer so we may debug it?

Thanks for you reply.
Here is the minimum reproducible code.
the openvino xml file is llama ir you can download from huggingface(https://huggingface.co/rajatkrishna/Meta-Llama-3-8B-OpenVINO-INT4/tree/main) with weight

import openvino as ov
import numpy as np
import torch
input_ids = torch.from_numpy(np.array([[13127,10,2467]],dtype=np.int64))
postion_ids = torch.from_numpy(np.array([[1, 2, 3]],dtype=np.int64))
attn_one = torch.from_numpy(np.array([[1, 1, 1]],dtype=np.int64))
beam_idx = torch.from_numpy(np.array([0],dtype=int))

inputs_1 = {}
inputs_1["input_ids"] = input_ids
inputs_1["attention_mask"] = attn_one 
inputs_1["position_ids"] = postion_ids
inputs_1["beam_idx"] = beam_idx
core=ov.Core()
model_1 = core.compile_model("path\\llama\\openvino_model.xml","GPU")
import time

infer_request_1 = model_1.create_infer_request()
infer_request_1.infer(inputs_1)
infer_request_1.infer(inputs_1)
states=infer_request_1.query_state()
for state in states:
    start_time = time.time()
    state_buf = state.state.data
    end_time = time.time()
    elapsed_time = (end_time - start_time) * 1000  # 毫秒
    print(f"Elapsed time: {elapsed_time} ms")

The output in lnl Ultra 9

Elapsed time: 18.038511276245117 ms
...
Elapsed time: 19.02484893798828 ms
...
Elapsed time: 17.007827758789062 ms
Elapsed time: 16.999244689941406 ms
Elapsed time: 46.535491943359375 ms
....

The output in MTL Ultra 5 is over than 40ms.

the latency can not afford

@piDack
Copy link
Author

piDack commented Jan 17, 2025

When we use get_states(), it automatically employs convert_and_copy_padded_source() due to the pad layout, which is very time-consuming. How can we configure it to use a no_pad layout?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance related topics support_request
Projects
None yet
Development

No branches or pull requests

2 participants