Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/redis connection is closed #3591

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import { Redis, RedisOptions } from 'ioredis'
import { isEqual } from 'lodash'
import { BufferMemory, BufferMemoryInput } from 'langchain/memory'
import { RedisChatMessageHistory, RedisChatMessageHistoryInput } from '@langchain/community/stores/message/ioredis'
import { mapStoredMessageToChatMessage, BaseMessage, AIMessage, HumanMessage } from '@langchain/core/messages'
import { INode, INodeData, INodeParams, ICommonObject, MessageType, IMessage, MemoryMethods, FlowiseMemory } from '../../../src/Interface'
import {
Expand All @@ -12,42 +10,6 @@ import {
mapChatMessageToBaseMessage
} from '../../../src/utils'

let redisClientSingleton: Redis
let redisClientOption: RedisOptions
let redisClientUrl: string

const getRedisClientbyOption = (option: RedisOptions) => {
if (!redisClientSingleton) {
// if client doesn't exists
redisClientSingleton = new Redis(option)
redisClientOption = option
return redisClientSingleton
} else if (redisClientSingleton && !isEqual(option, redisClientOption)) {
// if client exists but option changed
redisClientSingleton.quit()
redisClientSingleton = new Redis(option)
redisClientOption = option
return redisClientSingleton
}
return redisClientSingleton
}

const getRedisClientbyUrl = (url: string) => {
if (!redisClientSingleton) {
// if client doesn't exists
redisClientSingleton = new Redis(url)
redisClientUrl = url
return redisClientSingleton
} else if (redisClientSingleton && url !== redisClientUrl) {
// if client exists but option changed
redisClientSingleton.quit()
redisClientSingleton = new Redis(url)
redisClientUrl = url
return redisClientSingleton
}
return redisClientSingleton
}

class RedisBackedChatMemory_Memory implements INode {
label: string
name: string
Expand Down Expand Up @@ -114,11 +76,11 @@ class RedisBackedChatMemory_Memory implements INode {
}

async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
return await initalizeRedis(nodeData, options)
return await initializeRedis(nodeData, options)
}
}

const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Promise<BufferMemory> => {
const initializeRedis = async (nodeData: INodeData, options: ICommonObject): Promise<BufferMemory> => {
const sessionTTL = nodeData.inputs?.sessionTTL as number
const memoryKey = nodeData.inputs?.memoryKey as string
const sessionId = nodeData.inputs?.sessionId as string
Expand All @@ -127,120 +89,102 @@ const initalizeRedis = async (nodeData: INodeData, options: ICommonObject): Prom
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const redisUrl = getCredentialParam('redisUrl', credentialData, nodeData)

let client: Redis

if (!redisUrl || redisUrl === '') {
const username = getCredentialParam('redisCacheUser', credentialData, nodeData)
const password = getCredentialParam('redisCachePwd', credentialData, nodeData)
const portStr = getCredentialParam('redisCachePort', credentialData, nodeData)
const host = getCredentialParam('redisCacheHost', credentialData, nodeData)
const sslEnabled = getCredentialParam('redisCacheSslEnabled', credentialData, nodeData)

const tlsOptions = sslEnabled === true ? { tls: { rejectUnauthorized: false } } : {}

client = getRedisClientbyOption({
port: portStr ? parseInt(portStr) : 6379,
host,
username,
password,
...tlsOptions
})
} else {
client = getRedisClientbyUrl(redisUrl)
}

let obj: RedisChatMessageHistoryInput = {
sessionId,
client
}

if (sessionTTL) {
obj = {
...obj,
sessionTTL
}
}

const redisChatMessageHistory = new RedisChatMessageHistory(obj)
const redisOptions = redisUrl
? redisUrl
: ({
port: parseInt(getCredentialParam('redisCachePort', credentialData, nodeData) || '6379'),
host: getCredentialParam('redisCacheHost', credentialData, nodeData),
username: getCredentialParam('redisCacheUser', credentialData, nodeData),
password: getCredentialParam('redisCachePwd', credentialData, nodeData),
tls: getCredentialParam('redisCacheSslEnabled', credentialData, nodeData) ? { rejectUnauthorized: false } : undefined
} as RedisOptions)

const memory = new BufferMemoryExtended({
memoryKey: memoryKey ?? 'chat_history',
chatHistory: redisChatMessageHistory,
sessionId,
windowSize,
redisClient: client,
sessionTTL
sessionTTL,
redisOptions
})

return memory
}

interface BufferMemoryExtendedInput {
redisClient: Redis
sessionId: string
windowSize?: number
sessionTTL?: number
redisOptions: RedisOptions | string
}

class BufferMemoryExtended extends FlowiseMemory implements MemoryMethods {
sessionId = ''
redisClient: Redis
windowSize?: number
sessionTTL?: number
redisOptions: RedisOptions | string

constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) {
super(fields)
this.sessionId = fields.sessionId
this.redisClient = fields.redisClient
this.windowSize = fields.windowSize
this.sessionTTL = fields.sessionTTL
this.redisOptions = fields.redisOptions
}

private async withRedisClient<T>(fn: (client: Redis) => Promise<T>): Promise<T> {
const client = typeof this.redisOptions === 'string' ? new Redis(this.redisOptions) : new Redis(this.redisOptions)
try {
return await fn(client)
} finally {
await client.quit()
}
}

async getChatMessages(
overrideSessionId = '',
returnBaseMessages = false,
prependMessages?: IMessage[]
): Promise<IMessage[] | BaseMessage[]> {
if (!this.redisClient) return []

const id = overrideSessionId ? overrideSessionId : this.sessionId
const rawStoredMessages = await this.redisClient.lrange(id, this.windowSize ? this.windowSize * -1 : 0, -1)
const orderedMessages = rawStoredMessages.reverse().map((message) => JSON.parse(message))
const baseMessages = orderedMessages.map(mapStoredMessageToChatMessage)
if (prependMessages?.length) {
baseMessages.unshift(...(await mapChatMessageToBaseMessage(prependMessages)))
}
return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages)
return this.withRedisClient(async (client) => {
const id = overrideSessionId ? overrideSessionId : this.sessionId
const rawStoredMessages = await client.lrange(id, this.windowSize ? this.windowSize * -1 : 0, -1)
const orderedMessages = rawStoredMessages.reverse().map((message) => JSON.parse(message))
const baseMessages = orderedMessages.map(mapStoredMessageToChatMessage)
if (prependMessages?.length) {
baseMessages.unshift(...(await mapChatMessageToBaseMessage(prependMessages)))
}
return returnBaseMessages ? baseMessages : convertBaseMessagetoIMessage(baseMessages)
})
}

async addChatMessages(msgArray: { text: string; type: MessageType }[], overrideSessionId = ''): Promise<void> {
if (!this.redisClient) return

const id = overrideSessionId ? overrideSessionId : this.sessionId
const input = msgArray.find((msg) => msg.type === 'userMessage')
const output = msgArray.find((msg) => msg.type === 'apiMessage')

if (input) {
const newInputMessage = new HumanMessage(input.text)
const messageToAdd = [newInputMessage].map((msg) => msg.toDict())
await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL)
}
await this.withRedisClient(async (client) => {
const id = overrideSessionId ? overrideSessionId : this.sessionId
const input = msgArray.find((msg) => msg.type === 'userMessage')
const output = msgArray.find((msg) => msg.type === 'apiMessage')

if (input) {
const newInputMessage = new HumanMessage(input.text)
const messageToAdd = [newInputMessage].map((msg) => msg.toDict())
await client.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await client.expire(id, this.sessionTTL)
}

if (output) {
const newOutputMessage = new AIMessage(output.text)
const messageToAdd = [newOutputMessage].map((msg) => msg.toDict())
await this.redisClient.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await this.redisClient.expire(id, this.sessionTTL)
}
if (output) {
const newOutputMessage = new AIMessage(output.text)
const messageToAdd = [newOutputMessage].map((msg) => msg.toDict())
await client.lpush(id, JSON.stringify(messageToAdd[0]))
if (this.sessionTTL) await client.expire(id, this.sessionTTL)
}
})
}

async clearChatMessages(overrideSessionId = ''): Promise<void> {
if (!this.redisClient) return

const id = overrideSessionId ? overrideSessionId : this.sessionId
await this.redisClient.del(id)
await this.clear()
await this.withRedisClient(async (client) => {
const id = overrideSessionId ? overrideSessionId : this.sessionId
await client.del(id)
await this.clear()
})
}
}

Expand Down
Loading