Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced parameter storeConversationsToDefaultThread #149

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions core/embedjs-interfaces/src/interfaces/base-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { HumanMessage, AIMessage, SystemMessage } from '@langchain/core/messages
import createDebugMessages from 'debug';
import { v4 as uuidv4 } from 'uuid';

import { Chunk, QueryResponse, Message, SourceDetail, ModelResponse } from '../types.js';
import { Chunk, QueryResponse, Message, SourceDetail, ModelResponse, Conversation } from '../types.js';
import { BaseCache } from './base-cache.js';

export abstract class BaseModel {
Expand Down Expand Up @@ -78,23 +78,32 @@ export abstract class BaseModel {
system: string,
userQuery: string,
supportingContext: Chunk[],
conversationId = 'default',
conversationId?: string,
): Promise<QueryResponse> {
if (!(await BaseModel.cache.hasConversation(conversationId))) {
this.baseDebug(`Conversation with id '${conversationId}' is new`);
await BaseModel.cache.addConversation(conversationId);
}
let conversation: Conversation;

const conversation = await BaseModel.cache.getConversation(conversationId);
this.baseDebug(`${conversation.entries.length} history entries found for conversationId '${conversationId}'`);
if (conversationId) {
if (!(await BaseModel.cache.hasConversation(conversationId))) {
this.baseDebug(`Conversation with id '${conversationId}' is new`);
await BaseModel.cache.addConversation(conversationId);
}

// Add user query to history
await BaseModel.cache.addEntryToConversation(conversationId, {
id: uuidv4(),
timestamp: new Date(),
actor: 'HUMAN',
content: userQuery,
});
conversation = await BaseModel.cache.getConversation(conversationId);
this.baseDebug(
`${conversation.entries.length} history entries found for conversationId '${conversationId}'`,
);

// Add user query to history
await BaseModel.cache.addEntryToConversation(conversationId, {
id: uuidv4(),
timestamp: new Date(),
actor: 'HUMAN',
content: userQuery,
});
} else {
this.baseDebug('Conversation history is disabled as no conversationId was provided');
conversation = { conversationId: 'default', entries: [] };
}

const messages = await this.prepare(system, userQuery, supportingContext, conversation.entries.slice(0, -1));
const uniqueSources = this.extractUniqueSources(supportingContext);
Expand All @@ -112,8 +121,11 @@ export abstract class BaseModel {
sources: uniqueSources,
};

// Add AI response to history
await BaseModel.cache.addEntryToConversation(conversationId, newEntry);
if (conversationId) {
// Add AI response to history
await BaseModel.cache.addEntryToConversation(conversationId, newEntry);
}

return {
...newEntry,
tokenUse: {
Expand Down
15 changes: 15 additions & 0 deletions core/embedjs/src/core/rag-application-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export class RAGApplicationBuilder {
private searchResultCount: number;
private embeddingModel: BaseEmbeddings;
private embeddingRelevanceCutOff: number;
private storeConversationsToDefaultThread: boolean;

constructor() {
this.loaders = [];
Expand All @@ -24,6 +25,7 @@ export class RAGApplicationBuilder {

Do not use words like context or training data when responding. You can say you do not have all the information but do not indicate that you are not a reliable source.`;

this.storeConversationsToDefaultThread = true;
this.embeddingRelevanceCutOff = 0;
this.cache = new MemoryCache();
}
Expand Down Expand Up @@ -101,6 +103,15 @@ export class RAGApplicationBuilder {
return this;
}

/**
* The setParamStoreConversationsToDefaultThread configures whether the conversation hisotry for queries made
* without a conversationId passed should be stored in the default thread. This is set to True by default.
*/
setParamStoreConversationsToDefaultThread(storeConversationsToDefaultThread: boolean) {
this.storeConversationsToDefaultThread = storeConversationsToDefaultThread;
return this;
}

getLoaders() {
return this.loaders;
}
Expand Down Expand Up @@ -136,4 +147,8 @@ export class RAGApplicationBuilder {
getModel() {
return this.model;
}

getParamStoreConversationsToDefaultThread() {
return this.storeConversationsToDefaultThread;
}
}
9 changes: 8 additions & 1 deletion core/embedjs/src/core/rag-application.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { cleanString, getUnique } from '@llm-tools/embedjs-utils';

export class RAGApplication {
private readonly debug = createDebugMessages('embedjs:core');
private readonly storeConversationsToDefaultThread: boolean;
private readonly embeddingRelevanceCutOff: number;
private readonly searchResultCount: number;
private readonly systemMessage: string;
Expand All @@ -30,6 +31,7 @@ export class RAGApplication {
constructor(llmBuilder: RAGApplicationBuilder) {
if (!llmBuilder.getEmbeddingModel()) throw new Error('Embedding model must be set!');

this.storeConversationsToDefaultThread = llmBuilder.getParamStoreConversationsToDefaultThread();
this.cache = llmBuilder.getCache();
BaseLoader.setCache(this.cache);
BaseModel.setCache(this.cache);
Expand Down Expand Up @@ -382,11 +384,16 @@ export class RAGApplication {
let context = options?.customContext;
if (!context) context = await this.search(userQuery);

let conversationId = options?.conversationId;
if (!conversationId && this.storeConversationsToDefaultThread) {
conversationId = 'default';
}

const sources = [...new Set(context.map((chunk) => chunk.metadata.source))];
this.debug(
`Query resulted in ${context.length} chunks after filteration; chunks from ${sources.length} unique sources.`,
);

return this.model.query(this.systemMessage, userQuery, context, options?.conversationId);
return this.model.query(this.systemMessage, userQuery, context, conversationId);
}
}
Loading
Loading