diff --git a/.pnp.cjs b/.pnp.cjs index 1179779d7..d431cf37d 100755 --- a/.pnp.cjs +++ b/.pnp.cjs @@ -6635,8 +6635,10 @@ const RAW_RUNTIME_STATE = ["immer", "npm:10.0.3"],\ ["jsonpath-plus", "npm:7.2.0"],\ ["lodash-es", "npm:4.17.21"],\ + ["lru-cache", "npm:11.0.0"],\ ["majesticons", "npm:2.1.2"],\ ["marked", "npm:9.1.2"],\ + ["mime", "npm:4.0.4"],\ ["minimatch", "npm:9.0.3"],\ ["monaco-editor", "npm:0.44.0"],\ ["nanoid", "npm:3.3.7"],\ @@ -19184,6 +19186,13 @@ const RAW_RUNTIME_STATE = ],\ "linkType": "HARD"\ }],\ + ["npm:11.0.0", {\ + "packageLocation": "./.yarn/cache/lru-cache-npm-11.0.0-92d560d9d6-41f36fbff8.zip/node_modules/lru-cache/",\ + "packageDependencies": [\ + ["lru-cache", "npm:11.0.0"]\ + ],\ + "linkType": "HARD"\ + }],\ ["npm:5.1.1", {\ "packageLocation": "./.yarn/cache/lru-cache-npm-5.1.1-f475882a51-951d2673dc.zip/node_modules/lru-cache/",\ "packageDependencies": [\ @@ -19463,6 +19472,13 @@ const RAW_RUNTIME_STATE = ["mime", "npm:1.6.0"]\ ],\ "linkType": "HARD"\ + }],\ + ["npm:4.0.4", {\ + "packageLocation": "./.yarn/cache/mime-npm-4.0.4-03acf1c40a-28e41053ae.zip/node_modules/mime/",\ + "packageDependencies": [\ + ["mime", "npm:4.0.4"]\ + ],\ + "linkType": "HARD"\ }]\ ]],\ ["mime-db", [\ diff --git a/.yarn/cache/lru-cache-npm-11.0.0-92d560d9d6-41f36fbff8.zip b/.yarn/cache/lru-cache-npm-11.0.0-92d560d9d6-41f36fbff8.zip new file mode 100644 index 000000000..406193197 Binary files /dev/null and b/.yarn/cache/lru-cache-npm-11.0.0-92d560d9d6-41f36fbff8.zip differ diff --git a/.yarn/cache/mime-npm-4.0.4-03acf1c40a-28e41053ae.zip b/.yarn/cache/mime-npm-4.0.4-03acf1c40a-28e41053ae.zip new file mode 100644 index 000000000..1f2f40b8c Binary files /dev/null and b/.yarn/cache/mime-npm-4.0.4-03acf1c40a-28e41053ae.zip differ diff --git a/packages/app/package.json b/packages/app/package.json index 5e617d115..2a92af6e2 100644 --- a/packages/app/package.json +++ b/packages/app/package.json @@ -88,8 +88,10 @@ "immer": "^10.0.3", "jsonpath-plus": "^7.2.0", "lodash-es": "^4.17.21", + "lru-cache": "^11.0.0", "majesticons": "^2.1.2", "marked": "^9.1.2", + "mime": "^4.0.4", "minimatch": "^9.0.3", "monaco-editor": "^0.44.0", "nanoid": "^3.3.6", diff --git a/packages/app/src/components/ChatViewer.tsx b/packages/app/src/components/ChatViewer.tsx index ba6671b45..f7fa0819a 100644 --- a/packages/app/src/components/ChatViewer.tsx +++ b/packages/app/src/components/ChatViewer.tsx @@ -15,7 +15,13 @@ import { arrayizeDataValue, coerceTypeOptional, } from '@ironclad/rivet-core'; -import { type NodeRunData, lastRunDataByNodeState, graphRunningState } from '../state/dataFlow'; +import { + type NodeRunData, + lastRunDataByNodeState, + graphRunningState, + type NodeRunDataWithRefs, + type DataValueWithRefs, +} from '../state/dataFlow'; import { projectState } from '../state/savedGraphs'; import { ErrorBoundary } from 'react-error-boundary'; import TextField from '@atlaskit/textfield'; @@ -324,7 +330,7 @@ const ChatBubble: FC<{ nodeId: NodeId; nodeTitle: string; processId: ProcessId; - data: NodeRunData; + data: NodeRunDataWithRefs; splitIndex: number; style?: CSSProperties; onGoToNode?: (nodeId: NodeId) => void; @@ -333,16 +339,16 @@ const ChatBubble: FC<{ const responseRef = useRef(null); const [expanded, toggleExpanded] = useToggle(); - let prompt: DataValue; + let prompt: DataValueWithRefs; if (splitIndex === -1) { prompt = data.inputData?.['prompt' as PortId]!; } else { const values = arrayizeDataValue(data.inputData?.['prompt' as PortId] as ScalarOrArrayDataValue); if (values.length === 1) { - prompt = values[0]!; + prompt = values[0]! as DataValueWithRefs; } else { - prompt = values[splitIndex]!; + prompt = values[splitIndex]! as DataValueWithRefs; } } @@ -351,8 +357,8 @@ const ChatBubble: FC<{ ? data.outputData?.['response' as PortId] : data.splitOutputData![splitIndex]!['response' as PortId]; - const promptText = coerceTypeOptional(prompt, 'string'); - const responseText = coerceTypeOptional(chatOutput, 'string'); + const promptText = coerceTypeOptional(prompt as DataValue, 'string'); + const responseText = coerceTypeOptional(chatOutput as DataValue, 'string'); useLayoutEffect(() => { if (promptRef.current) { diff --git a/packages/app/src/components/NodeOutput.tsx b/packages/app/src/components/NodeOutput.tsx index fefd63997..80f5d0af3 100644 --- a/packages/app/src/components/NodeOutput.tsx +++ b/packages/app/src/components/NodeOutput.tsx @@ -1,10 +1,16 @@ import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'; -import { type NodeRunData, type ProcessDataForNode, lastRunData, selectedProcessPage } from '../state/dataFlow.js'; +import { + type NodeRunData, + type ProcessDataForNode, + lastRunData, + selectedProcessPage, + type NodeRunDataWithRefs, +} from '../state/dataFlow.js'; import { type FC, type ReactNode, memo, useMemo, useState, type MouseEvent } from 'react'; import { useUnknownNodeComponentDescriptorFor } from '../hooks/useNodeTypes.js'; import { useStableCallback } from '../hooks/useStableCallback.js'; import { copyToClipboard } from '../utils/copyToClipboard.js'; -import { type ChartNode, type PortId, type ProcessId, getWarnings } from '@ironclad/rivet-core'; +import { type ChartNode, type PortId, type ProcessId, getWarnings, type Outputs } from '@ironclad/rivet-core'; import { css } from '@emotion/react'; import CopyIcon from 'majesticons/line/clipboard-line.svg?react'; import ExpandIcon from 'majesticons/line/maximize-line.svg?react'; @@ -350,7 +356,6 @@ const NodeOutputBase: FC<{ node: ChartNode; children?: ReactNode; onOpenFullscre onOpenFullscreenModal, }) => { const output = useRecoilValue(lastRunData(node.id)); - if (!output?.length) { return null; } @@ -377,7 +382,7 @@ const NodeOutputBase: FC<{ node: ChartNode; children?: ReactNode; onOpenFullscre const NodeOutputSingleProcess: FC<{ node: ChartNode; - data: NodeRunData; + data: NodeRunDataWithRefs; processId: ProcessId; onOpenFullscreenModal?: () => void; }> = ({ node, data, processId, onOpenFullscreenModal }) => { @@ -482,9 +487,9 @@ const NodeOutputSingleProcess: FC<{ {body} - {getWarnings(data.outputData) && ( + {getWarnings(data.outputData as Outputs) && (
- {getWarnings(data.outputData)!.map((warning) => ( + {getWarnings(data.outputData as Outputs)!.map((warning) => (
{warning}
diff --git a/packages/app/src/components/PromptDesigner.tsx b/packages/app/src/components/PromptDesigner.tsx index f141ac167..abcd15b40 100644 --- a/packages/app/src/components/PromptDesigner.tsx +++ b/packages/app/src/components/PromptDesigner.tsx @@ -12,7 +12,7 @@ import { promptDesignerTestGroupResultsByNodeIdState, } from '../state/promptDesigner'; import { nodesByIdState, nodesState } from '../state/graph.js'; -import { lastRunDataByNodeState } from '../state/dataFlow.js'; +import { type InputsOrOutputsWithRefs, lastRunDataByNodeState } from '../state/dataFlow.js'; import { type ChatMessage, type ChatNode, @@ -31,6 +31,9 @@ import { getError, isArrayDataValue, openai, + type DataValue, + type ScalarDataValue, + type Inputs, } from '@ironclad/rivet-core'; import TextField from '@atlaskit/textfield'; import { Field } from '@atlaskit/form'; @@ -354,9 +357,11 @@ export const PromptDesigner: FC = ({ onClose }) => { let inputData = nodeDataForAttachedNodeProcess.inputData; // If node is a split run, just grab the first input data. if (attachedNode.isSplitRun) { - inputData = mapValues(inputData, (val) => (isArrayDataValue(val) ? arrayizeDataValue(val)[0] : val)); + inputData = mapValues(inputData, (val) => + isArrayDataValue(val as DataValue) ? arrayizeDataValue(val as ScalarDataValue)[0] : val, + ) as InputsOrOutputsWithRefs; } - const { messages } = getChatNodeMessages(inputData); + const { messages } = getChatNodeMessages(inputData as Inputs); setMessages({ messages, }); diff --git a/packages/app/src/components/RenderDataValue.tsx b/packages/app/src/components/RenderDataValue.tsx index 4aee6fec4..d16a5d4cb 100644 --- a/packages/app/src/components/RenderDataValue.tsx +++ b/packages/app/src/components/RenderDataValue.tsx @@ -1,18 +1,19 @@ -import { type FC } from 'react'; +import { useMemo, type FC } from 'react'; import { type DataValue, - type Outputs, type ScalarDataType, - type ScalarDataValue, arrayizeDataValue, getScalarTypeOf, inferType, - isArrayDataValue, isFunctionDataValue, - coerceTypeOptional, type NodeOutputDefinition, type ChatMessageMessagePart, type DataType, + type ImageDataValue, + type BinaryDataValue, + type AudioDataValue, + isArrayDataType, + type ScalarOrArrayDataValue, } from '@ironclad/rivet-core'; import { css } from '@emotion/react'; import { keys } from '../../../core/src/utils/typeSafety'; @@ -20,6 +21,8 @@ import { useMarkdown } from '../hooks/useMarkdown'; import ColorizedPreformattedText from './ColorizedPreformattedText'; import { P, match } from 'ts-pattern'; import clsx from 'clsx'; +import { type InputsOrOutputsWithRefs, type DataValueWithRefs, type ScalarDataValueWithRefs } from '../state/dataFlow'; +import { getGlobalDataRef } from '../utils/globals'; const styles = css` .chat-message.user header em { @@ -81,7 +84,7 @@ const multiOutput = css` `; type ScalarRendererProps = { - value: Extract; + value: Extract; depth?: number; renderMarkdown?: boolean; truncateLength?: number; @@ -148,7 +151,7 @@ const scalarRenderers: {
{message.function_calls.map((fc, i) => (
- +
))}
@@ -158,7 +161,7 @@ const scalarRenderers: {

Function Call:

- +
) @@ -191,7 +194,9 @@ const scalarRenderers: { if (inferred.type === 'any') { return <>{JSON.stringify(inferred.value)}; } - return ; + return ( + + ); }, object: ({ value }) => (
@@ -205,12 +210,19 @@ const scalarRenderers: { ), vector: ({ value }) => <>Vector (length {value.value.length}), image: ({ value }) => { + const resolved = getGlobalDataRef(value.value.ref); + if (!resolved) { + return
Could not find data.
; + } + const { value: { data, mediaType }, - } = value; + } = resolved as ImageDataValue; - const blob = new Blob([data], { type: mediaType }); - const imageUrl = URL.createObjectURL(blob); + const imageUrl = useMemo(() => { + const blob = new Blob([data], { type: mediaType }); + return URL.createObjectURL(blob); + }, [data, mediaType]); return (
@@ -219,19 +231,39 @@ const scalarRenderers: { ); }, binary: ({ value }) => { + const resolved = getGlobalDataRef(value.value.ref); + if (!resolved) { + return
Could not find data.
; + } + // FIXME: Coercing `value.value` into a `Uint8Array` here because `Uint8Array` gets parsed as an // object of shape `{ [index: number]: number }` when stringified via `JSON.stringify()`. // Consider coercing it back to `Uint8Array` at the entrypoints of the boundaries between // browser and node.js instead. - const coercedValue = new Uint8Array(Object.values(value.value)); + const coercedValue = useMemo(() => { + const resolved = getGlobalDataRef(value.value.ref); + if (resolved!.value instanceof Uint8Array) { + return resolved!.value; + } + return new Uint8Array(Object.values((resolved as BinaryDataValue).value)); + }, [value.value.ref]); + return <>Binary (length {coercedValue.length.toLocaleString()}); }, audio: ({ value }) => { + const resolved = getGlobalDataRef(value.value.ref); + if (!resolved) { + return
Could not find data.
; + } + const { - value: { data }, - } = value; + value: { data, mediaType }, + } = resolved as AudioDataValue; - const dataUri = `data:audio/mp4;base64,${data}`; + const dataUri = useMemo(() => { + const blob = new Blob([data], { type: mediaType }); + return URL.createObjectURL(blob); + }, [data, mediaType]); return (
@@ -257,8 +289,14 @@ const RenderChatMessagePart: FC<{ part: ChatMessageMessagePart; renderMarkdown?: return ; }) .with({ type: 'image' }, (part) => { - const Renderer = scalarRenderers.image; - return ; + const blob = new Blob([part.data], { type: part.mediaType }); + const imageUrl = URL.createObjectURL(blob); + + return ( +
+ +
+ ); }) .with({ type: 'url' }, (part) => { return {part.url}; @@ -267,7 +305,7 @@ const RenderChatMessagePart: FC<{ part: ChatMessageMessagePart; renderMarkdown?: }; export const RenderDataValue: FC<{ - value: DataValue | undefined; + value: DataValueWithRefs | undefined; depth?: number; renderMarkdown?: boolean; truncateLength?: number; @@ -280,8 +318,8 @@ export const RenderDataValue: FC<{ } const keys = Object.keys(value?.value ?? {}); - if (isArrayDataValue(value)) { - const items = arrayizeDataValue(value); + if (isArrayDataType(value.type)) { + const items = arrayizeDataValue(value as ScalarOrArrayDataValue); return (
@@ -325,7 +363,7 @@ export const RenderDataValue: FC<{ return (
= ({ definitions, outputs, renderMarkdown }) => { const outputPorts = keys(outputs); diff --git a/packages/app/src/components/editors/FileBrowserEditor.tsx b/packages/app/src/components/editors/FileBrowserEditor.tsx index 3d2b89629..43903f20d 100644 --- a/packages/app/src/components/editors/FileBrowserEditor.tsx +++ b/packages/app/src/components/editors/FileBrowserEditor.tsx @@ -16,6 +16,7 @@ import { projectDataState } from '../../state/savedGraphs'; import { ioProvider } from '../../utils/globals'; import { type SharedEditorProps } from './SharedEditorProps'; import { getHelperMessage } from './editorUtils'; +import mime from 'mime'; export const DefaultFileBrowserEditor: FC< SharedEditorProps & { @@ -27,7 +28,7 @@ export const DefaultFileBrowserEditor: FC< const helperMessage = getHelperMessage(editor, node.data); const pickFile = async () => { - await ioProvider.readFileAsBinary(async (binaryData) => { + await ioProvider.readFileAsBinary(async (binaryData, fileName) => { const dataId = nanoid() as DataId; onChange( { @@ -37,6 +38,7 @@ export const DefaultFileBrowserEditor: FC< [editor.dataKey]: { refId: dataId, } satisfies DataRef, + [editor.mediaTypeDataKey]: mime.getType(fileName) ?? 'application/octet-stream', }, }, { diff --git a/packages/app/src/components/editors/ImageBrowserEditor.tsx b/packages/app/src/components/editors/ImageBrowserEditor.tsx index d0e206fc3..18ae4acbf 100644 --- a/packages/app/src/components/editors/ImageBrowserEditor.tsx +++ b/packages/app/src/components/editors/ImageBrowserEditor.tsx @@ -14,6 +14,7 @@ import { projectDataState } from '../../state/savedGraphs'; import { ioProvider } from '../../utils/globals'; import { type SharedEditorProps } from './SharedEditorProps'; import { getHelperMessage } from './editorUtils'; +import mime from 'mime'; export const DefaultImageBrowserEditor: FC< SharedEditorProps & { @@ -36,6 +37,7 @@ export const DefaultImageBrowserEditor: FC< [editor.dataKey]: { refId: dataId, } satisfies DataRef, + [editor.mediaTypeDataKey]: mime.getType(editor.dataKey) ?? 'image/png', }, }, { diff --git a/packages/app/src/components/nodes/AudioNode.tsx b/packages/app/src/components/nodes/AudioNode.tsx index f8b1e82f1..a8802ba0d 100644 --- a/packages/app/src/components/nodes/AudioNode.tsx +++ b/packages/app/src/components/nodes/AudioNode.tsx @@ -1,4 +1,4 @@ -import { type FC, useLayoutEffect, useRef } from 'react'; +import { type FC, useLayoutEffect, useRef, useMemo } from 'react'; import { type NodeComponentDescriptor } from '../../hooks/useNodeTypes'; import { type AudioNode } from '@ironclad/rivet-core'; import { css } from '@emotion/react'; @@ -16,11 +16,20 @@ type AudioNodeBodyProps = { }; export const AudioNodeBody: FC = ({ node }) => { + if (node.data.useDataInput) { + return
Audio data from input
; + } + const projectData = useRecoilValue(projectDataState); const dataRef = node.data.data; + const b64Data = dataRef ? projectData?.[dataRef.refId] : undefined; - const dataUri = b64Data ? `data:audio/mp4;base64,${b64Data}` : undefined; + + const dataUri = useMemo( + () => `data:${node.data.mediaType ?? 'audio/wav'};base64,${b64Data}`, + [b64Data, node.data.mediaType], + ); const audioSourceRef = useRef(null); diff --git a/packages/app/src/components/nodes/ChatNode.tsx b/packages/app/src/components/nodes/ChatNode.tsx index 35585de5a..30d4998ae 100644 --- a/packages/app/src/components/nodes/ChatNode.tsx +++ b/packages/app/src/components/nodes/ChatNode.tsx @@ -8,11 +8,13 @@ import { coerceTypeOptional, inferType, isArrayDataValue, + type DataValue, } from '@ironclad/rivet-core'; import { type NodeComponentDescriptor } from '../../hooks/useNodeTypes.js'; import styled from '@emotion/styled'; import clsx from 'clsx'; import { useMarkdown } from '../../hooks/useMarkdown.js'; +import { type InputsOrOutputsWithRefs, type DataValueWithRefs } from '../../state/dataFlow'; const bodyStyles = css` display: flex; @@ -28,19 +30,23 @@ const bodyStyles = css` `; export const ChatNodeOutput: FC<{ - outputs: Outputs; + outputs: InputsOrOutputsWithRefs; fullscreen?: boolean; renderMarkdown?: boolean; }> = ({ outputs, fullscreen, renderMarkdown }) => { - if (isArrayDataValue(outputs['response' as PortId]) || isArrayDataValue(outputs['requestTokens' as PortId])) { - const outputTextAll = coerceTypeOptional(outputs['response' as PortId], 'string[]') ?? []; - - const requestTokensAll = coerceTypeOptional(outputs['requestTokens' as PortId], 'number[]') ?? []; - const responseTokensAll = coerceTypeOptional(outputs['responseTokens' as PortId], 'number[]') ?? []; - const costAll = coerceTypeOptional(outputs['cost' as PortId], 'number[]') ?? []; - const durationAll = coerceTypeOptional(outputs['duration' as PortId], 'number[]') ?? []; - - const functionCallOutput = outputs['function-call' as PortId] ?? outputs['function-calls' as PortId]; + if ( + isArrayDataValue(outputs['response' as PortId] as DataValue) || + isArrayDataValue(outputs['requestTokens' as PortId] as DataValue) + ) { + const outputTextAll = coerceTypeOptional(outputs['response' as PortId] as DataValue, 'string[]') ?? []; + + const requestTokensAll = coerceTypeOptional(outputs['requestTokens' as PortId] as DataValue, 'number[]') ?? []; + const responseTokensAll = coerceTypeOptional(outputs['responseTokens' as PortId] as DataValue, 'number[]') ?? []; + const costAll = coerceTypeOptional(outputs['cost' as PortId] as DataValue, 'number[]') ?? []; + const durationAll = coerceTypeOptional(outputs['duration' as PortId] as DataValue, 'number[]') ?? []; + + const functionCallOutput = + (outputs['function-call' as PortId] as DataValue) ?? (outputs['function-calls' as PortId] as DataValue); const functionCallAll = functionCallOutput?.type === 'object[]' ? functionCallOutput.value @@ -72,14 +78,15 @@ export const ChatNodeOutput: FC<{
); } else { - const outputText = coerceTypeOptional(outputs['response' as PortId], 'string'); + const outputText = coerceTypeOptional(outputs['response' as PortId] as DataValue, 'string'); - const requestTokens = coerceTypeOptional(outputs['requestTokens' as PortId], 'number'); - const responseTokens = coerceTypeOptional(outputs['responseTokens' as PortId], 'number'); - const cost = coerceTypeOptional(outputs['cost' as PortId], 'number'); - const duration = coerceTypeOptional(outputs['duration' as PortId], 'number'); + const requestTokens = coerceTypeOptional(outputs['requestTokens' as PortId] as DataValue, 'number'); + const responseTokens = coerceTypeOptional(outputs['responseTokens' as PortId] as DataValue, 'number'); + const cost = coerceTypeOptional(outputs['cost' as PortId] as DataValue, 'number'); + const duration = coerceTypeOptional(outputs['duration' as PortId] as DataValue, 'number'); - const functionCallOutput = outputs['function-call' as PortId] ?? outputs['function-calls' as PortId]; + const functionCallOutput = + (outputs['function-call' as PortId] as DataValue) ?? (outputs['function-calls' as PortId] as DataValue); return ( ) : (
- +
)}
@@ -180,7 +187,7 @@ export const ChatNodeOutputSingle: FC<{

{Array.isArray(functionCall) ? 'Function Calls' : 'Function Call'}:

- +
)} @@ -189,7 +196,7 @@ export const ChatNodeOutputSingle: FC<{ }; const ChatNodeFullscreenOutput: FC<{ - outputs: Outputs; + outputs: InputsOrOutputsWithRefs; renderMarkdown: boolean; }> = ({ outputs, renderMarkdown }) => { return ; diff --git a/packages/app/src/components/nodes/LoopControllerNode.tsx b/packages/app/src/components/nodes/LoopControllerNode.tsx index d92e15e4c..6cd2ca595 100644 --- a/packages/app/src/components/nodes/LoopControllerNode.tsx +++ b/packages/app/src/components/nodes/LoopControllerNode.tsx @@ -2,8 +2,11 @@ import { type FC } from 'react'; import { type Outputs, type PortId } from '@ironclad/rivet-core'; import { RenderDataValue } from '../RenderDataValue.js'; import { type NodeComponentDescriptor } from '../../hooks/useNodeTypes.js'; +import { type InputsOrOutputsWithRefs, type DataValueWithRefs } from '../../state/dataFlow'; -export const LoopControllerNodeOutput: FC<{ outputs: Outputs; renderMarkdown?: boolean }> = ({ outputs }) => { +export const LoopControllerNodeOutput: FC<{ outputs: InputsOrOutputsWithRefs; renderMarkdown?: boolean }> = ({ + outputs, +}) => { const outputKeys = Object.keys(outputs).filter((key) => key.startsWith('output')); const breakLoop = outputs['break' as PortId] != null && outputs['break' as PortId]!.type !== 'control-flow-excluded'; @@ -19,7 +22,7 @@ export const LoopControllerNodeOutput: FC<{ outputs: Outputs; renderMarkdown?: b
Output {i + 1}
- +
))}
diff --git a/packages/app/src/components/nodes/ReadDirectoryNode.tsx b/packages/app/src/components/nodes/ReadDirectoryNode.tsx index 566358646..1d0471655 100644 --- a/packages/app/src/components/nodes/ReadDirectoryNode.tsx +++ b/packages/app/src/components/nodes/ReadDirectoryNode.tsx @@ -2,12 +2,20 @@ import { type FC } from 'react'; import { css } from '@emotion/react'; import Toggle from '@atlaskit/toggle'; import Button from '@atlaskit/button'; -import { type ChartNode, type Outputs, type PortId, type ReadDirectoryNode, expectType } from '@ironclad/rivet-core'; +import { + type ChartNode, + type Outputs, + type PortId, + type ReadDirectoryNode, + expectType, + type DataValue, +} from '@ironclad/rivet-core'; import { type NodeComponentDescriptor } from '../../hooks/useNodeTypes.js'; import { ioProvider } from '../../utils/globals.js'; +import { type InputsOrOutputsWithRefs } from '../../state/dataFlow'; -export const ReadDirectoryNodeOutput: FC<{ outputs: Outputs }> = ({ outputs }) => { - const outputPaths = expectType(outputs['paths' as PortId], 'string[]'); +export const ReadDirectoryNodeOutput: FC<{ outputs: InputsOrOutputsWithRefs }> = ({ outputs }) => { + const outputPaths = expectType(outputs['paths' as PortId] as DataValue, 'string[]'); return (
{outputPaths.length === 0 ? ( diff --git a/packages/app/src/components/nodes/SubGraphNode.tsx b/packages/app/src/components/nodes/SubGraphNode.tsx index c440efcc0..7ca47fb79 100644 --- a/packages/app/src/components/nodes/SubGraphNode.tsx +++ b/packages/app/src/components/nodes/SubGraphNode.tsx @@ -1,10 +1,11 @@ import { type FC } from 'react'; import { useRecoilValue } from 'recoil'; import { projectState } from '../../state/savedGraphs.js'; -import { type Outputs, type PortId, type SubGraphNode, coerceTypeOptional } from '@ironclad/rivet-core'; +import { type Outputs, type PortId, type SubGraphNode, coerceTypeOptional, type DataValue } from '@ironclad/rivet-core'; import { type NodeComponentDescriptor } from '../../hooks/useNodeTypes.js'; import { RenderDataOutputs } from '../RenderDataValue.js'; import { omit } from 'lodash-es'; +import { type InputsOrOutputsWithRefs } from '../../state/dataFlow'; export const SubGraphNodeBody: FC<{ node: SubGraphNode; @@ -21,11 +22,11 @@ export const SubGraphNodeBody: FC<{ }; export const SubGraphNodeOutputSimple: FC<{ - outputs: Outputs; + outputs: InputsOrOutputsWithRefs; renderMarkdown?: boolean; }> = ({ outputs, renderMarkdown }) => { - const cost = coerceTypeOptional(outputs['cost' as PortId], 'number'); - const duration = coerceTypeOptional(outputs['duration' as PortId], 'number'); + const cost = coerceTypeOptional(outputs['cost' as PortId] as DataValue, 'number'); + const duration = coerceTypeOptional(outputs['duration' as PortId] as DataValue, 'number'); return (
@@ -42,14 +43,17 @@ export const SubGraphNodeOutputSimple: FC<{ )}
- +
); }; export const FullscreenSubGraphNodeOutputSimple: FC<{ - outputs: Outputs; + outputs: InputsOrOutputsWithRefs; renderMarkdown: boolean; }> = ({ outputs, renderMarkdown }) => { return ; diff --git a/packages/app/src/components/nodes/UserInputNode.tsx b/packages/app/src/components/nodes/UserInputNode.tsx index e95675018..474defcb2 100644 --- a/packages/app/src/components/nodes/UserInputNode.tsx +++ b/packages/app/src/components/nodes/UserInputNode.tsx @@ -1,7 +1,8 @@ import { type FC } from 'react'; import { css } from '@emotion/react'; -import { type Outputs, type PortId, expectType, getScalarTypeOf } from '@ironclad/rivet-core'; +import { type Outputs, type PortId, expectType, getScalarTypeOf, type DataValue } from '@ironclad/rivet-core'; import { type NodeComponentDescriptor } from '../../hooks/useNodeTypes.js'; +import { type InputsOrOutputsWithRefs } from '../../state/dataFlow'; const questionsAndAnswersStyles = css` display: flex; @@ -13,14 +14,14 @@ const questionsAndAnswersStyles = css` } `; -export const UserInputNodeOutput: FC<{ outputs: Outputs }> = ({ outputs }) => { +export const UserInputNodeOutput: FC<{ outputs: InputsOrOutputsWithRefs }> = ({ outputs }) => { const questionsAndAnswers = outputs['questionsAndAnswers' as PortId]; if (!questionsAndAnswers || getScalarTypeOf(questionsAndAnswers.type) === 'control-flow-excluded') { return null; } - const qa = expectType(questionsAndAnswers, 'string[]'); + const qa = expectType(questionsAndAnswers as DataValue, 'string[]'); return (
diff --git a/packages/app/src/hooks/useCurrentExecution.ts b/packages/app/src/hooks/useCurrentExecution.ts index d6d3df2db..11fd5faf6 100644 --- a/packages/app/src/hooks/useCurrentExecution.ts +++ b/packages/app/src/hooks/useCurrentExecution.ts @@ -7,9 +7,11 @@ import { type ProcessEvents, type ProcessId, coerceTypeOptional, + getScalarTypeOf, + isArrayDataValue, } from '@ironclad/rivet-core'; import { produce } from 'immer'; -import { cloneDeep } from 'lodash-es'; +import { cloneDeep, mapValues } from 'lodash-es'; import { useRecoilValue, useSetRecoilState } from 'recoil'; import { type NodeRunData, @@ -20,6 +22,9 @@ import { runningGraphsState, selectedProcessPageNodesState, graphStartTimeState, + type InputsOrOutputsWithRefs, + type DataValueWithRefs, + type NodeRunDataWithRefs, } from '../state/dataFlow'; import { type ProcessQuestions, userInputModalQuestionsState } from '../state/userInput'; import { lastRecordingState } from '../state/execution'; @@ -28,6 +33,8 @@ import { useLatest } from 'ahooks'; import { entries, keys } from '../../../core/src/utils/typeSafety'; import { match } from 'ts-pattern'; import { previousDataPerNodeToKeepState } from '../state/settings'; +import { nanoid } from 'nanoid'; +import { setGlobalDataRef } from '../utils/globals'; function sanitizeDataValueForLength(value: DataValue | undefined) { return match(value) @@ -84,6 +91,44 @@ function sanitizeDataValueForLength(value: DataValue | undefined) { .otherwise((value): DataValue | undefined => value); } +function cloneNodeDataForHistory(data: Partial): Partial { + return { + ...data, + inputData: cloneNodeInputOrOutputDataForHistory(data.inputData), + outputData: cloneNodeInputOrOutputDataForHistory(data.outputData), + splitOutputData: data.splitOutputData + ? (mapValues(data.splitOutputData, (val) => cloneNodeInputOrOutputDataForHistory(val)) as { + [index: number]: InputsOrOutputsWithRefs; + }) + : undefined, + }; +} + +function cloneNodeInputOrOutputDataForHistory(data: Inputs | Outputs | undefined): InputsOrOutputsWithRefs | undefined { + if (data == null) { + return undefined; + } + + return mapValues(data as Record, (val) => { + if (!val) { + return cloneDeep(val); + } + + return convertToRef(val); + }) as InputsOrOutputsWithRefs; +} + +function convertToRef(value: DataValue): DataValueWithRefs { + const scalarType = getScalarTypeOf(value.type); + if (scalarType !== 'audio' && scalarType !== 'binary' && scalarType !== 'image') { + return cloneDeep(value) as DataValueWithRefs; + } + + const refId = nanoid(); + setGlobalDataRef(refId, value); + return { type: value.type, value: { ref: refId } } as DataValueWithRefs; +} + export function useCurrentExecution() { const setLastRunData = useSetRecoilState(lastRunDataByNodeState); const setSelectedPage = useSetRecoilState(selectedProcessPageNodesState); @@ -108,7 +153,7 @@ export function useCurrentExecution() { if (existingProcess) { existingProcess.data = { ...existingProcess.data, - ...cloneDeep(data), + ...cloneNodeDataForHistory(data), }; } else { if (previousDataPerNodeToKeep > -1) { @@ -131,7 +176,7 @@ export function useCurrentExecution() { draft[nodeId]!.push({ processId, - data: cloneDeep(data), + data: cloneNodeDataForHistory(data)!, }); } }), @@ -314,14 +359,14 @@ export function useCurrentExecution() { if (existingProcess) { existingProcess.data.splitOutputData = { ...existingProcess.data.splitOutputData, - [index]: cloneDeep(outputs), + [index]: cloneNodeInputOrOutputDataForHistory(outputs)!, }; } else { draft[node.id]!.push({ processId, data: { splitOutputData: { - [index]: cloneDeep(outputs), + [index]: cloneNodeInputOrOutputDataForHistory(outputs)!, }, }, }); diff --git a/packages/app/src/hooks/useGetAdHocInternalProcessContext.ts b/packages/app/src/hooks/useGetAdHocInternalProcessContext.ts index 211fcbee0..3ce7a0ad9 100644 --- a/packages/app/src/hooks/useGetAdHocInternalProcessContext.ts +++ b/packages/app/src/hooks/useGetAdHocInternalProcessContext.ts @@ -13,7 +13,7 @@ import { nanoid } from 'nanoid/non-secure'; import { useRecoilValue } from 'recoil'; import { settingsState } from '../state/settings'; import { useDependsOnPlugins } from './useDependsOnPlugins'; -import { datasetProvider } from '../utils/globals'; +import { audioProvider, datasetProvider } from '../utils/globals'; export function useGetAdHocInternalProcessContext() { const settings = useRecoilValue(settingsState); @@ -33,6 +33,7 @@ export function useGetAdHocInternalProcessContext() { settings: await fillMissingSettingsFromEnvironmentVariables(settings, plugins), nativeApi: new TauriNativeApi(), datasetProvider, + audioProvider, processId: nanoid() as ProcessId, executionCache: new Map(), externalFunctions: {}, diff --git a/packages/app/src/hooks/useLocalExecutor.ts b/packages/app/src/hooks/useLocalExecutor.ts index 1762c32df..fffadad60 100644 --- a/packages/app/src/hooks/useLocalExecutor.ts +++ b/packages/app/src/hooks/useLocalExecutor.ts @@ -25,7 +25,7 @@ import { lastRecordingState, loadedRecordingState } from '../state/execution'; import { fillMissingSettingsFromEnvironmentVariables } from '../utils/tauri'; import { trivetState } from '../state/trivet'; import { runTrivet } from '@ironclad/trivet'; -import { datasetProvider } from '../utils/globals'; +import { audioProvider, datasetProvider } from '../utils/globals'; import { entries } from '../../../core/src/utils/typeSafety'; export function useLocalExecutor() { @@ -147,6 +147,7 @@ export function useLocalExecutor() { ), nativeApi: new TauriNativeApi(), datasetProvider, + audioProvider, }, {}, contextValues, @@ -208,6 +209,7 @@ export function useLocalExecutor() { ), nativeApi: new TauriNativeApi(), datasetProvider, + audioProvider, }, inputs, ); diff --git a/packages/app/src/hooks/useNodeTypes.ts b/packages/app/src/hooks/useNodeTypes.ts index 89cd72c24..9a4e800fb 100644 --- a/packages/app/src/hooks/useNodeTypes.ts +++ b/packages/app/src/hooks/useNodeTypes.ts @@ -22,14 +22,15 @@ import { loadDatasetNodeDescriptor } from '../components/nodes/LoadDatasetNode'; import { datasetNearestNeighborsNodeDescriptor } from '../components/nodes/DatasetNearestNeighborsNode'; import { getDatasetRowNodeDescriptor } from '../components/nodes/GetDatasetRowNode'; import { replaceDatasetNodeDescriptor } from '../components/nodes/ReplaceDatasetNode'; +import { type InputsOrOutputsWithRefs } from '../state/dataFlow'; export type UnknownNodeComponentDescriptor = { Body?: FC<{ node: ChartNode }>; Output?: FC<{ node: ChartNode }>; Editor?: FC<{ node: ChartNode; onChange?: (node: ChartNode) => void }>; FullscreenOutput?: FC<{ node: ChartNode }>; - OutputSimple?: FC<{ outputs: Outputs }>; - FullscreenOutputSimple?: FC<{ outputs: Outputs; renderMarkdown: boolean }>; + OutputSimple?: FC<{ outputs: InputsOrOutputsWithRefs }>; + FullscreenOutputSimple?: FC<{ outputs: InputsOrOutputsWithRefs; renderMarkdown: boolean }>; defaultRenderMarkdown?: boolean; }; @@ -38,8 +39,8 @@ export type NodeComponentDescriptor = { Output?: FC<{ node: NodeOfType }>; Editor?: FC<{ node: NodeOfType; onChange?: (node: NodeOfType) => void }>; FullscreenOutput?: FC<{ node: NodeOfType }>; - OutputSimple?: FC<{ outputs: Outputs }>; - FullscreenOutputSimple?: FC<{ outputs: Outputs; renderMarkdown: boolean }>; + OutputSimple?: FC<{ outputs: InputsOrOutputsWithRefs }>; + FullscreenOutputSimple?: FC<{ outputs: InputsOrOutputsWithRefs; renderMarkdown: boolean }>; defaultRenderMarkdown?: boolean; }; diff --git a/packages/app/src/io/BrowserIOProvider.ts b/packages/app/src/io/BrowserIOProvider.ts index 16a3eb43a..04c003533 100644 --- a/packages/app/src/io/BrowserIOProvider.ts +++ b/packages/app/src/io/BrowserIOProvider.ts @@ -90,18 +90,18 @@ export class BrowserIOProvider implements IOProvider { await writable.close(); } - async readFileAsString(callback: (data: string) => void): Promise { + async readFileAsString(callback: (data: string, fileName: string) => void): Promise { const [fileHandle] = await window.showOpenFilePicker(); const file = await fileHandle.getFile(); const text = await file.text(); - callback(text); + callback(text, file.name); } - async readFileAsBinary(callback: (data: Uint8Array) => void): Promise { + async readFileAsBinary(callback: (data: Uint8Array, fileName: string) => void): Promise { const [fileHandle] = await window.showOpenFilePicker(); const file = await fileHandle.getFile(); const arrayBuffer = await file.arrayBuffer(); - callback(new Uint8Array(arrayBuffer)); + callback(new Uint8Array(arrayBuffer), file.name); } async readPathAsString(path: string): Promise { diff --git a/packages/app/src/io/IOProvider.ts b/packages/app/src/io/IOProvider.ts index 425d2c27a..13397325a 100644 --- a/packages/app/src/io/IOProvider.ts +++ b/packages/app/src/io/IOProvider.ts @@ -22,9 +22,9 @@ export interface IOProvider { saveString(content: string, defaultFileName: string): Promise; - readFileAsString(callback: (data: string) => void): Promise; + readFileAsString(callback: (data: string, fileName: string) => void): Promise; - readFileAsBinary(callback: (data: Uint8Array) => void): Promise; + readFileAsBinary(callback: (data: Uint8Array, fileName: string) => void): Promise; readPathAsString(path: string): Promise; diff --git a/packages/app/src/io/LegacyBrowserIOProvider.ts b/packages/app/src/io/LegacyBrowserIOProvider.ts index 3a9b8074c..f9430256c 100644 --- a/packages/app/src/io/LegacyBrowserIOProvider.ts +++ b/packages/app/src/io/LegacyBrowserIOProvider.ts @@ -116,25 +116,25 @@ export class LegacyBrowserIOProvider implements IOProvider { link.click(); } - async readFileAsString(callback: (data: string) => void): Promise { + async readFileAsString(callback: (data: string, fileName: string) => void): Promise { const input = document.createElement('input'); input.type = 'file'; input.onchange = async (event) => { const file = (event.target as HTMLInputElement)!.files![0]!; const text = await file.text(); - callback(text); + callback(text, file.name); }; input.click(); } - async readFileAsBinary(callback: (data: Uint8Array) => void): Promise { + async readFileAsBinary(callback: (data: Uint8Array, fileName: string) => void): Promise { const input = document.createElement('input'); input.type = 'file'; input.onchange = async (event) => { const file = (event.target as HTMLInputElement)!.files![0]!; const reader = new FileReader(); reader.onload = () => { - callback(new Uint8Array(reader.result as ArrayBuffer)); + callback(new Uint8Array(reader.result as ArrayBuffer), file.name); }; reader.readAsArrayBuffer(file); }; diff --git a/packages/app/src/io/TauriBrowserAudioProvider.ts b/packages/app/src/io/TauriBrowserAudioProvider.ts new file mode 100644 index 000000000..ac4e00f68 --- /dev/null +++ b/packages/app/src/io/TauriBrowserAudioProvider.ts @@ -0,0 +1,23 @@ +import { type AudioDataValue, type AudioProvider } from '@ironclad/rivet-core'; + +export class TauriBrowserAudioProvider implements AudioProvider { + async playAudio(audio: AudioDataValue, abort: AbortSignal): Promise { + const blob = new Blob([audio.value.data], { type: audio.value.mediaType ?? 'audio/wav' }); + const audioNode = new Audio(URL.createObjectURL(blob)); + + const finished = new Promise((resolve, reject) => { + audioNode.onended = () => { + resolve(); + }; + + abort.onabort = () => { + audioNode.pause(); + reject(new Error('Audio playback aborted')); + }; + }); + + await audioNode.play(); + + return finished; + } +} diff --git a/packages/app/src/io/TauriIOProvider.ts b/packages/app/src/io/TauriIOProvider.ts index 55feaf30d..b4ebd0a62 100644 --- a/packages/app/src/io/TauriIOProvider.ts +++ b/packages/app/src/io/TauriIOProvider.ts @@ -203,25 +203,29 @@ export class TauriIOProvider implements IOProvider { } } - async readFileAsString(callback: (data: string) => void): Promise { + async readFileAsString(callback: (data: string, fileName: string) => void): Promise { const path = await open({ multiple: false, }); if (path) { + const fileName = (path as string).split('/').pop() as string; + const contents = await readTextFile(path as string); - callback(contents); + callback(contents, fileName); } } - async readFileAsBinary(callback: (data: Uint8Array) => void): Promise { + async readFileAsBinary(callback: (data: Uint8Array, fileName: string) => void): Promise { const path = await open({ multiple: false, }); if (path) { + const fileName = (path as string).split('/').pop() as string; + const contents = await readBinaryFile(path as string); - callback(contents); + callback(contents, fileName); } } diff --git a/packages/app/src/state/dataFlow.ts b/packages/app/src/state/dataFlow.ts index 79a145e13..1c3d92127 100644 --- a/packages/app/src/state/dataFlow.ts +++ b/packages/app/src/state/dataFlow.ts @@ -1,16 +1,26 @@ import { atom, selectorFamily } from 'recoil'; -import { type GraphId, type Inputs, type NodeId, type Outputs, type ProcessId } from '@ironclad/rivet-core'; +import { + type PortId, + type GraphId, + type Inputs, + type NodeId, + type Outputs, + type ProcessId, + type DataType, + type DataValue, + type ScalarDataType, +} from '@ironclad/rivet-core'; export type ProcessDataForNode = { processId: ProcessId; - data: NodeRunData; + data: NodeRunDataWithRefs; }; export type RunDataByNodeId = { [nodeId: NodeId]: ProcessDataForNode[]; }; -export type NodeRunData = { +export type NodeRunDataBase = { startedAt?: number; finishedAt?: number; @@ -20,7 +30,9 @@ export type NodeRunData = { | { type: 'running' } | { type: 'interrupted' } | { type: 'notRan'; reason: string }; +}; +export type NodeRunData = NodeRunDataBase & { inputData?: Inputs; outputData?: Outputs; @@ -30,6 +42,29 @@ export type NodeRunData = { }; }; +export type NodeRunDataWithRefs = NodeRunDataBase & { + inputData?: InputsOrOutputsWithRefs; + + outputData?: InputsOrOutputsWithRefs; + + splitOutputData?: { + [index: number]: InputsOrOutputsWithRefs; + }; +}; + +export type InputsOrOutputsWithRefs = { + [portId: PortId]: DataValueWithRefs; +}; + +export type DataValueWithRefs = { + [P in DataType]: { + type: P; + value: P extends 'binary' | 'audio' | 'image' ? { ref: string } : Extract['value']; + }; +}[DataType]; + +export type ScalarDataValueWithRefs = Extract; + export const lastRunDataByNodeState = atom({ key: 'lastData', default: {}, diff --git a/packages/app/src/utils/globals.ts b/packages/app/src/utils/globals.ts index 6779a916a..8161f565c 100644 --- a/packages/app/src/utils/globals.ts +++ b/packages/app/src/utils/globals.ts @@ -1,2 +1,6 @@ +import { type DataValue } from '@ironclad/rivet-core'; + export * from './globals/datasetProvider.js'; export * from './globals/ioProvider.js'; +export * from './globals/audioProvider.js'; +export * from './globals/globalDataRefs.js'; diff --git a/packages/app/src/utils/globals/audioProvider.ts b/packages/app/src/utils/globals/audioProvider.ts new file mode 100644 index 000000000..945020706 --- /dev/null +++ b/packages/app/src/utils/globals/audioProvider.ts @@ -0,0 +1,5 @@ +import { TauriBrowserAudioProvider } from '../../io/TauriBrowserAudioProvider'; + +const audioProvider = new TauriBrowserAudioProvider(); + +export { audioProvider }; diff --git a/packages/app/src/utils/globals/globalDataRefs.ts b/packages/app/src/utils/globals/globalDataRefs.ts new file mode 100644 index 000000000..c482c00da --- /dev/null +++ b/packages/app/src/utils/globals/globalDataRefs.ts @@ -0,0 +1,25 @@ +import { type DataValue } from '@ironclad/rivet-core'; +import { LRUCache } from 'lru-cache'; +import { match } from 'ts-pattern'; + +const globalDataRefs = new LRUCache({ + maxSize: 500 * 1024 * 1024, // 500MB + sizeCalculation: (value) => { + return match(value) + .with({ type: 'image' }, (v) => v.value.data.byteLength) + .with({ type: 'binary' }, (v) => v.value.byteLength) + .with({ type: 'audio' }, (v) => v.value.data.byteLength) + .with({ type: 'image[]' }, (v) => v.value.reduce((acc, img) => acc + img.data.byteLength, 0)) + .with({ type: 'binary[]' }, (v) => v.value.reduce((acc, bin) => acc + bin.byteLength, 0)) + .with({ type: 'audio[]' }, (v) => v.value.reduce((acc, audio) => acc + audio.data.byteLength, 0)) + .otherwise((v) => JSON.stringify(v).length); + }, +}); + +export function getGlobalDataRef(key: string): DataValue | undefined { + return globalDataRefs.get(key); +} + +export function setGlobalDataRef(key: string, value: DataValue): void { + globalDataRefs.set(key, value); +} diff --git a/packages/core/src/api/createProcessor.ts b/packages/core/src/api/createProcessor.ts index bb2b519da..b44ce2d18 100644 --- a/packages/core/src/api/createProcessor.ts +++ b/packages/core/src/api/createProcessor.ts @@ -1,6 +1,7 @@ import type { PascalCase } from 'type-fest'; import { type AttachedData, + type AudioProvider, type DataValue, type DatasetProvider, type ExternalFunction, @@ -27,6 +28,7 @@ export type RunGraphOptions = { context?: Record; nativeApi?: NativeApi; datasetProvider?: DatasetProvider; + audioProvider?: AudioProvider; externalFunctions?: { [key: string]: ExternalFunction; }; @@ -166,6 +168,7 @@ export function coreCreateProcessor(project: Project, options: RunGraphOptions) { nativeApi: options.nativeApi, datasetProvider: options.datasetProvider, + audioProvider: options.audioProvider, settings: { openAiKey: options.openAiKey ?? '', openAiOrganization: options.openAiOrganization ?? '', diff --git a/packages/core/src/exports.ts b/packages/core/src/exports.ts index 995c02637..7738bd8de 100644 --- a/packages/core/src/exports.ts +++ b/packages/core/src/exports.ts @@ -28,6 +28,7 @@ export * from './integrations/DatasetProvider.js'; export * from './model/Dataset.js'; export * from './api/streaming.js'; export * from './api/createProcessor.js'; +export * from './integrations/AudioProvider.js'; import * as openai from './utils/openai.js'; export { openai }; diff --git a/packages/core/src/integrations/AudioProvider.ts b/packages/core/src/integrations/AudioProvider.ts new file mode 100644 index 000000000..3163869a4 --- /dev/null +++ b/packages/core/src/integrations/AudioProvider.ts @@ -0,0 +1,5 @@ +import type { AudioDataValue } from '../model/DataValue.js'; + +export interface AudioProvider { + playAudio(audio: AudioDataValue, abort: AbortSignal): Promise; +} diff --git a/packages/core/src/model/DataValue.ts b/packages/core/src/model/DataValue.ts index 7bbb2db4f..b7222887e 100644 --- a/packages/core/src/model/DataValue.ts +++ b/packages/core/src/model/DataValue.ts @@ -68,7 +68,7 @@ export type ObjectDataValue = DataValueDef<'object', Record>; export type VectorDataValue = DataValueDef<'vector', number[]>; export type BinaryDataValue = DataValueDef<'binary', Uint8Array>; export type ImageDataValue = DataValueDef<'image', { mediaType: SupportedMediaTypes; data: Uint8Array }>; -export type AudioDataValue = DataValueDef<'audio', { data: Uint8Array }>; +export type AudioDataValue = DataValueDef<'audio', { mediaType?: string; data: Uint8Array }>; export type GraphReferenceValue = DataValueDef<'graph-reference', { graphId: GraphId; graphName: string }>; /** GPT function definition */ diff --git a/packages/core/src/model/EditorDefinition.ts b/packages/core/src/model/EditorDefinition.ts index 9873e94c3..fb2f0abc0 100644 --- a/packages/core/src/model/EditorDefinition.ts +++ b/packages/core/src/model/EditorDefinition.ts @@ -132,6 +132,7 @@ export type FileBrowserEditorDefinition = SharedEditorDefin type: 'fileBrowser'; dataKey: DataOfType; + mediaTypeDataKey: DataOfType; useInputToggleDataKey?: DataOfType; accept?: string; diff --git a/packages/core/src/model/Nodes.ts b/packages/core/src/model/Nodes.ts index 12ce3b07d..4e2993b66 100644 --- a/packages/core/src/model/Nodes.ts +++ b/packages/core/src/model/Nodes.ts @@ -210,7 +210,14 @@ import { graphReferenceNode } from './nodes/GraphReferenceNode.js'; export * from './nodes/GraphReferenceNode.js'; import { callGraphNode } from './nodes/CallGraphNode.js'; +export * from './nodes/CallGraphNode.js'; + import { delegateFunctionCallNode } from './nodes/DelegateFunctionCallNode.js'; +export * from './nodes/DelegateFunctionCallNode.js'; + +import { playAudioNode } from './nodes/PlayAudioNode.js'; +export * from './nodes/PlayAudioNode.js'; + export * from './nodes/CallGraphNode.js'; export const registerBuiltInNodes = (registry: NodeRegistration) => { @@ -285,7 +292,8 @@ export const registerBuiltInNodes = (registry: NodeRegistration) => { .register(listGraphsNode) .register(graphReferenceNode) .register(callGraphNode) - .register(delegateFunctionCallNode); + .register(delegateFunctionCallNode) + .register(playAudioNode); }; let globalRivetNodeRegistry = registerBuiltInNodes(new NodeRegistration()); diff --git a/packages/core/src/model/ProcessContext.ts b/packages/core/src/model/ProcessContext.ts index 75d3cda28..8dd27e0d6 100644 --- a/packages/core/src/model/ProcessContext.ts +++ b/packages/core/src/model/ProcessContext.ts @@ -12,6 +12,7 @@ import { type DatasetProvider, type ChartNode, type AttachedNodeData, + type AudioProvider, } from '../index.js'; import type { Tokenizer } from '../integrations/Tokenizer.js'; @@ -22,6 +23,9 @@ export type ProcessContext = { /** Sets the dataset provider to be used for all dataset node calls. */ datasetProvider?: DatasetProvider; + /** The provider responsible for being able to play audio. Undefined if unsupported in this context. */ + audioProvider?: AudioProvider; + /** Sets the tokenizer that will be used for all nodes. If unset, the default GptTokenizerTokenizer will be used. */ tokenizer?: Tokenizer; diff --git a/packages/core/src/model/nodes/AudioNode.ts b/packages/core/src/model/nodes/AudioNode.ts index 17f7883dd..9d0193d05 100644 --- a/packages/core/src/model/nodes/AudioNode.ts +++ b/packages/core/src/model/nodes/AudioNode.ts @@ -16,12 +16,16 @@ import { } from '../../index.js'; import { base64ToUint8Array, expectType } from '../../utils/index.js'; import { nodeDefinition } from '../NodeDefinition.js'; +import { getInputOrData } from '../../utils/inputs.js'; export type AudioNode = ChartNode<'audio', AudioNodeData>; type AudioNodeData = { data?: DataRef; useDataInput: boolean; + + mediaType?: 'audio/wav' | 'audio/mp3' | 'audio/ogg'; + useMediaTypeInput: boolean; }; export class AudioNodeImpl extends NodeImpl { @@ -33,6 +37,7 @@ export class AudioNodeImpl extends NodeImpl { visualData: { x: 0, y: 0, width: 300 }, data: { useDataInput: false, + useMediaTypeInput: false, }, }; } @@ -49,6 +54,15 @@ export class AudioNodeImpl extends NodeImpl { }); } + if (this.chartNode.data.useMediaTypeInput) { + inputDefinitions.push({ + id: 'mediaType' as PortId, + title: 'Media Type', + dataType: 'string', + coerced: false, + }); + } + return inputDefinitions; } @@ -68,9 +82,16 @@ export class AudioNodeImpl extends NodeImpl { type: 'fileBrowser', label: 'Audio File', dataKey: 'data', + mediaTypeDataKey: 'mediaType', useInputToggleDataKey: 'useDataInput', accept: 'audio/*', }, + { + type: 'string', + label: 'Media Type', + dataKey: 'mediaType', + useInputToggleDataKey: 'useMediaTypeInput', + }, ]; } @@ -86,6 +107,8 @@ export class AudioNodeImpl extends NodeImpl { async process(inputData: Inputs, context: InternalProcessContext): Promise { let data: Uint8Array; + const mediaType = getInputOrData(this.data, inputData, 'mediaType', 'string') || 'audio/wav'; + if (this.chartNode.data.useDataInput) { data = expectType(inputData['data' as PortId], 'binary'); } else { @@ -106,7 +129,7 @@ export class AudioNodeImpl extends NodeImpl { return { ['data' as PortId]: { type: 'audio', - value: { data }, + value: { data, mediaType }, }, }; } diff --git a/packages/core/src/model/nodes/PlayAudioNode.ts b/packages/core/src/model/nodes/PlayAudioNode.ts new file mode 100644 index 000000000..4834497e4 --- /dev/null +++ b/packages/core/src/model/nodes/PlayAudioNode.ts @@ -0,0 +1,83 @@ +import { + type ChartNode, + type NodeId, + type PortId, + type NodeInputDefinition, + type NodeOutputDefinition, +} from '../NodeBase.js'; +import { NodeImpl, type NodeUIData } from '../NodeImpl.js'; +import { nanoid } from 'nanoid/non-secure'; +import { type EditorDefinition, type Inputs, type InternalProcessContext, type Outputs } from '../../index.js'; +import { expectType } from '../../utils/index.js'; +import { nodeDefinition } from '../NodeDefinition.js'; + +export type PlayAudioNode = ChartNode<'playAudio', PlayAudioNodeData>; + +type PlayAudioNodeData = {}; + +export class PlayAudioNodeImpl extends NodeImpl { + static create(): PlayAudioNode { + return { + id: nanoid() as NodeId, + type: 'playAudio', + title: 'Play Audio', + visualData: { x: 0, y: 0, width: 200 }, + data: {}, + }; + } + + getInputDefinitions(): NodeInputDefinition[] { + const inputDefinitions: NodeInputDefinition[] = []; + + inputDefinitions.push({ + id: 'data' as PortId, + title: 'Data', + dataType: 'audio', + coerced: false, + }); + + return inputDefinitions; + } + + getOutputDefinitions(): NodeOutputDefinition[] { + return [ + { + id: 'data' as PortId, + title: 'Audio Data', + dataType: 'audio', + }, + ]; + } + + getEditors(): EditorDefinition[] { + return []; + } + + static getUIData(): NodeUIData { + return { + contextMenuTitle: 'Play Audio', + group: 'Input/Output', + infoBoxTitle: 'Play Audio Node', + infoBoxBody: 'Plays audio data to the speakers.', + }; + } + + async process(inputData: Inputs, context: InternalProcessContext): Promise { + if (!context.audioProvider) { + throw new Error('Playing audio is not supported in this context'); + } + + const data = expectType(inputData['data' as PortId], 'audio'); + + await context.audioProvider.playAudio({ type: 'audio', value: data }, context.signal); + + return { + ['data' as PortId]: { + type: 'audio', + value: data, + }, + }; + } +} + +export const playAudioNode = nodeDefinition(PlayAudioNodeImpl, 'Play Audio'); diff --git a/packages/node/src/api.ts b/packages/node/src/api.ts index ebd8bbad7..63e090ea6 100644 --- a/packages/node/src/api.ts +++ b/packages/node/src/api.ts @@ -65,6 +65,7 @@ export function createProcessor( { nativeApi: options.nativeApi ?? new NodeNativeApi(), datasetProvider: options.datasetProvider, + audioProvider: options.audioProvider, settings: { openAiKey: options.openAiKey ?? process.env.OPENAI_API_KEY ?? '', openAiOrganization: options.openAiOrganization ?? process.env.OPENAI_ORG_ID ?? '', diff --git a/yarn.lock b/yarn.lock index ec7c59a1e..eed91eb84 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3821,8 +3821,10 @@ __metadata: immer: "npm:^10.0.3" jsonpath-plus: "npm:^7.2.0" lodash-es: "npm:^4.17.21" + lru-cache: "npm:^11.0.0" majesticons: "npm:^2.1.2" marked: "npm:^9.1.2" + mime: "npm:^4.0.4" minimatch: "npm:^9.0.3" monaco-editor: "npm:^0.44.0" nanoid: "npm:^3.3.6" @@ -13263,6 +13265,13 @@ __metadata: languageName: node linkType: hard +"lru-cache@npm:^11.0.0": + version: 11.0.0 + resolution: "lru-cache@npm:11.0.0" + checksum: 41f36fbff8b6f199cce3e9cb2b625714f97a535dfd7f16d0988c2627f9ed4c38b6dc8f9ea7fdba19262a7c917ba41c89cad15ca3e3791fc9a2068af472b5bc8d + languageName: node + linkType: hard + "lru-cache@npm:^5.1.1": version: 5.1.1 resolution: "lru-cache@npm:5.1.1" @@ -13538,6 +13547,15 @@ __metadata: languageName: node linkType: hard +"mime@npm:^4.0.4": + version: 4.0.4 + resolution: "mime@npm:4.0.4" + bin: + mime: bin/cli.js + checksum: 28e41053ae09cbf4186c551d7cc3cdda10c04fdf447cfdb66db096d83279889a0e0589805b15e36c37ca8b0eedfa6317f25d0514462525271f0cffa5cb0514b4 + languageName: node + linkType: hard + "mimic-fn@npm:^2.1.0": version: 2.1.0 resolution: "mimic-fn@npm:2.1.0"