Skip to content

Commit

Permalink
added a check to order ai and prompt content in memory (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
Montagon authored Sep 22, 2023
1 parent db898c1 commit d2a6114
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ interface Chat : LLM {

@AiDsl
suspend fun promptMessages(prompt: Prompt, scope: Conversation): List<String> {
val requestedMemories = prompt.messages.toMemory(this@Chat, scope)
val promptMemories = prompt.messages.toMemory(this@Chat, scope)
val adaptedPrompt = PromptCalculator.adaptPromptToConversationAndModel(prompt, scope, this@Chat)

val request =
Expand All @@ -60,7 +60,7 @@ interface Chat : LLM {

return createChatCompletion(request)
.choices
.addMessagesToMemory(this@Chat, scope, requestedMemories)
.addMessagesToMemory(this@Chat, scope, promptMemories)
.mapNotNull { it.message?.content }
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.xebia.functional.xef.llm

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.chat.Choice
import com.xebia.functional.xef.llm.models.chat.ChoiceWithFunctions
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.Memory
import io.ktor.util.date.*

Expand All @@ -12,20 +16,18 @@ internal suspend fun List<Message>.addToMemory(chat: LLM, scope: Conversation) {
}
}

internal fun Message.toMemory(cid: ConversationId, chat: LLM, delta: Int = 0): Memory =
Memory(
conversationId = cid,
content = this,
timestamp = getTimeMillis() + delta,
approxTokens = chat.tokensFromMessages(listOf(this))
)

internal fun List<Message>.toMemory(chat: LLM, scope: Conversation): List<Memory> {
val cid = scope.conversationId
return if (cid != null) {
mapIndexed { delta, it ->
Memory(
conversationId = cid,
content = it,
// We are adding the delta to ensure that the timestamp is unique for every message.
// With this, we ensure that the order of the messages is preserved.
// We assume that the AI response time will be in the order of seconds.
timestamp = getTimeMillis() + delta,
approxTokens = chat.tokensFromMessages(listOf(it))
)
}
mapIndexed { index, it -> it.toMemory(cid, chat, index) }
} else emptyList()
}

Expand All @@ -47,7 +49,15 @@ internal suspend fun List<ChoiceWithFunctions>.addMessagesToMemory(
name = role.name
)

val newMessages = previousMemories + listOf(firstChoiceMessage).toMemory(chat, scope)
// Temporary solution to avoid duplicate timestamps when calling the AI.
val aiMemory =
firstChoiceMessage.toMemory(cid, chat).let {
if (previousMemories.isNotEmpty() && previousMemories.last().timestamp >= it.timestamp) {
it.copy(timestamp = previousMemories.last().timestamp + 1)
} else it
}

val newMessages = previousMemories + aiMemory
scope.store.addMemories(newMessages)
}
}
Expand All @@ -65,7 +75,15 @@ internal suspend fun List<Choice>.addMessagesToMemory(
val firstChoiceMessage =
Message(role = role, content = firstChoice.message?.content ?: "", name = role.name)

val newMessages = previousMemories + listOf(firstChoiceMessage).toMemory(chat, scope)
// Temporary solution to avoid duplicate timestamps when calling the AI.
val aiMemory =
firstChoiceMessage.toMemory(cid, chat).let {
if (previousMemories.isNotEmpty() && previousMemories.last().timestamp >= it.timestamp) {
it.copy(timestamp = previousMemories.last().timestamp + 1)
} else it
}

val newMessages = previousMemories + aiMemory
scope.store.addMemories(newMessages)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.xebia.functional.xef.llm.models.chat.*
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.usage.Usage
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow

class TestModel(
Expand All @@ -22,7 +21,6 @@ class TestModel(
request: ChatCompletionRequest
): ChatCompletionResponse {
requests.add(request)
delay(100) // Simulating a AI's delay response
return ChatCompletionResponse(
id = "fake-id",
`object` = "fake-object",
Expand Down

0 comments on commit d2a6114

Please sign in to comment.