Skip to content

Commit

Permalink
Updated tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Sep 10, 2024
1 parent 4b23c22 commit 81b3bf9
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 14 deletions.
12 changes: 6 additions & 6 deletions aana/core/models/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class SamplingParams(BaseModel):
kwargs (dict): Extra keyword arguments to pass as sampling parameters.
"""

temperature: float | None = Field(
default=None,
temperature: float = Field(
default=1.0,
ge=0.0,
description=(
"Float that controls the randomness of the sampling. "
Expand All @@ -32,8 +32,8 @@ class SamplingParams(BaseModel):
"Zero means greedy sampling."
),
)
top_p: float | None = Field(
default=None,
top_p: float = Field(
default=1.0,
gt=0.0,
le=1.0,
description=(
Expand All @@ -51,8 +51,8 @@ class SamplingParams(BaseModel):
max_tokens: int | None = Field(
default=None, ge=1, description="The maximum number of tokens to generate."
)
repetition_penalty: float | None = Field(
default=None,
repetition_penalty: float = Field(
default=1.0,
description=(
"Float that penalizes new tokens based on whether they appear in the "
"prompt and the generated text so far. Values > 1 encourage the model "
Expand Down
1 change: 1 addition & 0 deletions aana/deployments/hf_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ async def generate_stream(
temperature=sampling_params.temperature,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=sampling_params.repetition_penalty,
**sampling_params.kwargs,
)
if sampling_params.temperature == 0:
Expand Down
2 changes: 1 addition & 1 deletion aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def generate_stream(
try:
# convert SamplingParams to VLLMSamplingParams
sampling_params_vllm = VLLMSamplingParams(
**sampling_params.model_dump(exclude_unset=True),
**sampling_params.model_dump(exclude_unset=True, exclude=["kwargs"]),
**sampling_params.kwargs,
)
# start the request
Expand Down
4 changes: 4 additions & 0 deletions aana/tests/deployments/test_idefics2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aana.core.models.chat import ChatMessage
from aana.core.models.image import Image
from aana.core.models.image_chat import ImageChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.deployments.idefics_2_deployment import Idefics2Config, Idefics2Deployment
Expand All @@ -21,6 +22,9 @@
user_config=Idefics2Config(
model="HuggingFaceM4/idefics2-8b",
dtype=Dtype.FLOAT16,
default_sampling_params=SamplingParams(
temperature=0.0, kwargs={"diversity_penalty": 0.0}
),
).model_dump(mode="json"),
),
)
Expand Down
9 changes: 8 additions & 1 deletion aana/tests/deployments/test_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
model_kwargs={
"trust_remote_code": True,
},
default_sampling_params=SamplingParams(
temperature=0.0, kwargs={"diversity_penalty": 0.0}
),
).model_dump(mode="json"),
),
),
Expand All @@ -46,7 +49,11 @@
gpu_memory_reserved=10000,
enforce_eager=True,
default_sampling_params=SamplingParams(
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024
temperature=0.0,
top_p=1.0,
top_k=-1,
max_tokens=1024,
kwargs={"frequency_penalty": 0.0},
),
engine_args={
"trust_remote_code": True,
Expand Down
12 changes: 6 additions & 6 deletions aana/tests/units/test_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def test_valid_sampling_params():
assert params.max_tokens == 50
assert params.repetition_penalty == 1.5

# Test valid params with default values (None)
# Test valid params with default values
params = SamplingParams()
assert params.temperature is None
assert params.top_p is None
assert params.temperature == 1.0
assert params.top_p == 1.0
assert params.top_k is None
assert params.max_tokens is None
assert params.repetition_penalty is None
assert params.repetition_penalty == 1.0


def test_invalid_temperature():
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_kwargs():
)
assert params.kwargs == {"presence_penalty": 2.0, "frequency_penalty": 1.0}
assert params.temperature == 0.5
assert params.top_p is None
assert params.top_p == 1.0
assert params.top_k is None
assert params.max_tokens is None
assert params.repetition_penalty is None
assert params.repetition_penalty == 1.0

0 comments on commit 81b3bf9

Please sign in to comment.