Skip to content

Commit

Permalink
Add "mark last message as cache breakpoint" to assemble prompt node
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Sep 4, 2024
1 parent b71c26f commit 3f8d710
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
34 changes: 32 additions & 2 deletions packages/core/src/model/nodes/AssemblePromptNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import {
type PortId,
} from '../NodeBase.js';
import { nanoid } from 'nanoid/non-secure';
import { NodeImpl, type NodeUIData } from '../NodeImpl.js';
import { type ChatMessage, arrayizeDataValue, unwrapDataValue } from '../DataValue.js';
import { NodeImpl, type NodeBody, type NodeUIData } from '../NodeImpl.js';
import { type ChatMessage, arrayizeDataValue, unwrapDataValue, type ChatMessageDataValue } from '../DataValue.js';
import { type Inputs, type Outputs } from '../GraphProcessor.js';
import { coerceType } from '../../utils/coerceType.js';
import { orderBy } from 'lodash-es';
Expand All @@ -17,11 +17,15 @@ import { nodeDefinition } from '../NodeDefinition.js';
import type { EditorDefinition } from '../EditorDefinition.js';
import type { RivetUIContext } from '../RivetUIContext.js';
import type { InternalProcessContext } from '../ProcessContext.js';
import { getInputOrData } from '../../utils/inputs.js';

export type AssemblePromptNode = ChartNode<'assemblePrompt', AssemblePromptNodeData>;

export type AssemblePromptNodeData = {
computeTokenCount?: boolean;

isLastMessageCacheBreakpoint?: boolean;
useIsLastMessageCacheBreakpointInput?: boolean;
};

export class AssemblePromptNodeImpl extends NodeImpl<AssemblePromptNode> {
Expand All @@ -45,6 +49,15 @@ export class AssemblePromptNodeImpl extends NodeImpl<AssemblePromptNode> {
const inputs: NodeInputDefinition[] = [];
const messageCount = this.#getMessagePortCount(connections);

if (this.data.useIsLastMessageCacheBreakpointInput) {
inputs.push({
dataType: 'boolean',
id: 'isLastMessageCacheBreakpoint' as PortId,
title: 'Is Last Message Cache Breakpoint',
description: 'Whether the last message in the prompt should be a cache breakpoint.',
});
}

for (let i = 1; i <= messageCount; i++) {
inputs.push({
dataType: ['chat-message', 'chat-message[]'] as const,
Expand Down Expand Up @@ -118,12 +131,25 @@ export class AssemblePromptNodeImpl extends NodeImpl<AssemblePromptNode> {
label: 'Compute Token Count',
dataKey: 'computeTokenCount',
},
{
type: 'toggle',
label: 'Is Last Message Cache Breakpoint',
dataKey: 'isLastMessageCacheBreakpoint',
helperMessage:
'For Anthropic, marks the last message as a cache breakpoint - this message and every message before it will be cached using Prompt Caching.',
},
];
}

getBody(_context: RivetUIContext): NodeBody | Promise<NodeBody> {
return this.data.isLastMessageCacheBreakpoint ? 'Last message is cache breakpoint' : '';
}

async process(inputs: Inputs, context: InternalProcessContext): Promise<Outputs> {
const output: Outputs = {};

const isLastMessageCacheBreakpoint = getInputOrData(this.data, inputs, 'isLastMessageCacheBreakpoint', 'boolean');

const outMessages: ChatMessage[] = [];

const inputMessages = orderBy(
Expand Down Expand Up @@ -151,6 +177,10 @@ export class AssemblePromptNodeImpl extends NodeImpl<AssemblePromptNode> {
}
}

if (isLastMessageCacheBreakpoint && outMessages.length > 1) {
outMessages.at(-1)!.isCacheBreakpoint = true;
}

output['prompt' as PortId] = {
type: 'chat-message[]',
value: outMessages,
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/model/nodes/PromptNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ export class PromptNodeImpl extends NodeImpl<PromptNode> {
{
type: 'markdown',
text: dedent`
_${typeDisplay[this.data.type]}${this.data.name ? ` (${this.data.name})` : ''}_
_${typeDisplay[this.data.type]}${this.data.name ? ` (${this.data.name})` : ''}_ ${this.data.isCacheBreakpoint ? ' (Cache Breakpoint)' : ''}
`,
},
{
Expand Down

0 comments on commit 3f8d710

Please sign in to comment.