Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for event handlers to ignore args #4282

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Iterator,
List,
Optional,
Sequence,
Set,
Type,
Union,
Expand All @@ -38,6 +39,7 @@
PageNames,
)
from reflex.constants.compiler import SpecialAttributes
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.event import (
EventCallback,
EventChain,
Expand Down Expand Up @@ -533,7 +535,7 @@ def __init__(self, *args, **kwargs):

def _create_event_chain(
self,
args_spec: Any,
args_spec: types.ArgsSpec | Sequence[types.ArgsSpec],
value: Union[
Var,
EventHandler,
Expand Down Expand Up @@ -599,7 +601,7 @@ def _create_event_chain(

# If the input is a callable, create an event chain.
elif isinstance(value, Callable):
result = call_event_fn(value, args_spec)
result = call_event_fn(value, args_spec, key=key)
if isinstance(result, Var):
# Recursively call this function if the lambda returned an EventChain Var.
return self._create_event_chain(args_spec, result, key=key)
Expand Down Expand Up @@ -629,14 +631,16 @@ def _create_event_chain(
event_actions={},
)

def get_event_triggers(self) -> Dict[str, Any]:
def get_event_triggers(
self,
) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]:
"""Get the event triggers for the component.

Returns:
The event triggers.

"""
default_triggers = {
default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = {
EventTriggers.ON_FOCUS: no_args_event_spec,
EventTriggers.ON_BLUR: no_args_event_spec,
EventTriggers.ON_CLICK: no_args_event_spec,
Expand Down Expand Up @@ -1142,7 +1146,10 @@ def _event_trigger_values_use_state(self) -> bool:
if isinstance(event, EventCallback):
continue
if isinstance(event, EventSpec):
if event.handler.state_full_name:
if (
event.handler.state_full_name
and event.handler.state_full_name != FRONTEND_EVENT_STATE
):
return True
else:
if event._var_state:
Expand Down
4 changes: 4 additions & 0 deletions reflex/constants/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class StateManagerMode(str, Enum):
DISK = "disk"
MEMORY = "memory"
REDIS = "redis"


# Used for things like console_log, etc.
FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state"
137 changes: 73 additions & 64 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from typing_extensions import ParamSpec, Protocol, get_args, get_origin

from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import console, format
from reflex.utils.exceptions import (
EventFnArgMismatch,
EventHandlerArgMismatch,
EventHandlerArgTypeMismatch,
)
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
Expand Down Expand Up @@ -662,7 +662,7 @@ def fn():
fn.__qualname__ = name
fn.__signature__ = sig
return EventSpec(
handler=EventHandler(fn=fn),
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
args=tuple(
(
Var(_js_expr=k),
Expand Down Expand Up @@ -1092,8 +1092,8 @@ def get_hydrate_event(state) -> str:


def call_event_handler(
event_handler: EventHandler | EventSpec,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
event_callback: EventHandler | EventSpec,
event_spec: ArgsSpec | Sequence[ArgsSpec],
adhami3310 marked this conversation as resolved.
Show resolved Hide resolved
key: Optional[str] = None,
) -> EventSpec:
"""Call an event handler to get the event spec.
Expand All @@ -1103,53 +1103,57 @@ def call_event_handler(
Otherwise, the event handler will be called with no args.

Args:
event_handler: The event handler.
arg_spec: The lambda that define the argument(s) to pass to the event handler.
event_callback: The event handler.
event_spec: The lambda that define the argument(s) to pass to the event handler.
key: The key to pass to the event handler.

Raises:
EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec.

Returns:
The event spec from calling the event handler.

# noqa: DAR401 failure

"""
parsed_args = parse_args_spec(arg_spec) # type: ignore

if isinstance(event_handler, EventSpec):
# Handle partial application of EventSpec args
return event_handler.add_args(*parsed_args)

provided_callback_fullspec = inspect.getfullargspec(event_handler.fn)

provided_callback_n_args = (
len(provided_callback_fullspec.args) - 1
) # subtract 1 for bound self arg

if provided_callback_n_args != len(parsed_args):
raise EventHandlerArgMismatch(
"The number of arguments accepted by "
f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
event_spec_args = parse_args_spec(event_spec) # type: ignore

if isinstance(event_callback, EventSpec):
check_fn_match_arg_spec(
event_callback.handler.fn,
event_spec,
key,
bool(event_callback.handler.state_full_name) + len(event_callback.args),
event_callback.handler.fn.__qualname__,
)
# Handle partial application of EventSpec args
return event_callback.add_args(*event_spec_args)

check_fn_match_arg_spec(
event_callback.fn,
event_spec,
key,
bool(event_callback.state_full_name),
event_callback.fn.__qualname__,
)

all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
all_acceptable_specs = (
[event_spec] if not isinstance(event_spec, Sequence) else event_spec
)

event_spec_return_types = list(
filter(
lambda event_spec_return_type: event_spec_return_type is not None
and get_origin(event_spec_return_type) is tuple,
(get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec),
(
get_type_hints(arg_spec).get("return", None)
for arg_spec in all_acceptable_specs
),
)
)

if event_spec_return_types:
failures = []

event_callback_spec = inspect.getfullargspec(event_callback.fn)

for event_spec_index, event_spec_return_type in enumerate(
event_spec_return_types
):
Expand All @@ -1160,14 +1164,14 @@ def call_event_handler(
]

try:
type_hints_of_provided_callback = get_type_hints(event_handler.fn)
type_hints_of_provided_callback = get_type_hints(event_callback.fn)
except NameError:
type_hints_of_provided_callback = {}

failed_type_check = False

# check that args of event handler are matching the spec if type hints are provided
for i, arg in enumerate(provided_callback_fullspec.args[1:]):
for i, arg in enumerate(event_callback_spec.args[1:]):
if arg not in type_hints_of_provided_callback:
continue

Expand All @@ -1181,15 +1185,15 @@ def call_event_handler(
# f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
# ) from e
console.warn(
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}."
f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}."
)
compare_result = False

if compare_result:
continue
else:
failure = EventHandlerArgTypeMismatch(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead."
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
)
failures.append(failure)
failed_type_check = True
Expand All @@ -1210,14 +1214,14 @@ def call_event_handler(

given_string = ", ".join(
repr(type_hints_of_provided_callback.get(arg, Any))
for arg in provided_callback_fullspec.args[1:]
for arg in event_callback_spec.args[1:]
).replace("[", "\\[")

console.warn(
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. "
f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. "
f"This may lead to unexpected behavior but is intentionally ignored for {key}."
)
return event_handler(*parsed_args)
return event_callback(*event_spec_args)

if failures:
console.deprecate(
Expand All @@ -1227,7 +1231,7 @@ def call_event_handler(
"0.7.0",
)

return event_handler(*parsed_args) # type: ignore
return event_callback(*event_spec_args) # type: ignore


def unwrap_var_annotation(annotation: GenericType):
Expand Down Expand Up @@ -1294,45 +1298,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):


def check_fn_match_arg_spec(
fn: Callable,
arg_spec: ArgsSpec,
key: Optional[str] = None,
) -> List[Var]:
user_func: Callable,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: str | None = None,
number_of_bound_args: int = 0,
func_name: str | None = None,
):
"""Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not.

Args:
fn: The function to be validated.
user_func: The function to be validated.
arg_spec: The argument specification for the event trigger.
key: The key to pass to the event handler.

Returns:
The parsed arguments from the argument specification.
key: The key of the event trigger.
number_of_bound_args: The number of bound arguments to the function.
func_name: The name of the function to be validated.

Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
"""
fn_args = inspect.getfullargspec(fn).args
fn_defaults_args = inspect.getfullargspec(fn).defaults
n_fn_args = len(fn_args)
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
user_args = inspect.getfullargspec(user_func).args
user_default_args = inspect.getfullargspec(user_func).defaults
number_of_user_args = len(user_args) - number_of_bound_args
number_of_user_default_args = len(user_default_args) if user_default_args else 0

parsed_event_args = parse_args_spec(arg_spec)

number_of_event_args = len(parsed_event_args)

if number_of_user_args - number_of_user_default_args > number_of_event_args:
raise EventFnArgMismatch(
"The number of mandatory arguments accepted by "
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
f"Event {key} only provides {number_of_event_args} arguments, but "
f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
"arguments to be passed to the event handler.\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
return parsed_args


def call_event_fn(
fn: Callable,
arg_spec: ArgsSpec,
arg_spec: ArgsSpec | Sequence[ArgsSpec],
key: Optional[str] = None,
) -> list[EventSpec] | Var:
"""Call a function to a list of event specs.
Expand All @@ -1356,10 +1361,14 @@ def call_event_fn(
from reflex.utils.exceptions import EventHandlerValueError

# Check that fn signature matches arg_spec
parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key)
check_fn_match_arg_spec(fn, arg_spec, key=key)

parsed_args = parse_args_spec(arg_spec)

number_of_fn_args = len(inspect.getfullargspec(fn).args)

# Call the function with the parsed args.
out = fn(*parsed_args)
out = fn(*[*parsed_args][:number_of_fn_args])

# If the function returns a Var, assume it's an EventChain and render it directly.
if isinstance(out, Var):
Expand Down Expand Up @@ -1478,7 +1487,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
"""
signature = inspect.signature(fn)
new_param = inspect.Parameter(
"state", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any
)
return signature.replace(parameters=(new_param, *signature.parameters.values()))

Expand Down
6 changes: 1 addition & 5 deletions reflex/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError):
"""Raised when the return types of match cases are different."""


class EventHandlerArgMismatch(ReflexError, TypeError):
"""Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger."""


class EventHandlerArgTypeMismatch(ReflexError, TypeError):
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""


class EventFnArgMismatch(ReflexError, TypeError):
"""Raised when the number of args accepted by a lambda differs from that provided by the event trigger."""
"""Raised when the number of args required by an event handler is more than provided by the event trigger."""


class DynamicRouteArgShadowsStateVar(ReflexError, NameError):
Expand Down
3 changes: 2 additions & 1 deletion reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING, Any, List, Optional, Union

from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import exceptions
from reflex.utils.console import deprecate

Expand Down Expand Up @@ -439,7 +440,7 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]:

from reflex.state import State

if state_full_name == "state" and name not in State.__dict__:
if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__:
return ("", to_snake_case(handler.fn.__qualname__))

return (state_full_name, name)
Expand Down
4 changes: 2 additions & 2 deletions reflex/utils/pyi_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from multiprocessing import Pool, cpu_count
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Type, get_args, get_origin
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin

from reflex.components.component import Component
from reflex.utils import types as rx_types
Expand Down Expand Up @@ -560,7 +560,7 @@ def figure_out_return_type(annotation: Any):
inspect.signature(event_specs).return_annotation
)
if not isinstance(
event_specs := event_triggers[trigger], tuple
event_specs := event_triggers[trigger], Sequence
)
else ast.Subscript(
ast.Name("Union"),
Expand Down
Loading
Loading