Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix incorrect impl of flow (streaming gen.) for local LLMs #381

Merged
merged 7 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.*

interface Chat : LLM {
val modelType: ModelType
Expand All @@ -37,19 +34,14 @@ interface Chat : LLM {
streamToStandardOut = true
)

val buffer = StringBuilder()
createChatCompletions(request)
.onEach {
it.choices.forEach { choice ->
val text = choice.delta?.content ?: ""
buffer.append(text)
}
}
.onCompletion {
val message = assistant(buffer.toString())
.mapNotNull { it.choices.mapNotNull { it.delta?.content }.reduceOrNull(String::plus) }
.onEach { emit(it) }
.fold("", String::plus)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to achieve the same as the onEach and fold? Currently, each item is emitted to another outer flow.
This is required because we have to fold the flow and get the full generated string in order to add it to the memory. On the other hand, we want to pass each item forward to the caller of the function so that they can do whatever the want to do with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It kind-of depends on the desired behavior of the resulting Flow.
onEach will back-pressure the fold, so the resulting Flow can still back-pressure the fold which is desired.

onEach is the same of map { f(it); it }.

My concern is that fold will take everything in memory, but I guess that's inherit to MemoryManagement.addMemoriesAfterStream. Just going by the name it sounds like it puts everything in memory anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the goal is indeed to get the full string of generated text and store in the conversation history. At the same time I want to forward every individual item to the outer flow.
I came to the conclusion there is no better way.

If you have time, can you explain what you mean by the terminology of back-pressure?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Intex32 I explain back-pressure a little bit here. https://youtu.be/n724AdSAkOI
Let me know if this clears it up for you, if not we can discuss it further ☺️

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nomisRev I have more questions than before 😂. I especially liked the part where you talked about back pressure.
The pen has some classy dance moves tbh.

.also { finalText ->
val message = assistant(finalText)
MemoryManagement.addMemoriesAfterStream(this@Chat, request, scope, listOf(message))
}
.collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) }
}

@AiDsl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@ import com.xebia.functional.xef.store.VectorStore
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.flow.*
import java.io.OutputStream
import java.io.PrintStream
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Path
import java.util.UUID
import java.util.*
import kotlin.io.path.name


Expand Down Expand Up @@ -73,7 +68,6 @@ interface GPT4All : AutoCloseable, Chat, Completion {

override suspend fun createCompletion(request: CompletionRequest): CompletionResult =
with(request) {

val config = LLModel.config()
.withTopP(request.topP?.toFloat() ?: 0.4f)
.withTemp(request.temperature?.toFloat() ?: 0f)
Expand All @@ -94,8 +88,8 @@ interface GPT4All : AutoCloseable, Chat, Completion {
with(request) {
val prompt: String = messages.buildPrompt()
val config = LLModel.config()
.withTopP(request.topP.toFloat() ?: 0.4f)
.withTemp(request.temperature.toFloat() ?: 0f)
.withTopP(request.topP.toFloat())
.withTemp(request.temperature.toFloat())
.withRepeatPenalty(request.frequencyPenalty.toFloat())
.build()
val response: String = generateCompletion(prompt, config, request.streamToStandardOut)
Expand All @@ -117,51 +111,49 @@ interface GPT4All : AutoCloseable, Chat, Completion {
* @param request The ChatCompletionRequest containing the necessary information for creating completions.
* @return A Flow of ChatCompletionChunk objects representing the generated chat completions.
*/
override suspend fun createChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> =
with(request) {
val prompt: String = messages.buildPrompt()
val config = LLModel.config()
.withTopP(request.topP.toFloat())
.withTemp(request.temperature.toFloat())
.withRepeatPenalty(request.frequencyPenalty.toFloat())
override suspend fun createChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> {
val prompt: String = request.messages.buildPrompt()
val config = with(request) {
LLModel.config()
.withTopP(topP.toFloat())
.withTemp(temperature.toFloat())
.withRepeatPenalty(frequencyPenalty.toFloat())
.build()
}

val originalOut = System.out // Save the original standard output

return coroutineScope {
val channel = Channel<String>(capacity = UNLIMITED)

val outputStream = object : OutputStream() {
override fun write(b: Int) {
val c = b.toChar()
channel.trySend(c.toString())
}
}

val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8)
val channel = Channel<String>(capacity = UNLIMITED)
val outputStream = object : OutputStream() {
override fun write(b: Int) {
val c = b.toChar()
channel.trySend(c.toString())
}
}

fun toChunk(text: String?): ChatCompletionChunk =
ChatCompletionChunk(
UUID.randomUUID().toString(),
System.currentTimeMillis().toInt(),
path.name,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))),
Usage.ZERO,
)
val originalOut = System.out // Save the original standard output

val flow = channel.consumeAsFlow().map { toChunk(it) }
fun toChunk(text: String?): ChatCompletionChunk = ChatCompletionChunk(
UUID.randomUUID().toString(),
System.currentTimeMillis().toInt(),
path.name,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))),
Usage.ZERO,
)

launch(Dispatchers.IO) {
return merge(
emptyFlow<ChatCompletionChunk>()
.onStart {
val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8)
System.setOut(printStream) // Set the standard output to the print stream
generateCompletion(prompt, config, request.streamToStandardOut)
channel.close()
}

flow.onCompletion {
System.setOut(originalOut) // Restore the original standard output
}
}
}
.flowOn(Dispatchers.IO),
channel
.consumeAsFlow()
.map(::toChunk)
)
}

override fun tokensFromMessages(messages: List<Message>): Int {
return 0
Expand Down