Skip to content

Commit

Permalink
feat(chat): add tool calls, response format and system fingerprint (#253
Browse files Browse the repository at this point in the history
)
  • Loading branch information
aallam authored Nov 8, 2023
1 parent 90b07db commit 39762ac
Show file tree
Hide file tree
Showing 20 changed files with 639 additions and 98 deletions.
3 changes: 3 additions & 0 deletions guides/ChatToolCalls.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Chat (with Tool Calls)

_WIP_
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.aallam.openai.client

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.chat.ChatCompletion
import com.aallam.openai.api.chat.ChatCompletionChunk
import com.aallam.openai.api.chat.ChatCompletionRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ import com.aallam.openai.client.extension.internal.ChatMessageAssembler
*/
@ExperimentalOpenAI
public fun List<ChatChunk>.mergeToChatMessage(): ChatMessage {
require(isNotEmpty()) { "ChatChunks List must not be empty" }
return fold(ChatMessageAssembler()) { assembler, chatChunk -> assembler.merge(chatChunk) }.build()
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package com.aallam.openai.client.extension.internal

import com.aallam.openai.api.chat.*
import com.aallam.openai.api.chat.internal.ToolType

/**
* A class to help assemble chat messages from chat chunks.
*/
internal class ChatMessageAssembler {
internal class ChatMessageAssembler() {
private val chatFuncName = StringBuilder()
private val chatFuncArgs = StringBuilder()
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null
private var toolCallsAssemblers = mutableListOf<ToolCallAssembler>()

/**
* Merges a chat chunk into the chat message being assembled.
Expand All @@ -22,6 +24,13 @@ internal class ChatMessageAssembler {
call.nameOrNull?.let { chatFuncName.append(it) }
call.argumentsOrNull?.let { chatFuncArgs.append(it) }
}
toolCalls?.first()?.let { toolCall -> // TBC: tool calls come one by one
if (toolCall.idOrNull != null) {
val toolCallAssembler = ToolCallAssembler()
toolCallsAssemblers.add(toolCallAssembler)
}
toolCallsAssemblers.last().merge(toolCall)
}
}
return this
}
Expand All @@ -36,5 +45,36 @@ internal class ChatMessageAssembler {
this.functionCall = FunctionCall(chatFuncName.toString(), chatFuncArgs.toString())
this.name = chatFuncName.toString()
}
if (toolCallsAssemblers.isNotEmpty()) {
this.toolCalls = toolCallsAssemblers.map { it.build() }.toList()
}
}
}

internal class ToolCallAssembler {
private var toolId: ToolId? = null
private var toolType: ToolType? = null
private var funcName: String? = null
private val funcArgs = StringBuilder()

fun merge(toolCall: ToolCall): ToolCallAssembler {
toolCall.idOrNull?.let { toolId = it }
toolCall.typeOrNull?.let { toolType = it }
toolCall.functionOrNull?.let { call ->
call.nameOrNull?.let { funcName = it }
call.argumentsOrNull?.let { funcArgs.append(it) }
}
return this
}

/**
* Builds and returns the assembled chat message.
*/
fun build(): ToolCall = toolCall {
this.id = toolId
this.type = toolType
if (funcName?.isNotEmpty() == true || funcArgs.isNotEmpty()) {
this.function = FunctionCall(funcName, funcArgs.toString())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ class TestChatChunk {
role = ChatRole.Assistant,
content = "The World Series in 2020 is being held in Texas.",
name = null,
functionCall = null
)
assertEquals(chatMessage, message)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package com.aallam.openai.client

import com.aallam.openai.api.chat.*
import com.aallam.openai.api.chat.internal.ToolType
import com.aallam.openai.api.model.ModelId
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonPrimitive
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
import kotlin.test.assertTrue
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlin.test.*

class TestChatCompletions : TestOpenAI() {

Expand Down Expand Up @@ -61,40 +60,73 @@ class TestChatCompletions : TestOpenAI() {
val request = chatCompletionRequest {
model = modelId
messages = chatMessages
functions {
function {
name = "currentWeather"
description = "Get the current weather in a given location"
parameters = Parameters.fromJsonString(
"""
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
]
}
},
"required": [
"location"
tools {
function(
name = "currentWeather",
parameters =
"""
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
]
}
"""
)
}
},
"required": [
"location"
]
}
"""
)
}
functionCall = FunctionMode.Named("currentWeather")

toolChoice = ToolChoice.function("currentWeather")
}

val response = openAI.chatCompletion(request)
val message = response.choices.first().message
assertEquals("currentWeather", message.functionCall?.name)
val toolCall = message.toolCalls?.first()
assertNotNull(toolCall)
assertEquals(ToolType.Function, toolCall.type)
assertEquals("currentWeather", toolCall.function?.name)
assertEquals(buildJsonObject { put("location", "Boston, MA") }, toolCall.function?.argumentsAsJson())
}

@Test
fun json() = test {
val request = chatCompletionRequest {
model = ModelId("gpt-3.5-turbo-1106")
responseFormat = ChatResponseFormat.JsonObject
messages {
message {
role = ChatRole.System
content = "You are a helpful assistant.!"
}
message {
role = ChatRole.System
content = """All your answers should be a valid JSON, and the format: {"question": <question>, "response": <response>}"""
}
message {
role = ChatRole.User
content = "Who won the world cup in 1998?"
}
}
}
val response = openAI.chatCompletion(request)
val content = response.choices.first().message.content.orEmpty()

@Serializable
data class Answer(val question: String? = null, val response: String? = null)
val answer = Json.decodeFromString<Answer>(content)
assertNotNull(answer.question)
assertNotNull(answer.response)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,12 @@ public data class ChatCompletion(
* Text completion usage data.
*/
@SerialName("usage") public val usage: Usage? = null,

/**
* This fingerprint represents the backend configuration that the model runs with.
*
* Can be used in conjunction with the seed request parameter to understand when backend changes have been made that
* might impact determinism.
*/
@SerialName("system_fingerprint") public val systemFingerprint: String? = null,
)
Loading

0 comments on commit 39762ac

Please sign in to comment.