diff --git a/reflex/components/component.py b/reflex/components/component.py index 399becee92..85db3906dc 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -480,6 +480,7 @@ def __init__(self, *args, **kwargs): kwargs["event_triggers"][key] = self._create_event_chain( value=value, # type: ignore args_spec=component_specific_triggers[key], + key=key, ) # Remove any keys that were added as events. @@ -540,12 +541,14 @@ def _create_event_chain( List[Union[EventHandler, EventSpec, EventVar]], Callable, ], + key: Optional[str] = None, ) -> Union[EventChain, Var]: """Create an event chain from a variety of input types. Args: args_spec: The args_spec of the event trigger being bound. value: The value to create the event chain from. + key: The key of the event trigger being bound. Returns: The event chain. @@ -560,7 +563,7 @@ def _create_event_chain( elif isinstance(value, EventVar): value = [value] elif issubclass(value._var_type, (EventChain, EventSpec)): - return self._create_event_chain(args_spec, value.guess_type()) + return self._create_event_chain(args_spec, value.guess_type(), key=key) else: raise ValueError( f"Invalid event chain: {str(value)} of type {value._var_type}" @@ -579,10 +582,10 @@ def _create_event_chain( for v in value: if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. - events.append(call_event_handler(v, args_spec)) + events.append(call_event_handler(v, args_spec, key=key)) elif isinstance(v, Callable): # Call the lambda to get the event chain. - result = call_event_fn(v, args_spec) + result = call_event_fn(v, args_spec, key=key) if isinstance(result, Var): raise ValueError( f"Invalid event chain: {v}. Cannot use a Var-returning " @@ -599,7 +602,7 @@ def _create_event_chain( result = call_event_fn(value, args_spec) if isinstance(result, Var): # Recursively call this function if the lambda returned an EventChain Var. - return self._create_event_chain(args_spec, result) + return self._create_event_chain(args_spec, result, key=key) events = [*result] # Otherwise, raise an error. @@ -1722,6 +1725,7 @@ def __init__(self, *args, **kwargs): args_spec=event_triggers_in_component_declaration.get( key, empty_event ), + key=key, ) self.props[format.to_camel_case(key)] = value continue diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 7cb776ee9a..4caf14b414 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -111,6 +111,15 @@ def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]: return (FORM_DATA,) +def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]: + """Event handler spec for the on_submit event. + + Returns: + The event handler spec. + """ + return (FORM_DATA,) + + class Form(BaseHTML): """Display the form element.""" @@ -150,7 +159,7 @@ class Form(BaseHTML): handle_submit_unique_name: Var[str] # Fired when the form is submitted - on_submit: EventHandler[on_submit_event_spec] + on_submit: EventHandler[on_submit_event_spec, on_submit_string_event_spec] @classmethod def create(cls, *children, **props): diff --git a/reflex/components/el/elements/forms.pyi b/reflex/components/el/elements/forms.pyi index bc9bc9689a..a8e9b6174a 100644 --- a/reflex/components/el/elements/forms.pyi +++ b/reflex/components/el/elements/forms.pyi @@ -271,6 +271,7 @@ class Fieldset(Element): ... def on_submit_event_spec() -> Tuple[Var[Dict[str, Any]]]: ... +def on_submit_string_event_spec() -> Tuple[Var[Dict[str, str]]]: ... class Form(BaseHTML): @overload @@ -337,7 +338,9 @@ class Form(BaseHTML): on_mouse_over: Optional[EventType[[]]] = None, on_mouse_up: Optional[EventType[[]]] = None, on_scroll: Optional[EventType[[]]] = None, - on_submit: Optional[EventType[Dict[str, Any]]] = None, + on_submit: Optional[ + Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]] + ] = None, on_unmount: Optional[EventType[[]]] = None, **props, ) -> "Form": diff --git a/reflex/components/radix/primitives/form.pyi b/reflex/components/radix/primitives/form.pyi index c4dce0a364..72595a9338 100644 --- a/reflex/components/radix/primitives/form.pyi +++ b/reflex/components/radix/primitives/form.pyi @@ -129,7 +129,9 @@ class FormRoot(FormComponent, HTMLForm): on_mouse_over: Optional[EventType[[]]] = None, on_mouse_up: Optional[EventType[[]]] = None, on_scroll: Optional[EventType[[]]] = None, - on_submit: Optional[EventType[Dict[str, Any]]] = None, + on_submit: Optional[ + Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]] + ] = None, on_unmount: Optional[EventType[[]]] = None, **props, ) -> "FormRoot": @@ -596,7 +598,9 @@ class Form(FormRoot): on_mouse_over: Optional[EventType[[]]] = None, on_mouse_up: Optional[EventType[[]]] = None, on_scroll: Optional[EventType[[]]] = None, - on_submit: Optional[EventType[Dict[str, Any]]] = None, + on_submit: Optional[ + Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]] + ] = None, on_unmount: Optional[EventType[[]]] = None, **props, ) -> "Form": @@ -720,7 +724,9 @@ class FormNamespace(ComponentNamespace): on_mouse_over: Optional[EventType[[]]] = None, on_mouse_up: Optional[EventType[[]]] = None, on_scroll: Optional[EventType[[]]] = None, - on_submit: Optional[EventType[Dict[str, Any]]] = None, + on_submit: Optional[ + Union[EventType[Dict[str, Any]], EventType[Dict[str, str]]] + ] = None, on_unmount: Optional[EventType[[]]] = None, **props, ) -> "Form": diff --git a/reflex/components/radix/themes/components/slider.py b/reflex/components/radix/themes/components/slider.py index bf0e5c454a..bb017ea736 100644 --- a/reflex/components/radix/themes/components/slider.py +++ b/reflex/components/radix/themes/components/slider.py @@ -2,11 +2,11 @@ from __future__ import annotations -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Union from reflex.components.component import Component from reflex.components.core.breakpoints import Responsive -from reflex.event import EventHandler +from reflex.event import EventHandler, identity_event from reflex.vars.base import Var from ..base import ( @@ -14,19 +14,11 @@ RadixThemesComponent, ) - -def on_value_event_spec( - value: Var[List[Union[int, float]]], -) -> Tuple[Var[List[Union[int, float]]]]: - """Event handler spec for the value event. - - Args: - value: The value of the event. - - Returns: - The event handler spec. - """ - return (value,) # type: ignore +on_value_event_spec = ( + identity_event(list[Union[int, float]]), + identity_event(list[int]), + identity_event(list[float]), +) class Slider(RadixThemesComponent): diff --git a/reflex/components/radix/themes/components/slider.pyi b/reflex/components/radix/themes/components/slider.pyi index 5ac3c275f9..b2f155fe6b 100644 --- a/reflex/components/radix/themes/components/slider.pyi +++ b/reflex/components/radix/themes/components/slider.pyi @@ -3,18 +3,20 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload +from typing import Any, Dict, List, Literal, Optional, Union, overload from reflex.components.core.breakpoints import Breakpoints -from reflex.event import EventType +from reflex.event import EventType, identity_event from reflex.style import Style from reflex.vars.base import Var from ..base import RadixThemesComponent -def on_value_event_spec( - value: Var[List[Union[int, float]]], -) -> Tuple[Var[List[Union[int, float]]]]: ... +on_value_event_spec = ( + identity_event(list[Union[int, float]]), + identity_event(list[int]), + identity_event(list[float]), +) class Slider(RadixThemesComponent): @overload @@ -138,7 +140,13 @@ class Slider(RadixThemesComponent): autofocus: Optional[bool] = None, custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, on_blur: Optional[EventType[[]]] = None, - on_change: Optional[EventType[List[Union[int, float]]]] = None, + on_change: Optional[ + Union[ + EventType[list[Union[int, float]]], + EventType[list[int]], + EventType[list[float]], + ] + ] = None, on_click: Optional[EventType[[]]] = None, on_context_menu: Optional[EventType[[]]] = None, on_double_click: Optional[EventType[[]]] = None, @@ -153,7 +161,13 @@ class Slider(RadixThemesComponent): on_mouse_up: Optional[EventType[[]]] = None, on_scroll: Optional[EventType[[]]] = None, on_unmount: Optional[EventType[[]]] = None, - on_value_commit: Optional[EventType[List[Union[int, float]]]] = None, + on_value_commit: Optional[ + Union[ + EventType[list[Union[int, float]]], + EventType[list[int]], + EventType[list[float]], + ] + ] = None, **props, ) -> "Slider": """Create a Slider component. diff --git a/reflex/event.py b/reflex/event.py index 14249091eb..c2e6955f61 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -29,8 +29,12 @@ from reflex import constants from reflex.utils import console, format -from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch -from reflex.utils.types import ArgsSpec, GenericType +from reflex.utils.exceptions import ( + EventFnArgMismatch, + EventHandlerArgMismatch, + EventHandlerArgTypeMismatch, +) +from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass from reflex.vars import VarData from reflex.vars.base import ( LiteralVar, @@ -401,7 +405,9 @@ class EventChain(EventActionsMixin): default_factory=list ) - args_spec: Optional[Callable] = dataclasses.field(default=None) + args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field( + default=None + ) invocation: Optional[Var] = dataclasses.field(default=None) @@ -1053,7 +1059,8 @@ def get_hydrate_event(state) -> str: def call_event_handler( event_handler: EventHandler | EventSpec, - arg_spec: ArgsSpec, + arg_spec: ArgsSpec | Sequence[ArgsSpec], + key: Optional[str] = None, ) -> EventSpec: """Call an event handler to get the event spec. @@ -1064,12 +1071,16 @@ def call_event_handler( Args: event_handler: The event handler. arg_spec: The lambda that define the argument(s) to pass to the event handler. + key: The key to pass to the event handler. Raises: EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec. Returns: The event spec from calling the event handler. + + # noqa: DAR401 failure + """ parsed_args = parse_args_spec(arg_spec) # type: ignore @@ -1077,19 +1088,113 @@ def call_event_handler( # Handle partial application of EventSpec args return event_handler.add_args(*parsed_args) - args = inspect.getfullargspec(event_handler.fn).args - n_args = len(args) - 1 # subtract 1 for bound self arg - if n_args == len(parsed_args): - return event_handler(*parsed_args) # type: ignore - else: + provided_callback_fullspec = inspect.getfullargspec(event_handler.fn) + + provided_callback_n_args = ( + len(provided_callback_fullspec.args) - 1 + ) # subtract 1 for bound self arg + + if provided_callback_n_args != len(parsed_args): raise EventHandlerArgMismatch( "The number of arguments accepted by " - f"{event_handler.fn.__qualname__} ({n_args}) " + f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) " "does not match the arguments passed by the event trigger: " f"{[str(v) for v in parsed_args]}\n" "See https://reflex.dev/docs/events/event-arguments/" ) + all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec + + event_spec_return_types = list( + filter( + lambda event_spec_return_type: event_spec_return_type is not None + and get_origin(event_spec_return_type) is tuple, + (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec), + ) + ) + + if event_spec_return_types: + failures = [] + + for event_spec_index, event_spec_return_type in enumerate( + event_spec_return_types + ): + args = get_args(event_spec_return_type) + + args_types_without_vars = [ + arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args + ] + + try: + type_hints_of_provided_callback = get_type_hints(event_handler.fn) + except NameError: + type_hints_of_provided_callback = {} + + failed_type_check = False + + # check that args of event handler are matching the spec if type hints are provided + for i, arg in enumerate(provided_callback_fullspec.args[1:]): + if arg not in type_hints_of_provided_callback: + continue + + try: + compare_result = typehint_issubclass( + args_types_without_vars[i], type_hints_of_provided_callback[arg] + ) + except TypeError: + # TODO: In 0.7.0, remove this block and raise the exception + # raise TypeError( + # f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." + # ) from e + console.warn( + f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." + ) + compare_result = False + + if compare_result: + continue + else: + failure = EventHandlerArgTypeMismatch( + f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead." + ) + failures.append(failure) + failed_type_check = True + break + + if not failed_type_check: + if event_spec_index: + args = get_args(event_spec_return_types[0]) + + args_types_without_vars = [ + arg if get_origin(arg) is not Var else get_args(arg)[0] + for arg in args + ] + + expect_string = ", ".join( + repr(arg) for arg in args_types_without_vars + ).replace("[", "\\[") + + given_string = ", ".join( + repr(type_hints_of_provided_callback.get(arg, Any)) + for arg in provided_callback_fullspec.args[1:] + ).replace("[", "\\[") + + console.warn( + f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. " + f"This may lead to unexpected behavior but is intentionally ignored for {key}." + ) + return event_handler(*parsed_args) + + if failures: + console.deprecate( + "Mismatched event handler argument types", + "\n".join([str(f) for f in failures]), + "0.6.5", + "0.7.0", + ) + + return event_handler(*parsed_args) # type: ignore + def unwrap_var_annotation(annotation: GenericType): """Unwrap a Var annotation or return it as is if it's not Var[X]. @@ -1128,7 +1233,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str): return annotation -def parse_args_spec(arg_spec: ArgsSpec): +def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]): """Parse the args provided in the ArgsSpec of an event trigger. Args: @@ -1137,6 +1242,8 @@ def parse_args_spec(arg_spec: ArgsSpec): Returns: The parsed args. """ + # if there's multiple, the first is the default + arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec spec = inspect.getfullargspec(arg_spec) annotations = get_type_hints(arg_spec) @@ -1152,13 +1259,18 @@ def parse_args_spec(arg_spec: ArgsSpec): ) -def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]: +def check_fn_match_arg_spec( + fn: Callable, + arg_spec: ArgsSpec, + key: Optional[str] = None, +) -> List[Var]: """Ensures that the function signature matches the passed argument specification or raises an EventFnArgMismatch if they do not. Args: fn: The function to be validated. arg_spec: The argument specification for the event trigger. + key: The key to pass to the event handler. Returns: The parsed arguments from the argument specification. @@ -1184,7 +1296,11 @@ def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]: return parsed_args -def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: +def call_event_fn( + fn: Callable, + arg_spec: ArgsSpec, + key: Optional[str] = None, +) -> list[EventSpec] | Var: """Call a function to a list of event specs. The function should return a single EventSpec, a list of EventSpecs, or a @@ -1193,6 +1309,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: Args: fn: The function to call. arg_spec: The argument spec for the event trigger. + key: The key to pass to the event handler. Returns: The event specs from calling the function or a Var. @@ -1205,7 +1322,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: from reflex.utils.exceptions import EventHandlerValueError # Check that fn signature matches arg_spec - parsed_args = check_fn_match_arg_spec(fn, arg_spec) + parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key) # Call the function with the parsed args. out = fn(*parsed_args) @@ -1223,7 +1340,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: for e in out: if isinstance(e, EventHandler): # An un-called EventHandler gets all of the args of the event trigger. - e = call_event_handler(e, arg_spec) + e = call_event_handler(e, arg_spec, key=key) # Make sure the event spec is valid. if not isinstance(e, EventSpec): @@ -1433,7 +1550,12 @@ def create( Returns: The created LiteralEventChainVar instance. """ - sig = inspect.signature(value.args_spec) # type: ignore + arg_spec = ( + value.args_spec[0] + if isinstance(value.args_spec, Sequence) + else value.args_spec + ) + sig = inspect.signature(arg_spec) # type: ignore if sig.parameters: arg_def = tuple((f"_{p}" for p in sig.parameters)) arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index f163385138..31af0e94a5 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -90,7 +90,11 @@ class MatchTypeError(ReflexError, TypeError): class EventHandlerArgMismatch(ReflexError, TypeError): - """Raised when the number of args accepted by an EventHandler is differs from that provided by the event trigger.""" + """Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger.""" + + +class EventHandlerArgTypeMismatch(ReflexError, TypeError): + """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger.""" class EventFnArgMismatch(ReflexError, TypeError): diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 667015768e..342277cadd 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -490,7 +490,7 @@ def _generate_component_create_functiondef( def figure_out_return_type(annotation: Any): if inspect.isclass(annotation) and issubclass(annotation, inspect._empty): - return ast.Name(id="Optional[EventType]") + return ast.Name(id="EventType") if not isinstance(annotation, str) and get_origin(annotation) is tuple: arguments = get_args(annotation) @@ -509,20 +509,13 @@ def figure_out_return_type(annotation: Any): # Create EventType using the joined string event_type = ast.Name(id=f"EventType[{args_str}]") - # Wrap in Optional - optional_type = ast.Subscript( - value=ast.Name(id="Optional"), - slice=ast.Index(value=event_type), - ctx=ast.Load(), - ) - - return ast.Name(id=ast.unparse(optional_type)) + return event_type if isinstance(annotation, str) and annotation.startswith("Tuple["): inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]") if inside_of_tuple == "()": - return ast.Name(id="Optional[EventType[[]]]") + return ast.Name(id="EventType[[]]") arguments = [""] @@ -548,10 +541,8 @@ def figure_out_return_type(annotation: Any): for argument in arguments ] - return ast.Name( - id=f"Optional[EventType[{', '.join(arguments_without_var)}]]" - ) - return ast.Name(id="Optional[EventType]") + return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]") + return ast.Name(id="EventType") event_triggers = clz().get_event_triggers() @@ -560,8 +551,33 @@ def figure_out_return_type(annotation: Any): ( ast.arg( arg=trigger, - annotation=figure_out_return_type( - inspect.signature(event_triggers[trigger]).return_annotation + annotation=ast.Subscript( + ast.Name("Optional"), + ast.Index( # type: ignore + value=ast.Name( + id=ast.unparse( + figure_out_return_type( + inspect.signature(event_specs).return_annotation + ) + if not isinstance( + event_specs := event_triggers[trigger], tuple + ) + else ast.Subscript( + ast.Name("Union"), + ast.Tuple( + [ + figure_out_return_type( + inspect.signature( + event_spec + ).return_annotation + ) + for event_spec in event_specs + ] + ), + ) + ) + ) + ), ), ), ast.Constant(value=None), diff --git a/reflex/utils/types.py b/reflex/utils/types.py index baedcc5a01..d58825ed55 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -774,3 +774,69 @@ def wrapper(*args, **kwargs): # Store this here for performance. StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar) + + +def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: + """Check if a type hint is a subclass of another type hint. + + Args: + possible_subclass: The type hint to check. + possible_superclass: The type hint to check against. + + Returns: + Whether the type hint is a subclass of the other type hint. + """ + if possible_superclass is Any: + return True + if possible_subclass is Any: + return False + + provided_type_origin = get_origin(possible_subclass) + accepted_type_origin = get_origin(possible_superclass) + + if provided_type_origin is None and accepted_type_origin is None: + # In this case, we are dealing with a non-generic type, so we can use issubclass + return issubclass(possible_subclass, possible_superclass) + + # Remove this check when Python 3.10 is the minimum supported version + if hasattr(types, "UnionType"): + provided_type_origin = ( + Union if provided_type_origin is types.UnionType else provided_type_origin + ) + accepted_type_origin = ( + Union if accepted_type_origin is types.UnionType else accepted_type_origin + ) + + # Get type arguments (e.g., [float, int] for Dict[float, int]) + provided_args = get_args(possible_subclass) + accepted_args = get_args(possible_superclass) + + if accepted_type_origin is Union: + if provided_type_origin is not Union: + return any( + typehint_issubclass(possible_subclass, accepted_arg) + for accepted_arg in accepted_args + ) + return all( + any( + typehint_issubclass(provided_arg, accepted_arg) + for accepted_arg in accepted_args + ) + for provided_arg in provided_args + ) + + # Check if the origin of both types is the same (e.g., list for List[int]) + # This probably should be issubclass instead of == + if (provided_type_origin or possible_subclass) != ( + accepted_type_origin or possible_superclass + ): + return False + + # Ensure all specific types are compatible with accepted types + # Note this is not necessarily correct, as it doesn't check against contravariance and covariance + # It also ignores when the length of the arguments is different + return all( + typehint_issubclass(provided_arg, accepted_arg) + for provided_arg, accepted_arg in zip(provided_args, accepted_args) + if accepted_arg is not Any + ) diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index c2d73aca52..a614fd7152 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -20,13 +20,17 @@ EventChain, EventHandler, empty_event, + identity_event, input_event, parse_args_spec, ) from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports -from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch +from reflex.utils.exceptions import ( + EventFnArgMismatch, + EventHandlerArgMismatch, +) from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -43,6 +47,18 @@ def do_something(self): def do_something_arg(self, arg): pass + def do_something_with_bool(self, arg: bool): + pass + + def do_something_with_int(self, arg: int): + pass + + def do_something_with_list_int(self, arg: list[int]): + pass + + def do_something_with_list_str(self, arg: list[str]): + pass + return TestState @@ -95,8 +111,10 @@ def get_event_triggers(self) -> Dict[str, Any]: """ return { **super().get_event_triggers(), - "on_open": lambda e0: [e0], - "on_close": lambda e0: [e0], + "on_open": identity_event(bool), + "on_close": identity_event(bool), + "on_user_visited_count_changed": identity_event(int), + "on_user_list_changed": identity_event(List[str]), } def _get_imports(self) -> ParsedImportDict: @@ -582,7 +600,14 @@ def test_get_event_triggers(component1, component2): assert component1().get_event_triggers().keys() == default_triggers assert ( component2().get_event_triggers().keys() - == {"on_open", "on_close", "on_prop_event"} | default_triggers + == { + "on_open", + "on_close", + "on_prop_event", + "on_user_visited_count_changed", + "on_user_list_changed", + } + | default_triggers ) @@ -903,6 +928,22 @@ def test_invalid_event_handler_args(component2, test_state): on_prop_event=[test_state.do_something_arg, test_state.do_something] ) + # Enable when 0.7.0 happens + # # Event Handler types must match + # with pytest.raises(EventHandlerArgTypeMismatch): + # component2.create( + # on_user_visited_count_changed=test_state.do_something_with_bool + # ) + # with pytest.raises(EventHandlerArgTypeMismatch): + # component2.create(on_user_list_changed=test_state.do_something_with_int) + # with pytest.raises(EventHandlerArgTypeMismatch): + # component2.create(on_user_list_changed=test_state.do_something_with_list_int) + + # component2.create(on_open=test_state.do_something_with_int) + # component2.create(on_open=test_state.do_something_with_bool) + # component2.create(on_user_visited_count_changed=test_state.do_something_with_int) + # component2.create(on_user_list_changed=test_state.do_something_with_list_str) + # lambda cannot return weird values. with pytest.raises(ValueError): component2.create(on_click=lambda: 1) diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 81579acc77..dd88138bfd 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -2,7 +2,7 @@ import typing from functools import cached_property from pathlib import Path -from typing import Any, ClassVar, List, Literal, Type, Union +from typing import Any, ClassVar, Dict, List, Literal, Type, Union import pytest import typer @@ -77,6 +77,47 @@ def test_is_generic_alias(cls: type, expected: bool): assert types.is_generic_alias(cls) == expected +@pytest.mark.parametrize( + ("subclass", "superclass", "expected"), + [ + *[ + (base_type, base_type, True) + for base_type in [int, float, str, bool, list, dict] + ], + *[ + (one_type, another_type, False) + for one_type in [int, float, str, list, dict] + for another_type in [int, float, str, list, dict] + if one_type != another_type + ], + (bool, int, True), + (int, bool, False), + (list, List, True), + (list, List[str], True), # this is wrong, but it's a limitation of the function + (List, list, True), + (List[int], list, True), + (List[int], List, True), + (List[int], List[str], False), + (List[int], List[int], True), + (List[int], List[float], False), + (List[int], List[Union[int, float]], True), + (List[int], List[Union[float, str]], False), + (Union[int, float], List[Union[int, float]], False), + (Union[int, float], Union[int, float, str], True), + (Union[int, float], Union[str, float], False), + (Dict[str, int], Dict[str, int], True), + (Dict[str, bool], Dict[str, int], True), + (Dict[str, int], Dict[str, bool], False), + (Dict[str, Any], dict[str, str], False), + (Dict[str, str], dict[str, str], True), + (Dict[str, str], dict[str, Any], True), + (Dict[str, Any], dict[str, Any], True), + ], +) +def test_typehint_issubclass(subclass, superclass, expected): + assert types.typehint_issubclass(subclass, superclass) == expected + + def test_validate_invalid_bun_path(mocker): """Test that an error is thrown when a custom specified bun path is not valid or does not exist.