Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenVINO Serialization fix #73

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions fastrag/generators/openvino.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators import HuggingFaceLocalGenerator
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret
from haystack.utils import (
ComponentDevice,
Secret,
deserialize_callable,
deserialize_secrets_inplace,
serialize_callable,
)
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from transformers import AutoConfig, AutoTokenizer

with LazyImport("Install openvino using 'pip install -e .[openvino]'") as ov_import:
Expand All @@ -11,6 +20,7 @@
DEFAULT_OV_CONFIG = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}


@component
class OpenVINOGenerator(HuggingFaceLocalGenerator):
"""
Generator based on a Hugging Face model loaded with OpenVINO.
Expand Down Expand Up @@ -48,6 +58,7 @@ def __init__(
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
Creates an instance of a OpenVINOGenerator.
Expand Down Expand Up @@ -84,22 +95,23 @@ def __init__(
If you provide this parameter, you should not specify the `stopping_criteria` in `generation_kwargs`.
For some chat models, the output includes both the new text and the original prompt.
In these cases, it's important to make sure your prompt has no stop words.
:param streaming_callback: An optional callable for handling streaming responses.
"""
ov_import.check()
super().__init__(
super(OpenVINOGenerator, self).__init__(
model=model,
task=task,
device=device,
token=token,
generation_kwargs=generation_kwargs,
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs,
stop_words=stop_words,
streaming_callback=streaming_callback,
)
self.model = model
self.compressed_model_dir = compressed_model_dir
self.device_openvino = device_openvino
self.ov_config = ov_config
self.compressed_model_dir = compressed_model_dir

def warm_up(self):
"""
Expand All @@ -120,4 +132,35 @@ def warm_up(self):
self.huggingface_pipeline_kwargs["tokenizer"] = AutoTokenizer.from_pretrained(
self.model
)
super().warm_up()
super(OpenVINOGenerator, self).warm_up()

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
callback_name = (
serialize_callable(self.streaming_callback) if self.streaming_callback else None
)
serialization_dict = default_to_dict(
self,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
stop_words=self.stop_words,
token=self.token.to_dict() if self.token else None,
model=self.model,
compressed_model_dir=self.compressed_model_dir,
device_openvino=self.device_openvino,
ov_config=self.ov_config,
)

huggingface_pipeline_kwargs = serialization_dict["init_parameters"][
"huggingface_pipeline_kwargs"
]
huggingface_pipeline_kwargs.pop("token", None)

serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict
10 changes: 5 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[options]
install_requires =
haystack-ai==2.1.2
haystack-ai>=2.1.2
transformers>=4.35.2
datasets
evaluate
Expand All @@ -10,14 +10,14 @@ install_requires =
numba
openpyxl
numpy
protobuf==5.28.3
protobuf>=5.28.3
ujson
accelerate
fastapi
uvicorn
Pillow==10.1.0
chainlit==1.0.506
sentence-transformers==2.3.0
Pillow>=10.1.0
chainlit>=1.0.506
sentence-transformers>=2.3.0
events

[options.extras_require]
Expand Down