Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing for ElementFilter and User #3381

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions nicegui/element_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Generic, Iterator, List, Optional, Type, TypeVar, Union
from typing import Generic, Iterator, List, Optional, Type, TypeVar, Union, overload

from typing_extensions import Self

Expand All @@ -12,9 +12,26 @@
T = TypeVar('T', bound=Element)


class ElementFilter(Generic[T], Iterator[T]):
class ElementFilter(Generic[T]):
DEFAULT_LOCAL_SCOPE = False

@overload
def __init__(self: ElementFilter[Element], *,
marker: Union[str, List[str], None] = None,
content: Union[str, List[str], None] = None,
local_scope: bool = DEFAULT_LOCAL_SCOPE,
) -> None:
...

@overload
def __init__(self, *,
kind: Type[T],
marker: Union[str, List[str], None] = None,
content: Union[str, List[str], None] = None,
local_scope: bool = DEFAULT_LOCAL_SCOPE,
) -> None:
...

def __init__(self, *,
kind: Optional[Type[T]] = None,
marker: Union[str, List[str], None] = None,
Expand All @@ -33,18 +50,22 @@ def __init__(self, *,
:param content: filter for elements which contain ``content`` in one of their content attributes like ``.text``, ``.value``, ``.source``, ...; can be a singe string or a list of strings which all must match
:param local_scope: if `True`, only elements within the current scope are returned; by default the whole page is searched (this default behavior can be changed with ``ElementFilter.DEFAULT_LOCAL_SCOPE = True``)
"""
self._kind = kind or Element
self._kind = kind
self._markers = marker.split() if isinstance(marker, str) else marker
self._contents = [content] if isinstance(content, str) else content
self._within_types: List[Type[Element]] = []
self._within_markers: List[str] = []

self._within_kinds: List[Type[Element]] = []
self._not_within_types: List[Type[Element]] = []
self._not_within_markers: List[str] = []
self._within_instances: List[Element] = []
self._within_markers: List[str] = []

self._not_within_kinds: List[Type[Element]] = []
self._not_within_instances: List[Element] = []
self._not_within_markers: List[str] = []

self._exclude_kinds: List[Type[Element]] = []
self._exclude_markers: List[str] = []
self._exclude_content: List[str] = []

self._scope = context.slot.parent if local_scope else context.client.layout

def __iter__(self) -> Iterator[T]:
Expand Down Expand Up @@ -72,28 +93,23 @@ def _iterate(self, parent: Element, *, visited: Optional[List[Element]] = None)
(self._kind is None or isinstance(element, self._kind)) and
(not self._markers or all(m in element._markers for m in self._markers)) and
(not self._contents or all(c in content for c in self._contents)) and
(not self._exclude_kinds or not any(isinstance(element, type_) for type_ in self._exclude_kinds)) and
(not self._exclude_kinds or not isinstance(element, tuple(self._exclude_kinds))) and
(not self._exclude_markers or not any(m in element._markers for m in self._exclude_markers)) and
(not self._exclude_content or (hasattr(element, 'text') and not any(text in element.text for text in self._exclude_content))) and
(not self._within_kinds or any(element in within_kind for within_kind in self._within_kinds))
(not self._exclude_content or not any(text in getattr(element, 'text', '') for text in self._exclude_content)) and
(not self._within_instances or any(element in instance for instance in self._within_instances))
):
if (
(not self._within_types or any(isinstance(element, type_) for type_ in self._within_types for element in visited)) and
(not self._within_kinds or any(isinstance(element, kind) for kind in self._within_kinds for element in visited)) and
(not self._within_markers or any(m in element._markers for m in self._within_markers for element in visited)) and
(not self._not_within_types or not any(isinstance(element, type_) for type_ in self._not_within_types for element in visited)) and
(not self._not_within_kinds or not any(isinstance(element, kinds) for kinds in self._not_within_kinds for element in visited)) and
(not self._not_within_markers or not any(m in element._markers
for m in self._not_within_markers
for element in visited))
):
yield element
if element not in self._not_within_kinds:
yield element # type: ignore
if element not in self._not_within_instances:
yield from self._iterate(element, visited=[*visited, element])

def __next__(self) -> T:
if self._iterator is None:
raise StopIteration
return next(self._iterator)

def __len__(self) -> int:
return len(list(iter(self)))

Expand All @@ -104,14 +120,14 @@ def within(self, *, kind: Optional[Type] = None, marker: Optional[str] = None, i
"""Filter elements which have a specific match in the parent hierarchy."""
if kind is not None:
assert issubclass(kind, Element)
self._within_types.append(kind)
self._within_kinds.append(kind)
if marker is not None:
self._within_markers.append(marker)
if instance is not None:
self._within_kinds.extend(instance if isinstance(instance, list) else [instance])
self._within_instances.extend(instance if isinstance(instance, list) else [instance])
return self

def exclude(self, *, kind: Optional[Element] = None, marker: Optional[str] = None, content: Optional[str] = None) -> Self:
def exclude(self, *, kind: Optional[Type[Element]] = None, marker: Optional[str] = None, content: Optional[str] = None) -> Self:
"""Exclude elements with specific element type, marker or content."""
if kind is not None:
assert issubclass(kind, Element)
Expand All @@ -126,11 +142,11 @@ def not_within(self, *, kind: Optional[Type] = None, marker: Optional[str] = Non
"""Exclude elements which have a parent of a specific type or marker."""
if kind is not None:
assert issubclass(kind, Element)
self._not_within_types.append(kind)
self._not_within_kinds.append(kind)
if marker is not None:
self._not_within_markers.append(marker)
if instance is not None:
self._not_within_kinds.extend(instance if isinstance(instance, list) else [instance])
self._not_within_instances.extend(instance if isinstance(instance, list) else [instance])
return self

def classes(self, add: Optional[str] = None, *, remove: Optional[str] = None, replace: Optional[str] = None) -> Self:
Expand Down
28 changes: 22 additions & 6 deletions nicegui/testing/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,31 @@ async def should_not_see(self,

@overload
def find(self,
target: Union[str, Type[T]],
) -> UserInteraction:
target: str,
) -> UserInteraction[Element]:
...

@overload
def find(self,
target: Type[T],
) -> UserInteraction[T]:
...

@overload
def find(self: User,
*,
marker: Union[str, list[str], None] = None,
content: Union[str, list[str], None] = None,
) -> UserInteraction[Element]:
...

@overload
def find(self,
*,
kind: Type[T] = Element,
kind: Type[T],
marker: Union[str, list[str], None] = None,
content: Union[str, list[str], None] = None,
) -> UserInteraction:
) -> UserInteraction[T]:
...

def find(self,
Expand All @@ -154,7 +168,7 @@ def find(self,
kind: Optional[Type[T]] = None,
marker: Union[str, list[str], None] = None,
content: Union[str, list[str], None] = None,
) -> UserInteraction:
) -> UserInteraction[T]:
"""Select elements for interaction."""
assert self.client
with self.client:
Expand All @@ -177,9 +191,11 @@ def _gather_elements(self,
content: Union[str, list[str], None] = None,
) -> Set[T]:
if target is None:
if kind is None:
return set(ElementFilter(marker=marker, content=content)) # type: ignore
return set(ElementFilter(kind=kind, marker=marker, content=content))
elif isinstance(target, str):
return set(ElementFilter(marker=target)).union(ElementFilter(content=target))
return set(ElementFilter(marker=target)).union(ElementFilter(content=target)) # type: ignore
else:
return set(ElementFilter(kind=target))

Expand Down
4 changes: 2 additions & 2 deletions nicegui/testing/user_interaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Set, TypeVar
from typing import TYPE_CHECKING, Generic, Set, TypeVar

from typing_extensions import Self

Expand All @@ -13,7 +13,7 @@
T = TypeVar('T', bound=Element)


class UserInteraction:
class UserInteraction(Generic[T]):

def __init__(self, user: User, elements: Set[T]) -> None:
"""Iteraction object of the simulated user.
Expand Down
47 changes: 30 additions & 17 deletions tests/test_element_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@

pytestmark = pytest.mark.usefixtures('user')

# pylint: disable=missing-function-docstring

def test_find_all():

def test_find_all() -> None:
ui.button('button A')
ui.label('label A')
with ui.row():
ui.button('button B')
ui.label('label B')

elements: List[ui.element] = list(ElementFilter())

assert len(elements) == 8
assert elements[0].tag == 'q-page-container'
assert elements[1].tag == 'q-page'
Expand All @@ -41,7 +42,7 @@ def test_find_by_text_element():
assert result == ['button A', 'label A', 'button B', 'label B']


def test_find_by_type():
def test_find_by_kind():
ui.button('button A')
ui.label('label A')
ui.button('button B')
Expand Down Expand Up @@ -106,18 +107,6 @@ def test_find_by_multiple_markers():
assert result == ['button B', 'button C']


def test_find_within_type():
ui.button('button A')
ui.label('label A')
with ui.row():
ui.button('button B')
ui.label('label B')

result = [element.text for element in ElementFilter(kind=ui.button).within(kind=ui.row)]

assert result == ['button B']


def test_find_within_marker():
ui.button('button A')
ui.label('label A')
Expand Down Expand Up @@ -155,7 +144,19 @@ def test_find_within_elements():
assert result == ['button A', 'label A', 'button B', 'label B']


def test_find_with_excluding_type():
def test_find_within_kind():
ui.button('button A')
with ui.row():
ui.label('label A')
ui.button('button B')
ui.label('label B')

result = [element.text for element in ElementFilter(content='B').within(kind=ui.row)]

assert result == ['button B', 'label B']


def test_find_with_excluding_kind():
ui.button('button A')
ui.label('label A')
ui.button('button B')
Expand Down Expand Up @@ -188,7 +189,7 @@ def test_find_with_excluding_text():
assert result == ['button B']


def test_find_not_within_type():
def test_find_not_within_kind():
ui.button('button A')
ui.label('label A')
with ui.row():
Expand Down Expand Up @@ -266,3 +267,15 @@ async def test_setting_props(user: User):
await user.open('/')
for button in user.find('button').elements:
assert button._props['flat'] # pylint: disable=protected-access


async def test_typing(user: User):
ui.button('button A')
ui.label('label A')

await user.open('/')
# NOTE we have not yet found a way to test the typing suggestions automatically
# to test, hover over the variable and verify that your IDE infers the correct type
_ = ElementFilter(kind=ui.button) # ElementFilter[ui.button]
_ = ElementFilter(kind=ui.label) # ElementFilter[ui.label]
_ = ElementFilter() # ElementFilter[Element]
16 changes: 16 additions & 0 deletions tests/test_user_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,19 @@ def page():
Card
Image [src=https://via.placehol...]
'''.strip()


async def test_typing(user: User) -> None:
@ui.page('/')
def page():
ui.label('Hello!')
ui.button('World!')

await user.open('/')
# NOTE we have not yet found a way to test the typing suggestions automatically
# to test, hover over the variable and verify that your IDE inferres the correct type
_ = user.find(kind=ui.label).elements # Set[ui.label]
_ = user.find(ui.label).elements # Set[ui.label]
_ = user.find('World').elements # Set[ui.element]
_ = user.find('Hello').elements # Set[ui.element]
_ = user.find('!').elements # Set[ui.element]
Loading