Skip to content

Commit

Permalink
fix incorrect impl of flow (streaming gen.) for local LLMs (#381)
Browse files Browse the repository at this point in the history
* improve and fix local llm generation as flow

* change reduce to reduceNotNull to avoid exception

* addressing comment regarding thread instantiation

---------

Co-authored-by: Raúl Raja Martínez <raulraja@gmail.com>
  • Loading branch information
Intex32 and raulraja authored Sep 6, 2023
1 parent 4a8d3d0 commit d171ed0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 59 deletions.
20 changes: 6 additions & 14 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.prompt.templates.assistant
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.*

interface Chat : LLM {
val modelType: ModelType
Expand All @@ -37,19 +34,14 @@ interface Chat : LLM {
streamToStandardOut = true
)

val buffer = StringBuilder()
createChatCompletions(request)
.onEach {
it.choices.forEach { choice ->
val text = choice.delta?.content ?: ""
buffer.append(text)
}
}
.onCompletion {
val message = assistant(buffer.toString())
.mapNotNull { it.choices.mapNotNull { it.delta?.content }.reduceOrNull(String::plus) }
.onEach { emit(it) }
.fold("", String::plus)
.also { finalText ->
val message = assistant(finalText)
MemoryManagement.addMemoriesAfterStream(this@Chat, request, scope, listOf(message))
}
.collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) }
}

@AiDsl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@ import com.xebia.functional.xef.store.VectorStore
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.flow.*
import java.io.OutputStream
import java.io.PrintStream
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Path
import java.util.UUID
import java.util.*
import kotlin.io.path.name


Expand Down Expand Up @@ -73,7 +68,6 @@ interface GPT4All : AutoCloseable, Chat, Completion {

override suspend fun createCompletion(request: CompletionRequest): CompletionResult =
with(request) {

val config = LLModel.config()
.withTopP(request.topP?.toFloat() ?: 0.4f)
.withTemp(request.temperature?.toFloat() ?: 0f)
Expand All @@ -94,8 +88,8 @@ interface GPT4All : AutoCloseable, Chat, Completion {
with(request) {
val prompt: String = messages.buildPrompt()
val config = LLModel.config()
.withTopP(request.topP.toFloat() ?: 0.4f)
.withTemp(request.temperature.toFloat() ?: 0f)
.withTopP(request.topP.toFloat())
.withTemp(request.temperature.toFloat())
.withRepeatPenalty(request.frequencyPenalty.toFloat())
.build()
val response: String = generateCompletion(prompt, config, request.streamToStandardOut)
Expand All @@ -117,51 +111,49 @@ interface GPT4All : AutoCloseable, Chat, Completion {
* @param request The ChatCompletionRequest containing the necessary information for creating completions.
* @return A Flow of ChatCompletionChunk objects representing the generated chat completions.
*/
override suspend fun createChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> =
with(request) {
val prompt: String = messages.buildPrompt()
val config = LLModel.config()
.withTopP(request.topP.toFloat())
.withTemp(request.temperature.toFloat())
.withRepeatPenalty(request.frequencyPenalty.toFloat())
override suspend fun createChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> {
val prompt: String = request.messages.buildPrompt()
val config = with(request) {
LLModel.config()
.withTopP(topP.toFloat())
.withTemp(temperature.toFloat())
.withRepeatPenalty(frequencyPenalty.toFloat())
.build()
}

val originalOut = System.out // Save the original standard output

return coroutineScope {
val channel = Channel<String>(capacity = UNLIMITED)

val outputStream = object : OutputStream() {
override fun write(b: Int) {
val c = b.toChar()
channel.trySend(c.toString())
}
}

val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8)
val channel = Channel<String>(capacity = UNLIMITED)
val outputStream = object : OutputStream() {
override fun write(b: Int) {
val c = b.toChar()
channel.trySend(c.toString())
}
}

fun toChunk(text: String?): ChatCompletionChunk =
ChatCompletionChunk(
UUID.randomUUID().toString(),
System.currentTimeMillis().toInt(),
path.name,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))),
Usage.ZERO,
)
val originalOut = System.out // Save the original standard output

val flow = channel.consumeAsFlow().map { toChunk(it) }
fun toChunk(text: String?): ChatCompletionChunk = ChatCompletionChunk(
UUID.randomUUID().toString(),
System.currentTimeMillis().toInt(),
path.name,
listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))),
Usage.ZERO,
)

launch(Dispatchers.IO) {
return merge(
emptyFlow<ChatCompletionChunk>()
.onStart {
val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8)
System.setOut(printStream) // Set the standard output to the print stream
generateCompletion(prompt, config, request.streamToStandardOut)
channel.close()
}

flow.onCompletion {
System.setOut(originalOut) // Restore the original standard output
}
}
}
.flowOn(Dispatchers.IO),
channel
.consumeAsFlow()
.map(::toChunk)
)
}

override fun tokensFromMessages(messages: List<Message>): Int {
return 0
Expand Down

0 comments on commit d171ed0

Please sign in to comment.