diff --git a/nicegui/element_filter.py b/nicegui/element_filter.py index 3e5449909..91f4d537c 100644 --- a/nicegui/element_filter.py +++ b/nicegui/element_filter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, Iterator, List, Optional, Type, TypeVar, Union +from typing import Generic, Iterator, List, Optional, Type, TypeVar, Union, overload from typing_extensions import Self @@ -12,9 +12,26 @@ T = TypeVar('T', bound=Element) -class ElementFilter(Generic[T], Iterator[T]): +class ElementFilter(Generic[T]): DEFAULT_LOCAL_SCOPE = False + @overload + def __init__(self: ElementFilter[Element], *, + marker: Union[str, List[str], None] = None, + content: Union[str, List[str], None] = None, + local_scope: bool = DEFAULT_LOCAL_SCOPE, + ) -> None: + ... + + @overload + def __init__(self, *, + kind: Type[T], + marker: Union[str, List[str], None] = None, + content: Union[str, List[str], None] = None, + local_scope: bool = DEFAULT_LOCAL_SCOPE, + ) -> None: + ... + def __init__(self, *, kind: Optional[Type[T]] = None, marker: Union[str, List[str], None] = None, @@ -33,18 +50,22 @@ def __init__(self, *, :param content: filter for elements which contain ``content`` in one of their content attributes like ``.text``, ``.value``, ``.source``, ...; can be a singe string or a list of strings which all must match :param local_scope: if `True`, only elements within the current scope are returned; by default the whole page is searched (this default behavior can be changed with ``ElementFilter.DEFAULT_LOCAL_SCOPE = True``) """ - self._kind = kind or Element + self._kind = kind self._markers = marker.split() if isinstance(marker, str) else marker self._contents = [content] if isinstance(content, str) else content - self._within_types: List[Type[Element]] = [] - self._within_markers: List[str] = [] + self._within_kinds: List[Type[Element]] = [] - self._not_within_types: List[Type[Element]] = [] - self._not_within_markers: List[str] = [] + self._within_instances: List[Element] = [] + self._within_markers: List[str] = [] + self._not_within_kinds: List[Type[Element]] = [] + self._not_within_instances: List[Element] = [] + self._not_within_markers: List[str] = [] + self._exclude_kinds: List[Type[Element]] = [] self._exclude_markers: List[str] = [] self._exclude_content: List[str] = [] + self._scope = context.slot.parent if local_scope else context.client.layout def __iter__(self) -> Iterator[T]: @@ -72,28 +93,23 @@ def _iterate(self, parent: Element, *, visited: Optional[List[Element]] = None) (self._kind is None or isinstance(element, self._kind)) and (not self._markers or all(m in element._markers for m in self._markers)) and (not self._contents or all(c in content for c in self._contents)) and - (not self._exclude_kinds or not any(isinstance(element, type_) for type_ in self._exclude_kinds)) and + (not self._exclude_kinds or not isinstance(element, tuple(self._exclude_kinds))) and (not self._exclude_markers or not any(m in element._markers for m in self._exclude_markers)) and - (not self._exclude_content or (hasattr(element, 'text') and not any(text in element.text for text in self._exclude_content))) and - (not self._within_kinds or any(element in within_kind for within_kind in self._within_kinds)) + (not self._exclude_content or not any(text in getattr(element, 'text', '') for text in self._exclude_content)) and + (not self._within_instances or any(element in instance for instance in self._within_instances)) ): if ( - (not self._within_types or any(isinstance(element, type_) for type_ in self._within_types for element in visited)) and + (not self._within_kinds or any(isinstance(element, kind) for kind in self._within_kinds for element in visited)) and (not self._within_markers or any(m in element._markers for m in self._within_markers for element in visited)) and - (not self._not_within_types or not any(isinstance(element, type_) for type_ in self._not_within_types for element in visited)) and + (not self._not_within_kinds or not any(isinstance(element, kinds) for kinds in self._not_within_kinds for element in visited)) and (not self._not_within_markers or not any(m in element._markers for m in self._not_within_markers for element in visited)) ): - yield element - if element not in self._not_within_kinds: + yield element # type: ignore + if element not in self._not_within_instances: yield from self._iterate(element, visited=[*visited, element]) - def __next__(self) -> T: - if self._iterator is None: - raise StopIteration - return next(self._iterator) - def __len__(self) -> int: return len(list(iter(self))) @@ -104,14 +120,14 @@ def within(self, *, kind: Optional[Type] = None, marker: Optional[str] = None, i """Filter elements which have a specific match in the parent hierarchy.""" if kind is not None: assert issubclass(kind, Element) - self._within_types.append(kind) + self._within_kinds.append(kind) if marker is not None: self._within_markers.append(marker) if instance is not None: - self._within_kinds.extend(instance if isinstance(instance, list) else [instance]) + self._within_instances.extend(instance if isinstance(instance, list) else [instance]) return self - def exclude(self, *, kind: Optional[Element] = None, marker: Optional[str] = None, content: Optional[str] = None) -> Self: + def exclude(self, *, kind: Optional[Type[Element]] = None, marker: Optional[str] = None, content: Optional[str] = None) -> Self: """Exclude elements with specific element type, marker or content.""" if kind is not None: assert issubclass(kind, Element) @@ -126,11 +142,11 @@ def not_within(self, *, kind: Optional[Type] = None, marker: Optional[str] = Non """Exclude elements which have a parent of a specific type or marker.""" if kind is not None: assert issubclass(kind, Element) - self._not_within_types.append(kind) + self._not_within_kinds.append(kind) if marker is not None: self._not_within_markers.append(marker) if instance is not None: - self._not_within_kinds.extend(instance if isinstance(instance, list) else [instance]) + self._not_within_instances.extend(instance if isinstance(instance, list) else [instance]) return self def classes(self, add: Optional[str] = None, *, remove: Optional[str] = None, replace: Optional[str] = None) -> Self: diff --git a/nicegui/testing/user.py b/nicegui/testing/user.py index 3ca3e1013..5985c49f7 100644 --- a/nicegui/testing/user.py +++ b/nicegui/testing/user.py @@ -135,17 +135,31 @@ async def should_not_see(self, @overload def find(self, - target: Union[str, Type[T]], - ) -> UserInteraction: + target: str, + ) -> UserInteraction[Element]: + ... + + @overload + def find(self, + target: Type[T], + ) -> UserInteraction[T]: + ... + + @overload + def find(self: User, + *, + marker: Union[str, list[str], None] = None, + content: Union[str, list[str], None] = None, + ) -> UserInteraction[Element]: ... @overload def find(self, *, - kind: Type[T] = Element, + kind: Type[T], marker: Union[str, list[str], None] = None, content: Union[str, list[str], None] = None, - ) -> UserInteraction: + ) -> UserInteraction[T]: ... def find(self, @@ -154,7 +168,7 @@ def find(self, kind: Optional[Type[T]] = None, marker: Union[str, list[str], None] = None, content: Union[str, list[str], None] = None, - ) -> UserInteraction: + ) -> UserInteraction[T]: """Select elements for interaction.""" assert self.client with self.client: @@ -177,9 +191,11 @@ def _gather_elements(self, content: Union[str, list[str], None] = None, ) -> Set[T]: if target is None: + if kind is None: + return set(ElementFilter(marker=marker, content=content)) # type: ignore return set(ElementFilter(kind=kind, marker=marker, content=content)) elif isinstance(target, str): - return set(ElementFilter(marker=target)).union(ElementFilter(content=target)) + return set(ElementFilter(marker=target)).union(ElementFilter(content=target)) # type: ignore else: return set(ElementFilter(kind=target)) diff --git a/nicegui/testing/user_interaction.py b/nicegui/testing/user_interaction.py index b92b794ac..7d4b92f75 100644 --- a/nicegui/testing/user_interaction.py +++ b/nicegui/testing/user_interaction.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Set, TypeVar +from typing import TYPE_CHECKING, Generic, Set, TypeVar from typing_extensions import Self @@ -13,7 +13,7 @@ T = TypeVar('T', bound=Element) -class UserInteraction: +class UserInteraction(Generic[T]): def __init__(self, user: User, elements: Set[T]) -> None: """Iteraction object of the simulated user. diff --git a/tests/test_element_filter.py b/tests/test_element_filter.py index 24ac1560e..b86e3dfe2 100644 --- a/tests/test_element_filter.py +++ b/tests/test_element_filter.py @@ -8,8 +8,10 @@ pytestmark = pytest.mark.usefixtures('user') +# pylint: disable=missing-function-docstring -def test_find_all(): + +def test_find_all() -> None: ui.button('button A') ui.label('label A') with ui.row(): @@ -17,7 +19,6 @@ def test_find_all(): ui.label('label B') elements: List[ui.element] = list(ElementFilter()) - assert len(elements) == 8 assert elements[0].tag == 'q-page-container' assert elements[1].tag == 'q-page' @@ -41,7 +42,7 @@ def test_find_by_text_element(): assert result == ['button A', 'label A', 'button B', 'label B'] -def test_find_by_type(): +def test_find_by_kind(): ui.button('button A') ui.label('label A') ui.button('button B') @@ -106,18 +107,6 @@ def test_find_by_multiple_markers(): assert result == ['button B', 'button C'] -def test_find_within_type(): - ui.button('button A') - ui.label('label A') - with ui.row(): - ui.button('button B') - ui.label('label B') - - result = [element.text for element in ElementFilter(kind=ui.button).within(kind=ui.row)] - - assert result == ['button B'] - - def test_find_within_marker(): ui.button('button A') ui.label('label A') @@ -155,7 +144,19 @@ def test_find_within_elements(): assert result == ['button A', 'label A', 'button B', 'label B'] -def test_find_with_excluding_type(): +def test_find_within_kind(): + ui.button('button A') + with ui.row(): + ui.label('label A') + ui.button('button B') + ui.label('label B') + + result = [element.text for element in ElementFilter(content='B').within(kind=ui.row)] + + assert result == ['button B', 'label B'] + + +def test_find_with_excluding_kind(): ui.button('button A') ui.label('label A') ui.button('button B') @@ -188,7 +189,7 @@ def test_find_with_excluding_text(): assert result == ['button B'] -def test_find_not_within_type(): +def test_find_not_within_kind(): ui.button('button A') ui.label('label A') with ui.row(): @@ -266,3 +267,15 @@ async def test_setting_props(user: User): await user.open('/') for button in user.find('button').elements: assert button._props['flat'] # pylint: disable=protected-access + + +async def test_typing(user: User): + ui.button('button A') + ui.label('label A') + + await user.open('/') + # NOTE we have not yet found a way to test the typing suggestions automatically + # to test, hover over the variable and verify that your IDE infers the correct type + _ = ElementFilter(kind=ui.button) # ElementFilter[ui.button] + _ = ElementFilter(kind=ui.label) # ElementFilter[ui.label] + _ = ElementFilter() # ElementFilter[Element] diff --git a/tests/test_user_simulation.py b/tests/test_user_simulation.py index 4c81ba19f..57e8ddb3a 100644 --- a/tests/test_user_simulation.py +++ b/tests/test_user_simulation.py @@ -236,3 +236,19 @@ def page(): Card Image [src=https://via.placehol...] '''.strip() + + +async def test_typing(user: User) -> None: + @ui.page('/') + def page(): + ui.label('Hello!') + ui.button('World!') + + await user.open('/') + # NOTE we have not yet found a way to test the typing suggestions automatically + # to test, hover over the variable and verify that your IDE inferres the correct type + _ = user.find(kind=ui.label).elements # Set[ui.label] + _ = user.find(ui.label).elements # Set[ui.label] + _ = user.find('World').elements # Set[ui.element] + _ = user.find('Hello').elements # Set[ui.element] + _ = user.find('!').elements # Set[ui.element]