Skip to content

Commit

Permalink
feat(chat): add function calling (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenGuckSnk authored Jun 16, 2023
1 parent a1a0c4c commit 409d5bb
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 3 deletions.
1 change: 1 addition & 0 deletions openai-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ kotlin {
val commonMain by getting {
dependencies {
api(libs.okio)
api(libs.serialization.json)
implementation(libs.serialization.core)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.aallam.openai.api.chat

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
public data class ChatCompletionFunction(
/**
* The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum
* length of 64.
*/
@SerialName("name") val name: String,
/**
* The description of what the function does.
*/
@SerialName("description") val description: String? = null,
/**
* The parameters the functions accepts, described as a JSON Schema object. See the guide for examples and the
* JSON Schema reference for documentation about the format.
*/
@SerialName("parameters") val parameters: JsonData? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,21 @@ public class ChatCompletionRequest(
* A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
*/
@SerialName("user") public val user: String? = null,
)

/**
* A list of functions the model may generate JSON inputs for.
*/
@SerialName("functions") public val functions: List<ChatCompletionFunction>? = null,

/**
* Controls how the model responds to function calls. "none" means the model does not call a function, and responds
* to the end-user. "auto" means the model can pick between an end-user or calling a function. Specifying a
* particular function via {"name":\ "my_function"} forces the model to call that function. "none" is the default
* when no functions are present. "auto" is the default if functions are present.
*/
@SerialName("function_call") public val functionCall: FunctionCall? = null,

)

/**
* The messages to generate chat completions for.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ public data class ChatMessage(
/**
* The contents of the message.
*/
@SerialName("content") public val content: String,
@SerialName("content") public val content: String? = null,

/**
* The name of the user in a multi-user chat.
*/
@SerialName("name") public val name: String? = null
@SerialName("name") public val name: String? = null,

/**
* The name and arguments of a function that should be called, as generated by the model.
*/
@SerialName("function_call") public val functionCall: JsonData? = null
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ public value class ChatRole(public val role: String) {
public val System: ChatRole = ChatRole("system")
public val User: ChatRole = ChatRole("user")
public val Assistant: ChatRole = ChatRole("assistant")
public val Function: ChatRole = ChatRole("function")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.aallam.openai.api.chat

import com.aallam.openai.api.BetaOpenAI
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonEncoder
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlin.jvm.JvmInline

@BetaOpenAI
@Serializable(with = FunctionCallSerializer::class)
public sealed interface FunctionCall {
public val name: String

public companion object {
public val Auto: FunctionCall = FunctionCallString("auto")
public val None: FunctionCall = FunctionCallString("none")
public fun forceCall(name: String): FunctionCall = FunctionCallObject(name)
}
}
@OptIn(BetaOpenAI::class)
internal object FunctionCallSerializer: KSerializer<FunctionCall>{
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("FunctionCall") {}
private val objectSerializer = FunctionCallObject.serializer()
override fun deserialize(decoder: Decoder): FunctionCall {
if(decoder is JsonDecoder){
return when(val json = decoder.decodeJsonElement()){
is JsonPrimitive -> FunctionCallString(json.content)
is JsonObject -> objectSerializer.deserialize(decoder)
else -> throw UnsupportedOperationException("Cannot deserialize Parameters")
}
}
throw UnsupportedOperationException("Cannot deserialize Parameters")
}

override fun serialize(encoder: Encoder, value: FunctionCall) {
if(encoder is JsonEncoder){
when(value){
is FunctionCallString -> encoder.encodeString(value.name)
is FunctionCallObject -> objectSerializer.serialize(encoder, value)
}
return
}
throw UnsupportedOperationException("Cannot deserialize Parameters")
}
}

@OptIn(BetaOpenAI::class)
@JvmInline
@Serializable
internal value class FunctionCallString(override val name: String): FunctionCall

@OptIn(BetaOpenAI::class)
@Serializable
internal data class FunctionCallObject(override val name: String): FunctionCall

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.aallam.openai.api.chat

import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonEncoder
import kotlinx.serialization.json.JsonObjectBuilder
import kotlinx.serialization.json.buildJsonObject
import kotlin.jvm.JvmInline

@JvmInline
@Serializable(with = JsonData.JsonDataSerializer::class)
public value class JsonData(public val json: JsonElement){
public object JsonDataSerializer: KSerializer<JsonData>{
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("JsonData") {}

override fun deserialize(decoder: Decoder): JsonData {
if(decoder is JsonDecoder){
return JsonData(decoder.decodeJsonElement())
}
throw UnsupportedOperationException("Cannot deserialize Parameters")
}

override fun serialize(encoder: Encoder, value: JsonData) {
if(encoder is JsonEncoder){
encoder.encodeJsonElement(value.json)
return
}
}
}
public companion object{
public fun fromString(json: String): JsonData = fromJsonElement(Json.parseToJsonElement(json))

public fun fromJsonElement(json: JsonElement): JsonData = JsonData(json)

public fun builder(block: JsonObjectBuilder.() -> Unit): JsonData{
val json = buildJsonObject (
block
)
return fromJsonElement(json)
}

}
}
69 changes: 69 additions & 0 deletions sample/jvm/src/main/kotlin/com/aallam/openai/sample/jvm/App.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import com.aallam.openai.api.audio.TranslationRequest
import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.chat.FunctionCall
import com.aallam.openai.api.chat.ChatCompletionFunction
import com.aallam.openai.api.chat.JsonData
import com.aallam.openai.api.completion.CompletionRequest
import com.aallam.openai.api.file.FileSource
import com.aallam.openai.api.image.ImageCreation
Expand All @@ -22,6 +25,7 @@ import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.runBlocking
import okio.FileSystem
import okio.Path.Companion.toPath
import kotlinx.serialization.json.put

@OptIn(BetaOpenAI::class)
fun main() = runBlocking {
Expand Down Expand Up @@ -101,6 +105,71 @@ fun main() = runBlocking {
)
openAI.chatCompletion(chatCompletionRequest).choices.forEach(::println)

println("> Create Chat Completion function call...")
val chatCompletionCreateFunctionCall = ChatCompletionRequest(
model = ModelId("gpt-3.5-turbo-0613"),
messages = listOf(
ChatMessage(
role = ChatRole.System,
content = "You are a helpful assistant that translates English to French."
),
ChatMessage(
role = ChatRole.User,
content = "Translate the following English text to French: “OpenAI is awesome!”"
)
),
functionCall = FunctionCall.forceCall("translate"),
functions = listOf(
ChatCompletionFunction(
name = "translate",
description = "Translate English to French",
parameters = JsonData.fromString(
"""
{
"type": "object",
"properties": {
"text": {
"type": "string"
}
}
}
"""
),
)
)
)
openAI.chatCompletion(chatCompletionCreateFunctionCall).choices.forEach(::println)
println("> Process Chat Completion function call...")
val chatFunctionReturn = ChatCompletionRequest(
model = ModelId("gpt-3.5-turbo-0613"),
messages = listOf(
ChatMessage(
role = ChatRole.System,
content = "You are a helpful assistant that uses a function to translates English to French.\n" +
"Use only the result of the function call as the response."
),
ChatMessage(
role = ChatRole.User,
content = "Translate the following English text to French: “OpenAI is awesome!”"
),
ChatMessage(
role = ChatRole.Assistant,
content = "None",
functionCall = JsonData.builder {
put("name", "translate")
put("arguments", """{"text": "OpenAI is awesome!"}""")
}

),
ChatMessage(
role = ChatRole.Function,
content = "openai est super !",
name = "translate",
)
),
)
openAI.chatCompletion(chatFunctionReturn).choices.forEach(::println)

println("\n>️ Creating chat completions stream...")
openAI.chatCompletions(chatCompletionRequest)
.onEach { print(it.choices.first().delta?.content.orEmpty()) }
Expand Down

0 comments on commit 409d5bb

Please sign in to comment.