From 4e312004926bbc1150de9e9449248f1b10f7879a Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Tue, 7 May 2024 02:37:28 +0200 Subject: [PATCH] correctly determine expired computed vars + minor typing improvements --- integration/test_computed_vars.py | 46 +++++++++++++++++++++++++------ reflex/state.py | 18 +++++++++++- reflex/vars.py | 17 +++--------- reflex/vars.pyi | 1 + 4 files changed, 59 insertions(+), 23 deletions(-) diff --git a/integration/test_computed_vars.py b/integration/test_computed_vars.py index c0f0cfb51fe..2523248a905 100644 --- a/integration/test_computed_vars.py +++ b/integration/test_computed_vars.py @@ -38,15 +38,19 @@ def count3(self) -> int: return self.count # explicit dependency on count1 var - @rx.cached_var(deps=[count1]) + @rx.cached_var(deps=[count1], auto_deps=False) def depends_on_count1(self) -> int: return self.count + @rx.var(deps=[count3], auto_deps=False, cache=True) + def depends_on_count3(self) -> int: + return self.count + def increment(self): self.count += 1 - def do_nothing(self): - pass + def mark_dirty(self): + self._mark_dirty() def index() -> rx.Component: return rx.center( @@ -57,18 +61,29 @@ def index() -> rx.Component: is_read_only=True, ), rx.button("Increment", on_click=State.increment, id="increment"), - rx.button("Do nothing", on_click=State.do_nothing, id="do_nothing"), + rx.button("Do nothing", on_click=State.mark_dirty, id="mark_dirty"), + rx.text("count:"), rx.text(State.count, id="count"), + rx.text("count1:"), rx.text(State.count1, id="count1"), + rx.text("count2:"), rx.text(State.count2, id="count2"), + rx.text("count3:"), rx.text(State.count3, id="count3"), + rx.text("depends_on_count1:"), rx.text( State.depends_on_count1, id="depends_on_count1", ), + rx.text("depends_on_count3:"), + rx.text( + State.depends_on_count3, + id="depends_on_count3", + ), ), ) + # raise Exception(State.count3._deps(objclass=State)) app = rx.App() app.add_page(index) @@ -162,21 +177,34 @@ def test_computed_vars( assert count3 assert count3.text == "0" + depends_on_count1 = driver.find_element(By.ID, "depends_on_count1") + assert depends_on_count1 + assert depends_on_count1.text == "0" + + depends_on_count3 = driver.find_element(By.ID, "depends_on_count3") + assert depends_on_count3 + assert depends_on_count3.text == "0" + increment = driver.find_element(By.ID, "increment") assert increment.is_enabled() - do_nothing = driver.find_element(By.ID, "do_nothing") - assert do_nothing.is_enabled() + mark_dirty = driver.find_element(By.ID, "mark_dirty") + assert mark_dirty.is_enabled() + + mark_dirty.click() increment.click() assert computed_vars.poll_for_content(count, timeout=2, exp_not_equal="0") == "1" assert computed_vars.poll_for_content(count1, timeout=2, exp_not_equal="0") == "1" assert computed_vars.poll_for_content(count2, timeout=2, exp_not_equal="0") == "1" - do_nothing.click() + mark_dirty.click() with pytest.raises(TimeoutError): computed_vars.poll_for_content(count3, timeout=5, exp_not_equal="0") time.sleep(10) - do_nothing.click() - assert computed_vars.poll_for_content(count3, timeout=2) == "1" + assert count3.text == "0" + assert depends_on_count3.text == "0" + mark_dirty.click() + assert computed_vars.poll_for_content(count3, timeout=2, exp_not_equal="0") == "1" + assert depends_on_count3.text == "1" diff --git a/reflex/state.py b/reflex/state.py index 904f0168743..fbec6bacab7 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -660,7 +660,7 @@ def _init_var_dependency_dicts(cls): cls._always_dirty_computed_vars = set( cvar_name for cvar_name, cvar in cls.computed_vars.items() - if cvar.always_dirty + if not cvar._cache ) # Any substate containing a ComputedVar with cache=False always needs to be recomputed @@ -1633,6 +1633,18 @@ def _mark_dirty_computed_vars(self) -> None: if actual_var is not None: actual_var.mark_dirty(instance=self) + def _expired_computed_vars(self) -> set[str]: + """Determine ComputedVars that need to be recalculated based on the expiration time. + + Returns: + Set of computed vars to include in the delta. + """ + return set( + cvar + for cvar in self.computed_vars + if self.computed_vars[cvar].needs_update(instance=self) + ) + def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]: """Determine ComputedVars that need to be recalculated based on the given vars. @@ -1685,6 +1697,7 @@ def get_delta(self) -> Delta: # and always dirty computed vars (cache=False) delta_vars = ( self.dirty_vars.intersection(self.base_vars) + .union(self.dirty_vars.intersection(self.computed_vars)) .union(self._dirty_computed_vars()) .union(self._always_dirty_computed_vars) ) @@ -1718,6 +1731,9 @@ def _mark_dirty(self): self.parent_state.dirty_substates.add(self.get_name()) self.parent_state._mark_dirty() + # Append expired computed vars to dirty_vars to trigger recalculation + self.dirty_vars.update(self._expired_computed_vars()) + # have to mark computed vars dirty to allow access to newly computed # values within the same ComputedVar function self._mark_dirty_computed_vars() diff --git a/reflex/vars.py b/reflex/vars.py index 7a9583a2055..fd6da766df4 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1911,16 +1911,7 @@ def _last_updated_attr(self) -> str: """ return f"__last_updated_{self._var_name}" - @property - def always_dirty(self) -> bool: - """Whether the computed var should always be marked as dirty. - - Returns: - True if the computed var should always be marked as dirty, False otherwise. - """ - return not self._cache or self._update_interval is not None - - def _needs_update(self, instance) -> bool: + def needs_update(self, instance: BaseState) -> bool: """Check if the computed var needs to be updated. Args: @@ -1936,7 +1927,7 @@ def _needs_update(self, instance) -> bool: return True return datetime.datetime.now() - last_updated > self._update_interval - def __get__(self, instance, owner): + def __get__(self, instance: BaseState | None, owner): """Get the ComputedVar value. If the value is already cached on the instance, return the cached value. @@ -1952,7 +1943,7 @@ def __get__(self, instance, owner): return super().__get__(instance, owner) # handle caching - if not hasattr(instance, self._cache_attr) or self._needs_update(instance): + if not hasattr(instance, self._cache_attr) or self.needs_update(instance): # Set cache attr on state instance. setattr(instance, self._cache_attr, super().__get__(instance, owner)) # Ensure the computed var gets serialized to redis. @@ -2113,7 +2104,7 @@ def computed_var( if fget is not None: return ComputedVar(fget=fget, cache=cache) - def wrapper(fget): + def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar: return ComputedVar( fget=fget, initial_value=initial_value, diff --git a/reflex/vars.pyi b/reflex/vars.pyi index d10fee847d0..bbfb846263a 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -141,6 +141,7 @@ class ComputedVar(Var): def __get__(self, instance, owner): ... def _deps(self, objclass: Type, obj: Optional[FunctionType] = ...) -> Set[str]: ... def mark_dirty(self, instance) -> None: ... + def needs_update(self, instance) -> bool: ... def _determine_var_type(self) -> Type: ... @overload def __init__(