Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Function Calling #200

Merged
merged 5 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openai-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ kotlin {
val commonMain by getting {
dependencies {
api(libs.okio)
api(libs.serialization.json)
implementation(libs.serialization.core)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.aallam.openai.api.chat

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
public data class ChatCompletionFunction(
/**
* The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum
* length of 64.
*/
@SerialName("name") val name: String,
/**
* The description of what the function does.
*/
@SerialName("description") val description: String? = null,
/**
* The parameters the functions accepts, described as a JSON Schema object. See the guide for examples and the
* JSON Schema reference for documentation about the format.
*/
@SerialName("parameters") val parameters: JsonData? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,21 @@ public class ChatCompletionRequest(
* A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
*/
@SerialName("user") public val user: String? = null,
)

/**
* A list of functions the model may generate JSON inputs for.
*/
@SerialName("functions") public val functions: List<ChatCompletionFunction>? = null,

/**
* Controls how the model responds to function calls. "none" means the model does not call a function, and responds
* to the end-user. "auto" means the model can pick between an end-user or calling a function. Specifying a
* particular function via {"name":\ "my_function"} forces the model to call that function. "none" is the default
* when no functions are present. "auto" is the default if functions are present.
*/
@SerialName("function_call") public val functionCall: FunctionCall? = null,

)

/**
* The messages to generate chat completions for.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ public data class ChatMessage(
/**
* The contents of the message.
*/
@SerialName("content") public val content: String,
@SerialName("content") public val content: String? = null,

/**
* The name of the user in a multi-user chat.
*/
@SerialName("name") public val name: String? = null
@SerialName("name") public val name: String? = null,

/**
* The name and arguments of a function that should be called, as generated by the model.
*/
@SerialName("function_call") public val functionCall: JsonData? = null
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ public value class ChatRole(public val role: String) {
public val System: ChatRole = ChatRole("system")
public val User: ChatRole = ChatRole("user")
public val Assistant: ChatRole = ChatRole("assistant")
public val Function: ChatRole = ChatRole("function")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.aallam.openai.api.chat

import com.aallam.openai.api.BetaOpenAI
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonEncoder
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlin.jvm.JvmInline

@BetaOpenAI
@Serializable(with = FunctionCallSerializer::class)
public sealed interface FunctionCall {
public val name: String

public companion object {
public val Auto: FunctionCall = FunctionCallString("auto")
public val None: FunctionCall = FunctionCallString("none")
public fun forceCall(name: String): FunctionCall = FunctionCallObject(name)
}
}
@OptIn(BetaOpenAI::class)
internal object FunctionCallSerializer: KSerializer<FunctionCall>{
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("FunctionCall") {}
private val objectSerializer = FunctionCallObject.serializer()
override fun deserialize(decoder: Decoder): FunctionCall {
if(decoder is JsonDecoder){
return when(val json = decoder.decodeJsonElement()){
is JsonPrimitive -> FunctionCallString(json.content)
is JsonObject -> objectSerializer.deserialize(decoder)
else -> throw UnsupportedOperationException("Cannot deserialize Parameters")
}
}
throw UnsupportedOperationException("Cannot deserialize Parameters")
}

override fun serialize(encoder: Encoder, value: FunctionCall) {
if(encoder is JsonEncoder){
when(value){
is FunctionCallString -> encoder.encodeString(value.name)
is FunctionCallObject -> objectSerializer.serialize(encoder, value)
}
return
}
throw UnsupportedOperationException("Cannot deserialize Parameters")
}
}

@OptIn(BetaOpenAI::class)
@JvmInline
@Serializable
internal value class FunctionCallString(override val name: String): FunctionCall

@OptIn(BetaOpenAI::class)
@Serializable
internal data class FunctionCallObject(override val name: String): FunctionCall

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.aallam.openai.api.chat

import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonDecoder
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonEncoder
import kotlinx.serialization.json.JsonObjectBuilder
import kotlinx.serialization.json.buildJsonObject
import kotlin.jvm.JvmInline

@JvmInline
@Serializable(with = JsonData.JsonDataSerializer::class)
public value class JsonData(public val json: JsonElement){
public object JsonDataSerializer: KSerializer<JsonData>{
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("JsonData") {}

override fun deserialize(decoder: Decoder): JsonData {
if(decoder is JsonDecoder){
return JsonData(decoder.decodeJsonElement())
}
throw UnsupportedOperationException("Cannot deserialize Parameters")
}

override fun serialize(encoder: Encoder, value: JsonData) {
if(encoder is JsonEncoder){
encoder.encodeJsonElement(value.json)
return
}
}
}
public companion object{
public fun fromString(json: String): JsonData = fromJsonElement(Json.parseToJsonElement(json))

public fun fromJsonElement(json: JsonElement): JsonData = JsonData(json)

public fun builder(block: JsonObjectBuilder.() -> Unit): JsonData{
val json = buildJsonObject (
block
)
return fromJsonElement(json)
}

}
}
62 changes: 61 additions & 1 deletion sample/jvm/src/main/kotlin/com/aallam/openai/sample/jvm/App.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import com.aallam.openai.api.audio.TranslationRequest
import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.chat.FunctionCall
import com.aallam.openai.api.chat.ChatCompletionFunction
import com.aallam.openai.api.chat.JsonData
import com.aallam.openai.api.completion.CompletionRequest
import com.aallam.openai.api.file.FileSource
import com.aallam.openai.api.image.ImageCreation
Expand All @@ -22,9 +25,10 @@ import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.runBlocking
import okio.FileSystem
import okio.Path.Companion.toPath
import kotlinx.serialization.json.put

@OptIn(BetaOpenAI::class)
fun main() = runBlocking {
fun mainOld() = runBlocking {
JochenGuckSnk marked this conversation as resolved.
Show resolved Hide resolved
val apiKey = System.getenv("OPENAI_API_KEY")
val token = requireNotNull(apiKey) { "OPENAI_API_KEY environment variable must be set." }
val openAI = OpenAI(token = token, logging = LoggingConfig(LogLevel.All))
Expand Down Expand Up @@ -101,6 +105,62 @@ fun main() = runBlocking {
)
openAI.chatCompletion(chatCompletionRequest).choices.forEach(::println)

println("> Create Chat Completion function call...")
val chatCompletionCreateFunctionCall = ChatCompletionRequest(
model = ModelId("gpt-3.5-turbo-0613"),
messages = listOf(
ChatMessage(
role = ChatRole.System,
content = "You are a helpful assistant that translates English to French."
),
ChatMessage(
role = ChatRole.User,
content = "Translate the following English text to French: “OpenAI is awesome!”"
)
),
functionCall = FunctionCall.forceCall("translate"),
functions = listOf(
ChatCompletionFunction(
name = "translate",
description = "Translate English to French",
parameters = JsonData.fromString(
"{\"type\": \"object\", \"properties\": {\"text\": {\"type\": \"string\"}}}"
JochenGuckSnk marked this conversation as resolved.
Show resolved Hide resolved
),
)
)
)
openAI.chatCompletion(chatCompletionCreateFunctionCall).choices.forEach(::println)
println("> Process Chat Completion function call...")
val chatFunctionReturn = ChatCompletionRequest(
model = ModelId("gpt-3.5-turbo-0613"),
messages = listOf(
ChatMessage(
role = ChatRole.System,
content = "You are a helpful assistant that uses a function to translates English to French.\n" +
"Use only the result of the function call as the response."
),
ChatMessage(
role = ChatRole.User,
content = "Translate the following English text to French: “OpenAI is awesome!”"
),
ChatMessage(
role = ChatRole.Assistant,
content = "None",
functionCall = JsonData.builder {
put("name", "translate")
put("arguments", "{\"text\": \"OpenAI is awesome!\"}")
JochenGuckSnk marked this conversation as resolved.
Show resolved Hide resolved
}

),
ChatMessage(
role = ChatRole.Function,
content = "openai est super !",
name = "translate",
)
),
)
openAI.chatCompletion(chatFunctionReturn).choices.forEach(::println)

println("\n>️ Creating chat completions stream...")
openAI.chatCompletions(chatCompletionRequest)
.onEach { print(it.choices.first().delta?.content.orEmpty()) }
Expand Down