From 036bd79cb7c6c1433c2d4b865468ed7439e9e500 Mon Sep 17 00:00:00 2001 From: jkohls-indeed <108553270+jkohls-indeed@users.noreply.github.com> Date: Sun, 28 Apr 2024 14:30:04 -0700 Subject: [PATCH] feat(chat): add `logprob` and `topLogprobs` (#328) --- .../openai/client/TestChatCompletions.kt | 39 +++++++++++++++++++ .../com.aallam.openai.api/chat/ChatChoice.kt | 8 ++-- .../chat/ChatCompletion.kt | 1 - .../chat/ChatCompletionRequest.kt | 28 ++++++++++++- .../com.aallam.openai.api/chat/Logprobs.kt | 17 ++++++++ .../chat/LogprobsContent.kt | 33 ++++++++++++++++ .../com.aallam.openai.api/chat/TopLogprob.kt | 28 +++++++++++++ 7 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/Logprobs.kt create mode 100644 openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/LogprobsContent.kt create mode 100644 openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/TopLogprob.kt diff --git a/openai-client/src/commonTest/kotlin/com/aallam/openai/client/TestChatCompletions.kt b/openai-client/src/commonTest/kotlin/com/aallam/openai/client/TestChatCompletions.kt index f26289e6..bdb58639 100644 --- a/openai-client/src/commonTest/kotlin/com/aallam/openai/client/TestChatCompletions.kt +++ b/openai-client/src/commonTest/kotlin/com/aallam/openai/client/TestChatCompletions.kt @@ -126,4 +126,43 @@ class TestChatCompletions : TestOpenAI() { assertNotNull(answer.question) assertNotNull(answer.response) } + + @Test + fun logprobs() = test { + val request = chatCompletionRequest { + model = ModelId("gpt-3.5-turbo-0125") + messages { + message { + role = ChatRole.User + content = "What's the weather like in Boston?" + } + } + logprobs = true + } + val response = openAI.chatCompletion(request) + val logprobs = response.choices.first().logprobs + assertNotNull(logprobs) + assertEquals(response.usage!!.completionTokens, logprobs.content!!.size) + } + + @Test + fun top_logprobs() = test { + val expectedTopLogProbs = 5 + val request = chatCompletionRequest { + model = ModelId("gpt-3.5-turbo-0125") + messages { + message { + role = ChatRole.User + content = "What's the weather like in Boston?" + } + } + logprobs = true + topLogprobs = expectedTopLogProbs + } + val response = openAI.chatCompletion(request) + val logprobs = response.choices.first().logprobs + assertNotNull(logprobs) + assertEquals(response.usage!!.completionTokens, logprobs.content!!.size) + assertEquals(logprobs.content!![0].topLogprobs?.size, expectedTopLogProbs) + } } diff --git a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatChoice.kt b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatChoice.kt index 8c78ed23..27cb8ad5 100644 --- a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatChoice.kt +++ b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatChoice.kt @@ -1,6 +1,5 @@ -package com.aallam.openai.api.chat; +package com.aallam.openai.api.chat -import com.aallam.openai.api.BetaOpenAI import com.aallam.openai.api.core.FinishReason import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -20,9 +19,12 @@ public data class ChatChoice( * The generated chat message. */ @SerialName("message") public val message: ChatMessage, - /** * The reason why OpenAI stopped generating. */ @SerialName("finish_reason") public val finishReason: FinishReason? = null, + /** + * Log probability information for the choice. + */ + @SerialName("logprobs") public val logprobs: Logprobs? = null, ) diff --git a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletion.kt b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletion.kt index 9f2a87f9..eb0c94b8 100644 --- a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletion.kt +++ b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletion.kt @@ -1,6 +1,5 @@ package com.aallam.openai.api.chat -import com.aallam.openai.api.BetaOpenAI import com.aallam.openai.api.core.Usage import com.aallam.openai.api.model.ModelId import kotlinx.serialization.SerialName diff --git a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletionRequest.kt b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletionRequest.kt index ab67fa93..8d6407dd 100644 --- a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletionRequest.kt +++ b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/ChatCompletionRequest.kt @@ -146,6 +146,18 @@ public class ChatCompletionRequest( */ @property:BetaOpenAI @SerialName("seed") public val seed: Int? = null, + + /** + * Whether to return log probabilities of the output tokens or not. If true, + * returns the log probabilities of each output token returned in the content of message. + */ + @SerialName("logprobs") public val logprobs: Boolean? = null, + + /** + * An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, + * each with an associated log probability. logprobs must be set to true if this parameter is used. + */ + @SerialName("top_logprobs") public val topLogprobs: Int? = null, ) /** @@ -282,6 +294,18 @@ public class ChatCompletionRequestBuilder { */ public var toolChoice: ToolChoice? = null + /** + * Whether to return log probabilities of the output tokens or not. If true, + * returns the log probabilities of each output token returned in the content of message. + */ + public var logprobs: Boolean? = null + + /** + * An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, + * each with an associated log probability. logprobs must be set to true if this parameter is used. + */ + public var topLogprobs: Int? = null + /** * The messages to generate chat completions for. */ @@ -323,7 +347,9 @@ public class ChatCompletionRequestBuilder { functionCall = functionCall, responseFormat = responseFormat, toolChoice = toolChoice, - tools = tools + tools = tools, + logprobs = logprobs, + topLogprobs = topLogprobs ) } diff --git a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/Logprobs.kt b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/Logprobs.kt new file mode 100644 index 00000000..8a324443 --- /dev/null +++ b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/Logprobs.kt @@ -0,0 +1,17 @@ +package com.aallam.openai.api.chat + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * An object containing log probability information for the choice. + * + * [documentation](https://platform.openai.com/docs/api-reference/chat/object) + */ +@Serializable +public data class Logprobs( + /** + * A list of message content tokens with log probability information. + */ + @SerialName("content") public val content: List? = null, +) diff --git a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/LogprobsContent.kt b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/LogprobsContent.kt new file mode 100644 index 00000000..17e91206 --- /dev/null +++ b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/LogprobsContent.kt @@ -0,0 +1,33 @@ +package com.aallam.openai.api.chat + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * An object containing logprobs for a single token + * + * [documentation](https://platform.openai.com/docs/api-reference/chat/object) + */ +@Serializable +public data class LogprobsContent( + /** + * The token. + */ + @SerialName("token") public val token: String, + /** + * The log probability of this token, if it is within the top 20 most likely tokens. + * Otherwise, the value -9999.0 is used to signify that the token is very unlikely. + */ + @SerialName("logprob") public val logprob: Double, + /** + * A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where + * characters are represented by multiple tokens and their byte representations must be combined to generate + * the correct text representation. Can be `null` if there is no bytes representation for the token. + */ + @SerialName("bytes") public val bytes: List? = null, + /** + * List of the most likely tokens and their log probability, at this token position. + * In rare cases, there may be fewer than the number of requested top_logprobs returned. + */ + @SerialName("top_logprobs") public val topLogprobs: List, +) diff --git a/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/TopLogprob.kt b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/TopLogprob.kt new file mode 100644 index 00000000..c2afc3f1 --- /dev/null +++ b/openai-core/src/commonMain/kotlin/com.aallam.openai.api/chat/TopLogprob.kt @@ -0,0 +1,28 @@ +package com.aallam.openai.api.chat + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * An object containing a token and their log probability. + * + * [documentation](https://platform.openai.com/docs/api-reference/chat/object) + */ +@Serializable +public data class TopLogprob( + /** + * The token + */ + @SerialName("token") public val token: String, + /** + * The log probability of this token, if it is within the top 20 most likely tokens. + * Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. + */ + @SerialName("logprob") public val logprob: Double, + /** + * A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where + * characters are represented by multiple tokens and their byte representations must be combined to generate + * the correct text representation. Can be `null` if there is no bytes representation for the token. + */ + @SerialName("bytes") public val bytes: List? = null, +)