Skip to content

Commit

Permalink
Iterate
Browse files Browse the repository at this point in the history
  • Loading branch information
serefyarar committed Jun 1, 2024
1 parent da54d2a commit a4ec36d
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 31 deletions.
2 changes: 1 addition & 1 deletion api/src/controllers/discovery.js
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export const chat = async (req, res, next) => {
});
} catch (error) {
// Handle the exception
console.error("An error occurred:", error);
console.error("An error occurred:", error.message);
res.status(500).json({ error: "Internal Server Error" });
}
};
Expand Down
9 changes: 1 addition & 8 deletions api/src/libs/indexer.js
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,7 @@ class Indexer {
const chatRequest = {
indexIds,
input: {
question: `
Determine if the following information is relevant to the previous conversation.
If it is relevant, output a conversation simulating that you received real-time news for the user.
Use conversational output format suitable to data model, use bold texts and links when available.
Do not mention relevancy check, just share it as an update.
If it is not relevant, simply say "NOT_RELEVANT".
Information: ${JSON.stringify(item)}
`,
information: JSON.stringify(item),
chat_history: messages,
},
};
Expand Down
43 changes: 30 additions & 13 deletions indexer/src/app/modules/agent.module.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { ChatOpenAI, OpenAI, OpenAIEmbeddings } from '@langchain/openai';
import { loadSummarizationChain } from 'langchain/chains';
import {
RunnableConfig,
RunnableLambda,
RunnablePassthrough,
RunnableSequence,
Expand Down Expand Up @@ -42,10 +43,11 @@ export class Agent {
indexIds: string[],
model_type: string = 'OpenAI',
model_args: any,
prompt: any,
): Promise<any> {
switch (chain_type) {
case 'rag-v0':
return this.createRAGChain(indexIds, model_type, model_args);
return this.createRAGChain(indexIds, model_type, model_args, prompt);

default:
throw new Error('Chain type not supported');
Expand Down Expand Up @@ -133,7 +135,9 @@ export class Agent {
chroma_indices: string[],
model_type: string,
model_args: any = { temperature: 0.0, max_tokens: 1000, max_retries: 4 },
prompt: any,
): Promise<any> {
console.log(prompt);
// TODO: Add prior filtering for questions such as "What is new today?" (with date filter)
// TODO: Add self-ask prompt for fact-checking
const argv = model_args ?? {
Expand Down Expand Up @@ -181,8 +185,6 @@ export class Agent {

const retriever = vectorStore.asRetriever();

const answerPrompt = await pull(process.env.PROMPT_ANSWER_TAG);

const formatChatHistory = (chatHistory: string | string[]) => {
if (Array.isArray(chatHistory)) {
const updatedChat = chatHistory
Expand All @@ -206,7 +208,9 @@ export class Agent {
context: RunnableSequence.from([
{
docs: async (input) => {
const docs = await retriever.getRelevantDocuments(input.question);
const docs = await retriever.getRelevantDocuments(
input.question || input.information,
);
return docs;
},
},
Expand All @@ -219,8 +223,10 @@ export class Agent {
},
]),
question: (input) => {
Logger.log(input.question, 'ChatService:answerChain:inputQuestion');
return input.question;
if (input.question) {
Logger.log(input.question, 'ChatService:answerChain:inputQuestion');
return input.question;
}
},
chat_history: (input) => input.chat_history,
},
Expand All @@ -235,16 +241,27 @@ export class Agent {
return serialized;
},
question: (input) => {
Logger.log(
input.question,
'ChatService:answerChain:inputQuestion',
);
return input.question;
if (input.question) {
Logger.log(
input.question,
'ChatService:answerChain:inputQuestion',
);
return input.question;
}
},
information: (input) => {
if (input.information) {
Logger.log(
input.information,
'ChatService:answerChain:inputInformation',
);
return input.inormation;
}
},
chat_history: (input) => formatChatHistory(input.chat_history),
relation: (input) => '',
chat_history: (input) => formatChatHistory(input.chat_history),
},
answerPrompt as any,
prompt as any,
model,
new StringOutputParser(),
]),
Expand Down
2 changes: 1 addition & 1 deletion indexer/src/chat/schema/chat.schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export class RetrievalQuestionInput {
chat_history: [],
},
})
input: { question: string; chat_history: [] };
input: { question?: string; information?: string; chat_history: [] };

@ApiPropertyOptional({
description: 'Model arguments',
Expand Down
28 changes: 21 additions & 7 deletions indexer/src/chat/service/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
import { Agent } from 'src/app/modules/agent.module';
import { RunnableSequence } from '@langchain/core/runnables';
import { OpenAIEmbeddings } from '@langchain/openai';
import { pull } from 'langchain/hub';

@Injectable()
export class ChatService {
Expand All @@ -37,20 +38,33 @@ export class ChatService {

try {
// Initialize the agent

const answerPrompt = await pull(
body.input.question
? process.env.PROMPT_ANSWER_TAG
: process.env.PROMPT_RELEVANCY_CHECK_TAG,
);
const chain: RunnableSequence = await this.agentClient.createAgentChain(
body.chain_type,
body.indexIds,
body.model_type,
body?.model_args,
answerPrompt,
);

if (body.input.question) {
const stream = await chain.stream({
question: body.input.question,
chat_history: body.input.chat_history,
});
return stream;
} else if (body.input.information) {
const stream = await chain.stream({
information: body.input.information,
chat_history: body.input.chat_history,
});
return stream;
}
// Invoke the agent
const stream = await chain.stream({
question: body.input.question,
chat_history: body.input.chat_history,
});

return stream;
} catch (e) {
Logger.log(
`Cannot process ${body.input.question} ${e}`,
Expand Down
3 changes: 2 additions & 1 deletion web-app/src/components/site/indexes/AskIndexes/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ const AskIndexes: FC<AskIndexesProps> = ({ chatID, sources }) => {
return () => {
ws.close();
};
}, [chatID, setMessages, messages]);
// eslint-disable-next-line
}, [chatID]);
if (leftSectionIndexes.length === 0) {
return <NoIndexes tabKey={leftTabKey} />;
}
Expand Down

0 comments on commit a4ec36d

Please sign in to comment.