Skip to content

Commit

Permalink
feat: added a way to run action with telemetry, no more 'dev-run-acti…
Browse files Browse the repository at this point in the history
…on-wrapper' (#1357)
  • Loading branch information
pavelgj authored Nov 22, 2024
1 parent 45eff97 commit 97ca6a7
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 56 deletions.
80 changes: 63 additions & 17 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ import {
export { Status, StatusCodes, StatusSchema } from './statusTypes.js';
export { JSONSchema7 };

/**
* Action metadata.
*/
export interface ActionMetadata<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
Expand All @@ -43,16 +46,32 @@ export interface ActionMetadata<
metadata?: M;
}

/**
* Results of an action run. Includes telemetry.
*/
export interface ActionResult<O> {
result: O;
telemetry: {
traceId: string;
spanId: string;
};
}

/**
* Self-describing, validating, observable, locally and remotely callable function.
*/
export type Action<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> = ((input: z.infer<I>) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, M>;
run(input: z.infer<I>): Promise<ActionResult<z.infer<O>>>;
};

export type SideChannelData = Record<string, any>;

/**
* Action factory params.
*/
type ActionParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
Expand All @@ -73,10 +92,16 @@ type ActionParams<
use?: Middleware<z.infer<I>, z.infer<O>>[];
};

/**
* Middleware function for actions.
*/
export interface Middleware<I = any, O = any> {
(req: I, next: (req?: I) => Promise<O>): Promise<O>;
}

/**
* Creates an action with provided middleware.
*/
export function actionWithMiddleware<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
Expand All @@ -86,10 +111,17 @@ export function actionWithMiddleware<
middleware: Middleware<z.infer<I>, z.infer<O>>[]
): Action<I, O, M> {
const wrapped = (async (req: z.infer<I>) => {
return (await wrapped.run(req)).result;
}) as Action<I, O, M>;
wrapped.__action = action.__action;
wrapped.run = async (req: z.infer<I>): Promise<ActionResult<z.infer<O>>> => {
let telemetry;
const dispatch = async (index: number, req: z.infer<I>) => {
if (index === middleware.length) {
// end of the chain, call the original model action
return await action(req);
const result = await action.run(req);
telemetry = result.telemetry;
return result.result;
}

const currentMiddleware = middleware[index];
Expand All @@ -98,9 +130,8 @@ export function actionWithMiddleware<
);
};

return await dispatch(0, req);
}) as Action<I, O, M>;
wrapped.__action = action.__action;
return { result: await dispatch(0, req), telemetry };
};
return wrapped;
}

Expand All @@ -120,18 +151,36 @@ export function action<
? config.name
: `${config.name.pluginId}/${config.name.actionId}`;
const actionFn = async (input: I) => {
return (await actionFn.run(input)).result;
};
actionFn.__action = {
name: actionName,
description: config.description,
inputSchema: config.inputSchema,
inputJsonSchema: config.inputJsonSchema,
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
} as ActionMetadata<I, O, M>;
actionFn.run = async (
input: z.infer<I>
): Promise<ActionResult<z.infer<O>>> => {
input = parseSchema(input, {
schema: config.inputSchema,
jsonSchema: config.inputJsonSchema,
});
let traceId;
let spanId;
let output = await newTrace(
{
name: actionName,
labels: {
[SPAN_TYPE_ATTR]: 'action',
},
},
async (metadata) => {
async (metadata, span) => {
traceId = span.spanContext().traceId;
spanId = span.spanContext().spanId;
metadata.name = actionName;
metadata.input = input;

Expand All @@ -145,17 +194,14 @@ export function action<
schema: config.outputSchema,
jsonSchema: config.outputJsonSchema,
});
return output;
return {
result: output,
telemetry: {
traceId,
spanId,
},
};
};
actionFn.__action = {
name: actionName,
description: config.description,
inputSchema: config.inputSchema,
inputJsonSchema: config.inputJsonSchema,
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
} as ActionMetadata<I, O, M>;

if (config.use) {
return actionWithMiddleware(actionFn, config.use);
Expand Down
53 changes: 15 additions & 38 deletions js/core/src/reflection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ import { GENKIT_VERSION } from './index.js';
import { logger } from './logging.js';
import { Registry } from './registry.js';
import { toJsonSchema } from './schema.js';
import {
flushTracing,
newTrace,
setCustomMetadataAttribute,
setTelemetryServerUrl,
} from './tracing.js';
import { flushTracing, setTelemetryServerUrl } from './tracing.js';

// TODO: Move this to common location for schemas.
export const RunActionResponseSchema = z.object({
Expand Down Expand Up @@ -169,48 +164,30 @@ export class ReflectionServer {
return;
}
if (stream === 'true') {
const result = await newTrace(
{ name: 'dev-run-action-wrapper' },
async (_, span) => {
setCustomMetadataAttribute('genkit-dev-internal', 'true');
traceId = span.spanContext().traceId;
return await runWithStreamingCallback(
(chunk) => {
response.write(JSON.stringify(chunk) + '\n');
},
async () => await action(input)
);
}
const result = await runWithStreamingCallback(
(chunk) => {
response.write(JSON.stringify(chunk) + '\n');
},
async () => await action.run(input)
);
await flushTracing();
response.write(
JSON.stringify({
result,
telemetry: traceId
? {
traceId,
}
: undefined,
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
},
} as RunActionResponse)
);
response.end();
} else {
const result = await newTrace(
{ name: 'dev-run-action-wrapper' },
async (_, span) => {
setCustomMetadataAttribute('genkit-dev-internal', 'true');
traceId = span.spanContext().traceId;
return await action(input);
}
);
const result = await action.run(input);
await flushTracing();
response.send({
result,
telemetry: traceId
? {
traceId,
}
: undefined,
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
},
} as RunActionResponse);
}
} catch (err) {
Expand Down
32 changes: 32 additions & 0 deletions js/core/tests/action_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,36 @@ describe('action', () => {
20 // "foomiddle1middle2".length + 1 + 2
);
});

it('returns telemetry info', async () => {
const act = action(
{
name: 'foo',
inputSchema: z.string(),
outputSchema: z.number(),
use: [
async (input, next) => (await next(input + 'middle1')) + 1,
async (input, next) => (await next(input + 'middle2')) + 2,
],
},
async (input) => {
return input.length;
}
);

const result = await act.run('foo');
assert.strictEqual(
result.result,
20 // "foomiddle1middle2".length + 1 + 2
);
assert.strictEqual(result.telemetry !== null, true);
assert.strictEqual(
result.telemetry.traceId !== null && result.telemetry.traceId.length > 0,
true
);
assert.strictEqual(
result.telemetry.spanId !== null && result.telemetry.spanId.length > 0,
true
);
});
});
1 change: 0 additions & 1 deletion js/genkit/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ export {
type Middleware,
type ReflectionServerOptions,
type RunActionResponse,
type SideChannelData,
type Status,
type StreamableFlow,
type StreamingCallback,
Expand Down

0 comments on commit 97ca6a7

Please sign in to comment.