Skip to content

Commit

Permalink
OpenAI host and SpaceCraftLocal
Browse files Browse the repository at this point in the history
  • Loading branch information
javipacheco committed Aug 28, 2023
1 parent bc69f61 commit 116205b
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.xebia.functional.xef.conversation.streaming

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.store.LocalVectorStore

/*
* This examples is the same than SpaceCraft.kt but using a local server
*
* To run this example, you need to:
* - Execute xef-server in local using the command: ./gradlew server
*/
suspend fun main() {

val model = OpenAI(host = "http://localhost:8081/").DEFAULT_SERIALIZATION

val scope =
Conversation(LocalVectorStore(OpenAIEmbeddings(OpenAI.FromEnvironment.DEFAULT_EMBEDDING)))

model
.promptStreaming(
Prompt("Make a spacecraft with a mission to Mars"),
scope = scope,
serializer = InterstellarCraft.serializer()
)
.collect { element ->
when (element) {
is StreamedFunction.Property -> {
println("${element.path} = ${element.value}")
}
is StreamedFunction.Result -> {
println(element.value)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,37 @@ import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic
import kotlin.jvm.JvmSynthetic

class OpenAI(internal var token: String? = null) : AutoCloseable, AutoClose by autoClose() {
class OpenAI(internal var token: String? = null, internal var host: String? = null) :
AutoCloseable, AutoClose by autoClose() {

private fun openAITokenFromEnv(): String {
return getenv("OPENAI_TOKEN")
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var"))
}

private fun openAIHostFromEnv(): String? {
return getenv("OPENAI_HOST")
}

fun getToken(): String {
return token ?: openAITokenFromEnv()
}

fun getHost(): String? {
return host
?: run {
host = openAIHostFromEnv()
host
}
}

init {
if (token == null) {
token = openAITokenFromEnv()
}
if (host == null) {
host = openAIHostFromEnv()
}
}

val GPT_4 by lazy { autoClose(OpenAIModel(this, "gpt-4", ModelType.GPT_4)) }
Expand Down Expand Up @@ -127,7 +143,7 @@ class OpenAI(internal var token: String? = null) : AutoCloseable, AutoClose by a
}
}

fun String.toOpenAIModel(token: String): OpenAIModel {
val openAI = OpenAI(token)
fun String.toOpenAIModel(token: String, host: String? = null): OpenAIModel {
val openAI = OpenAI(token, host)
return openAI.supportedModels().find { it.name == this } ?: openAI.GPT_3_5_TURBO_16K
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.LoggingConfig
import com.aallam.openai.client.OpenAI as OpenAIClient
import com.aallam.openai.client.OpenAIHost
import com.xebia.functional.tokenizer.Encoding
import com.xebia.functional.tokenizer.ModelType
import com.xebia.functional.xef.llm.*
Expand Down Expand Up @@ -51,6 +52,7 @@ class OpenAIModel(

private val client =
OpenAIClient(
host = openAI.getHost()?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI,
token = openAI.getToken(),
logging = LoggingConfig(LogLevel.None),
headers = mapOf("Authorization" to " Bearer $openAI.token")
Expand Down
8 changes: 8 additions & 0 deletions server/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ task<JavaExec>("web-app") {
classpath = sourceSets.main.get().runtimeClasspath
mainClass.set("com.xebia.functional.xef.server.WebApp")
}

task<JavaExec>("server") {
dependsOn("compileKotlin")
group = "Execution"
description = "xef-server server application"
classpath = sourceSets.main.get().runtimeClasspath
mainClass.set("com.xebia.functional.xef.server.Server")
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.xebia.functional.xef.server


import arrow.continuations.SuspendApp
import arrow.fx.coroutines.resourceScope
import arrow.continuations.ktor.server
Expand All @@ -25,7 +24,7 @@ import io.ktor.server.resources.*
import io.ktor.server.routing.*
import kotlinx.coroutines.awaitCancellation

object Main {
object Server {
@JvmStatic
fun main(args: Array<String>) = SuspendApp {
resourceScope {
Expand All @@ -45,7 +44,7 @@ object Main {
install(ClientContentNegotiation)
}

server(factory = Netty, port = 8080, host = "0.0.0.0") {
server(factory = Netty, port = 8081, host = "0.0.0.0") {
install(CORS) {
allowNonSimpleContentTypes = true
anyHost()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fun String.toProvider(): Provider? = when (this) {
"openai" -> Provider.OPENAI
"gpt4all" -> Provider.GPT4ALL
"gcp" -> Provider.GCP
else -> null
else -> Provider.OPENAI
}


Expand All @@ -56,8 +56,8 @@ fun Routing.routes(
val provider: Provider = call.getProvider()
val token = call.getToken()
val scope = Conversation(persistenceService.getVectorStore(provider, token))
val darta = call.receive<String>()
val data = Json.decodeFromString<JsonObject>(darta)// call.receive<JsonObject>()
val context = call.receive<String>()
val data = Json.decodeFromString<JsonObject>(context)
if (!data.containsKey("model")) {
call.respondText("No model found", status = HttpStatusCode.BadRequest)
return@post
Expand All @@ -67,8 +67,6 @@ fun Routing.routes(
return@post
}

println(darta)

val isStream = data["stream"]?.jsonPrimitive?.boolean ?: false

if (!isStream) {
Expand All @@ -78,7 +76,7 @@ fun Routing.routes(
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(darta)
setBody(context)
}
call.respond(response.body<String>())
} else {
Expand All @@ -89,7 +87,7 @@ fun Routing.routes(
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(darta)
setBody(context)
}.execute { httpResponse ->
val channel: ByteReadChannel = httpResponse.body()
call.respondBytesWriter(contentType = ContentType.Application.Json) {
Expand All @@ -110,7 +108,7 @@ fun Routing.routes(

private fun ApplicationCall.getProvider(): Provider =
request.headers["xef-provider"]?.toProvider()
?: throw IllegalArgumentException("Not a valid provider")
?: Provider.OPENAI

private fun ApplicationCall.getToken(): String =
principal<UserIdPrincipal>()?.name ?: throw IllegalArgumentException("No token found")
Expand Down

0 comments on commit 116205b

Please sign in to comment.