Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/Prevent open connections on typeorm datasource #3652

Merged
merged 1 commit into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {

private async getDataSource(): Promise<DataSource> {
const { datasourceOptions } = this.config
if (!datasourceOptions) {
throw new Error('No datasource options provided')
}
// Prevent using default Postgres port, otherwise will throw uncaught error and crashing the app
if (datasourceOptions.port === 5432) {
throw new Error('Invalid port number')
}
const dataSource = new DataSource(datasourceOptions)
await dataSource.initialize()
return dataSource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ export class PostgresSaver extends BaseCheckpointSaver implements MemoryMethods

private async getDataSource(): Promise<DataSource> {
const { datasourceOptions } = this.config
if (!datasourceOptions) {
throw new Error('No datasource options provided')
}
// Prevent using default MySQL port, otherwise will throw uncaught error and crashing the app
if (datasourceOptions.port === 3006) {
throw new Error('Invalid port number')
}
const dataSource = new DataSource(datasourceOptions)
await dataSource.initialize()
return dataSource
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
import { ListKeyOptions, RecordManagerInterface, UpdateOptions } from '@langchain/community/indexes/base'
import { DataSource, QueryRunner } from 'typeorm'
import { DataSource } from 'typeorm'

class MySQLRecordManager_RecordManager implements INode {
label: string
Expand Down Expand Up @@ -167,47 +167,58 @@ type MySQLRecordManagerOptions = {

class MySQLRecordManager implements RecordManagerInterface {
lc_namespace = ['langchain', 'recordmanagers', 'mysql']

datasource: DataSource

queryRunner: QueryRunner

config: MySQLRecordManagerOptions
tableName: string

namespace: string

constructor(namespace: string, config: MySQLRecordManagerOptions) {
const { mysqlOptions, tableName } = config
const { tableName } = config
this.namespace = namespace
this.tableName = tableName || 'upsertion_records'
this.datasource = new DataSource(mysqlOptions)
this.config = config
}

private async getDataSource(): Promise<DataSource> {
const { mysqlOptions } = this.config
if (!mysqlOptions) {
throw new Error('No datasource options provided')
}
// Prevent using default Postgres port, otherwise will throw uncaught error and crashing the app
if (mysqlOptions.port === 5432) {
throw new Error('Invalid port number')
}
const dataSource = new DataSource(mysqlOptions)
await dataSource.initialize()
return dataSource
}

async createSchema(): Promise<void> {
try {
const appDataSource = await this.datasource.initialize()

this.queryRunner = appDataSource.createQueryRunner()
const dataSource = await this.getDataSource()
const queryRunner = dataSource.createQueryRunner()

await this.queryRunner.manager.query(`create table if not exists \`${this.tableName}\` (
await queryRunner.manager.query(`create table if not exists \`${this.tableName}\` (
\`uuid\` varchar(36) primary key default (UUID()),
\`key\` varchar(255) not null,
\`namespace\` varchar(255) not null,
\`updated_at\` DOUBLE precision not null,
\`group_id\` longtext,
unique key \`unique_key_namespace\` (\`key\`,
\`namespace\`));`)

const columns = [`updated_at`, `key`, `namespace`, `group_id`]
for (const column of columns) {
// MySQL does not support 'IF NOT EXISTS' function for Index
const Check = await this.queryRunner.manager.query(
const Check = await queryRunner.manager.query(
`SELECT COUNT(1) IndexIsThere FROM INFORMATION_SCHEMA.STATISTICS
WHERE table_schema=DATABASE() AND table_name='${this.tableName}' AND index_name='${column}_index';`
)
if (Check[0].IndexIsThere === 0)
await this.queryRunner.manager.query(`CREATE INDEX \`${column}_index\`
await queryRunner.manager.query(`CREATE INDEX \`${column}_index\`
ON \`${this.tableName}\` (\`${column}\`);`)
}

await queryRunner.release()
} catch (e: any) {
// This error indicates that the table already exists
// Due to asynchronous nature of the code, it is possible that
Expand All @@ -221,12 +232,17 @@ class MySQLRecordManager implements RecordManagerInterface {
}

async getTime(): Promise<number> {
const dataSource = await this.getDataSource()
try {
const res = await this.queryRunner.manager.query(`SELECT UNIX_TIMESTAMP(NOW()) AS epoch`)
const queryRunner = dataSource.createQueryRunner()
const res = await queryRunner.manager.query(`SELECT UNIX_TIMESTAMP(NOW()) AS epoch`)
await queryRunner.release()
return Number.parseFloat(res[0].epoch)
} catch (error) {
console.error('Error getting time in MySQLRecordManager:')
throw error
} finally {
await dataSource.destroy()
}
}

Expand All @@ -235,6 +251,9 @@ class MySQLRecordManager implements RecordManagerInterface {
return
}

const dataSource = await this.getDataSource()
const queryRunner = dataSource.createQueryRunner()

const updatedAt = await this.getTime()
const { timeAtLeast, groupIds: _groupIds } = updateOptions ?? {}

Expand All @@ -261,9 +280,18 @@ class MySQLRecordManager implements RecordManagerInterface {
ON DUPLICATE KEY UPDATE \`updated_at\` = VALUES(\`updated_at\`)`

// To handle multiple files upsert
for (const record of recordsToUpsert) {
// Consider using a transaction for batch operations
await this.queryRunner.manager.query(query, record.flat())
try {
for (const record of recordsToUpsert) {
// Consider using a transaction for batch operations
await queryRunner.manager.query(query, record.flat())
}

await queryRunner.release()
} catch (error) {
console.error('Error updating in MySQLRecordManager:')
throw error
} finally {
await dataSource.destroy()
}
}

Expand All @@ -272,6 +300,9 @@ class MySQLRecordManager implements RecordManagerInterface {
return []
}

const dataSource = await this.getDataSource()
const queryRunner = dataSource.createQueryRunner()

// Prepare the placeholders and the query
const placeholders = keys.map(() => `?`).join(', ')
const query = `
Expand All @@ -284,21 +315,27 @@ class MySQLRecordManager implements RecordManagerInterface {

try {
// Execute the query
const rows = await this.queryRunner.manager.query(query, [this.namespace, ...keys.flat()])
const rows = await queryRunner.manager.query(query, [this.namespace, ...keys.flat()])
// Create a set of existing keys for faster lookup
const existingKeysSet = new Set(rows.map((row: { key: string }) => row.key))
// Map the input keys to booleans indicating if they exist
keys.forEach((key, index) => {
existsArray[index] = existingKeysSet.has(key)
})
await queryRunner.release()
return existsArray
} catch (error) {
console.error('Error checking existence of keys')
throw error // Allow the caller to handle the error
throw error
} finally {
await dataSource.destroy()
}
}

async listKeys(options?: ListKeyOptions): Promise<string[]> {
const dataSource = await this.getDataSource()
const queryRunner = dataSource.createQueryRunner()

try {
const { before, after, limit, groupIds } = options ?? {}
let query = `SELECT \`key\` FROM \`${this.tableName}\` WHERE \`namespace\` = ?`
Expand Down Expand Up @@ -330,11 +367,14 @@ class MySQLRecordManager implements RecordManagerInterface {
query += ';'

// Directly using try/catch with async/await for cleaner flow
const result = await this.queryRunner.manager.query(query, values)
const result = await queryRunner.manager.query(query, values)
await queryRunner.release()
return result.map((row: { key: string }) => row.key)
} catch (error) {
console.error('MySQLRecordManager listKeys Error: ')
throw error // Re-throw the error to be handled by the caller
throw error
} finally {
await dataSource.destroy()
}
}

Expand All @@ -343,16 +383,22 @@ class MySQLRecordManager implements RecordManagerInterface {
return
}

const dataSource = await this.getDataSource()
const queryRunner = dataSource.createQueryRunner()

const placeholders = keys.map(() => '?').join(', ')
const query = `DELETE FROM \`${this.tableName}\` WHERE \`namespace\` = ? AND \`key\` IN (${placeholders});`
const values = [this.namespace, ...keys].map((v) => (typeof v !== 'string' ? `${v}` : v))

// Directly using try/catch with async/await for cleaner flow
try {
await this.queryRunner.manager.query(query, values)
await queryRunner.manager.query(query, values)
await queryRunner.release()
} catch (error) {
console.error('Error deleting keys')
throw error // Re-throw the error to be handled by the caller
throw error
} finally {
await dataSource.destroy()
}
}
}
Expand Down
Loading
Loading