From 4b23c22a22a44ac3ca95838e6e585e4847c3f7b1 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 10 Sep 2024 10:09:23 +0000 Subject: [PATCH 1/4] Add `repetition_penalty` and `kwargs` parameters to `SamplingParams` --- aana/core/models/sampling.py | 18 ++++++++++++++++++ .../hf_text_generation_deployment.py | 1 + aana/deployments/idefics_2_deployment.py | 1 + aana/deployments/vllm_deployment.py | 3 ++- aana/tests/units/test_sampling_params.py | 19 ++++++++++++++++++- 5 files changed, 40 insertions(+), 2 deletions(-) diff --git a/aana/core/models/sampling.py b/aana/core/models/sampling.py index 0c76d6e5..843f7dc6 100644 --- a/aana/core/models/sampling.py +++ b/aana/core/models/sampling.py @@ -15,6 +15,11 @@ class SamplingParams(BaseModel): top_k (int): Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. max_tokens (int): The maximum number of tokens to generate. + repetition_penalty (float): Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 encourage + the model to use new tokens, while values < 1 encourage the model to repeat + tokens. Default is 1.0 (no penalty). + kwargs (dict): Extra keyword arguments to pass as sampling parameters. """ temperature: float | None = Field( @@ -46,6 +51,19 @@ class SamplingParams(BaseModel): max_tokens: int | None = Field( default=None, ge=1, description="The maximum number of tokens to generate." ) + repetition_penalty: float | None = Field( + default=None, + description=( + "Float that penalizes new tokens based on whether they appear in the " + "prompt and the generated text so far. Values > 1 encourage the model " + "to use new tokens, while values < 1 encourage the model to repeat tokens. " + "Default is 1.0 (no penalty)." + ), + ) + kwargs: dict = Field( + default_factory=dict, + description="Extra keyword arguments to pass as sampling parameters.", + ) @field_validator("top_k") def check_top_k(cls, v: int): diff --git a/aana/deployments/hf_text_generation_deployment.py b/aana/deployments/hf_text_generation_deployment.py index eed29827..021f97a2 100644 --- a/aana/deployments/hf_text_generation_deployment.py +++ b/aana/deployments/hf_text_generation_deployment.py @@ -132,6 +132,7 @@ async def generate_stream( temperature=sampling_params.temperature, num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id, + **sampling_params.kwargs, ) if sampling_params.temperature == 0: generation_kwargs["do_sample"] = False diff --git a/aana/deployments/idefics_2_deployment.py b/aana/deployments/idefics_2_deployment.py index e7f7becd..e6c304fa 100644 --- a/aana/deployments/idefics_2_deployment.py +++ b/aana/deployments/idefics_2_deployment.py @@ -145,6 +145,7 @@ async def chat_stream( temperature=sampling_params.temperature, num_return_sequences=1, eos_token_id=self.processor.tokenizer.eos_token_id, + **sampling_params.kwargs, ) if sampling_params.temperature == 0: generation_kwargs["do_sample"] = False diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 61382d62..520cdb0c 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -148,7 +148,8 @@ async def generate_stream( try: # convert SamplingParams to VLLMSamplingParams sampling_params_vllm = VLLMSamplingParams( - **sampling_params.model_dump(exclude_unset=True) + **sampling_params.model_dump(exclude_unset=True), + **sampling_params.kwargs, ) # start the request request_id = random_uuid() diff --git a/aana/tests/units/test_sampling_params.py b/aana/tests/units/test_sampling_params.py index aaaf137b..253309f0 100644 --- a/aana/tests/units/test_sampling_params.py +++ b/aana/tests/units/test_sampling_params.py @@ -6,11 +6,14 @@ def test_valid_sampling_params(): """Test valid sampling parameters.""" - params = SamplingParams(temperature=0.5, top_p=0.9, top_k=10, max_tokens=50) + params = SamplingParams( + temperature=0.5, top_p=0.9, top_k=10, max_tokens=50, repetition_penalty=1.5 + ) assert params.temperature == 0.5 assert params.top_p == 0.9 assert params.top_k == 10 assert params.max_tokens == 50 + assert params.repetition_penalty == 1.5 # Test valid params with default values (None) params = SamplingParams() @@ -18,6 +21,7 @@ def test_valid_sampling_params(): assert params.top_p is None assert params.top_k is None assert params.max_tokens is None + assert params.repetition_penalty is None def test_invalid_temperature(): @@ -46,3 +50,16 @@ def test_invalid_max_tokens(): """Test invalid max_tokens values.""" with pytest.raises(ValueError): SamplingParams(max_tokens=0) + + +def test_kwargs(): + """Test extra keyword arguments.""" + params = SamplingParams( + temperature=0.5, kwargs={"presence_penalty": 2.0, "frequency_penalty": 1.0} + ) + assert params.kwargs == {"presence_penalty": 2.0, "frequency_penalty": 1.0} + assert params.temperature == 0.5 + assert params.top_p is None + assert params.top_k is None + assert params.max_tokens is None + assert params.repetition_penalty is None From 81b3bf93e9c179dd133f1294e6c7ecf25ed4d354 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 10 Sep 2024 10:56:34 +0000 Subject: [PATCH 2/4] Updated tests. --- aana/core/models/sampling.py | 12 ++++++------ aana/deployments/hf_text_generation_deployment.py | 1 + aana/deployments/vllm_deployment.py | 2 +- aana/tests/deployments/test_idefics2_deployment.py | 4 ++++ .../deployments/test_text_generation_deployment.py | 9 ++++++++- aana/tests/units/test_sampling_params.py | 12 ++++++------ 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/aana/core/models/sampling.py b/aana/core/models/sampling.py index 843f7dc6..7af91429 100644 --- a/aana/core/models/sampling.py +++ b/aana/core/models/sampling.py @@ -22,8 +22,8 @@ class SamplingParams(BaseModel): kwargs (dict): Extra keyword arguments to pass as sampling parameters. """ - temperature: float | None = Field( - default=None, + temperature: float = Field( + default=1.0, ge=0.0, description=( "Float that controls the randomness of the sampling. " @@ -32,8 +32,8 @@ class SamplingParams(BaseModel): "Zero means greedy sampling." ), ) - top_p: float | None = Field( - default=None, + top_p: float = Field( + default=1.0, gt=0.0, le=1.0, description=( @@ -51,8 +51,8 @@ class SamplingParams(BaseModel): max_tokens: int | None = Field( default=None, ge=1, description="The maximum number of tokens to generate." ) - repetition_penalty: float | None = Field( - default=None, + repetition_penalty: float = Field( + default=1.0, description=( "Float that penalizes new tokens based on whether they appear in the " "prompt and the generated text so far. Values > 1 encourage the model " diff --git a/aana/deployments/hf_text_generation_deployment.py b/aana/deployments/hf_text_generation_deployment.py index 021f97a2..bc147786 100644 --- a/aana/deployments/hf_text_generation_deployment.py +++ b/aana/deployments/hf_text_generation_deployment.py @@ -132,6 +132,7 @@ async def generate_stream( temperature=sampling_params.temperature, num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id, + repetition_penalty=sampling_params.repetition_penalty, **sampling_params.kwargs, ) if sampling_params.temperature == 0: diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 520cdb0c..59223779 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -148,7 +148,7 @@ async def generate_stream( try: # convert SamplingParams to VLLMSamplingParams sampling_params_vllm = VLLMSamplingParams( - **sampling_params.model_dump(exclude_unset=True), + **sampling_params.model_dump(exclude_unset=True, exclude=["kwargs"]), **sampling_params.kwargs, ) # start the request diff --git a/aana/tests/deployments/test_idefics2_deployment.py b/aana/tests/deployments/test_idefics2_deployment.py index cbdecb51..80f91fe6 100644 --- a/aana/tests/deployments/test_idefics2_deployment.py +++ b/aana/tests/deployments/test_idefics2_deployment.py @@ -6,6 +6,7 @@ from aana.core.models.chat import ChatMessage from aana.core.models.image import Image from aana.core.models.image_chat import ImageChatDialog +from aana.core.models.sampling import SamplingParams from aana.core.models.types import Dtype from aana.deployments.aana_deployment_handle import AanaDeploymentHandle from aana.deployments.idefics_2_deployment import Idefics2Config, Idefics2Deployment @@ -21,6 +22,9 @@ user_config=Idefics2Config( model="HuggingFaceM4/idefics2-8b", dtype=Dtype.FLOAT16, + default_sampling_params=SamplingParams( + temperature=0.0, kwargs={"diversity_penalty": 0.0} + ), ).model_dump(mode="json"), ), ) diff --git a/aana/tests/deployments/test_text_generation_deployment.py b/aana/tests/deployments/test_text_generation_deployment.py index bb9924bc..48a447b1 100644 --- a/aana/tests/deployments/test_text_generation_deployment.py +++ b/aana/tests/deployments/test_text_generation_deployment.py @@ -28,6 +28,9 @@ model_kwargs={ "trust_remote_code": True, }, + default_sampling_params=SamplingParams( + temperature=0.0, kwargs={"diversity_penalty": 0.0} + ), ).model_dump(mode="json"), ), ), @@ -46,7 +49,11 @@ gpu_memory_reserved=10000, enforce_eager=True, default_sampling_params=SamplingParams( - temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024 + temperature=0.0, + top_p=1.0, + top_k=-1, + max_tokens=1024, + kwargs={"frequency_penalty": 0.0}, ), engine_args={ "trust_remote_code": True, diff --git a/aana/tests/units/test_sampling_params.py b/aana/tests/units/test_sampling_params.py index 253309f0..b66136a8 100644 --- a/aana/tests/units/test_sampling_params.py +++ b/aana/tests/units/test_sampling_params.py @@ -15,13 +15,13 @@ def test_valid_sampling_params(): assert params.max_tokens == 50 assert params.repetition_penalty == 1.5 - # Test valid params with default values (None) + # Test valid params with default values params = SamplingParams() - assert params.temperature is None - assert params.top_p is None + assert params.temperature == 1.0 + assert params.top_p == 1.0 assert params.top_k is None assert params.max_tokens is None - assert params.repetition_penalty is None + assert params.repetition_penalty == 1.0 def test_invalid_temperature(): @@ -59,7 +59,7 @@ def test_kwargs(): ) assert params.kwargs == {"presence_penalty": 2.0, "frequency_penalty": 1.0} assert params.temperature == 0.5 - assert params.top_p is None + assert params.top_p == 1.0 assert params.top_k is None assert params.max_tokens is None - assert params.repetition_penalty is None + assert params.repetition_penalty == 1.0 From 24ba6dcbbc15349843f03a78afee35537ed8b7ba Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 10 Sep 2024 11:09:00 +0000 Subject: [PATCH 3/4] Adjusted idefics 2 test. --- aana/tests/deployments/test_idefics2_deployment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aana/tests/deployments/test_idefics2_deployment.py b/aana/tests/deployments/test_idefics2_deployment.py index 80f91fe6..03fba4c9 100644 --- a/aana/tests/deployments/test_idefics2_deployment.py +++ b/aana/tests/deployments/test_idefics2_deployment.py @@ -23,7 +23,7 @@ model="HuggingFaceM4/idefics2-8b", dtype=Dtype.FLOAT16, default_sampling_params=SamplingParams( - temperature=0.0, kwargs={"diversity_penalty": 0.0} + temperature=1.0, max_tokens=256, kwargs={"diversity_penalty": 0.0} ), ).model_dump(mode="json"), ), From 421cca2127380ad9af826f37590cfa0a4b41a50b Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 10 Sep 2024 14:00:02 +0000 Subject: [PATCH 4/4] Fix for OpenAI integration --- aana/core/models/chat.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/aana/core/models/chat.py b/aana/core/models/chat.py index e6ea82e0..c102f3d6 100644 --- a/aana/core/models/chat.py +++ b/aana/core/models/chat.py @@ -132,6 +132,7 @@ class ChatCompletionRequest(BaseModel): temperature (float): float that controls the randomness of the sampling top_p (float): float that controls the cumulative probability of the top tokens to consider max_tokens (int): the maximum number of tokens to generate + repetition_penalty (float): float that penalizes new tokens based on whether they appear in the prompt and the generated text so far stream (bool): if set, partial message deltas will be sent """ @@ -139,8 +140,8 @@ class ChatCompletionRequest(BaseModel): messages: list[ChatMessage] = Field( ..., description="A list of messages comprising the conversation so far." ) - temperature: float | None = Field( - default=None, + temperature: float = Field( + default=1.0, ge=0.0, description=( "Float that controls the randomness of the sampling. " @@ -149,8 +150,8 @@ class ChatCompletionRequest(BaseModel): "Zero means greedy sampling." ), ) - top_p: float | None = Field( - default=None, + top_p: float = Field( + default=1.0, gt=0.0, le=1.0, description=( @@ -161,6 +162,15 @@ class ChatCompletionRequest(BaseModel): max_tokens: int | None = Field( default=None, ge=1, description="The maximum number of tokens to generate." ) + repetition_penalty: float = Field( + default=1.0, + description=( + "Float that penalizes new tokens based on whether they appear in the " + "prompt and the generated text so far. Values > 1 encourage the model " + "to use new tokens, while values < 1 encourage the model to repeat tokens. " + "Default is 1.0 (no penalty)." + ), + ) stream: bool | None = Field( default=False,