Skip to content

Commit

Permalink
Merge branch 'dynamic' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
serefyarar committed Jun 19, 2024
2 parents e5ba9b5 + f2fcaf7 commit 3758a47
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 6 deletions.
152 changes: 152 additions & 0 deletions indexer/src/app/modules/agent.module.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { loadSummarizationChain } from 'langchain/chains';
import { PromptTemplate } from '@langchain/core/prompts';

import {
RunnableConfig,
RunnableLambda,
RunnableLike,
RunnablePassthrough,
RunnableSequence,
} from '@langchain/core/runnables';
Expand All @@ -16,12 +19,39 @@ import { ChatMistralAI } from '@langchain/mistralai';
import { ListOutputParser } from '@langchain/core/output_parsers';
import { TokenTextSplitter } from 'langchain/text_splitter';

const formatChatHistory = (chatHistory: string | string[]) => {
if (Array.isArray(chatHistory)) {
const updatedChat = chatHistory
.map((dialogTurn: any) => {
if (dialogTurn['role'] == 'user') {
return `Human: ${dialogTurn['content']}`;
}
if (dialogTurn['role'] == 'assistant') {
return `AI: ${dialogTurn['content']}`;
}
})
.join('\n');
Logger.log(updatedChat, 'ChatService:formatChatHistory');
return updatedChat;
}
return '';
};
enum SummarizationType {
map_reduce = 'map_reduce',
stuff = 'stuff',
refine = 'refine',
}

interface CreateDynamicChainParams {
indexIds: string[];
prompt: any;
inputs: FieldMappings;
}

interface FieldMappings {
[key: string]: string;
}

export class Agent {
private apiKey: string;

Expand Down Expand Up @@ -131,6 +161,128 @@ export class Agent {
return questions;
}

public async createDynamicChain({
indexIds,
prompt,
inputs,
}: CreateDynamicChainParams): Promise<any> {
const model = new ChatOpenAI({
modelName: process.env.MODEL_CHAT as string,
streaming: true,
temperature: 0.0,
maxRetries: 4,
});

const vectorStore = await Chroma.fromExistingCollection(
new OpenAIEmbeddings({
modelName: process.env.MODEL_EMBEDDING as string,
}),
{
url: process.env.CHROMA_URL as string,
collectionName: process.env.CHROMA_COLLECTION_NAME as string,
filter: { indexId: { $in: indexIds } },
},
);

const myPrompt = PromptTemplate.fromTemplate(prompt);
const retriever = vectorStore.asRetriever();

const baseChainSequence: { [key: string]: any } = {};

// Add context sequence if context is defined in inputs
if (inputs.context) {
baseChainSequence['context'] = RunnableSequence.from([
{
docs: async (input: any) => {
const docs = await retriever.getRelevantDocuments(inputs.context);
return docs;
},
},
{
docs: (input: any) => input.docs,
},
]);
}

// Dynamically add fields based on inputs
Object.keys(inputs).forEach((key) => {
if (key !== 'context' && key !== 'chat_history') {
baseChainSequence[key] = (input: any) => input[key];
}
});

// Add chat_history if defined in inputs
if (inputs.chat_history) {
baseChainSequence['chat_history'] = (input: any) => {
return formatChatHistory(inputs.chat_history);
};
}
const filterMultilineMetadata = (metadata) =>
Object.fromEntries(
Object.entries(metadata).filter(([key, value]) => {
const result =
typeof value === 'string' &&
!(value.includes('\n') || value.toString().length > 256);
console.log(value.toString().length, result);
return result;
}),
);

const answerChainSequence: [
RunnableLike<any, any>,
...RunnableLike<any, any>[],
RunnableLike<any, any>,
] = [
baseChainSequence,
{
answer: RunnableSequence.from([
{
context: (input: any) => {
if (input.context && input.context.docs) {
const docs_with_metadata = input.context.docs.map(
(doc: any) => {
const metaData = filterMultilineMetadata(doc.metadata);
return `${JSON.stringify(metaData)}\n${doc.pageContent}`;
},
);
return docs_with_metadata.join('\n');
}
return '';
},
...Object.keys(inputs).reduce(
(acc, key) => {
if (key !== 'context' && key !== 'chat_history') {
acc[key] = (input: any) => input[key];
}
return acc;
},
{} as { [key: string]: (input: any) => any },
),
...(inputs.chat_history
? {
chat_history: (input: any) =>
formatChatHistory(input.chat_history),
}
: {}),
},
myPrompt as any,
model,
new StringOutputParser(),
]),
sources: RunnableLambda.from((input: any) => {
if (input.context && input.context.docs) {
return input.context.docs.map((doc: any) => doc.metadata);
}
return [];
}),
},
];

const chain = RunnableSequence.from(answerChainSequence);
const { context, chat_history, ...filteredInputs } = inputs;
return await chain.stream(filteredInputs);
}

private async createRAGChain(
chroma_indices: string[],
model_type: string,
Expand Down
21 changes: 21 additions & 0 deletions indexer/src/chat/controller/chat.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,27 @@ import { ApiBody, ApiQuery } from '@nestjs/swagger';
export class ChatController {
constructor(private readonly chatService: ChatService) {}

@ApiBody({ type: RetrievalQuestionInput })
@Post('/external')
async dynamic(@Body() body: any, @Res() res: any) {
try {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.setHeader('Content-Encoding', 'none');

Logger.log(`Processing ${JSON.stringify(body)}`, 'chatController:stream');
const stream = await this.chatService.streamExternal(body);

for await (const chunk of stream) {
chunk.answer && res.write(chunk.answer);
}

res.end();
} catch (e) {
console.error(e);
}
}
@ApiBody({ type: RetrievalQuestionInput })
@Post('/stream')
async stream(@Body() body: RetrievalQuestionInput, @Res() res: any) {
Expand Down
30 changes: 24 additions & 6 deletions indexer/src/chat/service/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@ import {
Inject,
Injectable,
Logger,
StreamableFile,
} from '@nestjs/common';

import { Chroma } from '@langchain/community/vectorstores/chroma';
import {
QuestionGenerationInput,
RetrievalQuestionInput,
} from '../schema/chat.schema';
import { RetrievalQuestionInput } from '../schema/chat.schema';

import { Agent } from 'src/app/modules/agent.module';
import { RunnableSequence } from '@langchain/core/runnables';
import { OpenAIEmbeddings } from '@langchain/openai';
import { pull } from 'langchain/hub';

@Injectable()
Expand Down Expand Up @@ -74,6 +69,29 @@ export class ChatService {
}
}

/**
* @description Stream a question to the agent with a chat history
*
* @param body
* @returns
*/
async streamExternal(body: any) {
Logger.log(
`Processing ${JSON.stringify(body)}`,
'chatService:streamExternal',
);
try {
return await this.agentClient.createDynamicChain(body);
// Invoke the agent
} catch (e) {
Logger.log(
`Cannot process ${body.input.question} ${e}`,
'chatService:stream:error',
);
throw e;
}
}

async questions(indexIds: string[]) {
const chain =
await this.agentClient.createQuestionGenerationChain('OpenAI');
Expand Down

0 comments on commit 3758a47

Please sign in to comment.