From 421cca2127380ad9af826f37590cfa0a4b41a50b Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 10 Sep 2024 14:00:02 +0000 Subject: [PATCH] Fix for OpenAI integration --- aana/core/models/chat.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/aana/core/models/chat.py b/aana/core/models/chat.py index e6ea82e0..c102f3d6 100644 --- a/aana/core/models/chat.py +++ b/aana/core/models/chat.py @@ -132,6 +132,7 @@ class ChatCompletionRequest(BaseModel): temperature (float): float that controls the randomness of the sampling top_p (float): float that controls the cumulative probability of the top tokens to consider max_tokens (int): the maximum number of tokens to generate + repetition_penalty (float): float that penalizes new tokens based on whether they appear in the prompt and the generated text so far stream (bool): if set, partial message deltas will be sent """ @@ -139,8 +140,8 @@ class ChatCompletionRequest(BaseModel): messages: list[ChatMessage] = Field( ..., description="A list of messages comprising the conversation so far." ) - temperature: float | None = Field( - default=None, + temperature: float = Field( + default=1.0, ge=0.0, description=( "Float that controls the randomness of the sampling. " @@ -149,8 +150,8 @@ class ChatCompletionRequest(BaseModel): "Zero means greedy sampling." ), ) - top_p: float | None = Field( - default=None, + top_p: float = Field( + default=1.0, gt=0.0, le=1.0, description=( @@ -161,6 +162,15 @@ class ChatCompletionRequest(BaseModel): max_tokens: int | None = Field( default=None, ge=1, description="The maximum number of tokens to generate." ) + repetition_penalty: float = Field( + default=1.0, + description=( + "Float that penalizes new tokens based on whether they appear in the " + "prompt and the generated text so far. Values > 1 encourage the model " + "to use new tokens, while values < 1 encourage the model to repeat tokens. " + "Default is 1.0 (no penalty)." + ), + ) stream: bool | None = Field( default=False,