Skip to content

Commit

Permalink
feat(client): add async streaming methods (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] committed Nov 20, 2024
1 parent 2a1c09d commit c077dca
Show file tree
Hide file tree
Showing 10 changed files with 459 additions and 4 deletions.
3 changes: 3 additions & 0 deletions openai-java-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ dependencies {
testImplementation("org.assertj:assertj-core:3.25.3")
testImplementation("org.junit.jupiter:junit-jupiter-api:5.9.3")
testImplementation("org.junit.jupiter:junit-jupiter-params:5.9.3")
testImplementation("org.mockito:mockito-core:5.14.2")
testImplementation("org.mockito:mockito-junit-jupiter:5.14.2")
testImplementation("org.mockito.kotlin:mockito-kotlin:4.1.0")
}
26 changes: 26 additions & 0 deletions openai-java-core/src/main/kotlin/com/openai/core/ClientOptions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ import com.openai.core.http.PhantomReachableClosingHttpClient
import com.openai.core.http.QueryParams
import com.openai.core.http.RetryingHttpClient
import java.time.Clock
import java.util.concurrent.Executor
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicLong

class ClientOptions
private constructor(
private val originalHttpClient: HttpClient,
@get:JvmName("httpClient") val httpClient: HttpClient,
@get:JvmName("jsonMapper") val jsonMapper: JsonMapper,
@get:JvmName("streamHandlerExecutor") val streamHandlerExecutor: Executor,
@get:JvmName("clock") val clock: Clock,
@get:JvmName("baseUrl") val baseUrl: String,
@get:JvmName("headers") val headers: Headers,
Expand All @@ -41,6 +46,7 @@ private constructor(

private var httpClient: HttpClient? = null
private var jsonMapper: JsonMapper = jsonMapper()
private var streamHandlerExecutor: Executor? = null
private var clock: Clock = Clock.systemUTC()
private var baseUrl: String = PRODUCTION_URL
private var headers: Headers.Builder = Headers.builder()
Expand All @@ -55,6 +61,7 @@ private constructor(
internal fun from(clientOptions: ClientOptions) = apply {
httpClient = clientOptions.originalHttpClient
jsonMapper = clientOptions.jsonMapper
streamHandlerExecutor = clientOptions.streamHandlerExecutor
clock = clientOptions.clock
baseUrl = clientOptions.baseUrl
headers = clientOptions.headers.toBuilder()
Expand All @@ -70,6 +77,10 @@ private constructor(

fun jsonMapper(jsonMapper: JsonMapper) = apply { this.jsonMapper = jsonMapper }

fun streamHandlerExecutor(streamHandlerExecutor: Executor) = apply {
this.streamHandlerExecutor = streamHandlerExecutor
}

fun clock(clock: Clock) = apply { this.clock = clock }

fun baseUrl(baseUrl: String) = apply { this.baseUrl = baseUrl }
Expand Down Expand Up @@ -205,6 +216,21 @@ private constructor(
.build()
),
jsonMapper,
streamHandlerExecutor
?: Executors.newCachedThreadPool(
object : ThreadFactory {

private val threadFactory: ThreadFactory =
Executors.defaultThreadFactory()
private val count = AtomicLong(0)

override fun newThread(runnable: Runnable): Thread =
threadFactory.newThread(runnable).also {
it.name =
"openai-stream-handler-thread-${count.getAndIncrement()}"
}
}
),
clock,
baseUrl,
headers.build(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,30 @@ internal fun closeWhenPhantomReachable(observed: Any, closeable: AutoCloseable)
check(observed !== closeable) {
"`observed` cannot be the same object as `closeable` because it would never become phantom reachable"
}
closeWhenPhantomReachable?.let { it(observed, closeable::close) }
closeWhenPhantomReachable(observed, closeable::close)
}

private val closeWhenPhantomReachable: ((Any, AutoCloseable) -> Unit)? by lazy {
/**
* Calls [close] when [observed] becomes only phantom reachable.
*
* This is a wrapper around a Java 9+ [java.lang.ref.Cleaner], or a no-op in older Java versions.
*/
@JvmSynthetic
internal fun closeWhenPhantomReachable(observed: Any, close: () -> Unit) {
closeWhenPhantomReachable?.let { it(observed, close) }
}

private val closeWhenPhantomReachable: ((Any, () -> Unit) -> Unit)? by lazy {
try {
val cleanerClass = Class.forName("java.lang.ref.Cleaner")
val cleanerCreate = cleanerClass.getMethod("create")
val cleanerRegister =
cleanerClass.getMethod("register", Any::class.java, Runnable::class.java)
val cleanerObject = cleanerCreate.invoke(null);

{ observed, closeable ->
{ observed, close ->
try {
cleanerRegister.invoke(cleanerObject, observed, Runnable { closeable.close() })
cleanerRegister.invoke(cleanerObject, observed, Runnable { close() })
} catch (e: ReflectiveOperationException) {
if (e is InvocationTargetException) {
when (val cause = e.cause) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package com.openai.core.http

import com.openai.core.http.AsyncStreamResponse.Handler
import java.util.Optional
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Executor
import java.util.concurrent.atomic.AtomicReference

interface AsyncStreamResponse<T> {

fun subscribe(handler: Handler<T>): AsyncStreamResponse<T>

fun subscribe(handler: Handler<T>, executor: Executor): AsyncStreamResponse<T>

/**
* Closes this resource, relinquishing any underlying resources.
*
* This is purposefully not inherited from [AutoCloseable] because this response should not be
* synchronously closed via try-with-resources.
*/
fun close()

fun interface Handler<in T> {

fun onNext(value: T)

fun onComplete(error: Optional<Throwable>) {}
}
}

@JvmSynthetic
internal fun <T> CompletableFuture<StreamResponse<T>>.toAsync(streamHandlerExecutor: Executor) =
PhantomReachableClosingAsyncStreamResponse(
object : AsyncStreamResponse<T> {

private val state = AtomicReference(State.NEW)

override fun subscribe(handler: Handler<T>): AsyncStreamResponse<T> =
subscribe(handler, streamHandlerExecutor)

override fun subscribe(
handler: Handler<T>,
executor: Executor
): AsyncStreamResponse<T> = apply {
// TODO(JDK): Use `compareAndExchange` once targeting JDK 9.
check(state.compareAndSet(State.NEW, State.SUBSCRIBED)) {
if (state.get() == State.SUBSCRIBED) "Cannot subscribe more than once"
else "Cannot subscribe after the response is closed"
}

this@toAsync.whenCompleteAsync(
{ streamResponse, futureError ->
if (state.get() == State.CLOSED) {
// Avoid doing any work if `close` was called before the future
// completed.
return@whenCompleteAsync
}

if (futureError != null) {
// An error occurred before we started passing chunks to the handler.
handler.onComplete(Optional.of(futureError))
return@whenCompleteAsync
}

var streamError: Throwable? = null
try {
streamResponse.stream().forEach(handler::onNext)
} catch (e: Throwable) {
streamError = e
}

try {
handler.onComplete(Optional.ofNullable(streamError))
} finally {
close()
}
},
executor
)
}

override fun close() {
val previousState = state.getAndSet(State.CLOSED)
if (previousState == State.CLOSED) {
return
}

this@toAsync.whenComplete { streamResponse, _ -> streamResponse?.close() }
}
}
)

private enum class State {
NEW,
SUBSCRIBED,
CLOSED
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.openai.core.http

import com.openai.core.closeWhenPhantomReachable
import com.openai.core.http.AsyncStreamResponse.Handler
import java.util.concurrent.Executor

internal class PhantomReachableClosingAsyncStreamResponse<T>(
private val asyncStreamResponse: AsyncStreamResponse<T>
) : AsyncStreamResponse<T> {
init {
closeWhenPhantomReachable(this, asyncStreamResponse::close)
}

override fun subscribe(handler: Handler<T>): AsyncStreamResponse<T> = apply {
asyncStreamResponse.subscribe(handler)
}

override fun subscribe(handler: Handler<T>, executor: Executor): AsyncStreamResponse<T> =
apply {
asyncStreamResponse.subscribe(handler, executor)
}

override fun close() = asyncStreamResponse.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package com.openai.services.async

import com.openai.core.RequestOptions
import com.openai.core.http.AsyncStreamResponse
import com.openai.models.Completion
import com.openai.models.CompletionCreateParams
import java.util.concurrent.CompletableFuture
Expand All @@ -17,4 +18,11 @@ interface CompletionServiceAsync {
params: CompletionCreateParams,
requestOptions: RequestOptions = RequestOptions.none()
): CompletableFuture<Completion>

/** Creates a completion for the provided prompt and parameters. */
@JvmOverloads
fun createStreaming(
params: CompletionCreateParams,
requestOptions: RequestOptions = RequestOptions.none()
): AsyncStreamResponse<Completion>
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
package com.openai.services.async

import com.openai.core.ClientOptions
import com.openai.core.JsonValue
import com.openai.core.RequestOptions
import com.openai.core.handlers.errorHandler
import com.openai.core.handlers.jsonHandler
import com.openai.core.handlers.map
import com.openai.core.handlers.mapJson
import com.openai.core.handlers.sseHandler
import com.openai.core.handlers.withErrorHandler
import com.openai.core.http.AsyncStreamResponse
import com.openai.core.http.HttpMethod
import com.openai.core.http.HttpRequest
import com.openai.core.http.HttpResponse.Handler
import com.openai.core.http.StreamResponse
import com.openai.core.http.toAsync
import com.openai.core.json
import com.openai.errors.OpenAIError
import com.openai.models.Completion
Expand Down Expand Up @@ -52,4 +59,47 @@ constructor(
}
}
}

private val createStreamingHandler: Handler<StreamResponse<Completion>> =
sseHandler(clientOptions.jsonMapper).mapJson<Completion>().withErrorHandler(errorHandler)

/** Creates a completion for the provided prompt and parameters. */
override fun createStreaming(
params: CompletionCreateParams,
requestOptions: RequestOptions
): AsyncStreamResponse<Completion> {
val request =
HttpRequest.builder()
.method(HttpMethod.POST)
.addPathSegments("completions")
.putAllQueryParams(clientOptions.queryParams)
.replaceAllQueryParams(params.getQueryParams())
.putAllHeaders(clientOptions.headers)
.replaceAllHeaders(params.getHeaders())
.body(
json(
clientOptions.jsonMapper,
params
.getBody()
.toBuilder()
.putAdditionalProperty("stream", JsonValue.from(true))
.build()
)
)
.build()
return clientOptions.httpClient
.executeAsync(request, requestOptions)
.thenApply { response ->
response
.let { createStreamingHandler.handle(it) }
.let { streamResponse ->
if (requestOptions.responseValidation ?: clientOptions.responseValidation) {
streamResponse.map { it.validate() }
} else {
streamResponse
}
}
}
.toAsync(clientOptions.streamHandlerExecutor)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
package com.openai.services.async.chat

import com.openai.core.RequestOptions
import com.openai.core.http.AsyncStreamResponse
import com.openai.models.ChatCompletion
import com.openai.models.ChatCompletionChunk
import com.openai.models.ChatCompletionCreateParams
import java.util.concurrent.CompletableFuture

Expand All @@ -22,4 +24,16 @@ interface CompletionServiceAsync {
params: ChatCompletionCreateParams,
requestOptions: RequestOptions = RequestOptions.none()
): CompletableFuture<ChatCompletion>

/**
* Creates a model response for the given chat conversation. Learn more in the
* [text generation](https://platform.openai.com/docs/guides/text-generation),
* [vision](https://platform.openai.com/docs/guides/vision), and
* [audio](https://platform.openai.com/docs/guides/audio) guides.
*/
@JvmOverloads
fun createStreaming(
params: ChatCompletionCreateParams,
requestOptions: RequestOptions = RequestOptions.none()
): AsyncStreamResponse<ChatCompletionChunk>
}
Loading

0 comments on commit c077dca

Please sign in to comment.