Skip to content

Commit

Permalink
Server support stream (#346)
Browse files Browse the repository at this point in the history
* supporting stream

* OpenAI host and SpaceCraftLocal

---------

Co-authored-by: Javi Pacheco <javi.pacheco@gmail.com>
  • Loading branch information
Montagon and javipacheco authored Aug 28, 2023
1 parent be2b748 commit 0da389c
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.xebia.functional.xef.conversation.streaming

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.store.LocalVectorStore

/*
* This examples is the same than SpaceCraft.kt but using a local server
*
* To run this example, you need to:
* - Execute xef-server in local using the command: ./gradlew server
*/
suspend fun main() {

val model = OpenAI(host = "http://localhost:8081/").DEFAULT_SERIALIZATION

val scope =
Conversation(LocalVectorStore(OpenAIEmbeddings(OpenAI.FromEnvironment.DEFAULT_EMBEDDING)))

model
.promptStreaming(
Prompt("Make a spacecraft with a mission to Mars"),
scope = scope,
serializer = InterstellarCraft.serializer()
)
.collect { element ->
when (element) {
is StreamedFunction.Property -> {
println("${element.path} = ${element.value}")
}
is StreamedFunction.Result -> {
println(element.value)
}
}
}
}
3 changes: 3 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ kotlinx-coroutines-jdk8 = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-j
ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" }
ktor-http = { module = "io.ktor:ktor-http", version.ref = "ktor" }
ktor-client ={ module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-auth = { module = "io.ktor:ktor-client-auth", version.ref = "ktor" }
ktor-client-content-negotiation ={ module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" }
ktor-client-serialization = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" }
ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" }
ktor-client-logging = { module = "io.ktor:ktor-client-logging", version.ref = "ktor" }
ktor-client-js = { module = "io.ktor:ktor-client-js", version.ref = "ktor" }
ktor-client-json = { module = "io.ktor:ktor-client-json", version.ref = "ktor" }
ktor-client-winhttp = { module = "io.ktor:ktor-client-winhttp", version.ref = "ktor" }
ktor-server-auth = { module = "io.ktor:ktor-server-auth", version.ref = "ktor" }
ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,37 @@ import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic
import kotlin.jvm.JvmSynthetic

class OpenAI(internal var token: String? = null) : AutoCloseable, AutoClose by autoClose() {
class OpenAI(internal var token: String? = null, internal var host: String? = null) :
AutoCloseable, AutoClose by autoClose() {

private fun openAITokenFromEnv(): String {
return getenv("OPENAI_TOKEN")
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var"))
}

private fun openAIHostFromEnv(): String? {
return getenv("OPENAI_HOST")
}

fun getToken(): String {
return token ?: openAITokenFromEnv()
}

fun getHost(): String? {
return host
?: run {
host = openAIHostFromEnv()
host
}
}

init {
if (token == null) {
token = openAITokenFromEnv()
}
if (host == null) {
host = openAIHostFromEnv()
}
}

val GPT_4 by lazy { autoClose(OpenAIModel(this, "gpt-4", ModelType.GPT_4)) }
Expand Down Expand Up @@ -127,7 +143,7 @@ class OpenAI(internal var token: String? = null) : AutoCloseable, AutoClose by a
}
}

fun String.toOpenAIModel(token: String): OpenAIModel {
val openAI = OpenAI(token)
fun String.toOpenAIModel(token: String, host: String? = null): OpenAIModel {
val openAI = OpenAI(token, host)
return openAI.supportedModels().find { it.name == this } ?: openAI.GPT_3_5_TURBO_16K
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.LoggingConfig
import com.aallam.openai.client.OpenAI as OpenAIClient
import com.aallam.openai.client.OpenAIHost
import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.*
Expand Down Expand Up @@ -51,6 +52,7 @@ class OpenAIModel(

private val client =
OpenAIClient(
host = openAI.getHost()?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI,
token = openAI.getToken(),
logging = LoggingConfig(LogLevel.None),
headers = mapOf("Authorization" to " Bearer $openAI.token")
Expand Down
13 changes: 13 additions & 0 deletions server/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ dependencies {
implementation(libs.kotlinx.serialization.json)
implementation(libs.kotlinx.serialization.hocon)
implementation(libs.ktor.serialization.json)
implementation(libs.ktor.client)
implementation(libs.ktor.client.auth)
implementation(libs.ktor.client.content.negotiation)
implementation(libs.ktor.client.logging)
implementation(libs.ktor.client.json)
implementation(libs.ktor.server.auth)
implementation(libs.ktor.server.netty)
implementation(libs.ktor.server.core)
Expand Down Expand Up @@ -58,3 +63,11 @@ task<JavaExec>("web-app") {
classpath = sourceSets.main.get().runtimeClasspath
mainClass.set("com.xebia.functional.xef.server.WebApp")
}

task<JavaExec>("server") {
dependsOn("compileKotlin")
group = "Execution"
description = "xef-server server application"
classpath = sourceSets.main.get().runtimeClasspath
mainClass.set("com.xebia.functional.xef.server.Server")
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.xebia.functional.xef.server


import arrow.continuations.SuspendApp
import arrow.fx.coroutines.resourceScope
import arrow.continuations.ktor.server
Expand All @@ -10,6 +9,11 @@ import com.xebia.functional.xef.server.db.psql.Migrate
import com.xebia.functional.xef.server.db.psql.XefVectorStoreConfig
import com.xebia.functional.xef.server.db.psql.XefVectorStoreConfig.Companion.getPersistenceService
import com.xebia.functional.xef.server.http.routes.routes
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation as ClientContentNegotiation
import io.ktor.client.plugins.auth.*
import io.ktor.client.plugins.logging.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
Expand All @@ -20,7 +24,7 @@ import io.ktor.server.resources.*
import io.ktor.server.routing.*
import kotlinx.coroutines.awaitCancellation

object Main {
object Server {
@JvmStatic
fun main(args: Array<String>) = SuspendApp {
resourceScope {
Expand All @@ -32,7 +36,15 @@ object Main {
val persistenceService = vectorStoreConfig.getPersistenceService(config)
persistenceService.addCollection()

server(factory = Netty, port = 8080, host = "0.0.0.0") {
val ktorClient = HttpClient(CIO){
install(Auth)
install(Logging) {
level = LogLevel.INFO
}
install(ClientContentNegotiation)
}

server(factory = Netty, port = 8081, host = "0.0.0.0") {
install(CORS) {
allowNonSimpleContentTypes = true
anyHost()
Expand All @@ -46,7 +58,7 @@ object Main {
}
}
}
routing { routes(persistenceService) }
routing { routes(ktorClient, persistenceService) }
}
awaitCancellation()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,122 @@ import com.aallam.openai.api.chat.ChatCompletionRequest
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.prompt.configuration.PromptConfiguration
import com.xebia.functional.xef.conversation.llm.openai.*
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.llm.models.chat.ChatCompletionRequest as XefChatCompletionRequest
import com.xebia.functional.xef.llm.models.chat.ChatCompletionResponse
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.server.services.PersistenceService
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.util.cio.*
import io.ktor.util.pipeline.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.cancel
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.boolean
import kotlinx.serialization.json.jsonPrimitive

enum class Provider {
OPENAI, GPT4ALL, GCP
OPENAI, GPT4ALL, GCP
}

fun String.toProvider(): Provider? = when (this) {
"openai" -> Provider.OPENAI
"gpt4all" -> Provider.GPT4ALL
"gcp" -> Provider.GCP
else -> null
"openai" -> Provider.OPENAI
"gpt4all" -> Provider.GPT4ALL
"gcp" -> Provider.GCP
else -> Provider.OPENAI
}


@OptIn(BetaOpenAI::class)
fun Routing.routes(persistenceService: PersistenceService) {
authenticate("auth-bearer") {
post("/chat/completions") {
val provider: Provider = call.request.headers["xef-provider"]?.toProvider()
?: throw IllegalArgumentException("Not a valid provider")
val token = call.principal<UserIdPrincipal>()?.name ?: throw IllegalArgumentException("No token found")
val scope = Conversation(
persistenceService.getVectorStore(provider, token)
)
val data = call.receive<ChatCompletionRequest>().toCore()
val model: OpenAIModel = data.model.toOpenAIModel(token)
response<String, Throwable> {
model.promptMessage(
prompt = Prompt(
messages = data.messages,
configuration = PromptConfiguration(
temperature = data.temperature,
numberOfPredictions = data.n,
user = data.user ?: ""
)
),
scope = scope
)
}
fun Routing.routes(
client: HttpClient,
persistenceService: PersistenceService
) {
val openAiUrl = "https://api.openai.com/v1"

authenticate("auth-bearer") {
post("/chat/completions") {
val provider: Provider = call.getProvider()
val token = call.getToken()
val scope = Conversation(persistenceService.getVectorStore(provider, token))
val context = call.receive<String>()
val data = Json.decodeFromString<JsonObject>(context)
if (!data.containsKey("model")) {
call.respondText("No model found", status = HttpStatusCode.BadRequest)
return@post
}
val model: OpenAIModel = data["model"]?.jsonPrimitive?.content?.toOpenAIModel(token) ?: run {
call.respondText("No model found", status = HttpStatusCode.BadRequest)
return@post
}

val isStream = data["stream"]?.jsonPrimitive?.boolean ?: false

if (!isStream) {
val response = client.request("$openAiUrl/chat/completions") {
headers {
bearerAuth(token)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(context)
}
call.respond(response.body<String>())
} else {
runBlocking {
client.preparePost("$openAiUrl/chat/completions") {
headers {
bearerAuth(token)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(context)
}.execute { httpResponse ->
val channel: ByteReadChannel = httpResponse.body()
call.respondBytesWriter(contentType = ContentType.Application.Json) {
while (!channel.isClosedForRead) {
val packet = channel.readRemaining(DEFAULT_BUFFER_SIZE.toLong())
while (!packet.isEmpty) {
val bytes = packet.readBytes()
writeStringUtf8(bytes.decodeToString())
}
}
}
}
}
}
}
}
}
}

private fun ApplicationCall.getProvider(): Provider =
request.headers["xef-provider"]?.toProvider()
?: Provider.OPENAI

private fun ApplicationCall.getToken(): String =
principal<UserIdPrincipal>()?.name ?: throw IllegalArgumentException("No token found")


/**
* Responds with the data and converts any potential Throwable into a 404.
*/
private suspend inline fun <reified T : Any, E : Throwable> PipelineContext<*, ApplicationCall>.response(
block: () -> T
block: () -> T
) = arrow.core.raise.recover<E, Unit>({
call.respond(block())
call.respond(block())
}) {
call.respondText(it.message ?: "Response not found", status = HttpStatusCode.NotFound)
call.respondText(it.message ?: "Response not found", status = HttpStatusCode.NotFound)
}

0 comments on commit 0da389c

Please sign in to comment.