Skip to content

Commit

Permalink
iterate
Browse files Browse the repository at this point in the history
  • Loading branch information
serefyarar committed Dec 2, 2024
1 parent 3acb361 commit 6fc4037
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
15 changes: 6 additions & 9 deletions api/src/controllers/discovery.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,27 @@ export const search = async (req, res, next) => {
export const completions = async (req, res, next) => {
const definition = req.app.get("runtimeDefinition");

const { messages, sources, timeFilter, stream = true, schema } = req.body;
const { messages, sources, timeFilter, stream = true, prompt, schema } = req.body;

try {
const didService = new DIDService(definition);
const reqIndexIds = await flattenSources(sources, didService);

// Only set SSE headers if streaming is enabled
if (stream) {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.setHeader('Content-Encoding', 'none');
}

const completionsPrompt = await hub.pull("v2_completions");
const completionsPromptText = completionsPrompt.promptMessages[0].prompt.template;

const response = await handleCompletions({
messages: [{
role: 'system',
content: completionsPromptText
}, ...messages],
messages,
indexIds: reqIndexIds,
stream,
timeFilter,
schema,
prompt,
});

// Handle streaming response
Expand All @@ -78,7 +73,9 @@ export const completions = async (req, res, next) => {
}
} catch (error) {
console.error("An error occurred:", error);
res.status(500).json({ error: "Internal Server Error" });
if (!res.headersSent) {
res.status(500).json({ error: "Internal Server Error" });
}
}
};

Expand Down
19 changes: 11 additions & 8 deletions api/src/language/completions.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { searchItems } from "./search_item.js";
import { zodResponseFormat } from "openai/helpers/zod";
import { jsonSchemaToZod } from "json-schema-to-zod";
import tiktoken from 'tiktoken';
import * as hub from "langchain/hub";
import { getModelInfo } from '../utils/mode.js';

const openai = wrapOpenAI(new OpenAI({
Expand Down Expand Up @@ -59,7 +60,7 @@ const getDocText = (doc, metadata, runtimeDefinition) => {
return JSON.stringify(doc);
};

export const handleCompletions = traceable(async ({ messages, indexIds, maxDocs=500, stream, schema, timeFilter }) => {
export const handleCompletions = traceable(async ({ messages, indexIds, maxDocs=500, stream, prompt, schema, timeFilter }) => {
console.time('handleCompletions:total');
const MAX_TOKENS = 100000;
let totalTokens = 0;
Expand Down Expand Up @@ -94,17 +95,19 @@ export const handleCompletions = traceable(async ({ messages, indexIds, maxDocs=

console.log('totalTokens', totalTokens)

const completionsPrompt = await hub.pull(prompt || "v2_completions");
const completionsPromptText = completionsPrompt.promptMessages[0].prompt.template;

if (retrievedDocs) {
messages.push({
role: 'system',
content: `Context information:\n${retrievedDocs}`
});
}

const completionOptions = {
model: process.env.MODEL_CHAT,
messages,
messages: [ {
role: 'system',
content: completionsPromptText
},{
role: 'system',
content: `Context information:\n${retrievedDocs || 'No context found'}`
}, ...messages],
temperature: 0,
stream: stream,
};
Expand Down
1 change: 1 addition & 0 deletions api/src/packages/api.js
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ app.post(
validator.body(
Joi.object({
messages: Joi.array().items(Joi.any()).min(1).required(),
prompt: Joi.string().optional(),
sources: Joi.array().items(Joi.string()).min(1).required(),
timeFilter: Joi.object({
from: Joi.date().iso().optional(),
Expand Down

0 comments on commit 6fc4037

Please sign in to comment.