Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repetition_penalty and kwargs parameters to SamplingParams #174

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions aana/core/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,16 @@ class ChatCompletionRequest(BaseModel):
temperature (float): float that controls the randomness of the sampling
top_p (float): float that controls the cumulative probability of the top tokens to consider
max_tokens (int): the maximum number of tokens to generate
repetition_penalty (float): float that penalizes new tokens based on whether they appear in the prompt and the generated text so far
stream (bool): if set, partial message deltas will be sent
"""

model: str = Field(..., description="The model name (name of the LLM deployment).")
messages: list[ChatMessage] = Field(
..., description="A list of messages comprising the conversation so far."
)
temperature: float | None = Field(
default=None,
temperature: float = Field(
default=1.0,
ge=0.0,
description=(
"Float that controls the randomness of the sampling. "
Expand All @@ -149,8 +150,8 @@ class ChatCompletionRequest(BaseModel):
"Zero means greedy sampling."
),
)
top_p: float | None = Field(
default=None,
top_p: float = Field(
default=1.0,
gt=0.0,
le=1.0,
description=(
Expand All @@ -161,6 +162,15 @@ class ChatCompletionRequest(BaseModel):
max_tokens: int | None = Field(
default=None, ge=1, description="The maximum number of tokens to generate."
)
repetition_penalty: float = Field(
default=1.0,
description=(
"Float that penalizes new tokens based on whether they appear in the "
"prompt and the generated text so far. Values > 1 encourage the model "
"to use new tokens, while values < 1 encourage the model to repeat tokens. "
"Default is 1.0 (no penalty)."
),
)

stream: bool | None = Field(
default=False,
Expand Down
26 changes: 22 additions & 4 deletions aana/core/models/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ class SamplingParams(BaseModel):
top_k (int): Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
max_tokens (int): The maximum number of tokens to generate.
repetition_penalty (float): Float that penalizes new tokens based on whether
they appear in the prompt and the generated text so far. Values > 1 encourage
the model to use new tokens, while values < 1 encourage the model to repeat
tokens. Default is 1.0 (no penalty).
kwargs (dict): Extra keyword arguments to pass as sampling parameters.
"""

temperature: float | None = Field(
default=None,
temperature: float = Field(
default=1.0,
ge=0.0,
description=(
"Float that controls the randomness of the sampling. "
Expand All @@ -27,8 +32,8 @@ class SamplingParams(BaseModel):
"Zero means greedy sampling."
),
)
top_p: float | None = Field(
default=None,
top_p: float = Field(
default=1.0,
gt=0.0,
le=1.0,
description=(
Expand All @@ -46,6 +51,19 @@ class SamplingParams(BaseModel):
max_tokens: int | None = Field(
default=None, ge=1, description="The maximum number of tokens to generate."
)
repetition_penalty: float = Field(
default=1.0,
description=(
"Float that penalizes new tokens based on whether they appear in the "
"prompt and the generated text so far. Values > 1 encourage the model "
"to use new tokens, while values < 1 encourage the model to repeat tokens. "
"Default is 1.0 (no penalty)."
),
)
kwargs: dict = Field(
default_factory=dict,
description="Extra keyword arguments to pass as sampling parameters.",
)

@field_validator("top_k")
def check_top_k(cls, v: int):
Expand Down
2 changes: 2 additions & 0 deletions aana/deployments/hf_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ async def generate_stream(
temperature=sampling_params.temperature,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=sampling_params.repetition_penalty,
**sampling_params.kwargs,
)
if sampling_params.temperature == 0:
generation_kwargs["do_sample"] = False
Expand Down
1 change: 1 addition & 0 deletions aana/deployments/idefics_2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ async def chat_stream(
temperature=sampling_params.temperature,
num_return_sequences=1,
eos_token_id=self.processor.tokenizer.eos_token_id,
**sampling_params.kwargs,
)
if sampling_params.temperature == 0:
generation_kwargs["do_sample"] = False
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ async def generate_stream(
try:
# convert SamplingParams to VLLMSamplingParams
sampling_params_vllm = VLLMSamplingParams(
**sampling_params.model_dump(exclude_unset=True)
**sampling_params.model_dump(exclude_unset=True, exclude=["kwargs"]),
**sampling_params.kwargs,
)
# start the request
request_id = random_uuid()
Expand Down
4 changes: 4 additions & 0 deletions aana/tests/deployments/test_idefics2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aana.core.models.chat import ChatMessage
from aana.core.models.image import Image
from aana.core.models.image_chat import ImageChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.deployments.idefics_2_deployment import Idefics2Config, Idefics2Deployment
Expand All @@ -21,6 +22,9 @@
user_config=Idefics2Config(
model="HuggingFaceM4/idefics2-8b",
dtype=Dtype.FLOAT16,
default_sampling_params=SamplingParams(
temperature=1.0, max_tokens=256, kwargs={"diversity_penalty": 0.0}
),
).model_dump(mode="json"),
),
)
Expand Down
9 changes: 8 additions & 1 deletion aana/tests/deployments/test_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
model_kwargs={
"trust_remote_code": True,
},
default_sampling_params=SamplingParams(
temperature=0.0, kwargs={"diversity_penalty": 0.0}
),
).model_dump(mode="json"),
),
),
Expand All @@ -46,7 +49,11 @@
gpu_memory_reserved=10000,
enforce_eager=True,
default_sampling_params=SamplingParams(
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024
temperature=0.0,
top_p=1.0,
top_k=-1,
max_tokens=1024,
kwargs={"frequency_penalty": 0.0},
),
engine_args={
"trust_remote_code": True,
Expand Down
25 changes: 21 additions & 4 deletions aana/tests/units/test_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@

def test_valid_sampling_params():
"""Test valid sampling parameters."""
params = SamplingParams(temperature=0.5, top_p=0.9, top_k=10, max_tokens=50)
params = SamplingParams(
temperature=0.5, top_p=0.9, top_k=10, max_tokens=50, repetition_penalty=1.5
)
assert params.temperature == 0.5
assert params.top_p == 0.9
assert params.top_k == 10
assert params.max_tokens == 50
assert params.repetition_penalty == 1.5

# Test valid params with default values (None)
# Test valid params with default values
params = SamplingParams()
assert params.temperature is None
assert params.top_p is None
assert params.temperature == 1.0
assert params.top_p == 1.0
assert params.top_k is None
assert params.max_tokens is None
assert params.repetition_penalty == 1.0


def test_invalid_temperature():
Expand Down Expand Up @@ -46,3 +50,16 @@ def test_invalid_max_tokens():
"""Test invalid max_tokens values."""
with pytest.raises(ValueError):
SamplingParams(max_tokens=0)


def test_kwargs():
"""Test extra keyword arguments."""
params = SamplingParams(
temperature=0.5, kwargs={"presence_penalty": 2.0, "frequency_penalty": 1.0}
)
assert params.kwargs == {"presence_penalty": 2.0, "frequency_penalty": 1.0}
assert params.temperature == 0.5
assert params.top_p == 1.0
assert params.top_k is None
assert params.max_tokens is None
assert params.repetition_penalty == 1.0
Loading