Skip to content

Commit

Permalink
Use base URL for ktor client (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
franciscodr authored Apr 27, 2023
1 parent e8a4eca commit 216682a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 26 deletions.
7 changes: 4 additions & 3 deletions src/commonMain/kotlin/com/xebia/functional/env/config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ data class InvalidConfig(val message: String)

data class Env(val openAI: OpenAIConfig, val huggingFace: HuggingFaceConfig)

data class OpenAIConfig(val token: String, val chunkSize: Int, val retryConfig: RetryConfig)
data class OpenAIConfig(val token: String, val baseUrl: KUrl, val chunkSize: Int, val retryConfig: RetryConfig)

data class RetryConfig(val backoff: Duration, val maxRetries: Long) {
fun schedule(): Schedule<Throwable, Unit> =
Expand All @@ -37,15 +37,16 @@ fun Raise<InvalidConfig>.Env(): Env =
fun Raise<NonEmptyList<String>>.OpenAIConfig(token: String? = null) =
zipOrAccumulate(
{ token ?: env("OPENAI_TOKEN") },
{ env("OPENAI_BASE_URI", default = Url("https://api.openai.com/v1/")) { Url(it) } },
{ env("OPENAI_CHUNK_SIZE", default = 1000) { it.toIntOrNull() } },
{ env("OPENAI_BACKOFF", default = 5.seconds) { it.toIntOrNull()?.seconds } },
{ env("OPENAI_MAX_RETRIES", default = 5) { it.toLongOrNull() } },
) { token2, chunkSize, backoff, maxRetries -> OpenAIConfig(token2, chunkSize, RetryConfig(backoff, maxRetries)) }
) { token2, baseUrl, chunkSize, backoff, maxRetries -> OpenAIConfig(token2, baseUrl, chunkSize, RetryConfig(backoff, maxRetries)) }

fun Raise<NonEmptyList<String>>.HuggingFaceConfig(token: String? = null) =
zipOrAccumulate(
{ token ?: env("HF_TOKEN") },
{ env("HF_BASE_URI", default = Url("https://api-inference.huggingface.co")) { Url(it) } }
{ env("HF_BASE_URI", default = Url("https://api-inference.huggingface.co/")) { Url(it) } }
) { token2, baseUrl -> HuggingFaceConfig(token2, baseUrl) }

fun Raise<String>.Url(urlString: String): KUrl =
Expand Down
14 changes: 6 additions & 8 deletions src/commonMain/kotlin/com/xebia/functional/ktor.kt
Original file line number Diff line number Diff line change
@@ -1,32 +1,30 @@
package com.xebia.functional

import arrow.core.Either
import arrow.core.nonFatalOrThrow
import arrow.fx.coroutines.ResourceScope
import arrow.resilience.Schedule
import arrow.resilience.ScheduleStep
import com.xebia.functional.env.RetryConfig
import io.ktor.client.HttpClient
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.plugins.HttpRequestRetry
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.defaultRequest
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.setBody
import io.ktor.http.ContentType
import io.ktor.http.Url
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlin.time.Duration

inline fun <reified A> HttpRequestBuilder.configure(token: String, request: A): Unit {
header("Authorization", "Bearer $token")
contentType(ContentType.Application.Json)
setBody(request)
}

suspend fun ResourceScope.httpClient(engine: HttpClientEngine): HttpClient =
suspend fun ResourceScope.httpClient(engine: HttpClientEngine, baseUrl: Url): HttpClient =
install({
HttpClient(engine) {
install(ContentNegotiation) { json() }
defaultRequest {
url(baseUrl.toString())
}
}
}) { client, _ -> client.close() }
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@ import com.xebia.functional.httpClient
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.request.url
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.http.path
import io.ktor.serialization.kotlinx.json.json

interface HuggingFaceClient {
suspend fun generate(request: InferenceRequest, model: Model): List<Generation>
Expand All @@ -25,15 +17,15 @@ interface HuggingFaceClient {
suspend fun ResourceScope.KtorHuggingFaceClient(
engine: HttpClientEngine,
config: HuggingFaceConfig
): HuggingFaceClient = KtorHuggingFaceClient(httpClient(engine), config)
): HuggingFaceClient = KtorHuggingFaceClient(httpClient(engine, config.baseUrl), config)

private class KtorHuggingFaceClient(
private val httpClient: HttpClient,
private val config: HuggingFaceConfig
) : HuggingFaceClient {

override suspend fun generate(request: InferenceRequest, model: Model): List<Generation> {
val response = httpClient.post(config.baseUrl) {
val response = httpClient.post {
url { path("models", model.name) }
configure(config.token, request)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.request.post
import io.ktor.http.path

interface OpenAIClient {
suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice>
Expand All @@ -18,25 +19,29 @@ interface OpenAIClient {
suspend fun ResourceScope.KtorOpenAIClient(
engine: HttpClientEngine,
config: OpenAIConfig
): OpenAIClient = KtorOpenAIClient(httpClient(engine), config)
): OpenAIClient = KtorOpenAIClient(httpClient(engine, config.baseUrl), config)

private class KtorOpenAIClient(
private val httpClient: HttpClient,
private val config: OpenAIConfig
) : OpenAIClient {

private val baseUrl = "https://api.openai.com/v1"

override suspend fun createCompletion(request: CompletionRequest): List<CompletionChoice> {
val response = config.retryConfig.schedule().retry {
httpClient.post("$baseUrl/completions") { configure(config.token, request) }
httpClient.post {
url { path("completions") }
configure(config.token, request)
}
}
return response.body()
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
val response = config.retryConfig.schedule().retry {
httpClient.post("$baseUrl/embeddings") { configure(config.token, request) }
httpClient.post {
url { path("embeddings") }
configure(config.token, request)
}
}
return response.body()
}
Expand Down

0 comments on commit 216682a

Please sign in to comment.