Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit deps and interval for computed vars #3231

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions integration/test_computed_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Test computed vars."""

from __future__ import annotations

import time
from typing import Generator

import pytest
from selenium.webdriver.common.by import By

from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver


def ComputedVars():
"""Test app for computed vars."""
import reflex as rx

class State(rx.State):
count: int = 0

# cached var with dep on count
@rx.cached_var(interval=15)
def count1(self) -> int:
return self.count

# same as above, different notation
@rx.var(interval=15, cache=True)
def count2(self) -> int:
return self.count

# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int:
# this will not add deps, because auto_deps is False
print(self.count1)
print(self.count2)

return self.count

# explicit dependency on count1 var
@rx.cached_var(deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int:
return self.count

@rx.var(deps=[count3], auto_deps=False, cache=True)
def depends_on_count3(self) -> int:
return self.count

def increment(self):
self.count += 1

def mark_dirty(self):
self._mark_dirty()

def index() -> rx.Component:
return rx.center(
rx.vstack(
rx.input(
id="token",
value=State.router.session.client_token,
is_read_only=True,
),
rx.button("Increment", on_click=State.increment, id="increment"),
rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
rx.text("count:"),
rx.text(State.count, id="count"),
rx.text("count1:"),
rx.text(State.count1, id="count1"),
rx.text("count2:"),
rx.text(State.count2, id="count2"),
rx.text("count3:"),
rx.text(State.count3, id="count3"),
rx.text("depends_on_count1:"),
rx.text(
State.depends_on_count1,
id="depends_on_count1",
),
rx.text("depends_on_count3:"),
rx.text(
State.depends_on_count3,
id="depends_on_count3",
),
),
)

# raise Exception(State.count3._deps(objclass=State))
app = rx.App()
app.add_page(index)


@pytest.fixture(scope="module")
def computed_vars(
tmp_path_factory,
) -> Generator[AppHarness, None, None]:
"""Start ComputedVars app at tmp_path via AppHarness.

Args:
tmp_path_factory: pytest tmp_path_factory fixture

Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp(f"computed_vars"),
app_source=ComputedVars, # type: ignore
) as harness:
yield harness


@pytest.fixture
def driver(computed_vars: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the computed_vars app.

Args:
computed_vars: harness for ComputedVars app

Yields:
WebDriver instance.
"""
assert computed_vars.app_instance is not None, "app is not running"
driver = computed_vars.frontend()
try:
yield driver
finally:
driver.quit()


@pytest.fixture()
def token(computed_vars: AppHarness, driver: WebDriver) -> str:
"""Get a function that returns the active token.

Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.

Returns:
The token for the connected client
"""
assert computed_vars.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input

# wait for the backend connection to send the token
token = computed_vars.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
assert token is not None

return token


def test_computed_vars(
computed_vars: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that computed vars are working as expected.

Args:
computed_vars: harness for ComputedVars app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert computed_vars.app_instance is not None

count = driver.find_element(By.ID, "count")
assert count
assert count.text == "0"

count1 = driver.find_element(By.ID, "count1")
assert count1
assert count1.text == "0"

count2 = driver.find_element(By.ID, "count2")
assert count2
assert count2.text == "0"

count3 = driver.find_element(By.ID, "count3")
assert count3
assert count3.text == "0"

depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
assert depends_on_count1
assert depends_on_count1.text == "0"

depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
assert depends_on_count3
assert depends_on_count3.text == "0"

increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()

mark_dirty = driver.find_element(By.ID, "mark_dirty")
assert mark_dirty.is_enabled()

mark_dirty.click()

increment.click()
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"

mark_dirty.click()
with pytest.raises(TimeoutError):
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")

time.sleep(10)
assert count3.text == "0"
assert depends_on_count3.text == "0"
mark_dirty.click()
assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
assert depends_on_count3.text == "1"
16 changes: 16 additions & 0 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,18 @@ def _mark_dirty_computed_vars(self) -> None:
if actual_var is not None:
actual_var.mark_dirty(instance=self)

def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.

Returns:
Set of computed vars to include in the delta.
"""
return set(
cvar
for cvar in self.computed_vars
if self.computed_vars[cvar].needs_update(instance=self)
)

def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.

Expand Down Expand Up @@ -1586,6 +1598,7 @@ def get_delta(self) -> Delta:
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masenf did I find a bug here? You mentioned that one could define a computed var without dependencies and manually mark it as dirty. If I am not mistaken, this line is needed for this to work.

.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars)
)
Expand Down Expand Up @@ -1619,6 +1632,9 @@ def _mark_dirty(self):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state._mark_dirty()

# Append expired computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._expired_computed_vars())

# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()
Expand Down
Loading
Loading