Skip to content

Commit

Permalink
chore: removed generate/generateStream from executable prompts (#1239)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Nov 11, 2024
1 parent 8bb8f2b commit 6836a7e
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 242 deletions.
50 changes: 13 additions & 37 deletions js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,9 @@ export function isPrompt(arg: any): boolean {
}

export type PromptGenerateOptions<
I = undefined,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> = Omit<
GenerateOptions<z.ZodTypeAny, CustomOptions>,
'prompt' | 'input' | 'model'
> & {
model?: ModelArgument<CustomOptions>;
input?: I;
};
> = Omit<GenerateOptions<O, CustomOptions>, 'prompt'>;

/**
* A prompt that can be executed as a function.
Expand All @@ -89,51 +83,33 @@ export interface ExecutablePrompt<
* @param opt Options for the prompt template, including user input variables and custom model configuration options.
* @returns the model response as a promise of `GenerateStreamResponse`.
*/
<Out extends O>(
(
input?: I,
opts?: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateResponse<z.infer<Out>>>;
opts?: PromptGenerateOptions<O, CustomOptions>
): Promise<GenerateResponse<z.infer<O>>>;

/**
* Generates a response by rendering the prompt template with given user input and then calling the model.
* @param input Prompt inputs.
* @param opt Options for the prompt template, including user input variables and custom model configuration options.
* @returns the model response as a promise of `GenerateStreamResponse`.
*/
stream<Out extends O>(
stream(
input?: I,
opts?: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateStreamResponse<z.infer<Out>>>;

/**
* Generates a response by rendering the prompt template with given user input and additional generate options and then calling the model.
*
* @param opt Options for the prompt template, including user input variables and custom model configuration options.
* @returns the model response as a promise of `GenerateResponse`.
*/
generate<Out extends O>(
opt: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateResponse<z.infer<Out>>>;

/**
* Generates a streaming response by rendering the prompt template with given user input and additional generate options and then calling the model.
*
* @param opt Options for the prompt template, including user input variables and custom model configuration options.
* @returns the model response as a promise of `GenerateStreamResponse`.
*/
generateStream<Out extends O>(
opt: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateStreamResponse<z.infer<Out>>>;
opts?: PromptGenerateOptions<O, CustomOptions>
): Promise<GenerateStreamResponse<z.infer<O>>>;

/**
* Renders the prompt template based on user input.
*
* @param opt Options for the prompt template, including user input variables and custom model configuration options.
* @returns a `GenerateOptions` object to be used with the `generate()` function from @genkit-ai/ai.
*/
render<Out extends O>(
opt: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateOptions<CustomOptions, Out>>;
render(
opt: PromptGenerateOptions<O, CustomOptions> & {
input?: I;
}
): Promise<GenerateOptions<O, CustomOptions>>;

/**
* Returns the prompt usable as a tool.
Expand Down
34 changes: 8 additions & 26 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ export class Genkit {
): ExecutablePrompt<I, O, CustomOptions> {
const executablePrompt = async (
input?: z.infer<I>,
opts?: PromptGenerateOptions<I, CustomOptions>
opts?: PromptGenerateOptions<O, CustomOptions>
): Promise<GenerateResponse> => {
const renderedOpts = await (
executablePrompt as ExecutablePrompt<I, O, CustomOptions>
Expand All @@ -460,29 +460,11 @@ export class Genkit {
});
return this.generateStream(renderedOpts);
};
(executablePrompt as ExecutablePrompt<I, O, CustomOptions>).generate =
async (
opt: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateResponse<O>> => {
const renderedOpts = await (
executablePrompt as ExecutablePrompt<I, O, CustomOptions>
).render(opt);
return this.generate(renderedOpts);
};
(executablePrompt as ExecutablePrompt<I, O, CustomOptions>).generateStream =
async (
opt: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateStreamResponse<O>> => {
const renderedOpts = await (
executablePrompt as ExecutablePrompt<I, O, CustomOptions>
).render(opt);
return this.generateStream(renderedOpts);
};
(executablePrompt as ExecutablePrompt<I, O, CustomOptions>).render = async <
Out extends O,
>(
opt: PromptGenerateOptions<I, CustomOptions>
): Promise<GenerateOptions<CustomOptions, Out>> => {
(executablePrompt as ExecutablePrompt<I, O, CustomOptions>).render = async (
opt: PromptGenerateOptions<O, CustomOptions> & {
input?: I;
}
): Promise<GenerateOptions<O, CustomOptions>> => {
let model: ModelAction | undefined;
options = await options;
try {
Expand All @@ -509,8 +491,8 @@ export class Genkit {
...opt.config,
},
model,
} as GenerateOptions<CustomOptions, Out>;
delete (resultOptions as PromptGenerateOptions<I, CustomOptions>).input;
} as GenerateOptions<O, CustomOptions>;
delete (resultOptions as any).input;
return resultOptions;
};
(executablePrompt as ExecutablePrompt<I, O, CustomOptions>).asTool =
Expand Down
123 changes: 0 additions & 123 deletions js/genkit/tests/prompts_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,32 +77,6 @@ describe('definePrompt - dotprompt', () => {
);
});

it('calls dotprompt with .generate', async () => {
const hi = ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
config: {
temperature: 11,
},
},
'hi {{ name }}'
);

const response = await hi.generate({
input: { name: 'Genkit' },
config: { version: 'abc' },
});
assert.strictEqual(
response.text,
'Echo: hi Genkit; config: {"version":"abc","temperature":11}'
);
});

it('calls dotprompt with default model via retrieved prompt', async () => {
ai.definePrompt(
{
Expand Down Expand Up @@ -215,39 +189,6 @@ describe('definePrompt - dotprompt', () => {
assert.deepStrictEqual(chunks, ['3', '2', '1']);
});

it('streams dotprompt .generateStream', async () => {
const hi = ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
config: {
temperature: 11,
},
},
'hi {{ name }}'
);

const { response, stream } = await hi.generateStream({
input: { name: 'Genkit' },
config: { version: 'abc' },
});
const chunks: string[] = [];
for await (const chunk of stream) {
chunks.push(chunk.text);
}
const responseText = (await response).text;

assert.strictEqual(
responseText,
'Echo: hi Genkit; config: {"version":"abc","temperature":11}'
);
assert.deepStrictEqual(chunks, ['3', '2', '1']);
});

it('calls dotprompt with default model via retrieved prompt', async () => {
ai.definePrompt(
{
Expand Down Expand Up @@ -621,70 +562,6 @@ describe('definePrompt', () => {
'Echo: hi Genkit; config: {"version":"abc","temperature":11}'
);
});

it('works with .generate', async () => {
const hi = ai.definePrompt(
{
name: 'hi',
model: 'echoModel',
input: {
schema: z.object({
name: z.string(),
}),
},
},
async (input) => {
return {
messages: [
{ role: 'user', content: [{ text: `hi ${input.name}` }] },
],
};
}
);

const response = await hi.generate({ input: { name: 'Genkit' } });
assert.strictEqual(response.text, 'Echo: hi Genkit; config: {}');
});

it('streams dotprompt with .generateStream', async () => {
const hi = ai.definePrompt(
{
name: 'hi',
input: {
schema: z.object({
name: z.string(),
}),
},
config: {
temperature: 11,
},
},
async (input) => {
return {
messages: [
{ role: 'user', content: [{ text: `hi ${input.name}` }] },
],
};
}
);

const { response, stream } = await hi.generateStream({
model: 'echoModel',
input: { name: 'Genkit' },
config: { version: 'abc' },
});
const chunks: string[] = [];
for await (const chunk of stream) {
chunks.push(chunk.text);
}
const responseText = (await response).text;

assert.strictEqual(
responseText,
'Echo: hi Genkit; config: {"version":"abc","temperature":11}'
);
assert.deepStrictEqual(chunks, ['3', '2', '1']);
});
});

describe('render', () => {
Expand Down
10 changes: 6 additions & 4 deletions js/testapps/flow-simple-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,12 @@ export const dotpromptContext = ai.defineFlow(
},
];

const result = await ai.prompt('dotpromptContext').generate({
input: { question: question },
docs,
});
const result = await ai.prompt('dotpromptContext')(
{ question: question },
{
docs,
}
);
return result.output as any;
}
);
Expand Down
10 changes: 3 additions & 7 deletions js/testapps/menu/src/02/flows.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@ export const s02_menuQuestionFlow = ai.defineFlow(
outputSchema: AnswerOutputSchema,
},
async (input) => {
return s02_dataMenuPrompt
.generate({
input: { question: input.question },
})
.then((response) => {
return { answer: response.text };
});
return s02_dataMenuPrompt({ question: input.question }).then((response) => {
return { answer: response.text };
});
}
);
8 changes: 3 additions & 5 deletions js/testapps/menu/src/04/flows.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,9 @@ export const s04_ragMenuQuestionFlow = ai.defineFlow(
);

// Generate the response
const response = await s04_ragDataMenuPrompt.generate({
input: {
menuData: menuData,
question: input.question,
},
const response = await s04_ragDataMenuPrompt({
menuData: menuData,
question: input.question,
});
return { answer: response.text };
}
Expand Down
14 changes: 5 additions & 9 deletions js/testapps/menu/src/05/flows.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ export const s05_readMenuFlow = ai.defineFlow(
},
async (unused) => {
const imageDataUrl = await inlineDataUrl('menu.jpeg', 'image/jpeg');
const response = await s05_readMenuPrompt.generate({
input: {
imageUrl: imageDataUrl,
},
const response = await s05_readMenuPrompt({
imageUrl: imageDataUrl,
});
return { menuText: response.text };
}
Expand All @@ -57,11 +55,9 @@ export const s05_textMenuQuestionFlow = ai.defineFlow(
outputSchema: AnswerOutputSchema,
},
async (input) => {
const response = await s05_textMenuPrompt.generate({
input: {
menuText: input.menuText,
question: input.question,
},
const response = await s05_textMenuPrompt({
menuText: input.menuText,
question: input.question,
});
return { answer: response.text };
}
Expand Down
13 changes: 6 additions & 7 deletions js/testapps/prompt-file/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ ai.defineFlow(
outputSchema: RecipeSchema,
},
async (input) =>
(await ai.prompt('recipe').generate<typeof RecipeSchema>({ input: input }))
.output!
(await ai.prompt<any, typeof RecipeSchema>('recipe')(input)).output!
);

ai.defineFlow(
Expand All @@ -73,8 +72,7 @@ ai.defineFlow(
outputSchema: z.any(),
},
async (input) =>
(await ai.prompt('recipe', { variant: 'robot' }).generate({ input: input }))
.output
(await ai.prompt('recipe', { variant: 'robot' })(input)).output
);

// A variation that supports streaming, optionally
Expand All @@ -92,15 +90,16 @@ ai.defineStreamingFlow(
async ({ subject, personality }, streamingCallback) => {
const storyPrompt = ai.prompt('story');
if (streamingCallback) {
const { response, stream } = await storyPrompt.generateStream({
input: { subject, personality },
const { response, stream } = await storyPrompt.stream({
subject,
personality,
});
for await (const chunk of stream) {
streamingCallback(chunk.content[0]?.text!);
}
return (await response).text;
} else {
const response = await storyPrompt.generate({ input: { subject } });
const response = await storyPrompt({ subject });
return response.text;
}
}
Expand Down
Loading

0 comments on commit 6836a7e

Please sign in to comment.