Skip to content

Commit

Permalink
unionize base var fields types (#4153)
Browse files Browse the repository at this point in the history
* unionize base var fields types

* add tests

* fix union types for vars (#4152)

* remove 3.11 special casing

* special case on version

* fix old versions of python

---------

Co-authored-by: Masen Furer <m_github@0x26.net>
  • Loading branch information
adhami3310 and masenf authored Oct 12, 2024
1 parent 0889276 commit b1d4498
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 27 deletions.
28 changes: 23 additions & 5 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,26 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes)


def unionize(*args: GenericType) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]


def is_none(cls: GenericType) -> bool:
"""Check if a class is None.
Expand Down Expand Up @@ -358,11 +378,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
return type_
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):
type_ = get_attribute_access_type(arg, name)
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
return unionize(
*(get_attribute_access_type(arg, name) for arg in get_args(cls))
)
elif isinstance(cls, type):
# Bare class
if sys.version_info >= (3, 10):
Expand Down
22 changes: 1 addition & 21 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
ParsedImportDict,
parse_imports,
)
from reflex.utils.types import GenericType, Self, get_origin, has_args
from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize

if TYPE_CHECKING:
from reflex.state import BaseState
Expand Down Expand Up @@ -1237,26 +1237,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]:
return wrapper


def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
if len(args) == 1:
return args[0]
# We are bisecting the args list here to avoid hitting the recursion limit
# In Python versions >= 3.11, we can simply do `return Union[*args]`
midpoint = len(args) // 2
first_half, second_half = args[:midpoint], args[midpoint:]
return Union[unionize(*first_half), unionize(*second_half)]


def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value.
Expand Down
4 changes: 3 additions & 1 deletion reflex/vars/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def __getattr__(self, name) -> Var:
var_type = get_args(var_type)[0]

fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if isclass(fixed_type) and not issubclass(fixed_type, dict):
if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
fixed_type in types.UnionTypes
):
attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None:
raise VarAttributeError(
Expand Down
39 changes: 39 additions & 0 deletions tests/units/test_var.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import math
import sys
import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast

Expand Down Expand Up @@ -398,6 +399,44 @@ def test_list_tuple_contains(var, expected):
assert str(var.contains(other_var)) == f"{expected}.includes(other)"


class Foo(rx.Base):
"""Foo class."""

bar: int
baz: str


class Bar(rx.Base):
"""Bar class."""

bar: str
baz: str
foo: int


@pytest.mark.parametrize(
("var", "var_type"),
(
[
(Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
(Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
]
if sys.version_info >= (3, 10)
else []
)
+ [
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
(
Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
Union[int, None],
),
],
)
def test_var_types(var, var_type):
assert var._var_type == var_type


@pytest.mark.parametrize(
"var, expected",
[
Expand Down

0 comments on commit b1d4498

Please sign in to comment.