Skip to content

Commit

Permalink
feat(Question and Answer Chain Node): Customize question and answer s…
Browse files Browse the repository at this point in the history
…ystem prompt (#10385)
  • Loading branch information
mprytoluk authored Sep 27, 2024
1 parent 7073ec6 commit 08a27b3
Showing 1 changed file with 61 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,21 @@ import {
import { RetrievalQAChain } from 'langchain/chains';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import type { BaseRetriever } from '@langchain/core/retrievers';
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
} from '@langchain/core/prompts';
import { getTemplateNoticeField } from '../../../utils/sharedFields';
import { getPromptInputByType } from '../../../utils/helpers';
import { getPromptInputByType, isChatInstance } from '../../../utils/helpers';
import { getTracingConfig } from '../../../utils/tracing';

const SYSTEM_PROMPT_TEMPLATE = `Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
{context}`;

export class ChainRetrievalQa implements INodeType {
description: INodeTypeDescription = {
displayName: 'Question and Answer Chain',
Expand Down Expand Up @@ -137,6 +148,26 @@ export class ChainRetrievalQa implements INodeType {
},
},
},
{
displayName: 'Options',
name: 'options',
type: 'collection',
default: {},
placeholder: 'Add Option',
options: [
{
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
description:
'Template string used for the system prompt. This should include the variable `{context}` for the provided context. For text completion models, you should also include the variable `{question}` for the user’s query.',
typeOptions: {
rows: 6,
},
},
],
},
],
};

Expand All @@ -154,7 +185,6 @@ export class ChainRetrievalQa implements INodeType {
)) as BaseRetriever;

const items = this.getInputData();
const chain = RetrievalQAChain.fromLLM(model, retriever);

const returnData: INodeExecutionData[] = [];

Expand All @@ -178,6 +208,35 @@ export class ChainRetrievalQa implements INodeType {
throw new NodeOperationError(this.getNode(), 'The ‘query‘ parameter is empty.');
}

const options = this.getNodeParameter('options', itemIndex, {}) as {
systemPromptTemplate?: string;
};

const chainParameters = {} as {
prompt?: PromptTemplate | ChatPromptTemplate;
};

if (options.systemPromptTemplate !== undefined) {
if (isChatInstance(model)) {
const messages = [
SystemMessagePromptTemplate.fromTemplate(options.systemPromptTemplate),
HumanMessagePromptTemplate.fromTemplate('{question}'),
];
const chatPromptTemplate = ChatPromptTemplate.fromMessages(messages);

chainParameters.prompt = chatPromptTemplate;
} else {
const completionPromptTemplate = new PromptTemplate({
template: options.systemPromptTemplate,
inputVariables: ['context', 'question'],
});

chainParameters.prompt = completionPromptTemplate;
}
}

const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters);

const response = await chain.withConfig(getTracingConfig(this)).invoke({ query });
returnData.push({ json: { response } });
} catch (error) {
Expand Down

0 comments on commit 08a27b3

Please sign in to comment.