Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Postgres Chat Memory, Redis Chat Memory, Xata): Add support for context window length #10203

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type { BufferWindowMemoryInput } from 'langchain/memory';
import { BufferWindowMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

class MemoryChatBufferSingleton {
Expand Down Expand Up @@ -130,13 +130,7 @@ export class MemoryBufferWindow implements INodeType {
},
},
sessionKeyProperty,
{
displayName: 'Context Window Length',
name: 'contextWindowLength',
type: 'number',
default: 5,
description: 'The number of previous messages to consider for context',
},
contextWindowLengthProperty,
],
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow';
import { NodeConnectionType } from 'n8n-workflow';
import { BufferMemory } from 'langchain/memory';
import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import { PostgresChatMessageHistory } from '@langchain/community/stores/message/postgres';
import type pg from 'pg';
import { configurePostgres } from 'n8n-nodes-base/dist/nodes/Postgres/v2/transport';
import type { PostgresNodeCredentials } from 'n8n-nodes-base/dist/nodes/Postgres/v2/helpers/interfaces';
import { postgresConnectionTest } from 'n8n-nodes-base/dist/nodes/Postgres/v2/methods/credentialTest';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

export class MemoryPostgresChat implements INodeType {
Expand All @@ -18,7 +18,7 @@ export class MemoryPostgresChat implements INodeType {
name: 'memoryPostgresChat',
icon: 'file:postgres.svg',
group: ['transform'],
version: [1],
version: [1, 1.1],
description: 'Stores the chat history in Postgres table.',
defaults: {
name: 'Postgres Chat Memory',
Expand Down Expand Up @@ -60,6 +60,10 @@ export class MemoryPostgresChat implements INodeType {
description:
'The table name to store the chat history in. If table does not exist, it will be created.',
},
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.1 } }] } },
},
],
};

Expand All @@ -83,12 +87,19 @@ export class MemoryPostgresChat implements INodeType {
tableName,
});

const memory = new BufferMemory({
const memClass = this.getNode().typeVersion < 1.1 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.1
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };

const memory = new memClass({
memoryKey: 'chat_history',
chatHistory: pgChatHistory,
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
...kOptions,
});

async function closeFunction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import {
type SupplyData,
NodeConnectionType,
} from 'n8n-workflow';
import { BufferMemory } from 'langchain/memory';
import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import type { RedisChatMessageHistoryInput } from '@langchain/redis';
import { RedisChatMessageHistory } from '@langchain/redis';
import type { RedisClientOptions } from 'redis';
import { createClient } from 'redis';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

export class MemoryRedisChat implements INodeType {
Expand All @@ -23,7 +23,7 @@ export class MemoryRedisChat implements INodeType {
name: 'memoryRedisChat',
icon: 'file:redis.svg',
group: ['transform'],
version: [1, 1.1, 1.2],
version: [1, 1.1, 1.2, 1.3],
description: 'Stores the chat history in Redis.',
defaults: {
name: 'Redis Chat Memory',
Expand Down Expand Up @@ -95,6 +95,10 @@ export class MemoryRedisChat implements INodeType {
description:
'For how long the session should be stored in seconds. If set to 0 it will not expire.',
},
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } },
},
],
};

Expand Down Expand Up @@ -143,12 +147,19 @@ export class MemoryRedisChat implements INodeType {
}
const redisChatHistory = new RedisChatMessageHistory(redisChatConfig);

const memory = new BufferMemory({
const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.3
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };

const memory = new memClass({
memoryKey: 'chat_history',
chatHistory: redisChatHistory,
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
...kOptions,
});

async function closeFunction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import { NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow';
import { XataChatMessageHistory } from '@langchain/community/stores/message/xata';
import { BufferMemory } from 'langchain/memory';
import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import { BaseClient } from '@xata.io/client';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

export class MemoryXata implements INodeType {
Expand All @@ -15,7 +15,7 @@ export class MemoryXata implements INodeType {
name: 'memoryXata',
icon: 'file:xata.svg',
group: ['transform'],
version: [1, 1.1, 1.2],
version: [1, 1.1, 1.2, 1.3],
description: 'Use Xata Memory',
defaults: {
name: 'Xata',
Expand Down Expand Up @@ -81,6 +81,10 @@ export class MemoryXata implements INodeType {
},
},
sessionKeyProperty,
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } },
},
],
};

Expand Down Expand Up @@ -120,12 +124,19 @@ export class MemoryXata implements INodeType {
apiKey: credentials.apiKey as string,
});

const memory = new BufferMemory({
const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.3
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };

const memory = new memClass({
chatHistory,
memoryKey: 'chat_history',
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
...kOptions,
});

return {
Expand Down
8 changes: 8 additions & 0 deletions packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ export const sessionKeyProperty: INodeProperties = {
},
},
};

export const contextWindowLengthProperty: INodeProperties = {
displayName: 'Context Window Length',
name: 'contextWindowLength',
type: 'number',
default: 5,
hint: 'How many past interactions the model receives as context',
};
Loading