From 3f8d7106749ef44b14d140dac3777ad2521834ec Mon Sep 17 00:00:00 2001 From: Andy Brenneke Date: Wed, 4 Sep 2024 11:11:46 -0700 Subject: [PATCH] Add "mark last message as cache breakpoint" to assemble prompt node --- .../src/model/nodes/AssemblePromptNode.ts | 34 +++++++++++++++++-- packages/core/src/model/nodes/PromptNode.ts | 2 +- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/packages/core/src/model/nodes/AssemblePromptNode.ts b/packages/core/src/model/nodes/AssemblePromptNode.ts index be13fac9f..5c8bb32cc 100644 --- a/packages/core/src/model/nodes/AssemblePromptNode.ts +++ b/packages/core/src/model/nodes/AssemblePromptNode.ts @@ -7,8 +7,8 @@ import { type PortId, } from '../NodeBase.js'; import { nanoid } from 'nanoid/non-secure'; -import { NodeImpl, type NodeUIData } from '../NodeImpl.js'; -import { type ChatMessage, arrayizeDataValue, unwrapDataValue } from '../DataValue.js'; +import { NodeImpl, type NodeBody, type NodeUIData } from '../NodeImpl.js'; +import { type ChatMessage, arrayizeDataValue, unwrapDataValue, type ChatMessageDataValue } from '../DataValue.js'; import { type Inputs, type Outputs } from '../GraphProcessor.js'; import { coerceType } from '../../utils/coerceType.js'; import { orderBy } from 'lodash-es'; @@ -17,11 +17,15 @@ import { nodeDefinition } from '../NodeDefinition.js'; import type { EditorDefinition } from '../EditorDefinition.js'; import type { RivetUIContext } from '../RivetUIContext.js'; import type { InternalProcessContext } from '../ProcessContext.js'; +import { getInputOrData } from '../../utils/inputs.js'; export type AssemblePromptNode = ChartNode<'assemblePrompt', AssemblePromptNodeData>; export type AssemblePromptNodeData = { computeTokenCount?: boolean; + + isLastMessageCacheBreakpoint?: boolean; + useIsLastMessageCacheBreakpointInput?: boolean; }; export class AssemblePromptNodeImpl extends NodeImpl { @@ -45,6 +49,15 @@ export class AssemblePromptNodeImpl extends NodeImpl { const inputs: NodeInputDefinition[] = []; const messageCount = this.#getMessagePortCount(connections); + if (this.data.useIsLastMessageCacheBreakpointInput) { + inputs.push({ + dataType: 'boolean', + id: 'isLastMessageCacheBreakpoint' as PortId, + title: 'Is Last Message Cache Breakpoint', + description: 'Whether the last message in the prompt should be a cache breakpoint.', + }); + } + for (let i = 1; i <= messageCount; i++) { inputs.push({ dataType: ['chat-message', 'chat-message[]'] as const, @@ -118,12 +131,25 @@ export class AssemblePromptNodeImpl extends NodeImpl { label: 'Compute Token Count', dataKey: 'computeTokenCount', }, + { + type: 'toggle', + label: 'Is Last Message Cache Breakpoint', + dataKey: 'isLastMessageCacheBreakpoint', + helperMessage: + 'For Anthropic, marks the last message as a cache breakpoint - this message and every message before it will be cached using Prompt Caching.', + }, ]; } + getBody(_context: RivetUIContext): NodeBody | Promise { + return this.data.isLastMessageCacheBreakpoint ? 'Last message is cache breakpoint' : ''; + } + async process(inputs: Inputs, context: InternalProcessContext): Promise { const output: Outputs = {}; + const isLastMessageCacheBreakpoint = getInputOrData(this.data, inputs, 'isLastMessageCacheBreakpoint', 'boolean'); + const outMessages: ChatMessage[] = []; const inputMessages = orderBy( @@ -151,6 +177,10 @@ export class AssemblePromptNodeImpl extends NodeImpl { } } + if (isLastMessageCacheBreakpoint && outMessages.length > 1) { + outMessages.at(-1)!.isCacheBreakpoint = true; + } + output['prompt' as PortId] = { type: 'chat-message[]', value: outMessages, diff --git a/packages/core/src/model/nodes/PromptNode.ts b/packages/core/src/model/nodes/PromptNode.ts index d04d01387..e40a2b74c 100644 --- a/packages/core/src/model/nodes/PromptNode.ts +++ b/packages/core/src/model/nodes/PromptNode.ts @@ -194,7 +194,7 @@ export class PromptNodeImpl extends NodeImpl { { type: 'markdown', text: dedent` - _${typeDisplay[this.data.type]}${this.data.name ? ` (${this.data.name})` : ''}_ + _${typeDisplay[this.data.type]}${this.data.name ? ` (${this.data.name})` : ''}_ ${this.data.isCacheBreakpoint ? ' (Cache Breakpoint)' : ''} `, }, {