From 20041258f6eb5931280460e4f714803315c6edf2 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 21 Nov 2024 09:35:42 +0000 Subject: [PATCH 1/4] Refactor merged_options to exclude unset fields and update unset sampling_params handling in deployments --- aana/core/models/base.py | 2 +- aana/deployments/hf_text_generation_deployment.py | 7 +++++-- aana/deployments/idefics_2_deployment.py | 14 ++++++++++---- aana/deployments/vllm_deployment.py | 7 +++++-- aana/tests/units/test_merge_options.py | 15 +++++++++++++++ 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/aana/core/models/base.py b/aana/core/models/base.py index d129c53a..efb295e7 100644 --- a/aana/core/models/base.py +++ b/aana/core/models/base.py @@ -58,7 +58,7 @@ def merged_options(default_options: OptionType, options: OptionType) -> OptionTy if type(default_options) != type(options): raise ValueError("Option type mismatch.") # noqa: TRY003 default_options_dict = default_options.model_dump() - for k, v in options.model_dump().items(): + for k, v in options.model_dump(exclude_unset=True).items(): if v is not None: default_options_dict[k] = v return options.__class__.model_validate(default_options_dict) diff --git a/aana/deployments/hf_text_generation_deployment.py b/aana/deployments/hf_text_generation_deployment.py index 0d4be92b..a46c5d8e 100644 --- a/aana/deployments/hf_text_generation_deployment.py +++ b/aana/deployments/hf_text_generation_deployment.py @@ -68,8 +68,11 @@ async def generate_stream( prompt = str(prompt) if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) prompt_input = self.tokenizer( prompt, return_tensors="pt", add_special_tokens=False diff --git a/aana/deployments/idefics_2_deployment.py b/aana/deployments/idefics_2_deployment.py index 6536bf88..09069e84 100644 --- a/aana/deployments/idefics_2_deployment.py +++ b/aana/deployments/idefics_2_deployment.py @@ -105,8 +105,11 @@ async def chat_stream( transformers.set_seed(42) if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) messages, images = dialog.to_objects() text = self.processor.apply_chat_template(messages, add_generation_prompt=True) @@ -191,8 +194,11 @@ async def chat_batch( transformers.set_seed(42) if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) text_batch = [] image_batch = [] diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 708849b5..50386cc6 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -227,8 +227,11 @@ async def generate_stream( # noqa: C901 prompt_token_ids = prompt if sampling_params is None: - sampling_params = SamplingParams() - sampling_params = merged_options(self.default_sampling_params, sampling_params) + sampling_params = self.default_sampling_params + else: + sampling_params = merged_options( + self.default_sampling_params, sampling_params + ) json_schema = sampling_params.json_schema regex_string = sampling_params.regex_string diff --git a/aana/tests/units/test_merge_options.py b/aana/tests/units/test_merge_options.py index f95ddd09..6f813bfe 100644 --- a/aana/tests/units/test_merge_options.py +++ b/aana/tests/units/test_merge_options.py @@ -12,6 +12,7 @@ class MyOptions(BaseModel): field1: str field2: int | None = None field3: bool + field4: str = "default" def test_merged_options_same_type(): @@ -46,3 +47,17 @@ class AnotherOptions(BaseModel): with pytest.raises(ValueError): merged_options(default, to_merge) + + +def test_merged_options_unset(): + """Test merged_options with unset fields.""" + default = MyOptions(field1="default1", field2=2, field3=True, field4="new_default") + to_merge = MyOptions(field1="merge1", field3=False) # field4 is not set + merged = merged_options(default, to_merge) + + assert merged.field1 == "merge1" + assert merged.field2 == 2 + assert merged.field3 == False + assert ( + merged.field4 == "new_default" + ) # Should retain value from default_options as it's not set in options From 6c8dc1c5e8605818bd7ea851d5d4f7e69b1d1ae5 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 21 Nov 2024 09:43:56 +0000 Subject: [PATCH 2/4] Reverted changes to pyproject.toml --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0290e576..507d0d35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ bitblas = "^0.0.1.dev15" bitsandbytes = "^0.42.0" decord = "^0.6.0" fastapi = ">=0.111.0" -haystack-ai = ">=2.1.0" hf-transfer = "^0.1.6" hqq = "^0.2.2" mobius-faster-whisper = ">=1.1.1" @@ -43,7 +42,7 @@ pydantic = ">=2.0" pydantic-settings = "^2.1.0" python-multipart = "^0.0.9" psycopg = {extras = ["binary"], version = "^3.2.1"} -qdrant-haystack = ">=3.2.0" +qdrant-haystack = "^3.2.1" ray = {extras = ["serve"], version = ">=2.20"} rapidfuzz = "^3.4.0" scipy = "^1.11.3" From 90e3d83a8963c3d8b5c857a1c7078e168075ca54 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 21 Nov 2024 12:36:34 +0000 Subject: [PATCH 3/4] Reverted changes in pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c77fe874..f17870f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ bitblas = "^0.0.1.dev15" bitsandbytes = "^0.42.0" decord = "^0.6.0" fastapi = ">=0.111.0" +haystack-ai = ">=2.1.0" hf-transfer = "^0.1.6" hqq = "^0.2.2" mobius-faster-whisper = ">=1.1.1" @@ -42,7 +43,7 @@ pydantic = ">=2.0" pydantic-settings = "^2.1.0" python-multipart = "^0.0.9" psycopg = {extras = ["binary"], version = "^3.2.1"} -qdrant-haystack = "^3.2.1" +qdrant-haystack = ">=3.2.0" ray = {extras = ["serve"], version = ">=2.20"} rapidfuzz = "^3.4.0" scipy = "^1.11.3" From 3b5cca85b5028aa187104bfe2a886590fc9da153 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 21 Nov 2024 13:49:26 +0000 Subject: [PATCH 4/4] Updated test files --- ...ment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json b/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json index c44d4688..ba93432f 100644 --- a/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json +++ b/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json @@ -1 +1 @@ -" The painting is done by Vincent van Gogh." \ No newline at end of file +" Vincent van Gogh" \ No newline at end of file