Skip to content

Commit

Permalink
Add memoize.RECURSIVE
Browse files Browse the repository at this point in the history
Refs   #858.
  • Loading branch information
evhub committed Nov 12, 2024
1 parent 35ff35a commit 4a366ea
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 13 deletions.
2 changes: 2 additions & 0 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3116,6 +3116,8 @@ _Note: Passing `--strict` disables deprecated features._

Coconut provides `functools.lru_cache` as a built-in under the name `memoize` with the modification that the _maxsize_ parameter is set to `None` by default. `memoize` makes the use case of optimizing recursive functions easier, as a _maxsize_ of `None` is usually what is desired in that case.

`memoize` also supports a special `maxsize=memoize.RECURSIVE` argument, which will allow the cache to grow without bound within a single call to the top-level function, but clear the cache after the top-level call returns.

Use of `memoize` requires `functools.lru_cache`, which exists in the Python 3 standard library, but under Python 2 will require `pip install backports.functools_lru_cache` to function. Additionally, if on Python 2 and `backports.functools_lru_cache` is present, Coconut will patch `functools` such that `functools.lru_cache = backports.functools_lru_cache.lru_cache`.

Note that, if the function to be memoized is a generator or otherwise returns an iterator, [`recursive_generator`](#recursive_generator) can also be used to achieve a similar effect, the use of which is required for recursive generators.
Expand Down
1 change: 1 addition & 0 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ _coconut_zip = zip

zip_longest = _coconut.zip_longest
memoize = _lru_cache
memoize.RECURSIVE = None # type: ignore
reduce = _coconut.functools.reduce
takewhile = _coconut.itertools.takewhile
dropwhile = _coconut.itertools.dropwhile
Expand Down
41 changes: 35 additions & 6 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -1643,17 +1643,46 @@ def fmap(func, obj, **kwargs):
else:
mapped_obj = _coconut_map(func, obj)
return _coconut_base_makedata(obj.__class__, mapped_obj, from_fmap=True, fallback_to_init=fallback_to_init)
def _coconut_memoize_helper(maxsize=None, typed=False):
return maxsize, typed
def memoize(*args, **kwargs):
"""Decorator that memoizes a function, preventing it from being recomputed
if it is called multiple times with the same arguments."""
if not kwargs and _coconut.len(args) == 1 and _coconut.callable(args[0]):
return _coconut.functools.lru_cache(maxsize=None)(args[0])
return _coconut_memoize_helper()(args[0])
if _coconut.len(kwargs) == 1 and "user_function" in kwargs and _coconut.callable(kwargs["user_function"]):
return _coconut.functools.lru_cache(maxsize=None)(kwargs["user_function"])
maxsize, typed = _coconut_memoize_helper(*args, **kwargs)
return _coconut.functools.lru_cache(maxsize, typed)
return _coconut_memoize_helper()(kwargs["user_function"])
return _coconut_memoize_helper(*args, **kwargs)
memoize.RECURSIVE = _coconut_Sentinel()
def _coconut_memoize_helper(maxsize=None, typed=False):
if maxsize is memoize.RECURSIVE:
def memoizer(func):
"""memoize(...)"""
inside = [False]
cache = {empty_dict}
@_coconut_wraps(func)
def memoized_func(*args, **kwargs):
if typed:
key = (_coconut.tuple((x, _coconut.type(x)) for x in args), _coconut.tuple((k, _coconut.type(k), v, _coconut.type(v)) for k, v in kwargs.items()))
else:
key = (args, _coconut.tuple(kwargs.items()))
got = cache.get(key, _coconut_sentinel)
if got is not _coconut_sentinel:
return got
outer_inside, inside[0] = inside[0], True
try:
got = func(*args, **kwargs)
cache[key] = got
return got
finally:
inside[0] = outer_inside
if not inside[0]:
cache.clear()
memoized_func.__module__ = _coconut.getattr(func, "__module__", None)
memoized_func.__name__ = _coconut.getattr(func, "__name__", None)
memoized_func.__qualname__ = _coconut.getattr(func, "__qualname__", None)
return memoized_func
return memoizer
else:
return _coconut.functools.lru_cache(maxsize, typed)
{def_call_set_names}
class override(_coconut_baseclass):
"""Declare a method in a subclass as an override of a parent class method.
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.1.2"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 3
DEVELOP = 4
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
14 changes: 14 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,20 @@ def primary_test_2() -> bool:
assert reduce(function=(+), iterable=range(5), initial=-1) == 9 # type: ignore
assert takewhile(predicate=ident, iterable=[1, 2, 1, 0, 1]) |> list == [1, 2, 1] # type: ignore
assert dropwhile(predicate=(not), iterable=range(5)) |> list == [1, 2, 3, 4] # type: ignore
@memoize(typed=True)
def typed_memoized_func(x):
if x is 1:
return None
else:
return (x, typed_memoized_func(1))
assert typed_memoized_func(1.0) == (1.0, None)
assert typed_memoized_func(1.0)[0] |> type == float
@memoize()
def untyped_memoized_func(x=None):
if x is None:
return (untyped_memoized_func(1), untyped_memoized_func(1.0))
return x
assert untyped_memoized_func() |> map$(type) |> tuple == (int, float)

with process_map.multiple_sequential_calls(): # type: ignore
assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion coconut/tests/src/cocotest/agnostic/suite.coco
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,10 @@ def suite_test() -> bool:
assert plus1sqsum_all(1, 2) == 13 == plus1sqsum_all_(1, 2) # type: ignore
assert sum_list_range(10) == 45
assert sum2([3, 4]) == 7
assert ridiculously_recursive(300) == 201666561657114122540576123152528437944095370972927688812965354745141489205495516550423117825 == ridiculously_recursive_(300)
with process_map.multiple_sequential_calls():
for ridiculously_recursive in (ridiculously_recursive1, ridiculously_recursive2, ridiculously_recursive3):
got = process_map(ridiculously_recursive, [300])
assert got == (201666561657114122540576123152528437944095370972927688812965354745141489205495516550423117825,), got
assert [fib(n) for n in range(16)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] == [fib_(n) for n in range(16)]
assert [fib_alt1(n) for n in range(16)] == [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] == [fib_alt2(n) for n in range(16)]
assert fib.cache_info().hits == 28
Expand Down
20 changes: 15 additions & 5 deletions coconut/tests/src/cocotest/agnostic/util.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1346,24 +1346,34 @@ def sum2(ab) = a + b where:
# Memoization
import functools

@memoize()
def ridiculously_recursive(n):
@memoize(None)
def ridiculously_recursive1(n):
"""Requires maxsize=None when called on large numbers."""
if n <= 0:
return 1
result = 0
for i in range(1, 200):
result += ridiculously_recursive(n-i)
result += ridiculously_recursive1(n-i)
return result

@functools.lru_cache(maxsize=None) # type: ignore
def ridiculously_recursive_(n):
def ridiculously_recursive2(n):
"""Requires maxsize=None when called on large numbers."""
if n <= 0:
return 1
result = 0
for i in range(1, 200):
result += ridiculously_recursive_(n-i)
result += ridiculously_recursive2(n-i)
return result

@memoize(memoize.RECURSIVE) # type: ignore
def ridiculously_recursive3(n):
"""Requires maxsize=None when called on large numbers."""
if n <= 0:
return 1
result = 0
for i in range(1, 200):
result += ridiculously_recursive3(n-i)
return result

def fib(n if n < 2) = n
Expand Down

0 comments on commit 4a366ea

Please sign in to comment.