Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(chat): add tool call index #256

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import com.aallam.openai.api.chat.internal.ToolType
/**
* A class to help assemble chat messages from chat chunks.
*/
internal class ChatMessageAssembler() {
internal class ChatMessageAssembler {
private val chatFuncName = StringBuilder()
private val chatFuncArgs = StringBuilder()
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null
private var toolCallsAssemblers = mutableListOf<ToolCallAssembler>()
private val toolCallsAssemblers = mutableMapOf<Int, ToolCallAssembler>()

/**
* Merges a chat chunk into the chat message being assembled.
Expand All @@ -24,12 +24,9 @@ internal class ChatMessageAssembler() {
call.nameOrNull?.let { chatFuncName.append(it) }
call.argumentsOrNull?.let { chatFuncArgs.append(it) }
}
toolCalls?.first()?.let { toolCall -> // TBC: tool calls come one by one
if (toolCall.idOrNull != null) {
val toolCallAssembler = ToolCallAssembler()
toolCallsAssemblers.add(toolCallAssembler)
}
toolCallsAssemblers.last().merge(toolCall)
toolCalls?.onEach { toolCall ->
val assembler = toolCallsAssemblers.getOrPut(toolCall.index) { ToolCallAssembler() }
assembler.merge(toolCall)
}
}
return this
Expand All @@ -46,18 +43,20 @@ internal class ChatMessageAssembler() {
this.name = chatFuncName.toString()
}
if (toolCallsAssemblers.isNotEmpty()) {
this.toolCalls = toolCallsAssemblers.map { it.build() }.toList()
this.toolCalls = toolCallsAssemblers.map { (_, value) -> value.build() }.toList()
}
}
}

internal class ToolCallAssembler {
private var toolIndex: Int? = null
private var toolId: ToolId? = null
private var toolType: ToolType? = null
private var funcName: String? = null
private val funcArgs = StringBuilder()

fun merge(toolCall: ToolCall): ToolCallAssembler {
toolCall.indexOrNull?.let { toolIndex = it }
toolCall.idOrNull?.let { toolId = it }
toolCall.typeOrNull?.let { toolType = it }
toolCall.functionOrNull?.let { call ->
Expand All @@ -71,6 +70,7 @@ internal class ToolCallAssembler {
* Builds and returns the assembled chat message.
*/
fun build(): ToolCall = toolCall {
this.index = toolIndex
this.id = toolId
this.type = toolType
if (funcName?.isNotEmpty() == true || funcArgs.isNotEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import kotlinx.serialization.Serializable
/**
* The tool call generated by the model, such as function call.
*
* In case of streaming variant of the chat API, the object parameters can be `null`.
* In case of streaming variant of the chat API, [index] is required, other object parameters can be `null`.
*/
@Serializable
public data class ToolCall(
/** Tool call index. Required in the case of chat stream variant **/
@SerialName("index") val indexOrNull: Int? = null,
/** The ID of the tool call. **/
@SerialName("id") val idOrNull: ToolId? = null,
/** The type of the tool. **/
Expand All @@ -20,6 +22,9 @@ public data class ToolCall(
@SerialName("function") val functionOrNull: FunctionCall? = null,
) {

val index: Int
get() = requireNotNull(indexOrNull)

/** The ID of the tool call. **/
val id: ToolId
get() = requireNotNull(idOrNull)
Expand All @@ -40,6 +45,10 @@ public fun toolCall(block: ToolCallBuilder.() -> Unit): ToolCall = ToolCallBuild

@OpenAIDsl
public class ToolCallBuilder {

/** Tool call index. Required in the case of chat stream variant **/
public var index: Int? = null

/** The ID of the tool call. **/
public var id: ToolId? = null

Expand All @@ -53,6 +62,7 @@ public class ToolCallBuilder {
* Create [ToolCall] instance.
*/
public fun build(): ToolCall = ToolCall(
indexOrNull = index,
idOrNull = id,
typeOrNull = type,
functionOrNull = function,
Expand Down