Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory updated #467

Merged
merged 4 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ interface Chat : LLM {
.also { finalText ->
val aiResponseMessage = assistant(finalText)
val newMessages = prompt.messages + listOf(aiResponseMessage)
newMessages.addToMemory(this@Chat, scope)
newMessages.addToMemory(scope)
}
}

Expand All @@ -46,7 +46,7 @@ interface Chat : LLM {

@AiDsl
suspend fun promptMessages(prompt: Prompt, scope: Conversation): List<String> {
val promptMemories = prompt.messages.toMemory(this@Chat, scope)
val promptMemories = prompt.messages.toMemory(scope)
val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

val request =
Expand All @@ -60,7 +60,7 @@ interface Chat : LLM {

return createChatCompletion(request)
.choices
.addMessagesToMemory(this@Chat, scope, promptMemories)
.addChoiceToMemory(scope, promptMemories)
.mapNotNull { it.message?.content }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ interface ChatWithFunctions : LLM {
serializer,
promptWithFunctions.configuration.maxDeserializationAttempts
) {
val requestedMemories = prompt.messages.toMemory(this@ChatWithFunctions, scope)
val requestedMemories = prompt.messages.toMemory(scope)
createChatCompletionWithFunctions(request)
.choices
.addMessagesToMemory(this@ChatWithFunctions, scope, requestedMemories)
.addChoiceWithFunctionsToMemory(scope, requestedMemories)
.mapNotNull { it.message?.functionCall?.arguments }
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,85 +4,51 @@ import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.Choice
import com.xebia.functional.xef.llm.models.chat.ChoiceWithFunctions
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.Memory
import io.ktor.util.date.*

internal suspend fun List<Message>.addToMemory(chat: LLM, scope: Conversation) {
val memories = toMemory(chat, scope)
if (memories.isNotEmpty()) {
scope.store.addMemories(memories)
internal suspend fun List<Message>.addToMemory(scope: Conversation) {
val cid = scope.conversationId
if (cid != null) {
val memories = toMemory(scope)
if (memories.isNotEmpty()) {
scope.store.addMemories(memories)
}
}
}

internal fun Message.toMemory(cid: ConversationId, chat: LLM, delta: Int = 0): Memory =
Memory(
conversationId = cid,
content = this,
timestamp = getTimeMillis() + delta,
approxTokens = chat.tokensFromMessages(listOf(this))
)
internal fun Message.toMemory(cid: ConversationId, index: Int): Memory =
Memory(conversationId = cid, content = this, index = index)

internal fun List<Message>.toMemory(chat: LLM, scope: Conversation): List<Memory> {
internal fun List<Message>.toMemory(scope: Conversation): List<Memory> {
val cid = scope.conversationId
return if (cid != null) {
mapIndexed { index, it -> it.toMemory(cid, chat, index) }
mapIndexed { index, it -> it.toMemory(cid, scope.store.incrementIndexAndGet()) }
javipacheco marked this conversation as resolved.
Show resolved Hide resolved
} else emptyList()
}

internal suspend fun List<ChoiceWithFunctions>.addMessagesToMemory(
chat: LLM,
internal suspend fun List<ChoiceWithFunctions>.addChoiceWithFunctionsToMemory(
scope: Conversation,
previousMemories: List<Memory>
): List<ChoiceWithFunctions> = also {
val firstChoice = firstOrNull()
val cid = scope.conversationId
if (firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER

val firstChoiceMessage =
Message(
role = role,
content = firstChoice.message?.content
?: firstChoice.message?.functionCall?.arguments ?: "",
name = role.name
)

// Temporary solution to avoid duplicate timestamps when calling the AI.
if (isNotEmpty() && cid != null) {
val aiMemory =
firstChoiceMessage.toMemory(cid, chat).let {
if (previousMemories.isNotEmpty() && previousMemories.last().timestamp >= it.timestamp) {
it.copy(timestamp = previousMemories.last().timestamp + 1)
} else it
}

this.mapNotNull { it.message }
.map { it.toMessage().toMemory(cid, scope.store.incrementIndexAndGet()) }
val newMessages = previousMemories + aiMemory
scope.store.addMemories(newMessages)
}
}

internal suspend fun List<Choice>.addMessagesToMemory(
chat: Chat,
internal suspend fun List<Choice>.addChoiceToMemory(
scope: Conversation,
previousMemories: List<Memory>
): List<Choice> = also {
val firstChoice = firstOrNull()
val cid = scope.conversationId
if (firstChoice != null && cid != null) {
val role = firstChoice.message?.role?.name?.uppercase()?.let { Role.valueOf(it) } ?: Role.USER

val firstChoiceMessage =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name)

// Temporary solution to avoid duplicate timestamps when calling the AI.
if (isNotEmpty() && cid != null) {
val aiMemory =
firstChoiceMessage.toMemory(cid, chat).let {
if (previousMemories.isNotEmpty() && previousMemories.last().timestamp >= it.timestamp) {
it.copy(timestamp = previousMemories.last().timestamp + 1)
} else it
}

this.mapNotNull { it.message }.map { it.toMemory(cid, scope.store.incrementIndexAndGet()) }
val newMessages = previousMemories + aiMemory
scope.store.addMemories(newMessages)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ internal object PromptCalculator {
suspend fun adaptPromptToConversationAndModel(
prompt: Prompt,
scope: Conversation,
chat: LLM
llm: LLM
): Prompt {

// calculate tokens for history and context
val remainingTokensForContexts = calculateRemainingTokensForContext(chat, prompt)
val remainingTokensForContexts = calculateRemainingTokensForContext(llm, prompt)

val maxHistoryTokens = calculateMaxHistoryTokens(prompt, remainingTokensForContexts)

Expand All @@ -26,9 +26,12 @@ internal object PromptCalculator {
// calculate messages for history based on tokens

val memories: List<Memory> =
scope.memories(maxHistoryTokens + prompt.configuration.messagePolicy.historyPaddingTokens)
scope.memories(
llm,
maxHistoryTokens + prompt.configuration.messagePolicy.historyPaddingTokens
)

val historyAllowed = calculateMessagesFromHistory(chat, memories, maxHistoryTokens)
val historyAllowed = calculateMessagesFromHistory(llm, memories, maxHistoryTokens)

// calculate messages for context based on tokens
val ctxInfo =
Expand All @@ -41,7 +44,7 @@ internal object PromptCalculator {
if (ctxInfo.isNotEmpty()) {
val ctx: String = ctxInfo.joinToString("\n")

val ctxTruncated: String = chat.modelType.encoding.truncateText(ctx, maxContextTokens)
val ctxTruncated: String = llm.modelType.encoding.truncateText(ctx, maxContextTokens)

Prompt { +assistant(ctxTruncated) }.messages
} else {
Expand All @@ -55,7 +58,7 @@ internal object PromptCalculator {
memories.map { it.content }

private fun calculateMessagesFromHistory(
chat: LLM,
llm: LLM,
memories: List<Memory>,
maxHistoryTokens: Int
) =
Expand All @@ -64,11 +67,10 @@ internal object PromptCalculator {

// since we have the approximate tokens in memory, we need to fit the messages back to the
// number of tokens if necessary
val historyTokens = chat.tokensFromMessages(history)
val historyTokens = llm.tokensFromMessages(history)
if (historyTokens <= maxHistoryTokens) history
else {
val historyMessagesWithTokens =
history.map { Pair(it, chat.tokensFromMessages(listOf(it))) }
val historyMessagesWithTokens = history.map { Pair(it, llm.tokensFromMessages(listOf(it))) }

val totalTokenWithMessages =
historyMessagesWithTokens.foldRight(Pair(0, emptyList<Message>())) { pair, acc ->
Expand All @@ -94,11 +96,11 @@ internal object PromptCalculator {
return maxHistoryTokens
}

private fun calculateRemainingTokensForContext(chat: LLM, prompt: Prompt): Int {
val maxContextLength: Int = chat.modelType.maxContextLength
private fun calculateRemainingTokensForContext(llm: LLM, prompt: Prompt): Int {
val maxContextLength: Int = llm.modelType.maxContextLength
val remainingTokens: Int = maxContextLength - prompt.configuration.minResponseTokens

val messagesTokens = chat.tokensFromMessages(prompt.messages)
val messagesTokens = llm.tokensFromMessages(prompt.messages)

if (messagesTokens >= remainingTokens) {
throw AIError.PromptExceedsMaxRemainingTokenLength(messagesTokens, remainingTokens)
Expand All @@ -108,10 +110,10 @@ internal object PromptCalculator {
return remainingTokensForContexts
}

private suspend fun Conversation.memories(limitTokens: Int): List<Memory> {
private suspend fun Conversation.memories(llm: LLM, limitTokens: Int): List<Memory> {
val cid = conversationId
return if (cid != null) {
store.memories(cid, limitTokens)
store.memories(llm, cid, limitTokens)
} else {
emptyList()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ sealed class StreamedFunction<out A> {
.createChatCompletionsWithFunctions(request)
.onCompletion {
val newMessages = promptMessages + messages
newMessages.addToMemory(chat, scope)
newMessages.addToMemory(scope)
}
.collect { responseChunk ->
// Each chunk is emitted from the LLM and it will include a delta.parameters with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,13 @@ data class MessageWithFunctionCall(
val content: String? = null,
val functionCall: FunctionCall?,
val name: String? = Role.ASSISTANT.name
)
) {
fun toMessage(): Message {
val role = role.uppercase().let { Role.valueOf(it) } // TODO valueOf is unsafe
return Message(
role = role,
content = content ?: functionCall?.arguments ?: "",
name = role.name
)
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xebia.functional.xef.store

import com.xebia.functional.xef.llm.LLM
import com.xebia.functional.xef.llm.models.embeddings.Embedding

/**
Expand All @@ -11,13 +12,17 @@ import com.xebia.functional.xef.llm.models.embeddings.Embedding
class CombinedVectorStore(private val top: VectorStore, private val bottom: VectorStore) :
VectorStore by top {

override suspend fun memories(conversationId: ConversationId, limitTokens: Int): List<Memory> {
val bottomResults = bottom.memories(conversationId, limitTokens)
val topResults = top.memories(conversationId, limitTokens)
override suspend fun memories(
llm: LLM,
conversationId: ConversationId,
limitTokens: Int
): List<Memory> {
val bottomResults = bottom.memories(llm, conversationId, limitTokens)
val topResults = top.memories(llm, conversationId, limitTokens)

return (topResults + bottomResults)
.sortedByDescending { it.timestamp }
.reduceByLimitToken(limitTokens)
.sortedByDescending { it.index }
.reduceByLimitToken(llm, limitTokens)
.reversed()
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.xebia.functional.xef.store

import arrow.atomic.Atomic
import arrow.atomic.AtomicInt
import arrow.atomic.getAndUpdate
import arrow.atomic.update
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.LLM
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import kotlin.math.sqrt
Expand All @@ -27,6 +29,14 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi

private val requestConfig = RequestConfig(RequestConfig.Companion.User("user"))

override val indexValue: AtomicInt = AtomicInt(0)

override fun updateIndexByConversationId(conversationId: ConversationId) {
state.get().orderedMemories[conversationId]?.let { memories ->
memories.maxByOrNull { it.index }?.let { lastMemory -> indexValue.set(lastMemory.index) }
}
}

override suspend fun addMemories(memories: List<Memory>) {
state.update { prevState ->
prevState.copy(
Expand All @@ -44,12 +54,16 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
}
}

override suspend fun memories(conversationId: ConversationId, limitTokens: Int): List<Memory> {
override suspend fun memories(
llm: LLM,
conversationId: ConversationId,
limitTokens: Int
): List<Memory> {
val memories = state.get().orderedMemories[conversationId]
return memories
.orEmpty()
.sortedByDescending { it.timestamp }
.reduceByLimitToken(limitTokens)
.sortedByDescending { it.index }
.reduceByLimitToken(llm, limitTokens)
.reversed()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ import com.xebia.functional.xef.llm.models.chat.Message
*
* @property content message sent.
* @property conversationId uniquely identifies the conversation in which the message took place.
* @property timestamp in milliseconds.
* @property index autoincrement index.
*/
data class Memory(
val conversationId: ConversationId,
val content: Message,
val timestamp: Long,
val approxTokens: Int
)
data class Memory(val conversationId: ConversationId, val content: Message, val index: Int)
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package com.xebia.functional.xef.store

fun List<Memory>.reduceByLimitToken(limitTokens: Int): List<Memory> =
fold(Pair(0, emptyList<Memory>())) { (accTokens, list), memory ->
val totalTokens = accTokens + memory.approxTokens
if (totalTokens <= limitTokens) {
Pair(totalTokens, list + memory)
} else {
Pair(accTokens, list)
import com.xebia.functional.xef.llm.LLM

fun List<Memory>.reduceByLimitToken(llm: LLM, limitTokens: Int): List<Memory> {
val tokensFromMessages = llm.tokensFromMessages(map { it.content })
return if (tokensFromMessages <= limitTokens) this
else
fold(Pair(0, emptyList<Memory>())) { (accTokens, list), memory ->
val tokensFromMessage = llm.tokensFromMessages(listOf(memory.content))
val totalTokens = accTokens + tokensFromMessage
if (totalTokens <= limitTokens) {
Pair(totalTokens, list + memory)
} else {
Pair(accTokens, list)
}
}
}
.second
.second
}
Loading