Skip to content

Commit

Permalink
Get default for backend var defined in mixin (#4060)
Browse files Browse the repository at this point in the history
* Get default for backend var defined in mixin

If the backend var is defined in a mixin class, it won't appear in
`cls.__dict__`, but the value is still retrievable via `getattr` on `cls`.
Prefer to use the actual defined default before using
`Var.get_default_value()`.

If `Var.get_default_value()` fails, set the default to `None` such that the
backend var still gets recognized as a backend var when it is used on `self`.

----

Update test_component_state to include backend vars

Extra coverage for backend vars with and without defaults, defined in a
ComponentState/mixin class.

* fix integration test
  • Loading branch information
masenf authored and simon committed Oct 23, 2024
1 parent 1d71cbb commit 765a281
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
24 changes: 22 additions & 2 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,10 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
for name, value in cls.__dict__.items()
if types.is_backend_base_variable(name, cls)
}
# Add annotated backend vars that do not have a default value.
# Add annotated backend vars that may not have a default value.
new_backend_vars.update(
{
name: Var("", _var_type=annotation_value).get_default_value()
name: cls._get_var_default(name, annotation_value)
for name, annotation_value in get_type_hints(cls).items()
if name not in new_backend_vars
and types.is_backend_base_variable(name, cls)
Expand Down Expand Up @@ -990,6 +990,26 @@ def _set_default_value(cls, prop: Var):
# Ensure frontend uses null coalescing when accessing.
object.__setattr__(prop, "_var_type", Optional[prop._var_type])

@classmethod
def _get_var_default(cls, name: str, annotation_value: Any) -> Any:
"""Get the default value of a (backend) var.
Args:
name: The name of the var.
annotation_value: The annotation value of the var.
Returns:
The default value of the var or None.
"""
try:
return getattr(cls, name)
except AttributeError:
try:
return Var("", _var_type=annotation_value).get_default_value()
except TypeError:
pass
return None

@staticmethod
def _get_base_functions() -> dict[str, FunctionType]:
"""Get all functions of the state class excluding dunder methods.
Expand Down
46 changes: 45 additions & 1 deletion tests/integration/test_component_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,29 @@
import pytest
from selenium.webdriver.common.by import By

from reflex.state import State, _substate_key
from reflex.testing import AppHarness

from . import utils


def ComponentStateApp():
"""App using per component state."""
from typing import Generic, TypeVar

import reflex as rx

class MultiCounter(rx.ComponentState):
E = TypeVar("E")

class MultiCounter(rx.ComponentState, Generic[E]):
count: int = 0
_be: E
_be_int: int
_be_str: str = "42"

def increment(self):
self.count += 1
self._be = self.count # type: ignore

@classmethod
def get_component(cls, *children, **props):
Expand Down Expand Up @@ -48,6 +57,14 @@ def index():
on_click=mc_a.State.increment, # type: ignore
id="inc-a",
),
rx.text(
mc_a.State.get_name() if mc_a.State is not None else "",
id="a_state_name",
),
rx.text(
mc_b.State.get_name() if mc_b.State is not None else "",
id="b_state_name",
),
)


Expand Down Expand Up @@ -80,13 +97,26 @@ async def test_component_state_app(component_state_app: AppHarness):

ss = utils.SessionStorage(driver)
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
root_state_token = _substate_key(ss.get("token"), State)

count_a = driver.find_element(By.ID, "count-a")
count_b = driver.find_element(By.ID, "count-b")
button_a = driver.find_element(By.ID, "button-a")
button_b = driver.find_element(By.ID, "button-b")
button_inc_a = driver.find_element(By.ID, "inc-a")

# Check that backend vars in mixins are okay
a_state_name = driver.find_element(By.ID, "a_state_name").text
b_state_name = driver.find_element(By.ID, "b_state_name").text
root_state = await component_state_app.get_state(root_state_token)
a_state = root_state.substates[a_state_name]
b_state = root_state.substates[b_state_name]
assert a_state._backend_vars == a_state.backend_vars
assert a_state._backend_vars == b_state._backend_vars
assert a_state._backend_vars["_be"] is None
assert a_state._backend_vars["_be_int"] == 0
assert a_state._backend_vars["_be_str"] == "42"

assert count_a.text == "0"

button_a.click()
Expand All @@ -98,10 +128,24 @@ async def test_component_state_app(component_state_app: AppHarness):
button_inc_a.click()
assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3"

root_state = await component_state_app.get_state(root_state_token)
a_state = root_state.substates[a_state_name]
b_state = root_state.substates[b_state_name]
assert a_state._backend_vars != a_state.backend_vars
assert a_state._be == a_state._backend_vars["_be"] == 3
assert b_state._be is None
assert b_state._backend_vars["_be"] is None

assert count_b.text == "0"

button_b.click()
assert component_state_app.poll_for_content(count_b, exp_not_equal="0") == "1"

button_b.click()
assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2"

root_state = await component_state_app.get_state(root_state_token)
a_state = root_state.substates[a_state_name]
b_state = root_state.substates[b_state_name]
assert b_state._backend_vars != b_state.backend_vars
assert b_state._be == b_state._backend_vars["_be"] == 2

0 comments on commit 765a281

Please sign in to comment.