From 868f26736b18f6d7fa50e2c4b863ab8104cb4799 Mon Sep 17 00:00:00 2001 From: Joao Paulo Date: Tue, 26 Nov 2024 12:39:24 -0300 Subject: [PATCH 1/4] fix: change data source lifecycle on agent memory mysql saver --- .../nodes/memory/AgentMemory/mysqlSaver.ts | 185 +++++++++--------- 1 file changed, 89 insertions(+), 96 deletions(-) diff --git a/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts b/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts index 05bfcfc0b0d..7ee4c6589d8 100644 --- a/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts +++ b/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts @@ -1,50 +1,46 @@ import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph' import { RunnableConfig } from '@langchain/core/runnables' import { BaseMessage } from '@langchain/core/messages' -import { DataSource, QueryRunner } from 'typeorm' +import { DataSource } from 'typeorm' import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface' import { IMessage, MemoryMethods } from '../../../src/Interface' import { mapChatMessageToBaseMessage } from '../../../src/utils' export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { protected isSetup: boolean - - datasource: DataSource - - queryRunner: QueryRunner - config: SaverOptions - threadId: string - tableName = 'checkpoints' constructor(config: SaverOptions, serde?: SerializerProtocol) { super(serde) this.config = config - const { datasourceOptions, threadId } = config + const { threadId } = config this.threadId = threadId - this.datasource = new DataSource(datasourceOptions) } - private async setup(): Promise { - if (this.isSetup) { - return - } + private async getDataSource(): Promise { + const { datasourceOptions } = this.config + const dataSource = new DataSource(datasourceOptions) + await dataSource.initialize() + return dataSource + } + + private async setup(dataSource: DataSource): Promise { + if (this.isSetup) return try { - const appDataSource = await this.datasource.initialize() - - this.queryRunner = appDataSource.createQueryRunner() - await this.queryRunner.manager.query(` -CREATE TABLE IF NOT EXISTS ${this.tableName} ( - thread_id VARCHAR(255) NOT NULL, - checkpoint_id VARCHAR(255) NOT NULL, - parent_id VARCHAR(255), - checkpoint BLOB, - metadata BLOB, - PRIMARY KEY (thread_id, checkpoint_id) -);`) + const queryRunner = dataSource.createQueryRunner() + await queryRunner.manager.query(` + CREATE TABLE IF NOT EXISTS ${this.tableName} ( + thread_id VARCHAR(255) NOT NULL, + checkpoint_id VARCHAR(255) NOT NULL, + parent_id VARCHAR(255), + checkpoint BLOB, + metadata BLOB, + PRIMARY KEY (thread_id, checkpoint_id) + );`) + await queryRunner.release() } catch (error) { console.error(`Error creating ${this.tableName} table`, error) throw new Error(`Error creating ${this.tableName} table`) @@ -54,80 +50,64 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } async getTuple(config: RunnableConfig): Promise { - await this.setup() + const dataSource = await this.getDataSource() + await this.setup(dataSource) + const thread_id = config.configurable?.thread_id || this.threadId const checkpoint_id = config.configurable?.checkpoint_id - if (checkpoint_id) { - try { - const keys = [thread_id, checkpoint_id] - const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?` - - const rows = await this.queryRunner.manager.query(sql, keys) - - if (rows && rows.length > 0) { - return { - config, - checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint, - metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata, - parentConfig: rows[0].parent_id - ? { - configurable: { - thread_id, - checkpoint_id: rows[0].parent_id - } - } - : undefined - } - } - } catch (error) { - console.error(`Error retrieving ${this.tableName}`, error) - throw new Error(`Error retrieving ${this.tableName}`) - } - } else { - const keys = [thread_id] - const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1` + try { + const queryRunner = dataSource.createQueryRunner() + const sql = checkpoint_id + ? `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?` + : `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1` - const rows = await this.queryRunner.manager.query(sql, keys) + const rows = await queryRunner.manager.query(sql, checkpoint_id ? [thread_id, checkpoint_id] : [thread_id]) + await queryRunner.release() if (rows && rows.length > 0) { + const row = rows[0] return { config: { configurable: { - thread_id: rows[0].thread_id, - checkpoint_id: rows[0].checkpoint_id + thread_id: row.thread_id || thread_id, + checkpoint_id: row.checkpoint_id || checkpoint_id } }, - checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint, - metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata, - parentConfig: rows[0].parent_id + checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint, + metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata, + parentConfig: row.parent_id ? { configurable: { - thread_id: rows[0].thread_id, - checkpoint_id: rows[0].parent_id + thread_id, + checkpoint_id: row.parent_id } } : undefined } } + } catch (error) { + console.error(`Error retrieving ${this.tableName}`, error) + throw new Error(`Error retrieving ${this.tableName}`) + } finally { + await dataSource.destroy() } return undefined } - async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator { - await this.setup() - const thread_id = config.configurable?.thread_id || this.threadId - let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${ - before ? 'AND checkpoint_id < ?' : '' - } ORDER BY checkpoint_id DESC` - if (limit) { - sql += ` LIMIT ${limit}` - } - const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean) - + async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator { + const dataSource = await this.getDataSource() try { - const rows = await this.queryRunner.manager.query(sql, args) + const threadId = config.configurable?.thread_id || this.threadId + let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${ + before ? 'AND checkpoint_id < ?' : '' + } ORDER BY checkpoint_id DESC` + if (limit) { + sql += ` LIMIT ${limit}` + } + const args = [threadId, before?.configurable?.checkpoint_id].filter(Boolean) + const rows = await dataSource.manager.query(sql, args) if (rows && rows.length > 0) { for (const row of rows) { yield { @@ -151,15 +131,19 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } } } catch (error) { - console.error(`Error listing ${this.tableName}`, error) - throw new Error(`Error listing ${this.tableName}`) + console.error(`Error listing checkpoints`, error) + throw new Error(`Error listing checkpoints`) + } finally { + await dataSource.destroy() } } async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise { - await this.setup() - if (!config.configurable?.checkpoint_id) return {} + const dataSource = await this.getDataSource() + await this.setup(dataSource) + try { + const queryRunner = dataSource.createQueryRunner() const row = [ config.configurable?.thread_id || this.threadId, checkpoint.id, @@ -172,10 +156,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)` - await this.queryRunner.manager.query(query, row) + await queryRunner.manager.query(query, row) + await queryRunner.release() } catch (error) { console.error('Error saving checkpoint', error) throw new Error('Error saving checkpoint') + } finally { + await dataSource.destroy() } return { @@ -186,20 +173,6 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } } - async delete(threadId: string): Promise { - if (!threadId) { - return - } - await this.setup() - const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;` - - try { - await this.queryRunner.manager.query(query, [threadId]) - } catch (error) { - console.error(`Error deleting thread_id ${threadId}`, error) - } - } - async getChatMessages( overrideSessionId = '', returnBaseMessages = false, @@ -232,14 +205,34 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( type: m.role }) } + return returnIMessages } + async delete(threadId: string): Promise { + if (!threadId) return + + const dataSource = await this.getDataSource() + await this.setup(dataSource) + + try { + const queryRunner = dataSource.createQueryRunner() + const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;` + await queryRunner.manager.query(query, [threadId]) + await queryRunner.release() + } catch (error) { + console.error(`Error deleting thread_id ${threadId}`, error) + } finally { + await dataSource.destroy() + } + } + async addChatMessages(): Promise { - // Empty as it's not being used + // Implementação vazia porque o método não está sendo usado } async clearChatMessages(overrideSessionId = ''): Promise { + if (!overrideSessionId) return await this.delete(overrideSessionId) } } From f41f2a95f9a7a3c2a97d0f1ba1984a176cf7343f Mon Sep 17 00:00:00 2001 From: Henry Heng Date: Fri, 6 Dec 2024 19:16:35 +0000 Subject: [PATCH 2/4] Update mysqlSaver.ts --- .../nodes/memory/AgentMemory/mysqlSaver.ts | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts b/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts index 7ee4c6589d8..147acace2e9 100644 --- a/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts +++ b/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts @@ -97,6 +97,8 @@ export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator { const dataSource = await this.getDataSource() + await this.setup(dataSource) + const queryRunner = dataSource.createQueryRunner() try { const threadId = config.configurable?.thread_id || this.threadId let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${ @@ -107,7 +109,9 @@ export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { } const args = [threadId, before?.configurable?.checkpoint_id].filter(Boolean) - const rows = await dataSource.manager.query(sql, args) + const rows = await queryRunner.manager.query(sql, args) + await queryRunner.release() + if (rows && rows.length > 0) { for (const row of rows) { yield { @@ -142,6 +146,7 @@ export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { const dataSource = await this.getDataSource() await this.setup(dataSource) + if (!config.configurable?.checkpoint_id) return {} try { const queryRunner = dataSource.createQueryRunner() const row = [ @@ -173,6 +178,24 @@ export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { } } + async delete(threadId: string): Promise { + if (!threadId) return + + const dataSource = await this.getDataSource() + await this.setup(dataSource) + + try { + const queryRunner = dataSource.createQueryRunner() + const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;` + await queryRunner.manager.query(query, [threadId]) + await queryRunner.release() + } catch (error) { + console.error(`Error deleting thread_id ${threadId}`, error) + } finally { + await dataSource.destroy() + } + } + async getChatMessages( overrideSessionId = '', returnBaseMessages = false, @@ -209,26 +232,8 @@ export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { return returnIMessages } - async delete(threadId: string): Promise { - if (!threadId) return - - const dataSource = await this.getDataSource() - await this.setup(dataSource) - - try { - const queryRunner = dataSource.createQueryRunner() - const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;` - await queryRunner.manager.query(query, [threadId]) - await queryRunner.release() - } catch (error) { - console.error(`Error deleting thread_id ${threadId}`, error) - } finally { - await dataSource.destroy() - } - } - async addChatMessages(): Promise { - // Implementação vazia porque o método não está sendo usado + // Empty as it's not being used } async clearChatMessages(overrideSessionId = ''): Promise { From 65ee1ec3de1957a56eebb3a9dddd652fe57c2730 Mon Sep 17 00:00:00 2001 From: Henry Heng Date: Fri, 6 Dec 2024 19:17:27 +0000 Subject: [PATCH 3/4] Update pgSaver.ts --- .../nodes/memory/AgentMemory/pgSaver.ts | 118 +++++++++++------- 1 file changed, 75 insertions(+), 43 deletions(-) diff --git a/packages/components/nodes/memory/AgentMemory/pgSaver.ts b/packages/components/nodes/memory/AgentMemory/pgSaver.ts index 27e236a7f7a..098b9dc3c73 100644 --- a/packages/components/nodes/memory/AgentMemory/pgSaver.ts +++ b/packages/components/nodes/memory/AgentMemory/pgSaver.ts @@ -8,35 +8,32 @@ import { mapChatMessageToBaseMessage } from '../../../src/utils' export class PostgresSaver extends BaseCheckpointSaver implements MemoryMethods { protected isSetup: boolean - - datasource: DataSource - - queryRunner: QueryRunner - config: SaverOptions - threadId: string - tableName = 'checkpoints' constructor(config: SaverOptions, serde?: SerializerProtocol) { super(serde) this.config = config - const { datasourceOptions, threadId } = config + const { threadId } = config this.threadId = threadId - this.datasource = new DataSource(datasourceOptions) + } + + private async getDataSource(): Promise { + const { datasourceOptions } = this.config + const dataSource = new DataSource(datasourceOptions) + await dataSource.initialize() + return dataSource } - private async setup(): Promise { + private async setup(dataSource: DataSource): Promise { if (this.isSetup) { return } try { - const appDataSource = await this.datasource.initialize() - - this.queryRunner = appDataSource.createQueryRunner() - await this.queryRunner.manager.query(` + const queryRunner = dataSource.createQueryRunner() + await queryRunner.manager.query(` CREATE TABLE IF NOT EXISTS ${this.tableName} ( thread_id TEXT NOT NULL, checkpoint_id TEXT NOT NULL, @@ -44,6 +41,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( checkpoint BYTEA, metadata BYTEA, PRIMARY KEY (thread_id, checkpoint_id));`) + await queryRunner.release() } catch (error) { console.error(`Error creating ${this.tableName} table`, error) throw new Error(`Error creating ${this.tableName} table`) @@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } async getTuple(config: RunnableConfig): Promise { - await this.setup() + const dataSource = await this.getDataSource() + await this.setup(dataSource) + const thread_id = config.configurable?.thread_id || this.threadId const checkpoint_id = config.configurable?.checkpoint_id if (checkpoint_id) { try { + const queryRunner = dataSource.createQueryRunner() const keys = [thread_id, checkpoint_id] const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = $1 AND checkpoint_id = $2` - const rows = await this.queryRunner.manager.query(sql, keys) + const rows = await queryRunner.manager.query(sql, keys) + await queryRunner.release() if (rows && rows.length > 0) { return { @@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } catch (error) { console.error(`Error retrieving ${this.tableName}`, error) throw new Error(`Error retrieving ${this.tableName}`) + } finally { + await dataSource.destroy() } } else { - const keys = [thread_id] - const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1 ORDER BY checkpoint_id DESC LIMIT 1` + try { + const queryRunner = dataSource.createQueryRunner() + const keys = [thread_id] + const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1 ORDER BY checkpoint_id DESC LIMIT 1` - const rows = await this.queryRunner.manager.query(sql, keys) + const rows = await queryRunner.manager.query(sql, keys) + await queryRunner.release() - if (rows && rows.length > 0) { - return { - config: { - configurable: { - thread_id: rows[0].thread_id, - checkpoint_id: rows[0].checkpoint_id - } - }, - checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint, - metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata, - parentConfig: rows[0].parent_id - ? { - configurable: { - thread_id: rows[0].thread_id, - checkpoint_id: rows[0].parent_id - } - } - : undefined + if (rows && rows.length > 0) { + return { + config: { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].checkpoint_id + } + }, + checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint, + metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata, + parentConfig: rows[0].parent_id + ? { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].parent_id + } + } + : undefined + } } + } catch (error) { + console.error(`Error retrieving ${this.tableName}`, error) + throw new Error(`Error retrieving ${this.tableName}`) + } finally { + await dataSource.destroy() } } return undefined } async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator { - await this.setup() + const dataSource = await this.getDataSource() + await this.setup(dataSource) + + const queryRunner = dataSource.createQueryRunner() const thread_id = config.configurable?.thread_id || this.threadId let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1` const args = [thread_id] @@ -130,7 +146,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } try { - const rows = await this.queryRunner.manager.query(sql, args) + const rows = await queryRunner.manager.query(sql, args) + await queryRunner.release() if (rows && rows.length > 0) { for (const row of rows) { @@ -157,13 +174,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( } catch (error) { console.error(`Error listing ${this.tableName}`, error) throw new Error(`Error listing ${this.tableName}`) + } finally { + await dataSource.destroy() } } async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise { - await this.setup() + const dataSource = await this.getDataSource() + await this.setup(dataSource) + if (!config.configurable?.checkpoint_id) return {} try { + const queryRunner = dataSource.createQueryRunner() const row = [ config.configurable?.thread_id || this.threadId, checkpoint.id, @@ -177,10 +199,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( ON CONFLICT (thread_id, checkpoint_id) DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata` - await this.queryRunner.manager.query(query, row) + await queryRunner.manager.query(query, row) + await queryRunner.release() } catch (error) { console.error('Error saving checkpoint', error) throw new Error('Error saving checkpoint') + } finally { + await dataSource.destroy() } return { @@ -195,13 +220,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( if (!threadId) { return } - await this.setup() + + const dataSource = await this.getDataSource() + await this.setup(dataSource) + const query = `DELETE FROM "${this.tableName}" WHERE thread_id = $1;` try { - await this.queryRunner.manager.query(query, [threadId]) + const queryRunner = dataSource.createQueryRunner() + await queryRunner.manager.query(query, [threadId]) + await queryRunner.release() } catch (error) { console.error(`Error deleting thread_id ${threadId}`, error) + } finally { + await dataSource.destroy() } } From 17afad73186a6ab392f207dcf553b35f13449bb7 Mon Sep 17 00:00:00 2001 From: Henry Heng Date: Fri, 6 Dec 2024 19:27:30 +0000 Subject: [PATCH 4/4] linting fix --- .../nodes/memory/AgentMemory/pgSaver.ts | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/packages/components/nodes/memory/AgentMemory/pgSaver.ts b/packages/components/nodes/memory/AgentMemory/pgSaver.ts index 098b9dc3c73..7913825aa9a 100644 --- a/packages/components/nodes/memory/AgentMemory/pgSaver.ts +++ b/packages/components/nodes/memory/AgentMemory/pgSaver.ts @@ -1,7 +1,7 @@ import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph' import { RunnableConfig } from '@langchain/core/runnables' import { BaseMessage } from '@langchain/core/messages' -import { DataSource, QueryRunner } from 'typeorm' +import { DataSource } from 'typeorm' import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface' import { IMessage, MemoryMethods } from '../../../src/Interface' import { mapChatMessageToBaseMessage } from '../../../src/utils' @@ -18,7 +18,7 @@ export class PostgresSaver extends BaseCheckpointSaver implements MemoryMethods const { threadId } = config this.threadId = threadId } - + private async getDataSource(): Promise { const { datasourceOptions } = this.config const dataSource = new DataSource(datasourceOptions) @@ -108,11 +108,11 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata, parentConfig: rows[0].parent_id ? { - configurable: { - thread_id: rows[0].thread_id, - checkpoint_id: rows[0].parent_id - } - } + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].parent_id + } + } : undefined } } @@ -220,7 +220,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} ( if (!threadId) { return } - + const dataSource = await this.getDataSource() await this.setup(dataSource)