-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapt OpenVINO CPU plugin implementation #4
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from vllm.utils import is_openvino_optimum_intel | ||
|
||
import openvino as ov | ||
from openvino import Type | ||
|
||
|
||
def _flattenize_inputs(inputs): | ||
|
@@ -56,7 +57,8 @@ def ov_wrapper(self, *args, **kwargs) -> torch.Tensor: | |
|
||
def patch_stateful_model( | ||
model: ov.Model, | ||
factory): | ||
factory, | ||
kv_cache_dtype: Type): | ||
print('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM') | ||
from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher, AnyInput, Or | ||
from openvino.runtime import opset13 | ||
|
@@ -128,8 +130,8 @@ def callback(m: Matcher) -> bool: | |
real_v = mapping[v_current] | ||
hidden_shape = real_q.get_partial_shape() | ||
hidden_dim = hidden_shape[hidden_shape.rank.get_length() - 1].get_length() # TODO: What if it is a dynamic? Need to insert a ShapeOf sub-graph instead | ||
k_parameter = opset13.parameter(shape=[-1, -1, -1, -1, -1], dtype=np.float32) | ||
v_parameter = opset13.parameter(shape=[-1, -1, -1, -1], dtype=np.float32) | ||
k_parameter = opset13.parameter(shape=[-1, -1, -1, -1, -1], dtype=kv_cache_dtype) | ||
v_parameter = opset13.parameter(shape=[-1, -1, -1, -1], dtype=kv_cache_dtype) | ||
Comment on lines
+133
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it correct expectation, that CPU doesn't require any specific dimension in this part except the correct data type? But for GPU, specific dimensions is strongly required in static form? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, CPU can work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why? I expect GPU can also work with dynamic shapes |
||
kv_parameters.append(k_parameter) | ||
kv_parameters.append(v_parameter) | ||
# TODO: The rank 4 is used in the following code, but it is not guaranteed for all models, adopt to other ranks. | ||
|
@@ -274,7 +276,8 @@ def has_parameter(model, name): | |
|
||
def _patch_model_with_openvino( | ||
pt_model: torch.nn.Module, | ||
model_config: ModelConfig): | ||
model_config: ModelConfig, | ||
kv_cache_dtype: Type): | ||
print(' ============= PATCHING MODEL =============') | ||
from vllm.model_executor.layers.attention.attention import Attention | ||
from openvino.frontend.pytorch import ModuleExtension | ||
|
@@ -294,7 +297,15 @@ def _patch_model_with_openvino( | |
|
||
# Prepare example inputs | ||
|
||
kv_cache_dtype = torch.float32 | ||
torch_dtype_maping = { | ||
Type.boolean: torch.bool, | ||
Type.f32: torch.float32, | ||
Type.f16: torch.float16, | ||
Type.bf16: torch.bfloat16, | ||
Type.i32: torch.int32, | ||
Type.i64: torch.int64 | ||
} | ||
kv_cache_dtype = torch_dtype_maping[kv_cache_dtype] | ||
num_heads = pt_model.config.num_attention_heads | ||
num_kv_heads = num_heads | ||
head_size = pt_model.config.hidden_size // num_kv_heads | ||
|
@@ -423,6 +434,7 @@ def ov_sample( | |
|
||
def get_model(model_config: ModelConfig, | ||
device_config: DeviceConfig, | ||
kv_cache_dtype: Type, | ||
**kwargs) -> torch.nn.Module: | ||
lora_config = kwargs.get("lora_config", None) | ||
if lora_config: | ||
|
@@ -443,7 +455,7 @@ def get_model(model_config: ModelConfig, | |
# Keep factory to destroy it in a particular moment when all other objects referencing custom nodes are destoyed | ||
pt_model.ov_node_factory = NodeFactory() | ||
pt_model.ov_node_factory.add_extension('libuser_ov_extensions.so') | ||
patch_stateful_model(pt_model.model, pt_model.ov_node_factory) | ||
patch_stateful_model(pt_model.model, pt_model.ov_node_factory, kv_cache_dtype) | ||
core = ov.Core() | ||
ov_compiled = core.compile_model(pt_model.model, "CPU") | ||
pt_model._ov_request = ov_compiled.create_infer_request() | ||
|
@@ -457,6 +469,6 @@ def get_model(model_config: ModelConfig, | |
else: | ||
from vllm.model_executor.model_loader import get_model | ||
pt_model = get_model(model_config, device_config, **kwargs) | ||
_patch_model_with_openvino(pt_model, model_config) | ||
_patch_model_with_openvino(pt_model, model_config, kv_cache_dtype) | ||
|
||
return pt_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for me, optimized CPU kernel does not work for FP16 cache dtype
I had to use FP32 for correct behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please try again with the newest?