Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use base URL for Ktor client #10

Merged
merged 1 commit into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/commonMain/kotlin/com/xebia/functional/env/config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ data class InvalidConfig(val message: String)

data class Env(val openAI: OpenAIConfig, val huggingFace: HuggingFaceConfig)

data class OpenAIConfig(val token: String, val chunkSize: Int, val retryConfig: RetryConfig)
data class OpenAIConfig(val token: String, val baseUrl: KUrl, val chunkSize: Int, val retryConfig: RetryConfig)

data class RetryConfig(val backoff: Duration, val maxRetries: Long) {
fun schedule(): Schedule<Throwable, Unit> =
Expand All @@ -37,15 +37,16 @@ fun Raise<InvalidConfig>.Env(): Env =
fun Raise<NonEmptyList<String>>.OpenAIConfig(token: String? = null) =
zipOrAccumulate(
{ token ?: env("OPENAI_TOKEN") },
{ env("OPENAI_BASE_URI", default = Url("https://api.openai.com/v1/")) { Url(it) } },
{ env("OPENAI_CHUNK_SIZE", default = 1000) { it.toIntOrNull() } },
{ env("OPENAI_BACKOFF", default = 5.seconds) { it.toIntOrNull()?.seconds } },
{ env("OPENAI_MAX_RETRIES", default = 5) { it.toLongOrNull() } },
) { token2, chunkSize, backoff, maxRetries -> OpenAIConfig(token2, chunkSize, RetryConfig(backoff, maxRetries)) }
) { token2, baseUrl, chunkSize, backoff, maxRetries -> OpenAIConfig(token2, baseUrl, chunkSize, RetryConfig(backoff, maxRetries)) }

fun Raise<NonEmptyList<String>>.HuggingFaceConfig(token: String? = null) =
zipOrAccumulate(
{ token ?: env("HF_TOKEN") },
{ env("HF_BASE_URI", default = Url("https://api-inference.huggingface.co")) { Url(it) } }
{ env("HF_BASE_URI", default = Url("https://api-inference.huggingface.co/")) { Url(it) } }
) { token2, baseUrl -> HuggingFaceConfig(token2, baseUrl) }

fun Raise<String>.Url(urlString: String): KUrl =
Expand Down
14 changes: 6 additions & 8 deletions src/commonMain/kotlin/com/xebia/functional/ktor.kt
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
package com.xebia.functional

import arrow.core.Either
import arrow.core.nonFatalOrThrow
import arrow.fx.coroutines.ResourceScope
import arrow.resilience.Schedule
import arrow.resilience.ScheduleStep
import com.xebia.functional.env.RetryConfig
import io.ktor.client.HttpClient
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.plugins.HttpRequestRetry
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.defaultRequest
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.setBody
import io.ktor.http.ContentType
import io.ktor.http.Url
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration

inline fun <reified A> HttpRequestBuilder.configure(token: String, request: A): Unit {
header("Authorization", "Bearer $token")
contentType(ContentType.Application.Json)
setBody(request)
}

suspend fun ResourceScope.httpClient(engine: HttpClientEngine): HttpClient =
suspend fun ResourceScope.httpClient(engine: HttpClientEngine, baseUrl: Url): HttpClient =
install({
HttpClient(engine) {
install(ContentNegotiation) { json() }
defaultRequest {
url(baseUrl.toString())
}
}
}) { client, _ -> client.close() }
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@ import com.xebia.functional.httpClient
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.request.url
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.http.path
import io.ktor.serialization.kotlinx.json.json

interface HuggingFaceClient {
suspend fun generate(request: InferenceRequest, model: Model): List<Generation>
Expand All @@ -25,15 +17,15 @@ interface HuggingFaceClient {
suspend fun ResourceScope.KtorHuggingFaceClient(
engine: HttpClientEngine,
config: HuggingFaceConfig
): HuggingFaceClient = KtorHuggingFaceClient(httpClient(engine), config)
): HuggingFaceClient = KtorHuggingFaceClient(httpClient(engine, config.baseUrl), config)

private class KtorHuggingFaceClient(
private val httpClient: HttpClient,
private val config: HuggingFaceConfig
) : HuggingFaceClient {

override suspend fun generate(request: InferenceRequest, model: Model): List<Generation> {
val response = httpClient.post(config.baseUrl) {
val response = httpClient.post {
url { path("models", model.name) }
configure(config.token, request)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.request.post
import io.ktor.http.path

interface OpenAIClient {
suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice>
Expand All @@ -18,25 +19,29 @@ interface OpenAIClient {
suspend fun ResourceScope.KtorOpenAIClient(
engine: HttpClientEngine,
config: OpenAIConfig
): OpenAIClient = KtorOpenAIClient(httpClient(engine), config)
): OpenAIClient = KtorOpenAIClient(httpClient(engine, config.baseUrl), config)

private class KtorOpenAIClient(
private val httpClient: HttpClient,
private val config: OpenAIConfig
) : OpenAIClient {

private val baseUrl = "https://api.openai.com/v1"

override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> {
val response = config.retryConfig.schedule().retry {
httpClient.post("$baseUrl/completions") { configure(config.token, request) }
httpClient.post {
url { path("completions") }
configure(config.token, request)
}
}
return response.body()
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
val response = config.retryConfig.schedule().retry {
httpClient.post("$baseUrl/embeddings") { configure(config.token, request) }
httpClient.post {
url { path("embeddings") }
configure(config.token, request)
}
}
return response.body()
}
Expand Down