Skip to content

Commit

Permalink
feat(chat): add logprob and topLogprobs (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkohls-indeed authored Apr 28, 2024
1 parent 2d870f9 commit 036bd79
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,43 @@ class TestChatCompletions : TestOpenAI() {
assertNotNull(answer.question)
assertNotNull(answer.response)
}

@Test
fun logprobs() = test {
val request = chatCompletionRequest {
model = ModelId("gpt-3.5-turbo-0125")
messages {
message {
role = ChatRole.User
content = "What's the weather like in Boston?"
}
}
logprobs = true
}
val response = openAI.chatCompletion(request)
val logprobs = response.choices.first().logprobs
assertNotNull(logprobs)
assertEquals(response.usage!!.completionTokens, logprobs.content!!.size)
}

@Test
fun top_logprobs() = test {
val expectedTopLogProbs = 5
val request = chatCompletionRequest {
model = ModelId("gpt-3.5-turbo-0125")
messages {
message {
role = ChatRole.User
content = "What's the weather like in Boston?"
}
}
logprobs = true
topLogprobs = expectedTopLogProbs
}
val response = openAI.chatCompletion(request)
val logprobs = response.choices.first().logprobs
assertNotNull(logprobs)
assertEquals(response.usage!!.completionTokens, logprobs.content!!.size)
assertEquals(logprobs.content!![0].topLogprobs?.size, expectedTopLogProbs)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.aallam.openai.api.chat;
package com.aallam.openai.api.chat

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.core.FinishReason
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand All @@ -20,9 +19,12 @@ public data class ChatChoice(
* The generated chat message.
*/
@SerialName("message") public val message: ChatMessage,

/**
* The reason why OpenAI stopped generating.
*/
@SerialName("finish_reason") public val finishReason: FinishReason? = null,
/**
* Log probability information for the choice.
*/
@SerialName("logprobs") public val logprobs: Logprobs? = null,
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.aallam.openai.api.chat

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.core.Usage
import com.aallam.openai.api.model.ModelId
import kotlinx.serialization.SerialName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ public class ChatCompletionRequest(
*/
@property:BetaOpenAI
@SerialName("seed") public val seed: Int? = null,

/**
* Whether to return log probabilities of the output tokens or not. If true,
* returns the log probabilities of each output token returned in the content of message.
*/
@SerialName("logprobs") public val logprobs: Boolean? = null,

/**
* An integer between 0 and 20 specifying the number of most likely tokens to return at each token position,
* each with an associated log probability. logprobs must be set to true if this parameter is used.
*/
@SerialName("top_logprobs") public val topLogprobs: Int? = null,
)

/**
Expand Down Expand Up @@ -282,6 +294,18 @@ public class ChatCompletionRequestBuilder {
*/
public var toolChoice: ToolChoice? = null

/**
* Whether to return log probabilities of the output tokens or not. If true,
* returns the log probabilities of each output token returned in the content of message.
*/
public var logprobs: Boolean? = null

/**
* An integer between 0 and 20 specifying the number of most likely tokens to return at each token position,
* each with an associated log probability. logprobs must be set to true if this parameter is used.
*/
public var topLogprobs: Int? = null

/**
* The messages to generate chat completions for.
*/
Expand Down Expand Up @@ -323,7 +347,9 @@ public class ChatCompletionRequestBuilder {
functionCall = functionCall,
responseFormat = responseFormat,
toolChoice = toolChoice,
tools = tools
tools = tools,
logprobs = logprobs,
topLogprobs = topLogprobs
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.aallam.openai.api.chat

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

/**
* An object containing log probability information for the choice.
*
* [documentation](https://platform.openai.com/docs/api-reference/chat/object)
*/
@Serializable
public data class Logprobs(
/**
* A list of message content tokens with log probability information.
*/
@SerialName("content") public val content: List<LogprobsContent>? = null,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.aallam.openai.api.chat

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

/**
* An object containing logprobs for a single token
*
* [documentation](https://platform.openai.com/docs/api-reference/chat/object)
*/
@Serializable
public data class LogprobsContent(
/**
* The token.
*/
@SerialName("token") public val token: String,
/**
* The log probability of this token, if it is within the top 20 most likely tokens.
* Otherwise, the value -9999.0 is used to signify that the token is very unlikely.
*/
@SerialName("logprob") public val logprob: Double,
/**
* A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where
* characters are represented by multiple tokens and their byte representations must be combined to generate
* the correct text representation. Can be `null` if there is no bytes representation for the token.
*/
@SerialName("bytes") public val bytes: List<Int>? = null,
/**
* List of the most likely tokens and their log probability, at this token position.
* In rare cases, there may be fewer than the number of requested top_logprobs returned.
*/
@SerialName("top_logprobs") public val topLogprobs: List<TopLogprob>,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.aallam.openai.api.chat

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

/**
* An object containing a token and their log probability.
*
* [documentation](https://platform.openai.com/docs/api-reference/chat/object)
*/
@Serializable
public data class TopLogprob(
/**
* The token
*/
@SerialName("token") public val token: String,
/**
* The log probability of this token, if it is within the top 20 most likely tokens.
* Otherwise, the value `-9999.0` is used to signify that the token is very unlikely.
*/
@SerialName("logprob") public val logprob: Double,
/**
* A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where
* characters are represented by multiple tokens and their byte representations must be combined to generate
* the correct text representation. Can be `null` if there is no bytes representation for the token.
*/
@SerialName("bytes") public val bytes: List<Int>? = null,
)

0 comments on commit 036bd79

Please sign in to comment.