diff --git a/reflex/event.py b/reflex/event.py index 85a2541a59..312c9887f6 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -45,6 +45,8 @@ from reflex.vars.base import LiteralVar, Var from reflex.vars.function import ( ArgsFunctionOperation, + ArgsFunctionOperationBuilder, + BuilderFunctionVar, FunctionArgs, FunctionStringVar, FunctionVar, @@ -797,8 +799,7 @@ def scroll_to(elem_id: str, align_to_top: bool | Var[bool] = True) -> EventSpec: get_element_by_id = FunctionStringVar.create("document.getElementById") return run_script( - get_element_by_id(elem_id) - .call(elem_id) + get_element_by_id.call(elem_id) .to(ObjectVar) .scrollIntoView.to(FunctionVar) .call(align_to_top), @@ -1580,7 +1581,7 @@ def create( ) -class EventChainVar(FunctionVar, python_types=EventChain): +class EventChainVar(BuilderFunctionVar, python_types=EventChain): """Base class for event chain vars.""" @@ -1592,7 +1593,7 @@ class EventChainVar(FunctionVar, python_types=EventChain): # Note: LiteralVar is second in the inheritance list allowing it act like a # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the # _cached_var_name property. -class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar): +class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainVar): """A literal event chain var.""" _var_value: EventChain = dataclasses.field(default=None) # type: ignore diff --git a/reflex/utils/telemetry.py b/reflex/utils/telemetry.py index 806b916fcb..815d37a1be 100644 --- a/reflex/utils/telemetry.py +++ b/reflex/utils/telemetry.py @@ -51,7 +51,8 @@ def get_python_version() -> str: Returns: The Python version. """ - return platform.python_version() + # Remove the "+" from the version string in case user is using a pre-release version. + return platform.python_version().rstrip("+") def get_reflex_version() -> str: diff --git a/reflex/vars/base.py b/reflex/vars/base.py index b9aa55eb37..200f693def 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -361,21 +361,29 @@ def _var_is_string(self) -> bool: return False def __init_subclass__( - cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset, **kwargs + cls, + python_types: Tuple[GenericType, ...] | GenericType = types.Unset(), + default_type: GenericType = types.Unset(), + **kwargs, ): """Initialize the subclass. Args: python_types: The python types that the var represents. + default_type: The default type of the var. Defaults to the first python type. **kwargs: Additional keyword arguments. """ super().__init_subclass__(**kwargs) - if python_types is not types.Unset: + if python_types or default_type: python_types = ( - python_types if isinstance(python_types, tuple) else (python_types,) + (python_types if isinstance(python_types, tuple) else (python_types,)) + if python_types + else () ) + default_type = default_type or (python_types[0] if python_types else Any) + @dataclasses.dataclass( eq=False, frozen=True, @@ -388,7 +396,7 @@ class ToVarOperation(ToOperation, cls): default=Var(_js_expr="null", _var_type=None), ) - _default_var_type: ClassVar[GenericType] = python_types[0] + _default_var_type: ClassVar[GenericType] = default_type ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation' @@ -588,6 +596,12 @@ def to( output: type[list] | type[tuple] | type[set], ) -> ArrayVar: ... + @overload + def to( + self, + output: type[dict], + ) -> ObjectVar[dict]: ... + @overload def to( self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE] diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 98f3b23358..c65b38f707 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -4,32 +4,177 @@ import dataclasses import sys -from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload + +from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar from reflex.utils import format from reflex.utils.types import GenericType from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock +P = ParamSpec("P") +V1 = TypeVar("V1") +V2 = TypeVar("V2") +V3 = TypeVar("V3") +V4 = TypeVar("V4") +V5 = TypeVar("V5") +V6 = TypeVar("V6") +R = TypeVar("R") + + +class ReflexCallable(Protocol[P, R]): + """Protocol for a callable.""" + + __call__: Callable[P, R] + -class FunctionVar(Var[Callable], python_types=Callable): +CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True) +OTHER_CALLABLE_TYPE = TypeVar( + "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True +) + + +class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): """Base class for immutable function vars.""" - def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: - """Call the function with the given arguments. + @overload + def partial(self) -> FunctionVar[CALLABLE_TYPE]: ... + + @overload + def partial( + self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]], + arg1: Union[V1, Var[V1]], + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + @overload + def partial( + self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + @overload + def partial( + self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + @overload + def partial( + self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + @overload + def partial( + self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + arg5: Union[V5, Var[V5]], + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + @overload + def partial( + self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + arg5: Union[V5, Var[V5]], + arg6: Union[V6, Var[V6]], + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + @overload + def partial( + self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + @overload + def partial(self, *args: Var | Any) -> FunctionVar: ... + + def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore + """Partially apply the function with the given arguments. Args: - *args: The arguments to call the function with. + *args: The arguments to partially apply the function with. Returns: - The function call operation. + The partially applied function. """ + if not args: + return ArgsFunctionOperation.create((), self) return ArgsFunctionOperation.create( ("...args",), VarOperationCall.create(self, *args, Var(_js_expr="...args")), ) - def call(self, *args: Var | Any) -> VarOperationCall: + @overload + def call( + self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]] + ) -> VarOperationCall[[V1], R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> VarOperationCall[[V1, V2], R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> VarOperationCall[[V1, V2, V3], R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> VarOperationCall[[V1, V2, V3, V4], R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + arg5: Union[V5, Var[V5]], + ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + arg5: Union[V5, Var[V5]], + arg6: Union[V6, Var[V6]], + ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any + ) -> VarOperationCall[P, R]: ... + + @overload + def call(self, *args: Var | Any) -> Var: ... + + def call(self, *args: Var | Any) -> Var: # type: ignore """Call the function with the given arguments. Args: @@ -38,19 +183,29 @@ def call(self, *args: Var | Any) -> VarOperationCall: Returns: The function call operation. """ - return VarOperationCall.create(self, *args) + return VarOperationCall.create(self, *args).guess_type() + + __call__ = call + + +class BuilderFunctionVar( + FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any] +): + """Base class for immutable function vars with the builder pattern.""" + + __call__ = FunctionVar.partial -class FunctionStringVar(FunctionVar): +class FunctionStringVar(FunctionVar[CALLABLE_TYPE]): """Base class for immutable function vars from a string.""" @classmethod def create( cls, func: str, - _var_type: Type[Callable] = Callable, + _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any], _var_data: VarData | None = None, - ) -> FunctionStringVar: + ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]: """Create a new function var from a string. Args: @@ -60,7 +215,7 @@ def create( Returns: The function var. """ - return cls( + return FunctionStringVar( _js_expr=func, _var_type=_var_type, _var_data=_var_data, @@ -72,10 +227,10 @@ def create( frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class VarOperationCall(CachedVarOperation, Var): +class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]): """Base class for immutable vars that are the result of a function call.""" - _func: Optional[FunctionVar] = dataclasses.field(default=None) + _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None) _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) @cached_property_no_lock @@ -103,7 +258,7 @@ def _cached_get_all_var_data(self) -> VarData | None: @classmethod def create( cls, - func: FunctionVar, + func: FunctionVar[ReflexCallable[P, R]], *args: Var | Any, _var_type: GenericType = Any, _var_data: VarData | None = None, @@ -118,9 +273,15 @@ def create( Returns: The function call var. """ + function_return_type = ( + func._var_type.__args__[1] + if getattr(func._var_type, "__args__", None) + else Any + ) + var_type = _var_type if _var_type is not Any else function_return_type return cls( _js_expr="", - _var_type=_var_type, + _var_type=var_type, _var_data=_var_data, _func=func, _args=args, @@ -157,6 +318,33 @@ class FunctionArgs: rest: Optional[str] = None +def format_args_function_operation( + args: FunctionArgs, return_expr: Var | Any, explicit_return: bool +) -> str: + """Format an args function operation. + + Args: + args: The function arguments. + return_expr: The return expression. + explicit_return: Whether to use explicit return syntax. + + Returns: + The formatted args function operation. + """ + arg_names_str = ", ".join( + [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args] + ) + (f", ...{args.rest}" if args.rest else "") + + return_expr_str = str(LiteralVar.create(return_expr)) + + # Wrap return expression in curly braces if explicit return syntax is used. + return_expr_str_wrapped = ( + format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str + ) + + return f"(({arg_names_str}) => {return_expr_str_wrapped})" + + @dataclasses.dataclass( eq=False, frozen=True, @@ -176,24 +364,10 @@ def _cached_var_name(self) -> str: Returns: The name of the var. """ - arg_names_str = ", ".join( - [ - arg if isinstance(arg, str) else arg.to_javascript() - for arg in self._args.args - ] - ) + (f", ...{self._args.rest}" if self._args.rest else "") - - return_expr_str = str(LiteralVar.create(self._return_expr)) - - # Wrap return expression in curly braces if explicit return syntax is used. - return_expr_str_wrapped = ( - format.wrap(return_expr_str, "{", "}") - if self._explicit_return - else return_expr_str + return format_args_function_operation( + self._args, self._return_expr, self._explicit_return ) - return f"(({arg_names_str}) => {return_expr_str_wrapped})" - @classmethod def create( cls, @@ -203,7 +377,7 @@ def create( explicit_return: bool = False, _var_type: GenericType = Callable, _var_data: VarData | None = None, - ) -> ArgsFunctionOperation: + ): """Create a new function var. Args: @@ -226,8 +400,80 @@ def create( ) -JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify") -ARRAY_ISARRAY = FunctionStringVar.create("Array.isArray") -PROTOTYPE_TO_STRING = FunctionStringVar.create( - "((__to_string) => __to_string.toString())" +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, ) +class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): + """Base class for immutable function defined via arguments and return expression with the builder pattern.""" + + _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) + _return_expr: Union[Var, Any] = dataclasses.field(default=None) + _explicit_return: bool = dataclasses.field(default=False) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return format_args_function_operation( + self._args, self._return_expr, self._explicit_return + ) + + @classmethod + def create( + cls, + args_names: Sequence[Union[str, DestructuredArg]], + return_expr: Var | Any, + rest: str | None = None, + explicit_return: bool = False, + _var_type: GenericType = Callable, + _var_data: VarData | None = None, + ): + """Create a new function var. + + Args: + args_names: The names of the arguments. + return_expr: The return expression of the function. + rest: The name of the rest argument. + explicit_return: Whether to use explicit return syntax. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The function var. + """ + return cls( + _js_expr="", + _var_type=_var_type, + _var_data=_var_data, + _args=FunctionArgs(args=tuple(args_names), rest=rest), + _return_expr=return_expr, + _explicit_return=explicit_return, + ) + + +if python_version := sys.version_info[:2] >= (3, 10): + JSON_STRINGIFY = FunctionStringVar.create( + "JSON.stringify", _var_type=ReflexCallable[[Any], str] + ) + ARRAY_ISARRAY = FunctionStringVar.create( + "Array.isArray", _var_type=ReflexCallable[[Any], bool] + ) + PROTOTYPE_TO_STRING = FunctionStringVar.create( + "((__to_string) => __to_string.toString())", + _var_type=ReflexCallable[[Any], str], + ) +else: + JSON_STRINGIFY = FunctionStringVar.create( + "JSON.stringify", _var_type=ReflexCallable[Any, str] + ) + ARRAY_ISARRAY = FunctionStringVar.create( + "Array.isArray", _var_type=ReflexCallable[Any, bool] + ) + PROTOTYPE_TO_STRING = FunctionStringVar.create( + "((__to_string) => __to_string.toString())", + _var_type=ReflexCallable[Any, str], + ) diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 5944739213..4940246e79 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -928,7 +928,7 @@ def test_function_var(): == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))' ) - increment_func = addition_func(1) + increment_func = addition_func.partial(1) assert ( str(increment_func.call(2)) == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"