Skip to content

Commit

Permalink
Explicit deps and interval for computed vars (#3231)
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher authored May 28, 2024
1 parent ac1c660 commit 93de407
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 5 deletions.
210 changes: 210 additions & 0 deletions integration/test_computed_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Test computed vars."""

from __future__ import annotations

import time
from typing import Generator

import pytest
from selenium.webdriver.common.by import By

from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver


def ComputedVars():
"""Test app for computed vars."""
import reflex as rx

class State(rx.State):
count: int = 0

# cached var with dep on count
@rx.cached_var(interval=15)
def count1(self) -> int:
return self.count

# same as above, different notation
@rx.var(interval=15, cache=True)
def count2(self) -> int:
return self.count

# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int:
# this will not add deps, because auto_deps is False
print(self.count1)
print(self.count2)

return self.count

# explicit dependency on count1 var
@rx.cached_var(deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int:
return self.count

@rx.var(deps=[count3], auto_deps=False, cache=True)
def depends_on_count3(self) -> int:
return self.count

def increment(self):
self.count += 1

def mark_dirty(self):
self._mark_dirty()

def index() -> rx.Component:
return rx.center(
rx.vstack(
rx.input(
id="token",
value=State.router.session.client_token,
is_read_only=True,
),
rx.button("Increment", on_click=State.increment, id="increment"),
rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
rx.text("count:"),
rx.text(State.count, id="count"),
rx.text("count1:"),
rx.text(State.count1, id="count1"),
rx.text("count2:"),
rx.text(State.count2, id="count2"),
rx.text("count3:"),
rx.text(State.count3, id="count3"),
rx.text("depends_on_count1:"),
rx.text(
State.depends_on_count1,
id="depends_on_count1",
),
rx.text("depends_on_count3:"),
rx.text(
State.depends_on_count3,
id="depends_on_count3",
),
),
)

# raise Exception(State.count3._deps(objclass=State))
app = rx.App()
app.add_page(index)


@pytest.fixture(scope="module")
def computed_vars(
tmp_path_factory,
) -> Generator[AppHarness, None, None]:
"""Start ComputedVars app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp(f"computed_vars"),
app_source=ComputedVars, # type: ignore
) as harness:
yield harness


@pytest.fixture
def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the computed_vars app.
Args:
computed_vars: harness for ComputedVars app
Yields:
WebDriver instance.
"""
assert computed_vars.app_instance is not None, "app is not running"
driver = computed_vars.frontend()
try:
yield driver
finally:
driver.quit()


@pytest.fixture()
def token(computed_vars: AppHarness, driver: WebDriver) -> str:
"""Get a function that returns the active token.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
Returns:
The token for the connected client
"""
assert computed_vars.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input

# wait for the backend connection to send the token
token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
assert token is not None

return token


def test_computed_vars(
computed_vars: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that computed vars are working as expected.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert computed_vars.app_instance is not None

count = driver.find_element(By.ID, "count")
assert count
assert count.text == "0"

count1 = driver.find_element(By.ID, "count1")
assert count1
assert count1.text == "0"

count2 = driver.find_element(By.ID, "count2")
assert count2
assert count2.text == "0"

count3 = driver.find_element(By.ID, "count3")
assert count3
assert count3.text == "0"

depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
assert depends_on_count1
assert depends_on_count1.text == "0"

depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
assert depends_on_count3
assert depends_on_count3.text == "0"

increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()

mark_dirty = driver.find_element(By.ID, "mark_dirty")
assert mark_dirty.is_enabled()

mark_dirty.click()

increment.click()
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"

mark_dirty.click()
with pytest.raises(TimeoutError):
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")

time.sleep(10)
assert count3.text == "0"
assert depends_on_count3.text == "0"
mark_dirty.click()
assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
assert depends_on_count3.text == "1"
16 changes: 16 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,18 @@ def _mark_dirty_computed_vars(self) -> None:
if actual_var is not None:
actual_var.mark_dirty(instance=self)

def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
Returns:
Set of computed vars to include in the delta.
"""
return set(
cvar
for cvar in self.computed_vars
if self.computed_vars[cvar].needs_update(instance=self)
)

def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
Expand Down Expand Up @@ -1588,6 +1600,7 @@ def get_delta(self) -> Delta:
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars))
.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars)
)
Expand Down Expand Up @@ -1621,6 +1634,9 @@ def _mark_dirty(self):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state._mark_dirty()

# Append expired computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._expired_computed_vars())

# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()
Expand Down
Loading

0 comments on commit 93de407

Please sign in to comment.