diff --git a/fastrag/generators/openvino.py b/fastrag/generators/openvino.py index 34c4a4e..95b1265 100644 --- a/fastrag/generators/openvino.py +++ b/fastrag/generators/openvino.py @@ -1,9 +1,17 @@ from typing import Any, Callable, Dict, List, Literal, Optional +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators import HuggingFaceLocalGenerator from haystack.dataclasses import StreamingChunk from haystack.lazy_imports import LazyImport -from haystack.utils import ComponentDevice, Secret +from haystack.utils import ( + ComponentDevice, + Secret, + deserialize_callable, + deserialize_secrets_inplace, + serialize_callable, +) +from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs from transformers import AutoConfig, AutoTokenizer with LazyImport("Install openvino using 'pip install -e .[openvino]'") as ov_import: @@ -12,6 +20,7 @@ DEFAULT_OV_CONFIG = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""} +@component class OpenVINOGenerator(HuggingFaceLocalGenerator): """ Generator based on a Hugging Face model loaded with OpenVINO. @@ -89,7 +98,7 @@ def __init__( :param streaming_callback: An optional callable for handling streaming responses. """ ov_import.check() - super().__init__( + super(OpenVINOGenerator, self).__init__( model=model, task=task, device=device, @@ -103,7 +112,6 @@ def __init__( self.compressed_model_dir = compressed_model_dir self.device_openvino = device_openvino self.ov_config = ov_config - self.compressed_model_dir = compressed_model_dir def warm_up(self): """ @@ -124,4 +132,35 @@ def warm_up(self): self.huggingface_pipeline_kwargs["tokenizer"] = AutoTokenizer.from_pretrained( self.model ) - super().warm_up() + super(OpenVINOGenerator, self).warm_up() + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + callback_name = ( + serialize_callable(self.streaming_callback) if self.streaming_callback else None + ) + serialization_dict = default_to_dict( + self, + huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, + generation_kwargs=self.generation_kwargs, + streaming_callback=callback_name, + stop_words=self.stop_words, + token=self.token.to_dict() if self.token else None, + model=self.model, + compressed_model_dir=self.compressed_model_dir, + device_openvino=self.device_openvino, + ov_config=self.ov_config, + ) + + huggingface_pipeline_kwargs = serialization_dict["init_parameters"][ + "huggingface_pipeline_kwargs" + ] + huggingface_pipeline_kwargs.pop("token", None) + + serialize_hf_model_kwargs(huggingface_pipeline_kwargs) + return serialization_dict