Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: change data source lifecycle on agent memory mysql saver #3578

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 84 additions & 86 deletions packages/components/nodes/memory/AgentMemory/mysqlSaver.ts
Original file line number Diff line number Diff line change
@@ -1,50 +1,46 @@
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
import { RunnableConfig } from '@langchain/core/runnables'
import { BaseMessage } from '@langchain/core/messages'
import { DataSource, QueryRunner } from 'typeorm'
import { DataSource } from 'typeorm'
import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface'
import { IMessage, MemoryMethods } from '../../../src/Interface'
import { mapChatMessageToBaseMessage } from '../../../src/utils'

export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {
protected isSetup: boolean

datasource: DataSource

queryRunner: QueryRunner

config: SaverOptions

threadId: string

tableName = 'checkpoints'

constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
super(serde)
this.config = config
const { datasourceOptions, threadId } = config
const { threadId } = config
this.threadId = threadId
this.datasource = new DataSource(datasourceOptions)
}

private async setup(): Promise<void> {
if (this.isSetup) {
return
}
private async getDataSource(): Promise<DataSource> {
const { datasourceOptions } = this.config
const dataSource = new DataSource(datasourceOptions)
await dataSource.initialize()
return dataSource
}

private async setup(dataSource: DataSource): Promise<void> {
if (this.isSetup) return

try {
const appDataSource = await this.datasource.initialize()

this.queryRunner = appDataSource.createQueryRunner()
await this.queryRunner.manager.query(`
CREATE TABLE IF NOT EXISTS ${this.tableName} (
thread_id VARCHAR(255) NOT NULL,
checkpoint_id VARCHAR(255) NOT NULL,
parent_id VARCHAR(255),
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_id)
);`)
const queryRunner = dataSource.createQueryRunner()
await queryRunner.manager.query(`
CREATE TABLE IF NOT EXISTS ${this.tableName} (
thread_id VARCHAR(255) NOT NULL,
checkpoint_id VARCHAR(255) NOT NULL,
parent_id VARCHAR(255),
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_id)
);`)
await queryRunner.release()
} catch (error) {
console.error(`Error creating ${this.tableName} table`, error)
throw new Error(`Error creating ${this.tableName} table`)
Expand All @@ -54,79 +50,67 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

const thread_id = config.configurable?.thread_id || this.threadId
const checkpoint_id = config.configurable?.checkpoint_id

if (checkpoint_id) {
try {
const keys = [thread_id, checkpoint_id]
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`

const rows = await this.queryRunner.manager.query(sql, keys)

if (rows && rows.length > 0) {
return {
config,
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id,
checkpoint_id: rows[0].parent_id
}
}
: undefined
}
}
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
}
} else {
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
try {
const queryRunner = dataSource.createQueryRunner()
const sql = checkpoint_id
? `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
: `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`

const rows = await this.queryRunner.manager.query(sql, keys)
const rows = await queryRunner.manager.query(sql, checkpoint_id ? [thread_id, checkpoint_id] : [thread_id])
await queryRunner.release()

if (rows && rows.length > 0) {
const row = rows[0]
return {
config: {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].checkpoint_id
thread_id: row.thread_id || thread_id,
checkpoint_id: row.checkpoint_id || checkpoint_id
}
},
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
parentConfig: rows[0].parent_id
checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata,
parentConfig: row.parent_id
? {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].parent_id
thread_id,
checkpoint_id: row.parent_id
}
}
: undefined
}
}
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
} finally {
await dataSource.destroy()
}
return undefined
}

async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
await this.setup()
const thread_id = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
before ? 'AND checkpoint_id < ?' : ''
} ORDER BY checkpoint_id DESC`
if (limit) {
sql += ` LIMIT ${limit}`
}
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)

async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple, void, unknown> {
const dataSource = await this.getDataSource()
await this.setup(dataSource)
const queryRunner = dataSource.createQueryRunner()
try {
const rows = await this.queryRunner.manager.query(sql, args)
const threadId = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
before ? 'AND checkpoint_id < ?' : ''
} ORDER BY checkpoint_id DESC`
if (limit) {
sql += ` LIMIT ${limit}`
}
const args = [threadId, before?.configurable?.checkpoint_id].filter(Boolean)

const rows = await queryRunner.manager.query(sql, args)
await queryRunner.release()

if (rows && rows.length > 0) {
for (const row of rows) {
Expand All @@ -151,15 +135,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
}
}
} catch (error) {
console.error(`Error listing ${this.tableName}`, error)
throw new Error(`Error listing ${this.tableName}`)
console.error(`Error listing checkpoints`, error)
throw new Error(`Error listing checkpoints`)
} finally {
await dataSource.destroy()
}
}

async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
await this.setup()
const dataSource = await this.getDataSource()
await this.setup(dataSource)

if (!config.configurable?.checkpoint_id) return {}
try {
const queryRunner = dataSource.createQueryRunner()
const row = [
config.configurable?.thread_id || this.threadId,
checkpoint.id,
Expand All @@ -172,10 +161,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
VALUES (?, ?, ?, ?, ?)
ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)`

await this.queryRunner.manager.query(query, row)
await queryRunner.manager.query(query, row)
await queryRunner.release()
} catch (error) {
console.error('Error saving checkpoint', error)
throw new Error('Error saving checkpoint')
} finally {
await dataSource.destroy()
}

return {
Expand All @@ -187,16 +179,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
}

async delete(threadId: string): Promise<void> {
if (!threadId) {
return
}
await this.setup()
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
if (!threadId) return

const dataSource = await this.getDataSource()
await this.setup(dataSource)

try {
await this.queryRunner.manager.query(query, [threadId])
const queryRunner = dataSource.createQueryRunner()
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`
await queryRunner.manager.query(query, [threadId])
await queryRunner.release()
} catch (error) {
console.error(`Error deleting thread_id ${threadId}`, error)
} finally {
await dataSource.destroy()
}
}

Expand Down Expand Up @@ -232,6 +228,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
type: m.role
})
}

return returnIMessages
}

Expand All @@ -240,6 +237,7 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
}

async clearChatMessages(overrideSessionId = ''): Promise<void> {
if (!overrideSessionId) return
await this.delete(overrideSessionId)
}
}
Loading
Loading