Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of __iter__ and __next__ does not confer all Iterator protocol abilities #17361

Closed
chadrik opened this issue Jun 11, 2024 · 4 comments
Closed
Labels
bug mypy got something wrong

Comments

@chadrik
Copy link
Contributor

chadrik commented Jun 11, 2024

I've implemented a custom iterator that uses __next__ and another object returns it from its __iter__. mypy correctly detects the object is iterable, but mypy incorrectly omits other implied capabilities of the iterator protocol, such as __contains__ .

From the python docs:

For user-defined classes which define the __contains__() method, x in y returns True if y.__contains__(x) returns a true value, and False otherwise.

For user-defined classes which do not define __contains__() but do define __iter__(), x in y is True if some value z, for which the expression x is z or x == z is true, is produced while iterating over y. If an exception is raised during the iteration, it is as if in raised that exception.

This is my sample code:

from __future__ import absolute_import, print_function

from typing import Iterator


class IteratorA:
    def __iter__(self) -> Iterator[str]:
        yield "foo"


class MyIterator:
    def __init__(self):
        self.called = False

    def __next__(self) -> str:
        if self.called:
            raise StopIteration
        self.called = True
        return "foo"


class IteratorB:
    def __iter__(self) -> MyIterator:
        return MyIterator()


assert "foo" in IteratorA()
for x in IteratorA():
    print(x)
print(list(IteratorA()))

assert "foo" in IteratorB()  # E: Unsupported right operand type for in ("IteratorB")
for x in IteratorB():
    print(x)
print(list(IteratorB()))  # E: No overload variant of "list" matches argument type "IteratorB"

mypy prints

test_iter_next.py:32: error: Unsupported right operand type for in ("IteratorB")
test_iter_next.py:35: error: No overload variant of "list" matches argument type "IteratorB"
test_iter_next.py:35: note: Possible overload variants:
test_iter_next.py:35: note:     def [_T] list(self) -> List[_T]
test_iter_next.py:35: note:     def [_T] list(self, Iterable[_T]) -> List[_T]

There should not be any errors printed here, because the two classes, IteratorA and IteratorB are functionally equivalent.

The code is value and running it produces this output:

foo
['foo']
foo
['foo']
chad$ mypy --version
mypy 0.960 (compiled: yes)
@Hnasar
Copy link
Contributor

Hnasar commented Jun 11, 2024

Just a note that your example works fine if you change the class definition it to explicitly inherit from the ABC

class MyIterator(Iterator[str]):
    ...

@TeamSpen210
Copy link
Contributor

The problem is that MyIterator does not define __iter__(), which is required by the Iterator protocol. Technically that works right now in some cases, but it's a historical accident that CPython doesn't enforce this. See the glossary entry, and the issue where that was discussed. Inheriting from Iterator solves it because that gives you an __iter__() implementation.

@hauntsaninja
Copy link
Collaborator

Closing per Spencer's comment

@hauntsaninja hauntsaninja closed this as not planned Won't fix, can't repro, duplicate, stale Jun 21, 2024
@chadrik
Copy link
Contributor Author

chadrik commented Jun 25, 2024

To close the loop, these are a few examples given in the cited issue that do not work if iter(iter(IteratorB())) is not guaranteed to succeed:

def grouper(iterable, n, fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
    args = [iter(iterable)] * n
    print(args)
    # zip_longest is a C function and retects MyIterator due to lack of __iter__ method
    return zip_longest(*args, fillvalue=fillvalue)


def sliding_window(iterable, n):
    # sliding_window('ABCDEFG', 4) -> ABCD BCDE CDEF DEFG
    it = iter(iterable)
    sl = islice(it, n)
    window = collections.deque(sl, maxlen=n)
    if len(window) == n:
        yield tuple(window)
    for x in it:
        window.append(x)
        yield tuple(window)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug mypy got something wrong
Projects
None yet
Development

No branches or pull requests

4 participants