Skip to content

Commit

Permalink
Allow building multipage Gradio apps (#10433)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* chnages

* changes

* changes

* add changeset

* Update gradio/blocks.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/blocks.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* changes

* changes

* changes

* chagnes

* Update js/core/src/Blocks.svelte

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* Update js/core/src/Blocks.svelte

Co-authored-by: Hannah <hannahblair@users.noreply.github.com>

* changes

* chagnes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* docs

* changes

* changes

* changes

* rename guide

* rename guide

* changes

* chagnes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* try skipping

* format

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
  • Loading branch information
5 people authored Feb 5, 2025
1 parent 35fda36 commit 2e8dc74
Show file tree
Hide file tree
Showing 21 changed files with 510 additions and 118 deletions.
10 changes: 10 additions & 0 deletions .changeset/large-beans-retire.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@gradio/client": minor
"@gradio/core": minor
"@gradio/lite": minor
"@self/app": minor
"@self/spa": minor
"gradio": minor
---

feat:Allow building multipage Gradio apps
34 changes: 34 additions & 0 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,40 @@ export class Client {
current_payload: any;
ws_map: Record<string, WebSocket | "failed"> = {};

get_url_config(url: string | null = null): Config {
if (!this.config) {
throw new Error(CONFIG_ERROR_MSG);
}
if (url === null) {
url = window.location.href;
}
const stripSlashes = (str: string): string => str.replace(/^\/+|\/+$/g, "");
let root_path = stripSlashes(new URL(this.config.root).pathname);
let url_path = stripSlashes(new URL(url).pathname);
let page = stripSlashes(url_path.substring(root_path.length));
return this.get_page_config(page);
}
get_page_config(page: string): Config {
if (!this.config) {
throw new Error(CONFIG_ERROR_MSG);
}
let config = this.config;
if (!(page in config.page)) {
throw new Error(`Page ${page} not found`);
}
return {
...config,
current_page: page,
layout: config.page[page].layout,
components: config.components.filter((c) =>
config.page[page].components.includes(c.id)
),
dependencies: this.config.dependencies.filter((d) =>
config.page[page].dependencies.includes(d.id)
)
};
}

fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
const headers = new Headers(init?.headers || {});
if (this && this.cookies) {
Expand Down
10 changes: 10 additions & 0 deletions client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
current_page: string;
page: Record<
string,
{
components: number[];
dependencies: number[];
layout: any;
}
>;
pages: [string, string][];
protocol: "sse_v3" | "sse_v2.1" | "sse_v2" | "sse_v1" | "sse" | "ws";
max_file_size?: number;
theme_hash?: number;
Expand Down
1 change: 1 addition & 0 deletions demo/multipage/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: multipage"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " greet_btn = gr.Button(\"Greet\")\n", " @gr.on([greet_btn.click, name.submit], inputs=name, outputs=output)\n", " def greet(name):\n", " return \"Hello \" + name + \"!\"\n", " \n", " @gr.render(inputs=name, triggers=[output.change])\n", " def spell_out(name):\n", " with gr.Row():\n", " for letter in name:\n", " gr.Textbox(letter)\n", "\n", "with demo.route(\"Up\") as incrementer_demo:\n", " num = gr.Number()\n", " incrementer_demo.load(lambda: time.sleep(1) or random.randint(10, 40), None, num)\n", "\n", " with gr.Row():\n", " inc_btn = gr.Button(\"Increase\")\n", " dec_btn = gr.Button(\"Decrease\")\n", " inc_btn.click(fn=lambda x: x + 1, inputs=num, outputs=num, api_name=\"increment\")\n", " dec_btn.click(fn=lambda x: x - 1, inputs=num, outputs=num, api_name=\"decrement\")\n", " for i in range(100):\n", " gr.Textbox()\n", "\n", "def wait(x):\n", " time.sleep(2)\n", " return x\n", "\n", "identity_iface = gr.Interface(wait, \"image\", \"image\")\n", "\n", "with demo.route(\"Interface\") as incrementer_demo:\n", " identity_iface.render()\n", " gr.Interface(lambda x, y: x * y, [\"number\", \"number\"], \"number\")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
42 changes: 42 additions & 0 deletions demo/multipage/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import gradio as gr
import random
import time

with gr.Blocks() as demo:
name = gr.Textbox(label="Name")
output = gr.Textbox(label="Output Box")
greet_btn = gr.Button("Greet")
@gr.on([greet_btn.click, name.submit], inputs=name, outputs=output)
def greet(name):
return "Hello " + name + "!"

@gr.render(inputs=name, triggers=[output.change])
def spell_out(name):
with gr.Row():
for letter in name:
gr.Textbox(letter)

with demo.route("Up") as incrementer_demo:
num = gr.Number()
incrementer_demo.load(lambda: time.sleep(1) or random.randint(10, 40), None, num)

with gr.Row():
inc_btn = gr.Button("Increase")
dec_btn = gr.Button("Decrease")
inc_btn.click(fn=lambda x: x + 1, inputs=num, outputs=num, api_name="increment")
dec_btn.click(fn=lambda x: x - 1, inputs=num, outputs=num, api_name="decrement")
for i in range(100):
gr.Textbox()

def wait(x):
time.sleep(2)
return x

identity_iface = gr.Interface(wait, "image", "image")

with demo.route("Interface") as incrementer_demo:
identity_iface.render()
gr.Interface(lambda x, y: x * y, ["number", "number"], "number")

if __name__ == "__main__":
demo.launch()
97 changes: 88 additions & 9 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
import random
import re
import secrets
import string
import sys
Expand Down Expand Up @@ -54,6 +55,7 @@
FileData,
GradioModel,
GradioRootModel,
Layout,
)
from gradio.events import (
EventData,
Expand All @@ -69,7 +71,7 @@
from gradio.helpers import create_tracker, skip, special_args
from gradio.node_server import start_node_server
from gradio.route_utils import API_PREFIX, MediaStream
from gradio.routes import VERSION, App, Request
from gradio.routes import INTERNAL_ROUTES, VERSION, App, Request
from gradio.state_holder import SessionState, StateHolder
from gradio.themes import Default as DefaultTheme
from gradio.themes import ThemeClass as Theme
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
self.share_token = secrets.token_urlsafe(32)
self.parent: BlockContext | None = None
self.rendered_in: Renderable | None = None
self.page: str
self.is_rendered: bool = False
self._constructor_args: list[dict]
self.state_session_capacity = 10000
Expand Down Expand Up @@ -187,6 +190,8 @@ def render(self):
f"A block with id: {self._id} has already been rendered in the current Blocks."
)
if render_context is not None:
if root_context:
self.page = root_context.root_block.current_page
render_context.add(self)
if root_context is not None:
root_context.blocks[self._id] = self
Expand Down Expand Up @@ -467,6 +472,7 @@ def fill_expected_parents(self):
pseudo_parent.parent = self
children.append(pseudo_parent)
pseudo_parent.add_child(child)
pseudo_parent.page = child.page
if root_context:
root_context.blocks[pseudo_parent._id] = pseudo_parent
child.parent = pseudo_parent
Expand Down Expand Up @@ -521,6 +527,7 @@ def __init__(
stream_every: float = 0.5,
like_user_message: bool = False,
event_specific_args: list[str] | None = None,
page: str = "",
):
self.fn = fn
self._id = _id
Expand Down Expand Up @@ -554,6 +561,7 @@ def __init__(
) or inspect.isasyncgenfunction(self.fn)
self.renderable = renderable
self.rendered_in = rendered_in
self.page = page

# We need to keep track of which events are cancel events
# so that the client can call the /cancel route directly
Expand Down Expand Up @@ -871,19 +879,32 @@ def set_event_trigger(
stream_every=stream_every,
like_user_message=like_user_message,
event_specific_args=event_specific_args,
page=self.root_block.current_page,
)

self.fns[self.fn_id] = block_fn
self.fn_id += 1
return block_fn, block_fn._id

def get_config(self, renderable: Renderable | None = None):
config = {}
config = {
"page": {},
"components": [],
"dependencies": [],
}

for page, _ in self.root_block.pages:
if page not in config["page"]:
config["page"][page] = {
"layout": {"id": self.root_block._id, "children": []},
"components": [],
"dependencies": [],
}

rendered_ids = []
sidebar_count = [0]

def get_layout(block: Block):
def get_layout(block: Block) -> Layout:
rendered_ids.append(block._id)
if block.get_block_name() == "sidebar":
sidebar_count[0] += 1
Expand All @@ -895,16 +916,22 @@ def get_layout(block: Block):
return {"id": block._id}
children_layout = []
for child in block.children:
children_layout.append(get_layout(child))
layout = get_layout(child)
children_layout.append(layout)
return {"id": block._id, "children": children_layout}

if renderable:
root_block = self.blocks[renderable.container_id]
else:
root_block = self.root_block
config["layout"] = get_layout(root_block)
layout = get_layout(root_block)
config["layout"] = layout

for root_child in layout.get("children", []):
if isinstance(root_child, dict) and root_child["id"] in self.blocks:
block = self.blocks[root_child["id"]]
config["page"][block.page]["layout"]["children"].append(root_child)

config["components"] = []
blocks_items = list(
self.blocks.items()
) # freeze as list to prevent concurrent re-renders from changing the dict during loop, see https://github.com/gradio-app/gradio/issues/9991
Expand Down Expand Up @@ -937,11 +964,15 @@ def get_layout(block: Block):
block_config["api_info_as_output"] = block.api_info() # type: ignore
block_config["example_inputs"] = block.example_inputs() # type: ignore
config["components"].append(block_config)
config["page"][block.page]["components"].append(block._id)

dependencies = []
for fn in self.fns.values():
if renderable is None or fn.rendered_in == renderable:
dependencies.append(fn.get_config())
dependency_config = fn.get_config()
dependencies.append(dependency_config)
config["page"][fn.page]["dependencies"].append(dependency_config["id"])

config["dependencies"] = dependencies
return config

Expand Down Expand Up @@ -1143,6 +1174,9 @@ def __init__(
self.root_path = os.environ.get("GRADIO_ROOT_PATH", "")
self.proxy_urls = set()

self.pages: list[tuple[str, str]] = [("", "Home")]
self.current_page = ""

if self.analytics_enabled:
is_custom_theme = not any(
self.theme.to_dict() == built_in_theme.to_dict()
Expand Down Expand Up @@ -1263,7 +1297,7 @@ def iterate_over_children(children_list):
original_mapping[0] = root_block = Context.root_block or blocks

if "layout" in config:
iterate_over_children(config["layout"]["children"])
iterate_over_children(config["layout"].get("children", []))

first_dependency = None

Expand Down Expand Up @@ -1427,6 +1461,8 @@ def render(self):
"At least one block in this Blocks has already been rendered."
)

for block in self.blocks.values():
block.page = Context.root_block.current_page
root_context.blocks.update(self.blocks)
dependency_offset = max(root_context.fns.keys(), default=-1) + 1
existing_api_names = [
Expand All @@ -1435,6 +1471,7 @@ def render(self):
if isinstance(dep.api_name, str)
]
for dependency in self.fns.values():
dependency.page = Context.root_block.current_page
dependency._id += dependency_offset
# Any event -- e.g. Blocks.load() -- that is triggered by this Blocks
# should now be triggered by the root Blocks instead.
Expand Down Expand Up @@ -2179,6 +2216,8 @@ def get_config_file(self) -> BlocksConfigDict:
"fill_width": self.fill_width,
"theme_hash": self.theme_hash,
"pwa": self.pwa,
"pages": self.pages,
"page": {},
}
config.update(self.default_config.get_config()) # type: ignore
config["connect_heartbeat"] = utils.connect_heartbeat(
Expand Down Expand Up @@ -2213,6 +2252,7 @@ def __exit__(self, exc_type: type[BaseException] | None = None, *args):
self.progress_tracking = any(
block_fn.tracks_progress for block_fn in self.fns.values()
)
self.page = ""
self.exited = True

def clear(self):
Expand Down Expand Up @@ -2261,7 +2301,6 @@ def queue(
blocks=self,
default_concurrency_limit=default_concurrency_limit,
)
self.config = self.get_config_file()
self.app = App.create_app(self)
return self

Expand Down Expand Up @@ -3039,3 +3078,43 @@ def get_event_targets(
event = getattr(block, event_name)
target_events.append(event)
return target_events

@document()
def route(self, name: str, path: str | None = None) -> Blocks:
"""
Adds a new page to the Blocks app.
Parameters:
name: The name of the page as it appears in the nav bar.
path: The URL suffix appended after your Gradio app's root URL to access this page (e.g. if path="/test", the page may be accessible e.g. at http://localhost:7860/test). If not provided, the path is generated from the name by converting to lowercase and replacing spaces with hyphens. Any leading or trailing forward slashes are stripped.
Example:
with gr.Blocks() as demo:
name = gr.Textbox(label="Name")
...
with demo.route("Test", "/test"):
num = gr.Number()
...
"""
if get_blocks_context():
raise ValueError(
"You cannot create a route while inside a Blocks() context. Call route() outside the Blocks() context (unindent this line)."
)

if path:
path = path.strip("/")
valid_path_regex = re.compile(r"^[a-zA-Z0-9-._~!$&'()*+,;=:@\[\]]+$")
if not valid_path_regex.match(path):
raise ValueError(
f"Path '{path}' contains invalid characters. Paths can only contain alphanumeric characters and the following special characters: -._~!$&'()*+,;=:@[]"
)
if path in INTERNAL_ROUTES:
raise ValueError(f"Route with path '{path}' already exists")
if path is None:
path = name.lower().replace(" ", "-")
path = "".join(
[letter for letter in path if letter.isalnum() or letter == "-"]
)
while path in INTERNAL_ROUTES or path in [page[0] for page in self.pages]:
path = "_" + path
self.pages.append((path, name))
self.current_page = path
return self
11 changes: 10 additions & 1 deletion gradio/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,13 @@ class BodyCSS(TypedDict):

class Layout(TypedDict):
id: int
children: list[int | Layout]
children: NotRequired[list[int | Layout]]


class Page(TypedDict):
components: list[int]
dependencies: list[int]
layout: Layout


class BlocksConfigDict(TypedDict):
Expand Down Expand Up @@ -386,6 +392,9 @@ class BlocksConfigDict(TypedDict):
username: NotRequired[str | None]
api_prefix: str
pwa: NotRequired[bool]
page: dict[str, Page]
pages: list[tuple[str, str]]
current_page: NotRequired[str]


class MediaStreamChunk(TypedDict):
Expand Down
Loading

0 comments on commit 2e8dc74

Please sign in to comment.