Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Changes to prompts, chat agent, endpoints to support Multimodal I/O #57

Merged
merged 6 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@oconva/qvikchat",
"version": "2.0.0-alpha.0",
"version": "2.0.0-alpha.1",
"repository": {
"type": "git",
"url": "https://github.com/oconva/qvikchat.git"
Expand Down
67 changes: 55 additions & 12 deletions src/agents/chat-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ import type {GenerateResponse} from '@genkit-ai/ai';
import {ChatHistoryStore} from '../history/chat-history-store';
import {
ModelConfig,
OutputSchemaType,
SupportedModelNames,
SupportedModels,
} from '../models/models';
import {
closeEndedSystemPrompt,
openEndedSystemPrompt,
ragSystemPrompt,
getOpenEndedSystemPrompt,
getCloseEndedSystemPrompt,
getRagSystemPrompt,
} from '../prompts/system-prompts';
import {MessageData} from '@genkit-ai/ai/model';
import {ToolArgument} from '@genkit-ai/ai/tool';
import {Dotprompt} from '@genkit-ai/dotprompt';
import {PromptOutputSchema} from '../prompts/prompts';

/**
* Represents the type of chat agent.
Expand Down Expand Up @@ -44,13 +46,15 @@ export type AgentTypeConfig =
* @property tools - Tools for the chat agent.
* @property model - The supported model to use for chat completion.
* @property modelConfig - The model configuration.
* @property responseOutputSchema - The output schema for the response.
*/
export type ChatAgentConfig = {
systemPrompt?: Dotprompt;
chatPrompt?: Dotprompt;
tools?: ToolArgument[];
model?: SupportedModels;
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;
} & AgentTypeConfig;

export type DefaultChatAgentConfigType = {
Expand All @@ -67,7 +71,6 @@ export const defaultChatAgentConfig: DefaultChatAgentConfigType = {
};

/**

* Represents the attributes of the chat agent.
*/
export interface ChatAgentAttributes {
Expand All @@ -77,6 +80,7 @@ export interface ChatAgentAttributes {
chatPrompt?: Dotprompt;
tools?: ToolArgument[];
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;
}

/**
Expand Down Expand Up @@ -117,6 +121,7 @@ export type GenerateResponseProps = {
modelConfig?: ModelConfig;
systemPrompt?: Dotprompt;
chatPrompt?: Dotprompt;
responseOutputSchema?: OutputSchemaType;
} & GenerateResponseHistoryProps;

/**
Expand Down Expand Up @@ -179,6 +184,7 @@ export class ChatAgent implements ChatAgentInterface {
tools?: ToolArgument[];
private modelName: string;
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;

/**
* Creates a new instance of the chat agent.
Expand All @@ -194,33 +200,38 @@ export class ChatAgent implements ChatAgentInterface {
*/
constructor(config: ChatAgentConfig = {}) {
this.agentType = config.agentType ?? defaultChatAgentConfig.agentType;
if ('topic' in config) {
this.topic = config.topic;
}
this.systemPrompt = config.systemPrompt;
this.chatPrompt = config.chatPrompt;
this.tools = config.tools;
this.modelName = config.model
? SupportedModelNames[config.model]
: SupportedModelNames[defaultChatAgentConfig.model];
this.modelConfig = config.modelConfig;
if ('topic' in config) {
this.topic = config.topic;
}
this.responseOutputSchema = config.responseOutputSchema;
}

/**
* Gets the system prompt based on the agent type.
* @param agentType - The type of agent.
* @param outputSchema - The output schema for the system prompt.
* @returns Returns the system prompt.
* @throws Throws an error if the agent type is invalid.
*/
private static getSystemPrompt(agentType?: ChatAgentType) {
private static getSystemPrompt(
agentType?: ChatAgentType,
outputSchema?: PromptOutputSchema
) {
// get the system prompt based on the agent type
switch (agentType) {
case 'open-ended':
return openEndedSystemPrompt;
return getOpenEndedSystemPrompt({outputSchema});
case 'close-ended':
return closeEndedSystemPrompt;
return getCloseEndedSystemPrompt({outputSchema});
case 'rag':
return ragSystemPrompt;
return getRagSystemPrompt({outputSchema});
default:
throw new Error('Invalid agent type');
}
Expand Down Expand Up @@ -300,6 +311,33 @@ export class ChatAgent implements ChatAgentInterface {
return res;
}

/**
* Method to get prompt output schema based on the response type.
* @param responseOutputSchema Response type specified by user
* @returns PromptOutputSchema object based on the response type
*/
static getPromptOutputSchema(
responseOutputSchema?: OutputSchemaType
): PromptOutputSchema {
if (!responseOutputSchema) {
return {format: 'text'};
} else if (responseOutputSchema.responseType === 'text') {
return {format: 'text'};
} else if (responseOutputSchema.responseType === 'json') {
return {
format: 'json',
schema: responseOutputSchema.schema,
jsonSchema: responseOutputSchema.jsonSchema,
};
} else if (responseOutputSchema.responseType === 'media') {
return {format: 'media'};
} else {
throw new Error(
`Invalid response type ${responseOutputSchema.responseType}`
);
}
}

/**
* Generates a response based on the given properties.
* @param query - The query string.
Expand All @@ -320,7 +358,12 @@ export class ChatAgent implements ChatAgentInterface {
const prompt =
params.systemPrompt ??
this.systemPrompt ??
ChatAgent.getSystemPrompt(this.agentType);
ChatAgent.getSystemPrompt(
this.agentType,
ChatAgent.getPromptOutputSchema(
params.responseOutputSchema ?? this.responseOutputSchema // responseOutputSchema provided as argument takes priority
)
);
// if not using chat history
if (!params.enableChatHistory) {
// return response in specified format
Expand Down
70 changes: 39 additions & 31 deletions src/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ import {getDataRetriever} from '../rag/data-retrievers/data-retrievers';
import {ChatHistoryStore} from '../history/chat-history-store';
import {Dotprompt} from '@genkit-ai/dotprompt';
import {ToolArgument} from '@genkit-ai/ai/tool';
import {ModelConfig, SupportedModels} from '../models/models';
import {
ModelConfig,
OutputSchema,
OutputSchemaType,
SupportedModels,
} from '../models/models';
import {getSystemPromptText} from '../prompts/system-prompts';

type ChatHistoryParams =
Expand Down Expand Up @@ -87,15 +92,6 @@ type EndpointChatAgentConfig = {
modelConfig?: ModelConfig;
};

type ResponseTypeParams =
| {
responseType?: 'json' | 'text';
}
| {
responseType?: 'media';
contentType: string;
};

type VerboseDetails = {
usage: GenerationUsage;
request?: GenerateRequest;
Expand All @@ -106,8 +102,8 @@ export type DefineChatEndpointConfig = {
endpoint: string;
chatAgentConfig?: EndpointChatAgentConfig;
verbose?: boolean;
} & ResponseTypeParams &
ChatAgentTypeParams &
outputSchema?: OutputSchemaType;
} & ChatAgentTypeParams &
ChatHistoryParams &
AuthParams &
CacheParams &
Expand All @@ -117,8 +113,7 @@ export type DefineChatEndpointConfig = {
* Method to define a chat endpoint using the provided chat agent and endpoint, with support for chat history.
* @param chatAgentConfig Configurations for the chat agent, like LLM model, system prompt, chat prompt, and tools.
* @param endpoint Server endpoint to which queries should be sent to run this chat flow.
* @param responseType Type of response to return. Can be "json", "text", or "media".
* @param contentType If response type is "media", provide the content type of the media response.
* @param outputSchema Output schema for the chat endpoint. Can be "text", "json" or "media". By default, the output format is text.
* @param verbose A flag to indicate whether to return a verbose response or not.
* @param agentType Type of chat agent to use for this endpoint. Can be "open-ended" or "close-ended".
* @param topic Topic for the close-ended or RAG chat agent. Required if agentType is "close-ended" or if RAG is enabled.
Expand All @@ -141,13 +136,16 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
query: z.string(),
chatId: z.string().optional(),
uid: config.enableAuth ? z.string() : z.string().optional(),
outputSchema: OutputSchema.optional(),
}),
outputSchema: z.union([
z.object({
response:
config.responseType === 'text' || config.responseType === undefined
!config.outputSchema ||
!config.outputSchema.responseType ||
config.outputSchema?.responseType === 'text'
? z.string()
: config.responseType === 'media'
: config.outputSchema.responseType === 'media'
? z.object({
contentType: z.string(),
url: z.string(),
Expand Down Expand Up @@ -194,12 +192,18 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
},
},
async ({query, chatId}) => {
async ({query, chatId, outputSchema}) => {
if (query === '') return {response: 'How can I help you today?'};

// set default response type
if (!config.responseType) {
config.responseType = 'text';
if (!config.outputSchema || !config.outputSchema.responseType) {
config.outputSchema = {responseType: 'text'};
}

// set output schema
// Output schema provided in the request takes precedence over output schema configured for the endpoint
if (!outputSchema) {
outputSchema = config.outputSchema;
}

// store chat agent
Expand All @@ -221,10 +225,12 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
? new ChatAgent({
agentType: 'close-ended',
topic: config.topic,
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
})
: new ChatAgent({
agentType: 'open-ended',
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
});
}
Expand All @@ -234,6 +240,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
chatAgent = new ChatAgent({
agentType: 'rag',
topic: config.topic,
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
});
}
Expand Down Expand Up @@ -297,7 +304,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
// also check if response type matches the expected response type
if (
cachedQuery.response &&
cachedQuery.responseType === config.responseType
cachedQuery.responseType === outputSchema.responseType
) {
// increment cache hits
config.cacheStore.incrementCacheHits(queryHash);
Expand All @@ -307,7 +314,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
let cachedModelResponse: MessageData;
// if expected response type is "text" and cached response type is "text"
if (
config.responseType === 'text' &&
outputSchema.responseType === 'text' &&
cachedQuery.responseType === 'text'
) {
cachedModelResponse = {
Expand All @@ -317,7 +324,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
// else if expected response type is "json" and cached response type is "json"
else if (
config.responseType === 'json' &&
outputSchema.responseType === 'json' &&
cachedQuery.responseType === 'json'
) {
cachedModelResponse = {
Expand All @@ -327,7 +334,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
// else if expected response type is "media" and cached response type is "media"
else if (
config.responseType === 'media' &&
outputSchema.responseType === 'media' &&
cachedQuery.responseType === 'media'
) {
cachedModelResponse = {
Expand Down Expand Up @@ -416,10 +423,10 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
} catch (error) {
// if query is not in cache, add it to cache to track the number of times this query is received
// sending hash is optional. Sending so hash doesn't have to be recalculated
// remeber to add the query with context
// remember to add the query with context
config.cacheStore.addQuery(
queryWithContext,
config.responseType,
outputSchema.responseType ?? 'text', // default to text
queryHash
);
}
Expand Down Expand Up @@ -454,10 +461,11 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
let queryWithParams: GenerateResponseProps = {
query,
context,
responseOutputSchema: outputSchema,
};

// If chat history is enabled and response type is not media
if (config.enableChatHistory && config.responseType !== 'media') {
// If chat history is enabled
if (config.enableChatHistory) {
// Prepare history props for generating response
const historyProps: GenerateResponseHistoryProps = {
enableChatHistory: config.enableChatHistory,
Expand All @@ -477,14 +485,14 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
// Not supported for media response type
if (config.enableCache && config.cacheStore && cacheThresholdReached) {
// cache response based on response type
if (config.responseType === 'json') {
if (outputSchema.responseType === 'json') {
config.cacheStore.cacheResponse(queryHash, {
responseType: 'json',
response: JSON.stringify(response.res.output()),
});
}
// if media
else if (config.responseType === 'media') {
else if (outputSchema.responseType === 'media') {
const mediaContent = response.res.media();
// if we have valid data
if (mediaContent?.contentType && mediaContent?.url) {
Expand All @@ -508,9 +516,9 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>

// return response based on response type
let res;
if (config.responseType === 'json') {
if (outputSchema.responseType === 'json') {
res = response.res.output();
} else if (config.responseType === 'media') {
} else if (outputSchema.responseType === 'media') {
const mediaContent = response.res.media();
// if we have valid data
if (mediaContent?.contentType && mediaContent?.url) {
Expand Down
Loading