Skip to content

Commit

Permalink
OpenVINO Serialization fix (#73)
Browse files Browse the repository at this point in the history
Related to #65.

Co-authored-by: Nicolas Oliver <dario.n.oliver@intel.com>
  • Loading branch information
danielfleischer and dnoliver authored Nov 11, 2024
1 parent 4592240 commit 0fb03c7
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions fastrag/generators/openvino.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from typing import Any, Callable, Dict, List, Literal, Optional

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators import HuggingFaceLocalGenerator
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret
from haystack.utils import (
ComponentDevice,
Secret,
deserialize_callable,
deserialize_secrets_inplace,
serialize_callable,
)
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from transformers import AutoConfig, AutoTokenizer

with LazyImport("Install openvino using 'pip install -e .[openvino]'") as ov_import:
Expand All @@ -12,6 +20,7 @@
DEFAULT_OV_CONFIG = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}


@component
class OpenVINOGenerator(HuggingFaceLocalGenerator):
"""
Generator based on a Hugging Face model loaded with OpenVINO.
Expand Down Expand Up @@ -89,7 +98,7 @@ def __init__(
:param streaming_callback: An optional callable for handling streaming responses.
"""
ov_import.check()
super().__init__(
super(OpenVINOGenerator, self).__init__(
model=model,
task=task,
device=device,
Expand All @@ -103,7 +112,6 @@ def __init__(
self.compressed_model_dir = compressed_model_dir
self.device_openvino = device_openvino
self.ov_config = ov_config
self.compressed_model_dir = compressed_model_dir

def warm_up(self):
"""
Expand All @@ -124,4 +132,35 @@ def warm_up(self):
self.huggingface_pipeline_kwargs["tokenizer"] = AutoTokenizer.from_pretrained(
self.model
)
super().warm_up()
super(OpenVINOGenerator, self).warm_up()

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
callback_name = (
serialize_callable(self.streaming_callback) if self.streaming_callback else None
)
serialization_dict = default_to_dict(
self,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
stop_words=self.stop_words,
token=self.token.to_dict() if self.token else None,
model=self.model,
compressed_model_dir=self.compressed_model_dir,
device_openvino=self.device_openvino,
ov_config=self.ov_config,
)

huggingface_pipeline_kwargs = serialization_dict["init_parameters"][
"huggingface_pipeline_kwargs"
]
huggingface_pipeline_kwargs.pop("token", None)

serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict

0 comments on commit 0fb03c7

Please sign in to comment.