Skip to content

Commit

Permalink
Added Changes to prompts, chat agent, endpoints to support Multimodal…
Browse files Browse the repository at this point in the history
… I/O (#57)

* added DALL-E 3 to supported models

* Updated prompts with partials and custom output schema #43

* updated prompts for multimodality #43

* Updated chat agent class for multimodal I/O #43

* Updated defineChatEndpoint to support multimodal I/O #43

* bumped NPM package for alpha release
  • Loading branch information
pranav-kural authored Jul 29, 2024
1 parent 8581cdc commit 60f116a
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 263 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@oconva/qvikchat",
"version": "2.0.0-alpha.0",
"version": "2.0.0-alpha.1",
"repository": {
"type": "git",
"url": "https://github.com/oconva/qvikchat.git"
Expand Down
67 changes: 55 additions & 12 deletions src/agents/chat-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ import type {GenerateResponse} from '@genkit-ai/ai';
import {ChatHistoryStore} from '../history/chat-history-store';
import {
ModelConfig,
OutputSchemaType,
SupportedModelNames,
SupportedModels,
} from '../models/models';
import {
closeEndedSystemPrompt,
openEndedSystemPrompt,
ragSystemPrompt,
getOpenEndedSystemPrompt,
getCloseEndedSystemPrompt,
getRagSystemPrompt,
} from '../prompts/system-prompts';
import {MessageData} from '@genkit-ai/ai/model';
import {ToolArgument} from '@genkit-ai/ai/tool';
import {Dotprompt} from '@genkit-ai/dotprompt';
import {PromptOutputSchema} from '../prompts/prompts';

/**
* Represents the type of chat agent.
Expand Down Expand Up @@ -44,13 +46,15 @@ export type AgentTypeConfig =
* @property tools - Tools for the chat agent.
* @property model - The supported model to use for chat completion.
* @property modelConfig - The model configuration.
* @property responseOutputSchema - The output schema for the response.
*/
export type ChatAgentConfig = {
systemPrompt?: Dotprompt;
chatPrompt?: Dotprompt;
tools?: ToolArgument[];
model?: SupportedModels;
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;
} & AgentTypeConfig;

export type DefaultChatAgentConfigType = {
Expand All @@ -67,7 +71,6 @@ export const defaultChatAgentConfig: DefaultChatAgentConfigType = {
};

/**
* Represents the attributes of the chat agent.
*/
export interface ChatAgentAttributes {
Expand All @@ -77,6 +80,7 @@ export interface ChatAgentAttributes {
chatPrompt?: Dotprompt;
tools?: ToolArgument[];
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;
}

/**
Expand Down Expand Up @@ -117,6 +121,7 @@ export type GenerateResponseProps = {
modelConfig?: ModelConfig;
systemPrompt?: Dotprompt;
chatPrompt?: Dotprompt;
responseOutputSchema?: OutputSchemaType;
} & GenerateResponseHistoryProps;

/**
Expand Down Expand Up @@ -179,6 +184,7 @@ export class ChatAgent implements ChatAgentInterface {
tools?: ToolArgument[];
private modelName: string;
modelConfig?: ModelConfig;
responseOutputSchema?: OutputSchemaType;

/**
* Creates a new instance of the chat agent.
Expand All @@ -194,33 +200,38 @@ export class ChatAgent implements ChatAgentInterface {
*/
constructor(config: ChatAgentConfig = {}) {
this.agentType = config.agentType ?? defaultChatAgentConfig.agentType;
if ('topic' in config) {
this.topic = config.topic;
}
this.systemPrompt = config.systemPrompt;
this.chatPrompt = config.chatPrompt;
this.tools = config.tools;
this.modelName = config.model
? SupportedModelNames[config.model]
: SupportedModelNames[defaultChatAgentConfig.model];
this.modelConfig = config.modelConfig;
if ('topic' in config) {
this.topic = config.topic;
}
this.responseOutputSchema = config.responseOutputSchema;
}

/**
* Gets the system prompt based on the agent type.
* @param agentType - The type of agent.
* @param outputSchema - The output schema for the system prompt.
* @returns Returns the system prompt.
* @throws Throws an error if the agent type is invalid.
*/
private static getSystemPrompt(agentType?: ChatAgentType) {
private static getSystemPrompt(
agentType?: ChatAgentType,
outputSchema?: PromptOutputSchema
) {
// get the system prompt based on the agent type
switch (agentType) {
case 'open-ended':
return openEndedSystemPrompt;
return getOpenEndedSystemPrompt({outputSchema});
case 'close-ended':
return closeEndedSystemPrompt;
return getCloseEndedSystemPrompt({outputSchema});
case 'rag':
return ragSystemPrompt;
return getRagSystemPrompt({outputSchema});
default:
throw new Error('Invalid agent type');
}
Expand Down Expand Up @@ -300,6 +311,33 @@ export class ChatAgent implements ChatAgentInterface {
return res;
}

/**
* Method to get prompt output schema based on the response type.
* @param responseOutputSchema Response type specified by user
* @returns PromptOutputSchema object based on the response type
*/
static getPromptOutputSchema(
responseOutputSchema?: OutputSchemaType
): PromptOutputSchema {
if (!responseOutputSchema) {
return {format: 'text'};
} else if (responseOutputSchema.responseType === 'text') {
return {format: 'text'};
} else if (responseOutputSchema.responseType === 'json') {
return {
format: 'json',
schema: responseOutputSchema.schema,
jsonSchema: responseOutputSchema.jsonSchema,
};
} else if (responseOutputSchema.responseType === 'media') {
return {format: 'media'};
} else {
throw new Error(
`Invalid response type ${responseOutputSchema.responseType}`
);
}
}

/**
* Generates a response based on the given properties.
* @param query - The query string.
Expand All @@ -320,7 +358,12 @@ export class ChatAgent implements ChatAgentInterface {
const prompt =
params.systemPrompt ??
this.systemPrompt ??
ChatAgent.getSystemPrompt(this.agentType);
ChatAgent.getSystemPrompt(
this.agentType,
ChatAgent.getPromptOutputSchema(
params.responseOutputSchema ?? this.responseOutputSchema // responseOutputSchema provided as argument takes priority
)
);
// if not using chat history
if (!params.enableChatHistory) {
// return response in specified format
Expand Down
70 changes: 39 additions & 31 deletions src/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ import {getDataRetriever} from '../rag/data-retrievers/data-retrievers';
import {ChatHistoryStore} from '../history/chat-history-store';
import {Dotprompt} from '@genkit-ai/dotprompt';
import {ToolArgument} from '@genkit-ai/ai/tool';
import {ModelConfig, SupportedModels} from '../models/models';
import {
ModelConfig,
OutputSchema,
OutputSchemaType,
SupportedModels,
} from '../models/models';
import {getSystemPromptText} from '../prompts/system-prompts';

type ChatHistoryParams =
Expand Down Expand Up @@ -87,15 +92,6 @@ type EndpointChatAgentConfig = {
modelConfig?: ModelConfig;
};

type ResponseTypeParams =
| {
responseType?: 'json' | 'text';
}
| {
responseType?: 'media';
contentType: string;
};

type VerboseDetails = {
usage: GenerationUsage;
request?: GenerateRequest;
Expand All @@ -106,8 +102,8 @@ export type DefineChatEndpointConfig = {
endpoint: string;
chatAgentConfig?: EndpointChatAgentConfig;
verbose?: boolean;
} & ResponseTypeParams &
ChatAgentTypeParams &
outputSchema?: OutputSchemaType;
} & ChatAgentTypeParams &
ChatHistoryParams &
AuthParams &
CacheParams &
Expand All @@ -117,8 +113,7 @@ export type DefineChatEndpointConfig = {
* Method to define a chat endpoint using the provided chat agent and endpoint, with support for chat history.
* @param chatAgentConfig Configurations for the chat agent, like LLM model, system prompt, chat prompt, and tools.
* @param endpoint Server endpoint to which queries should be sent to run this chat flow.
* @param responseType Type of response to return. Can be "json", "text", or "media".
* @param contentType If response type is "media", provide the content type of the media response.
* @param outputSchema Output schema for the chat endpoint. Can be "text", "json" or "media". By default, the output format is text.
* @param verbose A flag to indicate whether to return a verbose response or not.
* @param agentType Type of chat agent to use for this endpoint. Can be "open-ended" or "close-ended".
* @param topic Topic for the close-ended or RAG chat agent. Required if agentType is "close-ended" or if RAG is enabled.
Expand All @@ -141,13 +136,16 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
query: z.string(),
chatId: z.string().optional(),
uid: config.enableAuth ? z.string() : z.string().optional(),
outputSchema: OutputSchema.optional(),
}),
outputSchema: z.union([
z.object({
response:
config.responseType === 'text' || config.responseType === undefined
!config.outputSchema ||
!config.outputSchema.responseType ||
config.outputSchema?.responseType === 'text'
? z.string()
: config.responseType === 'media'
: config.outputSchema.responseType === 'media'
? z.object({
contentType: z.string(),
url: z.string(),
Expand Down Expand Up @@ -194,12 +192,18 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
},
},
async ({query, chatId}) => {
async ({query, chatId, outputSchema}) => {
if (query === '') return {response: 'How can I help you today?'};

// set default response type
if (!config.responseType) {
config.responseType = 'text';
if (!config.outputSchema || !config.outputSchema.responseType) {
config.outputSchema = {responseType: 'text'};
}

// set output schema
// Output schema provided in the request takes precedence over output schema configured for the endpoint
if (!outputSchema) {
outputSchema = config.outputSchema;
}

// store chat agent
Expand All @@ -221,10 +225,12 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
? new ChatAgent({
agentType: 'close-ended',
topic: config.topic,
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
})
: new ChatAgent({
agentType: 'open-ended',
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
});
}
Expand All @@ -234,6 +240,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
chatAgent = new ChatAgent({
agentType: 'rag',
topic: config.topic,
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
});
}
Expand Down Expand Up @@ -297,7 +304,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
// also check if response type matches the expected response type
if (
cachedQuery.response &&
cachedQuery.responseType === config.responseType
cachedQuery.responseType === outputSchema.responseType
) {
// increment cache hits
config.cacheStore.incrementCacheHits(queryHash);
Expand All @@ -307,7 +314,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
let cachedModelResponse: MessageData;
// if expected response type is "text" and cached response type is "text"
if (
config.responseType === 'text' &&
outputSchema.responseType === 'text' &&
cachedQuery.responseType === 'text'
) {
cachedModelResponse = {
Expand All @@ -317,7 +324,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
// else if expected response type is "json" and cached response type is "json"
else if (
config.responseType === 'json' &&
outputSchema.responseType === 'json' &&
cachedQuery.responseType === 'json'
) {
cachedModelResponse = {
Expand All @@ -327,7 +334,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
// else if expected response type is "media" and cached response type is "media"
else if (
config.responseType === 'media' &&
outputSchema.responseType === 'media' &&
cachedQuery.responseType === 'media'
) {
cachedModelResponse = {
Expand Down Expand Up @@ -416,10 +423,10 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
} catch (error) {
// if query is not in cache, add it to cache to track the number of times this query is received
// sending hash is optional. Sending so hash doesn't have to be recalculated
// remeber to add the query with context
// remember to add the query with context
config.cacheStore.addQuery(
queryWithContext,
config.responseType,
outputSchema.responseType ?? 'text', // default to text
queryHash
);
}
Expand Down Expand Up @@ -454,10 +461,11 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
let queryWithParams: GenerateResponseProps = {
query,
context,
responseOutputSchema: outputSchema,
};

// If chat history is enabled and response type is not media
if (config.enableChatHistory && config.responseType !== 'media') {
// If chat history is enabled
if (config.enableChatHistory) {
// Prepare history props for generating response
const historyProps: GenerateResponseHistoryProps = {
enableChatHistory: config.enableChatHistory,
Expand All @@ -477,14 +485,14 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
// Not supported for media response type
if (config.enableCache && config.cacheStore && cacheThresholdReached) {
// cache response based on response type
if (config.responseType === 'json') {
if (outputSchema.responseType === 'json') {
config.cacheStore.cacheResponse(queryHash, {
responseType: 'json',
response: JSON.stringify(response.res.output()),
});
}
// if media
else if (config.responseType === 'media') {
else if (outputSchema.responseType === 'media') {
const mediaContent = response.res.media();
// if we have valid data
if (mediaContent?.contentType && mediaContent?.url) {
Expand All @@ -508,9 +516,9 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>

// return response based on response type
let res;
if (config.responseType === 'json') {
if (outputSchema.responseType === 'json') {
res = response.res.output();
} else if (config.responseType === 'media') {
} else if (outputSchema.responseType === 'media') {
const mediaContent = response.res.media();
// if we have valid data
if (mediaContent?.contentType && mediaContent?.url) {
Expand Down
Loading

0 comments on commit 60f116a

Please sign in to comment.