diff --git a/integration/test_computed_vars.py b/integration/test_computed_vars.py new file mode 100644 index 00000000000..c0f0cfb51fe --- /dev/null +++ b/integration/test_computed_vars.py @@ -0,0 +1,182 @@ +"""Test computed vars.""" + +from __future__ import annotations + +import time +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver + + +def ComputedVars(): + """Test app for computed vars.""" + import reflex as rx + + class State(rx.State): + count: int = 0 + + # cached var with dep on count + @rx.cached_var(interval=15) + def count1(self) -> int: + return self.count + + # same as above, different notation + @rx.var(interval=15, cache=True) + def count2(self) -> int: + return self.count + + # explicit disabled auto_deps + @rx.var(interval=15, cache=True, auto_deps=False) + def count3(self) -> int: + # this will not add deps, because auto_deps is False + print(self.count1) + print(self.count2) + + return self.count + + # explicit dependency on count1 var + @rx.cached_var(deps=[count1]) + def depends_on_count1(self) -> int: + return self.count + + def increment(self): + self.count += 1 + + def do_nothing(self): + pass + + def index() -> rx.Component: + return rx.center( + rx.vstack( + rx.input( + id="token", + value=State.router.session.client_token, + is_read_only=True, + ), + rx.button("Increment", on_click=State.increment, id="increment"), + rx.button("Do nothing", on_click=State.do_nothing, id="do_nothing"), + rx.text(State.count, id="count"), + rx.text(State.count1, id="count1"), + rx.text(State.count2, id="count2"), + rx.text(State.count3, id="count3"), + rx.text( + State.depends_on_count1, + id="depends_on_count1", + ), + ), + ) + + app = rx.App() + app.add_page(index) + + +@pytest.fixture(scope="module") +def computed_vars( + tmp_path_factory, +) -> Generator[AppHarness, None, None]: + """Start ComputedVars app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp(f"computed_vars"), + app_source=ComputedVars, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the computed_vars app. + + Args: + computed_vars: harness for ComputedVars app + + Yields: + WebDriver instance. + """ + assert computed_vars.app_instance is not None, "app is not running" + driver = computed_vars.frontend() + try: + yield driver + finally: + driver.quit() + + +@pytest.fixture() +def token(computed_vars: AppHarness, driver: WebDriver) -> str: + """Get a function that returns the active token. + + Args: + computed_vars: harness for ComputedVars app. + driver: WebDriver instance. + + Returns: + The token for the connected client + """ + assert computed_vars.app_instance is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2) + assert token is not None + + return token + + +def test_computed_vars( + computed_vars: AppHarness, + driver: WebDriver, + token: str, +): + """Test that computed vars are working as expected. + + Args: + computed_vars: harness for ComputedVars app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert computed_vars.app_instance is not None + + count = driver.find_element(By.ID, "count") + assert count + assert count.text == "0" + + count1 = driver.find_element(By.ID, "count1") + assert count1 + assert count1.text == "0" + + count2 = driver.find_element(By.ID, "count2") + assert count2 + assert count2.text == "0" + + count3 = driver.find_element(By.ID, "count3") + assert count3 + assert count3.text == "0" + + increment = driver.find_element(By.ID, "increment") + assert increment.is_enabled() + + do_nothing = driver.find_element(By.ID, "do_nothing") + assert do_nothing.is_enabled() + + increment.click() + assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1" + assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1" + assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1" + + do_nothing.click() + with pytest.raises(TimeoutError): + computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0") + + time.sleep(10) + do_nothing.click() + assert computed_vars.poll_for_content(count3, timeout=2) == "1" diff --git a/reflex/vars.py b/reflex/vars.py index dfa3f09e8a7..7a9583a2055 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1861,7 +1861,7 @@ def __init__( fget: Callable[[BaseState], Any], initial_value: Any | types.Unset = types.Unset(), cache: bool = False, - deps: Optional[set[Union[str, Var]]] = None, + deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[int, datetime.timedelta]] = None, **kwargs, @@ -1883,7 +1883,7 @@ def __init__( interval = datetime.timedelta(seconds=interval) self._update_interval = interval if deps is None: - deps = set() + deps = [] self._static_deps = { dep._var_name if isinstance(dep, Var) else dep for dep in deps } @@ -2085,7 +2085,7 @@ def computed_var( fget: Callable[[BaseState], Any] | None = None, initial_value: Any | None = None, cache: bool = False, - deps: Optional[set[Union[str, Var]]] = None, + deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[datetime.timedelta, int]] = None, **kwargs, diff --git a/reflex/vars.pyi b/reflex/vars.pyi index fb2ed465734..d10fee847d0 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -2,6 +2,7 @@ from __future__ import annotations +import datetime from dataclasses import dataclass from _typeshed import Incomplete from reflex import constants as constants @@ -154,10 +155,24 @@ class ComputedVar(Var): def computed_var( fget: Callable[[BaseState], Any] | None = None, initial_value: Any | None = None, + cache: bool = False, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[datetime.timedelta, int]] = None, **kwargs, ) -> Callable[[Callable[[Any], Any]], ComputedVar]: ... @overload def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ... +@overload +def cached_var( + fget: Callable[[BaseState], Any] | None = None, + initial_value: Any | None = None, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[datetime.timedelta, int]] = None, + **kwargs, +) -> Callable[[Callable[[Any], Any]], ComputedVar]: ... +@overload def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... class CallableVar(BaseVar):