Skip to content

Commit

Permalink
return a Dependency instance from Blocks.load event listener (#4304)
Browse files Browse the repository at this point in the history
* return a Dependency instance from Blocks.load event listener

* a test case for chaining then from load event

* update CHANGELOG

* add test for load.then with blocks re-used

* fixes

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
2 people authored and dawoodkhan82 committed Jun 2, 2023
1 parent aec21de commit 944ac1f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## New Features:

No changes to highlight.
* Make `Blocks.load` behave like other event listeners (allows chaining `then` off of it) [@anentropic](https://github.com/anentropic/) in [PR 4304](https://github.com/gradio-app/gradio/pull/4304)

## Bug Fixes:

Expand Down
5 changes: 4 additions & 1 deletion gradio/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:

for x in blocks.dependencies:
targets_telemetry = targets_telemetry + [
str(blocks.blocks[y]) for y in x["targets"]
# Sometimes the target can be the Blocks object itself, so we need to check if its in blocks.blocks
str(blocks.blocks[y])
for y in x["targets"]
if y in blocks.blocks
]
inputs_telemetry = inputs_telemetry + [
str(blocks.blocks[y]) for y in x["inputs"]
Expand Down
7 changes: 5 additions & 2 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,9 @@ def get_time():
name=name, src=src, hf_token=api_key, alias=alias, **kwargs
)
else:
return self_or_cls.set_event_trigger(
from gradio.events import Dependency

dep, dep_index = self_or_cls.set_event_trigger(
event_name="load",
fn=fn,
inputs=inputs,
Expand All @@ -1498,7 +1500,8 @@ def get_time():
max_batch_size=max_batch_size,
every=every,
no_target=True,
)[0]
)
return Dependency(self_or_cls, dep, dep_index)

def clear(self):
"""Resets the layout of the Blocks object."""
Expand Down
2 changes: 1 addition & 1 deletion gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ async def run_predict(
dependency = app.get_blocks().dependencies[fn_index_inferred]
target = dependency["targets"][0] if len(dependency["targets"]) else None
event_data = EventData(
app.get_blocks().blocks[target] if target else None,
app.get_blocks().blocks.get(target) if target else None,
body.event_data,
)
batch = dependency["batch"]
Expand Down
45 changes: 45 additions & 0 deletions test/test_events.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os

import pytest
from fastapi.testclient import TestClient

import gradio as gr

os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"


class TestEvent:
def test_clear_event(self):
Expand Down Expand Up @@ -69,6 +73,47 @@ def clear():
assert not parent.config["dependencies"][2]["trigger_only_on_success"]
assert parent.config["dependencies"][3]["trigger_only_on_success"]

def test_load_chaining(self):
calls = 0

def increment():
nonlocal calls
calls += 1
return str(calls)

with gr.Blocks() as demo:
out = gr.Textbox(label="Call counter")
demo.load(increment, inputs=None, outputs=out).then(
increment, inputs=None, outputs=out
)

assert demo.config["dependencies"][0]["trigger"] == "load"
assert demo.config["dependencies"][0]["trigger_after"] is None
assert demo.config["dependencies"][1]["trigger"] == "then"
assert demo.config["dependencies"][1]["trigger_after"] == 0

def test_load_chaining_reuse(self):
calls = 0

def increment():
nonlocal calls
calls += 1
return str(calls)

with gr.Blocks() as demo:
out = gr.Textbox(label="Call counter")
demo.load(increment, inputs=None, outputs=out).then(
increment, inputs=None, outputs=out
)

with gr.Blocks() as demo2:
demo.render()

assert demo2.config["dependencies"][0]["trigger"] == "load"
assert demo2.config["dependencies"][0]["trigger_after"] is None
assert demo2.config["dependencies"][1]["trigger"] == "then"
assert demo2.config["dependencies"][1]["trigger_after"] == 0


class TestEventErrors:
def test_event_defined_invalid_scope(self):
Expand Down

0 comments on commit 944ac1f

Please sign in to comment.