Skip to content

Commit

Permalink
correctly determine expired computed vars + minor typing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher committed May 7, 2024
1 parent bb7fc06 commit 4e31200
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 23 deletions.
46 changes: 37 additions & 9 deletions integration/test_computed_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,19 @@ def count3(self) -> int:
return self.count

# explicit dependency on count1 var
@rx.cached_var(deps=[count1])
@rx.cached_var(deps=[count1], auto_deps=False)
def depends_on_count1(self) -> int:
return self.count

@rx.var(deps=[count3], auto_deps=False, cache=True)
def depends_on_count3(self) -> int:
return self.count

def increment(self):
self.count += 1

def do_nothing(self):
pass
def mark_dirty(self):
self._mark_dirty()

def index() -> rx.Component:
return rx.center(
Expand All @@ -57,18 +61,29 @@ def index() -> rx.Component:
is_read_only=True,
),
rx.button("Increment", on_click=State.increment, id="increment"),
rx.button("Do nothing", on_click=State.do_nothing, id="do_nothing"),
rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"),
rx.text("count:"),
rx.text(State.count, id="count"),
rx.text("count1:"),
rx.text(State.count1, id="count1"),
rx.text("count2:"),
rx.text(State.count2, id="count2"),
rx.text("count3:"),
rx.text(State.count3, id="count3"),
rx.text("depends_on_count1:"),
rx.text(
State.depends_on_count1,
id="depends_on_count1",
),
rx.text("depends_on_count3:"),
rx.text(
State.depends_on_count3,
id="depends_on_count3",
),
),
)

# raise Exception(State.count3._deps(objclass=State))
app = rx.App()
app.add_page(index)

Expand Down Expand Up @@ -162,21 +177,34 @@ def test_computed_vars(
assert count3
assert count3.text == "0"

depends_on_count1 = driver.find_element(By.ID, "depends_on_count1")
assert depends_on_count1
assert depends_on_count1.text == "0"

depends_on_count3 = driver.find_element(By.ID, "depends_on_count3")
assert depends_on_count3
assert depends_on_count3.text == "0"

increment = driver.find_element(By.ID, "increment")
assert increment.is_enabled()

do_nothing = driver.find_element(By.ID, "do_nothing")
assert do_nothing.is_enabled()
mark_dirty = driver.find_element(By.ID, "mark_dirty")
assert mark_dirty.is_enabled()

mark_dirty.click()

increment.click()
assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1"
assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1"

do_nothing.click()
mark_dirty.click()
with pytest.raises(TimeoutError):
computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0")

time.sleep(10)
do_nothing.click()
assert computed_vars.poll_for_content(count3, timeout=2) == "1"
assert count3.text == "0"
assert depends_on_count3.text == "0"
mark_dirty.click()
assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1"
assert depends_on_count3.text == "1"
18 changes: 17 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def _init_var_dependency_dicts(cls):
cls._always_dirty_computed_vars = set(
cvar_name
for cvar_name, cvar in cls.computed_vars.items()
if cvar.always_dirty
if not cvar._cache
)

# Any substate containing a ComputedVar with cache=False always needs to be recomputed
Expand Down Expand Up @@ -1633,6 +1633,18 @@ def _mark_dirty_computed_vars(self) -> None:
if actual_var is not None:
actual_var.mark_dirty(instance=self)

def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
Returns:
Set of computed vars to include in the delta.
"""
return set(
cvar
for cvar in self.computed_vars
if self.computed_vars[cvar].needs_update(instance=self)
)

def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
Expand Down Expand Up @@ -1685,6 +1697,7 @@ def get_delta(self) -> Delta:
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars))
.union(self._dirty_computed_vars())
.union(self._always_dirty_computed_vars)
)
Expand Down Expand Up @@ -1718,6 +1731,9 @@ def _mark_dirty(self):
self.parent_state.dirty_substates.add(self.get_name())
self.parent_state._mark_dirty()

# Append expired computed vars to dirty_vars to trigger recalculation
self.dirty_vars.update(self._expired_computed_vars())

# have to mark computed vars dirty to allow access to newly computed
# values within the same ComputedVar function
self._mark_dirty_computed_vars()
Expand Down
17 changes: 4 additions & 13 deletions reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,16 +1911,7 @@ def _last_updated_attr(self) -> str:
"""
return f"__last_updated_{self._var_name}"

@property
def always_dirty(self) -> bool:
"""Whether the computed var should always be marked as dirty.
Returns:
True if the computed var should always be marked as dirty, False otherwise.
"""
return not self._cache or self._update_interval is not None

def _needs_update(self, instance) -> bool:
def needs_update(self, instance: BaseState) -> bool:
"""Check if the computed var needs to be updated.
Args:
Expand All @@ -1936,7 +1927,7 @@ def _needs_update(self, instance) -> bool:
return True
return datetime.datetime.now() - last_updated > self._update_interval

def __get__(self, instance, owner):
def __get__(self, instance: BaseState | None, owner):
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
Expand All @@ -1952,7 +1943,7 @@ def __get__(self, instance, owner):
return super().__get__(instance, owner)

# handle caching
if not hasattr(instance, self._cache_attr) or self._needs_update(instance):
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, super().__get__(instance, owner))
# Ensure the computed var gets serialized to redis.
Expand Down Expand Up @@ -2113,7 +2104,7 @@ def computed_var(
if fget is not None:
return ComputedVar(fget=fget, cache=cache)

def wrapper(fget):
def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar:
return ComputedVar(
fget=fget,
initial_value=initial_value,
Expand Down
1 change: 1 addition & 0 deletions reflex/vars.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ComputedVar(Var):
def __get__(self, instance, owner): ...
def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ...
def mark_dirty(self, instance) -> None: ...
def needs_update(self, instance) -> bool: ...
def _determine_var_type(self) -> Type: ...
@overload
def __init__(
Expand Down

0 comments on commit 4e31200

Please sign in to comment.