Skip to content

Commit

Permalink
Merge pull request #206 from mobiusml/merge_options_bug_fix
Browse files Browse the repository at this point in the history
Fix Sampling Parameters Merging
  • Loading branch information
movchan74 authored Nov 21, 2024
2 parents 5d01dc9 + 3b5cca8 commit a45d79e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 10 deletions.
2 changes: 1 addition & 1 deletion aana/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def merged_options(default_options: OptionType, options: OptionType) -> OptionTy
if type(default_options) != type(options):
raise ValueError("Option type mismatch.") # noqa: TRY003
default_options_dict = default_options.model_dump()
for k, v in options.model_dump().items():
for k, v in options.model_dump(exclude_unset=True).items():
if v is not None:
default_options_dict[k] = v
return options.__class__.model_validate(default_options_dict)
Expand Down
7 changes: 5 additions & 2 deletions aana/deployments/hf_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ async def generate_stream(
prompt = str(prompt)

if sampling_params is None:
sampling_params = SamplingParams()
sampling_params = merged_options(self.default_sampling_params, sampling_params)
sampling_params = self.default_sampling_params
else:
sampling_params = merged_options(
self.default_sampling_params, sampling_params
)

prompt_input = self.tokenizer(
prompt, return_tensors="pt", add_special_tokens=False
Expand Down
14 changes: 10 additions & 4 deletions aana/deployments/idefics_2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ async def chat_stream(
transformers.set_seed(42)

if sampling_params is None:
sampling_params = SamplingParams()
sampling_params = merged_options(self.default_sampling_params, sampling_params)
sampling_params = self.default_sampling_params
else:
sampling_params = merged_options(
self.default_sampling_params, sampling_params
)

messages, images = dialog.to_objects()
text = self.processor.apply_chat_template(messages, add_generation_prompt=True)
Expand Down Expand Up @@ -191,8 +194,11 @@ async def chat_batch(
transformers.set_seed(42)

if sampling_params is None:
sampling_params = SamplingParams()
sampling_params = merged_options(self.default_sampling_params, sampling_params)
sampling_params = self.default_sampling_params
else:
sampling_params = merged_options(
self.default_sampling_params, sampling_params
)

text_batch = []
image_batch = []
Expand Down
7 changes: 5 additions & 2 deletions aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,11 @@ async def generate_stream( # noqa: C901
prompt_token_ids = prompt

if sampling_params is None:
sampling_params = SamplingParams()
sampling_params = merged_options(self.default_sampling_params, sampling_params)
sampling_params = self.default_sampling_params
else:
sampling_params = merged_options(
self.default_sampling_params, sampling_params
)

json_schema = sampling_params.json_schema
regex_string = sampling_params.regex_string
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
" The painting is done by Vincent van Gogh."
" Vincent van Gogh"
15 changes: 15 additions & 0 deletions aana/tests/units/test_merge_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class MyOptions(BaseModel):
field1: str
field2: int | None = None
field3: bool
field4: str = "default"


def test_merged_options_same_type():
Expand Down Expand Up @@ -46,3 +47,17 @@ class AnotherOptions(BaseModel):

with pytest.raises(ValueError):
merged_options(default, to_merge)


def test_merged_options_unset():
"""Test merged_options with unset fields."""
default = MyOptions(field1="default1", field2=2, field3=True, field4="new_default")
to_merge = MyOptions(field1="merge1", field3=False) # field4 is not set
merged = merged_options(default, to_merge)

assert merged.field1 == "merge1"
assert merged.field2 == 2
assert merged.field3 == False
assert (
merged.field4 == "new_default"
) # Should retain value from default_options as it's not set in options

0 comments on commit a45d79e

Please sign in to comment.