Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add huggingface client #5

Merged
merged 3 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ repositories {
plugins {
base
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.kotlinx.serialization)
}

java {
Expand Down Expand Up @@ -38,7 +39,7 @@ kotlin {
}
val hostOs = System.getProperty("os.name")
val isMingwX64 = hostOs.startsWith("Windows")
val nativeTarget = when {
when {
hostOs == "Mac OS X" -> macosX64("native")
hostOs == "Linux" -> linuxX64("native")
isMingwX64 -> mingwX64("native")
Expand All @@ -47,22 +48,17 @@ kotlin {


sourceSets {
val commonMain by getting {
commonMain {
dependencies {
implementation(libs.arrow.core)
implementation(libs.open.ai)
implementation(libs.arrow.fx)
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.ktor.client)
}
}
val commonTest by getting {
commonTest {
dependencies {
implementation(kotlin("test"))
}
}
val jvmMain by getting
val jvmTest by getting
val jsMain by getting
val jsTest by getting
val nativeMain by getting
val nativeTest by getting
}
}
18 changes: 17 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,25 @@
arrow = "1.2.0-RC"
kotlin = "1.8.20"
openai = "0.12.0"
kotlinx-json = "1.5.0"
ktor = "2.2.2"

[libraries]
arrow-core = { module = "io.arrow-kt:arrow-core", version.ref = "arrow" }
arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
open-ai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" }
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" }

ktor-client = { module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-content-negotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" }
ktor-client-serialization = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" }

[bundles]
ktor-client = [
"ktor-client",
"ktor-client-content-negotiation",
"ktor-client-serialization"
]

[plugins]
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
kotlinx-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
25 changes: 25 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/ktor.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.xebia.functional

import arrow.fx.coroutines.ResourceScope
import io.ktor.client.HttpClient
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.setBody
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json

inline fun <reified A> HttpRequestBuilder.configure(token: String, request: A): Unit {
header("Authorization", "Bearer $token")
contentType(ContentType.Application.Json)
setBody(request)
}

suspend fun ResourceScope.httpClient(engine: HttpClientEngine): HttpClient =
install({
HttpClient(engine) {
install(ContentNegotiation) { json() }
}
}) { client, _ -> client.close() }
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.xebia.functional.llm.huggingface

import arrow.fx.coroutines.ResourceScope
import com.xebia.functional.configure
import com.xebia.functional.httpClient
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json

interface HuggingFaceClient {
suspend fun generate(request: InferenceRequest, model: Model): List<Generation>
}

suspend fun ResourceScope.KtorHuggingFaceClient(
engine: HttpClientEngine,
token: String
): HuggingFaceClient = KtorHuggingFaceClient(httpClient(engine), token)

private class KtorHuggingFaceClient(
private val httpClient: HttpClient,
private val token: String
) : HuggingFaceClient {

// TODO move to config
private val baseUrl = "https://api-inference.huggingface.co"

override suspend fun generate(request: InferenceRequest, model: Model): List<Generation> {
val response = httpClient.post("$baseUrl/models/${model.name}") {
configure(token, request)
}
return response.body()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.xebia.functional.llm.huggingface

import kotlin.jvm.JvmInline
import kotlinx.serialization.Serializable

@Serializable
data class Generation(val generatedText: String)

@Serializable
data class InferenceRequest(val inputs: String, val maxLength: Int = 1000)

@Serializable
@JvmInline
value class Model(val name: String)
3 changes: 0 additions & 3 deletions src/commonMain/kotlin/llm/models/CompletionChoice.kt

This file was deleted.

12 changes: 0 additions & 12 deletions src/commonMain/kotlin/llm/openai/OpenAIClient.kt

This file was deleted.

13 changes: 0 additions & 13 deletions src/jsMain/kotlin/llm/openai/OpenAIClientFactory.kt

This file was deleted.

14 changes: 0 additions & 14 deletions src/jvmMain/kotlin/llm/openai/OpenAIClientFactory.kt

This file was deleted.

14 changes: 0 additions & 14 deletions src/nativeMain/kotlin/llm/openai/OpenAIClientFactory.kt

This file was deleted.