Skip to content

Commit

Permalink
feat(client): add streaming support (#56)
Browse files Browse the repository at this point in the history
* feat(client): add streaming support

* fix: variable type definition

* chore: add docs

* chore: bump to client 0.9.0
  • Loading branch information
drochetti authored Mar 6, 2024
1 parent 6d112c8 commit 335b817
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 26 deletions.
87 changes: 87 additions & 0 deletions apps/demo-nextjs-app-router/app/streaming/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
'use client';

import * as fal from '@fal-ai/serverless-client';
import { useState } from 'react';

fal.config({
proxyUrl: '/api/fal/proxy',
});

type LlavaInput = {
prompt: string;
image_url: string;
max_new_tokens?: number;
temperature?: number;
top_p?: number;
};

type LlavaOutput = {
output: string;
partial: boolean;
stats: {
num_input_tokens: number;
num_output_tokens: number;
};
};

export default function StreamingDemo() {
const [answer, setAnswer] = useState<string>('');
const [streamStatus, setStreamStatus] = useState<string>('idle');

const runInference = async () => {
const stream = await fal.stream<LlavaInput, LlavaOutput>(
'fal-ai/llavav15-13b',
{
input: {
prompt:
'Do you know who drew this picture and what is the name of it?',
image_url: 'https://llava-vl.github.io/static/images/monalisa.jpg',
max_new_tokens: 100,
temperature: 0.2,
top_p: 1,
},
}
);
setStreamStatus('running');

for await (const partial of stream) {
setAnswer(partial.output);
}

const result = await stream.done();
setStreamStatus('done');
setAnswer(result.output);
};

return (
<div className="min-h-screen dark:bg-gray-900 bg-gray-100">
<main className="container dark:text-gray-50 text-gray-900 flex flex-col items-center justify-center w-full flex-1 py-10 space-y-8">
<h1 className="text-4xl font-bold mb-8">
Hello <code className="text-pink-600">fal</code> +{' '}
<code className="text-indigo-500">streaming</code>
</h1>

<div className="flex flex-row space-x-2">
<button
onClick={runInference}
className="bg-indigo-600 hover:bg-indigo-700 text-white font-bold text-lg py-3 px-6 mx-auto rounded focus:outline-none focus:shadow-outline disabled:opacity-70"
>
Run inference
</button>
</div>

<div className="w-full flex flex-col space-y-4">
<div className="flex flex-row items-center justify-between">
<h2 className="text-2xl font-bold">Answer</h2>
<span>
streaming: <code className="font-semibold">{streamStatus}</code>
</span>
</div>
<p className="text-lg p-4 border min-h-[12rem] border-gray-300 bg-gray-200 dark:bg-gray-800 dark:border-gray-700 rounded">
{answer}
</p>
</div>
</main>
</div>
);
}
3 changes: 2 additions & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/serverless-client",
"description": "The fal serverless JS/TS client",
"version": "0.8.6",
"version": "0.9.0",
"license": "MIT",
"repository": {
"type": "git",
Expand All @@ -17,6 +17,7 @@
],
"dependencies": {
"@msgpack/msgpack": "^3.0.0-beta2",
"eventsource-parser": "^1.1.2",
"robot3": "^0.4.1",
"uuid-random": "^1.3.2"
},
Expand Down
26 changes: 26 additions & 0 deletions libs/client/src/auth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { getRestApiUrl } from './config';
import { dispatchRequest } from './request';
import { ensureAppIdFormat } from './utils';

export const TOKEN_EXPIRATION_SECONDS = 120;

/**
* Get a token to connect to the realtime endpoint.
*/
export async function getTemporaryAuthToken(app: string): Promise<string> {
const [, appAlias] = ensureAppIdFormat(app).split('/');
const token: string | object = await dispatchRequest<any, string>(
'POST',
`${getRestApiUrl()}/tokens/`,
{
allowed_apps: [appAlias],
token_expiration: TOKEN_EXPIRATION_SECONDS,
}
);
// keep this in case the response was wrapped (old versions of the proxy do that)
// should be safe to remove in the future
if (typeof token !== 'string' && token['detail']) {
return token['detail'];
}
return token;
}
1 change: 1 addition & 0 deletions libs/client/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export { realtimeImpl as realtime } from './realtime';
export { ApiError, ValidationError } from './response';
export type { ResponseHandler } from './response';
export { storageImpl as storage } from './storage';
export { stream } from './streaming';
export type {
QueueStatus,
ValidationErrorInfo,
Expand Down
27 changes: 2 additions & 25 deletions libs/client/src/realtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import {
transition,
} from 'robot3';
import uuid from 'uuid-random';
import { getRestApiUrl } from './config';
import { dispatchRequest } from './request';
import { TOKEN_EXPIRATION_SECONDS, getTemporaryAuthToken } from './auth';
import { ApiError } from './response';
import { isBrowser } from './runtime';
import { ensureAppIdFormat, isReact, throttle } from './utils';
Expand Down Expand Up @@ -280,7 +279,6 @@ function buildRealtimeUrl(
return `wss://fal.run/${appId}/${suffix}?${queryParams.toString()}`;
}

const TOKEN_EXPIRATION_SECONDS = 120;
const DEFAULT_THROTTLE_INTERVAL = 128;

function shouldSendBinary(message: any): boolean {
Expand All @@ -292,27 +290,6 @@ function shouldSendBinary(message: any): boolean {
);
}

/**
* Get a token to connect to the realtime endpoint.
*/
async function getToken(app: string): Promise<string> {
const [, appAlias] = ensureAppIdFormat(app).split('/');
const token: string | object = await dispatchRequest<any, string>(
'POST',
`${getRestApiUrl()}/tokens/`,
{
allowed_apps: [appAlias],
token_expiration: TOKEN_EXPIRATION_SECONDS,
}
);
// keep this in case the response was wrapped (old versions of the proxy do that)
// should be safe to remove in the future
if (typeof token !== 'string' && token['detail']) {
return token['detail'];
}
return token;
}

function isUnauthorizedError(message: any): boolean {
// TODO we need better protocol definition with error codes
return message['status'] === 'error' && message['error'] === 'Unauthorized';
Expand Down Expand Up @@ -441,7 +418,7 @@ export const realtimeImpl: RealtimeClient = {
previousState !== machine.current
) {
send({ type: 'initiateAuth' });
getToken(app)
getTemporaryAuthToken(app)
.then((token) => {
send({ type: 'authenticated', token });
const tokenExpirationTimeout = Math.round(
Expand Down
Loading

0 comments on commit 335b817

Please sign in to comment.