Skip to content

Commit

Permalink
Open audio/image input stream only when queue is ready (#9149)
Browse files Browse the repository at this point in the history
* fix

* submit logic happens in Blocks

* add changeset

* trigger ci

* trigger ci

* Add code

* Add code

* Fix retrigger refactor

* Add code

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot authored Aug 20, 2024
1 parent f3652eb commit 3d7a9b8
Show file tree
Hide file tree
Showing 17 changed files with 239 additions and 41 deletions.
10 changes: 10 additions & 0 deletions .changeset/easy-files-serve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@gradio/audio": minor
"@gradio/client": minor
"@gradio/core": minor
"@gradio/icons": minor
"@gradio/image": minor
"gradio": minor
---

feat:Open audio/image input stream only when queue is ready
4 changes: 3 additions & 1 deletion client/js/src/helpers/api_info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ export function handle_message(
| "unexpected_error";
data?: any;
status?: Status;
original_msg?: string;
} {
const queue = true;
switch (data.msg) {
Expand Down Expand Up @@ -373,7 +374,8 @@ export function handle_message(
position: 0,
success: data.success,
eta: data.eta
}
},
original_msg: "process_starts"
};
}

Expand Down
1 change: 1 addition & 0 deletions client/js/src/test/api_info.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ describe("handle_message", () => {
const result = handle_message(data, last_status);
expect(result).toEqual({
type: "update",
original_msg: "process_starts",
status: {
queue: true,
stage: "pending",
Expand Down
1 change: 1 addition & 0 deletions client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ export interface StatusMessage extends Status {
type: "status";
endpoint: string;
fn_index: number;
original_msg?: string;
}

export interface PayloadMessage extends Payload {
Expand Down
3 changes: 2 additions & 1 deletion client/js/src/utils/submit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ export function submit(
event_id_final = event_id;
let callback = async function (_data: object): Promise<void> {
try {
const { type, status, data } = handle_message(
const { type, status, data, original_msg } = handle_message(
_data,
last_status[fn_index]
);
Expand All @@ -614,6 +614,7 @@ export function submit(
endpoint: _endpoint,
fn_index,
time: new Date(),
original_msg: original_msg,
...status
});
} else if (type === "complete") {
Expand Down
12 changes: 10 additions & 2 deletions js/audio/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@
export let streaming: boolean;
export let stream_every: number;
export let close_stream: () => void;
let stream_state = "closed";
let _modify_stream: (state: "open" | "closed" | "waiting") => void;
export function modify_stream_state(
state: "open" | "closed" | "waiting"
): void {
stream_state = state;
_modify_stream(state);
}
export const get_stream_state: () => void = () => stream_state;
export let set_time_limit: (time: number) => void;
export let gradio: Gradio<{
input: never;
Expand Down Expand Up @@ -245,7 +253,7 @@
{waveform_options}
{trim_region_settings}
{stream_every}
bind:close_stream
bind:modify_stream={_modify_stream}
bind:set_time_limit
upload={gradio.client.upload}
stream_handler={gradio.client.stream}
Expand Down
18 changes: 16 additions & 2 deletions js/audio/interactive/InteractiveAudio.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,21 @@
export let stream_every: number;
let time_limit: number | null = null;
let stream_state: "open" | "waiting" | "closed" = "closed";
export const close_stream: () => void = () => {
time_limit = null;
export const modify_stream: (state: "open" | "closed" | "waiting") => void = (
state: "open" | "closed" | "waiting"
) => {
if (state === "closed") {
time_limit = null;
stream_state = "closed";
} else if (state === "waiting") {
stream_state = "waiting";
} else {
stream_state = "open";
}
};
export const set_time_limit = (time: number): void => {
if (recording) time_limit = time;
};
Expand All @@ -60,6 +71,7 @@
let pending_stream: Uint8Array[] = [];
let submit_pending_stream_on_pending_end = false;
let inited = false;
let stream_open = false;
const NUM_HEADER_BYTES = 44;
let audio_chunks: Blob[] = [];
Expand Down Expand Up @@ -167,6 +179,7 @@
pending_stream.push(payload);
} else {
let blobParts = [header].concat(pending_stream, [payload]);
if (!recording || stream_state === "waiting") return;
dispatch_blob(blobParts, "stream");
pending_stream = [];
}
Expand Down Expand Up @@ -240,6 +253,7 @@
{i18n}
{waveform_settings}
{waveform_options}
waiting={stream_state === "waiting"}
/>
{:else}
<AudioRecorder
Expand Down
35 changes: 34 additions & 1 deletion js/audio/streaming/StreamAudio.svelte
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<script lang="ts">
import { onMount } from "svelte";
import type { I18nFormatter } from "@gradio/utils";
import { Spinner } from "@gradio/icons";
import WaveSurfer from "wavesurfer.js";
import RecordPlugin from "wavesurfer.js/dist/plugins/record.js";
import type { WaveformOptions } from "../shared/types";
Expand All @@ -15,6 +16,7 @@
export let waveform_options: WaveformOptions = {
show_recording_waveform: true
};
export let waiting = false;
let micWaveform: WaveSurfer;
let waveformRecord: RecordPlugin;
Expand Down Expand Up @@ -48,7 +50,7 @@
/>
{/if}
<div class="controls">
{#if recording}
{#if recording && !waiting}
<button
class={paused_recording ? "stop-button-paused" : "stop-button"}
on:click={() => {
Expand All @@ -62,6 +64,18 @@
</span>
{paused_recording ? i18n("audio.pause") : i18n("audio.stop")}
</button>
{:else if recording && waiting}
<button
class="spinner-button"
on:click={() => {
stop();
}}
>
<div class="icon">
<Spinner />
</div>
{i18n("audio.waiting")}
</button>
{:else}
<button
class="record-button"
Expand Down Expand Up @@ -95,6 +109,13 @@
margin: var(--spacing-xl);
}
.icon {
width: var(--size-4);
height: var(--size-4);
fill: var(--primary-600);
stroke: var(--primary-600);
}
.stop-button-paused {
display: none;
height: var(--size-8);
Expand Down Expand Up @@ -136,6 +157,18 @@
display: flex;
}
.spinner-button {
height: var(--size-8);
width: var(--size-24);
background-color: var(--block-background-fill);
border-radius: var(--radius-3xl);
align-items: center;
border: 1px solid var(--primary-600);
margin: 0 var(--spacing-xl);
display: flex;
justify-content: space-evenly;
}
.record-button::before {
content: "";
height: var(--size-4);
Expand Down
47 changes: 38 additions & 9 deletions js/core/src/Blocks.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
targets,
update_value,
get_data,
close_stream,
modify_stream,
get_stream_state,
set_time_limit,
loading_status,
scheduled_updates,
Expand Down Expand Up @@ -280,13 +281,25 @@
let submission: ReturnType<typeof app.submit>;
app.set_current_payload(payload);
if (streaming && submit_map.has(dep_index)) {
await app.post_data(
// @ts-ignore
`${app.config.root}/stream/${submit_map.get(dep_index).event_id()}`,
{ ...payload, session_hash: app.session_hash }
);
return;
if (streaming) {
if (!submit_map.has(dep_index)) {
dep.inputs.forEach((id) => modify_stream(id, "waiting"));
} else if (
submit_map.has(dep_index) &&
dep.inputs.some((id) => get_stream_state(id) === "waiting")
) {
return;
} else if (
submit_map.has(dep_index) &&
dep.inputs.some((id) => get_stream_state(id) === "open")
) {
await app.post_data(
// @ts-ignore
`${app.config.root}/stream/${submit_map.get(dep_index).event_id()}`,
{ ...payload, session_hash: app.session_hash }
);
return;
}
}
try {
submission = app.submit(
Expand Down Expand Up @@ -371,13 +384,29 @@
];
}
function open_stream_events(
status: StatusMessage,
id: number,
dep: Dependency
): void {
if (
status.original_msg === "process_starts" &&
dep.connection === "stream"
) {
modify_stream(id, "open");
}
}
function handle_status_update(message: StatusMessage): void {
const { fn_index, ...status } = message;
if (status.stage === "streaming" && status.time_limit) {
dep.inputs.forEach((id) => {
set_time_limit(id, status.time_limit);
});
}
dep.inputs.forEach((id) => {
open_stream_events(message, id, dep);
});
//@ts-ignore
loading_status.update({
...status,
Expand Down Expand Up @@ -428,7 +457,7 @@
}
});
dep.inputs.forEach((id) => {
close_stream(id);
modify_stream(id, "closed");
});
submit_map.delete(dep_index);
}
Expand Down
24 changes: 19 additions & 5 deletions js/core/src/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ export function create_components(): {
targets: Writable<TargetMap>;
update_value: (updates: UpdateTransaction[]) => void;
get_data: (id: number) => any | Promise<any>;
close_stream: (id: number) => void;
modify_stream: (id: number, state: "open" | "waiting" | "closed") => void;
get_stream_state: (id: number) => "open" | "waiting" | "closed" | "not_set";
set_time_limit: (id: number, time_limit: number | undefined) => void;
loading_status: ReturnType<typeof create_loading_status_store>;
scheduled_updates: Writable<boolean>;
Expand Down Expand Up @@ -347,13 +348,25 @@ export function create_components(): {
return comp.props.value;
}

function close_stream(id: number): void {
function modify_stream(
id: number,
state: "open" | "closed" | "waiting"
): void {
const comp = _component_map.get(id);
if (comp && comp.instance.close_stream) {
comp.instance.close_stream();
if (comp && comp.instance.modify_stream_state) {
comp.instance.modify_stream_state(state);
}
}

function get_stream_state(
id: number
): "open" | "closed" | "waiting" | "not_set" {
const comp = _component_map.get(id);
if (comp && comp.instance.get_stream_state)
return comp.instance.get_stream_state();
return "not_set";
}

function set_time_limit(id: number, time_limit: number | undefined): void {
const comp = _component_map.get(id);
if (comp && comp.instance.set_time_limit) {
Expand All @@ -366,7 +379,8 @@ export function create_components(): {
targets: target_map,
update_value,
get_data,
close_stream,
modify_stream,
get_stream_state,
set_time_limit,
loading_status,
scheduled_updates: update_scheduled_store,
Expand Down
3 changes: 2 additions & 1 deletion js/core/src/lang/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"record": "Record",
"no_microphone": "No microphone found",
"pause": "Pause",
"play": "Play"
"play": "Play",
"waiting": "Waiting"
},
"blocks": {
"connection_can_break": "On mobile, the connection can break if this tab is unfocused or the device sleeps, losing your position in queue.",
Expand Down
3 changes: 2 additions & 1 deletion js/core/src/lang/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"record": "录制",
"no_microphone": "找不到麦克风",
"pause": "暂停",
"play": "播放"
"play": "播放",
"waiting": "等待"
},
"blocks": {
"connection_can_break": "在移动设备上,如果此标签页失去焦点或设备休眠,连接可能会中断,导致您在队列中失去位置。",
Expand Down
Loading

0 comments on commit 3d7a9b8

Please sign in to comment.