Skip to content

Commit

Permalink
fix typing, add simple test
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher committed May 5, 2024
1 parent 5e0a784 commit 308b595
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 3 deletions.
182 changes: 182 additions & 0 deletions integration/test_computed_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Test computed vars."""

from __future__ import annotations

import time
from typing import Generator

import pytest
from selenium.webdriver.common.by import By

from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver


def ComputedVars():
"""Test app for computed vars."""
import reflex as rx

class State(rx.State):
count: int = 0

# cached var with dep on count
@rx.cached_var(interval=15)
def count1(self) -> int:
return self.count

# same as above, different notation
@rx.var(interval=15, cache=True)
def count2(self) -> int:
return self.count

# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int:
# this will not add deps, because auto_deps is False
print(self.count1)
print(self.count2)

return self.count

# explicit dependency on count1 var
@rx.cached_var(deps=[count1])
def depends_on_count1(self) -> int:
return self.count

def increment(self):
self.count += 1

def do_nothing(self):
pass

def index() -> rx.Component:
return rx.center(
rx.vstack(
rx.input(
id="token",
value=State.router.session.client_token,
is_read_only=True,
),
rx.button("Increment", on_click=State.increment, id="increment"),
rx.button("Do nothing", on_click=State.do_nothing, id="do_nothing"),
rx.text(State.count, id="count"),
rx.text(State.count1, id="count1"),
rx.text(State.count2, id="count2"),
rx.text(State.count3, id="count3"),
rx.text(
State.depends_on_count1,
id="depends_on_count1",
),
),
)

app = rx.App()
app.add_page(index)


@pytest.fixture(scope="module")
def computed_vars(
tmp_path_factory,
) -> Generator[AppHarness, None, None]:
"""Start ComputedVars app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp(f"computed_vars"),
app_source=ComputedVars, # type: ignore
) as harness:
yield harness


@pytest.fixture
def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the computed_vars app.
Args:
computed_vars: harness for ComputedVars app
Yields:
WebDriver instance.
"""
assert computed_vars.app_instance is not None, "app is not running"
driver = computed_vars.frontend()
try:
yield driver
finally:
driver.quit()


@pytest.fixture()
def token(computed_vars: AppHarness, driver: WebDriver) -> str:
"""Get a function that returns the active token.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
Returns:
The token for the connected client
"""
assert computed_vars.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input

# wait for the backend connection to send the token
token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
assert token is not None

return token


def test_computed_vars(
computed_vars: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that computed vars are working as expected.
Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert computed_vars.app_instance is not None

count = driver.find_element(By.ID, "count")
assert count
assert count.text == "0"

count1 = driver.find_element(By.ID, "count1")
assert count1
assert count1.text == "0"

count2 = driver.find_element(By.ID, "count2")
assert count2
assert count2.text == "0"

count3 = driver.find_element(By.ID, "count3")
assert count3
assert count3.text == "0"

increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()

do_nothing = driver.find_element(By.ID, "do_nothing")
assert do_nothing.is_enabled()

increment.click()
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"

do_nothing.click()
with pytest.raises(TimeoutError):
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")

time.sleep(10)
do_nothing.click()
assert computed_vars.poll_for_content(count3, timeout=2) == "1"
6 changes: 3 additions & 3 deletions reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,7 +1861,7 @@ def __init__(
fget: Callable[[BaseState], Any],
initial_value: Any | types.Unset = types.Unset(),
cache: bool = False,
deps: Optional[set[Union[str, Var]]] = None,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
**kwargs,
Expand All @@ -1883,7 +1883,7 @@ def __init__(
interval = datetime.timedelta(seconds=interval)
self._update_interval = interval
if deps is None:
deps = set()
deps = []
self._static_deps = {
dep._var_name if isinstance(dep, Var) else dep for dep in deps
}
Expand Down Expand Up @@ -2085,7 +2085,7 @@ def computed_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
cache: bool = False,
deps: Optional[set[Union[str, Var]]] = None,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs,
Expand Down
15 changes: 15 additions & 0 deletions reflex/vars.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import datetime
from dataclasses import dataclass
from _typeshed import Incomplete
from reflex import constants as constants
Expand Down Expand Up @@ -154,10 +155,24 @@ class ComputedVar(Var):
def computed_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload
def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ...
@overload
def cached_var(
fget: Callable[[BaseState], Any] | None = None,
initial_value: Any | None = None,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
**kwargs,
) -> Callable[[Callable[[Any], Any]], ComputedVar]: ...
@overload
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ...

class CallableVar(BaseVar):
Expand Down

0 comments on commit 308b595

Please sign in to comment.