Skip to content

Commit

Permalink
recursive_iterator to recursive_generator
Browse files Browse the repository at this point in the history
Resolves   #749.
  • Loading branch information
evhub committed Oct 28, 2023
1 parent 586dd5e commit 2e03ba0
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 69 deletions.
22 changes: 11 additions & 11 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2316,8 +2316,6 @@ Coconut will perform automatic [tail call](https://en.wikipedia.org/wiki/Tail_ca

Tail call optimization (though not tail recursion elimination) will work even for 1) mutual recursion and 2) pattern-matching functions split across multiple definitions using [`addpattern`](#addpattern).

If you are encountering a `RuntimeError` due to maximum recursion depth, it is highly recommended that you rewrite your function to meet either the criteria above for tail call optimization, or the corresponding criteria for [`recursive_iterator`](#recursive-iterator), either of which should prevent such errors.

##### Example

**Coconut:**
Expand Down Expand Up @@ -2960,6 +2958,8 @@ Coconut provides `functools.lru_cache` as a built-in under the name `memoize` wi

Use of `memoize` requires `functools.lru_cache`, which exists in the Python 3 standard library, but under Python 2 will require `pip install backports.functools_lru_cache` to function. Additionally, if on Python 2 and `backports.functools_lru_cache` is present, Coconut will patch `functools` such that `functools.lru_cache = backports.functools_lru_cache.lru_cache`.

Note that, if the function to be memoized is a generator or otherwise returns an iterator, [`recursive_generator`](#recursive_generator) can also be used to achieve a similar effect, the use of which is required for recursive generators.

##### Python Docs

@**memoize**(_user\_function_)
Expand Down Expand Up @@ -3060,36 +3060,36 @@ class B:
**Python:**
_Can't be done without a long decorator definition. The full definition of the decorator in Python can be found in the Coconut header._

#### `recursive_iterator`
#### `recursive_generator`

**recursive\_iterator**(_func_)
**recursive\_generator**(_func_)

Coconut provides a `recursive_iterator` decorator that memoizes any stateless, recursive function that returns an iterator. To use `recursive_iterator` on a function, it must meet the following criteria:
Coconut provides a `recursive_generator` decorator that memoizes and makes [`reiterable`](#reiterable) any generator or other stateless function that returns an iterator. To use `recursive_generator` on a function, it must meet the following criteria:

1. your function either always `return`s an iterator or generates an iterator using `yield`,
2. when called multiple times with arguments that are equal, your function produces the same iterator (your function is stateless), and
3. your function gets called (usually calls itself) multiple times with the same arguments.

If you are encountering a `RuntimeError` due to maximum recursion depth, it is highly recommended that you rewrite your function to meet either the criteria above for `recursive_iterator`, or the corresponding criteria for Coconut's [tail call optimization](#tail-call-optimization), either of which should prevent such errors.

Furthermore, `recursive_iterator` also allows the resolution of a [nasty segmentation fault in Python's iterator logic that has never been fixed](http://bugs.python.org/issue14010). Specifically, instead of writing
Importantly, `recursive_generator` also allows the resolution of a [nasty segmentation fault in Python's iterator logic that has never been fixed](http://bugs.python.org/issue14010). Specifically, instead of writing
```coconut
seq = get_elem() :: seq
```
which will crash due to the aforementioned Python issue, write
```coconut
@recursive_iterator
@recursive_generator
def seq() = get_elem() :: seq()
```
which will work just fine.

One pitfall to keep in mind working with `recursive_iterator` is that it shouldn't be used in contexts where the function can potentially be called multiple times with the same iterator object as an input, but with that object not actually corresponding to the same items (e.g. because the first time the object hasn't been iterated over yet and the second time it has been).
One pitfall to keep in mind working with `recursive_generator` is that it shouldn't be used in contexts where the function can potentially be called multiple times with the same iterator object as an input, but with that object not actually corresponding to the same items (e.g. because the first time the object hasn't been iterated over yet and the second time it has been).

_Deprecated: `recursive_iterator` is available as a deprecated alias for `recursive_generator`. Note that deprecated features are disabled in `--strict` mode._

##### Example

**Coconut:**
```coconut
@recursive_iterator
@recursive_generator
def fib() = (1, 1) :: map((+), fib(), fib()$[1:])
```

Expand Down
4 changes: 2 additions & 2 deletions FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Information on every Coconut release is chronicled on the [GitHub releases page]

Yes! Coconut compiles the [newest](https://www.python.org/dev/peps/pep-0526/), [fanciest](https://www.python.org/dev/peps/pep-0484/) type annotation syntax into version-independent type comments which can then by checked using Coconut's built-in [MyPy Integration](./DOCS.md#mypy-integration).

### Help! I tried to write a recursive iterator and my Python segfaulted!
### Help! I tried to write a recursive generator and my Python segfaulted!

No problem—just use Coconut's [`recursive_iterator`](./DOCS.md#recursive-iterator) decorator and you should be fine. This is a [known Python issue](http://bugs.python.org/issue14010) but `recursive_iterator` will fix it for you.
No problem—just use Coconut's [`recursive_generator`](./DOCS.md#recursive_generator) decorator and you should be fine. This is a [known Python issue](http://bugs.python.org/issue14010) but `recursive_generator` will fix it for you.

### How do I split an expression across multiple lines in Coconut?

Expand Down
5 changes: 3 additions & 2 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,10 @@ def _coconut_call_or_coefficient(
) -> _T: ...


def recursive_iterator(func: _T_iter_func) -> _T_iter_func:
def recursive_generator(func: _T_iter_func) -> _T_iter_func:
"""Decorator that memoizes a recursive function that returns an iterator (e.g. a recursive generator)."""
return func
recursive_iterator = recursive_generator


# if sys.version_info >= (3, 12):
Expand Down Expand Up @@ -590,7 +591,7 @@ def addpattern(
*add_funcs: _Callable,
allow_any_func: bool=False,
) -> _t.Callable[..., _t.Any]:
"""Decorator to add a new case to a pattern-matching function (where the new case is checked last).
"""Decorator to add new cases to a pattern-matching function (where the new case is checked last).
Pass allow_any_func=True to allow any object as the base_func rather than just pattern-matching functions.
If add_funcs are passed, addpattern(base_func, add_func) is equivalent to addpattern(base_func)(add_func).
Expand Down
6 changes: 5 additions & 1 deletion coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap):
comma_object="" if target.startswith("3") else ", object",
comma_slash=", /" if target_info >= (3, 8) else "",
report_this_text=report_this_text,
from_None=" from None" if target.startswith("3") else "",
numpy_modules=tuple_str_of(numpy_modules, add_quotes=True),
pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True),
jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True),
Expand Down Expand Up @@ -310,7 +311,7 @@ def pattern_prepender(func):
def datamaker(data_type):
"""DEPRECATED: use makedata instead."""
return _coconut.functools.partial(makedata, data_type)
of, parallel_map, concurrent_map = call, process_map, thread_map
of, parallel_map, concurrent_map, recursive_iterator = call, process_map, thread_map, recursive_generator
'''
if not strict else
r'''
Expand All @@ -329,6 +330,9 @@ def parallel_map(*args, **kwargs):
def concurrent_map(*args, **kwargs):
"""Deprecated Coconut built-in 'concurrent_map' disabled by --strict compilation; use 'thread_map' instead."""
raise _coconut.NameError("deprecated Coconut built-in 'concurrent_map' disabled by --strict compilation; use 'thread_map' instead")
def recursive_iterator(*args, **kwargs):
"""Deprecated Coconut built-in 'recursive_iterator' disabled by --strict compilation; use 'recursive_generator' instead."""
raise _coconut.NameError("deprecated Coconut built-in 'recursive_iterator' disabled by --strict compilation; use 'recursive_generator' instead")
'''
),
return_method_of_self=pycondition(
Expand Down
68 changes: 29 additions & 39 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _coconut_super(type=None, object_or_type=None):
try:
cls = frame.f_locals["__class__"]
except _coconut.AttributeError:
raise _coconut.RuntimeError("super(): __class__ cell not found")
raise _coconut.RuntimeError("super(): __class__ cell not found"){from_None}
self = frame.f_locals[frame.f_code.co_varnames[0]]
return _coconut_py_super(cls, self)
return _coconut_py_super(type, object_or_type)
Expand Down Expand Up @@ -1315,39 +1315,29 @@ class groupsof(_coconut_has_iter):
return (self.__class__, (self.group_size, self.iter))
def __copy__(self):
return self.__class__(self.group_size, self.get_new_iter())
class recursive_iterator(_coconut_base_callable):
class recursive_generator(_coconut_base_callable):
"""Decorator that memoizes a generator (or any function that returns an iterator).
Particularly useful for recursive generators, which may require recursive_iterator to function properly."""
__slots__ = ("func", "reit_store", "backup_reit_store")
Particularly useful for recursive generators, which may require recursive_generator to function properly."""
__slots__ = ("func", "reit_store")
def __init__(self, func):
self.func = func
self.reit_store = {empty_dict}
self.backup_reit_store = []
def __call__(self, *args, **kwargs):
key = (args, _coconut.frozenset(kwargs.items()))
use_backup = False
key = (0, args, _coconut.frozenset(kwargs.items()))
try:
_coconut.hash(key)
except _coconut.Exception:
except _coconut.TypeError:
try:
key = _coconut.pickle.dumps(key, -1)
key = (1, _coconut.pickle.dumps(key, -1))
except _coconut.Exception:
use_backup = True
if use_backup:
for k, v in self.backup_reit_store:
if k == key:
return reit
raise _coconut.TypeError("recursive_generator() requires function arguments to be hashable or pickleable"){from_None}
reit = self.reit_store.get(key)
if reit is None:
reit = {_coconut_}reiterable(self.func(*args, **kwargs))
self.backup_reit_store.append([key, reit])
return reit
else:
reit = self.reit_store.get(key)
if reit is None:
reit = {_coconut_}reiterable(self.func(*args, **kwargs))
self.reit_store[key] = reit
return reit
self.reit_store[key] = reit
return reit
def __repr__(self):
return "recursive_iterator(%r)" % (self.func,)
return "recursive_generator(%r)" % (self.func,)
def __reduce__(self):
return (self.__class__, (self.func,))
class _coconut_FunctionMatchErrorContext(_coconut_baseclass):
Expand Down Expand Up @@ -1416,7 +1406,7 @@ def _coconut_mark_as_match(base_func):{COMMENT._coconut_is_match_is_used_above_a
base_func._coconut_is_match = True
return base_func
def addpattern(base_func, *add_funcs, **kwargs):
"""Decorator to add a new case to a pattern-matching function (where the new case is checked last).
"""Decorator to add new cases to a pattern-matching function (where the new case is checked last).

Pass allow_any_func=True to allow any object as the base_func rather than just pattern-matching functions.
If add_funcs are passed, addpattern(base_func, add_func) is equivalent to addpattern(base_func)(add_func).
Expand Down Expand Up @@ -2010,7 +2000,7 @@ class _coconut_SupportsAdd(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __add__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((+) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((+) in a typing context is a Protocol)")
class _coconut_SupportsMinus(_coconut.typing.Protocol):
"""Coconut (-) Protocol. Equivalent to:

Expand All @@ -2021,9 +2011,9 @@ class _coconut_SupportsMinus(_coconut.typing.Protocol):
raise NotImplementedError
"""
def __sub__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)")
def __neg__(self):
raise NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)")
class _coconut_SupportsMul(_coconut.typing.Protocol):
"""Coconut (*) Protocol. Equivalent to:

Expand All @@ -2032,7 +2022,7 @@ class _coconut_SupportsMul(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __mul__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((*) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((*) in a typing context is a Protocol)")
class _coconut_SupportsPow(_coconut.typing.Protocol):
"""Coconut (**) Protocol. Equivalent to:

Expand All @@ -2041,7 +2031,7 @@ class _coconut_SupportsPow(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __pow__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((**) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((**) in a typing context is a Protocol)")
class _coconut_SupportsTruediv(_coconut.typing.Protocol):
"""Coconut (/) Protocol. Equivalent to:

Expand All @@ -2050,7 +2040,7 @@ class _coconut_SupportsTruediv(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __truediv__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((/) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((/) in a typing context is a Protocol)")
class _coconut_SupportsFloordiv(_coconut.typing.Protocol):
"""Coconut (//) Protocol. Equivalent to:

Expand All @@ -2059,7 +2049,7 @@ class _coconut_SupportsFloordiv(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __floordiv__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((//) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((//) in a typing context is a Protocol)")
class _coconut_SupportsMod(_coconut.typing.Protocol):
"""Coconut (%) Protocol. Equivalent to:

Expand All @@ -2068,7 +2058,7 @@ class _coconut_SupportsMod(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __mod__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((%) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((%) in a typing context is a Protocol)")
class _coconut_SupportsAnd(_coconut.typing.Protocol):
"""Coconut (&) Protocol. Equivalent to:

Expand All @@ -2077,7 +2067,7 @@ class _coconut_SupportsAnd(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __and__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((&) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((&) in a typing context is a Protocol)")
class _coconut_SupportsXor(_coconut.typing.Protocol):
"""Coconut (^) Protocol. Equivalent to:

Expand All @@ -2086,7 +2076,7 @@ class _coconut_SupportsXor(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __xor__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((^) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((^) in a typing context is a Protocol)")
class _coconut_SupportsOr(_coconut.typing.Protocol):
"""Coconut (|) Protocol. Equivalent to:

Expand All @@ -2095,7 +2085,7 @@ class _coconut_SupportsOr(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __or__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((|) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((|) in a typing context is a Protocol)")
class _coconut_SupportsLshift(_coconut.typing.Protocol):
"""Coconut (<<) Protocol. Equivalent to:

Expand All @@ -2104,7 +2094,7 @@ class _coconut_SupportsLshift(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __lshift__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((<<) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((<<) in a typing context is a Protocol)")
class _coconut_SupportsRshift(_coconut.typing.Protocol):
"""Coconut (>>) Protocol. Equivalent to:

Expand All @@ -2113,7 +2103,7 @@ class _coconut_SupportsRshift(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __rshift__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((>>) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((>>) in a typing context is a Protocol)")
class _coconut_SupportsMatmul(_coconut.typing.Protocol):
"""Coconut (@) Protocol. Equivalent to:

Expand All @@ -2122,7 +2112,7 @@ class _coconut_SupportsMatmul(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __matmul__(self, other):
raise NotImplementedError("Protocol methods cannot be called at runtime ((@) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((@) in a typing context is a Protocol)")
class _coconut_SupportsInv(_coconut.typing.Protocol):
"""Coconut (~) Protocol. Equivalent to:

Expand All @@ -2131,7 +2121,7 @@ class _coconut_SupportsInv(_coconut.typing.Protocol):
raise NotImplementedError(...)
"""
def __invert__(self):
raise NotImplementedError("Protocol methods cannot be called at runtime ((~) in a typing context is a Protocol)")
raise _coconut.NotImplementedError("Protocol methods cannot be called at runtime ((~) in a typing context is a Protocol)")
{def_aliases}
_coconut_self_match_types = {self_match_types}
_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file}
Loading

0 comments on commit 2e03ba0

Please sign in to comment.