Skip to content

Commit

Permalink
Updated defineChatEndpoint to support multimodal I/O #43
Browse files Browse the repository at this point in the history
  • Loading branch information
pranav-kural committed Jul 29, 2024
1 parent 679057e commit 479ce37
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 31 deletions.
70 changes: 39 additions & 31 deletions src/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ import {getDataRetriever} from '../rag/data-retrievers/data-retrievers';
import {ChatHistoryStore} from '../history/chat-history-store';
import {Dotprompt} from '@genkit-ai/dotprompt';
import {ToolArgument} from '@genkit-ai/ai/tool';
import {ModelConfig, SupportedModels} from '../models/models';
import {
ModelConfig,
OutputSchema,
OutputSchemaType,
SupportedModels,
} from '../models/models';
import {getSystemPromptText} from '../prompts/system-prompts';

type ChatHistoryParams =
Expand Down Expand Up @@ -87,15 +92,6 @@ type EndpointChatAgentConfig = {
modelConfig?: ModelConfig;
};

type ResponseTypeParams =
| {
responseType?: 'json' | 'text';
}
| {
responseType?: 'media';
contentType: string;
};

type VerboseDetails = {
usage: GenerationUsage;
request?: GenerateRequest;
Expand All @@ -106,8 +102,8 @@ export type DefineChatEndpointConfig = {
endpoint: string;
chatAgentConfig?: EndpointChatAgentConfig;
verbose?: boolean;
} & ResponseTypeParams &
ChatAgentTypeParams &
outputSchema?: OutputSchemaType;
} & ChatAgentTypeParams &
ChatHistoryParams &
AuthParams &
CacheParams &
Expand All @@ -117,8 +113,7 @@ export type DefineChatEndpointConfig = {
* Method to define a chat endpoint using the provided chat agent and endpoint, with support for chat history.
* @param chatAgentConfig Configurations for the chat agent, like LLM model, system prompt, chat prompt, and tools.
* @param endpoint Server endpoint to which queries should be sent to run this chat flow.
* @param responseType Type of response to return. Can be "json", "text", or "media".
* @param contentType If response type is "media", provide the content type of the media response.
* @param outputSchema Output schema for the chat endpoint. Can be "text", "json" or "media". By default, the output format is text.
* @param verbose A flag to indicate whether to return a verbose response or not.
* @param agentType Type of chat agent to use for this endpoint. Can be "open-ended" or "close-ended".
* @param topic Topic for the close-ended or RAG chat agent. Required if agentType is "close-ended" or if RAG is enabled.
Expand All @@ -141,13 +136,16 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
query: z.string(),
chatId: z.string().optional(),
uid: config.enableAuth ? z.string() : z.string().optional(),
outputSchema: OutputSchema.optional(),
}),
outputSchema: z.union([
z.object({
response:
config.responseType === 'text' || config.responseType === undefined
!config.outputSchema ||
!config.outputSchema.responseType ||
config.outputSchema?.responseType === 'text'
? z.string()
: config.responseType === 'media'
: config.outputSchema.responseType === 'media'
? z.object({
contentType: z.string(),
url: z.string(),
Expand Down Expand Up @@ -194,12 +192,18 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
},
},
async ({query, chatId}) => {
async ({query, chatId, outputSchema}) => {
if (query === '') return {response: 'How can I help you today?'};

// set default response type
if (!config.responseType) {
config.responseType = 'text';
if (!config.outputSchema || !config.outputSchema.responseType) {
config.outputSchema = {responseType: 'text'};
}

// set output schema
// Output schema provided in the request takes precedence over output schema configured for the endpoint
if (!outputSchema) {
outputSchema = config.outputSchema;
}

// store chat agent
Expand All @@ -221,10 +225,12 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
? new ChatAgent({
agentType: 'close-ended',
topic: config.topic,
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
})
: new ChatAgent({
agentType: 'open-ended',
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
});
}
Expand All @@ -234,6 +240,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
chatAgent = new ChatAgent({
agentType: 'rag',
topic: config.topic,
responseOutputSchema: outputSchema,
...config.chatAgentConfig,
});
}
Expand Down Expand Up @@ -297,7 +304,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
// also check if response type matches the expected response type
if (
cachedQuery.response &&
cachedQuery.responseType === config.responseType
cachedQuery.responseType === outputSchema.responseType
) {
// increment cache hits
config.cacheStore.incrementCacheHits(queryHash);
Expand All @@ -307,7 +314,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
let cachedModelResponse: MessageData;
// if expected response type is "text" and cached response type is "text"
if (
config.responseType === 'text' &&
outputSchema.responseType === 'text' &&
cachedQuery.responseType === 'text'
) {
cachedModelResponse = {
Expand All @@ -317,7 +324,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
// else if expected response type is "json" and cached response type is "json"
else if (
config.responseType === 'json' &&
outputSchema.responseType === 'json' &&
cachedQuery.responseType === 'json'
) {
cachedModelResponse = {
Expand All @@ -327,7 +334,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
}
// else if expected response type is "media" and cached response type is "media"
else if (
config.responseType === 'media' &&
outputSchema.responseType === 'media' &&
cachedQuery.responseType === 'media'
) {
cachedModelResponse = {
Expand Down Expand Up @@ -416,10 +423,10 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
} catch (error) {
// if query is not in cache, add it to cache to track the number of times this query is received
// sending hash is optional. Sending so hash doesn't have to be recalculated
// remeber to add the query with context
// remember to add the query with context
config.cacheStore.addQuery(
queryWithContext,
config.responseType,
outputSchema.responseType ?? 'text', // default to text
queryHash
);
}
Expand Down Expand Up @@ -454,10 +461,11 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
let queryWithParams: GenerateResponseProps = {
query,
context,
responseOutputSchema: outputSchema,
};

// If chat history is enabled and response type is not media
if (config.enableChatHistory && config.responseType !== 'media') {
// If chat history is enabled
if (config.enableChatHistory) {
// Prepare history props for generating response
const historyProps: GenerateResponseHistoryProps = {
enableChatHistory: config.enableChatHistory,
Expand All @@ -477,14 +485,14 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>
// Not supported for media response type
if (config.enableCache && config.cacheStore && cacheThresholdReached) {
// cache response based on response type
if (config.responseType === 'json') {
if (outputSchema.responseType === 'json') {
config.cacheStore.cacheResponse(queryHash, {
responseType: 'json',
response: JSON.stringify(response.res.output()),
});
}
// if media
else if (config.responseType === 'media') {
else if (outputSchema.responseType === 'media') {
const mediaContent = response.res.media();
// if we have valid data
if (mediaContent?.contentType && mediaContent?.url) {
Expand All @@ -508,9 +516,9 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) =>

// return response based on response type
let res;
if (config.responseType === 'json') {
if (outputSchema.responseType === 'json') {
res = response.res.output();
} else if (config.responseType === 'media') {
} else if (outputSchema.responseType === 'media') {
const mediaContent = response.res.media();
// if we have valid data
if (mediaContent?.contentType && mediaContent?.url) {
Expand Down
24 changes: 24 additions & 0 deletions src/models/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
gpt4Vision,
dallE3,
} from 'genkitx-openai';
import {z} from 'zod';

/**
* Names of supported models.
Expand Down Expand Up @@ -62,3 +63,26 @@ export type ModelConfig = {
}[]
| undefined;
};

/**
* Output schema for model responses.
*/
export const OutputSchema = z.union([
z.object({
responseType: z.literal('text').optional(),
}),
z.object({
responseType: z.literal('json').optional(),
schema: z.any().optional(),
jsonSchema: z.any().optional(),
}),
z.object({
responseType: z.literal('media').optional(),
contentType: z.string(),
}),
]);

/**
* Possible output schemas for model responses.
*/
export type OutputSchemaType = z.infer<typeof OutputSchema>;

0 comments on commit 479ce37

Please sign in to comment.