Skip to content

Commit

Permalink
Fix typing for ElementFilter and User (#3381)
Browse files Browse the repository at this point in the history
* improve typing and naming

* fix typing

* reverted "kind" and type T of ElementFilter

* fix element filter type inference

* fix typing in __next__

* ignore typing where it can not be matched

* open page

* code review

---------

Co-authored-by: Falko Schindler <falko@zauberzeug.com>
  • Loading branch information
rodja and falkoschindler authored Jul 22, 2024
1 parent 0cd7d04 commit e8afcb5
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 49 deletions.
64 changes: 40 additions & 24 deletions nicegui/element_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Generic, Iterator, List, Optional, Type, TypeVar, Union
from typing import Generic, Iterator, List, Optional, Type, TypeVar, Union, overload

from typing_extensions import Self

Expand All @@ -12,9 +12,26 @@
T = TypeVar('T', bound=Element)


class ElementFilter(Generic[T], Iterator[T]):
class ElementFilter(Generic[T]):
DEFAULT_LOCAL_SCOPE = False

@overload
def __init__(self: ElementFilter[Element], *,
marker: Union[str, List[str], None] = None,
content: Union[str, List[str], None] = None,
local_scope: bool = DEFAULT_LOCAL_SCOPE,
) -> None:
...

@overload
def __init__(self, *,
kind: Type[T],
marker: Union[str, List[str], None] = None,
content: Union[str, List[str], None] = None,
local_scope: bool = DEFAULT_LOCAL_SCOPE,
) -> None:
...

def __init__(self, *,
kind: Optional[Type[T]] = None,
marker: Union[str, List[str], None] = None,
Expand All @@ -33,18 +50,22 @@ def __init__(self, *,
:param content: filter for elements which contain ``content`` in one of their content attributes like ``.text``, ``.value``, ``.source``, ...; can be a singe string or a list of strings which all must match
:param local_scope: if `True`, only elements within the current scope are returned; by default the whole page is searched (this default behavior can be changed with ``ElementFilter.DEFAULT_LOCAL_SCOPE = True``)
"""
self._kind = kind or Element
self._kind = kind
self._markers = marker.split() if isinstance(marker, str) else marker
self._contents = [content] if isinstance(content, str) else content
self._within_types: List[Type[Element]] = []
self._within_markers: List[str] = []

self._within_kinds: List[Type[Element]] = []
self._not_within_types: List[Type[Element]] = []
self._not_within_markers: List[str] = []
self._within_instances: List[Element] = []
self._within_markers: List[str] = []

self._not_within_kinds: List[Type[Element]] = []
self._not_within_instances: List[Element] = []
self._not_within_markers: List[str] = []

self._exclude_kinds: List[Type[Element]] = []
self._exclude_markers: List[str] = []
self._exclude_content: List[str] = []

self._scope = context.slot.parent if local_scope else context.client.layout

def __iter__(self) -> Iterator[T]:
Expand Down Expand Up @@ -72,28 +93,23 @@ def _iterate(self, parent: Element, *, visited: Optional[List[Element]] = None)
(self._kind is None or isinstance(element, self._kind)) and
(not self._markers or all(m in element._markers for m in self._markers)) and
(not self._contents or all(c in content for c in self._contents)) and
(not self._exclude_kinds or not any(isinstance(element, type_) for type_ in self._exclude_kinds)) and
(not self._exclude_kinds or not isinstance(element, tuple(self._exclude_kinds))) and
(not self._exclude_markers or not any(m in element._markers for m in self._exclude_markers)) and
(not self._exclude_content or (hasattr(element, 'text') and not any(text in element.text for text in self._exclude_content))) and
(not self._within_kinds or any(element in within_kind for within_kind in self._within_kinds))
(not self._exclude_content or not any(text in getattr(element, 'text', '') for text in self._exclude_content)) and
(not self._within_instances or any(element in instance for instance in self._within_instances))
):
if (
(not self._within_types or any(isinstance(element, type_) for type_ in self._within_types for element in visited)) and
(not self._within_kinds or any(isinstance(element, kind) for kind in self._within_kinds for element in visited)) and
(not self._within_markers or any(m in element._markers for m in self._within_markers for element in visited)) and
(not self._not_within_types or not any(isinstance(element, type_) for type_ in self._not_within_types for element in visited)) and
(not self._not_within_kinds or not any(isinstance(element, kinds) for kinds in self._not_within_kinds for element in visited)) and
(not self._not_within_markers or not any(m in element._markers
for m in self._not_within_markers
for element in visited))
):
yield element
if element not in self._not_within_kinds:
yield element # type: ignore
if element not in self._not_within_instances:
yield from self._iterate(element, visited=[*visited, element])

def __next__(self) -> T:
if self._iterator is None:
raise StopIteration
return next(self._iterator)

def __len__(self) -> int:
return len(list(iter(self)))

Expand All @@ -104,14 +120,14 @@ def within(self, *, kind: Optional[Type] = None, marker: Optional[str] = None, i
"""Filter elements which have a specific match in the parent hierarchy."""
if kind is not None:
assert issubclass(kind, Element)
self._within_types.append(kind)
self._within_kinds.append(kind)
if marker is not None:
self._within_markers.append(marker)
if instance is not None:
self._within_kinds.extend(instance if isinstance(instance, list) else [instance])
self._within_instances.extend(instance if isinstance(instance, list) else [instance])
return self

def exclude(self, *, kind: Optional[Element] = None, marker: Optional[str] = None, content: Optional[str] = None) -> Self:
def exclude(self, *, kind: Optional[Type[Element]] = None, marker: Optional[str] = None, content: Optional[str] = None) -> Self:
"""Exclude elements with specific element type, marker or content."""
if kind is not None:
assert issubclass(kind, Element)
Expand All @@ -126,11 +142,11 @@ def not_within(self, *, kind: Optional[Type] = None, marker: Optional[str] = Non
"""Exclude elements which have a parent of a specific type or marker."""
if kind is not None:
assert issubclass(kind, Element)
self._not_within_types.append(kind)
self._not_within_kinds.append(kind)
if marker is not None:
self._not_within_markers.append(marker)
if instance is not None:
self._not_within_kinds.extend(instance if isinstance(instance, list) else [instance])
self._not_within_instances.extend(instance if isinstance(instance, list) else [instance])
return self

def classes(self, add: Optional[str] = None, *, remove: Optional[str] = None, replace: Optional[str] = None) -> Self:
Expand Down
28 changes: 22 additions & 6 deletions nicegui/testing/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,31 @@ async def should_not_see(self,

@overload
def find(self,
target: Union[str, Type[T]],
) -> UserInteraction:
target: str,
) -> UserInteraction[Element]:
...

@overload
def find(self,
target: Type[T],
) -> UserInteraction[T]:
...

@overload
def find(self: User,
*,
marker: Union[str, list[str], None] = None,
content: Union[str, list[str], None] = None,
) -> UserInteraction[Element]:
...

@overload
def find(self,
*,
kind: Type[T] = Element,
kind: Type[T],
marker: Union[str, list[str], None] = None,
content: Union[str, list[str], None] = None,
) -> UserInteraction:
) -> UserInteraction[T]:
...

def find(self,
Expand All @@ -154,7 +168,7 @@ def find(self,
kind: Optional[Type[T]] = None,
marker: Union[str, list[str], None] = None,
content: Union[str, list[str], None] = None,
) -> UserInteraction:
) -> UserInteraction[T]:
"""Select elements for interaction."""
assert self.client
with self.client:
Expand All @@ -177,9 +191,11 @@ def _gather_elements(self,
content: Union[str, list[str], None] = None,
) -> Set[T]:
if target is None:
if kind is None:
return set(ElementFilter(marker=marker, content=content)) # type: ignore
return set(ElementFilter(kind=kind, marker=marker, content=content))
elif isinstance(target, str):
return set(ElementFilter(marker=target)).union(ElementFilter(content=target))
return set(ElementFilter(marker=target)).union(ElementFilter(content=target)) # type: ignore
else:
return set(ElementFilter(kind=target))

Expand Down
4 changes: 2 additions & 2 deletions nicegui/testing/user_interaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Set, TypeVar
from typing import TYPE_CHECKING, Generic, Set, TypeVar

from typing_extensions import Self

Expand All @@ -13,7 +13,7 @@
T = TypeVar('T', bound=Element)


class UserInteraction:
class UserInteraction(Generic[T]):

def __init__(self, user: User, elements: Set[T]) -> None:
"""Iteraction object of the simulated user.
Expand Down
47 changes: 30 additions & 17 deletions tests/test_element_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@

pytestmark = pytest.mark.usefixtures('user')

# pylint: disable=missing-function-docstring

def test_find_all():

def test_find_all() -> None:
ui.button('button A')
ui.label('label A')
with ui.row():
ui.button('button B')
ui.label('label B')

elements: List[ui.element] = list(ElementFilter())

assert len(elements) == 8
assert elements[0].tag == 'q-page-container'
assert elements[1].tag == 'q-page'
Expand All @@ -41,7 +42,7 @@ def test_find_by_text_element():
assert result == ['button A', 'label A', 'button B', 'label B']


def test_find_by_type():
def test_find_by_kind():
ui.button('button A')
ui.label('label A')
ui.button('button B')
Expand Down Expand Up @@ -106,18 +107,6 @@ def test_find_by_multiple_markers():
assert result == ['button B', 'button C']


def test_find_within_type():
ui.button('button A')
ui.label('label A')
with ui.row():
ui.button('button B')
ui.label('label B')

result = [element.text for element in ElementFilter(kind=ui.button).within(kind=ui.row)]

assert result == ['button B']


def test_find_within_marker():
ui.button('button A')
ui.label('label A')
Expand Down Expand Up @@ -155,7 +144,19 @@ def test_find_within_elements():
assert result == ['button A', 'label A', 'button B', 'label B']


def test_find_with_excluding_type():
def test_find_within_kind():
ui.button('button A')
with ui.row():
ui.label('label A')
ui.button('button B')
ui.label('label B')

result = [element.text for element in ElementFilter(content='B').within(kind=ui.row)]

assert result == ['button B', 'label B']


def test_find_with_excluding_kind():
ui.button('button A')
ui.label('label A')
ui.button('button B')
Expand Down Expand Up @@ -188,7 +189,7 @@ def test_find_with_excluding_text():
assert result == ['button B']


def test_find_not_within_type():
def test_find_not_within_kind():
ui.button('button A')
ui.label('label A')
with ui.row():
Expand Down Expand Up @@ -266,3 +267,15 @@ async def test_setting_props(user: User):
await user.open('/')
for button in user.find('button').elements:
assert button._props['flat'] # pylint: disable=protected-access


async def test_typing(user: User):
ui.button('button A')
ui.label('label A')

await user.open('/')
# NOTE we have not yet found a way to test the typing suggestions automatically
# to test, hover over the variable and verify that your IDE infers the correct type
_ = ElementFilter(kind=ui.button) # ElementFilter[ui.button]
_ = ElementFilter(kind=ui.label) # ElementFilter[ui.label]
_ = ElementFilter() # ElementFilter[Element]
16 changes: 16 additions & 0 deletions tests/test_user_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,19 @@ def page():
Card
Image [src=https://via.placehol...]
'''.strip()


async def test_typing(user: User) -> None:
@ui.page('/')
def page():
ui.label('Hello!')
ui.button('World!')

await user.open('/')
# NOTE we have not yet found a way to test the typing suggestions automatically
# to test, hover over the variable and verify that your IDE inferres the correct type
_ = user.find(kind=ui.label).elements # Set[ui.label]
_ = user.find(ui.label).elements # Set[ui.label]
_ = user.find('World').elements # Set[ui.element]
_ = user.find('Hello').elements # Set[ui.element]
_ = user.find('!').elements # Set[ui.element]

0 comments on commit e8afcb5

Please sign in to comment.