Skip to content

Commit

Permalink
Allow custom mappings for different models
Browse files Browse the repository at this point in the history
  • Loading branch information
fedefernandez committed Oct 31, 2023
1 parent 21f0f4a commit 39cd65b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.xebia.functional.xef.server.http.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonPrimitive

class ModelUriAdapter
internal constructor(private val urlMap: Map<OpenAIPathType, Map<String, String>>) {

val logger = KotlinLogging.logger {}

fun isDefined(path: OpenAIPathType): Boolean = urlMap.containsKey(path)

fun findPath(path: OpenAIPathType, model: String): String? = urlMap[path]?.get(model)

companion object : HttpClientPlugin<ModelUriAdapterBuilder, ModelUriAdapter> {

override val key: AttributeKey<ModelUriAdapter> = AttributeKey("ModelAuthAdapter")

override fun prepare(block: ModelUriAdapterBuilder.() -> Unit): ModelUriAdapter =
ModelUriAdapterBuilder().apply(block).build()

override fun install(plugin: ModelUriAdapter, scope: HttpClient) {
installModelAuthAdapter(plugin, scope)
}

private fun readModelFromRequest(originalRequest: OutgoingContent.ByteArrayContent?): String? {
val requestBody = originalRequest?.bytes()?.toString(Charsets.UTF_8)
val json = requestBody?.let { Json.decodeFromString<JsonObject>(it) }
return json?.get("model")?.jsonPrimitive?.content
}

private fun installModelAuthAdapter(plugin: ModelUriAdapter, scope: HttpClient) {
val adaptAuthRequestPhase = PipelinePhase("ModelAuthAdaptRequest")
scope.sendPipeline.insertPhaseAfter(HttpSendPipeline.State, adaptAuthRequestPhase)
scope.sendPipeline.intercept(adaptAuthRequestPhase) { content ->
val originalPath = OpenAIPathType.from(context.url.encodedPath) ?: return@intercept
if (plugin.isDefined(originalPath)) {
val originalRequest = content as? OutgoingContent.ByteArrayContent
if (originalRequest == null) {
plugin.logger.warn {
"""
|Can't adapt the model auth.
|The body type is: ${content::class}, with Content-Type: ${context.contentType()}.
|
|If you expect serialized body, please check that you have installed the corresponding
|plugin(like `ContentNegotiation`) and set `Content-Type` header."""
.trimMargin()
}
return@intercept
}
val model = readModelFromRequest(originalRequest)
val newURL = model?.let { plugin.findPath(originalPath, it) }
if (newURL == null) {
plugin.logger.info {
"Model auth didn't found a new url for path $originalPath and model $model"
}
} else {
val baseBuilder = URLBuilder(newURL).build()
context.url.set(
scheme = baseBuilder.protocol.name,
host = baseBuilder.host,
port = baseBuilder.port,
path = baseBuilder.encodedPath
)
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.xebia.functional.xef.server.http.client

class ModelUriAdapterBuilder {

private var pathMap: Map<OpenAIPathType, Map<String, String>> = LinkedHashMap()

fun setPathMap(pathMap: Map<OpenAIPathType, Map<String, String>>) {
this.pathMap = pathMap
}

fun addToPath(path: OpenAIPathType, vararg modelUriPaths: Pair<String, String>) {
val newPathTypeMap = mapOf(*modelUriPaths.map { Pair(it.first, it.second) }.toTypedArray())
this.pathMap += mapOf(path to newPathTypeMap)
}

internal fun build(): ModelUriAdapter = ModelUriAdapter(pathMap)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.xef.server.http.client

enum class OpenAIPathType(val value: String) {
CHAT("/v1/chat/completions"),
EMBEDDINGS("/v1/embeddings"),
FINE_TUNING("/v1/fine_tuning/jobs"),
FILES("/v1/files"),
IMAGES("/v1/images/generations"),
MODELS("/v1/models"),
MODERATION("/v1/moderations");

companion object {
fun from(v: String): OpenAIPathType? = entries.find { it.value == v }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private suspend fun HttpClient.makeRequest(
method = HttpMethod.Post
setBody(body)
}
call.response.headers.copyFrom(response.headers)
call.response.headers.copyFrom(response.headers, "Content-Length")
call.respond(response.status, response.readBytes())
}

Expand All @@ -91,17 +91,18 @@ private suspend fun HttpClient.makeStreaming(
setBody(body)
}
.execute { httpResponse ->
call.response.headers.copyFrom(httpResponse.headers)
call.response.headers.copyFrom(httpResponse.headers, "Content-Length")
call.respondOutputStream { httpResponse.bodyAsChannel().copyTo(this@respondOutputStream) }
}
}

private fun ResponseHeaders.copyFrom(headers: Headers) =
private fun ResponseHeaders.copyFrom(headers: Headers, vararg filterOut: String) =
headers
.entries()
.filter { (key, _) ->
!HttpHeaders.isUnsafe(key)
} // setting unsafe headers results in exception
.filterNot { (key, _) -> filterOut.any { it.equals(key, true) } }
.forEach { (key, values) -> values.forEach { value -> this.appendIfAbsent(key, value) } }

internal fun HeadersBuilder.copyFrom(headers: Headers) =
Expand Down

0 comments on commit 39cd65b

Please sign in to comment.