diff --git a/DOCS.md b/DOCS.md index c93131e61..5ce14c3d8 100644 --- a/DOCS.md +++ b/DOCS.md @@ -678,7 +678,7 @@ f_into_g = lambda *args, **kwargs: g(f(*args, **kwargs)) Coconut uses a `$` sign right after an iterator before a slice to perform iterator slicing, as in `it$[:5]`. Coconut's iterator slicing works much the same as Python's sequence slicing, and looks much the same as Coconut's partial application, but with brackets instead of parentheses. -Iterator slicing works just like sequence slicing, including support for negative indices and slices, and support for `slice` objects in the same way as can be done with normal slicing. Iterator slicing makes no guarantee, however, that the original iterator passed to it be preserved (to preserve the iterator, use Coconut's [`tee`](#tee) or [`reiterable`](#reiterable) built-ins). +Iterator slicing works just like sequence slicing, including support for negative indices and slices, and support for `slice` objects in the same way as can be done with normal slicing. Iterator slicing makes no guarantee, however, that the original iterator passed to it be preserved (to preserve the iterator, use Coconut's [`reiterable`](#reiterable) built-in). Coconut's iterator slicing is very similar to Python's `itertools.islice`, but unlike `itertools.islice`, Coconut's iterator slicing supports negative indices, and will preferentially call an object's `__iter_getitem__` (Coconut-specific magic method, preferred) or `__getitem__` (general Python magic method), if they exist. Coconut's iterator slicing is also optimized to work well with all of Coconut's built-in objects, only computing the elements of each that are actually necessary to extract the desired slice. @@ -1074,7 +1074,7 @@ base_pattern ::= ( - Iterable Splits (` :: :: :: :: `): same as other sequence destructuring, but works on any iterable (`collections.abc.Iterable`), including infinite iterators (note that if an iterator is matched against it will be modified unless it is [`reiterable`](#reiterable)). - Complex String Matching (` + + + + `): string matching supports the same destructuring options as above. -_Note: Like [iterator slicing](#iterator-slicing), iterator and lazy list matching make no guarantee that the original iterator matched against be preserved (to preserve the iterator, use Coconut's [`tee`](#tee) or [`reiterable`](#reiterable) built-ins)._ +_Note: Like [iterator slicing](#iterator-slicing), iterator and lazy list matching make no guarantee that the original iterator matched against be preserved (to preserve the iterator, use Coconut's [`reiterable`](#reiterable) built-in)._ When checking whether or not an object can be matched against in a particular fashion, Coconut makes use of Python's abstract base classes. Therefore, to ensure proper matching for a custom object, it's recommended to register it with the proper abstract base classes. @@ -1271,7 +1271,7 @@ data () [from ]: ``` `` is the name of the new data type, `` are the arguments to its constructor as well as the names of its attributes, `` contains the data type's methods, and `` optionally contains any desired base classes. -Coconut allows data fields in `` to have defaults and/or [type annotations](#enhanced-type-annotation) attached to them, and supports a starred parameter at the end to collect extra arguments. +Coconut allows data fields in `` to have defaults and/or [type annotations](#enhanced-type-annotation) attached to them, and supports a starred parameter at the end to collect extra arguments. Additionally, Coconut allows type parameters to be specified in brackets after `` using Coconut's [type parameter syntax](#type-parameter-syntax). Writing constructors for `data` types must be done using the `__new__` method instead of the `__init__` method. For helping to easily write `__new__` methods, Coconut provides the [makedata](#makedata) built-in. @@ -2879,12 +2879,39 @@ if group: pairs.append(tuple(group)) ``` +### `reiterable` + +**reiterable**(_iterable_) + +`reiterable` wraps the given iterable to ensure that every time the `reiterable` is iterated over, it produces the same results. Note that the result need not be a `reiterable` object if the given iterable is already reiterable. `reiterable` uses [`tee`](#tee) under the hood and `tee` can be used in its place, though `reiterable` is generally recommended over `tee`. + +##### Example + +**Coconut:** +```coconut +def list_type(xs): + match reiterable(xs): + case [fst, snd] :: tail: + return "at least 2" + case [fst] :: tail: + return "at least 1" + case (| |): + return "empty" +``` + +**Python:** +_Can't be done without a long series of checks for each `match` statement. See the compiled code for the Python syntax._ + ### `tee` **tee**(_iterable_, _n_=`2`) Coconut provides an optimized version of `itertools.tee` as a built-in under the name `tee`. +Though `tee` is not deprecated, [`reiterable`](#reiterable) is generally recommended over `tee`. + +Custom `tee`/`reiterable` implementations for custom [Containers/Collections](https://docs.python.org/3/library/collections.abc.html) should be put in the `__copy__` method. Note that all [Sequences/Mappings/Sets](https://docs.python.org/3/library/collections.abc.html) are always assumed to be reiterable even without calling `__copy__`. + ##### Python Docs **tee**(_iterable, n=2_) @@ -2922,58 +2949,6 @@ original, temp = itertools.tee(original) sliced = itertools.islice(temp, 5, None) ``` -### `reiterable` - -**reiterable**(_iterable_) - -Sometimes, when an iterator may need to be iterated over an arbitrary number of times, [`tee`](#tee) can be cumbersome to use. For such cases, Coconut provides `reiterable`, which wraps the given iterator such that whenever an attempt to iterate over it is made, it iterates over a `tee` instead of the original. - -##### Example - -**Coconut:** -```coconut -def list_type(xs): - match reiterable(xs): - case [fst, snd] :: tail: - return "at least 2" - case [fst] :: tail: - return "at least 1" - case (| |): - return "empty" -``` - -**Python:** -_Can't be done without a long series of checks for each `match` statement. See the compiled code for the Python syntax._ - -### `consume` - -**consume**(_iterable_, _keep\_last_=`0`) - -Coconut provides the `consume` function to efficiently exhaust an iterator and thus perform any lazy evaluation contained within it. `consume` takes one optional argument, `keep_last`, that defaults to 0 and specifies how many, if any, items from the end to return as a sequence (`None` will keep all elements). - -Equivalent to: -```coconut -def consume(iterable, keep_last=0): - """Fully exhaust iterable and return the last keep_last elements.""" - return collections.deque(iterable, maxlen=keep_last) # fastest way to exhaust an iterator -``` - -##### Rationale - -In the process of lazily applying operations to iterators, eventually a point is reached where evaluation of the iterator is necessary. To do this efficiently, Coconut provides the `consume` function, which will fully exhaust the iterator given to it. - -##### Example - -**Coconut:** -```coconut -range(10) |> map$((x) -> x**2) |> map$(print) |> consume -``` - -**Python:** -```coconut_python -collections.deque(map(print, map(lambda x: x**2, range(10))), maxlen=0) -``` - ### `count` **count**(_start_=`0`, _step_=`1`) @@ -3160,7 +3135,7 @@ for x in input_data: **flatten**(_iterable_) -Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `len`, `repr`, `in`, `.count()`, `.index()`, and `fmap`. +Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `repr`, `in`, `.count()`, `.index()`, and `fmap`. Additionally, `flatten` includes special support for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, in which case a multidimensional array is returned instead of an iterator. @@ -3458,6 +3433,111 @@ with concurrent.futures.ThreadPoolExecutor() as executor: print(list(executor.map(get_data_for_user, get_all_users()))) ``` +### `consume` + +**consume**(_iterable_, _keep\_last_=`0`) + +Coconut provides the `consume` function to efficiently exhaust an iterator and thus perform any lazy evaluation contained within it. `consume` takes one optional argument, `keep_last`, that defaults to 0 and specifies how many, if any, items from the end to return as a sequence (`None` will keep all elements). + +Equivalent to: +```coconut +def consume(iterable, keep_last=0): + """Fully exhaust iterable and return the last keep_last elements.""" + return collections.deque(iterable, maxlen=keep_last) # fastest way to exhaust an iterator +``` + +##### Rationale + +In the process of lazily applying operations to iterators, eventually a point is reached where evaluation of the iterator is necessary. To do this efficiently, Coconut provides the `consume` function, which will fully exhaust the iterator given to it. + +##### Example + +**Coconut:** +```coconut +range(10) |> map$((x) -> x**2) |> map$(print) |> consume +``` + +**Python:** +```coconut_python +collections.deque(map(print, map(lambda x: x**2, range(10))), maxlen=0) +``` + +### `Expected` + +**Expected**(_result_=`None`, _error_=`None`) + +Coconut's `Expected` built-in is a Coconut [`data` type](#data) that represents a value that may or may not be an error, similar to Haskell's [`Either`](https://hackage.haskell.org/package/base-4.17.0.0/docs/Data-Either.html). + +`Expected` is effectively equivalent to the following: +```coconut +data Expected[T](result: T?, error: Exception?): + def __new__(cls, result: T?=None, error: Exception?=None) -> Expected[T]: + if result is not None and error is not None: + raise ValueError("Expected cannot have both a result and an error") + return makedata(cls, result, error) + def __bool__(self) -> bool: + return self.error is None + def __fmap__[U](self, func: T -> U) -> Expected[U]: + return self.__class__(func(self.result)) if self else self +``` + +`Expected` is primarily used as the return type for [`safe_call`](#safe_call). Generally, the best way to use `Expected` is with [`fmap`](#fmap), which will apply a function to the result if it exists, or otherwise retain the error. + +##### Example + +**Coconut:** +```coconut +def try_divide(x, y): + try: + return Expected(x / y) + except Exception as err: + return Expected(error=err) + +try_divide(1, 2) |> fmap$(.+1) |> print +try_divide(1, 0) |> fmap$(.+1) |> print +``` + +**Python:** +_Can't be done without a complex `Expected` definition. See the compiled code for the Python syntax._ + +### `call` + +**call**(_func_, /, *_args_, \*\*_kwargs_) + +Coconut's `call` simply implements function application. Thus, `call` is equivalent to +```coconut +def call(f, /, *args, **kwargs) = f(*args, **kwargs) +``` + +`call` is primarily useful as an [operator function](#operator-functions) for function application when writing in a point-free style. + +**DEPRECATED:** `of` is available as a deprecated alias for `call`. Note that deprecated features are disabled in `--strict` mode. + +### `safe_call` + +**safe_call**(_func_, /, *_args_, \*\*_kwargs_) + +Coconut's `safe_call` is a version of [`call`](#call) that catches any `Exception`s and returns an [`Expected`](#expected) containing either the result or the error. + +`safe_call` is effectively equivalent to: +```coconut +def safe_call(f, /, *args, **kwargs): + try: + return Expected(f(*args, **kwargs)) + except Exception as err: + return Expected(error=err) +``` + +##### Example + +**Coconut:** +```coconut +res, err = safe_call(-> 1 / 0) |> fmap$(.+1) +``` + +**Python:** +_Can't be done without a complex `Expected` definition. See the compiled code for the Python syntax._ + ### `lift` **lift**(_func_) @@ -3523,19 +3603,6 @@ def flip(f, nargs=None) = ) ``` -### `call` - -**call**(_func_, /, *_args_, \*\*_kwargs_) - -Coconut's `call` simply implements function application. Thus, `call` is equivalent to -```coconut -def call(f, /, *args, **kwargs) = f(*args, **kwargs) -``` - -`call` is primarily useful as an [operator function](#operator-functions) for function application when writing in a point-free style. - -**DEPRECATED:** `of` is available as a deprecated alias for `call`. Note that deprecated features are disabled in `--strict` mode. - ### `const` **const**(_value_) diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index a3f6cd073..c3a567a73 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -47,6 +47,7 @@ _Wco = _t.TypeVar("_Wco", covariant=True) _Tcontra = _t.TypeVar("_Tcontra", contravariant=True) _Tfunc = _t.TypeVar("_Tfunc", bound=_Callable) _Ufunc = _t.TypeVar("_Ufunc", bound=_Callable) +_Tfunc_contra = _t.TypeVar("_Tfunc_contra", bound=_Callable, contravariant=True) _Titer = _t.TypeVar("_Titer", bound=_Iterable) _T_iter_func = _t.TypeVar("_T_iter_func", bound=_t.Callable[..., _Iterable]) @@ -179,6 +180,7 @@ def _coconut_tco(func: _Tfunc) -> _Tfunc: return func +# any changes here should also be made to safe_call below @_t.overload def call( _func: _t.Callable[[_T], _Uco], @@ -232,6 +234,71 @@ def call( _coconut_tail_call = of = call +class _base_Expected(_t.NamedTuple, _t.Generic[_T]): + result: _t.Optional[_T] + error: _t.Optional[Exception] + def __fmap__(self, func: _t.Callable[[_T], _U]) -> Expected[_U]: ... +class Expected(_base_Expected[_T]): + __slots__ = () + def __new__( + self, + result: _t.Optional[_T] = None, + error: _t.Optional[Exception] = None + ) -> Expected[_T]: ... +_coconut_Expected = Expected + + +# should match call above but with Expected +@_t.overload +def safe_call( + _func: _t.Callable[[_T], _Uco], + _x: _T, +) -> Expected[_Uco]: ... +@_t.overload +def safe_call( + _func: _t.Callable[[_T, _U], _Vco], + _x: _T, + _y: _U, +) -> Expected[_Vco]: ... +@_t.overload +def safe_call( + _func: _t.Callable[[_T, _U, _V], _Wco], + _x: _T, + _y: _U, + _z: _V, +) -> Expected[_Wco]: ... +@_t.overload +def safe_call( + _func: _t.Callable[_t.Concatenate[_T, _P], _Uco], + _x: _T, + *args: _t.Any, + **kwargs: _t.Any, +) -> Expected[_Uco]: ... +@_t.overload +def safe_call( + _func: _t.Callable[_t.Concatenate[_T, _U, _P], _Vco], + _x: _T, + _y: _U, + *args: _t.Any, + **kwargs: _t.Any, +) -> Expected[_Vco]: ... +@_t.overload +def safe_call( + _func: _t.Callable[_t.Concatenate[_T, _U, _V, _P], _Wco], + _x: _T, + _y: _U, + _z: _V, + *args: _t.Any, + **kwargs: _t.Any, +) -> Expected[_Wco]: ... +@_t.overload +def safe_call( + _func: _t.Callable[..., _Tco], + *args: _t.Any, + **kwargs: _t.Any, +) -> Expected[_Tco]: ... + + def recursive_iterator(func: _T_iter_func) -> _T_iter_func: return func @@ -248,7 +315,7 @@ def _coconut_call_set_names(cls: object) -> None: ... class _coconut_base_pattern_func: def __init__(self, *funcs: _Callable) -> None: ... - def add(self, func: _Callable) -> None: ... + def add_pattern(self, func: _Callable) -> None: ... def __call__(self, *args: _t.Any, **kwargs: _t.Any) -> _t.Any: ... @_t.overload @@ -274,6 +341,7 @@ def _coconut_mark_as_match(func: _Tfunc) -> _Tfunc: class _coconut_partial(_t.Generic[_T]): args: _Tuple = ... + required_nargs: int = ... keywords: _t.Dict[_t.Text, _t.Any] = ... def __init__( self, @@ -564,6 +632,12 @@ def consume( ) -> _t.Sequence[_T]: ... +class _FMappable(_t.Protocol[_Tfunc_contra, _Tco]): + def __fmap__(self, func: _Tfunc_contra) -> _Tco: ... + + +@_t.overload +def fmap(func: _Tfunc, obj: _FMappable[_Tfunc, _Tco]) -> _Tco: ... @_t.overload def fmap(func: _t.Callable[[_Tco], _Tco], obj: _Titer) -> _Titer: ... @_t.overload diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 3c7cd29ac..fdd6fd5fd 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -22,13 +22,14 @@ import types as _types import itertools as _itertools import operator as _operator import threading as _threading -import weakref as _weakref import os as _os import warnings as _warnings import contextlib as _contextlib import traceback as _traceback -import pickle as _pickle +import weakref as _weakref import multiprocessing as _multiprocessing +import math as _math +import pickle as _pickle from multiprocessing import dummy as _multiprocessing_dummy if sys.version_info >= (3,): @@ -88,29 +89,38 @@ typing = _t collections = _collections copy = _copy -copyreg = _copyreg functools = _functools types = _types itertools = _itertools operator = _operator threading = _threading -weakref = _weakref os = _os warnings = _warnings contextlib = _contextlib traceback = _traceback -pickle = _pickle -asyncio = _asyncio -abc = _abc +weakref = _weakref multiprocessing = _multiprocessing +math = _math multiprocessing_dummy = _multiprocessing_dummy -numpy = _numpy -npt = _npt # Fake, like typing + +copyreg = _copyreg +asyncio = _asyncio +pickle = _pickle if sys.version_info >= (2, 7): OrderedDict = collections.OrderedDict else: OrderedDict = dict +abc = _abc +abc.Sequence.register(collections.deque) +numpy = _numpy +npt = _npt # Fake, like typing zip_longest = _zip_longest + +numpy_modules: _t.Any = ... +jax_numpy_modules: _t.Any = ... +tee_type: _t.Any = ... +reiterables: _t.Any = ... + Ellipsis = Ellipsis NotImplemented = NotImplemented NotImplementedError = NotImplementedError diff --git a/coconut/__coconut__.pyi b/coconut/__coconut__.pyi index 07e89e519..838ec0d57 100644 --- a/coconut/__coconut__.pyi +++ b/coconut/__coconut__.pyi @@ -1,2 +1,2 @@ from __coconut__ import * -from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_super, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten +from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_super, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 8f75df637..74d9517de 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -2706,7 +2706,11 @@ def make_namedtuple_call(self, name, namedtuple_args, types=None): return '_coconut.collections.namedtuple("' + name + '", ' + tuple_str_of(namedtuple_args, add_quotes=True) + ')' def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts, stmts, match_args, paramdefs=()): - """Create a data class definition from the given components.""" + """Create a data class definition from the given components. + + IMPORTANT: Any changes to assemble_data must be reflected in the + definition of Expected in header.py_template. + """ # create class out = ( "".join(paramdefs) diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index bf11dd365..00e2e1306 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -199,6 +199,7 @@ def process_header_args(which, target, use_hash, no_tco, strict, no_wrap): VERSION_STR=VERSION_STR, module_docstring='"""Built-in Coconut utilities."""\n\n' if which == "__coconut__" else "", object="" if target_startswith == "3" else "(object)", + comma_object="" if target_startswith == "3" else ", object", report_this_text=report_this_text, numpy_modules=tuple_str_of(numpy_modules, add_quotes=True), jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True), @@ -443,7 +444,7 @@ async def __anext__(self): # second round for format dict elements that use the format dict extra_format_dict = dict( # when anything is added to this list it must also be added to *both* __coconut__ stub files - underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_super, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten".format(**format_dict), + underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_super, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten".format(**format_dict), import_typing=pycondition( (3, 5), if_ge="import typing", diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index bce1ea763..33f2b6eb2 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -31,6 +31,8 @@ def _coconut_super(type=None, object_or_type=None): abc.Sequence.register(numpy.ndarray) numpy_modules = {numpy_modules} jax_numpy_modules = {jax_numpy_modules} + tee_type = type(itertools.tee((), 1)[0]) + reiterables = abc.Sequence, abc.Mapping, abc.Set abc.Sequence.register(collections.deque) Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, callable, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, callable, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} class _coconut_sentinel{object}: @@ -109,6 +111,80 @@ def _coconut_tco(func): tail_call_optimized_func.__qualname__ = _coconut.getattr(func, "__qualname__", None) _coconut_tco_func_dict[_coconut.id(tail_call_optimized_func)] = _coconut.weakref.ref(tail_call_optimized_func) return tail_call_optimized_func +@_coconut.functools.wraps(_coconut.itertools.tee) +def tee(iterable, n=2): + if n < 0: + raise ValueError("n must be >= 0") + elif n == 0: + return () + elif n == 1: + return (iterable,) + elif _coconut.isinstance(iterable, _coconut.reiterables): + return (iterable,) * n + else: + if _coconut.getattr(iterable, "__getitem__", None) is not None or _coconut.isinstance(iterable, (_coconut.tee_type, _coconut.abc.Sized, _coconut.abc.Container)): + existing_copies = [iterable] + while _coconut.len(existing_copies) < n: + try: + copy = _coconut.copy.copy(iterable) + except _coconut.TypeError: + break + else: + existing_copies.append(copy) + else:{COMMENT.no_break} + return _coconut.tuple(existing_copies) + return _coconut.itertools.tee(iterable, n) +class _coconut_has_iter(_coconut_base_hashable): + __slots__ = ("lock", "iter") + def __new__(cls, iterable): + self = _coconut.object.__new__(cls) + self.lock = _coconut.threading.Lock() + self.iter = iterable + return self + def get_new_iter(self): + """Tee the underlying iterator.""" + with self.lock: + self.iter = _coconut_reiterable(self.iter) + return self.iter +class reiterable(_coconut_has_iter): + """Allow an iterator to be iterated over multiple times with the same results.""" + __slots__ = () + def __new__(cls, iterable): + if _coconut.isinstance(iterable, _coconut.reiterables): + return iterable + return _coconut_has_iter.__new__(cls, iterable) + def get_new_iter(self): + """Tee the underlying iterator.""" + with self.lock: + self.iter, new_iter = _coconut_tee(self.iter) + return new_iter + def __iter__(self): + return _coconut.iter(self.get_new_iter()) + def __repr__(self): + return "reiterable(%s)" % (_coconut.repr(self.get_new_iter()),) + def __reduce__(self): + return (self.__class__, (self.iter,)) + def __copy__(self): + return self.__class__(self.get_new_iter()) + def __fmap__(self, func): + return _coconut_map(func, self) + def __getitem__(self, index): + return _coconut_iter_getitem(self.get_new_iter(), index) + def __reversed__(self): + return _coconut_reversed(self.get_new_iter()) + def __len__(self): + if not _coconut.isinstance(self.iter, _coconut.abc.Sized): + return _coconut.NotImplemented + return _coconut.len(self.get_new_iter()) + def __contains__(self, elem): + return elem in self.get_new_iter() + def count(self, elem): + """Count the number of times elem appears in the iterable.""" + return self.get_new_iter().count(elem) + def index(self, elem): + """Find the index of elem in the iterable.""" + return self.get_new_iter().index(elem) +_coconut.reiterables = (reiterable,) + _coconut.reiterables def _coconut_iter_getitem_special_case(iterable, start, stop, step): iterable = _coconut.itertools.islice(iterable, start, None) cache = _coconut.collections.deque(_coconut.itertools.islice(iterable, -stop), maxlen=-stop) @@ -141,7 +217,7 @@ def _coconut_iter_getitem(iterable, index): return _coconut.collections.deque(iterable, maxlen=-index)[0] result = _coconut.next(_coconut.itertools.islice(iterable, index, index + 1), _coconut_sentinel) if result is _coconut_sentinel: - raise _coconut.IndexError("$[] index out of range") + raise _coconut.IndexError(".$[] index out of range") return result start = _coconut.operator.index(index.start) if index.start is not None else None stop = _coconut.operator.index(index.stop) if index.stop is not None else None @@ -327,72 +403,21 @@ def _coconut_comma_op(*args): """Comma operator (,). Equivalent to (*args) -> args.""" return args {def_coconut_matmul} -@_coconut.functools.wraps(_coconut.itertools.tee) -def tee(iterable, n=2): - if n < 0: - raise ValueError("n must be >= 0") - elif n == 0: - return () - elif n == 1: - return (iterable,) - elif _coconut.isinstance(iterable, (_coconut.tuple, _coconut.frozenset)): - return (iterable,) * n - else: - if _coconut.getattr(iterable, "__getitem__", None) is not None: - try: - copy = _coconut.copy.copy(iterable) - except _coconut.TypeError: - pass - else: - return (iterable, copy) + _coconut.tuple(_coconut.copy.copy(iterable) for _ in _coconut.range(2, n)) - return _coconut.itertools.tee(iterable, n) -class reiterable(_coconut_base_hashable): - """Allow an iterator to be iterated over multiple times with the same results.""" - __slots__ = ("lock", "iter") - def __new__(cls, iterable): - if _coconut.isinstance(iterable, _coconut_reiterable): - return iterable - self = _coconut.object.__new__(cls) - self.lock = _coconut.threading.Lock() - self.iter = iterable - return self - def get_new_iter(self): - with self.lock: - self.iter, new_iter = _coconut_tee(self.iter) - return new_iter - def __iter__(self): - return _coconut.iter(self.get_new_iter()) - def __getitem__(self, index): - return _coconut_iter_getitem(self.get_new_iter(), index) - def __reversed__(self): - return _coconut_reversed(self.get_new_iter()) - def __len__(self): - if not _coconut.isinstance(self.iter, _coconut.abc.Sized): - return _coconut.NotImplemented - return _coconut.len(self.iter) - def __repr__(self): - return "reiterable(%s)" % (_coconut.repr(self.iter),) - def __reduce__(self): - return (self.__class__, (self.iter,)) - def __copy__(self): - return self.__class__(self.get_new_iter()) - def __fmap__(self, func): - return _coconut_map(func, self) -class scan(_coconut_base_hashable): +class scan(_coconut_has_iter): """Reduce func over iterable, yielding intermediate results, optionally starting from initial.""" - __slots__ = ("func", "iter", "initial") - def __init__(self, function, iterable, initial=_coconut_sentinel): + __slots__ = ("func", "initial") + def __new__(cls, function, iterable, initial=_coconut_sentinel): + self = _coconut_has_iter.__new__(cls, iterable) self.func = function - self.iter = iterable self.initial = initial + return self def __repr__(self): return "scan(%r, %s%s)" % (self.func, _coconut.repr(self.iter), "" if self.initial is _coconut_sentinel else ", " + _coconut.repr(self.initial)) def __reduce__(self): return (self.__class__, (self.func, self.iter, self.initial)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(self.func, new_iter, self.initial) + return self.__class__(self.func, self.get_new_iter(), self.initial) def __iter__(self): acc = self.initial if acc is not _coconut_sentinel: @@ -409,15 +434,14 @@ class scan(_coconut_base_hashable): return _coconut.len(self.iter) def __fmap__(self, func): return _coconut_map(func, self) -class reversed(_coconut_base_hashable): - __slots__ = ("iter",) +class reversed(_coconut_has_iter): + __slots__ = () __doc__ = getattr(_coconut.reversed, "__doc__", "") def __new__(cls, iterable): if _coconut.isinstance(iterable, _coconut.range): return iterable[::-1] if _coconut.getattr(iterable, "__reversed__", None) is None or _coconut.isinstance(iterable, (_coconut.list, _coconut.tuple)): - self = _coconut.object.__new__(cls) - self.iter = iterable + self = _coconut_has_iter.__new__(cls, iterable) return self return _coconut.reversed(iterable) def __repr__(self): @@ -425,8 +449,7 @@ class reversed(_coconut_base_hashable): def __reduce__(self): return (self.__class__, (self.iter,)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(new_iter) + return self.__class__(self.get_new_iter()) def __iter__(self): return _coconut.iter(_coconut.reversed(self.iter)) def __getitem__(self, index): @@ -449,53 +472,50 @@ class reversed(_coconut_base_hashable): return _coconut.len(self.iter) - self.iter.index(elem) - 1 def __fmap__(self, func): return self.__class__(_coconut_map(func, self.iter)) -class flatten(_coconut_base_hashable): +class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_become_very_innefficient} """Flatten an iterable of iterables into a single iterable. Flattens the first axis of numpy arrays.""" - __slots__ = ("iter",) + __slots__ = () def __new__(cls, iterable): if iterable.__class__.__module__ in _coconut.numpy_modules: if len(iterable.shape) < 2: raise _coconut.TypeError("flatten() on numpy arrays requires two or more dimensions") return iterable.reshape(-1, *iterable.shape[2:]) - self = _coconut.object.__new__(cls) - self.iter = iterable + self = _coconut_has_iter.__new__(cls, iterable) return self + def get_new_iter(self): + """Tee the underlying iterator.""" + with self.lock: + if not (_coconut.isinstance(self.iter, _coconut_reiterable) and _coconut.isinstance(self.iter.iter, _coconut_map) and self.iter.iter.func is _coconut_reiterable): + self.iter = _coconut_map(_coconut_reiterable, self.iter) + self.iter = _coconut_reiterable(self.iter) + return self.iter def __iter__(self): return _coconut.itertools.chain.from_iterable(self.iter) def __reversed__(self): - return self.__class__(_coconut_reversed(_coconut_map(_coconut_reversed, self.iter))) + return self.__class__(_coconut_reversed(_coconut_map(_coconut_reversed, self.get_new_iter()))) def __repr__(self): return "flatten(%s)" % (_coconut.repr(self.iter),) def __reduce__(self): return (self.__class__, (self.iter,)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(new_iter) + return self.__class__(self.get_new_iter()) def __contains__(self, elem): - self.iter, new_iter = _coconut_tee(self.iter) - return _coconut.any(elem in it for it in new_iter) - def __len__(self): - if not _coconut.isinstance(self.iter, _coconut.abc.Sized): - return _coconut.NotImplemented - self.iter, new_iter = _coconut_tee(self.iter) - return _coconut.sum(_coconut.len(it) for it in new_iter) + return _coconut.any(elem in it for it in self.get_new_iter()) def count(self, elem): """Count the number of times elem appears in the flattened iterable.""" - self.iter, new_iter = _coconut_tee(self.iter) - return _coconut.sum(it.count(elem) for it in new_iter) + return _coconut.sum(it.count(elem) for it in self.get_new_iter()) def index(self, elem): """Find the index of elem in the flattened iterable.""" - self.iter, new_iter = _coconut_tee(self.iter) ind = 0 - for it in new_iter: + for it in self.get_new_iter(): try: return ind + it.index(elem) except _coconut.ValueError: ind += _coconut.len(it) raise ValueError("%r not in %r" % (elem, self)) def __fmap__(self, func): - return self.__class__(_coconut_map(_coconut.functools.partial(_coconut_map, func), self.iter)) + return self.__class__(_coconut_map(_coconut.functools.partial(_coconut_map, func), self.get_new_iter())) class cartesian_product(_coconut_base_hashable): __slots__ = ("iters", "repeat") __doc__ = getattr(_coconut.itertools.product, "__doc__", "Cartesian product of input iterables.") + """ @@ -528,14 +548,8 @@ Additionally supports Cartesian products of numpy arrays.""" def __reduce__(self): return (self.__class__, self.iters, {lbrace}"repeat": self.repeat{rbrace}) def __copy__(self): - old_iters = [] - new_iters = [] - for it in self.iters: - old_it, new_it = _coconut_tee(it) - old_iters.append(old_it) - new_iters.append(new_it) - self.iters = old_iters - return self.__class__(*new_iters, repeat=self.repeat) + self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + return self.__class__(*self.iters, repeat=self.repeat) @property def all_iters(self): return _coconut.itertools.chain.from_iterable(_coconut.itertools.repeat(self.iters, self.repeat)) @@ -567,10 +581,10 @@ class map(_coconut_base_hashable, _coconut.map): __slots__ = ("func", "iters") __doc__ = getattr(_coconut.map, "__doc__", "") def __new__(cls, function, *iterables): - new_map = _coconut.map.__new__(cls, function, *iterables) - new_map.func = function - new_map.iters = iterables - return new_map + self = _coconut.map.__new__(cls, function, *iterables) + self.func = function + self.iters = iterables + return self def __getitem__(self, index): if _coconut.isinstance(index, _coconut.slice): return self.__class__(self.func, *(_coconut_iter_getitem(it, index) for it in self.iters)) @@ -586,14 +600,8 @@ class map(_coconut_base_hashable, _coconut.map): def __reduce__(self): return (self.__class__, (self.func,) + self.iters) def __copy__(self): - old_iters = [] - new_iters = [] - for it in self.iters: - old_it, new_it = _coconut_tee(it) - old_iters.append(old_it) - new_iters.append(new_it) - self.iters = old_iters - return self.__class__(self.func, *new_iters) + self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + return self.__class__(self.func, *self.iters) def __iter__(self): return _coconut.iter(_coconut.map(self.func, *self.iters)) def __fmap__(self, func): @@ -689,10 +697,10 @@ class filter(_coconut_base_hashable, _coconut.filter): __slots__ = ("func", "iter") __doc__ = getattr(_coconut.filter, "__doc__", "") def __new__(cls, function, iterable): - new_filter = _coconut.filter.__new__(cls, function, iterable) - new_filter.func = function - new_filter.iter = iterable - return new_filter + self = _coconut.filter.__new__(cls, function, iterable) + self.func = function + self.iter = iterable + return self def __reversed__(self): return self.__class__(self.func, _coconut_reversed(self.iter)) def __repr__(self): @@ -700,8 +708,8 @@ class filter(_coconut_base_hashable, _coconut.filter): def __reduce__(self): return (self.__class__, (self.func, self.iter)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(self.func, new_iter) + self.iter = _coconut_reiterable(self.iter) + return self.__class__(self.func, self.iter) def __iter__(self): return _coconut.iter(_coconut.filter(self.func, self.iter)) def __fmap__(self, func): @@ -710,12 +718,12 @@ class zip(_coconut_base_hashable, _coconut.zip): __slots__ = ("iters", "strict") __doc__ = getattr(_coconut.zip, "__doc__", "") def __new__(cls, *iterables, **kwargs): - new_zip = _coconut.zip.__new__(cls, *iterables) - new_zip.iters = iterables - new_zip.strict = kwargs.pop("strict", False) + self = _coconut.zip.__new__(cls, *iterables) + self.iters = iterables + self.strict = kwargs.pop("strict", False) if kwargs: raise _coconut.TypeError("zip() got unexpected keyword arguments " + _coconut.repr(kwargs)) - return new_zip + return self def __getitem__(self, index): if _coconut.isinstance(index, _coconut.slice): return self.__class__(*(_coconut_iter_getitem(i, index) for i in self.iters), strict=self.strict) @@ -731,14 +739,8 @@ class zip(_coconut_base_hashable, _coconut.zip): def __reduce__(self): return (self.__class__, self.iters, {lbrace}"strict": self.strict{rbrace}) def __copy__(self): - old_iters = [] - new_iters = [] - for it in self.iters: - old_it, new_it = _coconut_tee(it) - old_iters.append(old_it) - new_iters.append(new_it) - self.iters = old_iters - return self.__class__(*new_iters, strict=self.strict) + self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + return self.__class__(*self.iters, strict=self.strict) def __iter__(self): {zip_iter} def __fmap__(self, func): @@ -788,24 +790,18 @@ class zip_longest(zip): def __reduce__(self): return (self.__class__, self.iters, {lbrace}"fillvalue": self.fillvalue{rbrace}) def __copy__(self): - old_iters = [] - new_iters = [] - for it in self.iters: - old_it, new_it = _coconut_tee(it) - old_iters.append(old_it) - new_iters.append(new_it) - self.iters = old_iters - return self.__class__(*new_iters, fillvalue=self.fillvalue) + self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + return self.__class__(*self.iters, fillvalue=self.fillvalue) def __iter__(self): return _coconut.iter(_coconut.zip_longest(*self.iters, fillvalue=self.fillvalue)) class enumerate(_coconut_base_hashable, _coconut.enumerate): __slots__ = ("iter", "start") __doc__ = getattr(_coconut.enumerate, "__doc__", "") def __new__(cls, iterable, start=0): - new_enumerate = _coconut.enumerate.__new__(cls, iterable, start) - new_enumerate.iter = iterable - new_enumerate.start = start - return new_enumerate + self = _coconut.enumerate.__new__(cls, iterable, start) + self.iter = iterable + self.start = start + return self def __repr__(self): return "enumerate(%s, %r)" % (_coconut.repr(self.iter), self.start) def __fmap__(self, func): @@ -813,8 +809,8 @@ class enumerate(_coconut_base_hashable, _coconut.enumerate): def __reduce__(self): return (self.__class__, (self.iter, self.start)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(new_iter, self.start) + self.iter = _coconut_reiterable(self.iter) + return self.__class__(self.iter, self.start) def __iter__(self): return _coconut.iter(_coconut.enumerate(self.iter, self.start)) def __getitem__(self, index): @@ -825,7 +821,7 @@ class enumerate(_coconut_base_hashable, _coconut.enumerate): if not _coconut.isinstance(self.iter, _coconut.abc.Sized): return _coconut.NotImplemented return _coconut.len(self.iter) -class multi_enumerate(_coconut_base_hashable): +class multi_enumerate(_coconut_has_iter): """Enumerate an iterable of iterables. Works like enumerate, but indexes through inner iterables and produces a tuple index representing the index in each inner iterable. Supports indexing. @@ -837,9 +833,7 @@ class multi_enumerate(_coconut_base_hashable): Also supports len for numpy arrays. """ - __slots__ = ("iter",) - def __init__(self, iterable): - self.iter = iterable + __slots__ = () def __repr__(self): return "multi_enumerate(%s)" % (_coconut.repr(self.iter),) def __fmap__(self, func): @@ -847,8 +841,7 @@ class multi_enumerate(_coconut_base_hashable): def __reduce__(self): return (self.__class__, (self.iter,)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(new_iter) + return self.__class__(self.get_new_iter()) @property def is_numpy(self): return self.iter.__class__.__module__ in _coconut.numpy_modules @@ -943,17 +936,18 @@ class count(_coconut_base_hashable): return (self.__class__, (self.start, self.step)) def __fmap__(self, func): return _coconut_map(func, self) -class groupsof(_coconut_base_hashable): +class groupsof(_coconut_has_iter): """groupsof(n, iterable) splits iterable into groups of size n. If the length of the iterable is not divisible by n, the last group will be of size < n. """ - __slots__ = ("group_size", "iter") - def __init__(self, n, iterable): + __slots__ = ("group_size",) + def __new__(cls, n, iterable): + self = _coconut_has_iter.__new__(cls, iterable) self.group_size = _coconut.operator.index(n) if self.group_size <= 0: raise _coconut.ValueError("group size must be > 0; not %r" % (self.group_size,)) - self.iter = iterable + return self def __iter__(self): iterator = _coconut.iter(self.iter) loop = True @@ -976,17 +970,16 @@ class groupsof(_coconut_base_hashable): def __reduce__(self): return (self.__class__, (self.group_size, self.iter)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(self.group_size, new_iter) + return self.__class__(self.group_size, self.get_new_iter()) def __fmap__(self, func): return _coconut_map(func, self) class recursive_iterator(_coconut_base_hashable): """Decorator that optimizes a recursive function that returns an iterator (e.g. a recursive generator).""" - __slots__ = ("func", "tee_store", "backup_tee_store") + __slots__ = ("func", "reit_store", "backup_reit_store") def __init__(self, func): self.func = func - self.tee_store = {empty_dict} - self.backup_tee_store = [] + self.reit_store = {empty_dict} + self.backup_reit_store = [] def __call__(self, *args, **kwargs): key = (args, _coconut.frozenset(kwargs.items())) use_backup = False @@ -998,24 +991,18 @@ class recursive_iterator(_coconut_base_hashable): except _coconut.Exception: use_backup = True if use_backup: - for i, (k, v) in _coconut.enumerate(self.backup_tee_store): + for k, v in self.backup_reit_store: if k == key: - to_tee, store_pos = v, i - break - else:{COMMENT.no_break} - to_tee = self.func(*args, **kwargs) - store_pos = None - to_store, to_return = _coconut_tee(to_tee) - if store_pos is None: - self.backup_tee_store.append([key, to_store]) - else: - self.backup_tee_store[store_pos][1] = to_store + return reit + reit = _coconut_reiterable(self.func(*args, **kwargs)) + self.backup_reit_store.append([key, reit]) + return reit else: - it = self.tee_store.get(key) - if it is None: - it = self.func(*args, **kwargs) - self.tee_store[key], to_return = _coconut_tee(it) - return to_return + reit = self.reit_store.get(key) + if reit is None: + reit = _coconut_reiterable(self.func(*args, **kwargs)) + self.reit_store[key] = reit + return reit def __repr__(self): return "recursive_iterator(%r)" % (self.func,) def __reduce__(self): @@ -1174,10 +1161,10 @@ class starmap(_coconut_base_hashable, _coconut.itertools.starmap): __slots__ = ("func", "iter") __doc__ = getattr(_coconut.itertools.starmap, "__doc__", "starmap(func, iterable) = (func(*args) for args in iterable)") def __new__(cls, function, iterable): - new_map = _coconut.itertools.starmap.__new__(cls, function, iterable) - new_map.func = function - new_map.iter = iterable - return new_map + self = _coconut.itertools.starmap.__new__(cls, function, iterable) + self.func = function + self.iter = iterable + return self def __getitem__(self, index): if _coconut.isinstance(index, _coconut.slice): return self.__class__(self.func, _coconut_iter_getitem(self.iter, index)) @@ -1193,8 +1180,8 @@ class starmap(_coconut_base_hashable, _coconut.itertools.starmap): def __reduce__(self): return (self.__class__, (self.func, self.iter)) def __copy__(self): - self.iter, new_iter = _coconut_tee(self.iter) - return self.__class__(self.func, new_iter) + self.iter = _coconut_reiterable(self.iter) + return self.__class__(self.func, self.iter) def __iter__(self): return _coconut.iter(_coconut.itertools.starmap(self.func, self.iter)) def __fmap__(self, func): @@ -1338,6 +1325,42 @@ def call(_coconut_f, *args, **kwargs): """ return _coconut_f(*args, **kwargs) {of_is_call} +def safe_call(_coconut_f, *args, **kwargs): + """safe_call is a version of call that catches any Exceptions and + returns an Expected containing either the result or the error. + + Equivalent to: + def safe_call(f, /, *args, **kwargs): + try: + return Expected(f(*args, **kwargs)) + except Exception as err: + return Expected(error=err) + """ + try: + return _coconut_Expected(_coconut_f(*args, **kwargs)) + except _coconut.Exception as err: + return _coconut_Expected(error=err) +class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){comma_object}): + """TODO""" + _coconut_is_data = True + __slots__ = () + def __add__(self, other): return _coconut.NotImplemented + def __mul__(self, other): return _coconut.NotImplemented + def __rmul__(self, other): return _coconut.NotImplemented + __ne__ = _coconut.object.__ne__ + def __eq__(self, other): + return self.__class__ is other.__class__ and _coconut.tuple.__eq__(self, other) + def __hash__(self): + return _coconut.tuple.__hash__(self) ^ hash(self.__class__) + __match_args__ = ('result', 'error') + def __new__(cls, result=None, error=None): + if result is not None and error is not None: + raise _coconut.ValueError("Expected cannot have both a result and an error") + return _coconut.tuple.__new__(cls, (result, error)) + def __fmap__(self, func): + return self if self.error is not None else self.__class__(func(self.result)) + def __bool__(self): + return self.error is None class flip(_coconut_base_hashable): """Given a function, return a new function with inverse argument order. If nargs is passed, only the first nargs arguments are reversed.""" @@ -1491,4 +1514,4 @@ def _coconut_multi_dim_arr(arrs, dim): max_arr_dim = _coconut.max(arr_dims) return _coconut_concatenate(arrs, max_arr_dim - dim) _coconut_self_match_types = {self_match_types} -_coconut_MatchError, _coconut_count, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_map, _coconut_reiterable, _coconut_reversed, _coconut_starmap, _coconut_tee, _coconut_zip, TYPE_CHECKING, reduce, takewhile, dropwhile = MatchError, count, enumerate, flatten, filter, map, reiterable, reversed, starmap, tee, zip, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile +_coconut_Expected, _coconut_MatchError, _coconut_count, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_map, _coconut_reiterable, _coconut_reversed, _coconut_starmap, _coconut_tee, _coconut_zip, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, count, enumerate, flatten, filter, map, reiterable, reversed, starmap, tee, zip, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile diff --git a/coconut/constants.py b/coconut/constants.py index b6a5a36e6..67dcf31fc 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -590,9 +590,10 @@ def get_bool_env_var(env_var, default=False): ) coconut_specific_builtins = ( + "TYPE_CHECKING", + "Expected", "breakpoint", "help", - "TYPE_CHECKING", "reduce", "takewhile", "dropwhile", @@ -615,6 +616,7 @@ def get_bool_env_var(env_var, default=False): "flatten", "ident", "call", + "safe_call", "flip", "const", "lift", @@ -645,17 +647,17 @@ def get_bool_env_var(env_var, default=False): "_namedtuple_of", ) -all_builtins = frozenset(python_builtins + coconut_specific_builtins) +coconut_exceptions = ( + "MatchError", +) + +all_builtins = frozenset(python_builtins + coconut_specific_builtins + coconut_exceptions) magic_methods = ( "__fmap__", "__iter_getitem__", ) -exceptions = ( - "MatchError", -) - new_operators = ( r"@", r"\$", @@ -1000,7 +1002,7 @@ def get_bool_env_var(env_var, default=False): "islice", ) + ( coconut_specific_builtins - + exceptions + + coconut_exceptions + magic_methods + reserved_vars ) diff --git a/coconut/highlighter.py b/coconut/highlighter.py index 16b04c500..aef74f588 100644 --- a/coconut/highlighter.py +++ b/coconut/highlighter.py @@ -35,7 +35,7 @@ shebang_regex, magic_methods, template_ext, - exceptions, + coconut_exceptions, main_prompt, ) @@ -95,7 +95,7 @@ class CoconutLexer(Python3Lexer): ] tokens["builtins"] += [ (words(coconut_specific_builtins + interp_only_builtins, suffix=r"\b"), Name.Builtin), - (words(exceptions, suffix=r"\b"), Name.Exception), + (words(coconut_exceptions, suffix=r"\b"), Name.Exception), ] tokens["numbers"] = [ (r"0b[01_]+", Number.Integer), diff --git a/coconut/root.py b/coconut/root.py index f275e61d2..35ef56253 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -26,7 +26,7 @@ VERSION = "2.1.1" VERSION_NAME = "The Spanish Inquisition" # False for release, int >= 1 for develop -DEVELOP = 18 +DEVELOP = 19 ALPHA = False # for pre releases rather than post releases # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/tests/src/cocotest/agnostic/main.coco b/coconut/tests/src/cocotest/agnostic/main.coco index c0ee5dc76..165724df8 100644 --- a/coconut/tests/src/cocotest/agnostic/main.coco +++ b/coconut/tests/src/cocotest/agnostic/main.coco @@ -1265,6 +1265,25 @@ def main_test() -> bool: ys = (_ for _ in range(2)) :: (_ for _ in range(2)) assert ys |> list == [0, 1, 0, 1] assert ys |> list == [] + assert Expected(10) |> fmap$(.+1) == Expected(11) + some_err = ValueError() + assert Expected(error=some_err) |> fmap$(.+1) == Expected(error=some_err) + res, err = Expected(10) + assert (res, err) == (10, None) + assert Expected("abc") + assert not Expected(error=TypeError()) + assert all(it `isinstance` flatten for it in tee(flatten([[1], [2]]))) + fl12 = flatten([[1], [2]]) + assert fl12.get_new_iter() is fl12.get_new_iter() # type: ignore + res, err = safe_call(-> 1 / 0) |> fmap$(.+1) + assert res is None + assert err `isinstance` ZeroDivisionError + recit = ([1,2,3] :: recit) |> map$(.+1) + assert tee(recit) + rawit = (_ for _ in (0, 1)) + t1, t2 = tee(rawit) + t1a, t1b = tee(t1) + assert (list(t1a), list(t1b), list(t2)) == ([0, 1], [0, 1], [0, 1]) return True def test_asyncio() -> bool: diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 676831f2d..f76db8cdf 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -275,6 +275,7 @@ def suite_test() -> bool: assert vector(1, 2) |> (==)$(vector(1, 2)) assert vector(1, 2) |> .__eq__(other=vector(1, 2)) # type: ignore assert fibs()$[1:4] |> tuple == (1, 2, 3) == fibs_()$[1:4] |> tuple + assert fibs()$[:10] |> list == [1,1,2,3,5,8,13,21,34,55] == fibs_()$[:10] |> list assert fibs() |> takewhile$((i) -> i < 4000000 ) |> filter$((i) -> i % 2 == 0 ) |> sum == 4613732 == fibs_() |> takewhile$((i) -> i < 4000000 ) |> filter$((i) -> i % 2 == 0 ) |> sum # type: ignore assert loop([1,2])$[:4] |> list == [1, 2] * 2 assert parallel_map(list .. .$[:2] .. loop, ([1], [2]))$[:2] |> tuple == ([1, 1], [2, 2]) @@ -1001,6 +1002,8 @@ forward 2""") == 900 """) == 7 + 8 + 9 assert split_in_half("123456789") |> list == [("1","2","3","4","5"), ("6","7","8","9")] assert arr_of_prod([5,2,1,4,3]) |> list == [24,60,120,30,40] + assert safe_call(raise_exc).error `isinstance` Exception + assert safe_call((.+1), 5).result == 6 # must come at end assert fibs_calls[0] == 1