diff --git a/aana/core/models/chat.py b/aana/core/models/chat.py index e6ea82e0..c102f3d6 100644 --- a/aana/core/models/chat.py +++ b/aana/core/models/chat.py @@ -132,6 +132,7 @@ class ChatCompletionRequest(BaseModel): temperature (float): float that controls the randomness of the sampling top_p (float): float that controls the cumulative probability of the top tokens to consider max_tokens (int): the maximum number of tokens to generate + repetition_penalty (float): float that penalizes new tokens based on whether they appear in the prompt and the generated text so far stream (bool): if set, partial message deltas will be sent """ @@ -139,8 +140,8 @@ class ChatCompletionRequest(BaseModel): messages: list[ChatMessage] = Field( ..., description="A list of messages comprising the conversation so far." ) - temperature: float | None = Field( - default=None, + temperature: float = Field( + default=1.0, ge=0.0, description=( "Float that controls the randomness of the sampling. " @@ -149,8 +150,8 @@ class ChatCompletionRequest(BaseModel): "Zero means greedy sampling." ), ) - top_p: float | None = Field( - default=None, + top_p: float = Field( + default=1.0, gt=0.0, le=1.0, description=( @@ -161,6 +162,15 @@ class ChatCompletionRequest(BaseModel): max_tokens: int | None = Field( default=None, ge=1, description="The maximum number of tokens to generate." ) + repetition_penalty: float = Field( + default=1.0, + description=( + "Float that penalizes new tokens based on whether they appear in the " + "prompt and the generated text so far. Values > 1 encourage the model " + "to use new tokens, while values < 1 encourage the model to repeat tokens. " + "Default is 1.0 (no penalty)." + ), + ) stream: bool | None = Field( default=False, diff --git a/aana/core/models/sampling.py b/aana/core/models/sampling.py index 0c76d6e5..7af91429 100644 --- a/aana/core/models/sampling.py +++ b/aana/core/models/sampling.py @@ -15,10 +15,15 @@ class SamplingParams(BaseModel): top_k (int): Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. max_tokens (int): The maximum number of tokens to generate. + repetition_penalty (float): Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 encourage + the model to use new tokens, while values < 1 encourage the model to repeat + tokens. Default is 1.0 (no penalty). + kwargs (dict): Extra keyword arguments to pass as sampling parameters. """ - temperature: float | None = Field( - default=None, + temperature: float = Field( + default=1.0, ge=0.0, description=( "Float that controls the randomness of the sampling. " @@ -27,8 +32,8 @@ class SamplingParams(BaseModel): "Zero means greedy sampling." ), ) - top_p: float | None = Field( - default=None, + top_p: float = Field( + default=1.0, gt=0.0, le=1.0, description=( @@ -46,6 +51,19 @@ class SamplingParams(BaseModel): max_tokens: int | None = Field( default=None, ge=1, description="The maximum number of tokens to generate." ) + repetition_penalty: float = Field( + default=1.0, + description=( + "Float that penalizes new tokens based on whether they appear in the " + "prompt and the generated text so far. Values > 1 encourage the model " + "to use new tokens, while values < 1 encourage the model to repeat tokens. " + "Default is 1.0 (no penalty)." + ), + ) + kwargs: dict = Field( + default_factory=dict, + description="Extra keyword arguments to pass as sampling parameters.", + ) @field_validator("top_k") def check_top_k(cls, v: int): diff --git a/aana/deployments/hf_text_generation_deployment.py b/aana/deployments/hf_text_generation_deployment.py index eed29827..bc147786 100644 --- a/aana/deployments/hf_text_generation_deployment.py +++ b/aana/deployments/hf_text_generation_deployment.py @@ -132,6 +132,8 @@ async def generate_stream( temperature=sampling_params.temperature, num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id, + repetition_penalty=sampling_params.repetition_penalty, + **sampling_params.kwargs, ) if sampling_params.temperature == 0: generation_kwargs["do_sample"] = False diff --git a/aana/deployments/idefics_2_deployment.py b/aana/deployments/idefics_2_deployment.py index e7f7becd..e6c304fa 100644 --- a/aana/deployments/idefics_2_deployment.py +++ b/aana/deployments/idefics_2_deployment.py @@ -145,6 +145,7 @@ async def chat_stream( temperature=sampling_params.temperature, num_return_sequences=1, eos_token_id=self.processor.tokenizer.eos_token_id, + **sampling_params.kwargs, ) if sampling_params.temperature == 0: generation_kwargs["do_sample"] = False diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 61382d62..59223779 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -148,7 +148,8 @@ async def generate_stream( try: # convert SamplingParams to VLLMSamplingParams sampling_params_vllm = VLLMSamplingParams( - **sampling_params.model_dump(exclude_unset=True) + **sampling_params.model_dump(exclude_unset=True, exclude=["kwargs"]), + **sampling_params.kwargs, ) # start the request request_id = random_uuid() diff --git a/aana/tests/deployments/test_idefics2_deployment.py b/aana/tests/deployments/test_idefics2_deployment.py index cbdecb51..03fba4c9 100644 --- a/aana/tests/deployments/test_idefics2_deployment.py +++ b/aana/tests/deployments/test_idefics2_deployment.py @@ -6,6 +6,7 @@ from aana.core.models.chat import ChatMessage from aana.core.models.image import Image from aana.core.models.image_chat import ImageChatDialog +from aana.core.models.sampling import SamplingParams from aana.core.models.types import Dtype from aana.deployments.aana_deployment_handle import AanaDeploymentHandle from aana.deployments.idefics_2_deployment import Idefics2Config, Idefics2Deployment @@ -21,6 +22,9 @@ user_config=Idefics2Config( model="HuggingFaceM4/idefics2-8b", dtype=Dtype.FLOAT16, + default_sampling_params=SamplingParams( + temperature=1.0, max_tokens=256, kwargs={"diversity_penalty": 0.0} + ), ).model_dump(mode="json"), ), ) diff --git a/aana/tests/deployments/test_text_generation_deployment.py b/aana/tests/deployments/test_text_generation_deployment.py index bb9924bc..48a447b1 100644 --- a/aana/tests/deployments/test_text_generation_deployment.py +++ b/aana/tests/deployments/test_text_generation_deployment.py @@ -28,6 +28,9 @@ model_kwargs={ "trust_remote_code": True, }, + default_sampling_params=SamplingParams( + temperature=0.0, kwargs={"diversity_penalty": 0.0} + ), ).model_dump(mode="json"), ), ), @@ -46,7 +49,11 @@ gpu_memory_reserved=10000, enforce_eager=True, default_sampling_params=SamplingParams( - temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024 + temperature=0.0, + top_p=1.0, + top_k=-1, + max_tokens=1024, + kwargs={"frequency_penalty": 0.0}, ), engine_args={ "trust_remote_code": True, diff --git a/aana/tests/units/test_sampling_params.py b/aana/tests/units/test_sampling_params.py index aaaf137b..b66136a8 100644 --- a/aana/tests/units/test_sampling_params.py +++ b/aana/tests/units/test_sampling_params.py @@ -6,18 +6,22 @@ def test_valid_sampling_params(): """Test valid sampling parameters.""" - params = SamplingParams(temperature=0.5, top_p=0.9, top_k=10, max_tokens=50) + params = SamplingParams( + temperature=0.5, top_p=0.9, top_k=10, max_tokens=50, repetition_penalty=1.5 + ) assert params.temperature == 0.5 assert params.top_p == 0.9 assert params.top_k == 10 assert params.max_tokens == 50 + assert params.repetition_penalty == 1.5 - # Test valid params with default values (None) + # Test valid params with default values params = SamplingParams() - assert params.temperature is None - assert params.top_p is None + assert params.temperature == 1.0 + assert params.top_p == 1.0 assert params.top_k is None assert params.max_tokens is None + assert params.repetition_penalty == 1.0 def test_invalid_temperature(): @@ -46,3 +50,16 @@ def test_invalid_max_tokens(): """Test invalid max_tokens values.""" with pytest.raises(ValueError): SamplingParams(max_tokens=0) + + +def test_kwargs(): + """Test extra keyword arguments.""" + params = SamplingParams( + temperature=0.5, kwargs={"presence_penalty": 2.0, "frequency_penalty": 1.0} + ) + assert params.kwargs == {"presence_penalty": 2.0, "frequency_penalty": 1.0} + assert params.temperature == 0.5 + assert params.top_p == 1.0 + assert params.top_k is None + assert params.max_tokens is None + assert params.repetition_penalty == 1.0