From 97ca6a73a6d4bf8e876b0b5b7f0c64d64c4b3dab Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 21 Nov 2024 21:14:22 -0500 Subject: [PATCH] feat: added a way to run action with telemetry, no more 'dev-run-action-wrapper' (#1357) --- js/core/src/action.ts | 80 ++++++++++++++++++++++++++++-------- js/core/src/reflection.ts | 53 +++++++----------------- js/core/tests/action_test.ts | 32 +++++++++++++++ js/genkit/src/index.ts | 1 - 4 files changed, 110 insertions(+), 56 deletions(-) diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 382b80d14..e8360549a 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -28,6 +28,9 @@ import { export { Status, StatusCodes, StatusSchema } from './statusTypes.js'; export { JSONSchema7 }; +/** + * Action metadata. + */ export interface ActionMetadata< I extends z.ZodTypeAny, O extends z.ZodTypeAny, @@ -43,16 +46,32 @@ export interface ActionMetadata< metadata?: M; } +/** + * Results of an action run. Includes telemetry. + */ +export interface ActionResult { + result: O; + telemetry: { + traceId: string; + spanId: string; + }; +} + +/** + * Self-describing, validating, observable, locally and remotely callable function. + */ export type Action< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, M extends Record = Record, > = ((input: z.infer) => Promise>) & { __action: ActionMetadata; + run(input: z.infer): Promise>>; }; -export type SideChannelData = Record; - +/** + * Action factory params. + */ type ActionParams< I extends z.ZodTypeAny, O extends z.ZodTypeAny, @@ -73,10 +92,16 @@ type ActionParams< use?: Middleware, z.infer>[]; }; +/** + * Middleware function for actions. + */ export interface Middleware { (req: I, next: (req?: I) => Promise): Promise; } +/** + * Creates an action with provided middleware. + */ export function actionWithMiddleware< I extends z.ZodTypeAny, O extends z.ZodTypeAny, @@ -86,10 +111,17 @@ export function actionWithMiddleware< middleware: Middleware, z.infer>[] ): Action { const wrapped = (async (req: z.infer) => { + return (await wrapped.run(req)).result; + }) as Action; + wrapped.__action = action.__action; + wrapped.run = async (req: z.infer): Promise>> => { + let telemetry; const dispatch = async (index: number, req: z.infer) => { if (index === middleware.length) { // end of the chain, call the original model action - return await action(req); + const result = await action.run(req); + telemetry = result.telemetry; + return result.result; } const currentMiddleware = middleware[index]; @@ -98,9 +130,8 @@ export function actionWithMiddleware< ); }; - return await dispatch(0, req); - }) as Action; - wrapped.__action = action.__action; + return { result: await dispatch(0, req), telemetry }; + }; return wrapped; } @@ -120,10 +151,26 @@ export function action< ? config.name : `${config.name.pluginId}/${config.name.actionId}`; const actionFn = async (input: I) => { + return (await actionFn.run(input)).result; + }; + actionFn.__action = { + name: actionName, + description: config.description, + inputSchema: config.inputSchema, + inputJsonSchema: config.inputJsonSchema, + outputSchema: config.outputSchema, + outputJsonSchema: config.outputJsonSchema, + metadata: config.metadata, + } as ActionMetadata; + actionFn.run = async ( + input: z.infer + ): Promise>> => { input = parseSchema(input, { schema: config.inputSchema, jsonSchema: config.inputJsonSchema, }); + let traceId; + let spanId; let output = await newTrace( { name: actionName, @@ -131,7 +178,9 @@ export function action< [SPAN_TYPE_ATTR]: 'action', }, }, - async (metadata) => { + async (metadata, span) => { + traceId = span.spanContext().traceId; + spanId = span.spanContext().spanId; metadata.name = actionName; metadata.input = input; @@ -145,17 +194,14 @@ export function action< schema: config.outputSchema, jsonSchema: config.outputJsonSchema, }); - return output; + return { + result: output, + telemetry: { + traceId, + spanId, + }, + }; }; - actionFn.__action = { - name: actionName, - description: config.description, - inputSchema: config.inputSchema, - inputJsonSchema: config.inputJsonSchema, - outputSchema: config.outputSchema, - outputJsonSchema: config.outputJsonSchema, - metadata: config.metadata, - } as ActionMetadata; if (config.use) { return actionWithMiddleware(actionFn, config.use); diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index 135ca60e3..4ebf30a69 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -25,12 +25,7 @@ import { GENKIT_VERSION } from './index.js'; import { logger } from './logging.js'; import { Registry } from './registry.js'; import { toJsonSchema } from './schema.js'; -import { - flushTracing, - newTrace, - setCustomMetadataAttribute, - setTelemetryServerUrl, -} from './tracing.js'; +import { flushTracing, setTelemetryServerUrl } from './tracing.js'; // TODO: Move this to common location for schemas. export const RunActionResponseSchema = z.object({ @@ -169,48 +164,30 @@ export class ReflectionServer { return; } if (stream === 'true') { - const result = await newTrace( - { name: 'dev-run-action-wrapper' }, - async (_, span) => { - setCustomMetadataAttribute('genkit-dev-internal', 'true'); - traceId = span.spanContext().traceId; - return await runWithStreamingCallback( - (chunk) => { - response.write(JSON.stringify(chunk) + '\n'); - }, - async () => await action(input) - ); - } + const result = await runWithStreamingCallback( + (chunk) => { + response.write(JSON.stringify(chunk) + '\n'); + }, + async () => await action.run(input) ); await flushTracing(); response.write( JSON.stringify({ - result, - telemetry: traceId - ? { - traceId, - } - : undefined, + result: result.result, + telemetry: { + traceId: result.telemetry.traceId, + }, } as RunActionResponse) ); response.end(); } else { - const result = await newTrace( - { name: 'dev-run-action-wrapper' }, - async (_, span) => { - setCustomMetadataAttribute('genkit-dev-internal', 'true'); - traceId = span.spanContext().traceId; - return await action(input); - } - ); + const result = await action.run(input); await flushTracing(); response.send({ - result, - telemetry: traceId - ? { - traceId, - } - : undefined, + result: result.result, + telemetry: { + traceId: result.telemetry.traceId, + }, } as RunActionResponse); } } catch (err) { diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index c76549bd4..439e3562b 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -41,4 +41,36 @@ describe('action', () => { 20 // "foomiddle1middle2".length + 1 + 2 ); }); + + it('returns telemetry info', async () => { + const act = action( + { + name: 'foo', + inputSchema: z.string(), + outputSchema: z.number(), + use: [ + async (input, next) => (await next(input + 'middle1')) + 1, + async (input, next) => (await next(input + 'middle2')) + 2, + ], + }, + async (input) => { + return input.length; + } + ); + + const result = await act.run('foo'); + assert.strictEqual( + result.result, + 20 // "foomiddle1middle2".length + 1 + 2 + ); + assert.strictEqual(result.telemetry !== null, true); + assert.strictEqual( + result.telemetry.traceId !== null && result.telemetry.traceId.length > 0, + true + ); + assert.strictEqual( + result.telemetry.spanId !== null && result.telemetry.spanId.length > 0, + true + ); + }); }); diff --git a/js/genkit/src/index.ts b/js/genkit/src/index.ts index 8002e7bd7..bc7dd3e0d 100644 --- a/js/genkit/src/index.ts +++ b/js/genkit/src/index.ts @@ -143,7 +143,6 @@ export { type Middleware, type ReflectionServerOptions, type RunActionResponse, - type SideChannelData, type Status, type StreamableFlow, type StreamingCallback,