Skip to content

Commit

Permalink
Play Audio node, and refactor how large blobs are stored in memory fo…
Browse files Browse the repository at this point in the history
…r the app, improve performance
  • Loading branch information
abrenneke committed Aug 2, 2024
1 parent 21fefdb commit 3a2ab69
Show file tree
Hide file tree
Showing 40 changed files with 508 additions and 108 deletions.
16 changes: 16 additions & 0 deletions .pnp.cjs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions packages/app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@
"immer": "^10.0.3",
"jsonpath-plus": "^7.2.0",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.0",
"majesticons": "^2.1.2",
"marked": "^9.1.2",
"mime": "^4.0.4",
"minimatch": "^9.0.3",
"monaco-editor": "^0.44.0",
"nanoid": "^3.3.6",
Expand Down
20 changes: 13 additions & 7 deletions packages/app/src/components/ChatViewer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ import {
arrayizeDataValue,
coerceTypeOptional,
} from '@ironclad/rivet-core';
import { type NodeRunData, lastRunDataByNodeState, graphRunningState } from '../state/dataFlow';
import {
type NodeRunData,
lastRunDataByNodeState,
graphRunningState,
type NodeRunDataWithRefs,
type DataValueWithRefs,
} from '../state/dataFlow';
import { projectState } from '../state/savedGraphs';
import { ErrorBoundary } from 'react-error-boundary';
import TextField from '@atlaskit/textfield';
Expand Down Expand Up @@ -324,7 +330,7 @@ const ChatBubble: FC<{
nodeId: NodeId;
nodeTitle: string;
processId: ProcessId;
data: NodeRunData;
data: NodeRunDataWithRefs;
splitIndex: number;
style?: CSSProperties;
onGoToNode?: (nodeId: NodeId) => void;
Expand All @@ -333,16 +339,16 @@ const ChatBubble: FC<{
const responseRef = useRef<HTMLDivElement>(null);
const [expanded, toggleExpanded] = useToggle();

let prompt: DataValue;
let prompt: DataValueWithRefs;

if (splitIndex === -1) {
prompt = data.inputData?.['prompt' as PortId]!;
} else {
const values = arrayizeDataValue(data.inputData?.['prompt' as PortId] as ScalarOrArrayDataValue);
if (values.length === 1) {
prompt = values[0]!;
prompt = values[0]! as DataValueWithRefs;
} else {
prompt = values[splitIndex]!;
prompt = values[splitIndex]! as DataValueWithRefs;
}
}

Expand All @@ -351,8 +357,8 @@ const ChatBubble: FC<{
? data.outputData?.['response' as PortId]
: data.splitOutputData![splitIndex]!['response' as PortId];

const promptText = coerceTypeOptional(prompt, 'string');
const responseText = coerceTypeOptional(chatOutput, 'string');
const promptText = coerceTypeOptional(prompt as DataValue, 'string');
const responseText = coerceTypeOptional(chatOutput as DataValue, 'string');

useLayoutEffect(() => {
if (promptRef.current) {
Expand Down
17 changes: 11 additions & 6 deletions packages/app/src/components/NodeOutput.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
import { type NodeRunData, type ProcessDataForNode, lastRunData, selectedProcessPage } from '../state/dataFlow.js';
import {
type NodeRunData,
type ProcessDataForNode,
lastRunData,
selectedProcessPage,
type NodeRunDataWithRefs,
} from '../state/dataFlow.js';
import { type FC, type ReactNode, memo, useMemo, useState, type MouseEvent } from 'react';
import { useUnknownNodeComponentDescriptorFor } from '../hooks/useNodeTypes.js';
import { useStableCallback } from '../hooks/useStableCallback.js';
import { copyToClipboard } from '../utils/copyToClipboard.js';
import { type ChartNode, type PortId, type ProcessId, getWarnings } from '@ironclad/rivet-core';
import { type ChartNode, type PortId, type ProcessId, getWarnings, type Outputs } from '@ironclad/rivet-core';
import { css } from '@emotion/react';
import CopyIcon from 'majesticons/line/clipboard-line.svg?react';
import ExpandIcon from 'majesticons/line/maximize-line.svg?react';
Expand Down Expand Up @@ -350,7 +356,6 @@ const NodeOutputBase: FC<{ node: ChartNode; children?: ReactNode; onOpenFullscre
onOpenFullscreenModal,
}) => {
const output = useRecoilValue(lastRunData(node.id));

if (!output?.length) {
return null;
}
Expand All @@ -377,7 +382,7 @@ const NodeOutputBase: FC<{ node: ChartNode; children?: ReactNode; onOpenFullscre

const NodeOutputSingleProcess: FC<{
node: ChartNode;
data: NodeRunData;
data: NodeRunDataWithRefs;
processId: ProcessId;
onOpenFullscreenModal?: () => void;
}> = ({ node, data, processId, onOpenFullscreenModal }) => {
Expand Down Expand Up @@ -482,9 +487,9 @@ const NodeOutputSingleProcess: FC<{
</div>
</div>
{body}
{getWarnings(data.outputData) && (
{getWarnings(data.outputData as Outputs) && (
<div className="node-output-warnings">
{getWarnings(data.outputData)!.map((warning) => (
{getWarnings(data.outputData as Outputs)!.map((warning) => (
<div className="node-output-warning" key={warning}>
{warning}
</div>
Expand Down
11 changes: 8 additions & 3 deletions packages/app/src/components/PromptDesigner.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
promptDesignerTestGroupResultsByNodeIdState,
} from '../state/promptDesigner';
import { nodesByIdState, nodesState } from '../state/graph.js';
import { lastRunDataByNodeState } from '../state/dataFlow.js';
import { type InputsOrOutputsWithRefs, lastRunDataByNodeState } from '../state/dataFlow.js';
import {
type ChatMessage,
type ChatNode,
Expand All @@ -31,6 +31,9 @@ import {
getError,
isArrayDataValue,
openai,
type DataValue,
type ScalarDataValue,
type Inputs,
} from '@ironclad/rivet-core';
import TextField from '@atlaskit/textfield';
import { Field } from '@atlaskit/form';
Expand Down Expand Up @@ -354,9 +357,11 @@ export const PromptDesigner: FC<PromptDesignerProps> = ({ onClose }) => {
let inputData = nodeDataForAttachedNodeProcess.inputData;
// If node is a split run, just grab the first input data.
if (attachedNode.isSplitRun) {
inputData = mapValues(inputData, (val) => (isArrayDataValue(val) ? arrayizeDataValue(val)[0] : val));
inputData = mapValues(inputData, (val) =>
isArrayDataValue(val as DataValue) ? arrayizeDataValue(val as ScalarDataValue)[0] : val,
) as InputsOrOutputsWithRefs;
}
const { messages } = getChatNodeMessages(inputData);
const { messages } = getChatNodeMessages(inputData as Inputs);
setMessages({
messages,
});
Expand Down
88 changes: 63 additions & 25 deletions packages/app/src/components/RenderDataValue.tsx
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import { type FC } from 'react';
import { useMemo, type FC } from 'react';
import {
type DataValue,
type Outputs,
type ScalarDataType,
type ScalarDataValue,
arrayizeDataValue,
getScalarTypeOf,
inferType,
isArrayDataValue,
isFunctionDataValue,
coerceTypeOptional,
type NodeOutputDefinition,
type ChatMessageMessagePart,
type DataType,
type ImageDataValue,
type BinaryDataValue,
type AudioDataValue,
isArrayDataType,
type ScalarOrArrayDataValue,
} from '@ironclad/rivet-core';
import { css } from '@emotion/react';
import { keys } from '../../../core/src/utils/typeSafety';
import { useMarkdown } from '../hooks/useMarkdown';
import ColorizedPreformattedText from './ColorizedPreformattedText';
import { P, match } from 'ts-pattern';
import clsx from 'clsx';
import { type InputsOrOutputsWithRefs, type DataValueWithRefs, type ScalarDataValueWithRefs } from '../state/dataFlow';
import { getGlobalDataRef } from '../utils/globals';

const styles = css`
.chat-message.user header em {
Expand Down Expand Up @@ -81,7 +84,7 @@ const multiOutput = css`
`;

type ScalarRendererProps<T extends DataType = DataType> = {
value: Extract<ScalarDataValue, { type: T }>;
value: Extract<ScalarDataValueWithRefs, { type: T }>;
depth?: number;
renderMarkdown?: boolean;
truncateLength?: number;
Expand Down Expand Up @@ -148,7 +151,7 @@ const scalarRenderers: {
<div className="pre-wrap">
{message.function_calls.map((fc, i) => (
<div key={i}>
<RenderDataValue value={inferType(fc)} />
<RenderDataValue value={inferType(fc) as DataValueWithRefs} />
</div>
))}
</div>
Expand All @@ -158,7 +161,7 @@ const scalarRenderers: {
<div className="function-call">
<h4>Function Call:</h4>
<div className="pre-wrap">
<RenderDataValue value={inferType(message.function_call)} />
<RenderDataValue value={inferType(message.function_call) as DataValueWithRefs} />
</div>
</div>
)
Expand Down Expand Up @@ -191,7 +194,9 @@ const scalarRenderers: {
if (inferred.type === 'any') {
return <>{JSON.stringify(inferred.value)}</>;
}
return <RenderDataValue value={inferred} depth={(depth ?? 0) + 1} renderMarkdown={renderMarkdown} />;
return (
<RenderDataValue value={inferred as DataValueWithRefs} depth={(depth ?? 0) + 1} renderMarkdown={renderMarkdown} />
);
},
object: ({ value }) => (
<div className="rendered-object-type">
Expand All @@ -205,12 +210,19 @@ const scalarRenderers: {
),
vector: ({ value }) => <>Vector (length {value.value.length})</>,
image: ({ value }) => {
const resolved = getGlobalDataRef(value.value.ref);
if (!resolved) {
return <div>Could not find data.</div>;
}

const {
value: { data, mediaType },
} = value;
} = resolved as ImageDataValue;

const blob = new Blob([data], { type: mediaType });
const imageUrl = URL.createObjectURL(blob);
const imageUrl = useMemo(() => {
const blob = new Blob([data], { type: mediaType });
return URL.createObjectURL(blob);
}, [data, mediaType]);

return (
<div>
Expand All @@ -219,19 +231,39 @@ const scalarRenderers: {
);
},
binary: ({ value }) => {
const resolved = getGlobalDataRef(value.value.ref);
if (!resolved) {
return <div>Could not find data.</div>;
}

// FIXME: Coercing `value.value` into a `Uint8Array` here because `Uint8Array` gets parsed as an
// object of shape `{ [index: number]: number }` when stringified via `JSON.stringify()`.
// Consider coercing it back to `Uint8Array` at the entrypoints of the boundaries between
// browser and node.js instead.
const coercedValue = new Uint8Array(Object.values(value.value));
const coercedValue = useMemo(() => {
const resolved = getGlobalDataRef(value.value.ref);
if (resolved!.value instanceof Uint8Array) {
return resolved!.value;
}
return new Uint8Array(Object.values((resolved as BinaryDataValue).value));
}, [value.value.ref]);

return <>Binary (length {coercedValue.length.toLocaleString()})</>;
},
audio: ({ value }) => {
const resolved = getGlobalDataRef(value.value.ref);
if (!resolved) {
return <div>Could not find data.</div>;
}

const {
value: { data },
} = value;
value: { data, mediaType },
} = resolved as AudioDataValue;

const dataUri = `data:audio/mp4;base64,${data}`;
const dataUri = useMemo(() => {
const blob = new Blob([data], { type: mediaType });
return URL.createObjectURL(blob);
}, [data, mediaType]);

return (
<div>
Expand All @@ -257,8 +289,14 @@ const RenderChatMessagePart: FC<{ part: ChatMessageMessagePart; renderMarkdown?:
return <Renderer value={{ type: 'string', value: part }} renderMarkdown={renderMarkdown} />;
})
.with({ type: 'image' }, (part) => {
const Renderer = scalarRenderers.image;
return <Renderer value={{ type: 'image', value: part }} renderMarkdown={renderMarkdown} />;
const blob = new Blob([part.data], { type: part.mediaType });
const imageUrl = URL.createObjectURL(blob);

return (
<div>
<img src={imageUrl} alt="" />
</div>
);
})
.with({ type: 'url' }, (part) => {
return <img className="chat-message-url-image" src={part.url} alt={part.url} />;
Expand All @@ -267,7 +305,7 @@ const RenderChatMessagePart: FC<{ part: ChatMessageMessagePart; renderMarkdown?:
};

export const RenderDataValue: FC<{
value: DataValue | undefined;
value: DataValueWithRefs | undefined;
depth?: number;
renderMarkdown?: boolean;
truncateLength?: number;
Expand All @@ -280,8 +318,8 @@ export const RenderDataValue: FC<{
}

const keys = Object.keys(value?.value ?? {});
if (isArrayDataValue(value)) {
const items = arrayizeDataValue(value);
if (isArrayDataType(value.type)) {
const items = arrayizeDataValue(value as ScalarOrArrayDataValue);
return (
<div
css={multiOutput}
Expand All @@ -296,7 +334,7 @@ export const RenderDataValue: FC<{
<div className="multi-output-item" key={i}>
<RenderDataValue
key={i}
value={v}
value={v as DataValueWithRefs}
depth={(depth ?? 0) + 1}
renderMarkdown={renderMarkdown}
truncateLength={truncateLength}
Expand All @@ -307,7 +345,7 @@ export const RenderDataValue: FC<{
);
}

if (isFunctionDataValue(value)) {
if (isFunctionDataValue(value as DataValue)) {
const type = getScalarTypeOf(value.type);
return (
<div>
Expand All @@ -325,7 +363,7 @@ export const RenderDataValue: FC<{
return (
<div css={styles}>
<Renderer
value={value}
value={value as ScalarDataValueWithRefs}
depth={(depth ?? 0) + 1}
renderMarkdown={renderMarkdown}
truncateLength={truncateLength}
Expand All @@ -336,7 +374,7 @@ export const RenderDataValue: FC<{

export const RenderDataOutputs: FC<{
definitions?: NodeOutputDefinition[];
outputs: Outputs;
outputs: InputsOrOutputsWithRefs;
renderMarkdown?: boolean;
}> = ({ definitions, outputs, renderMarkdown }) => {
const outputPorts = keys(outputs);
Expand Down
Loading

0 comments on commit 3a2ab69

Please sign in to comment.