diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f6561c20..c5784b994 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: args: - --ignore=W503,E501,E265,E402,F405,E305,E126 - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v2.0.1 + rev: v2.0.2 hooks: - id: autopep8 args: diff --git a/DOCS.md b/DOCS.md index 0a4e4d118..ebacc5ea4 100644 --- a/DOCS.md +++ b/DOCS.md @@ -85,10 +85,9 @@ pip install coconut[opt_dep_1,opt_dep_2] The full list of optional dependencies is: -- `all`: alias for `jupyter,watch,jobs,mypy,backports,xonsh` (this is the recommended way to install a feature-complete version of Coconut). +- `all`: alias for `jupyter,watch,mypy,backports,xonsh` (this is the recommended way to install a feature-complete version of Coconut). - `jupyter`/`ipython`: enables use of the `--jupyter` / `--ipython` flag. - `watch`: enables use of the `--watch` flag. -- `jobs`: improves use of the `--jobs` flag. - `mypy`: enables use of the `--mypy` flag. - `backports`: installs libraries that backport newer Python features to older versions, which Coconut will automatically use instead of the standard library if the standard library is not available. Specifically: - Installs [`typing`](https://pypi.org/project/typing/) and [`typing_extensions`](https://pypi.org/project/typing-extensions/) to backport [`typing`](https://docs.python.org/3/library/typing.html). @@ -122,10 +121,11 @@ depth: 1 ``` coconut [-h] [--and source [dest ...]] [-v] [-t version] [-i] [-p] [-a] [-l] [-k] [-w] - [-r] [-n] [-d] [-q] [-s] [--no-tco] [--no-wrap] [-c code] [-j processes] [-f] - [--minify] [--jupyter ...] [--mypy ...] [--argv ...] [--tutorial] [--docs] - [--style name] [--history-file path] [--vi-mode] [--recursion-limit limit] - [--site-install] [--site-uninstall] [--verbose] [--trace] [--profile] + [-r] [-n] [-d] [-q] [-s] [--no-tco] [--no-wrap-types] [-c code] [-j processes] + [-f] [--minify] [--jupyter ...] [--mypy ...] [--argv ...] [--tutorial] + [--docs] [--style name] [--history-file path] [--vi-mode] + [--recursion-limit limit] [--stack-size kbs] [--site-install] + [--site-uninstall] [--verbose] [--trace] [--profile] [source] [dest] ``` @@ -140,7 +140,6 @@ dest destination directory for compiled files (defaults to ##### Optional Arguments ``` -optional arguments: -h, --help show this help message and exit --and source [dest ...] add an additional source/dest pair to compile @@ -167,12 +166,13 @@ optional arguments: runnable code to stdout) -s, --strict enforce code cleanliness standards --no-tco, --notco disable tail call optimization - --no-wrap, --nowrap disable wrapping type annotations in strings and turn off 'from + --no-wrap-types, --nowraptypes + disable wrapping type annotations in strings and turn off 'from __future__ import annotations' behavior -c code, --code code run Coconut passed in as a string (can also be piped into stdin) -j processes, --jobs processes - number of additional processes to use (defaults to 0) (pass 'sys' to - use machine default) + number of additional processes to use (defaults to 'sys') (0 is no + additional processes; 'sys' uses machine default) -f, --force force re-compilation even when source code and compilation parameters haven't changed --minify reduce size of compiled Python @@ -195,7 +195,12 @@ optional arguments: --vi-mode, --vimode enable vi mode in the interpreter (currently set to False) (can be modified by setting COCONUT_VI_MODE environment variable) --recursion-limit limit, --recursionlimit limit - set maximum recursion depth in compiler (defaults to 2090) + set maximum recursion depth in compiler (defaults to 1920) (when + increasing --recursion-limit, you may also need to increase --stack- + size) + --stack-size kbs, --stacksize kbs + run the compiler in a separate thread with the given stack size in + kilobytes --site-install, --siteinstall set up coconut.convenience to be imported on Python start --site-uninstall, --siteuninstall @@ -213,7 +218,7 @@ coconut-run ``` as an alias for ``` -coconut --run --quiet --target sys --argv +coconut --run --quiet --target sys --line-numbers --argv ``` which will quietly compile and run ``, passing any additional arguments to the script, mimicking how the `python` command works. @@ -222,15 +227,21 @@ which will quietly compile and run ``, passing any additional arguments #!/usr/bin/env coconut-run ``` +To pass additional compilation arguments to `coconut-run` (e.g. `--no-tco`), put them before the `` file. + #### Naming Source Files -Coconut source files should, so the compiler can recognize them, use the extension `.coco` (preferred), `.coc`, or `.coconut`. When Coconut compiles a `.coco` (or `.coc`/`.coconut`) file, it will compile to another file with the same name, except with `.py` instead of `.coco`, which will hold the compiled code. If an extension other than `.py` is desired for the compiled files, such as `.pyde` for [Python Processing](http://py.processing.org/), then that extension can be put before `.coco` in the source file name, and it will be used instead of `.py` for the compiled files. For example, `name.coco` will compile to `name.py`, whereas `name.pyde.coco` will compile to `name.pyde`. +Coconut source files should, so the compiler can recognize them, use the extension `.coco` (preferred), `.coc`, or `.coconut`. + +When Coconut compiles a `.coco` file, it will compile to another file with the same name, except with `.py` instead of `.coco`, which will hold the compiled code. + +If an extension other than `.py` is desired for the compiled files, then that extension can be put before `.coco` in the source file name, and it will be used instead of `.py` for the compiled files. For example, `name.coco` will compile to `name.py`, whereas `name.abc.coco` will compile to `name.abc`. #### Compilation Modes Files compiled by the `coconut` command-line utility will vary based on compilation parameters. If an entire directory of files is compiled (which the compiler will search recursively for any folders containing `.coco`, `.coc`, or `.coconut` files), a `__coconut__.py` file will be created to house necessary functions (package mode), whereas if only a single file is compiled, that information will be stored within a header inside the file (standalone mode). Standalone mode is better for single files because it gets rid of the overhead involved in importing `__coconut__.py`, but package mode is better for large packages because it gets rid of the need to run the same Coconut header code again in every file, since it can just be imported from `__coconut__.py`. -By default, if the `source` argument to the command-line utility is a file, it will perform standalone compilation on it, whereas if it is a directory, it will recursively search for all `.coco` (or `.coc` / `.coconut`) files and perform package compilation on them. Thus, in most cases, the mode chosen by Coconut automatically will be the right one. But if it is very important that no additional files like `__coconut__.py` be created, for example, then the command-line utility can also be forced to use a specific mode with the `--package` (`-p`) and `--standalone` (`-a`) flags. +By default, if the `source` argument to the command-line utility is a file, it will perform standalone compilation on it, whereas if it is a directory, it will recursively search for all `.coco` files and perform package compilation on them. Thus, in most cases, the mode chosen by Coconut automatically will be the right one. But if it is very important that no additional files like `__coconut__.py` be created, for example, then the command-line utility can also be forced to use a specific mode with the `--package` (`-p`) and `--standalone` (`-a`) flags. #### Compatible Python Versions @@ -239,6 +250,7 @@ While Coconut syntax is based off of the latest Python 3, Coconut code compiled To make Coconut built-ins universal across Python versions, Coconut makes available on any Python version built-ins that only exist in later versions, including **automatically overwriting Python 2 built-ins with their Python 3 counterparts.** Additionally, Coconut also [overwrites some Python 3 built-ins for optimization and enhancement purposes](#enhanced-built-ins). If access to the original Python versions of any overwritten built-ins is desired, the old built-ins can be retrieved by prefixing them with `py_`. Specifically, the overwritten built-ins are: - `py_chr` +- `py_dict` - `py_hex` - `py_input` - `py_int` @@ -263,8 +275,6 @@ _Note: Coconut's `repr` can be somewhat tricky, as it will attempt to remove the For standard library compatibility, **Coconut automatically maps imports under Python 3 names to imports under Python 2 names**. Thus, Coconut will automatically take care of any standard library modules that were renamed from Python 2 to Python 3 if just the Python 3 name is used. For modules or packages that only exist in Python 3, however, Coconut has no way of maintaining compatibility. -Additionally, Coconut allows the [`__set_name__`](https://docs.python.org/3/reference/datamodel.html#object.__set_name__) magic method for descriptors to work on any Python version. - Finally, while Coconut will try to compile Python-3-specific syntax to its universal equivalent, the following constructs have no equivalent in Python 2, and require the specification of a target of at least `3` to be used: - the `nonlocal` keyword, @@ -275,6 +285,8 @@ Finally, while Coconut will try to compile Python-3-specific syntax to its unive - `a[x, *y]` variadic generic syntax (use [type parameter syntax](#type-parameter-syntax) for universal code) (requires `--target 3.11`), and - `except*` multi-except statements (requires `--target 3.11`). +_Note: Coconut also universalizes many magic methods, including making `__bool__` and [`__set_name__`](https://docs.python.org/3/reference/datamodel.html#object.__set_name__) work on any Python version._ + #### Allowable Targets If the version of Python that the compiled code will be running on is known ahead of time, a target should be specified with `--target`. The given target will only affect the compiled code and whether or not the Python-3-specific syntax detailed above is allowed. Where Python syntax differs across versions, Coconut syntax will always follow the latest Python 3 across all targets. The supported targets are: @@ -295,7 +307,7 @@ If the version of Python that the compiled code will be running on is known ahea - `3.12` (will work on any Python `>= 3.12`), and - `sys` (chooses the target corresponding to the current Python version). -_Note: Periods are ignored in target specifications, such that the target `27` is equivalent to the target `2.7`._ +_Note: Periods are optional in target specifications, such that the target `27` is equivalent to the target `2.7`._ #### `strict` Mode @@ -338,7 +350,7 @@ Text editors with support for Coconut syntax highlighting are: - **SublimeText**: See SublimeText section below. - **Spyder** (or any other editor that supports **Pygments**): See Pygments section below. - **Vim**: See [`coconut.vim`](https://github.com/manicmaniac/coconut.vim). -- **Emacs**: See [`coconut-mode`](https://github.com/NickSeagull/coconut-mode). +- **Emacs**: See [`emacs-coconut`](https://codeberg.org/kobarity/emacs-coconut)/[`emacs-ob-coconut`](https://codeberg.org/kobarity/emacs-ob-coconut). - **Atom**: See [`language-coconut`](https://github.com/enilsen16/language-coconut). Alternatively, if none of the above work for you, you can just treat Coconut as Python. Simply set up your editor so it interprets all `.coco` files as Python and that should highlight most of your code well enough (e.g. for IntelliJ IDEA see [registering file types](https://www.jetbrains.com/help/idea/creating-and-registering-file-types.html)). @@ -375,7 +387,7 @@ If Coconut is used as a kernel, all code in the console or notebook will be sent Simply installing Coconut should add a `Coconut` kernel to your Jupyter/IPython notebooks. If you are having issues accessing the Coconut kernel, however, the command `coconut --jupyter` will re-install the `Coconut` kernel to ensure it is using the current Python as well as add the additional kernels `Coconut (Default Python)`, `Coconut (Default Python 2)`, and `Coconut (Default Python 3)` which will use, respectively, the Python accessible as `python`, `python2`, and `python3` (these kernels are accessible in the console as `coconut_py`, `coconut_py2`, and `coconut_py3`). Furthermore, the Coconut kernel fully supports [`nb_conda_kernels`](https://github.com/Anaconda-Platform/nb_conda_kernels) to enable accessing the Coconut kernel in one Conda environment from another Conda environment. -The Coconut kernel will always compile using the parameters: `--target sys --line-numbers --keep-lines --no-wrap`. +The Coconut kernel will always compile using the parameters: `--target sys --line-numbers --keep-lines --no-wrap-types`. Coconut also provides the following convenience commands: @@ -393,7 +405,9 @@ The line magic `%load_ext coconut` will load Coconut as an extension, providing _Note: Unlike the normal Coconut command-line, `%%coconut` defaults to the `sys` target rather than the `universal` target._ -#### MyPy Integration +#### Type Checking + +##### MyPy Integration Coconut has the ability to integrate with [MyPy](http://mypy-lang.org/) to provide optional static type_checking, including for all Coconut built-ins. Simply pass `--mypy` to `coconut` to enable MyPy integration, though be careful to pass it only as the last argument, since all arguments after `--mypy` are passed to `mypy`, not Coconut. @@ -402,7 +416,22 @@ You can also run `mypy`—or any other static type checker—directly on the com 1. run `coconut --mypy install` and 2. tell your static type checker of choice to look in `~/.coconut_stubs` for stub files (for `mypy`, this is done by adding it to your [`MYPYPATH`](https://mypy.readthedocs.io/en/latest/running_mypy.html#how-imports-are-found)). -To explicitly annotate your code with types to be checked, Coconut supports [Python 3 function type annotations](https://www.python.org/dev/peps/pep-0484/), [Python 3.6 variable type annotations](https://www.python.org/dev/peps/pep-0526/), and even Coconut's own [enhanced type annotation syntax](#enhanced-type-annotation). By default, all type annotations are compiled to Python-2-compatible type comments, which means it all works on any Python version. Coconut also supports [PEP 695 type parameter syntax](#type-parameter-syntax) for easily adding type parameters to classes, functions, [`data` types](#data), and type aliases. +To distribute your code with checkable type annotations, you'll need to include `coconut` as a dependency (though a `--no-deps` install should be fine), as installing it is necessary to make the requisite stub files available. You'll also probably want to include a [`py.typed`](https://peps.python.org/pep-0561/) file. + +##### Syntax + +To explicitly annotate your code with types to be checked, Coconut supports: +* [Python 3 function type annotations](https://www.python.org/dev/peps/pep-0484/), +* [Python 3.6 variable type annotations](https://www.python.org/dev/peps/pep-0526/), +* [PEP 695 type parameter syntax](#type-parameter-syntax) for easily adding type parameters to classes, functions, [`data` types](#data), and type aliases, +* Coconut's own [enhanced type annotation syntax](#enhanced-type-annotation), and +* Coconut's [protocol intersection operator](#protocol-intersection). + +By default, all type annotations are compiled to Python-2-compatible type comments, which means it all works on any Python version. + +Sometimes, MyPy will not know how to handle certain Coconut constructs, such as `addpattern`. For the `addpattern` case, it is recommended to pass `--allow-redefinition` to MyPy (i.e. run `coconut --mypy --allow-redefinition`), though in some cases `--allow-redefinition` may not be sufficient. In that case, either hide the offending code using [`TYPE_CHECKING`](#type_checking) or put a `# type: ignore` comment on the Coconut line which is generating the line MyPy is complaining about and the comment will be added to every generated line. + +##### Interpreter Coconut even supports `--mypy` in the interpreter, which will intelligently scan each new line of code, in the context of previous lines, for newly-introduced MyPy errors. For example: ```coconut_pycon @@ -414,10 +443,6 @@ Coconut even supports `--mypy` in the interpreter, which will intelligently scan ``` _For more information on `reveal_type`, see [`reveal_type` and `reveal_locals`](#reveal-type-and-reveal-locals)._ -Sometimes, MyPy will not know how to handle certain Coconut constructs, such as `addpattern`. For the `addpattern` case, it is recommended to pass `--allow-redefinition` to MyPy (i.e. run `coconut --mypy --allow-redefinition`), though in some cases `--allow-redefinition` may not be sufficient. In that case, either hide the offending code using [`TYPE_CHECKING`](#type_checking) or put a `# type: ignore` comment on the Coconut line which is generating the line MyPy is complaining about and the comment will be added to every generated line. - -To distribute your code with checkable type annotations, you'll need to include `coconut` as a dependency (though a `--no-deps` install should be fine), as installing it is necessary to make the requisite stub files available. You'll also probably want to include a [`py.typed`](https://peps.python.org/pep-0561/) file. - #### `numpy` Integration To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all compiled Coconut code will do a number of special things to better integrate with `numpy` (if `numpy` is available to import when the code is run). Specifically: @@ -429,9 +454,10 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all * [`cartesian_product`](#cartesian_product) can compute the Cartesian product of given `numpy` arrays as a `numpy` array. * [`all_equal`](#all_equal) allows for easily checking if all the elements in a `numpy` array are the same. - [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification). +- `numpy` objects are allowed seamlessly in Coconut's [implicit coefficient syntax](#implicit-function-application-and-coefficients), allowing the use of e.g. `A B**2` shorthand for `A * B**2` when `A` and `B` are `numpy` arrays (note: **not** `A @ B**2`). - Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions). -Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/) and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `jax.numpy` methods over `numpy` methods when given `jax` arrays. +Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `pandas`/`jax`-specific methods over `numpy` methods when given `pandas`/`jax` objects. #### `xonsh` Support @@ -445,7 +471,7 @@ user@computer ~ $ $(ls -la) |> .splitlines() |> len 30 ``` -Note that the way that Coconut integrates with `xonsh`, `@()` syntax will only work with Python code, not Coconut code. +Note that the way that Coconut integrates with `xonsh`, `@()` syntax and the `execx` command will only work with Python code, not Coconut code. Additionally, Coconut will only compile individual commands—Coconut will not touch the `.xonshrc` or any other `.xsh` files. @@ -465,18 +491,19 @@ In order of precedence, highest first, the operators supported in Coconut are: ====================== ========================== Symbol(s) Associativity ====================== ========================== -f x n/a await x n/a -.. n/a -** right +** right (allows unary) +f x n/a +, -, ~ unary *, /, //, %, @ left +, - left <<, >> left & left +&: left ^ left | left :: n/a (lazy) +.. n/a a `b` c, left (captures lambda) all custom operators ?? left (short-circuits) @@ -620,9 +647,11 @@ Coconut uses pipe operators for pipeline-style function application. All the ope (<**?|) => None-aware keyword arg pipe backward ``` -Additionally, all pipe operators support a lambda as the last argument, despite lambdas having a lower precedence. Thus, `a |> x -> b |> c` is equivalent to `a |> (x -> b |> c)`, not `a |> (x -> b) |> c`. +The None-aware pipe operators here are equivalent to a [monadic bind](https://en.wikipedia.org/wiki/Monad_(functional_programming)) treating the object as a `Maybe` monad composed of either `None` or the given object. Thus, `x |?> f` is equivalent to `None if x is None else f(x)`. Note that only the object being piped, not the function being piped into, may be `None` for `None`-aware pipes. -The None-aware pipe operators here are equivalent to a [monadic bind](https://en.wikipedia.org/wiki/Monad_(functional_programming)) treating the object as a `Maybe` monad composed of either `None` or the given object. Note that only the object being piped, not the function being piped into, may be `None` for `None`-aware pipes. +For working with `async` functions in pipes, all non-starred pipes support piping into `await` to await the awaitable piped into them, such that `x |> await` is equivalent to `await x`. + +Additionally, all pipe operators support a lambda as the last argument, despite lambdas having a lower precedence. Thus, `a |> x -> b |> c` is equivalent to `a |> (x -> b |> c)`, not `a |> (x -> b) |> c`. _Note: To visually spread operations across several lines, just use [parenthetical continuation](#enhanced-parenthetical-continuation)._ @@ -642,7 +671,7 @@ If Coconut compiled each of the partials in the pipe syntax as an actual partial This applies even to in-place pipes such as `|>=`. -##### Example +##### Examples **Coconut:** ```coconut @@ -650,6 +679,15 @@ def sq(x) = x**2 (1, 2) |*> (+) |> sq |> print ``` +```coconut +async def do_stuff(some_data) = ( + some_data + |> async_func + |> await + |> post_proc +) +``` + **Python:** ```coconut_python import operator @@ -657,6 +695,11 @@ def sq(x): return x**2 print(sq(operator.add(1, 2))) ``` +```coconut_python +async def do_stuff(some_data): + return post_proc(await async_func(some_data)) +``` + ### Function Composition Coconut has three basic function composition operators: `..`, `..>`, and `<..`. Both `..` and `<..` use math-style "backwards" function composition, where the first function is called last, while `..>` uses "forwards" function composition, where the first function is called first. Forwards and backwards function composition pipes cannot be used together in the same expression (unlike normal pipes) and have precedence in-between `None`-coalescing and normal pipes. @@ -679,7 +722,7 @@ The `..>` and `<..` function composition pipe operators also have multi-arg, key Note that `None`-aware function composition pipes don't allow either function to be `None`—rather, they allow the return of the first evaluated function to be `None`, in which case `None` is returned immediately rather than calling the next function. -The `..` operator has lower precedence than `await` but higher precedence than `**` while the `..>` pipe operators have a precedence directly higher than normal pipes. +The `..` operator has lower precedence than `::` but higher precedence than infix functions while the `..>` pipe operators have a precedence directly higher than normal pipes. All function composition operators also have in-place versions (e.g. `..=`). @@ -704,7 +747,7 @@ Coconut uses a `$` sign right after an iterator before a slice to perform iterat Iterator slicing works just like sequence slicing, including support for negative indices and slices, and support for `slice` objects in the same way as can be done with normal slicing. Iterator slicing makes no guarantee, however, that the original iterator passed to it be preserved (to preserve the iterator, use Coconut's [`reiterable`](#reiterable) built-in). -Coconut's iterator slicing is very similar to Python's `itertools.islice`, but unlike `itertools.islice`, Coconut's iterator slicing supports negative indices, and will preferentially call an object's `__iter_getitem__` (Coconut-specific magic method, preferred) or `__getitem__` (general Python magic method), if they exist. Coconut's iterator slicing is also optimized to work well with all of Coconut's built-in objects, only computing the elements of each that are actually necessary to extract the desired slice. +Coconut's iterator slicing is very similar to Python's `itertools.islice`, but unlike `itertools.islice`, Coconut's iterator slicing supports negative indices, and will preferentially call an object's `__iter_getitem__` (always used if available) or `__getitem__` (only used if the object is a collections.abc.Sequence). Coconut's iterator slicing is also optimized to work well with all of Coconut's built-in objects, only computing the elements of each that are actually necessary to extract the desired slice. ##### Example @@ -829,7 +872,7 @@ from import operator Custom operators will often need to be surrounded by whitespace (or parentheses when used as an operator function) to be parsed correctly. -If a custom operator that is also a valid name is desired, you can use a backslash before the name to get back the name instead with Coconut's [keyword/variable disambiguation syntax](#handling-keywordvariable-name-overlap). +If a custom operator that is also a valid name is desired, you can use a backslash before the name to get back the name instead using Coconut's [keyword/variable disambiguation syntax](#handling-keywordvariable-name-overlap). ##### Examples @@ -928,6 +971,61 @@ import functools (lambda result: None if result is None else result.attr[index].method())(could_be_none()) ``` +### Protocol Intersection + +Coconut uses the `&:` operator to indicate protocol intersection. That is, for two [`typing.Protocol`s](https://docs.python.org/3/library/typing.html#typing.Protocol) `Protocol1` and `Protocol1`, `Protocol1 &: Protocol2` is equivalent to a `Protocol` that combines the requirements of both `Protocol1` and `Protocol2`. + +The recommended way to use Coconut's protocol intersection operator is in combination with Coconut's [operator `Protocol`s](#supported-protocols). Note, however, that while `&:` will work anywhere, operator `Protocol`s will only work inside type annotations (which means, for example, you'll need to do `type HasAdd = (+)` instead of just `HasAdd = (+)`). + +See Coconut's [enhanced type annotation](#enhanced-type-annotation) for more information on how Coconut handles type annotations more generally. + +##### Example + +**Coconut:** +```coconut +from typing import Protocol + +class X(Protocol): + x: str + +class Y(Protocol): + y: str + +def foo(xy: X &: Y) -> None: + print(xy.x, xy.y) + +type CanAddAndSub = (+) &: (-) +``` + +**Python:** +```coconut_python +from typing import Protocol, TypeVar, Generic + +class X(Protocol): + x: str + +class Y(Protocol): + y: str + +class XY(X, Y, Protocol): + pass + +def foo(xy: XY) -> None: + print(xy.x, xy.y) + +T = TypeVar("T", infer_variance=True) +U = TypeVar("U", infer_variance=True) +V = TypeVar("V", infer_variance=True) + +class CanAddAndSub(Protocol, Generic[T, U, V]): + def __add__(self: T, other: U) -> V: + raise NotImplementedError + def __sub__(self: T, other: U) -> V: + raise NotImplementedError + def __neg__(self: T) -> V: + raise NotImplementedError +``` + ### Unicode Alternatives Coconut supports Unicode alternatives to many different operator symbols. The Unicode alternatives are relatively straightforward, and chosen to reflect the look and/or meaning of the original symbol. @@ -936,9 +1034,9 @@ Coconut supports Unicode alternatives to many different operator symbols. The Un ``` → (\u2192) => "->" -× (\xd7) => "*" -↑ (\u2191) => "**" -÷ (\xf7) => "/" +× (\xd7) => "*" (only multiplication) +↑ (\u2191) => "**" (only exponentiation) +÷ (\xf7) => "/" (only division) ÷/ (\xf7/) => "//" ⁻ (\u207b) => "-" (only negation) ≠ (\u2260) or ¬= (\xac=) => "!=" @@ -1033,7 +1131,10 @@ base_pattern ::= ( | "class" NAME "(" patterns ")" # classes | "{" pattern_pairs # dictionaries ["," "**" (NAME | "{}")] "}" # (keys must be constants or equality checks) - | ["s"] "{" pattern_consts "}" # sets + | ["s" | "f" | "m"] "{" + pattern_consts + ["," ("*_" | "*()")] + "}" # sets | (EXPR) -> pattern # view patterns | "(" patterns ")" # sequences can be in tuple form | "[" patterns "]" # or in list form @@ -1086,7 +1187,6 @@ base_pattern ::= ( - Constants, Numbers, and Strings: will only match to the same constant, number, or string in the same position in the arguments. - Equality Checks (`==`): will check that whatever is in that position is `==` to the expression ``. - Identity Checks (`is `): will check that whatever is in that position `is` the expression ``. - - Sets (`{}`): will only match a set (`collections.abc.Set`) of the same length and contents. - Arbitrary Function Patterns: - Infix Checks (`` `` ``): will check that the operator `$(?, )` returns a truthy value when called on whatever is in that position, then matches ``. For example, `` x `isinstance` int `` will check that whatever is in that position `isinstance$(?, int)` and bind it to `x`. If `` is not given, will simply check `` directly rather than `$()`. Additionally, `` `` `` can instead be a [custom operator](#custom-operators) (in that case, no backticks should be used). - View Patterns (`() -> `): calls `` on the item being matched and matches the result to ``. The match fails if a [`MatchError`](#matcherror) is raised. `` may be unparenthesized only when it is a single atom. @@ -1097,6 +1197,11 @@ base_pattern ::= ( - Mapping Destructuring: - Dicts (`{: , ...}`): will match any mapping (`collections.abc.Mapping`) with the given keys and values that match the value patterns. Keys must be constants or equality checks. - Dicts With Rest (`{, **}`): will match a mapping (`collections.abc.Mapping`) containing all the ``, and will put a `dict` of everything else into ``. If `` is `{}`, will enforce that the mapping is exactly the same length as ``. +- Set Destructuring: + - Sets (`s{, *_}`): will match a set (`collections.abc.Set`) that contains the given ``, though it may also contain other items. The `s` prefix and the `*_` are optional. + - Fixed-length Sets (`s{, *()}`): will match a `set` (`collections.abc.Set`) that contains the given ``, and nothing else. + - Frozensets (`f{}`): will match a `frozenset` (`frozenset`) that contains the given ``. May use either normal or fixed-length syntax. + - Multisets (`m{}`): will match a [`multiset`](#multiset) (`collections.Counter`) that contains at least the given ``. May use either normal or fixed-length syntax. - Sequence Destructuring: - Lists (`[]`), Tuples (`()`): will only match a sequence (`collections.abc.Sequence`) of the same length, and will check the contents against `` (Coconut automatically registers `numpy` arrays and `collections.deque` objects as sequences). - Lazy lists (`(||)`): same as list or tuple matching, but checks for an Iterable (`collections.abc.Iterable`) instead of a Sequence. @@ -1149,11 +1254,9 @@ data Node(l, r) from Tree def depth(Tree()) = 0 -@addpattern(depth) -def depth(Tree(n)) = 1 +addpattern def depth(Tree(n)) = 1 -@addpattern(depth) -def depth(Tree(l, r)) = 1 + max([depth(l), depth(r)]) +addpattern def depth(Tree(l, r)) = 1 + max([depth(l), depth(r)]) Empty() |> depth |> print Leaf(5) |> depth |> print @@ -1173,8 +1276,7 @@ _Showcases head-tail splitting, one of the most common uses of pattern-matching, def sieve([head] :: tail) = [head] :: sieve(n for n in tail if n % head) -@addpattern(sieve) -def sieve((||)) = [] +addpattern def sieve((||)) = [] ``` _Showcases how to match against iterators, namely that the empty iterator case (`(||)`) must come last, otherwise that case will exhaust the whole iterator before any other pattern has a chance to match against it._ @@ -1351,11 +1453,9 @@ data Node(l, r) def size(Empty()) = 0 -@addpattern(size) -def size(Leaf(n)) = 1 +addpattern def size(Leaf(n)) = 1 -@addpattern(size) -def size(Node(l, r)) = size(l) + size(r) +addpattern def size(Node(l, r)) = size(l) + size(r) size(Node(Empty(), Leaf(10))) == 1 ``` @@ -1442,8 +1542,6 @@ c = a + b ### Handling Keyword/Variable Name Overlap In Coconut, the following keywords are also valid variable names: -- `async` (keyword in Python 3.5) -- `await` (keyword in Python 3.5) - `data` - `match` - `case` @@ -1500,9 +1598,9 @@ The statement lambda syntax is an extension of the [normal lambda syntax](#lambd The syntax for a statement lambda is ``` -[async] [match] def (arguments) -> statement; statement; ... +[async|match|copyclosure] def (arguments) -> statement; statement; ... ``` -where `arguments` can be standard function arguments or [pattern-matching function definition](#pattern-matching-functions) arguments and `statement` can be an assignment statement or a keyword statement. Note that the `async` and `match` keywords can be in any order. +where `arguments` can be standard function arguments or [pattern-matching function definition](#pattern-matching-functions) arguments and `statement` can be an assignment statement or a keyword statement. Note that the `async`, `match`, and [`copyclosure`](#copyclosure-functions) keywords can be combined and can be in any order. If the last `statement` (not followed by a semicolon) in a statement lambda is an `expression`, it will automatically be returned. @@ -1596,7 +1694,9 @@ A very common thing to do in functional programming is to make use of function v (and) => # boolean and (or) => # boolean or (is) => (operator.is_) +(is not) => (operator.is_not) (in) => (operator.contains) +(not in) => # negative containment (assert) => def (cond, msg=None) -> assert cond, msg # (but a better msg if msg is None) (raise) => def (exc=None, from_exc=None) -> raise exc from from_exc # or just raise if exc is None # there are two operator functions that don't require parentheses: @@ -1604,7 +1704,9 @@ A very common thing to do in functional programming is to make use of function v .$[] => # iterator slicing operator ``` -_For an operator function for function application, see [`call`](#call)._ +For an operator function for function application, see [`call`](#call). + +Though no operator function is available for `await`, an equivalent syntax is available for [pipes](#pipes) in the form of `awaitable |> await`. ##### Example @@ -1625,7 +1727,6 @@ Coconut supports a number of different syntactical aliases for common partial ap ```coconut .attr => operator.attrgetter("attr") .method(args) => operator.methodcaller("method", args) -obj. => getattr$(obj) func$ => ($)$(func) seq[] => operator.getitem$(seq) iter$[] => # the equivalent of seq[] for iterators @@ -1649,6 +1750,8 @@ Additionally, Coconut also supports implicit operator function partials for arbi ``` based on Coconut's [infix notation](#infix-functions) where `` is the name of the function. Additionally, `` `` `` can instead be a [custom operator](#custom-operators) (in that case, no backticks should be used). +_DEPRECATED: Coconut also supports `obj.` as an implicit partial for `getattr$(obj)`, but its usage is deprecated and will show a warning to switch to `getattr$(obj)` instead._ + ##### Example **Coconut:** @@ -1671,7 +1774,7 @@ Since Coconut syntax is a superset of Python 3 syntax, it supports [Python 3 fun Since not all supported Python versions support the [`typing`](https://docs.python.org/3/library/typing.html) module, Coconut provides the [`TYPE_CHECKING`](#type_checking) built-in for hiding your `typing` imports and `TypeVar` definitions from being executed at runtime. Coconut will also automatically use [`typing_extensions`](https://pypi.org/project/typing-extensions/) over `typing` when importing objects not available in `typing` on the current Python version. -Furthermore, when compiling type annotations to Python 3 versions without [PEP 563](https://www.python.org/dev/peps/pep-0563/) support, Coconut wraps annotation in strings to prevent them from being evaluated at runtime (note that `--no-wrap` disables all wrapping, including via PEP 563 support). +Furthermore, when compiling type annotations to Python 3 versions without [PEP 563](https://www.python.org/dev/peps/pep-0563/) support, Coconut wraps annotation in strings to prevent them from being evaluated at runtime (note that `--no-wrap-types` disables all wrapping, including via PEP 563 support). Additionally, Coconut adds special syntax for making type annotations easier and simpler to write. When inside of a type annotation, Coconut treats certain syntax constructs differently, compiling them to type annotations instead of what they would normally represent. Specifically, Coconut applies the following transformations: ```coconut @@ -1700,6 +1803,8 @@ async () -> ``` where `typing` is the Python 3.5 built-in [`typing` module](https://docs.python.org/3/library/typing.html). For more information on the Callable syntax, see [PEP 677](https://peps.python.org/pep-0677), which Coconut fully supports. +Additionally, many of Coconut's [operator functions](#operator-functions) will compile into equivalent [`Protocol`s](https://docs.python.org/3/library/typing.html#typing.Protocol) instead when inside a type annotation. See below for the full list and specification. + _Note: The transformation to `Union` is not done on Python 3.10 as Python 3.10 has native [PEP 604](https://www.python.org/dev/peps/pep-0604) support._ To use these transformations in a [type alias](https://peps.python.org/pep-0484/#type-aliases), use the syntax @@ -1710,7 +1815,52 @@ which will allow `` to include Coconut's special type annotation syntax an Such type alias statements—as well as all `class`, `data`, and function definitions in Coconut—also support Coconut's [type parameter syntax](#type-parameter-syntax), allowing you to do things like `type OrStr[T] = T | str`. -Importantly, note that `int[]` does not map onto `typing.List[int]` but onto `typing.Sequence[int]`. This is because, when writing in an idiomatic functional style, assignment should be rare and tuples should be common. Using `Sequence` covers both cases, accommodating tuples and lists and preventing indexed assignment. When an indexed assignment is attempted into a variable typed with `Sequence`, MyPy will generate an error: +##### Supported Protocols + +Using Coconut's [operator function](#operator-functions) syntax inside of a type annotation will instead produce a [`Protocol`](https://docs.python.org/3/library/typing.html#typing.Protocol) corresponding to that operator (or raise a syntax error if no such `Protocol` is available). All available `Protocol`s are listed below. + +For the operator functions +``` +(+) +(*) +(**) +(/) +(//) +(%) +(&) +(^) +(|) +(<<) +(>>) +(@) +``` +the resulting `Protocol` is +```coconut +class SupportsOp[T, U, V](Protocol): + def __op__(self: T, other: U) -> V: + raise NotImplementedError(...) +``` +where `__op__` is the magic method corresponding to that operator. + +For the operator function `(-)`, the resulting `Protocol` is: +```coconut +class SupportsMinus[T, U, V](Protocol): + def __sub__(self: T, other: U) -> V: + raise NotImplementedError + def __neg__(self: T) -> V: + raise NotImplementedError +``` + +For the operator function `(~)`, the resulting `Protocol` is: +```coconut +class SupportsInv[T, V](Protocol): + def __invert__(self: T) -> V: + raise NotImplementedError(...) +``` + +##### `List` vs. `Sequence` + +Importantly, note that `T[]` does not map onto `typing.List[T]` but onto `typing.Sequence[T]`. This allows the resulting type to be covariant, such that if `U` is a subtype of `T`, then `U[]` is a subtype of `T[]`. Additionally, `Sequence[T]` allows for tuples, and when writing in an idiomatic functional style, assignment should be rare and tuples should be common. Using `Sequence` covers both cases, accommodating tuples and lists and preventing indexed assignment. When an indexed assignment is attempted into a variable typed with `Sequence`, MyPy will generate an error: ```coconut foo: int[] = [0, 1, 2, 3, 4, 5] @@ -1730,24 +1880,38 @@ def int_map( xs: int[], ) -> int[] = xs |> map$(f) |> list + +type CanAddAndSub = (+) &: (-) ``` **Python:** ```coconut_python import typing # unlike this typing import, Coconut produces universal code + def int_map( f, # type: typing.Callable[[int], int] xs, # type: typing.Sequence[int] ): # type: (...) -> typing.Sequence[int] return list(map(f, xs)) + +T = typing.TypeVar("T", infer_variance=True) +U = typing.TypeVar("U", infer_variance=True) +V = typing.TypeVar("V", infer_variance=True) +class CanAddAndSub(typing.Protocol, typing.Generic[T, U, V]): + def __add__(self: T, other: U) -> V: + raise NotImplementedError + def __sub__(self: T, other: U) -> V: + raise NotImplementedError + def __neg__(self: T) -> V: + raise NotImplementedError ``` ### Multidimensional Array Literal/Concatenation Syntax Coconut supports multidimensional array literal and array [concatenation](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html)/[stack](https://numpy.org/doc/stable/reference/generated/numpy.stack.html) syntax. -By default, all multidimensional array syntax will simply operate on Python lists of lists. However, if [`numpy`](#numpy-integration) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`). +By default, all multidimensional array syntax will simply operate on Python lists of lists (or any non-`str` `Sequence`). However, if [`numpy`](#numpy-integration) objects are used, the appropriate `numpy` calls will be made instead. To give custom objects multidimensional array concatenation support, define `type(obj).__matconcat__` (should behave as `np.concat`), `obj.ndim` (should behave as `np.ndarray.ndim`), and `obj.reshape` (should behave as `np.ndarray.reshape`). As a simple example, 2D matrices can be constructed by separating the rows with `;;` inside of a list literal: ```coconut_pycon @@ -1838,16 +2002,28 @@ Lazy lists, where sequences are only evaluated when their contents are requested **Python:** _Can't be done without a complicated iterator comprehension in place of the lazy list. See the compiled code for the Python syntax._ -### Implicit Function Application +### Implicit Function Application and Coefficients -Coconut supports implicit function application of the form `f x y`, which is compiled to `f(x, y)` (note: **not** `f(x)(y)` as is common in many languages with automatic currying). Implicit function application has a lower precedence than attribute access, slices, normal function calls, etc. but a higher precedence than `await`. +Coconut supports implicit function application of the form `f x y`, which is compiled to `f(x, y)` (note: **not** `f(x)(y)` as is common in many languages with automatic currying). -Supported arguments to implicit function application are highly restricted, and must be: +Additionally, if the first argument is not callable, and is instead an `int`, `float`, `complex`, or [`numpy`](#numpy-integration) object, then the result is multiplication rather than function application, such that `2 x` is equivalent to `2*x`. + +Though the first item may be any atom, following arguments are highly restricted, and must be: - variables/attributes (e.g. `a.b`), -- literal constants (e.g. `True`), or -- number literals (e.g. `1.5`). +- literal constants (e.g. `True`), +- number literals (e.g. `1.5`), or +- one of the above followed by an exponent (e.g. `a**-5`). + +For example, `(f .. g) x 1` will work, but `f x [1]`, `f x (1+2)`, and `f "abc"` will not. + +Implicit function application and coefficient syntax is only intended for simple use cases. For more complex cases, use the standard multiplication operator `*`, standard function application, or [pipes](#pipes). -For example, `f x 1` will work but `f x [1]`, `f x (1+2)`, and `f "abc"` will not. Strings are disallowed due to conflicting with [Python's implicit string concatenation](https://stackoverflow.com/questions/18842779/string-concatenation-without-operator). Implicit function application is only intended for simple use cases—for more complex cases, use either standard function application or [pipes](#pipes). +Implicit function application and coefficient syntax has a lower precedence than `**` but a higher precedence than unary operators. As a result, `2 x**2 + 3 x` is equivalent to `2 * x**2 + 3 * x`. + +Due to potential confusion, some syntactic constructs are explicitly disallowed in implicit function application and coefficient syntax. Specifically: +- Strings are always disallowed everywhere in implicit function application / coefficient syntax due to conflicting with [Python's implicit string concatenation](https://stackoverflow.com/questions/18842779/string-concatenation-without-operator). +- Multiplying two or more numeric literals with implicit coefficient syntax is prohibited, so `10 20` is not allowed. +- `await` is not allowed in front of implicit function application and coefficient syntax. To use `await`, simply parenthesize the expression, as in `await (f x)`. ##### Examples @@ -1862,6 +2038,10 @@ def p1(x) = x + 1 print <| p1 5 ``` +```coconut +quad = 5 x**2 + 3 x + 1 +``` + **Python:** ```coconut_python def f(x, y): return (x, y) @@ -1873,6 +2053,10 @@ def p1(x): return x + 1 print(p1(5)) ``` +```coconut_python +quad = 5 * x**2 + 3 * x + 1 +``` + ### Anonymous Namedtuples Coconut supports anonymous [`namedtuple`](https://docs.python.org/3/library/collections.html#collections.namedtuple) literals, such that `(a=1, b=2)` can be used just as `(1, 2)`, but with added names. Anonymous `namedtuple`s are always pickleable. @@ -2027,12 +2211,10 @@ _Showcases tail recursion elimination._ ```coconut # unlike in Python, neither of these functions will ever hit a maximum recursion depth error def is_even(0) = True -@addpattern(is_even) -def is_even(n `isinstance` int if n > 0) = is_odd(n-1) +addpattern def is_even(n `isinstance` int if n > 0) = is_odd(n-1) def is_odd(0) = False -@addpattern(is_odd) -def is_odd(n `isinstance` int if n > 0) = is_even(n-1) +addpattern def is_odd(n `isinstance` int if n > 0) = is_even(n-1) ``` _Showcases tail call optimization._ @@ -2104,7 +2286,7 @@ print(binexp(5)) Coconut pattern-matching functions are just normal functions, except where the arguments are patterns to be matched against instead of variables to be assigned to. The syntax for pattern-matching function definition is ```coconut -[async] [match] def (, , ... [if ]) [-> ]: +[match] def (, , ... [if ]) [-> ]: ``` where `` is defined as @@ -2119,7 +2301,7 @@ In addition to supporting pattern-matching in their arguments, pattern-matching - If pattern-matching function definition fails, it will raise a [`MatchError`](#matcherror) (just like [destructuring assignment](#destructuring-assignment)) instead of a `TypeError`. - All defaults in pattern-matching function definition are late-bound rather than early-bound. Thus, `match def f(xs=[]) = xs` will instantiate a new list for each call where `xs` is not given, unlike `def f(xs=[]) = xs`, which will use the same list for all calls where `xs` is unspecified. -_Note: Pattern-matching function definition can be combined with assignment and/or infix function definition._ +Pattern-matching function definition can also be combined with `async` functions, [`copyclosure` functions](#copyclosure-functions), [`yield` functions](#explicit-generators), [infix function definition](#infix-functions), and [assignment function syntax](#assignment-functions). The various keywords in front of the `def` can be put in any order. ##### Example @@ -2147,29 +2329,82 @@ match def func(...): ``` syntax using the [`addpattern`](#addpattern) decorator. +Additionally, `addpattern def` will act just like a normal [`match def`](#pattern-matching-functions) if the function has not previously been defined, allowing for `addpattern def` to be used for each case rather than requiring `match def` for the first case and `addpattern def` for future cases. + If you want to put a decorator on an `addpattern def` function, make sure to put it on the _last_ pattern function. ##### Example **Coconut:** ```coconut -def factorial(0) = 1 +addpattern def factorial(0) = 1 addpattern def factorial(n) = n * factorial(n - 1) ``` **Python:** _Can't be done without a complicated decorator definition and a long series of checks for each pattern-matching. See the compiled code for the Python syntax._ +### `copyclosure` Functions + +Coconut supports the syntax +``` +copyclosure def (): + +``` +to define a function that uses as its closure a shallow copy of its enclosing scopes at the time that the function is defined, rather than a reference to those scopes (as with normal Python functions). + +For example,`in +```coconut +def outer_func(): + funcs = [] + for x in range(10): + copyclosure def inner_func(): + return x + funcs.append(inner_func) + return funcs +``` +the resulting `inner_func`s will each return a _different_ `x` value rather than all the same `x` value, since they look at what `x` was bound to at function definition time rather than during function execution. + +`copyclosure` functions can also be combined with `async` functions, [`yield` functions](#explicit-generators), [pattern-matching functions](#pattern-matching-functions), [infix function definition](#infix-functions), and [assignment function syntax](#assignment-functions). The various keywords in front of the `def` can be put in any order. + +If `global` or `nonlocal` are used in a `copyclosure` function, they will not be able to modify variables in enclosing scopes. However, they will allow state to be preserved accross multiple calls to the `copyclosure` function. + +##### Example + +**Coconut:** +```coconut +def outer_func(): + funcs = [] + for x in range(10): + copyclosure def inner_func(): + return x + funcs.append(inner_func) + return funcs +``` + +**Python:** +```coconut_python +from functools import partial + +def outer_func(): + funcs = [] + for x in range(10): + def inner_func(_x): + return _x + funcs.append(partial(inner_func, x)) + return funcs +``` + ### Explicit Generators Coconut supports the syntax ``` -[async] yield def (): +yield def (): ``` -to denote that you are explicitly defining a generator function. This is useful to ensure that, even if all the `yield`s in your function are removed, it'll always be a generator function. Note that the `async` and `yield` keywords can be in any order. +to denote that you are explicitly defining a generator function. This is useful to ensure that, even if all the `yield`s in your function are removed, it'll always be a generator function. -Explicit generator functions also support [pattern-matching syntax](#pattern-matching-functions), [infix function definition](#infix-functions), and [assignment function syntax](#assignment-functions) (though note that assignment function syntax here creates a generator return). +Explicit generator functions can also be combined with `async` functions, [`copyclosure` functions](#copyclosure-functions), [pattern-matching functions](#pattern-matching-functions), [infix function definition](#infix-functions), and [assignment function syntax](#assignment-functions) (though note that assignment function syntax here creates a generator return). The various keywords in front of the `def` can be put in any order. ##### Example @@ -2187,7 +2422,7 @@ def empty_it(): ### Dotted Function Definition -Coconut allows for function definition using a dotted name to assign a function as a method of an object as specified in [PEP 542](https://www.python.org/dev/peps/pep-0542/). +Coconut allows for function definition using a dotted name to assign a function as a method of an object as specified in [PEP 542](https://www.python.org/dev/peps/pep-0542/). Dotted function definition can be combined with all other types of function definition above. ##### Example @@ -2248,9 +2483,11 @@ Coconut fully supports [PEP 695](https://peps.python.org/pep-0695/) type paramet That includes type parameters for classes, [`data` types](#data), and [all types of function definition](#function-definition). For different types of function definition, the type parameters always come in brackets right after the function name. Coconut's [enhanced type annotation syntax](#enhanced-type-annotation) is supported for all type parameter bounds. +_Warning: until `mypy` adds support for `infer_variance=True` in `TypeVar`, `TypeVar`s created this way will always be invariant._ + Additionally, Coconut supports the alternative bounds syntax of `type NewType[T <: bound] = ...` rather than `type NewType[T: bound] = ...`, to make it more clear that it is an upper bound rather than a type. In `--strict` mode, `<:` is required over `:` for all type parameter bounds. _DEPRECATED: `<=` can also be used as an alternative to `<:`._ -_Note that, by default, all type declarations are wrapped in strings to enable forward references and improve runtime performance. If you don't want that—e.g. because you want to use type annotations at runtime—simply pass the `--no-wrap` flag._ +_Note that, by default, all type declarations are wrapped in strings to enable forward references and improve runtime performance. If you don't want that—e.g. because you want to use type annotations at runtime—simply pass the `--no-wrap-types` flag._ ##### PEP 695 Docs @@ -2824,6 +3061,9 @@ data Expected[T](result: T? = None, error: BaseException? = None): if not self.result `isinstance` Expected: raise TypeError("Expected.join() requires an Expected[Expected[_]]") return self.result + def map_error(self, func: BaseException -> BaseException) -> Expected[T]: + """Maps func over the error if it exists.""" + return self if self else self.__class__(error=func(self.error)) def or_else[U](self, func: BaseException -> Expected[U]) -> Expected[T | U]: """Return self if no error, otherwise return the result of evaluating func on the error.""" return self if self else func(self.error) @@ -2928,6 +3168,8 @@ For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the map For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. +For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s). + For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to ```coconut_python async def fmap_over_async_iters(func, async_iter): @@ -2958,7 +3200,7 @@ _Can't be done without a series of method definitions for each data type. See th **call**(_func_, /, *_args_, \*\*_kwargs_) -Coconut's `call` simply implements function application. Thus, `call` is equivalent to +Coconut's `call` simply implements function application. Thus, `call` is effectively equivalent to ```coconut def call(f, /, *args, **kwargs) = f(*args, **kwargs) ``` @@ -3568,15 +3810,7 @@ assert list(product(v, v)) == [(1, 1), (1, 2), (2, 1), (2, 2)] Coconut's `multi_enumerate` enumerates through an iterable of iterables. `multi_enumerate` works like enumerate, but indexes through inner iterables and produces a tuple index representing the index in each inner iterable. Supports indexing. -For [`numpy`](#numpy-integration) objects, effectively equivalent to: -```coconut_python -def multi_enumerate(iterable): - it = np.nditer(iterable, flags=["multi_index"]) - for x in it: - yield it.multi_index, x -``` - -Also supports `len` for [`numpy`](#numpy-integration). +For [`numpy`](#numpy-integration) objects, uses [`np.nditer`](https://numpy.org/doc/stable/reference/generated/numpy.nditer.html) under the hood. Also supports `len` for [`numpy`](#numpy-integration) arrays. ##### Example diff --git a/HELP.md b/HELP.md index aed288d3c..e016cb271 100644 --- a/HELP.md +++ b/HELP.md @@ -1132,7 +1132,6 @@ Another useful Coconut feature is implicit partials. Coconut supports a number o ```coconut .attr .method(args) -obj. func$ seq[] iter$[] diff --git a/MANIFEST.in b/MANIFEST.in index c0c085b1e..f15216482 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,6 +8,7 @@ global-include *.md global-include *.json global-include *.toml global-include *.coco +global-include *.ini global-include py.typed prune coconut/tests/dest prune docs diff --git a/Makefile b/Makefile index 80770b79e..2e7c73d2d 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ dev-py3: clean setup-py3 .PHONY: setup setup: python -m ensurepip - python -m pip install --upgrade "setuptools<58" wheel pip pytest_remotedata + python -m pip install --upgrade setuptools wheel pip pytest_remotedata .PHONY: setup-py2 setup-py2: @@ -32,7 +32,7 @@ setup-py2: .PHONY: setup-py3 setup-py3: python3 -m ensurepip - python3 -m pip install --upgrade "setuptools<58" wheel pip pytest_remotedata + python3 -m pip install --upgrade setuptools wheel pip pytest_remotedata .PHONY: setup-pypy setup-pypy: @@ -42,7 +42,7 @@ setup-pypy: .PHONY: setup-pypy3 setup-pypy3: pypy3 -m ensurepip - pypy3 -m pip install --upgrade "setuptools<58" wheel pip pytest_remotedata + pypy3 -m pip install --upgrade setuptools wheel pip pytest_remotedata .PHONY: install install: setup @@ -57,11 +57,11 @@ install-py3: setup-py3 python3 -m pip install -e .[tests] .PHONY: install-pypy -install-pypy: +install-pypy: setup-pypy pypy -m pip install -e .[tests] .PHONY: install-pypy3 -install-pypy3: +install-pypy3: setup-pypy3 pypy3 -m pip install -e .[tests] .PHONY: format @@ -155,11 +155,19 @@ test-verbose: clean python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py -# same as test-mypy but uses --verbose and --check-untyped-defs +# same as test-mypy but uses --verbose +.PHONY: test-mypy-verbose +test-mypy-verbose: export COCONUT_USE_COLOR=TRUE +test-mypy-verbose: clean + python ./coconut/tests --strict --force --target sys --verbose --jobs 0 --keep-lines --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition + python ./coconut/tests/dest/runner.py + python ./coconut/tests/dest/extras.py + +# same as test-mypy but uses --check-untyped-defs .PHONY: test-mypy-all test-mypy-all: export COCONUT_USE_COLOR=TRUE test-mypy-all: clean - python ./coconut/tests --strict --force --target sys --verbose --keep-lines --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition --check-untyped-defs + python ./coconut/tests --strict --force --target sys --keep-lines --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition --check-untyped-defs python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py @@ -198,6 +206,12 @@ test-watch: clean test-mini: coconut ./coconut/tests/src/cocotest/agnostic ./coconut/tests/dest/cocotest --force --jobs 0 +.PHONY: debug-comp-crash +debug-comp-crash: export COCONUT_USE_COLOR=TRUE +debug-comp-crash: export COCONUT_PURE_PYTHON=TRUE +debug-comp-crash: + python -X dev -m coconut ./coconut/tests/src/cocotest/agnostic ./coconut/tests/dest/cocotest --strict --line-numbers --keep-lines --force --jobs 0 + .PHONY: debug-test-crash debug-test-crash: python -X dev ./coconut/tests/dest/runner.py @@ -221,7 +235,7 @@ clean: .PHONY: wipe wipe: clean - rm -rf vprof.json profile.log *.egg-info + rm -rf vprof.json profile.log *.egg-info -find . -name "__pycache__" -delete -C:/GnuWin32/bin/find.exe . -name "__pycache__" -delete -find . -name "*.pyc" -delete @@ -253,19 +267,18 @@ check-reqs: python ./coconut/requirements.py .PHONY: profile-parser +profile-parser: export COCONUT_USE_COLOR=TRUE profile-parser: export COCONUT_PURE_PYTHON=TRUE profile-parser: coconut ./coconut/tests/src/cocotest/agnostic ./coconut/tests/dest/cocotest --force --profile --verbose --recursion-limit 4096 2>&1 | tee ./profile.log .PHONY: profile-time -profile-time: export COCONUT_PURE_PYTHON=TRUE profile-time: - vprof -c h "coconut ./coconut/tests/src/cocotest/agnostic ./coconut/tests/dest/cocotest --force" --output-file ./vprof.json + vprof -c h "./coconut ./coconut/tests/src/cocotest/agnostic ./coconut/tests/dest/cocotest --force" --output-file ./vprof.json .PHONY: profile-memory -profile-memory: export COCONUT_PURE_PYTHON=TRUE profile-memory: - vprof -c m "coconut ./coconut/tests/src/cocotest/agnostic ./coconut/tests/dest/cocotest --force" --output-file ./vprof.json + vprof -c m "./coconut ./coconut/tests/src/cocotest/agnostic ./coconut/tests/dest/cocotest --force" --output-file ./vprof.json .PHONY: view-profile view-profile: diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 333086bdb..4a42bb999 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -8,36 +8,9 @@ License: Apache 2.0 Description: MyPy stub file for __coconut__.py. """ -# ----------------------------------------------------------------------------------------------------------------------- -# IMPORTS: -# ----------------------------------------------------------------------------------------------------------------------- - import sys import typing as _t -if sys.version_info >= (3, 11): - from typing import dataclass_transform as _dataclass_transform -else: - try: - from typing_extensions import dataclass_transform as _dataclass_transform - except ImportError: - dataclass_transform = ... - -import _coconut as __coconut # we mock _coconut as a package since mypy doesn't handle namespace classes very well -_coconut = __coconut - -if sys.version_info >= (3, 2): - from functools import lru_cache as _lru_cache -else: - from backports.functools_lru_cache import lru_cache as _lru_cache # `pip install -U coconut[mypy]` to fix errors on this line - _coconut.functools.lru_cache = _lru_cache # type: ignore - -if sys.version_info >= (3, 7): - from dataclasses import dataclass as _dataclass -else: - @_dataclass_transform() - def _dataclass(cls: type[_T], **kwargs: _t.Any) -> type[_T]: ... - # ----------------------------------------------------------------------------------------------------------------------- # TYPE VARS: # ----------------------------------------------------------------------------------------------------------------------- @@ -82,6 +55,40 @@ _P = _t.ParamSpec("_P") class _SupportsIndex(_t.Protocol): def __index__(self) -> int: ... + +# ----------------------------------------------------------------------------------------------------------------------- +# IMPORTS: +# ----------------------------------------------------------------------------------------------------------------------- + +if sys.version_info >= (3, 11): + from typing import dataclass_transform as _dataclass_transform +else: + try: + from typing_extensions import dataclass_transform as _dataclass_transform + except ImportError: + dataclass_transform = ... + +import _coconut as __coconut # we mock _coconut as a package since mypy doesn't handle namespace classes very well +_coconut = __coconut + +if sys.version_info >= (3, 2): + from functools import lru_cache as _lru_cache +else: + from backports.functools_lru_cache import lru_cache as _lru_cache # `pip install -U coconut[mypy]` to fix errors on this line + _coconut.functools.lru_cache = _lru_cache # type: ignore + +if sys.version_info >= (3, 7): + from dataclasses import dataclass as _dataclass +else: + @_dataclass_transform() + def _dataclass(cls: t_coype[_T], **kwargs: _t.Any) -> type[_T]: ... + +try: + from typing_extensions import deprecated as _deprecated # type: ignore +except ImportError: + def _deprecated(message: _t.Text) -> _t.Callable[[_T], _T]: ... # type: ignore + + # ----------------------------------------------------------------------------------------------------------------------- # STUB: # ----------------------------------------------------------------------------------------------------------------------- @@ -126,6 +133,7 @@ if sys.version_info < (3, 7): py_chr = chr +py_dict = dict py_hex = hex py_input = input py_int = int @@ -210,7 +218,7 @@ def _coconut_tco(func: _Tfunc) -> _Tfunc: return func -# any changes here should also be made to safe_call below +# any changes here should also be made to safe_call and call_or_coefficient below @_t.overload def call( _func: _t.Callable[[_T], _U], @@ -304,6 +312,7 @@ class Expected(_BaseExpected[_T]): def __getitem__(self, index: slice) -> _t.Tuple[_T | BaseException | None, ...]: ... def and_then(self, func: _t.Callable[[_T], Expected[_U]]) -> Expected[_U]: ... def join(self: Expected[Expected[_T]]) -> Expected[_T]: ... + def map_error(self, func: _t.Callable[[BaseException], BaseException]) -> Expected[_T]: ... def or_else(self, func: _t.Callable[[BaseException], Expected[_U]]) -> Expected[_T | _U]: ... def result_or(self, default: _U) -> _T | _U: ... def result_or_else(self, func: _t.Callable[[BaseException], _U]) -> _T | _U: ... @@ -363,6 +372,58 @@ def safe_call( ) -> Expected[_T]: ... +# based on call above +@_t.overload +def _coconut_call_or_coefficient( + _func: _t.Callable[[_T], _U], + _x: _T, +) -> _U: ... +@_t.overload +def _coconut_call_or_coefficient( + _func: _t.Callable[[_T, _U], _V], + _x: _T, + _y: _U, +) -> _V: ... +@_t.overload +def _coconut_call_or_coefficient( + _func: _t.Callable[[_T, _U, _V], _W], + _x: _T, + _y: _U, + _z: _V, +) -> _W: ... +@_t.overload +def _coconut_call_or_coefficient( + _func: _t.Callable[_t.Concatenate[_T, _P], _U], + _x: _T, + *args: _t.Any, +) -> _U: ... +@_t.overload +def _coconut_call_or_coefficient( + _func: _t.Callable[_t.Concatenate[_T, _U, _P], _V], + _x: _T, + _y: _U, + *args: _t.Any, +) -> _V: ... +@_t.overload +def _coconut_call_or_coefficient( + _func: _t.Callable[_t.Concatenate[_T, _U, _V, _P], _W], + _x: _T, + _y: _U, + _z: _V, + *args: _t.Any, +) -> _W: ... +@_t.overload +def _coconut_call_or_coefficient( + _func: _t.Callable[..., _T], + *args: _t.Any, +) -> _T: ... +@_t.overload +def _coconut_call_or_coefficient( + _func: _T, + *args: _T, +) -> _T: ... + + def recursive_iterator(func: _T_iter_func) -> _T_iter_func: return func @@ -412,6 +473,7 @@ def addpattern( *add_funcs: _Callable, allow_any_func: bool=False, ) -> _t.Callable[..., _t.Any]: ... + _coconut_addpattern = prepattern = addpattern @@ -812,6 +874,10 @@ def _coconut_bool_or(a: _t.Literal[False], b: _T) -> _T: ... def _coconut_bool_or(a: _T, b: _U) -> _t.Union[_T, _U]: ... +def _coconut_in(a: _T, b: _t.Sequence[_T]) -> bool: ... +_coconut_not_in = _coconut_in + + @_t.overload def _coconut_none_coalesce(a: _T, b: None) -> _T: ... @_t.overload @@ -957,6 +1023,7 @@ _coconut_flatten = flatten def makedata(data_type: _t.Type[_T], *args: _t.Any) -> _T: ... +@_deprecated("use makedata instead") def datamaker(data_type: _t.Type[_T]) -> _t.Callable[..., _T]: return _coconut.functools.partial(makedata, data_type) @@ -985,10 +1052,10 @@ def fmap(func: _t.Callable[[_T], _U], obj: _t.Iterator[_T]) -> _t.Iterator[_U]: def fmap(func: _t.Callable[[_T], _U], obj: _t.Set[_T]) -> _t.Set[_U]: ... @_t.overload def fmap(func: _t.Callable[[_T], _U], obj: _t.AsyncIterable[_T]) -> _t.AsyncIterable[_U]: ... -@_t.overload -def fmap(func: _t.Callable[[_t.Tuple[_T, _U]], _t.Tuple[_V, _W]], obj: _t.Dict[_T, _U]) -> _t.Dict[_V, _W]: ... -@_t.overload -def fmap(func: _t.Callable[[_t.Tuple[_T, _U]], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U]) -> _t.Mapping[_V, _W]: ... +# @_t.overload +# def fmap(func: _t.Callable[[_t.Tuple[_T, _U]], _t.Tuple[_V, _W]], obj: _t.Dict[_T, _U]) -> _t.Dict[_V, _W]: ... +# @_t.overload +# def fmap(func: _t.Callable[[_t.Tuple[_T, _U]], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U]) -> _t.Mapping[_V, _W]: ... @_t.overload def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Dict[_T, _U], starmap_over_mappings: _t.Literal[True]) -> _t.Dict[_V, _W]: ... @_t.overload @@ -1042,16 +1109,16 @@ class _coconut_lifted_1(_t.Generic[_T, _W]): # self, # _g: _t.Callable[[_X], _T], # ) -> _t.Callable[[_X], _W]: ... - @_t.overload - def __call__( - self, - _g: _t.Callable[[_X, _Y], _T], - ) -> _t.Callable[[_X, _Y], _W]: ... - @_t.overload - def __call__( - self, - _g: _t.Callable[[_X, _Y, _Z], _T], - ) -> _t.Callable[[_X, _Y, _Z], _W]: ... + # @_t.overload + # def __call__( + # self, + # _g: _t.Callable[[_X, _Y], _T], + # ) -> _t.Callable[[_X, _Y], _W]: ... + # @_t.overload + # def __call__( + # self, + # _g: _t.Callable[[_X, _Y, _Z], _T], + # ) -> _t.Callable[[_X, _Y, _Z], _W]: ... @_t.overload def __call__( self, @@ -1076,18 +1143,18 @@ class _coconut_lifted_2(_t.Generic[_T, _U, _W]): # _g: _t.Callable[[_X], _T], # _h: _t.Callable[[_X], _U], # ) -> _t.Callable[[_X], _W]: ... - @_t.overload - def __call__( - self, - _g: _t.Callable[[_X, _Y], _T], - _h: _t.Callable[[_X, _Y], _U], - ) -> _t.Callable[[_X, _Y], _W]: ... - @_t.overload - def __call__( - self, - _g: _t.Callable[[_X, _Y, _Z], _T], - _h: _t.Callable[[_X, _Y, _Z], _U], - ) -> _t.Callable[[_X, _Y, _Z], _W]: ... + # @_t.overload + # def __call__( + # self, + # _g: _t.Callable[[_X, _Y], _T], + # _h: _t.Callable[[_X, _Y], _U], + # ) -> _t.Callable[[_X, _Y], _W]: ... + # @_t.overload + # def __call__( + # self, + # _g: _t.Callable[[_X, _Y, _Z], _T], + # _h: _t.Callable[[_X, _Y, _Z], _U], + # ) -> _t.Callable[[_X, _Y, _Z], _W]: ... @_t.overload def __call__( self, @@ -1116,20 +1183,20 @@ class _coconut_lifted_3(_t.Generic[_T, _U, _V, _W]): # _h: _t.Callable[[_X], _U], # _i: _t.Callable[[_X], _V], # ) -> _t.Callable[[_X], _W]: ... - @_t.overload - def __call__( - self, - _g: _t.Callable[[_X, _Y], _T], - _h: _t.Callable[[_X, _Y], _U], - _i: _t.Callable[[_X, _Y], _V], - ) -> _t.Callable[[_X, _Y], _W]: ... - @_t.overload - def __call__( - self, - _g: _t.Callable[[_X, _Y, _Z], _T], - _h: _t.Callable[[_X, _Y, _Z], _U], - _i: _t.Callable[[_X, _Y, _Z], _V], - ) -> _t.Callable[[_X, _Y, _Z], _W]: ... + # @_t.overload + # def __call__( + # self, + # _g: _t.Callable[[_X, _Y], _T], + # _h: _t.Callable[[_X, _Y], _U], + # _i: _t.Callable[[_X, _Y], _V], + # ) -> _t.Callable[[_X, _Y], _W]: ... + # @_t.overload + # def __call__( + # self, + # _g: _t.Callable[[_X, _Y, _Z], _T], + # _h: _t.Callable[[_X, _Y, _Z], _U], + # _i: _t.Callable[[_X, _Y, _Z], _V], + # ) -> _t.Callable[[_X, _Y, _Z], _W]: ... @_t.overload def __call__( self, @@ -1267,3 +1334,62 @@ def _coconut_multi_dim_arr( @_t.overload def _coconut_multi_dim_arr(arrs: _Tuple, dim: int) -> _Sequence: ... + + +class _coconut_SupportsAdd(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __add__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsMinus(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __sub__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + def __neg__(self: _Tco) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsMul(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __mul__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsPow(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __pow__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsTruediv(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __truediv__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsFloordiv(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __floordiv__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsMod(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __mod__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsAnd(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __and__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsXor(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __xor__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsOr(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __or__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsLshift(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __lshift__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsRshift(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __rshift__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsMatmul(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): + def __matmul__(self: _Tco, other: _Ucontra) -> _Vco: + raise NotImplementedError + +class _coconut_SupportsInv(_t.Protocol, _t.Generic[_Tco, _Vco]): + def __invert__(self: _Tco) -> _Vco: + raise NotImplementedError diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 6b5c906b5..e60765ee8 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -60,32 +60,46 @@ except ImportError: else: _abc.Sequence.register(_numpy.ndarray) +# ----------------------------------------------------------------------------------------------------------------------- +# TYPING: +# ----------------------------------------------------------------------------------------------------------------------- + +typing = _t + +from typing_extensions import TypeVar +typing.TypeVar = TypeVar # type: ignore + +if sys.version_info < (3, 8): + try: + from typing_extensions import Protocol + except ImportError: + Protocol = ... # type: ignore + typing.Protocol = Protocol # type: ignore + if sys.version_info < (3, 10): try: from typing_extensions import TypeAlias, ParamSpec, Concatenate except ImportError: - TypeAlias = ... - ParamSpec = ... - typing.TypeAlias = TypeAlias - typing.ParamSpec = ParamSpec - typing.Concatenate = Concatenate - + TypeAlias = ... # type: ignore + ParamSpec = ... # type: ignore + Concatenate = ... # type: ignore + typing.TypeAlias = TypeAlias # type: ignore + typing.ParamSpec = ParamSpec # type: ignore + typing.Concatenate = Concatenate # type: ignore if sys.version_info < (3, 11): try: from typing_extensions import TypeVarTuple, Unpack except ImportError: - TypeVarTuple = ... - Unpack = ... - typing.TypeVarTuple = TypeVarTuple - typing.Unpack = Unpack + TypeVarTuple = ... # type: ignore + Unpack = ... # type: ignore + typing.TypeVarTuple = TypeVarTuple # type: ignore + typing.Unpack = Unpack # type: ignore # ----------------------------------------------------------------------------------------------------------------------- # STUB: # ----------------------------------------------------------------------------------------------------------------------- -typing = _t - collections = _collections copy = _copy functools = _functools @@ -108,13 +122,16 @@ if sys.version_info >= (2, 7): OrderedDict = collections.OrderedDict else: OrderedDict = dict + abc = _abc abc.Sequence.register(collections.deque) + numpy = _numpy npt = _npt # Fake, like typing zip_longest = _zip_longest numpy_modules: _t.Any = ... +pandas_numpy_modules: _t.Any = ... jax_numpy_modules: _t.Any = ... tee_type: _t.Any = ... reiterables: _t.Any = ... @@ -134,6 +151,7 @@ StopIteration = StopIteration RuntimeError = RuntimeError callable = callable classmethod = classmethod +complex = complex all = all any = any bool = bool @@ -154,6 +172,7 @@ iter = iter len: _t.Callable[..., int] = ... # pattern-matching needs an untyped _coconut.len to avoid type errors list = list locals = locals +globals = globals map = map min = min max = max diff --git a/coconut/__coconut__.pyi b/coconut/__coconut__.pyi index 1964666c7..45d413ea3 100644 --- a/coconut/__coconut__.pyi +++ b/coconut/__coconut__.pyi @@ -1,2 +1,2 @@ from __coconut__ import * -from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_super, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose +from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in diff --git a/coconut/_pyparsing.py b/coconut/_pyparsing.py index 1806c433b..d975a6d14 100644 --- a/coconut/_pyparsing.py +++ b/coconut/_pyparsing.py @@ -28,8 +28,8 @@ from collections import defaultdict from coconut.constants import ( - PURE_PYTHON, PYPY, + PURE_PYTHON, use_fast_pyparsing_reprs, use_packrat_parser, packrat_cache_size, @@ -39,6 +39,8 @@ pure_python_env_var, enable_pyparsing_warnings, use_left_recursion_if_available, + get_bool_env_var, + use_computation_graph_env_var, ) from coconut.util import get_clock_time # NOQA from coconut.util import ( @@ -121,9 +123,12 @@ + " (run either '{python} -m pip install cPyparsing<{max_ver}' or '{python} -m pip install pyparsing<{max_ver}' to fix)".format(python=sys.executable, max_ver=max_ver_str), ) -USE_COMPUTATION_GRAPH = ( - not MODERN_PYPARSING # not yet supported - and not PYPY # experimentally determined +USE_COMPUTATION_GRAPH = get_bool_env_var( + use_computation_graph_env_var, + default=( + not MODERN_PYPARSING # not yet supported + and not PYPY # experimentally determined + ), ) if enable_pyparsing_warnings: diff --git a/coconut/command/cli.py b/coconut/command/cli.py index f666adcd7..5e9c930a1 100644 --- a/coconut/command/cli.py +++ b/coconut/command/cli.py @@ -34,6 +34,7 @@ prompt_histfile, home_env_var, py_version_str, + default_jobs, ) # ----------------------------------------------------------------------------------------------------------------------- @@ -166,7 +167,7 @@ ) arguments.add_argument( - "--no-wrap", "--nowrap", + "--no-wrap-types", "--nowraptypes", action="store_true", help="disable wrapping type annotations in strings and turn off 'from __future__ import annotations' behavior", ) @@ -182,7 +183,7 @@ "-j", "--jobs", metavar="processes", type=str, - help="number of additional processes to use (defaults to 0) (pass 'sys' to use machine default)", + help="number of additional processes to use (defaults to " + ascii(default_jobs) + ") (0 is no additional processes; 'sys' uses machine default)", ) arguments.add_argument( @@ -255,7 +256,14 @@ "--recursion-limit", "--recursionlimit", metavar="limit", type=int, - help="set maximum recursion depth in compiler (defaults to " + str(default_recursion_limit) + ")", + help="set maximum recursion depth in compiler (defaults to " + ascii(default_recursion_limit) + ") (when increasing --recursion-limit, you may also need to increase --stack-size)", +) + +arguments.add_argument( + "--stack-size", "--stacksize", + metavar="kbs", + type=int, + help="run the compiler in a separate thread with the given stack size in kilobytes", ) arguments.add_argument( diff --git a/coconut/command/command.py b/coconut/command/command.py index a4fd5d5f6..ebeeace41 100644 --- a/coconut/command/command.py +++ b/coconut/command/command.py @@ -66,6 +66,8 @@ mypy_builtin_regex, coconut_pth_file, error_color_code, + jupyter_console_commands, + default_jobs, ) from coconut.util import ( univ_open, @@ -93,6 +95,7 @@ set_recursion_limit, can_parse, invert_mypy_arg, + run_with_stack_size, ) from coconut.compiler.util import ( should_indent, @@ -110,18 +113,25 @@ class Command(object): """Coconut command-line interface.""" comp = None # current coconut.compiler.Compiler - show = False # corresponds to --display flag runner = None # the current Runner - jobs = 0 # corresponds to --jobs flag executor = None # runs --jobs exit_code = 0 # exit status to return errmsg = None # error message to display + + show = False # corresponds to --display flag + jobs = 0 # corresponds to --jobs flag mypy_args = None # corresponds to --mypy flag argv_args = None # corresponds to --argv flag + stack_size = 0 # corresponds to --stack-size flag - def __init__(self): - """Create the CLI.""" - self.prompt = Prompt() + _prompt = None + + @property + def prompt(self): + """Delay creation of a Prompt() until it's needed.""" + if self._prompt is None: + self._prompt = Prompt() + return self._prompt def start(self, run=False): """Endpoint for coconut and coconut-run.""" @@ -144,21 +154,29 @@ def start(self, run=False): def cmd(self, args=None, argv=None, interact=True, default_target=None): """Process command-line arguments.""" - if args is None: - parsed_args = arguments.parse_args() - else: - parsed_args = arguments.parse_args(args) - if argv is not None: - if parsed_args.argv is not None: - raise CoconutException("cannot pass --argv/--args when using coconut-run (coconut-run interprets any arguments after the source file as --argv/--args)") - parsed_args.argv = argv - if parsed_args.target is None: - parsed_args.target = default_target - self.exit_code = 0 with self.handling_exceptions(): - self.use_args(parsed_args, interact, original_args=args) + if args is None: + parsed_args = arguments.parse_args() + else: + parsed_args = arguments.parse_args(args) + if argv is not None: + if parsed_args.argv is not None: + raise CoconutException("cannot pass --argv/--args when using coconut-run (coconut-run interprets any arguments after the source file as --argv/--args)") + parsed_args.argv = argv + if parsed_args.target is None: + parsed_args.target = default_target + self.exit_code = 0 + self.stack_size = parsed_args.stack_size + self.run_with_stack_size(self.execute_args, parsed_args, interact, original_args=args) self.exit_on_error() + def run_with_stack_size(self, func, *args, **kwargs): + """Execute func with the correct stack size.""" + if self.stack_size: + return run_with_stack_size(self.stack_size, func, *args, **kwargs) + else: + return func(*args, **kwargs) + def setup(self, *args, **kwargs): """Set parameters for the compiler.""" if self.comp is None: @@ -182,159 +200,162 @@ def exit_on_error(self): kill_children() sys.exit(self.exit_code) - def use_args(self, args, interact=True, original_args=None): + def execute_args(self, args, interact=True, original_args=None): """Handle command-line arguments.""" - # fix args - if not DEVELOP: - args.trace = args.profile = False - - # set up logger - logger.quiet, logger.verbose, logger.tracing = args.quiet, args.verbose, args.trace - if args.verbose or args.trace or args.profile: - set_grammar_names() - if args.trace or args.profile: - unset_fast_pyparsing_reprs() - if args.profile: - collect_timing_info() - logger.enable_colors() - - logger.log(cli_version) - if original_args is not None: - logger.log("Directly passed args:", original_args) - logger.log("Parsed args:", args) - - # validate general command args - if args.mypy is not None and args.line_numbers: - logger.warn("extraneous --line-numbers argument passed; --mypy implies --line-numbers") - if args.site_install and args.site_uninstall: - raise CoconutException("cannot --site-install and --site-uninstall simultaneously") - for and_args in getattr(args, "and") or []: - if len(and_args) > 2: - raise CoconutException( - "--and accepts at most two arguments, source and dest ({n} given: {args!r})".format( - n=len(and_args), - args=and_args, - ), - ) + with self.handling_exceptions(): + # fix args + if not DEVELOP: + args.trace = args.profile = False + + # set up logger + logger.quiet, logger.verbose, logger.tracing = args.quiet, args.verbose, args.trace + if args.verbose or args.trace or args.profile: + set_grammar_names() + if args.trace or args.profile: + unset_fast_pyparsing_reprs() + if args.profile: + collect_timing_info() + logger.enable_colors() + + logger.log(cli_version) + if original_args is not None: + logger.log("Directly passed args:", original_args) + logger.log("Parsed args:", args) + + # validate general command args + if args.stack_size and args.stack_size % 4 != 0: + logger.warn("--stack-size should generally be a multiple of 4, not {stack_size} (to support 4 KB pages)".format(stack_size=args.stack_size)) + if args.mypy is not None and args.line_numbers: + logger.warn("extraneous --line-numbers argument passed; --mypy implies --line-numbers") + if args.site_install and args.site_uninstall: + raise CoconutException("cannot --site-install and --site-uninstall simultaneously") + for and_args in getattr(args, "and") or []: + if len(and_args) > 2: + raise CoconutException( + "--and accepts at most two arguments, source and dest ({n} given: {args!r})".format( + n=len(and_args), + args=and_args, + ), + ) - # process general command args - if args.recursion_limit is not None: - set_recursion_limit(args.recursion_limit) - if args.jobs is not None: - self.set_jobs(args.jobs) - if args.display: - self.show = True - if args.style is not None: - self.prompt.set_style(args.style) - if args.history_file is not None: - self.prompt.set_history_file(args.history_file) - if args.vi_mode: - self.prompt.vi_mode = True - if args.docs: - launch_documentation() - if args.tutorial: - launch_tutorial() - if args.site_uninstall: - self.site_uninstall() - if args.site_install: - self.site_install() - if args.argv is not None: - self.argv_args = list(args.argv) - - # additional validation after processing - if args.profile and self.jobs != 0: - raise CoconutException("--profile incompatible with --jobs {jobs}".format(jobs=args.jobs)) - - # process general compiler args - self.setup( - target=args.target, - strict=args.strict, - minify=args.minify, - line_numbers=args.line_numbers or args.mypy is not None, - keep_lines=args.keep_lines, - no_tco=args.no_tco, - no_wrap=args.no_wrap, - ) + # process general command args + self.set_jobs(args.jobs, args.profile) + if args.recursion_limit is not None: + set_recursion_limit(args.recursion_limit) + if args.display: + self.show = True + if args.style is not None: + self.prompt.set_style(args.style) + if args.history_file is not None: + self.prompt.set_history_file(args.history_file) + if args.vi_mode: + self.prompt.vi_mode = True + if args.docs: + launch_documentation() + if args.tutorial: + launch_tutorial() + if args.site_uninstall: + self.site_uninstall() + if args.site_install: + self.site_install() + if args.argv is not None: + self.argv_args = list(args.argv) + + # process general compiler args + self.setup( + target=args.target, + strict=args.strict, + minify=args.minify, + line_numbers=args.line_numbers or args.mypy is not None, + keep_lines=args.keep_lines, + no_tco=args.no_tco, + no_wrap=args.no_wrap_types, + ) - # process mypy args and print timing info (must come after compiler setup) - if args.mypy is not None: - self.set_mypy_args(args.mypy) - logger.log("Grammar init time: " + str(self.comp.grammar_init_time) + " secs / Total init time: " + str(get_clock_time() - first_import_time) + " secs") - - if args.source is not None: - # warnings if source is given - if args.interact and args.run: - logger.warn("extraneous --run argument passed; --interact implies --run") - if args.package and self.mypy: - logger.warn("extraneous --package argument passed; --mypy implies --package") - - # errors if source is given - if args.standalone and args.package: - raise CoconutException("cannot compile as both --package and --standalone") - if args.standalone and self.mypy: - raise CoconutException("cannot compile as both --package (implied by --mypy) and --standalone") - if args.no_write and self.mypy: - raise CoconutException("cannot compile with --no-write when using --mypy") - - # process all source, dest pairs - src_dest_package_triples = [ - self.process_source_dest(src, dst, args) - for src, dst in ( - [(args.source, args.dest)] - + (getattr(args, "and") or []) - ) - ] - - # do compilation - with self.running_jobs(exit_on_error=not args.watch): - filepaths = [] - for source, dest, package in src_dest_package_triples: - filepaths += self.compile_path(source, dest, package, run=args.run or args.interact, force=args.force) - self.run_mypy(filepaths) - - # validate args if no source is given - elif ( - args.run - or args.no_write - or args.force - or args.package - or args.standalone - or args.watch - ): - raise CoconutException("a source file/folder must be specified when options that depend on the source are enabled") - elif getattr(args, "and"): - raise CoconutException("--and should only be used for extra source/dest pairs, not the first source/dest pair") - - # handle extra cli tasks - if args.code is not None: - self.execute(self.parse_block(args.code)) - got_stdin = False - if args.jupyter is not None: - self.start_jupyter(args.jupyter) - elif stdin_readable(): - logger.log("Reading piped input from stdin...") - self.execute(self.parse_block(sys.stdin.read())) - got_stdin = True - if args.interact or ( - interact and not ( - got_stdin - or args.source - or args.code - or args.tutorial - or args.docs + # process mypy args and print timing info (must come after compiler setup) + if args.mypy is not None: + self.set_mypy_args(args.mypy) + logger.log("Grammar init time: " + str(self.comp.grammar_init_time) + " secs / Total init time: " + str(get_clock_time() - first_import_time) + " secs") + + if args.source is not None: + # warnings if source is given + if args.interact and args.run: + logger.warn("extraneous --run argument passed; --interact implies --run") + if args.package and self.mypy: + logger.warn("extraneous --package argument passed; --mypy implies --package") + + # errors if source is given + if args.standalone and args.package: + raise CoconutException("cannot compile as both --package and --standalone") + if args.standalone and self.mypy: + raise CoconutException("cannot compile as both --package (implied by --mypy) and --standalone") + if args.no_write and self.mypy: + raise CoconutException("cannot compile with --no-write when using --mypy") + + # process all source, dest pairs + src_dest_package_triples = [ + self.process_source_dest(src, dst, args) + for src, dst in ( + [(args.source, args.dest)] + + (getattr(args, "and") or []) + ) + ] + + # disable jobs if we know we're only compiling one file + if len(src_dest_package_triples) <= 1 and not any(package for _, _, package in src_dest_package_triples): + self.disable_jobs() + + # do compilation + with self.running_jobs(exit_on_error=not args.watch): + filepaths = [] + for source, dest, package in src_dest_package_triples: + filepaths += self.compile_path(source, dest, package, run=args.run or args.interact, force=args.force) + self.run_mypy(filepaths) + + # validate args if no source is given + elif ( + args.run + or args.no_write + or args.force + or args.package + or args.standalone or args.watch - or args.site_uninstall - or args.site_install - or args.jupyter is not None - or args.mypy == [mypy_install_arg] - ) - ): - self.start_prompt() - if args.watch: - # src_dest_package_triples is always available here - self.watch(src_dest_package_triples, args.run, args.force) - if args.profile: - print_timing_info() + or args.jobs + ): + raise CoconutException("a source file/folder must be specified when options that depend on the source are enabled") + elif getattr(args, "and"): + raise CoconutException("--and should only be used for extra source/dest pairs, not the first source/dest pair") + + # handle extra cli tasks + if args.code is not None: + self.execute(self.parse_block(args.code)) + got_stdin = False + if args.jupyter is not None: + self.start_jupyter(args.jupyter) + elif stdin_readable(): + logger.log("Reading piped input from stdin...") + self.execute(self.parse_block(sys.stdin.read())) + got_stdin = True + if args.interact or ( + interact and not ( + got_stdin + or args.source + or args.code + or args.tutorial + or args.docs + or args.watch + or args.site_uninstall + or args.site_install + or args.jupyter is not None + or args.mypy == [mypy_install_arg] + ) + ): + self.start_prompt() + if args.watch: + # src_dest_package_triples is always available here + self.watch(src_dest_package_triples, args.run, args.force) + if args.profile: + print_timing_info() def process_source_dest(self, source, dest, args): """Determine the correct source, dest, package mode to use for the given source, dest, and args.""" @@ -408,7 +429,9 @@ def handling_exceptions(self): except SystemExit as err: self.register_exit_code(err.code) except BaseException as err: - if isinstance(err, CoconutException): + if isinstance(err, GeneratorExit): + raise + elif isinstance(err, CoconutException): logger.print_exc() elif not isinstance(err, KeyboardInterrupt): logger.print_exc() @@ -529,9 +552,9 @@ def callback(compiled): self.execute_file(destpath, argv_source_path=codepath) if package is True: - self.submit_comp_job(codepath, callback, "parse_package", code, package_level=package_level) + self.submit_comp_job(codepath, callback, "parse_package", code, package_level=package_level, filename=os.path.basename(codepath)) elif package is False: - self.submit_comp_job(codepath, callback, "parse_file", code) + self.submit_comp_job(codepath, callback, "parse_file", code, filename=os.path.basename(codepath)) else: raise CoconutInternalException("invalid value for package", package) @@ -574,18 +597,18 @@ def submit_comp_job(self, path, callback, method, *args, **kwargs): with logger.in_path(path): # pickle the compiler in the path context future = self.executor.submit(multiprocess_wrapper(self.comp, method), *args, **kwargs) - def callback_wrapper(completed_future): - """Ensures that all errors are always caught, since errors raised in a callback won't be propagated.""" - with logger.in_path(path): # handle errors in the path context - with self.handling_exceptions(): - result = completed_future.result() - callback(result) - future.add_done_callback(callback_wrapper) + def callback_wrapper(completed_future): + """Ensures that all errors are always caught, since errors raised in a callback won't be propagated.""" + with logger.in_path(path): # handle errors in the path context + with self.handling_exceptions(): + result = completed_future.result() + callback(result) + future.add_done_callback(callback_wrapper) - def set_jobs(self, jobs): + def set_jobs(self, jobs, profile=False): """Set --jobs.""" - if jobs == "sys": - self.jobs = None + if jobs in (None, "sys"): + self.jobs = jobs else: try: jobs = int(jobs) @@ -594,11 +617,30 @@ def set_jobs(self, jobs): if jobs < 0: raise CoconutException("--jobs must be an integer >= 0 or 'sys'") self.jobs = jobs + logger.log("Jobs:", self.jobs) + if profile and self.jobs != 0: + raise CoconutException("--profile incompatible with --jobs {jobs}".format(jobs=jobs)) + + def disable_jobs(self): + """Disables use of --jobs.""" + if self.jobs not in (0, 1, None): + logger.warn("got --jobs {jobs} but only compiling one file; disabling --jobs".format(jobs=self.jobs)) + self.jobs = 0 + logger.log("Jobs:", self.jobs) + + def get_max_workers(self): + """Get the max_workers to use for creating ProcessPoolExecutor.""" + jobs = self.jobs if self.jobs is not None else default_jobs + if jobs == "sys": + return None + else: + return jobs @property def using_jobs(self): """Determine whether or not multiprocessing is being used.""" - return self.jobs is None or self.jobs > 1 + max_workers = self.get_max_workers() + return max_workers is None or max_workers > 1 @contextmanager def running_jobs(self, exit_on_error=True): @@ -607,7 +649,7 @@ def running_jobs(self, exit_on_error=True): if self.using_jobs: from concurrent.futures import ProcessPoolExecutor try: - with ProcessPoolExecutor(self.jobs) as self.executor: + with ProcessPoolExecutor(self.get_max_workers()) as self.executor: yield finally: self.executor = None @@ -798,7 +840,10 @@ def run_mypy(self, paths=(), code=None): if code is None: # file logger.printerr(line) self.register_exit_code(errmsg="MyPy error") - elif not line.startswith(mypy_silent_non_err_prefixes): + elif line.startswith(mypy_silent_non_err_prefixes): + if code is None: # file + logger.print("MyPy", line) + else: if code is None: # file logger.printerr(line) if any(infix in line for infix in mypy_err_infixes): @@ -923,10 +968,9 @@ def start_jupyter(self, args): logger.warn("could not find {name!r} kernel; using {kernel!r} kernel instead".format(name=icoconut_custom_kernel_name, kernel=kernel)) # pass the kernel to the console or otherwise just launch Jupyter now that we know our kernel is available - if args[0] == "console": - run_args = jupyter + ["console", "--kernel", kernel] + args[1:] - else: - run_args = jupyter + args + if args[0] in jupyter_console_commands: + args += ["--kernel", kernel] + run_args = jupyter + args if newly_installed_kernels: logger.show_sig("Successfully installed Jupyter kernels: '" + "', '".join(newly_installed_kernels) + "'") diff --git a/coconut/command/util.py b/coconut/command/util.py index 74dfe8394..8403def86 100644 --- a/coconut/command/util.py +++ b/coconut/command/util.py @@ -23,6 +23,7 @@ import os import subprocess import shutil +import threading from select import select from contextlib import contextmanager from functools import partial @@ -73,6 +74,9 @@ interpreter_uses_auto_compilation, interpreter_uses_coconut_breakpoint, interpreter_compiler_var, + must_use_specific_target_builtins, + kilobyte, + min_stack_size_kbs, ) if PY26: @@ -216,7 +220,7 @@ def handling_broken_process_pool(): yield except BrokenProcessPool: logger.log_exc() - raise BaseCoconutException("broken process pool") + raise BaseCoconutException("broken process pool (if this is due to a stack overflow, you may be able to fix by re-running with a larger '--stack-size', otherwise try disabling multiprocessing with '--jobs 0')") def kill_children(): @@ -226,7 +230,7 @@ def kill_children(): except ImportError: logger.warn( "missing psutil; --jobs may not properly terminate", - extra="run '{python} -m pip install coconut[jobs]' to fix".format(python=sys.executable), + extra="run '{python} -m pip install psutil' to fix".format(python=sys.executable), ) else: parent = psutil.Process() @@ -425,6 +429,20 @@ def invert_mypy_arg(arg): return None +def run_with_stack_size(stack_kbs, func, *args, **kwargs): + """Run the given function with a stack of the given size in KBs.""" + if stack_kbs < min_stack_size_kbs: + raise CoconutException("--stack-size must be at least " + str(min_stack_size_kbs) + " KB") + old_stack_size = threading.stack_size(stack_kbs * kilobyte) + out = [] + thread = threading.Thread(target=lambda *args, **kwargs: out.append(func(*args, **kwargs)), args=args, kwargs=kwargs) + thread.start() + thread.join() + logger.log("Stack size used:", old_stack_size, "->", stack_kbs * kilobyte) + internal_assert(len(out) == 1, "invalid threading results", out) + return out[0] + + # ----------------------------------------------------------------------------------------------------------------------- # CLASSES: # ----------------------------------------------------------------------------------------------------------------------- @@ -568,7 +586,7 @@ def fix_pickle(self): """Fix pickling of Coconut header objects.""" from coconut import __coconut__ # this is expensive, so only do it here for var in self.vars: - if not var.startswith("__") and var in dir(__coconut__): + if not var.startswith("__") and var in dir(__coconut__) and var not in must_use_specific_target_builtins: cur_val = self.vars[var] static_val = getattr(__coconut__, var) if getattr(cur_val, "__doc__", None) == getattr(static_val, "__doc__", None): @@ -649,21 +667,26 @@ class multiprocess_wrapper(pickleable_obj): """Wrapper for a method that needs to be multiprocessed.""" __slots__ = ("base", "method", "rec_limit", "logger", "argv") - def __init__(self, base, method, _rec_limit=None, _logger=None, _argv=None): + def __init__(self, base, method, stack_size=None, _rec_limit=None, _logger=None, _argv=None): """Create new multiprocessable method.""" self.base = base self.method = method + self.stack_size = stack_size self.rec_limit = sys.getrecursionlimit() if _rec_limit is None else _rec_limit self.logger = logger.copy() if _logger is None else _logger self.argv = sys.argv if _argv is None else _argv def __reduce__(self): """Pickle for transfer across processes.""" - return (self.__class__, (self.base, self.method, self.rec_limit, self.logger, self.argv)) + return (self.__class__, (self.base, self.method, self.stack_size, self.rec_limit, self.logger, self.argv)) def __call__(self, *args, **kwargs): """Call the method.""" sys.setrecursionlimit(self.rec_limit) logger.copy_from(self.logger) sys.argv = self.argv - return getattr(self.base, self.method)(*args, **kwargs) + func = getattr(self.base, self.method) + if self.stack_size: + return run_with_stack_size(self.stack_size, func, args, kwargs) + else: + return func(*args, **kwargs) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 69c5d371f..6f0ff640c 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -86,6 +86,7 @@ streamline_grammar_for_len, all_builtins, in_place_op_funcs, + match_first_arg_var, ) from coconut.util import ( pickleable_obj, @@ -95,6 +96,7 @@ clean, get_target_info, get_clock_time, + get_name, ) from coconut.exceptions import ( CoconutException, @@ -142,7 +144,7 @@ parse, all_matches, get_target_info_smart, - split_leading_comment, + split_leading_comments, compile_regex, append_it, interleaved_join, @@ -157,8 +159,8 @@ rem_and_count_indents, normalize_indent_markers, try_parse, + does_parse, prep_grammar, - split_leading_whitespace, ordered, tuple_str_of_str, dict_to_str, @@ -175,6 +177,13 @@ # ----------------------------------------------------------------------------------------------------------------------- +match_func_paramdef = "{match_first_arg_var}=_coconut_sentinel, *{match_to_args_var}, **{match_to_kwargs_var}".format( + match_first_arg_var=match_first_arg_var, + match_to_args_var=match_to_args_var, + match_to_kwargs_var=match_to_kwargs_var, +) + + def set_to_tuple(tokens): """Converts set literal tokens to tuples.""" internal_assert(len(tokens) == 1, "invalid set maker tokens", tokens) @@ -355,6 +364,58 @@ def reconstitute_paramdef(pos_only_args, req_args, default_args, star_arg, kwd_o return ", ".join(args_list) +def split_star_expr_tokens(tokens, is_dict=False): + """Split testlist_star_expr or dict_literal tokens.""" + groups = [[]] + has_star = False + has_comma = False + for tok_grp in tokens: + if tok_grp == ",": + has_comma = True + elif len(tok_grp) == 1: + internal_assert(not is_dict, "found non-star non-pair item in dict literal", tok_grp) + groups[-1].append(tok_grp[0]) + elif len(tok_grp) == 2: + internal_assert(not tok_grp[0].lstrip("*"), "invalid star expr item signifier", tok_grp[0]) + has_star = True + groups.append(tok_grp[1]) + groups.append([]) + elif len(tok_grp) == 3: + internal_assert(is_dict, "found dict key-value pair in non-dict tokens", tok_grp) + k, c, v = tok_grp + internal_assert(c == ":", "invalid colon in dict literal item", c) + groups[-1].append((k, v)) + else: + raise CoconutInternalException("invalid testlist_star_expr tokens", tokens) + if not groups[-1]: + groups.pop() + return groups, has_star, has_comma + + +def join_dict_group(group, as_tuples=False): + """Join group from split_star_expr_tokens$(is_dict=True).""" + items = [] + for k, v in group: + if as_tuples: + items.append("(" + k + ", " + v + ")") + else: + items.append(k + ": " + v) + if as_tuples: + return tuple_str_of(items, add_parens=False) + else: + return ", ".join(items) + + +def call_decorators(decorators, func_name): + """Convert decorators into function calls on func_name.""" + out = func_name + for decorator in reversed(decorators.splitlines()): + internal_assert(decorator.startswith("@"), "invalid decorator", decorator) + base_decorator = rem_comment(decorator[1:]) + out = "(" + base_decorator + ")(" + out + ")" + return out + + # end: UTILITIES # ----------------------------------------------------------------------------------------------------------------------- # COMPILER: @@ -375,6 +436,7 @@ class Compiler(Grammar, pickleable_obj): ] reformatprocs = [ + # deferred_code_proc must come first lambda self: self.deferred_code_proc, lambda self: self.reind_proc, lambda self: self.endline_repl, @@ -390,13 +452,17 @@ def __init__(self, *args, **kwargs): """Creates a new compiler with the given parsing parameters.""" self.setup(*args, **kwargs) - # changes here should be reflected in the stub for coconut.convenience.setup + # changes here should be reflected in __reduce__ and in the stub for coconut.convenience.setup def setup(self, target=None, strict=False, minify=False, line_numbers=False, keep_lines=False, no_tco=False, no_wrap=False): """Initializes parsing parameters.""" if target is None: target = "" else: - target = str(target).replace(".", "") + target = str(target) + if len(target) > 1 and target[1] == ".": + target = target[:1] + target[2:] + if "." in target: + raise CoconutException("target Python version must be major.minor, not major.minor.micro") if target == "sys": target = sys_target if target in pseudo_targets: @@ -450,10 +516,15 @@ def genhash(self, code, package_level=-1): temp_var_counts = None operators = None - def reset(self, keep_state=False): - """Resets references.""" + def reset(self, keep_state=False, filename=None): + """Reset references. + + IMPORTANT: When adding anything here, consider whether it should also be added to inner_environment. + """ + self.filename = filename self.indchar = None self.comments = {} + self.wrapped_type_ignore = None self.refs = [] self.skips = [] self.docstring = "" @@ -461,9 +532,6 @@ def reset(self, keep_state=False): if self.temp_var_counts is None or not keep_state: self.temp_var_counts = defaultdict(int) self.parsing_context = defaultdict(list) - self.add_code_before = {} - self.add_code_before_regexes = {} - self.add_code_before_replacements = {} self.unused_imports = defaultdict(list) self.kept_lines = [] self.num_lines = 0 @@ -471,6 +539,10 @@ def reset(self, keep_state=False): if self.operators is None or not keep_state: self.operators = [] self.operator_repl_table = [] + self.add_code_before = {} + self.add_code_before_regexes = {} + self.add_code_before_replacements = {} + self.add_code_before_ignore_names = {} @contextmanager def inner_environment(self): @@ -478,6 +550,7 @@ def inner_environment(self): line_numbers, self.line_numbers = self.line_numbers, False keep_lines, self.keep_lines = self.keep_lines, False comments, self.comments = self.comments, {} + wrapped_type_ignore, self.wrapped_type_ignore = self.wrapped_type_ignore, None skips, self.skips = self.skips, [] docstring, self.docstring = self.docstring, "" parsing_context, self.parsing_context = self.parsing_context, defaultdict(list) @@ -489,6 +562,7 @@ def inner_environment(self): self.line_numbers = line_numbers self.keep_lines = keep_lines self.comments = comments + self.wrapped_type_ignore = wrapped_type_ignore self.skips = skips self.docstring = docstring self.parsing_context = parsing_context @@ -602,6 +676,7 @@ def bind(cls): cls.endline <<= attach(cls.endline_ref, cls.method("endline_handle")) cls.normal_pipe_expr <<= trace_attach(cls.normal_pipe_expr_tokens, cls.method("pipe_handle")) cls.return_typedef <<= trace_attach(cls.return_typedef_ref, cls.method("typedef_handle")) + cls.power_in_impl_call <<= trace_attach(cls.power, cls.method("power_in_impl_call_check")) # handle all atom + trailers constructs with item_handle cls.trailer_atom <<= trace_attach(cls.trailer_atom_ref, cls.method("item_handle")) @@ -612,6 +687,10 @@ def bind(cls): cls.string_atom <<= trace_attach(cls.string_atom_ref, cls.method("string_atom_handle")) cls.f_string_atom <<= trace_attach(cls.f_string_atom_ref, cls.method("string_atom_handle")) + # handle all keyword funcdefs with keyword_funcdef_handle + cls.keyword_funcdef <<= trace_attach(cls.keyword_funcdef_ref, cls.method("keyword_funcdef_handle")) + cls.async_keyword_funcdef <<= trace_attach(cls.async_keyword_funcdef_ref, cls.method("keyword_funcdef_handle")) + # standard handlers of the form name <<= trace_attach(name_tokens, method("name_handle")) (implies name_tokens is reused) cls.function_call <<= trace_attach(cls.function_call_tokens, cls.method("function_call_handle")) cls.testlist_star_namedexpr <<= trace_attach(cls.testlist_star_namedexpr_tokens, cls.method("testlist_star_expr_handle")) @@ -649,6 +728,8 @@ def bind(cls): cls.base_match_for_stmt <<= trace_attach(cls.base_match_for_stmt_ref, cls.method("base_match_for_stmt_handle")) cls.unsafe_typedef_tuple <<= trace_attach(cls.unsafe_typedef_tuple_ref, cls.method("unsafe_typedef_tuple_handle")) cls.funcname_typeparams <<= trace_attach(cls.funcname_typeparams_ref, cls.method("funcname_typeparams_handle")) + cls.impl_call <<= trace_attach(cls.impl_call_ref, cls.method("impl_call_handle")) + cls.protocol_intersect_expr <<= trace_attach(cls.protocol_intersect_expr_ref, cls.method("protocol_intersect_expr_handle")) # these handlers just do strict/target checking cls.u_string <<= trace_attach(cls.u_string_ref, cls.method("u_string_check")) @@ -697,6 +778,10 @@ def adjust(self, ln, skips=None): adj_ln = i return adj_ln + need_unskipped + def reformat_post_deferred_code_proc(self, snip): + """Do post-processing that comes after deferred_code_proc.""" + return self.apply_procs(self.reformatprocs[1:], snip, reformatting=True, log=False) + def reformat(self, snip, *indices, **kwargs): """Post process a preprocessed snippet.""" internal_assert("ignore_errors" in kwargs, "reformat() missing required keyword argument: 'ignore_errors'") @@ -711,6 +796,12 @@ def reformat(self, snip, *indices, **kwargs): + tuple(len(self.reformat(snip[:index], **kwargs)) for index in indices) ) + def reformat_without_adding_code_before(self, code, **kwargs): + """Reformats without adding code before and instead returns what would have been added.""" + got_code_to_add_before = {} + reformatted_code = self.reformat(code, put_code_to_add_before_in=got_code_to_add_before, **kwargs) + return reformatted_code, tuple(got_code_to_add_before.keys()), got_code_to_add_before.values() + def literal_eval(self, code): """Version of ast.literal_eval that reformats first.""" return literal_eval(self.reformat(code, ignore_errors=False)) @@ -752,6 +843,7 @@ def complain_on_err(self): try: yield except ParseBaseException as err: + # don't reformat, since we might have gotten here because reformat failed complain(self.make_parse_err(err, reformat=False, include_ln=False)) except CoconutException as err: complain(err) @@ -819,11 +911,8 @@ def wrap_passthrough(self, text, multiline=True, early=False): out += "\n" return out - def wrap_comment(self, text, reformat=True): + def wrap_comment(self, text): """Wrap a comment.""" - if reformat: - whitespace, base_comment = split_leading_whitespace(text) - text = whitespace + self.reformat(base_comment, ignore_errors=False) return "#" + self.add_ref("comment", text) + unwrapper def wrap_error(self, error): @@ -838,7 +927,10 @@ def raise_or_wrap_error(self, error): return self.wrap_error(error) def type_ignore_comment(self): - return self.wrap_comment(" type: ignore", reformat=False) + """Get a "type: ignore" comment.""" + if self.wrapped_type_ignore is None: + self.wrapped_type_ignore = self.wrap_comment(" type: ignore") + return self.wrapped_type_ignore def wrap_line_number(self, ln): """Wrap a line number.""" @@ -860,13 +952,18 @@ def apply_procs(self, procs, inputstring, log=True, **kwargs): def pre(self, inputstring, **kwargs): """Perform pre-processing.""" + log = kwargs.get("log", True) out = self.apply_procs(self.preprocs, str(inputstring), **kwargs) - logger.log_tag("skips", self.skips) + if log: + logger.log_tag("skips", self.skips) return out def post(self, result, **kwargs): """Perform post-processing.""" internal_assert(isinstance(result, str), "got non-string parse result", result) + log = kwargs.get("log", True) + if log: + logger.log_tag("before post-processing", result, multiline=True) return self.apply_procs(self.postprocs, result, **kwargs) def getheader(self, which, use_hash=None, polish=True): @@ -938,7 +1035,7 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor if extra is not None: kwargs["extra"] = extra - return errtype(message, snippet, loc_in_snip, ln, endpoint=endpt_in_snip, **kwargs) + return errtype(message, snippet, loc_in_snip, ln, endpoint=endpt_in_snip, filename=self.filename, **kwargs) def make_syntax_err(self, err, original): """Make a CoconutSyntaxError from a CoconutDeferredSyntaxError.""" @@ -980,10 +1077,10 @@ def inner_parse_eval( return self.post(parsed, **postargs) @contextmanager - def parsing(self, keep_state=False): + def parsing(self, keep_state=False, filename=None): """Acquire the lock and reset the parser.""" with self.lock: - self.reset(keep_state) + self.reset(keep_state, filename) self.current_compiler[0] = self yield @@ -994,7 +1091,7 @@ def streamline(self, grammar, inputstring=""): prep_grammar(grammar, streamline=True) logger.log_lambda( lambda: "Streamlined {grammar} in {time} seconds (streamlined due to receiving input of length {length}).".format( - grammar=grammar.name, + grammar=get_name(grammar), time=get_clock_time() - start_time, length=len(inputstring), ), @@ -1017,9 +1114,9 @@ def run_final_checks(self, original, keep_state=False): loc, ) - def parse(self, inputstring, parser, preargs, postargs, streamline=True, keep_state=False): + def parse(self, inputstring, parser, preargs, postargs, streamline=True, keep_state=False, filename=None): """Use the parser to parse the inputstring with appropriate setup and teardown.""" - with self.parsing(keep_state): + with self.parsing(keep_state, filename): if streamline: self.streamline(parser, inputstring) with logger.gather_parsing_stats(): @@ -1033,10 +1130,11 @@ def parse(self, inputstring, parser, preargs, postargs, streamline=True, keep_st except CoconutDeferredSyntaxError as err: internal_assert(pre_procd is not None, "invalid deferred syntax error in pre-processing", err) raise self.make_syntax_err(err, pre_procd) + # RuntimeError, not RecursionError, for Python < 3.5 except RuntimeError as err: raise CoconutException( str(err), extra="try again with --recursion-limit greater than the current " - + str(sys.getrecursionlimit()), + + str(sys.getrecursionlimit()) + " (you may also need to increase --stack-size)", ) self.run_final_checks(pre_procd, keep_state) return out @@ -1084,7 +1182,7 @@ def str_proc(self, inputstring, **kwargs): if hold is not None: if len(hold) == 1: # hold == [_comment] if c == "\n": - out.append(self.wrap_comment(hold[_comment], reformat=False) + c) + out += [self.wrap_comment(hold[_comment]), c] hold = None else: hold[_comment] += c @@ -1148,9 +1246,9 @@ def str_proc(self, inputstring, **kwargs): if hold is not None or found is not None: raise self.make_err(CoconutSyntaxError, "unclosed string", inputstring, x, reformat=False) - else: - self.set_skips(skips) - return "".join(out) + + self.set_skips(skips) + return "".join(out) def passthrough_proc(self, inputstring, **kwargs): """Process python passthroughs.""" @@ -1187,7 +1285,7 @@ def passthrough_proc(self, inputstring, **kwargs): count = -1 multiline = True else: - out.append("\\" + c) + out += ["\\", c] found = None elif c == "\\": found = True @@ -1250,7 +1348,7 @@ def operator_proc(self, inputstring, keep_state=False, **kwargs): any_delimiter = r"|".join(re.escape(sym) for sym in delimiter_symbols) self.operator_repl_table.append(( compile_regex(r"(^|\s|(? 0 else None break ''', - add_newline=True, ).format( expr=expr, yield_from_var=self.get_temp_var("yield_from"), @@ -2405,7 +2627,7 @@ def endline_handle(self, original, loc, tokens): out = [] ln = lineno(loc, original) for endline in lines: - out.append(self.wrap_line_number(ln) + endline) + out += [self.wrap_line_number(ln), endline] ln += 1 return "".join(out) @@ -2532,14 +2754,20 @@ def match_datadef_handle(self, original, loc, tokens): matcher = self.get_matcher(original, loc, check_var, name_list=[]) pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg = split_args_list(matches, loc) - matcher.match_function(match_to_args_var, match_to_kwargs_var, pos_only_args, req_args + default_args, star_arg, kwd_only_args, dubstar_arg) + matcher.match_function( + pos_only_match_args=pos_only_args, + match_args=req_args + default_args, + star_arg=star_arg, + kwd_only_match_args=kwd_only_args, + dubstar_arg=dubstar_arg, + ) if cond is not None: matcher.add_guard(cond) extra_stmts = handle_indentation( ''' -def __new__(_coconut_cls, *{match_to_args_var}, **{match_to_kwargs_var}): +def __new__(_coconut_cls, {match_func_paramdef}): {check_var} = False {matching} {pattern_error} @@ -2547,8 +2775,7 @@ def __new__(_coconut_cls, *{match_to_args_var}, **{match_to_kwargs_var}): ''', add_newline=True, ).format( - match_to_args_var=match_to_args_var, - match_to_kwargs_var=match_to_kwargs_var, + match_func_paramdef=match_func_paramdef, check_var=check_var, matching=matcher.out(), pattern_error=self.pattern_error(original, loc, match_to_args_var, check_var, function_match_error_var), @@ -2733,17 +2960,24 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts, definition of Expected in header.py_template. """ # create class - out = ( - "".join(paramdefs) - + decorators - + "class " + name + "(" - + namedtuple_call - + (", " + inherit if inherit is not None else "") - + (", " + self.get_generic_for_typevars() if paramdefs else "") - + (", _coconut.object" if not self.target.startswith("3") else "") - + "):\n" - + openindent - ) + out = [ + "".join(paramdefs), + decorators, + "class ", + name, + "(", + namedtuple_call, + ] + if inherit is not None: + out += [", ", inherit] + if paramdefs: + out += [", ", self.get_generic_for_typevars()] + if not self.target.startswith("3"): + out.append(", _coconut.object") + out += [ + "):\n", + openindent, + ] # add universal statements all_extra_stmts = handle_indentation( @@ -2770,31 +3004,31 @@ def __hash__(self): # manage docstring rest = None if "simple" in stmts and len(stmts) == 1: - out += all_extra_stmts + out += [all_extra_stmts] rest = stmts[0] elif "docstring" in stmts and len(stmts) == 1: - out += stmts[0] + all_extra_stmts + out += [stmts[0], all_extra_stmts] elif "complex" in stmts and len(stmts) == 1: - out += all_extra_stmts + out += [all_extra_stmts] rest = "".join(stmts[0]) elif "complex" in stmts and len(stmts) == 2: - out += stmts[0] + all_extra_stmts + out += [stmts[0], all_extra_stmts] rest = "".join(stmts[1]) elif "empty" in stmts and len(stmts) == 1: - out += all_extra_stmts.rstrip() + stmts[0] + out += [all_extra_stmts.rstrip(), stmts[0]] else: raise CoconutInternalException("invalid inner data tokens", stmts) # create full data definition if rest is not None and rest != "pass\n": - out += rest - out += closeindent + out.append(rest) + out.append(closeindent) # add override detection if self.target_info < (3, 6): - out += "_coconut_call_set_names(" + name + ")\n" + out += ["_coconut_call_set_names(", name, ")\n"] - return out + return "".join(out) def anon_namedtuple_handle(self, tokens): """Handle anonymous named tuples.""" @@ -2915,7 +3149,7 @@ def universal_import(self, imports, imp_from=None): stmts.append( handle_indentation(""" try: - {store_var} = sys + {store_var} = sys {type_ignore} except _coconut.NameError: {store_var} = _coconut_sentinel sys = _coconut_sys @@ -2931,6 +3165,7 @@ def universal_import(self, imports, imp_from=None): new_imp="\n".join(self.single_import(new_imp, imp_as)), # should only type: ignore the old import old_imp="\n".join(self.single_import(old_imp, imp_as, type_ignore=type_ignore)), + type_ignore=self.type_ignore_comment(), ), ) return "\n".join(stmts) @@ -2962,20 +3197,26 @@ def complex_raise_stmt_handle(self, tokens): if self.target.startswith("3"): return "raise " + raise_expr + " from " + from_expr else: - raise_from_var = self.get_temp_var("raise_from") - return ( - raise_from_var + " = " + raise_expr + "\n" - + raise_from_var + ".__cause__ = " + from_expr + "\n" - + "raise " + raise_from_var + return handle_indentation( + ''' +{raise_from_var} = {raise_expr} +{raise_from_var}.__cause__ = {from_expr} +raise {raise_from_var} + ''', + ).format( + raise_from_var=self.get_temp_var("raise_from"), + raise_expr=raise_expr, + from_expr=from_expr, ) def dict_comp_handle(self, loc, tokens): """Process Python 2.7 dictionary comprehension.""" key, val, comp = tokens - if self.target.startswith("3"): + # on < 3.9 have to use _coconut.dict since it's different than py_dict + if self.target_info >= (3, 9): return "{" + key + ": " + val + " " + comp + "}" else: - return "dict(((" + key + "), (" + val + ")) " + comp + ")" + return "_coconut.dict(((" + key + "), (" + val + ")) " + comp + ")" def pattern_error(self, original, loc, value_var, check_var, match_error_class='_coconut_MatchError'): """Construct a pattern-matching error message.""" @@ -3020,9 +3261,15 @@ def full_match_handle(self, original, loc, tokens, match_to_var=None, match_chec matching.match(matches, match_to_var) if cond: matching.add_guard(cond) - return ( - match_to_var + " = " + item + "\n" - + matching.build(stmts, invert=invert) + return handle_indentation( + ''' +{match_to_var} = {item} +{match} + ''', + ).format( + match_to_var=match_to_var, + item=item, + match=matching.build(stmts, invert=invert), ) def destructuring_stmt_handle(self, original, loc, tokens): @@ -3048,15 +3295,18 @@ def name_match_funcdef_handle(self, original, loc, tokens): matcher = self.get_matcher(original, loc, check_var) pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg = split_args_list(matches, loc) - matcher.match_function(match_to_args_var, match_to_kwargs_var, pos_only_args, req_args + default_args, star_arg, kwd_only_args, dubstar_arg) + matcher.match_function( + pos_only_match_args=pos_only_args, + match_args=req_args + default_args, + star_arg=star_arg, + kwd_only_match_args=kwd_only_args, + dubstar_arg=dubstar_arg, + ) if cond is not None: matcher.add_guard(cond) - before_colon = ( - "def " + func - + "(*" + match_to_args_var + ", **" + match_to_kwargs_var + ")" - ) + before_colon = "def " + func + "(" + match_func_paramdef + ")" after_docstring = ( openindent + check_var + " = False\n" @@ -3120,13 +3370,16 @@ def set_letter_literal_handle(self, tokens): def stmt_lambdef_handle(self, original, loc, tokens): """Process multi-line lambdef statements.""" - kwds, params, stmts_toks = tokens + got_kwds, params, stmts_toks = tokens is_async = False - for kwd in kwds: + add_kwds = [] + for kwd in got_kwds: if kwd == "async": self.internal_assert(not is_async, original, loc, "duplicate stmt_lambdef async keyword", kwd) is_async = True + elif kwd == "copyclosure": + add_kwds.append(kwd) else: raise CoconutInternalException("invalid stmt_lambdef keyword", kwd) @@ -3158,6 +3411,8 @@ def stmt_lambdef_handle(self, original, loc, tokens): + body ) + funcdef = " ".join(add_kwds + [funcdef]) + self.add_code_before[name] = self.decoratable_funcdef_stmt_handle(original, loc, [decorators, funcdef], is_async, is_stmt_lambda=True) return name @@ -3194,22 +3449,52 @@ def await_expr_handle(self, original, loc, tokens): def unsafe_typedef_handle(self, tokens): """Process type annotations without a comma after them.""" - return self.typedef_handle(tokens.asList() + [","]) + # we add an empty string token to take the place of the comma, + # but it should be empty so we don't actually put a comma in + return self.typedef_handle(tokens.asList() + [""]) + + def wrap_code_before(self, add_code_before_list): + """Wrap code to add before by putting it behind a TYPE_CHECKING check.""" + if not add_code_before_list: + return "" + return "if _coconut.typing.TYPE_CHECKING:\n" + openindent + "\n".join(add_code_before_list) + closeindent - def wrap_typedef(self, typedef, for_py_typedef): + def wrap_typedef(self, typedef, for_py_typedef, duplicate=False): """Wrap a type definition in a string to defer it unless --no-wrap or __future__.annotations.""" if self.no_wrap or for_py_typedef and self.target_info >= (3, 7): return typedef else: - return self.wrap_str_of(self.reformat(typedef, ignore_errors=False)) + reformatted_typedef, ignore_names, add_code_before_list = self.reformat_without_adding_code_before(typedef, ignore_errors=False) + wrapped = self.wrap_str_of(reformatted_typedef) + if duplicate: + # duplicate means that the necessary add_code_before will already have been done + add_code_before = "" + else: + # since we're wrapping the typedef, also wrap the code to add before + add_code_before = self.wrap_code_before(add_code_before_list) + return self.add_code_before_marker_with_replacement(wrapped, add_code_before, ignore_names=ignore_names) + + def wrap_type_comment(self, typedef, is_return=False, add_newline=False): + reformatted_typedef, ignore_names, add_code_before_list = self.reformat_without_adding_code_before(typedef, ignore_errors=False) + if is_return: + type_comment = " type: (...) -> " + reformatted_typedef + else: + type_comment = " type: " + reformatted_typedef + wrapped = self.wrap_comment(type_comment) + if add_newline: + wrapped += non_syntactic_newline + # since we're wrapping the typedef, also wrap the code to add before + add_code_before = self.wrap_code_before(add_code_before_list) + return self.add_code_before_marker_with_replacement(wrapped, add_code_before, ignore_names=ignore_names) def typedef_handle(self, tokens): """Process Python 3 type annotations.""" if len(tokens) == 1: # return typedef + typedef, = tokens if self.target.startswith("3"): - return " -> " + self.wrap_typedef(tokens[0], for_py_typedef=True) + ":" + return " -> " + self.wrap_typedef(typedef, for_py_typedef=True) + ":" else: - return ":\n" + self.wrap_comment(" type: (...) -> " + tokens[0]) + return ":\n" + self.wrap_type_comment(typedef, is_return=True) else: # argument typedef if len(tokens) == 3: varname, typedef, comma = tokens @@ -3221,7 +3506,7 @@ def typedef_handle(self, tokens): if self.target.startswith("3"): return varname + ": " + self.wrap_typedef(typedef, for_py_typedef=True) + default + comma else: - return varname + default + comma + self.wrap_passthrough(self.wrap_comment(" type: " + typedef) + non_syntactic_newline, early=True) + return varname + default + comma + self.wrap_type_comment(typedef, add_newline=True) def typed_assign_stmt_handle(self, tokens): """Process Python 3.6 variable type annotations.""" @@ -3236,19 +3521,22 @@ def typed_assign_stmt_handle(self, tokens): if self.target_info >= (3, 6): return name + ": " + self.wrap_typedef(typedef, for_py_typedef=True) + ("" if value is None else " = " + value) else: - return handle_indentation(''' + return handle_indentation( + ''' {name} = {value}{comment} if "__annotations__" not in _coconut.locals(): - __annotations__ = {{}} + __annotations__ = {{}} {type_ignore} __annotations__["{name}"] = {annotation} - ''').format( + ''', + ).format( name=name, value=( value if value is not None else "_coconut.typing.cast(_coconut.typing.Any, {ellipsis})".format(ellipsis=self.ellipsis_handle()) ), - comment=self.wrap_comment(" type: " + typedef), - annotation=self.wrap_typedef(typedef, for_py_typedef=False), + comment=self.wrap_type_comment(typedef), + annotation=self.wrap_typedef(typedef, for_py_typedef=False, duplicate=True), + type_ignore=self.type_ignore_comment(), ) def funcname_typeparams_handle(self, tokens): @@ -3258,23 +3546,20 @@ def funcname_typeparams_handle(self, tokens): return name else: name, paramdefs = tokens - # temp_marker will be set back later, but needs to be a unique name until then for add_code_before - temp_marker = self.get_temp_var("type_param_func") - self.add_code_before[temp_marker] = "".join(paramdefs) - self.add_code_before_replacements[temp_marker] = name - return temp_marker + return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) funcname_typeparams_handle.ignore_one_token = True def type_param_handle(self, original, loc, tokens): """Compile a type param into an assignment.""" bounds = "" + kwargs = "" if "TypeVar" in tokens: TypeVarFunc = "TypeVar" - if len(tokens) == 1: - name, = tokens + if len(tokens) == 2: + name_loc, name = tokens else: - name, bound_op, bound = tokens + name_loc, name, bound_op, bound = tokens if bound_op == "<=": self.strict_err_or_warn( "use of " + repr(bound_op) + " as a type parameter bound declaration operator is deprecated (Coconut style is to use '<:' operator)", @@ -3290,28 +3575,40 @@ def type_param_handle(self, original, loc, tokens): else: self.internal_assert(bound_op == "<:", original, loc, "invalid type_param bound_op", bound_op) bounds = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False) + # uncomment this line whenever mypy adds support for infer_variance in TypeVar + # (and remove the warning about it in the DOCS) + # kwargs = ", infer_variance=True" elif "TypeVarTuple" in tokens: TypeVarFunc = "TypeVarTuple" - name, = tokens + name_loc, name = tokens elif "ParamSpec" in tokens: TypeVarFunc = "ParamSpec" - name, = tokens + name_loc, name = tokens else: raise CoconutInternalException("invalid type_param tokens", tokens) + name_loc = int(name_loc) + internal_assert(name_loc == loc if TypeVarFunc == "TypeVar" else name_loc >= loc, "invalid name location for " + TypeVarFunc, (name_loc, loc, tokens)) + typevar_info = self.current_parsing_context("typevars") if typevar_info is not None: - if name in typevar_info["all_typevars"]: - raise CoconutDeferredSyntaxError("type variable {name!r} already defined", loc) - temp_name = self.get_temp_var("typevar_" + name) - typevar_info["all_typevars"][name] = temp_name - typevar_info["new_typevars"].append((TypeVarFunc, temp_name)) - name = temp_name - - return '{name} = _coconut.typing.{TypeVarFunc}("{name}"{bounds})\n'.format( + # check to see if we already parsed this exact typevar, in which case just reuse the existing temp_name + if typevar_info["typevar_locs"].get(name, None) == name_loc: + name = typevar_info["all_typevars"][name] + else: + if name in typevar_info["all_typevars"]: + raise CoconutDeferredSyntaxError("type variable {name!r} already defined".format(name=name), loc) + temp_name = self.get_temp_var("typevar_" + name) + typevar_info["all_typevars"][name] = temp_name + typevar_info["new_typevars"].append((TypeVarFunc, temp_name)) + typevar_info["typevar_locs"][name] = name_loc + name = temp_name + + return '{name} = _coconut.typing.{TypeVarFunc}("{name}"{bounds}{kwargs})\n'.format( name=name, TypeVarFunc=TypeVarFunc, bounds=bounds, + kwargs=kwargs, ) def get_generic_for_typevars(self): @@ -3328,7 +3625,7 @@ def get_generic_for_typevars(self): else: generics.append("_coconut.typing.Unpack[" + name + "]") else: - raise CoconutInternalException("invalid TypeVarFunc", TypeVarFunc) + raise CoconutInternalException("invalid TypeVarFunc", TypeVarFunc, "(", name, ")") return "_coconut.typing.Generic[" + ", ".join(generics) + "]" @contextmanager @@ -3339,6 +3636,7 @@ def type_alias_stmt_manage(self, item=None, original=None, loc=None): typevars_stack.append({ "all_typevars": {} if prev_typevar_info is None else prev_typevar_info["all_typevars"].copy(), "new_typevars": [], + "typevar_locs": {}, }) try: yield @@ -3549,30 +3847,9 @@ def unsafe_typedef_or_expr_handle(self, tokens): else: return "_coconut.typing.Union[" + ", ".join(tokens) + "]" - def split_star_expr_tokens(self, tokens): - """Split testlist_star_expr or dict_literal tokens.""" - groups = [[]] - has_star = False - has_comma = False - for tok_grp in tokens: - if tok_grp == ",": - has_comma = True - elif len(tok_grp) == 1: - groups[-1].append(tok_grp[0]) - elif len(tok_grp) == 2: - internal_assert(not tok_grp[0].lstrip("*"), "invalid star expr item signifier", tok_grp[0]) - has_star = True - groups.append(tok_grp[1]) - groups.append([]) - else: - raise CoconutInternalException("invalid testlist_star_expr tokens", tokens) - if not groups[-1]: - groups.pop() - return groups, has_star, has_comma - def testlist_star_expr_handle(self, original, loc, tokens, is_list=False): """Handle naked a, *b.""" - groups, has_star, has_comma = self.split_star_expr_tokens(tokens) + groups, has_star, has_comma = split_star_expr_tokens(tokens) is_sequence = has_comma or is_list if not is_sequence and not has_star: @@ -3619,35 +3896,44 @@ def list_expr_handle(self, original, loc, tokens): """Handle non-comprehension list literals.""" return self.testlist_star_expr_handle(original, loc, tokens, is_list=True) + def make_dict(self, tok_grp): + """Construct a dictionary literal out of the given group.""" + # on < 3.9 have to use _coconut.dict since it's different than py_dict + if self.target_info >= (3, 9): + return "{" + join_dict_group(tok_grp) + "}" + else: + return "_coconut.dict((" + join_dict_group(tok_grp, as_tuples=True) + "))" + def dict_literal_handle(self, tokens): """Handle {**d1, **d2}.""" if not tokens: - return "{}" + # on < 3.9 have to use _coconut.dict since it's different than py_dict + return "{}" if self.target_info >= (3, 9) else "_coconut.dict()" - groups, has_star, _ = self.split_star_expr_tokens(tokens) + groups, has_star, _ = split_star_expr_tokens(tokens, is_dict=True) if not has_star: internal_assert(len(groups) == 1, "dict_literal group splitting failed on", tokens) - return "{" + ", ".join(groups[0]) + "}" + return self.make_dict(groups[0]) - # naturally supported on 3.5+ - elif self.target_info >= (3, 5): + # supported on 3.5, but only guaranteed to be ordered on 3.7 + elif self.target_info >= (3, 7): to_literal = [] for g in groups: - if isinstance(g, list): - to_literal.extend(g) - else: + if not isinstance(g, list): to_literal.append("**" + g) + elif g: + to_literal.append(join_dict_group(g)) return "{" + ", ".join(to_literal) + "}" # otherwise universalize else: to_merge = [] for g in groups: - if isinstance(g, list): - to_merge.append("{" + ", ".join(g) + "}") - else: + if not isinstance(g, list): to_merge.append(g) + elif g: + to_merge.append(self.make_dict(g)) return "_coconut_dict_merge(" + ", ".join(to_merge) + ")" def new_testlist_star_expr_handle(self, tokens): @@ -3716,6 +4002,53 @@ def term_handle(self, tokens): out += [op, term] return " ".join(out) + def impl_call_handle(self, loc, tokens): + """Process implicit function application or coefficient syntax.""" + internal_assert(len(tokens) >= 2, "invalid implicit call / coefficient tokens", tokens) + first_is_num = does_parse(self.number, tokens[0]) + if first_is_num: + if does_parse(self.number, tokens[1]): + raise CoconutDeferredSyntaxError("multiplying two or more numeric literals with implicit coefficient syntax is prohibited", loc) + return "(" + " * ".join(tokens) + ")" + else: + return "_coconut_call_or_coefficient(" + ", ".join(tokens) + ")" + + def keyword_funcdef_handle(self, tokens): + """Process function definitions with keywords in front.""" + keywords, funcdef = tokens + for kwd in keywords: + if kwd == "yield": + funcdef += handle_indentation( + """ +if False: + yield + """, + add_newline=True, + extra_indent=1, + ) + else: + # new keywords here must be replicated in def_regex and handled in proc_funcdef + internal_assert(kwd in ("addpattern", "copyclosure"), "unknown deferred funcdef keyword", kwd) + funcdef = kwd + " " + funcdef + return funcdef + + def protocol_intersect_expr_handle(self, tokens): + if len(tokens) == 1: + return tokens[0] + internal_assert(len(tokens) >= 2, "invalid protocol intersection tokens", tokens) + protocol_var = self.get_temp_var("protocol_intersection") + self.add_code_before[protocol_var] = handle_indentation( + ''' +class {protocol_var}({tokens}, _coconut.typing.Protocol): pass + ''', + ).format( + protocol_var=protocol_var, + tokens=", ".join(tokens), + ) + return protocol_var + + protocol_intersect_expr_handle.ignore_one_token = True + # end: HANDLERS # ----------------------------------------------------------------------------------------------------------------------- # CHECKING HANDLERS: @@ -3759,6 +4092,17 @@ def match_check_equals_check(self, original, loc, tokens): """Check for old-style =item in pattern-matching.""" return self.check_strict("deprecated equality-checking '=...' pattern; use '==...' instead", original, loc, tokens, always_warn=True) + def power_in_impl_call_check(self, original, loc, tokens): + """Check for exponentation in implicit function application / coefficient syntax.""" + return self.check_strict( + "syntax with new behavior in Coconut v3; 'f x ** y' is now equivalent to 'f(x**y)' not 'f(x)**y'", + original, + loc, + tokens, + only_warn=True, + always_warn=True, + ) + def check_py(self, version, name, original, loc, tokens): """Check for Python-version-specific syntax.""" self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) @@ -3839,17 +4183,21 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): if typevar_info is not None: typevars = typevar_info["all_typevars"] if name in typevars: - if assign: - return self.raise_or_wrap_error( - self.make_err( - CoconutSyntaxError, - "cannot reassign type variable '{name}'".format(name=name), - original, - loc, - extra="use explicit '\\{name}' syntax if intended".format(name=name), - ), - ) - return typevars[name] + # if we're looking at the same position where the typevar was defined, + # then we shouldn't treat this as a typevar, since then it's either + # a reparse of a setname in a typevar, or not a typevar at all + if typevar_info["typevar_locs"].get(name, None) != loc: + if assign: + return self.raise_or_wrap_error( + self.make_err( + CoconutSyntaxError, + "cannot reassign type variable '{name}'".format(name=name), + original, + loc, + extra="use explicit '\\{name}' syntax if intended".format(name=name), + ), + ) + return typevars[name] if not assign: self.unused_imports.pop(name, None) @@ -3890,11 +4238,7 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): if self.in_method: cls_context = self.current_parsing_context("class") enclosing_cls = cls_context["name_prefix"] + cls_context["name"] - # temp_marker will be set back later, but needs to be a unique name until then for add_code_before - temp_marker = self.get_temp_var("super") - self.add_code_before[temp_marker] = "__class__ = " + enclosing_cls + "\n" - self.add_code_before_replacements[temp_marker] = name - return temp_marker + return self.add_code_before_marker_with_replacement(name, "__class__ = " + enclosing_cls + "\n", add_spaces=False) else: return name elif not escaped and name.startswith(reserved_prefix) and name not in self.operators: diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 5f0d92a6a..0c830210e 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -29,6 +29,7 @@ from collections import defaultdict from contextlib import contextmanager +from functools import partial from coconut._pyparsing import ( CaselessLiteral, @@ -54,6 +55,7 @@ from coconut.util import ( memoize, get_clock_time, + keydefaultdict, ) from coconut.exceptions import ( CoconutInternalException, @@ -76,6 +78,8 @@ untcoable_funcs, early_passthrough_wrapper, new_operators, + wildcard, + op_func_protocols, ) from coconut.compiler.util import ( combine, @@ -95,21 +99,22 @@ split_trailing_indent, split_leading_indent, collapse_indents, - keyword, + base_keyword, match_in, disallow_keywords, regex_item, stores_loc_item, invalid_syntax, skip_to_in_line, - handle_indentation, labeled_group, any_keyword_in, any_char, tuple_str_of, any_len_perm, + any_len_perm_at_least_one, boundary, compile_regex, + always_match, ) @@ -468,7 +473,7 @@ def simple_kwd_assign_handle(tokens): simple_kwd_assign_handle.ignore_one_token = True -def compose_item_handle(tokens): +def compose_expr_handle(tokens): """Process function composition.""" if len(tokens) == 1: return tokens[0] @@ -476,13 +481,7 @@ def compose_item_handle(tokens): return "_coconut_forward_compose(" + ", ".join(reversed(tokens)) + ")" -compose_item_handle.ignore_one_token = True - - -def impl_call_item_handle(tokens): - """Process implicit function application.""" - internal_assert(len(tokens) > 1, "invalid implicit function application tokens", tokens) - return tokens[0] + "(" + ", ".join(tokens[1:]) + ")" +compose_expr_handle.ignore_one_token = True def tco_return_handle(tokens): @@ -527,7 +526,10 @@ def where_handle(tokens): def kwd_err_msg_handle(tokens): """Handle keyword parse error messages.""" kwd, = tokens - return 'invalid use of the keyword "' + kwd + '"' + if kwd == "def": + return "invalid function definition" + else: + return 'invalid use of the keyword "' + kwd + '"' def alt_ternary_handle(tokens): @@ -536,19 +538,6 @@ def alt_ternary_handle(tokens): return "{if_true} if {cond} else {if_false}".format(cond=cond, if_true=if_true, if_false=if_false) -def yield_funcdef_handle(tokens): - """Handle yield def explicit generators.""" - funcdef, = tokens - return funcdef + handle_indentation( - """ -if False: - yield - """, - add_newline=True, - extra_indent=1, - ) - - def partial_op_item_handle(tokens): """Handle operator function implicit partials.""" tok_grp, = tokens @@ -603,6 +592,23 @@ def array_literal_handle(loc, tokens): return "_coconut_multi_dim_arr(" + tuple_str_of(array_elems) + ", " + str(sep_level) + ")" +def typedef_op_item_handle(loc, tokens): + """Converts operator functions in type contexts into Protocols.""" + op_name, = tokens + op_name = op_name.strip("_") + if op_name.startswith("coconut"): + op_name = op_name[len("coconut"):] + op_name = op_name.lstrip("._") + if op_name.startswith("operator."): + op_name = op_name[len("operator."):] + + proto = op_func_protocols.get(op_name) + if proto is None: + raise CoconutDeferredSyntaxError("operator Protocol for " + repr(op_name) + " operator not supported", loc) + + return proto + + # end: HANDLERS # ----------------------------------------------------------------------------------------------------------------------- # MAIN GRAMMAR: @@ -700,7 +706,8 @@ class Grammar(object): | fixto(Literal("<**?\u2218"), "<**?..") | invalid_syntax("") + ~Literal("|*") + Literal("|") | fixto(Literal("\u2228") | Literal("\u222a"), "|") bar = ~rbanana + unsafe_bar | invalid_syntax("\xa6", "invalid broken bar character", greedy=True) @@ -718,19 +725,13 @@ class Grammar(object): questionmark = ~dubquestion + Literal("?") bang = ~Literal("!=") + Literal("!") + kwds = keydefaultdict(partial(base_keyword, explicit_prefix=colon)) + keyword = kwds.__getitem__ + except_star_kwd = combine(keyword("except") + star) - except_kwd = ~except_star_kwd + keyword("except") - lambda_kwd = keyword("lambda") | fixto(keyword("\u03bb", explicit_prefix=colon), "lambda") - async_kwd = keyword("async", explicit_prefix=colon) - await_kwd = keyword("await", explicit_prefix=colon) - data_kwd = keyword("data", explicit_prefix=colon) - match_kwd = keyword("match", explicit_prefix=colon) - case_kwd = keyword("case", explicit_prefix=colon) - cases_kwd = keyword("cases", explicit_prefix=colon) - where_kwd = keyword("where", explicit_prefix=colon) - addpattern_kwd = keyword("addpattern", explicit_prefix=colon) - then_kwd = keyword("then", explicit_prefix=colon) - type_kwd = keyword("type", explicit_prefix=colon) + kwds["except"] = ~except_star_kwd + keyword("except") + kwds["lambda"] = keyword("lambda") | fixto(keyword("\u03bb"), "lambda") + kwds["operator"] = base_keyword("operator", explicit_prefix=colon, require_whitespace=True) ellipsis = Forward() ellipsis_tokens = Literal("...") | fixto(Literal("\u2026"), "...") @@ -808,16 +809,15 @@ class Grammar(object): bin_num = combine(CaselessLiteral("0b") + Optional(underscore.suppress()) + binint) oct_num = combine(CaselessLiteral("0o") + Optional(underscore.suppress()) + octint) hex_num = combine(CaselessLiteral("0x") + Optional(underscore.suppress()) + hexint) - number = addspace( - ( - bin_num - | oct_num - | hex_num - | imag_num - | numitem - ) - + Optional(condense(dot + unsafe_name)), + number = ( + bin_num + | oct_num + | hex_num + | imag_num + | numitem ) + # make sure that this gets addspaced not condensed so it doesn't produce a SyntaxError + num_atom = addspace(number + Optional(condense(dot + unsafe_name))) moduledoc_item = Forward() unwrap = Literal(unwrapper) @@ -929,6 +929,15 @@ class Grammar(object): namedexpr_test = Forward() # for namedexpr locations only supported in Python 3.10 new_namedexpr_test = Forward() + lambdef = Forward() + + typedef = Forward() + typedef_default = Forward() + unsafe_typedef_default = Forward() + typedef_test = Forward() + typedef_tuple = Forward() + typedef_ellipsis = Forward() + typedef_op_item = Forward() negable_atom_item = condense(Optional(neg_minus) + atom_item) @@ -958,7 +967,8 @@ class Grammar(object): lbrace.suppress() + Optional( tokenlist( - Group(addspace(condense(test + colon) + test)) | dubstar_expr, + Group(test + colon + test) + | dubstar_expr, comma, ), ) @@ -1027,9 +1037,13 @@ class Grammar(object): | fixto(ne, "_coconut.operator.ne") | fixto(tilde, "_coconut.operator.inv") | fixto(matrix_at, "_coconut_matmul") + | fixto(keyword("is") + keyword("not"), "_coconut.operator.is_not") + | fixto(keyword("not") + keyword("in"), "_coconut_not_in") + + # must come after is not / not in | fixto(keyword("not"), "_coconut.operator.not_") | fixto(keyword("is"), "_coconut.operator.is_") - | fixto(keyword("in"), "_coconut.operator.contains") + | fixto(keyword("in"), "_coconut_in") ) partialable_op = base_op_item | infix_op partial_op_item_tokens = ( @@ -1037,17 +1051,14 @@ class Grammar(object): | labeled_group(test_no_infix + partialable_op + dot.suppress(), "left partial") ) partial_op_item = attach(partial_op_item_tokens, partial_op_item_handle) - op_item = trace(partial_op_item | base_op_item) + op_item = trace( + typedef_op_item + | partial_op_item + | base_op_item, + ) partial_op_atom_tokens = lparen.suppress() + partial_op_item_tokens + rparen.suppress() - typedef = Forward() - typedef_default = Forward() - unsafe_typedef_default = Forward() - typedef_test = Forward() - typedef_tuple = Forward() - typedef_ellipsis = Forward() - # we include (var)arg_comma to ensure the pattern matches the whole arg arg_comma = comma | fixto(FollowedBy(rparen), "") setarg_comma = arg_comma | fixto(FollowedBy(colon), "") @@ -1243,7 +1254,7 @@ class Grammar(object): known_atom = trace( keyword_atom | string_atom - | number + | num_atom | list_item | dict_comp | dict_literal @@ -1277,7 +1288,7 @@ class Grammar(object): Group(condense(dollar + lbrack) + subscriptgroup + rbrack.suppress()) # $[ | Group(condense(dollar + lbrack + rbrack)) # $[] | Group(condense(lbrack + rbrack)) # [] - | Group(dot + ~unsafe_name + ~lbrack) # . + | Group(dot + ~unsafe_name + ~lbrack + ~dot) # . | Group(questionmark) # ? ) + ~questionmark partial_trailer = ( @@ -1344,51 +1355,50 @@ class Grammar(object): type_param = Forward() type_param_bound_op = lt_colon | colon | le + type_var_name = stores_loc_item + setname type_param_ref = ( - (setname + Optional(type_param_bound_op + typedef_test))("TypeVar") - | (star.suppress() + setname)("TypeVarTuple") - | (dubstar.suppress() + setname)("ParamSpec") + (type_var_name + Optional(type_param_bound_op + typedef_test))("TypeVar") + | (star.suppress() + type_var_name)("TypeVarTuple") + | (dubstar.suppress() + type_var_name)("ParamSpec") ) type_params = Group(lbrack.suppress() + tokenlist(type_param, comma) + rbrack.suppress()) type_alias_stmt = Forward() - type_alias_stmt_ref = type_kwd.suppress() + setname + Optional(type_params) + equals.suppress() + typedef_test + type_alias_stmt_ref = keyword("type").suppress() + setname + Optional(type_params) + equals.suppress() + typedef_test + + await_expr = Forward() + await_expr_ref = keyword("await").suppress() + atom_item + await_item = await_expr | atom_item + + factor = Forward() + unary = plus | neg_minus | tilde + + power = condense(exp_dubstar + ZeroOrMore(unary) + await_item) + power_in_impl_call = Forward() - impl_call_arg = disallow_keywords(reserved_vars) + ( + impl_call_arg = condense(( keyword_atom | number - | dotted_refname - ) - impl_call = attach( + | disallow_keywords(reserved_vars) + dotted_refname + ) + Optional(power_in_impl_call)) + impl_call_item = condense( disallow_keywords(reserved_vars) + + ~any_string + atom_item - + OneOrMore(impl_call_arg), - impl_call_item_handle, + + Optional(power_in_impl_call), ) - impl_call_item = ( - atom_item + ~impl_call_arg - | impl_call + impl_call = Forward() + impl_call_ref = ( + impl_call_item + OneOrMore(impl_call_arg) ) - await_expr = Forward() - await_expr_ref = await_kwd.suppress() + impl_call_item - await_item = await_expr | impl_call_item - - lambdef = Forward() - - compose_item = attach( - tokenlist( - await_item, - dotdot + Optional(invalid_syntax(lambdef, "lambdas only allowed after composition pipe operators '..>' and '<..', not '..' (replace '..' with '<..' to fix)")), - allow_trailing=False, - ), compose_item_handle, + factor <<= condense( + ZeroOrMore(unary) + ( + impl_call + | await_item + Optional(power) + ), ) - factor = Forward() - unary = plus | neg_minus | tilde - power = trace(condense(compose_item + Optional(exp_dubstar + factor))) - factor <<= condense(ZeroOrMore(unary) + power) - mulop = mul_star | div_slash | div_dubslash | percent | matrix_at addop = plus | sub_minus shift = lshift | rshift @@ -1400,30 +1410,41 @@ class Grammar(object): # arith_expr = exprlist(term, addop) # shift_expr = exprlist(arith_expr, shift) # and_expr = exprlist(shift_expr, amp) - # xor_expr = exprlist(and_expr, caret) - xor_expr = exprlist( + and_expr = exprlist( term, addop | shift - | amp - | caret, + | amp, ) + protocol_intersect_expr = Forward() + protocol_intersect_expr_ref = tokenlist(and_expr, amp_colon, allow_trailing=False) + + xor_expr = exprlist(protocol_intersect_expr, caret) + or_expr = typedef_or_expr | exprlist(xor_expr, bar) chain_expr = attach(tokenlist(or_expr, dubcolon, allow_trailing=False), chain_handle) + compose_expr = attach( + tokenlist( + chain_expr, + dotdot + Optional(invalid_syntax(lambdef, "lambdas only allowed after composition pipe operators '..>' and '<..', not '..' (replace '..' with '<..' to fix)")), + allow_trailing=False, + ), compose_expr_handle, + ) + infix_op <<= backtick.suppress() + test_no_infix + backtick.suppress() infix_expr = Forward() infix_item = attach( - Group(Optional(chain_expr)) + Group(Optional(compose_expr)) + OneOrMore( - infix_op + Group(Optional(lambdef | chain_expr)), + infix_op + Group(Optional(lambdef | compose_expr)), ), infix_handle, ) infix_expr <<= ( - chain_expr + ~backtick + compose_expr + ~backtick | infix_item ) @@ -1468,7 +1489,8 @@ class Grammar(object): ) pipe_item = ( # we need the pipe_op since any of the atoms could otherwise be the start of an expression - labeled_group(attrgetter_atom_tokens, "attrgetter") + pipe_op + labeled_group(keyword("await"), "await") + pipe_op + | labeled_group(attrgetter_atom_tokens, "attrgetter") + pipe_op | labeled_group(itemgetter_atom_tokens, "itemgetter") + pipe_op | labeled_group(partial_atom_tokens, "partial") + pipe_op | labeled_group(partial_op_atom_tokens, "op partial") + pipe_op @@ -1477,7 +1499,8 @@ class Grammar(object): ) pipe_augassign_item = trace( # should match pipe_item but with pipe_op -> end_simple_stmt_item and no expr - labeled_group(attrgetter_atom_tokens, "attrgetter") + end_simple_stmt_item + labeled_group(keyword("await"), "await") + end_simple_stmt_item + | labeled_group(attrgetter_atom_tokens, "attrgetter") + end_simple_stmt_item | labeled_group(itemgetter_atom_tokens, "itemgetter") + end_simple_stmt_item | labeled_group(partial_atom_tokens, "partial") + end_simple_stmt_item | labeled_group(partial_op_atom_tokens, "op partial") + end_simple_stmt_item, @@ -1486,6 +1509,7 @@ class Grammar(object): lambdef("expr") # we need longest here because there's no following pipe_op we can use as above | longest( + keyword("await")("await"), attrgetter_atom_tokens("attrgetter"), itemgetter_atom_tokens("itemgetter"), partial_atom_tokens("partial"), @@ -1525,7 +1549,7 @@ class Grammar(object): classic_lambdef = Forward() classic_lambdef_params = maybeparens(lparen, set_args_list, rparen) new_lambdef_params = lparen.suppress() + set_args_list + rparen.suppress() | setname - classic_lambdef_ref = addspace(lambda_kwd + condense(classic_lambdef_params + colon)) + classic_lambdef_ref = addspace(keyword("lambda") + condense(classic_lambdef_params + colon)) new_lambdef = attach(new_lambdef_params + arrow.suppress(), lambdef_handle) implicit_lambdef = fixto(arrow, "lambda _=None:") lambdef_base = classic_lambdef | new_lambdef | implicit_lambdef @@ -1547,7 +1571,8 @@ class Grammar(object): general_stmt_lambdef = ( Group( any_len_perm( - async_kwd, + keyword("async"), + keyword("copyclosure"), ), ) + keyword("def").suppress() + stmt_lambdef_params @@ -1557,8 +1582,9 @@ class Grammar(object): match_stmt_lambdef = ( Group( any_len_perm( - match_kwd.suppress(), - async_kwd, + keyword("match").suppress(), + keyword("async"), + keyword("copyclosure"), ), ) + keyword("def").suppress() + stmt_lambdef_match_params @@ -1582,7 +1608,7 @@ class Grammar(object): ), ) unsafe_typedef_callable = attach( - Optional(async_kwd, default="") + Optional(keyword("async"), default="") + typedef_callable_params + arrow.suppress() + typedef_test, @@ -1604,21 +1630,25 @@ class Grammar(object): unsafe_typedef_ellipsis = ellipsis_tokens - _typedef_test, typedef_callable, _typedef_trailer, _typedef_or_expr, _typedef_tuple, _typedef_ellipsis = disable_outside( + unsafe_typedef_op_item = attach(base_op_item, typedef_op_item_handle) + + _typedef_test, typedef_callable, _typedef_trailer, _typedef_or_expr, _typedef_tuple, _typedef_ellipsis, _typedef_op_item = disable_outside( test, unsafe_typedef_callable, unsafe_typedef_trailer, unsafe_typedef_or_expr, unsafe_typedef_tuple, unsafe_typedef_ellipsis, + unsafe_typedef_op_item, ) typedef_test <<= _typedef_test typedef_trailer <<= _typedef_trailer typedef_or_expr <<= _typedef_or_expr typedef_tuple <<= _typedef_tuple typedef_ellipsis <<= _typedef_ellipsis + typedef_op_item <<= _typedef_op_item - alt_ternary_expr = attach(keyword("if").suppress() + test_item + then_kwd.suppress() + test_item + keyword("else").suppress() + test, alt_ternary_handle) + alt_ternary_expr = attach(keyword("if").suppress() + test_item + keyword("then").suppress() + test_item + keyword("else").suppress() + test, alt_ternary_handle) test <<= ( typedef_callable | lambdef @@ -1669,7 +1699,7 @@ class Grammar(object): | test_item ) base_comp_for = addspace(keyword("for") + assignlist + keyword("in") + comp_it_item + Optional(comp_iter)) - async_comp_for_ref = addspace(async_kwd + base_comp_for) + async_comp_for_ref = addspace(keyword("async") + base_comp_for) comp_for <<= async_comp_for | base_comp_for comp_if = addspace(keyword("if") + test_no_cond + Optional(comp_iter)) comp_iter <<= comp_for | comp_if @@ -1780,10 +1810,16 @@ class Grammar(object): | Optional(neg_minus) + number | match_dotted_name_const, ) + empty_const = fixto( + lparen + rparen + | lbrack + rbrack + | set_letter + lbrace + rbrace, + "()", + ) - matchlist_set = Group(Optional(tokenlist(match_const, comma))) match_pair = Group(match_const + colon.suppress() + match) matchlist_dict = Group(Optional(tokenlist(match_pair, comma))) + set_star = star.suppress() + (keyword(wildcard) | empty_const) matchlist_tuple_items = ( match + OneOrMore(comma.suppress() + match) + Optional(comma.suppress()) @@ -1832,14 +1868,22 @@ class Grammar(object): | match_const("const") | (keyword_atom | keyword("is").suppress() + negable_atom_item)("is") | (keyword("in").suppress() + negable_atom_item)("in") - | (lbrace.suppress() + matchlist_dict + Optional(dubstar.suppress() + (setname | condense(lbrace + rbrace))) + rbrace.suppress())("dict") - | (Optional(set_s.suppress()) + lbrace.suppress() + matchlist_set + rbrace.suppress())("set") | iter_match | match_lazy("lazy") | sequence_match | star_match | (lparen.suppress() + match + rparen.suppress())("paren") - | (data_kwd.suppress() + dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("data") + | (lbrace.suppress() + matchlist_dict + Optional(dubstar.suppress() + (setname | condense(lbrace + rbrace)) + Optional(comma.suppress())) + rbrace.suppress())("dict") + | ( + Group(Optional(set_letter)) + + lbrace.suppress() + + ( + Group(tokenlist(match_const, comma, allow_trailing=False)) + Optional(comma.suppress() + set_star + Optional(comma.suppress())) + | Group(always_match) + set_star + Optional(comma.suppress()) + | Group(Optional(tokenlist(match_const, comma))) + ) + rbrace.suppress() + )("set") + | (keyword("data").suppress() + dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("data") | (keyword("class").suppress() + dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("class") | (dotted_refname + lparen.suppress() + matchlist_data + rparen.suppress())("data_or_class") | Optional(keyword("as").suppress()) + setname("var"), @@ -1876,23 +1920,25 @@ class Grammar(object): full_suite = colon.suppress() - Group((newline.suppress() - indent.suppress() - OneOrMore(stmt) - dedent.suppress()) | simple_stmt) full_match = Forward() full_match_ref = ( - match_kwd.suppress() + keyword("match").suppress() + many_match + addspace(Optional(keyword("not")) + keyword("in")) - - testlist_star_namedexpr - - match_guard + + testlist_star_namedexpr + + match_guard + # avoid match match-case blocks + + ~FollowedBy(colon + newline + indent + keyword("case")) - full_suite ) match_stmt = trace(condense(full_match - Optional(else_stmt))) destructuring_stmt = Forward() - base_destructuring_stmt = Optional(match_kwd.suppress()) + many_match + equals.suppress() + test_expr + base_destructuring_stmt = Optional(keyword("match").suppress()) + many_match + equals.suppress() + test_expr destructuring_stmt_ref, match_dotted_name_const_ref = disable_inside(base_destructuring_stmt, must_be_dotted_name + ~lparen) # both syntaxes here must be kept the same except for the keywords case_match_co_syntax = trace( Group( - (match_kwd | case_kwd).suppress() + (keyword("match") | keyword("case")).suppress() + stores_loc_item + many_match + Optional(keyword("if").suppress() + namedexpr_test) @@ -1900,13 +1946,13 @@ class Grammar(object): ), ) cases_stmt_co_syntax = ( - (cases_kwd | case_kwd) + testlist_star_namedexpr + colon.suppress() + newline.suppress() + (keyword("cases") | keyword("case")) + testlist_star_namedexpr + colon.suppress() + newline.suppress() + indent.suppress() + Group(OneOrMore(case_match_co_syntax)) + dedent.suppress() + Optional(keyword("else").suppress() + suite) ) case_match_py_syntax = trace( Group( - case_kwd.suppress() + keyword("case").suppress() + stores_loc_item + many_match + Optional(keyword("if").suppress() + namedexpr_test) @@ -1914,7 +1960,7 @@ class Grammar(object): ), ) cases_stmt_py_syntax = ( - match_kwd + testlist_star_namedexpr + colon.suppress() + newline.suppress() + keyword("match") + testlist_star_namedexpr + colon.suppress() + newline.suppress() + indent.suppress() + Group(OneOrMore(case_match_py_syntax)) + dedent.suppress() + Optional(keyword("else").suppress() - suite) ) @@ -1939,7 +1985,7 @@ class Grammar(object): base_match_for_stmt = Forward() base_match_for_stmt_ref = keyword("for").suppress() + many_match + keyword("in").suppress() - new_testlist_star_expr - colon.suppress() - condense(nocolon_suite - Optional(else_stmt)) - match_for_stmt = Optional(match_kwd.suppress()) + base_match_for_stmt + match_for_stmt = Optional(keyword("match").suppress()) + base_match_for_stmt except_item = ( testlist_has_comma("list") @@ -1947,15 +1993,15 @@ class Grammar(object): ) - Optional( keyword("as").suppress() - setname, ) - except_clause = attach(except_kwd + except_item, except_handle) + except_clause = attach(keyword("except") + except_item, except_handle) except_star_clause = Forward() except_star_clause_ref = attach(except_star_kwd + except_item, except_handle) try_stmt = condense( keyword("try") - suite + ( keyword("finally") - suite | ( - OneOrMore(except_clause - suite) - Optional(except_kwd - suite) - | except_kwd - suite + OneOrMore(except_clause - suite) - Optional(keyword("except") - suite) + | keyword("except") - suite | OneOrMore(except_star_clause - suite) ) - Optional(else_stmt) - Optional(keyword("finally") - suite) ), @@ -2024,16 +2070,16 @@ class Grammar(object): ) match_def_modifiers = trace( any_len_perm( - match_kwd.suppress(), - # we don't suppress addpattern so its presence can be detected later - addpattern_kwd, + keyword("match").suppress(), + # addpattern is detected later + keyword("addpattern"), ), ) match_funcdef = addspace(match_def_modifiers + def_match_funcdef) where_stmt = attach( unsafe_simple_stmt_item - + where_kwd.suppress() + + keyword("where").suppress() - full_suite, where_handle, ) @@ -2044,7 +2090,7 @@ class Grammar(object): ) implicit_return_where = attach( implicit_return - + where_kwd.suppress() + + keyword("where").suppress() - full_suite, where_handle, ) @@ -2086,74 +2132,72 @@ class Grammar(object): async_stmt = Forward() async_stmt_ref = addspace( - async_kwd + (with_stmt | for_stmt | match_for_stmt) # handles async [match] for - | match_kwd.suppress() + async_kwd + base_match_for_stmt, # handles match async for + keyword("async") + (with_stmt | for_stmt | match_for_stmt) # handles async [match] for + | keyword("match").suppress() + keyword("async") + base_match_for_stmt, # handles match async for ) - async_funcdef = async_kwd.suppress() + (funcdef | math_funcdef) + async_funcdef = keyword("async").suppress() + (funcdef | math_funcdef) async_match_funcdef = trace( addspace( any_len_perm( - match_kwd.suppress(), - # we don't suppress addpattern so its presence can be detected later - addpattern_kwd, - required=(async_kwd.suppress(),), + keyword("match").suppress(), + # addpattern is detected later + keyword("addpattern"), + required=(keyword("async").suppress(),), ) + (def_match_funcdef | math_match_funcdef), ), ) - async_yield_funcdef = attach( - trace( - any_len_perm( - required=( - async_kwd.suppress(), - keyword("yield").suppress(), - ), - ) + (funcdef | math_funcdef), + + async_keyword_normal_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), + required=(keyword("async").suppress(),), ), - yield_funcdef_handle, - ) - async_yield_match_funcdef = attach( - trace( - addspace( - any_len_perm( - match_kwd.suppress(), - # we don't suppress addpattern so its presence can be detected later - addpattern_kwd, - required=( - async_kwd.suppress(), - keyword("yield").suppress(), - ), - ) + (def_match_funcdef | math_match_funcdef), - ), + ) + (funcdef | math_funcdef) + async_keyword_match_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), + keyword("match").suppress(), + # addpattern is detected later + keyword("addpattern"), + required=(keyword("async").suppress(),), ), - yield_funcdef_handle, - ) + ) + (def_match_funcdef | math_match_funcdef) + async_keyword_funcdef = Forward() + async_keyword_funcdef_ref = async_keyword_normal_funcdef | async_keyword_match_funcdef + async_funcdef_stmt = ( async_funcdef | async_match_funcdef - | async_yield_funcdef - | async_yield_match_funcdef + | async_keyword_funcdef ) - yield_normal_funcdef = keyword("yield").suppress() + (funcdef | math_funcdef) - yield_match_funcdef = trace( - addspace( - any_len_perm( - match_kwd.suppress(), - # we don't suppress addpattern so its presence can be detected later - addpattern_kwd, - required=(keyword("yield").suppress(),), - ) + (def_match_funcdef | math_match_funcdef), + keyword_normal_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), ), - ) - yield_funcdef = attach(yield_normal_funcdef | yield_match_funcdef, yield_funcdef_handle) + ) + (funcdef | math_funcdef) + keyword_match_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), + keyword("match").suppress(), + # addpattern is detected later + keyword("addpattern"), + ), + ) + (def_match_funcdef | math_match_funcdef) + keyword_funcdef = Forward() + keyword_funcdef_ref = keyword_normal_funcdef | keyword_match_funcdef normal_funcdef_stmt = ( funcdef | math_funcdef | math_match_funcdef | match_funcdef - | yield_funcdef + | keyword_funcdef ) datadef = Forward() @@ -2181,7 +2225,7 @@ class Grammar(object): ) datadef_ref = ( Optional(decorators, default="") - + data_kwd.suppress() + + keyword("data").suppress() + classname + Optional(type_params, default=()) + data_args @@ -2196,8 +2240,8 @@ class Grammar(object): # we don't support type_params here since we don't support types match_datadef_ref = ( Optional(decorators, default="") - + Optional(match_kwd.suppress()) - + data_kwd.suppress() + + Optional(keyword("match").suppress()) + + keyword("data").suppress() + classname + match_data_args + data_inherit @@ -2323,7 +2367,7 @@ class Grammar(object): whitespace_regex = compile_regex(r"\s") - def_regex = compile_regex(r"((async|addpattern)\s+)*def\b") + def_regex = compile_regex(r"((async|addpattern|copyclosure)\s+)*def\b") yield_regex = compile_regex(r"\byield(?!\s+_coconut\.asyncio\.From)\b") tco_disable_regex = compile_regex(r"try\b|(async\s+)?(with\b|for\b)|while\b") @@ -2344,10 +2388,10 @@ def get_tre_return_grammar(self, func_name): """The TRE return grammar is parameterized by the name of the function being optimized.""" return ( self.start_marker - + keyword("return").suppress() + + self.keyword("return").suppress() + maybeparens( self.lparen, - keyword(func_name, explicit_prefix=False).suppress() + base_keyword(func_name).suppress() + self.original_function_call_tokens, self.rparen, ) + self.end_marker @@ -2396,12 +2440,12 @@ def get_tre_return_grammar(self, func_name): | ~comma + ~rparen + ~equals + any_char, ), ) - tfpdef_tokens = unsafe_name - Optional(colon.suppress() - rest_of_tfpdef.suppress()) - tfpdef_default_tokens = tfpdef_tokens - Optional(equals.suppress() - rest_of_tfpdef) + tfpdef_tokens = unsafe_name - Optional(colon - rest_of_tfpdef).suppress() + tfpdef_default_tokens = tfpdef_tokens - Optional(equals - rest_of_tfpdef) type_comment = Optional( - comment_tokens.suppress() - | passthrough_item.suppress(), - ) + comment_tokens + | passthrough_item, + ).suppress() parameters_tokens = Group( Optional( tokenlist( @@ -2424,7 +2468,7 @@ def get_tre_return_grammar(self, func_name): ) stores_scope = boundary + ( - lambda_kwd + keyword("lambda") # match comprehensions but not for loops | ~indent + ~dedent + any_char + keyword("for") + unsafe_name + keyword("in") ) @@ -2435,7 +2479,7 @@ def get_tre_return_grammar(self, func_name): unsafe_equals = Literal("=") - kwd_err_msg = attach(any_keyword_in(keyword_vars), kwd_err_msg_handle) + kwd_err_msg = attach(any_keyword_in(keyword_vars + reserved_vars), kwd_err_msg_handle) parse_err_msg = ( start_marker + ( fixto(end_of_line, "misplaced newline (maybe missing ':')") @@ -2456,10 +2500,9 @@ def get_tre_return_grammar(self, func_name): string_start = start_marker + quotedString - operator_kwd = keyword("operator", explicit_prefix=colon, require_whitespace=True) operator_stmt = ( start_marker - + operator_kwd.suppress() + + keyword("operator").suppress() + restOfLine ) @@ -2469,7 +2512,7 @@ def get_tre_return_grammar(self, func_name): + keyword("from").suppress() + unsafe_import_from_name + keyword("import").suppress() - + operator_kwd.suppress() + + keyword("operator").suppress() + restOfLine ) diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 9aa506816..da436a7fa 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -22,7 +22,7 @@ import os.path from functools import partial -from coconut.root import _indent +from coconut.root import _indent, _get_root_header from coconut.exceptions import CoconutInternalException from coconut.terminal import internal_assert from coconut.constants import ( @@ -33,6 +33,7 @@ justify_len, report_this_text, numpy_modules, + pandas_numpy_modules, jax_numpy_modules, self_match_types, is_data_var, @@ -174,11 +175,11 @@ def base_pycondition(target, ver, if_lt=None, if_ge=None, indent=None, newline=F return out -def make_py_str(str_contents, target_startswith, after_py_str_defined=False): +def make_py_str(str_contents, target, after_py_str_defined=False): """Get code that effectively wraps the given code in py_str.""" return ( - repr(str_contents) if target_startswith == "3" - else "b" + repr(str_contents) if target_startswith == "2" + repr(str_contents) if target.startswith("3") + else "b" + repr(str_contents) if target.startswith("2") else "py_str(" + repr(str_contents) + ")" if after_py_str_defined else "str(" + repr(str_contents) + ")" ) @@ -202,35 +203,37 @@ def __getattr__(self, attr): def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): """Create the dictionary passed to str.format in the header.""" - target_startswith = one_num_ver(target) target_info = get_target_info(target) pycondition = partial(base_pycondition, target) format_dict = dict( COMMENT=COMMENT, - empty_dict="{}", + empty_dict="{}" if target_info >= (3, 7) else "_coconut.dict()", + empty_py_dict="{}" if target_info >= (3, 7) else "_coconut_py_dict()", lbrace="{", rbrace="}", is_data_var=is_data_var, data_defaults_var=data_defaults_var, - target_startswith=target_startswith, + target_major=one_num_ver(target), default_encoding=default_encoding, hash_line=hash_prefix + use_hash + "\n" if use_hash is not None else "", typing_line="# type: ignore\n" if which == "__coconut__" else "", + _coconut_="_coconut_" if which != "__coconut__" else "", # only for aliases defined at the end of the header VERSION_STR=VERSION_STR, module_docstring='"""Built-in Coconut utilities."""\n\n' if which == "__coconut__" else "", - __coconut__=make_py_str("__coconut__", target_startswith), - _coconut_cached__coconut__=make_py_str("_coconut_cached__coconut__", target_startswith), - object="" if target_startswith == "3" else "(object)", - comma_object="" if target_startswith == "3" else ", object", + __coconut__=make_py_str("__coconut__", target), + _coconut_cached__coconut__=make_py_str("_coconut_cached__coconut__", target), + object="" if target.startswith("3") else "(object)", + comma_object="" if target.startswith("3") else ", object", comma_slash=", /" if target_info >= (3, 8) else "", report_this_text=report_this_text, numpy_modules=tuple_str_of(numpy_modules, add_quotes=True), + pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True), jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True), self_match_types=tuple_str_of(self_match_types), set_super=( # we have to use _coconut_super even on the universal target, since once we set __class__ it becomes a local variable - "super = _coconut_super\n" if target_startswith != 3 else "" + "super = py_super" if target.startswith("3") else "super = _coconut_super" ), import_pickle=pycondition( (3,), @@ -268,9 +271,9 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): else "zip_longest = itertools.izip_longest", indent=1, ), - comma_bytearray=", bytearray" if target_startswith != "3" else "", - lstatic="staticmethod(" if target_startswith != "3" else "", - rstatic=")" if target_startswith != "3" else "", + comma_bytearray=", bytearray" if not target.startswith("3") else "", + lstatic="staticmethod(" if not target.startswith("3") else "", + rstatic=")" if not target.startswith("3") else "", zip_iter=prepare( r''' for items in _coconut.iter(_coconut.zip(*self.iters, strict=self.strict) if _coconut_sys.version_info >= (3, 10) else _coconut.zip_longest(*self.iters, fillvalue=_coconut_sentinel) if self.strict else _coconut.zip(*self.iters)): @@ -346,7 +349,7 @@ def pattern_prepender(func): set_name = _coconut.getattr(v, "__set_name__", None) if set_name is not None: set_name(cls, k)''' - if target_startswith == "2" else + if target.startswith("2") else r'''def _coconut_call_set_names(cls): pass''' if target_info >= (3, 6) else r'''def _coconut_call_set_names(cls): @@ -419,23 +422,12 @@ def _coconut_matmul(a, b, **kwargs): else: if result is not _coconut.NotImplemented: return result - if "numpy" in (a.__class__.__module__, b.__class__.__module__): + if "numpy" in (_coconut_get_base_module(a), _coconut_get_base_module(b)): from numpy import matmul return matmul(a, b) raise _coconut.TypeError("unsupported operand type(s) for @: " + _coconut.repr(_coconut.type(a)) + " and " + _coconut.repr(_coconut.type(b))) ''', ), - import_typing_NamedTuple=pycondition( - (3, 6), - if_lt=''' -def NamedTuple(name, fields): - return _coconut.collections.namedtuple(name, [x for x, t in fields]) -typing.NamedTuple = NamedTuple -NamedTuple = staticmethod(NamedTuple) - ''', - indent=1, - newline=True, - ), def_total_and_comparisons=pycondition( (3, 10), if_lt=''' @@ -483,10 +475,21 @@ def __lt__(self, other): indent=1, newline=True, ), + assign_multiset_views=pycondition( + (3,), + if_lt=''' +keys = _coconut.collections.Counter.viewkeys +values = _coconut.collections.Counter.viewvalues +items = _coconut.collections.Counter.viewitems + ''', + indent=1, + newline=True, + ), + # used in the second round tco_comma="_coconut_tail_call, _coconut_tco, " if not no_tco else "", call_set_names_comma="_coconut_call_set_names, " if target_info < (3, 6) else "", - handle_cls_args_comma="_coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, " if target_startswith != "3" else "", + handle_cls_args_comma="_coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, " if not target.startswith("3") else "", async_def_anext=prepare( r''' async def __anext__(self): @@ -525,9 +528,10 @@ async def __anext__(self): ) # second round for format dict elements that use the format dict + # (extra_format_dict is to keep indentation levels matching) extra_format_dict = dict( # when anything is added to this list it must also be added to *both* __coconut__ stub files - underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_super, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose".format(**format_dict), + underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in".format(**format_dict), import_typing=pycondition( (3, 5), if_ge="import typing", @@ -542,36 +546,76 @@ def cast(self, t, x): return x def __getattr__(self, name): raise _coconut.ImportError("the typing module is not available at runtime in Python 3.4 or earlier; try hiding your typedefs behind an 'if TYPE_CHECKING:' block") + def TypeVar(name, *args, **kwargs): + """Runtime mock of typing.TypeVar for Python 3.4 and earlier.""" + return name + class Generic_mock{object}: + """Runtime mock of typing.Generic for Python 3.4 and earlier.""" + __slots__ = () + def __getitem__(self, vars): + return _coconut.object + Generic = Generic_mock() typing = typing_mock() '''.format(**format_dict), indent=1, ), # all typing_extensions imports must be added to the _coconut stub file - import_typing_TypeAlias_ParamSpec_Concatenate=pycondition( + import_typing_36=pycondition( + (3, 6), + if_lt=''' +def NamedTuple(name, fields): + return _coconut.collections.namedtuple(name, [x for x, t in fields]) +typing.NamedTuple = NamedTuple +NamedTuple = staticmethod(NamedTuple) + ''', + indent=1, + newline=True, + ), + import_typing_38=pycondition( + (3, 8), + if_lt=''' +try: + from typing_extensions import Protocol +except ImportError: + class YouNeedToInstallTypingExtensions{object}: + __slots__ = () + Protocol = YouNeedToInstallTypingExtensions +typing.Protocol = Protocol + '''.format(**format_dict), + indent=1, + newline=True, + ), + import_typing_310=pycondition( (3, 10), if_lt=''' try: - from typing_extensions import TypeAlias, ParamSpec, Concatenate + from typing_extensions import ParamSpec, TypeAlias, Concatenate except ImportError: + def ParamSpec(name, *args, **kwargs): + """Runtime mock of typing.ParamSpec for Python 3.9 and earlier.""" + return _coconut.typing.TypeVar(name) class you_need_to_install_typing_extensions{object}: __slots__ = () - TypeAlias = ParamSpec = Concatenate = you_need_to_install_typing_extensions() -typing.TypeAlias = TypeAlias + TypeAlias = Concatenate = you_need_to_install_typing_extensions() typing.ParamSpec = ParamSpec +typing.TypeAlias = TypeAlias typing.Concatenate = Concatenate '''.format(**format_dict), indent=1, newline=True, ), - import_typing_TypeVarTuple_Unpack=pycondition( + import_typing_311=pycondition( (3, 11), if_lt=''' try: from typing_extensions import TypeVarTuple, Unpack except ImportError: + def TypeVarTuple(name, *args, **kwargs): + """Runtime mock of typing.TypeVarTuple for Python 3.10 and earlier.""" + return _coconut.typing.TypeVar(name) class you_need_to_install_typing_extensions{object}: __slots__ = () - TypeVarTuple = Unpack = you_need_to_install_typing_extensions() + Unpack = you_need_to_install_typing_extensions() typing.TypeVarTuple = TypeVarTuple typing.Unpack = Unpack '''.format(**format_dict), @@ -599,7 +643,7 @@ class you_need_to_install_trollius{object}: _coconut_amap = None ''', if_ge=r''' -class _coconut_amap(_coconut_base_hashable): +class _coconut_amap(_coconut_baseclass): __slots__ = ("func", "aiter") def __init__(self, func, aiter): self.func = func @@ -658,13 +702,13 @@ def getheader(which, use_hash, target, no_tco, strict, no_wrap): # initial, __coconut__, package:n, sys, code, file - target_startswith = one_num_ver(target) target_info = get_target_info(target) - header_info = tuple_str_of((VERSION, target, no_tco, strict, no_wrap), add_quotes=True) + # header_info only includes arguments that affect __coconut__.py compatibility + header_info = tuple_str_of((VERSION, target, strict), add_quotes=True) format_dict = process_header_args(which, use_hash, target, no_tco, strict, no_wrap) if which == "initial" or which == "__coconut__": - header = '''#!/usr/bin/env python{target_startswith} + header = '''#!/usr/bin/env python{target_major} # -*- coding: {default_encoding} -*- {hash_line}{typing_line} # Compiled with Coconut version {VERSION_STR} @@ -682,7 +726,7 @@ def getheader(which, use_hash, target, no_tco, strict, no_wrap): header += section("Coconut Header", newline_before=False) - if target_startswith != "3": + if not target.startswith("3"): header += "from __future__ import print_function, absolute_import, unicode_literals, division\n" # including generator_stop here is fine, even though to universalize # generator returns we raise StopIteration errors, since we only do so @@ -756,16 +800,18 @@ def getheader(which, use_hash, target, no_tco, strict, no_wrap): newline=True, ).format(**format_dict) + if target_info >= (3, 9): + header += _get_root_header("39") if target_info >= (3, 7): - header += PY37_HEADER - elif target_startswith == "3": - header += PY3_HEADER + header += _get_root_header("37") + elif target.startswith("3"): + header += _get_root_header("3") elif target_info >= (2, 7): - header += PY27_HEADER - elif target_startswith == "2": - header += PY2_HEADER + header += _get_root_header("27") + elif target.startswith("2"): + header += _get_root_header("2") else: - header += PYCHECK_HEADER + header += _get_root_header("universal") header += get_template("header").format(**format_dict) diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index 947035aa2..2bc2e5a8d 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -42,6 +42,9 @@ data_defaults_var, default_matcher_style, self_match_types, + match_first_arg_var, + match_to_args_var, + match_to_kwargs_var, ) from coconut.compiler.util import ( paren_join, @@ -230,9 +233,12 @@ def using_python_rules(self): """Whether the current style uses PEP 622 rules.""" return self.style.startswith("python") - def rule_conflict_warn(self, message, if_coconut=None, if_python=None, extra=None): + def rule_conflict_warn(self, message, if_coconut=None, if_python=None, extra=None, only_strict=False): """Warns on conflicting style rules if callback was given.""" - if self.style.endswith("warn") or self.style.endswith("strict") and self.comp.strict: + if ( + self.style.endswith("warn") and (not only_strict or self.comp.strict) + or self.style.endswith("strict") and self.comp.strict + ): full_msg = message if if_python or if_coconut: full_msg += " (" + (if_python if self.using_python_rules else if_coconut) + ")" @@ -343,10 +349,33 @@ def check_len_in(self, min_len, max_len, item): else: self.add_check(str(min_len) + " <= _coconut.len(" + item + ") <= " + str(max_len)) - def match_function(self, args, kwargs, pos_only_match_args=(), match_args=(), star_arg=None, kwd_only_match_args=(), dubstar_arg=None): + def match_function( + self, + first_arg=match_first_arg_var, + args=match_to_args_var, + kwargs=match_to_kwargs_var, + pos_only_match_args=(), + match_args=(), + star_arg=None, + kwd_only_match_args=(), + dubstar_arg=None, + ): """Matches a pattern-matching function.""" # before everything, pop the FunctionMatchError from context self.add_def(function_match_error_var + " = _coconut_get_function_match_error()") + # and fix args to include first_arg, which we have to do to make super work + self.add_def( + handle_indentation( + """ +if {first_arg} is not _coconut_sentinel: + {args} = ({first_arg},) + {args} + """, + ).format( + first_arg=first_arg, + args=args, + ), + ) + with self.down_a_level(): self.match_in_args_kwargs(pos_only_match_args, match_args, args, kwargs, allow_star_args=star_arg is not None) @@ -475,15 +504,16 @@ def match_dict(self, tokens, item): self.rule_conflict_warn( "found pattern with new behavior in Coconut v2; dict patterns now allow the dictionary being matched against to contain extra keys", extra="use explicit '{..., **_}' or '{..., **{}}' syntax to resolve", + only_strict=True, ) - check_len = not self.using_python_rules + strict_len = not self.using_python_rules elif rest == "{}": - check_len = True + strict_len = True rest = None else: - check_len = False + strict_len = False - if check_len: + if strict_len: self.add_check("_coconut.len(" + item + ") == " + str(len(matches))) seen_keys = set() @@ -500,8 +530,8 @@ def match_dict(self, tokens, item): if rest is not None and rest != wildcard: match_keys = [k for k, v in matches] rest_item = ( - "dict((k, v) for k, v in " - + item + ".items() if k not in set((" + "_coconut.dict((k, v) for k, v in " + + item + ".items() if k not in _coconut.set((" + ", ".join(match_keys) + ("," if len(match_keys) == 1 else "") + ")))" ) @@ -900,11 +930,48 @@ def match_in(self, tokens, item): def match_set(self, tokens, item): """Matches a set.""" - match, = tokens - self.add_check("_coconut.isinstance(" + item + ", _coconut.abc.Set)") - self.add_check("_coconut.len(" + item + ") == " + str(len(match))) - for const in match: - self.add_check(const + " in " + item) + if len(tokens) == 2: + letter_toks, match = tokens + star = None + else: + letter_toks, match, star = tokens + + if letter_toks: + letter, = letter_toks + else: + letter = "s" + + # process *() or *_ + if star is None: + self.rule_conflict_warn( + "found pattern with new behavior in Coconut v3; set patterns now allow the set being matched against to contain extra items", + extra="use explicit '{..., *_}' or '{..., *()}' syntax to resolve", + ) + strict_len = not self.using_python_rules + elif star == wildcard: + strict_len = False + else: + internal_assert(star == "()", "invalid set match tokens", tokens) + strict_len = True + + # handle set letter + if letter == "s": + self.add_check("_coconut.isinstance(" + item + ", _coconut.abc.Set)") + elif letter == "f": + self.add_check("_coconut.isinstance(" + item + ", _coconut.frozenset)") + elif letter == "m": + self.add_check("_coconut.isinstance(" + item + ", _coconut.collections.Counter)") + else: + raise CoconutInternalException("invalid set match letter", letter) + + # match set contents + if letter == "m": + self.add_check("_coconut_multiset(" + tuple_str_of(match) + ") " + ("== " if strict_len else "<= ") + item) + else: + if strict_len: + self.add_check("_coconut.len(" + item + ") == " + str(len(match))) + for const in match: + self.add_check(const + " in " + item) def split_data_or_class_match(self, tokens): """Split data/class match tokens into cls_name, pos_matches, name_matches, star_match.""" diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 33848360f..347eb1178 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -11,7 +11,8 @@ def _coconut_super(type=None, object_or_type=None): self = frame.f_locals[frame.f_code.co_varnames[0]] return _coconut_py_super(cls, self) return _coconut_py_super(type, object_or_type) -{set_super}class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} +{set_super} +class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} import collections, copy, functools, types, itertools, operator, threading, os, warnings, contextlib, traceback, weakref, multiprocessing from multiprocessing import dummy as multiprocessing_dummy {maybe_bind_lru_cache}{import_copyreg} @@ -20,7 +21,7 @@ def _coconut_super(type=None, object_or_type=None): {import_OrderedDict} {import_collections_abc} {import_typing} -{import_typing_NamedTuple}{import_typing_TypeAlias_ParamSpec_Concatenate}{import_typing_TypeVarTuple_Unpack}{set_zip_longest} +{import_typing_36}{import_typing_38}{import_typing_310}{import_typing_311}{set_zip_longest} try: import numpy except ImportError: @@ -30,15 +31,39 @@ def _coconut_super(type=None, object_or_type=None): else: abc.Sequence.register(numpy.ndarray) numpy_modules = {numpy_modules} + pandas_numpy_modules = {pandas_numpy_modules} jax_numpy_modules = {jax_numpy_modules} tee_type = type(itertools.tee((), 1)[0]) reiterables = abc.Sequence, abc.Mapping, abc.Set abc.Sequence.register(collections.deque) - Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} -class _coconut_Sentinel{object}: - __slots__ = () -_coconut_sentinel = _coconut_Sentinel() -class _coconut_base_hashable{object}: + Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} +def _coconut_handle_cls_kwargs(**kwargs): + """Some code taken from six under the terms of its MIT license.""" + metaclass = kwargs.pop("metaclass", None) + if kwargs and metaclass is None: + raise _coconut.TypeError("unexpected keyword argument(s) in class definition: %r" % (kwargs,)) + def coconut_handle_cls_kwargs_wrapper(cls): + if metaclass is None: + return cls + orig_vars = cls.__dict__.copy() + slots = orig_vars.get("__slots__") + if slots is not None: + if _coconut.isinstance(slots, _coconut.str): + slots = [slots] + for slots_var in slots: + orig_vars.pop(slots_var) + orig_vars.pop("__dict__", None) + orig_vars.pop("__weakref__", None) + if _coconut.hasattr(cls, "__qualname__"): + orig_vars["__qualname__"] = cls.__qualname__ + return metaclass(cls.__name__, cls.__bases__, orig_vars, **kwargs) + return coconut_handle_cls_kwargs_wrapper +def _coconut_handle_cls_stargs(*args): + temp_names = ["_coconut_base_cls_%s" % (i,) for i in _coconut.range(_coconut.len(args))] + ns = _coconut_py_dict(_coconut.zip(temp_names, args)) + _coconut_exec("class _coconut_cls_stargs_base(" + ", ".join(temp_names) + "): pass", ns) + return ns["_coconut_cls_stargs_base"] +class _coconut_baseclass{object}: __slots__ = ("__weakref__",) def __reduce_ex__(self, _): return self.__reduce__() @@ -49,7 +74,19 @@ class _coconut_base_hashable{object}: def __setstate__(self, setvars):{COMMENT.fixes_unpickling_with_slots} for k, v in setvars.items(): _coconut.setattr(self, k, v) -class MatchError(_coconut_base_hashable, Exception): + def __iter_getitem__(self, index): + getitem = _coconut.getattr(self, "__getitem__", None) + if getitem is None: + raise _coconut.NotImplementedError + return getitem(index) +class _coconut_Sentinel(_coconut_baseclass): + __slots__ = () + def __reduce__(self): + return (self.__class__, ()) +_coconut_sentinel = _coconut_Sentinel() +def _coconut_get_base_module(obj): + return obj.__class__.__module__.split(".", 1)[0] +class MatchError(_coconut_baseclass, Exception): """Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below} max_val_repr_len = 500 def __init__(self, pattern=None, value=None): @@ -75,18 +112,20 @@ class MatchError(_coconut_base_hashable, Exception): def __reduce__(self): return (self.__class__, (self.pattern, self.value), {lbrace}"_message": self._message{rbrace}) def __setstate__(self, state): - _coconut_base_hashable.__setstate__(self, state) + _coconut_baseclass.__setstate__(self, state) if self._message is not None: Exception.__init__(self, self._message) _coconut_cached_MatchError = None if _coconut_cached__coconut__ is None else getattr(_coconut_cached__coconut__, "MatchError", None) if _coconut_cached_MatchError is not None:{patch_cached_MatchError} MatchError = _coconut_cached_MatchError -class _coconut_tail_call{object}: +class _coconut_tail_call(_coconut_baseclass): __slots__ = ("func", "args", "kwargs") def __init__(self, _coconut_func, *args, **kwargs): self.func = _coconut_func self.args = args self.kwargs = kwargs + def __reduce__(self): + return (self.__class__, (self.func, self.args, self.kwargs)) _coconut_tco_func_dict = {empty_dict} def _coconut_tco(func): @_coconut.functools.wraps(func) @@ -141,7 +180,7 @@ def tee(iterable, n=2): else:{COMMENT.no_break} return _coconut.tuple(existing_copies) return _coconut.itertools.tee(iterable, n) -class _coconut_has_iter(_coconut_base_hashable): +class _coconut_has_iter(_coconut_baseclass): __slots__ = ("lock", "iter") def __new__(cls, iterable): self = _coconut.object.__new__(cls) @@ -151,10 +190,10 @@ class _coconut_has_iter(_coconut_base_hashable): def get_new_iter(self): """Tee the underlying iterator.""" with self.lock: - self.iter = _coconut_reiterable(self.iter) + self.iter = {_coconut_}reiterable(self.iter) return self.iter def __fmap__(self, func): - return _coconut_map(func, self) + return {_coconut_}map(func, self) class reiterable(_coconut_has_iter): """Allow an iterator to be iterated over multiple times with the same results.""" __slots__ = () @@ -165,7 +204,7 @@ class reiterable(_coconut_has_iter): def get_new_iter(self): """Tee the underlying iterator.""" with self.lock: - self.iter, new_iter = _coconut_tee(self.iter) + self.iter, new_iter = {_coconut_}tee(self.iter) return new_iter def __iter__(self): return _coconut.iter(self.get_new_iter()) @@ -178,7 +217,7 @@ class reiterable(_coconut_has_iter): def __getitem__(self, index): return _coconut_iter_getitem(self.get_new_iter(), index) def __reversed__(self): - return _coconut_reversed(self.get_new_iter()) + return {_coconut_}reversed(self.get_new_iter()) def __len__(self): if not _coconut.isinstance(self.iter, _coconut.abc.Sized): return _coconut.NotImplemented @@ -203,12 +242,12 @@ def _coconut_iter_getitem_special_case(iterable, start, stop, step): def _coconut_iter_getitem(iterable, index): """Iterator slicing works just like sequence slicing, including support for negative indices and slices, and support for `slice` objects in the same way as can be done with normal slicing. - Coconut's iterator slicing is very similar to Python's `itertools.islice`, but unlike `itertools.islice`, Coconut's iterator slicing supports negative indices, and will preferentially call an object's `__iter_getitem__` (Coconut-specific magic method, preferred) or `__getitem__` (general Python magic method), if they exist. Coconut's iterator slicing is also optimized to work well with all of Coconut's built-in objects, only computing the elements of each that are actually necessary to extract the desired slice. + Coconut's iterator slicing is very similar to Python's `itertools.islice`, but unlike `itertools.islice`, Coconut's iterator slicing supports negative indices, and will preferentially call an object's `__iter_getitem__` (always used if available) or `__getitem__` (only used if the object is a collections.abc.Sequence). Coconut's iterator slicing is also optimized to work well with all of Coconut's built-in objects, only computing the elements of each that are actually necessary to extract the desired slice. Some code taken from more_itertools under the terms of its MIT license. """ obj_iter_getitem = _coconut.getattr(iterable, "__iter_getitem__", None) - if obj_iter_getitem is None: + if obj_iter_getitem is None and _coconut.isinstance(iterable, _coconut.abc.Sequence): obj_iter_getitem = _coconut.getattr(iterable, "__getitem__", None) if obj_iter_getitem is not None: try: @@ -216,8 +255,7 @@ def _coconut_iter_getitem(iterable, index): except _coconut.NotImplementedError: pass else: - if result is not _coconut.NotImplemented: - return result + return result if not _coconut.isinstance(index, _coconut.slice): index = _coconut.operator.index(index) if index < 0: @@ -258,7 +296,7 @@ def _coconut_iter_getitem(iterable, index): return () if n < -start or step != 1: cache = _coconut.itertools.islice(cache, 0, n, step) - return _coconut_map(_coconut.operator.itemgetter(1), cache) + return {_coconut_}map(_coconut.operator.itemgetter(1), cache) elif stop is None or stop >= 0: return _coconut.itertools.islice(iterable, start, stop, step) else: @@ -273,7 +311,7 @@ def _coconut_iter_getitem(iterable, index): i, j = start, stop else: i, j = _coconut.min(start - len_iter, -1), None - return _coconut_map(_coconut.operator.itemgetter(1), _coconut.tuple(cache)[i:j:step]) + return {_coconut_}map(_coconut.operator.itemgetter(1), _coconut.tuple(cache)[i:j:step]) else: if stop is not None: m = stop + 1 @@ -292,7 +330,7 @@ def _coconut_iter_getitem(iterable, index): return () iterable = _coconut.itertools.islice(iterable, 0, n) return _coconut.tuple(iterable)[i::step] -class _coconut_base_compose(_coconut_base_hashable): +class _coconut_base_compose(_coconut_baseclass): __slots__ = ("func", "func_infos") def __init__(self, func, *func_infos): self.func = func @@ -439,6 +477,12 @@ def _coconut_bool_and(a, b): def _coconut_bool_or(a, b): """Boolean or operator (or). Equivalent to (a, b) -> a or b.""" return a or b +def _coconut_in(a, b): + """Containment operator (in). Equivalent to (a, b) -> a in b.""" + return a in b +def _coconut_not_in(a, b): + """Negative containment operator (not in). Equivalent to (a, b) -> a not in b.""" + return a not in b def _coconut_none_coalesce(a, b): """None coalescing operator (??). Equivalent to (a, b) -> a if a is not None else b.""" return b if a is None else a @@ -517,7 +561,7 @@ class reversed(_coconut_has_iter): """Find the index of elem in the reversed iterable.""" return _coconut.len(self.iter) - self.iter.index(elem) - 1 def __fmap__(self, func): - return self.__class__(_coconut_map(func, self.iter)) + return self.__class__({_coconut_}map(func, self.iter)) class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_become_very_innefficient} """Flatten an iterable of iterables into a single iterable. Only flattens the top level of the iterable.""" @@ -538,9 +582,9 @@ class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_bec with self.lock: if not self._made_reit: for i in _coconut.reversed(_coconut.range(0 if self.levels is None else self.levels + 1)): - mapper = _coconut_reiterable + mapper = {_coconut_}reiterable for _ in _coconut.range(i): - mapper = _coconut.functools.partial(_coconut_map, mapper) + mapper = _coconut.functools.partial({_coconut_}map, mapper) self.iter = mapper(self.iter) self._made_reit = True return self.iter @@ -564,9 +608,9 @@ class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_bec return _coconut.reversed(_coconut.tuple(self._iter_all_levels(new=True))) reversed_iter = self.get_new_iter() for i in _coconut.reversed(_coconut.range(self.levels + 1)): - reverser = _coconut_reversed + reverser = {_coconut_}reversed for _ in _coconut.range(i): - reverser = _coconut.functools.partial(_coconut_map, reverser) + reverser = _coconut.functools.partial({_coconut_}map, reverser) reversed_iter = reverser(reversed_iter) return self.__class__(reversed_iter, self.levels) def __repr__(self): @@ -597,9 +641,9 @@ class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_bec raise ValueError("%r not in %r" % (elem, self)) def __fmap__(self, func): if self.levels == 1: - return self.__class__(_coconut_map(_coconut.functools.partial(_coconut_map, func), self.get_new_iter())) - return _coconut_map(func, self) -class cartesian_product(_coconut_base_hashable): + return self.__class__({_coconut_}map(_coconut.functools.partial({_coconut_}map, func), self.get_new_iter())) + return {_coconut_}map(func, self) +class cartesian_product(_coconut_baseclass): __slots__ = ("iters", "repeat") __doc__ = getattr(_coconut.itertools.product, "__doc__", "Cartesian product of input iterables.") + """ @@ -613,17 +657,21 @@ Additionally supports Cartesian products of numpy arrays.""" repeat = 1 if repeat < 0: raise _coconut.ValueError("cartesian_product: repeat cannot be negative") - if iterables and _coconut.all(it.__class__.__module__ in _coconut.numpy_modules for it in iterables): - if _coconut.any(it.__class__.__module__ in _coconut.jax_numpy_modules for it in iterables): - from jax import numpy - else: - numpy = _coconut.numpy - iterables *= repeat - dtype = numpy.result_type(*iterables) - arr = numpy.empty([_coconut.len(a) for a in iterables] + [_coconut.len(iterables)], dtype=dtype) - for i, a in _coconut.enumerate(numpy.ix_(*iterables)): - arr[..., i] = a - return arr.reshape(-1, _coconut.len(iterables)) + if iterables: + it_modules = [_coconut_get_base_module(it) for it in iterables] + if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules): + if _coconut.any(mod in _coconut.pandas_numpy_modules for mod in it_modules): + iterables = tuple((it.to_numpy() if _coconut_get_base_module(it) in _coconut.pandas_numpy_modules else it) for it in iterables) + if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules): + from jax import numpy + else: + numpy = _coconut.numpy + iterables *= repeat + dtype = numpy.result_type(*iterables) + arr = numpy.empty([_coconut.len(a) for a in iterables] + [_coconut.len(iterables)], dtype=dtype) + for i, a in _coconut.enumerate(numpy.ix_(*iterables)): + arr[..., i] = a + return arr.reshape(-1, _coconut.len(iterables)) self = _coconut.object.__new__(cls) self.iters = iterables self.repeat = repeat @@ -635,7 +683,7 @@ Additionally supports Cartesian products of numpy arrays.""" def __reduce__(self): return (self.__class__, self.iters, {lbrace}"repeat": self.repeat{rbrace}) def __copy__(self): - self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + self.iters = _coconut.tuple({_coconut_}reiterable(it) for it in self.iters) return self.__class__(*self.iters, repeat=self.repeat) @property def all_iters(self): @@ -663,8 +711,8 @@ Additionally supports Cartesian products of numpy arrays.""" return total_count return total_count def __fmap__(self, func): - return _coconut_map(func, self) -class map(_coconut_base_hashable, _coconut.map): + return {_coconut_}map(func, self) +class map(_coconut_baseclass, _coconut.map): __slots__ = ("func", "iters") __doc__ = getattr(_coconut.map, "__doc__", "") def __new__(cls, function, *iterables, **kwargs): @@ -672,7 +720,7 @@ class map(_coconut_base_hashable, _coconut.map): if kwargs: raise _coconut.TypeError(cls.__name__ + "() got unexpected keyword arguments " + _coconut.repr(kwargs)) if strict and _coconut.len(iterables) > 1: - return _coconut_starmap(function, _coconut_zip(*iterables, strict=True)) + return {_coconut_}starmap(function, {_coconut_}zip(*iterables, strict=True)) self = _coconut.map.__new__(cls, function, *iterables) self.func = function self.iters = iterables @@ -682,7 +730,7 @@ class map(_coconut_base_hashable, _coconut.map): return self.__class__(self.func, *(_coconut_iter_getitem(it, index) for it in self.iters)) return self.func(*(_coconut_iter_getitem(it, index) for it in self.iters)) def __reversed__(self): - return self.__class__(self.func, *(_coconut_reversed(it) for it in self.iters)) + return self.__class__(self.func, *({_coconut_}reversed(it) for it in self.iters)) def __len__(self): if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters): return _coconut.NotImplemented @@ -692,13 +740,13 @@ class map(_coconut_base_hashable, _coconut.map): def __reduce__(self): return (self.__class__, (self.func,) + self.iters) def __copy__(self): - self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + self.iters = _coconut.tuple({_coconut_}reiterable(it) for it in self.iters) return self.__class__(self.func, *self.iters) def __iter__(self): return _coconut.iter(_coconut.map(self.func, *self.iters)) def __fmap__(self, func): return self.__class__(_coconut_forward_compose(self.func, func), *self.iters) -class _coconut_parallel_concurrent_map_func_wrapper(_coconut_base_hashable): +class _coconut_parallel_concurrent_map_func_wrapper(_coconut_baseclass): __slots__ = ("map_cls", "func", "star") def __init__(self, map_cls, func, star): self.map_cls = map_cls @@ -726,7 +774,7 @@ class _coconut_base_parallel_concurrent_map(map): def get_pool_stack(cls): return cls.threadlocal_ns.__dict__.setdefault("pool_stack", [None]) def __new__(cls, function, *iterables, **kwargs): - self = _coconut_map.__new__(cls, function, *iterables) + self = {_coconut_}map.__new__(cls, function, *iterables) self.result = None self.chunksize = kwargs.pop("chunksize", 1) self.strict = kwargs.pop("strict", False) @@ -754,10 +802,10 @@ class _coconut_base_parallel_concurrent_map(map): if _coconut.len(self.iters) == 1: self.result = _coconut.list(self.get_pool_stack()[-1].imap(_coconut_parallel_concurrent_map_func_wrapper(self.__class__, self.func, False), self.iters[0], self.chunksize)) elif self.strict: - self.result = _coconut.list(self.get_pool_stack()[-1].imap(_coconut_parallel_concurrent_map_func_wrapper(self.__class__, self.func, True), _coconut_zip(*self.iters, strict=True), self.chunksize)) + self.result = _coconut.list(self.get_pool_stack()[-1].imap(_coconut_parallel_concurrent_map_func_wrapper(self.__class__, self.func, True), {_coconut_}zip(*self.iters, strict=True), self.chunksize)) else: self.result = _coconut.list(self.get_pool_stack()[-1].imap(_coconut_parallel_concurrent_map_func_wrapper(self.__class__, self.func, True), _coconut.zip(*self.iters), self.chunksize)) - self.func = _coconut_ident + self.func = {_coconut_}ident self.iters = (self.result,) return self.result def __iter__(self): @@ -786,7 +834,7 @@ class concurrent_map(_coconut_base_parallel_concurrent_map): @staticmethod def make_pool(max_workers=None): return _coconut.multiprocessing_dummy.Pool(_coconut.multiprocessing.cpu_count() * 5 if max_workers is None else max_workers) -class zip(_coconut_base_hashable, _coconut.zip): +class zip(_coconut_baseclass, _coconut.zip): __slots__ = ("iters", "strict") __doc__ = getattr(_coconut.zip, "__doc__", "") def __new__(cls, *iterables, **kwargs): @@ -801,7 +849,7 @@ class zip(_coconut_base_hashable, _coconut.zip): return self.__class__(*(_coconut_iter_getitem(it, index) for it in self.iters), strict=self.strict) return _coconut.tuple(_coconut_iter_getitem(it, index) for it in self.iters) def __reversed__(self): - return self.__class__(*(_coconut_reversed(it) for it in self.iters), strict=self.strict) + return self.__class__(*({_coconut_}reversed(it) for it in self.iters), strict=self.strict) def __len__(self): if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters): return _coconut.NotImplemented @@ -811,17 +859,17 @@ class zip(_coconut_base_hashable, _coconut.zip): def __reduce__(self): return (self.__class__, self.iters, {lbrace}"strict": self.strict{rbrace}) def __copy__(self): - self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + self.iters = _coconut.tuple({_coconut_}reiterable(it) for it in self.iters) return self.__class__(*self.iters, strict=self.strict) def __iter__(self): {zip_iter} def __fmap__(self, func): - return _coconut_map(func, self) + return {_coconut_}map(func, self) class zip_longest(zip): __slots__ = ("fillvalue",) __doc__ = getattr(_coconut.zip_longest, "__doc__", "Version of zip that fills in missing values with fillvalue.") def __new__(cls, *iterables, **kwargs): - self = _coconut_zip.__new__(cls, *iterables, strict=False) + self = {_coconut_}zip.__new__(cls, *iterables, strict=False) self.fillvalue = kwargs.pop("fillvalue", None) if kwargs: raise _coconut.TypeError(cls.__name__ + "() got unexpected keyword arguments " + _coconut.repr(kwargs)) @@ -862,11 +910,11 @@ class zip_longest(zip): def __reduce__(self): return (self.__class__, self.iters, {lbrace}"fillvalue": self.fillvalue{rbrace}) def __copy__(self): - self.iters = _coconut.tuple(_coconut_reiterable(it) for it in self.iters) + self.iters = _coconut.tuple({_coconut_}reiterable(it) for it in self.iters) return self.__class__(*self.iters, fillvalue=self.fillvalue) def __iter__(self): return _coconut.iter(_coconut.zip_longest(*self.iters, fillvalue=self.fillvalue)) -class filter(_coconut_base_hashable, _coconut.filter): +class filter(_coconut_baseclass, _coconut.filter): __slots__ = ("func", "iter") __doc__ = getattr(_coconut.filter, "__doc__", "") def __new__(cls, function, iterable): @@ -875,19 +923,19 @@ class filter(_coconut_base_hashable, _coconut.filter): self.iter = iterable return self def __reversed__(self): - return self.__class__(self.func, _coconut_reversed(self.iter)) + return self.__class__(self.func, {_coconut_}reversed(self.iter)) def __repr__(self): return "filter(%r, %s)" % (self.func, _coconut.repr(self.iter)) def __reduce__(self): return (self.__class__, (self.func, self.iter)) def __copy__(self): - self.iter = _coconut_reiterable(self.iter) + self.iter = {_coconut_}reiterable(self.iter) return self.__class__(self.func, self.iter) def __iter__(self): return _coconut.iter(_coconut.filter(self.func, self.iter)) def __fmap__(self, func): - return _coconut_map(func, self) -class enumerate(_coconut_base_hashable, _coconut.enumerate): + return {_coconut_}map(func, self) +class enumerate(_coconut_baseclass, _coconut.enumerate): __slots__ = ("iter", "start") __doc__ = getattr(_coconut.enumerate, "__doc__", "") def __new__(cls, iterable, start=0): @@ -899,11 +947,11 @@ class enumerate(_coconut_base_hashable, _coconut.enumerate): def __repr__(self): return "enumerate(%s, %r)" % (_coconut.repr(self.iter), self.start) def __fmap__(self, func): - return _coconut_map(func, self) + return {_coconut_}map(func, self) def __reduce__(self): return (self.__class__, (self.iter, self.start)) def __copy__(self): - self.iter = _coconut_reiterable(self.iter) + self.iter = {_coconut_}reiterable(self.iter) return self.__class__(self.iter, self.start) def __iter__(self): return _coconut.iter(_coconut.enumerate(self.iter, self.start)) @@ -921,7 +969,7 @@ class multi_enumerate(_coconut_has_iter): in each inner iterable. Supports indexing. For numpy arrays, effectively equivalent to: - it = np.nditer(iterable, flags=["multi_index"]) + it = np.nditer(iterable, flags=["multi_index", "refs_ok"]) for x in it: yield it.multi_index, x @@ -936,11 +984,12 @@ class multi_enumerate(_coconut_has_iter): return self.__class__(self.get_new_iter()) @property def is_numpy(self): - return self.iter.__class__.__module__ in _coconut.numpy_modules + return _coconut_get_base_module(self.iter) in _coconut.numpy_modules def __iter__(self): if self.is_numpy: - it = _coconut.numpy.nditer(self.iter, flags=["multi_index"]) + it = _coconut.numpy.nditer(self.iter, ["multi_index", "refs_ok"], [["readonly"]]) for x in it: + x, = x.flatten() yield it.multi_index, x else: ind = [-1] @@ -971,7 +1020,7 @@ class multi_enumerate(_coconut_has_iter): if self.is_numpy: return self.iter.size return _coconut.NotImplemented -class count(_coconut_base_hashable): +class count(_coconut_baseclass): __slots__ = ("start", "step") __doc__ = getattr(_coconut.itertools.count, "__doc__", "count(start, step) returns an infinite iterator starting at start and increasing by step.") def __init__(self, start=0, step=1): @@ -987,7 +1036,7 @@ class count(_coconut_base_hashable): if self.step: self.start += self.step def __fmap__(self, func): - return _coconut_map(func, self) + return {_coconut_}map(func, self) def __contains__(self, elem): if not self.step: return elem == self.start @@ -1006,7 +1055,7 @@ class count(_coconut_base_hashable): return self.__class__(new_start, new_step) if self.step and _coconut.isinstance(self.start, _coconut.int) and _coconut.isinstance(self.step, _coconut.int): return _coconut.range(new_start, self.start + self.step * index.stop, new_step) - return _coconut_map(self.__getitem__, _coconut.range(index.start if index.start is not None else 0, index.stop, index.step if index.step is not None else 1)) + return {_coconut_}map(self.__getitem__, _coconut.range(index.start if index.start is not None else 0, index.stop, index.step if index.step is not None else 1)) raise _coconut.IndexError("count() indices cannot be negative") if index < 0: raise _coconut.IndexError("count() indices cannot be negative") @@ -1059,9 +1108,9 @@ class cycle(_coconut_has_iter): raise _coconut.IndexError("cycle index out of range") return self.iter[index % _coconut.len(self.iter)] if self.times is None: - return _coconut_map(self.__getitem__, _coconut_count()[index]) + return {_coconut_}map(self.__getitem__, {_coconut_}count()[index]) else: - return _coconut_map(self.__getitem__, _coconut_range(0, _coconut.len(self))[index]) + return {_coconut_}map(self.__getitem__, {_coconut_}range(0, _coconut.len(self))[index]) def __len__(self): if self.times is None or not _coconut.isinstance(self.iter, _coconut.abc.Sized): return _coconut.NotImplemented @@ -1069,7 +1118,7 @@ class cycle(_coconut_has_iter): def __reversed__(self): if self.times is None: raise _coconut.TypeError(_coconut.repr(self) + " object is not reversible") - return self.__class__(_coconut_reversed(self.get_new_iter()), self.times) + return self.__class__({_coconut_}reversed(self.get_new_iter()), self.times) def count(self, elem): """Count the number of times elem appears in the cycle.""" return self.iter.count(elem) * (float("inf") if self.times is None else self.times) @@ -1160,7 +1209,7 @@ class groupsof(_coconut_has_iter): return (self.__class__, (self.group_size, self.iter)) def __copy__(self): return self.__class__(self.group_size, self.get_new_iter()) -class recursive_iterator(_coconut_base_hashable): +class recursive_iterator(_coconut_baseclass): """Decorator that optimizes a recursive function that returns an iterator (e.g. a recursive generator).""" __slots__ = ("func", "reit_store", "backup_reit_store") def __init__(self, func): @@ -1181,13 +1230,13 @@ class recursive_iterator(_coconut_base_hashable): for k, v in self.backup_reit_store: if k == key: return reit - reit = _coconut_reiterable(self.func(*args, **kwargs)) + reit = {_coconut_}reiterable(self.func(*args, **kwargs)) self.backup_reit_store.append([key, reit]) return reit else: reit = self.reit_store.get(key) if reit is None: - reit = _coconut_reiterable(self.func(*args, **kwargs)) + reit = {_coconut_}reiterable(self.func(*args, **kwargs)) self.reit_store[key] = reit return reit def __repr__(self): @@ -1198,7 +1247,7 @@ class recursive_iterator(_coconut_base_hashable): if obj is None: return self {return_method_of_self} -class _coconut_FunctionMatchErrorContext{object}: +class _coconut_FunctionMatchErrorContext(_coconut_baseclass): __slots__ = ("exc_class", "taken") threadlocal_ns = _coconut.threading.local() def __init__(self, exc_class): @@ -1206,28 +1255,26 @@ class _coconut_FunctionMatchErrorContext{object}: self.taken = False @classmethod def get_contexts(cls): - try: - return cls.threadlocal_ns.contexts - except _coconut.AttributeError: - cls.threadlocal_ns.contexts = [] - return cls.threadlocal_ns.contexts + return cls.threadlocal_ns.__dict__.setdefault("contexts", []) def __enter__(self): self.get_contexts().append(self) def __exit__(self, type, value, traceback): self.get_contexts().pop() + def __reduce__(self): + return (self.__class__, (self.exc_class,)) def _coconut_get_function_match_error(): - try: - ctx = _coconut_FunctionMatchErrorContext.get_contexts()[-1] - except _coconut.IndexError: - return _coconut_MatchError + contexts = _coconut_FunctionMatchErrorContext.get_contexts() + if not contexts: + return {_coconut_}MatchError + ctx = contexts[-1] if ctx.taken: - return _coconut_MatchError + return {_coconut_}MatchError ctx.taken = True return ctx.exc_class -class _coconut_base_pattern_func(_coconut_base_hashable):{COMMENT.no_slots_to_allow_func_attrs} +class _coconut_base_pattern_func(_coconut_baseclass):{COMMENT.no_slots_to_allow_func_attrs} _coconut_is_match = True def __init__(self, *funcs): - self.FunctionMatchError = _coconut.type(_coconut_py_str("MatchError"), (_coconut_MatchError,), {empty_dict}) + self.FunctionMatchError = _coconut.type(_coconut_py_str("MatchError"), ({_coconut_}MatchError,), {empty_py_dict}) self.patterns = [] self.__doc__ = None self.__name__ = None @@ -1285,7 +1332,7 @@ def addpattern(base_func, *add_funcs, **kwargs): return _coconut.functools.partial(_coconut_base_pattern_func, base_func) _coconut_addpattern = addpattern {def_prepattern} -class _coconut_partial(_coconut_base_hashable): +class _coconut_partial(_coconut_baseclass): __slots__ = ("func", "_argdict", "_arglen", "_pos_kwargs", "_stargs", "keywords") def __init__(self, _coconut_func, _coconut_argdict, _coconut_arglen, _coconut_pos_kwargs, *args, **kwargs): self.func = _coconut_func @@ -1343,7 +1390,7 @@ class _coconut_partial(_coconut_base_hashable): def consume(iterable, keep_last=0): """consume(iterable, keep_last) fully exhausts iterable and returns the last keep_last elements.""" return _coconut.collections.deque(iterable, maxlen=keep_last) -class starmap(_coconut_base_hashable, _coconut.itertools.starmap): +class starmap(_coconut_baseclass, _coconut.itertools.starmap): __slots__ = ("func", "iter") __doc__ = getattr(_coconut.itertools.starmap, "__doc__", "starmap(func, iterable) = (func(*args) for args in iterable)") def __new__(cls, function, iterable): @@ -1356,7 +1403,7 @@ class starmap(_coconut_base_hashable, _coconut.itertools.starmap): return self.__class__(self.func, _coconut_iter_getitem(self.iter, index)) return self.func(*_coconut_iter_getitem(self.iter, index)) def __reversed__(self): - return self.__class__(self.func, *_coconut_reversed(self.iter)) + return self.__class__(self.func, *{_coconut_}reversed(self.iter)) def __len__(self): if not _coconut.isinstance(self.iter, _coconut.abc.Sized): return _coconut.NotImplemented @@ -1366,7 +1413,7 @@ class starmap(_coconut_base_hashable, _coconut.itertools.starmap): def __reduce__(self): return (self.__class__, (self.func, self.iter)) def __copy__(self): - self.iter = _coconut_reiterable(self.iter) + self.iter = {_coconut_}reiterable(self.iter) return self.__class__(self.func, self.iter) def __iter__(self): return _coconut.iter(_coconut.itertools.starmap(self.func, self.iter)) @@ -1406,7 +1453,7 @@ class multiset(_coconut.collections.Counter{comma_object}): if result < 0: raise _coconut.ValueError("multiset has negative count for " + _coconut.repr(item)) return result -{def_total_and_comparisons}_coconut.abc.MutableSet.register(multiset) +{def_total_and_comparisons}{assign_multiset_views}_coconut.abc.MutableSet.register(multiset) def _coconut_base_makedata(data_type, args): if _coconut.hasattr(data_type, "_make") and _coconut.issubclass(data_type, _coconut.tuple): return data_type._make(args) @@ -1438,10 +1485,15 @@ def fmap(func, obj, **kwargs): else: if result is not _coconut.NotImplemented: return result - if obj.__class__.__module__ in _coconut.jax_numpy_modules: + obj_module = _coconut_get_base_module(obj) + if obj_module in _coconut.pandas_numpy_modules: + if obj.ndim <= 1: + return obj.apply(func) + return obj.apply(func, axis=obj.ndim-1) + if obj_module in _coconut.jax_numpy_modules: import jax.numpy as jnp return jnp.vectorize(func)(obj) - if obj.__class__.__module__ in _coconut.numpy_modules: + if obj_module in _coconut.numpy_modules: return _coconut.numpy.vectorize(func)(obj) obj_aiter = _coconut.getattr(obj, "__aiter__", None) if obj_aiter is not None and _coconut_amap is not None: @@ -1453,9 +1505,9 @@ def fmap(func, obj, **kwargs): if aiter is not _coconut.NotImplemented: return _coconut_amap(func, aiter) if starmap_over_mappings: - return _coconut_base_makedata(obj.__class__, _coconut_starmap(func, obj.items()) if _coconut.isinstance(obj, _coconut.abc.Mapping) else _coconut_map(func, obj)) + return _coconut_base_makedata(obj.__class__, {_coconut_}starmap(func, obj.items()) if _coconut.isinstance(obj, _coconut.abc.Mapping) else {_coconut_}map(func, obj)) else: - return _coconut_base_makedata(obj.__class__, _coconut_map(func, obj.items() if _coconut.isinstance(obj, _coconut.abc.Mapping) else obj)) + return _coconut_base_makedata(obj.__class__, {_coconut_}map(func, obj.items() if _coconut.isinstance(obj, _coconut.abc.Mapping) else obj)) def _coconut_memoize_helper(maxsize=None, typed=False): return maxsize, typed def memoize(*args, **kwargs): @@ -1468,7 +1520,7 @@ def memoize(*args, **kwargs): maxsize, typed = _coconut_memoize_helper(*args, **kwargs) return _coconut.functools.lru_cache(maxsize, typed) {def_call_set_names} -class override(_coconut_base_hashable): +class override(_coconut_baseclass): __slots__ = ("func",) def __init__(self, func): self.func = func @@ -1489,32 +1541,6 @@ def reveal_locals(): """Special function to get MyPy to print the type of the current locals. At runtime, reveal_locals always returns None.""" pass -def _coconut_handle_cls_kwargs(**kwargs): - """Some code taken from six under the terms of its MIT license.""" - metaclass = kwargs.pop("metaclass", None) - if kwargs and metaclass is None: - raise _coconut.TypeError("unexpected keyword argument(s) in class definition: %r" % (kwargs,)) - def coconut_handle_cls_kwargs_wrapper(cls): - if metaclass is None: - return cls - orig_vars = cls.__dict__.copy() - slots = orig_vars.get("__slots__") - if slots is not None: - if _coconut.isinstance(slots, _coconut.str): - slots = [slots] - for slots_var in slots: - orig_vars.pop(slots_var) - orig_vars.pop("__dict__", None) - orig_vars.pop("__weakref__", None) - if _coconut.hasattr(cls, "__qualname__"): - orig_vars["__qualname__"] = cls.__qualname__ - return metaclass(cls.__name__, cls.__bases__, orig_vars, **kwargs) - return coconut_handle_cls_kwargs_wrapper -def _coconut_handle_cls_stargs(*args): - temp_names = ["_coconut_base_cls_%s" % (i,) for i in _coconut.range(_coconut.len(args))] - ns = _coconut.dict(_coconut.zip(temp_names, args)) - _coconut_exec("class _coconut_cls_stargs_base(" + ", ".join(temp_names) + "): pass", ns) - return ns["_coconut_cls_stargs_base"] def _coconut_dict_merge(*dicts, **kwargs): for_func = kwargs.pop("for_func", False) assert not kwargs, "error with internal Coconut function _coconut_dict_merge {report_this_text}" @@ -1556,9 +1582,9 @@ def safe_call(_coconut_f{comma_slash}, *args, **kwargs): return Expected(error=err) """ try: - return _coconut_Expected(_coconut_f(*args, **kwargs)) + return {_coconut_}Expected(_coconut_f(*args, **kwargs)) except _coconut.Exception as err: - return _coconut_Expected(error=err) + return {_coconut_}Expected(error=err) class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){comma_object}): '''Coconut's Expected built-in is a Coconut data that represents a value that may or may not be an error, similar to Haskell's Either. @@ -1580,6 +1606,9 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){ if not self.result `isinstance` Expected: raise TypeError("Expected.join() requires an Expected[Expected[_]]") return self.result + def map_error(self, func: BaseException -> BaseException) -> Expected[T]: + """Maps func over the error if it exists.""" + return self if self else self.__class__(error=func(self.error)) def or_else[U](self, func: BaseException -> Expected[U]) -> Expected[T | U]: """Return self if no error, otherwise return the result of evaluating func on the error.""" return self if self else func(self.error) @@ -1618,7 +1647,11 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){ def __bool__(self): return self.error is None def __fmap__(self, func): - return self if not self else self.__class__(func(self.result)) + """Maps func over the result if it exists. + + __fmap__ should be used directly only when fmap is not available (e.g. when consuming an Expected in vanilla Python). + """ + return self.__class__(func(self.result)) if self else self def and_then(self, func): """Maps a T -> Expected[U] over an Expected[T] to produce an Expected[U]. Implements a monadic bind. Equivalent to fmap ..> .join().""" @@ -1627,15 +1660,18 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){ """Monadic join. Converts Expected[Expected[T]] to Expected[T].""" if not self: return self - if not _coconut.isinstance(self.result, _coconut_Expected): + if not _coconut.isinstance(self.result, {_coconut_}Expected): raise _coconut.TypeError("Expected.join() requires an Expected[Expected[_]]") return self.result + def map_error(self, func): + """Maps func over the error if it exists.""" + return self if self else self.__class__(error=func(self.error)) def or_else(self, func): """Return self if no error, otherwise return the result of evaluating func on the error.""" if self: return self got = func(self.error) - if not _coconut.isinstance(got, _coconut_Expected): + if not _coconut.isinstance(got, {_coconut_}Expected): raise _coconut.TypeError("Expected.or_else() requires a function that returns an Expected") return got def result_or(self, default): @@ -1649,7 +1685,7 @@ class Expected(_coconut.collections.namedtuple("Expected", ("result", "error")){ if not self: raise self.error return self.result -class flip(_coconut_base_hashable): +class flip(_coconut_baseclass): """Given a function, return a new function with inverse argument order. If nargs is passed, only the first nargs arguments are reversed.""" __slots__ = ("func", "nargs") @@ -1671,7 +1707,7 @@ class flip(_coconut_base_hashable): return self.func(*(args[self.nargs-1::-1] + args[self.nargs:]), **kwargs) def __repr__(self): return "flip(%r%s)" % (self.func, "" if self.nargs is None else ", " + _coconut.repr(self.nargs)) -class const(_coconut_base_hashable): +class const(_coconut_baseclass): """Create a function that, whatever its arguments, just returns the given value.""" __slots__ = ("value",) def __init__(self, value): @@ -1682,7 +1718,7 @@ class const(_coconut_base_hashable): return self.value def __repr__(self): return "const(%s)" % (_coconut.repr(self.value),) -class _coconut_lifted(_coconut_base_hashable): +class _coconut_lifted(_coconut_baseclass): __slots__ = ("func", "func_args", "func_kwargs") def __init__(self, _coconut_func, *func_args, **func_kwargs): self.func = _coconut_func @@ -1691,10 +1727,10 @@ class _coconut_lifted(_coconut_base_hashable): def __reduce__(self): return (self.__class__, (self.func,) + self.func_args, {lbrace}"func_kwargs": self.func_kwargs{rbrace}) def __call__(self, *args, **kwargs): - return self.func(*(g(*args, **kwargs) for g in self.func_args), **_coconut.dict((k, h(*args, **kwargs)) for k, h in self.func_kwargs.items())) + return self.func(*(g(*args, **kwargs) for g in self.func_args), **_coconut_py_dict((k, h(*args, **kwargs)) for k, h in self.func_kwargs.items())) def __repr__(self): return "lift(%r)(%s%s)" % (self.func, ", ".join(_coconut.repr(g) for g in self.func_args), ", ".join(k + "=" + _coconut.repr(h) for k, h in self.func_kwargs.items())) -class lift(_coconut_base_hashable): +class lift(_coconut_baseclass): """Lifts a function up so that all of its arguments are functions. For a binary function f(x, y) and two unary functions g(z) and h(z), lift works as the S' combinator: @@ -1724,7 +1760,10 @@ def all_equal(iterable): Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'. """ - if iterable.__class__.__module__ in _coconut.numpy_modules: + iterable_module = _coconut_get_base_module(iterable) + if iterable_module in _coconut.numpy_modules: + if iterable_module in _coconut.pandas_numpy_modules: + iterable = iterable.to_numpy() return not _coconut.len(iterable) or (iterable == iterable[0]).all() first_item = _coconut_sentinel for item in iterable: @@ -1767,22 +1806,27 @@ def _coconut_mk_anon_namedtuple(fields, types=None, of_kwargs=None): return NT return NT(**of_kwargs) def _coconut_ndim(arr): - if (arr.__class__.__module__ in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): + if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): return arr.ndim - if not _coconut.isinstance(arr, _coconut.abc.Sequence): + if not _coconut.isinstance(arr, _coconut.abc.Sequence) or _coconut.isinstance(arr, (_coconut.str, _coconut.bytes)): return 0 if _coconut.len(arr) == 0: return 1 arr_dim = 1 inner_arr = arr[0] + if inner_arr == arr: + return 0 while _coconut.isinstance(inner_arr, _coconut.abc.Sequence): arr_dim += 1 if _coconut.len(inner_arr) < 1: break - inner_arr = inner_arr[0] + new_inner_arr = inner_arr[0] + if new_inner_arr == inner_arr: + break + inner_arr = new_inner_arr return arr_dim def _coconut_expand_arr(arr, new_dims): - if (arr.__class__.__module__ in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "reshape"): + if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "reshape"): return arr.reshape((1,) * new_dims + arr.shape) for _ in _coconut.range(new_dims): arr = [arr] @@ -1790,17 +1834,21 @@ def _coconut_expand_arr(arr, new_dims): def _coconut_concatenate(arrs, axis): matconcat = None for a in arrs: - if a.__class__.__module__ in _coconut.jax_numpy_modules: + if _coconut.hasattr(a.__class__, "__matconcat__"): + matconcat = a.__class__.__matconcat__ + break + a_module = _coconut_get_base_module(a) + if a_module in _coconut.pandas_numpy_modules: + from pandas import concat as matconcat + break + if a_module in _coconut.jax_numpy_modules: from jax.numpy import concatenate as matconcat break - if a.__class__.__module__ in _coconut.numpy_modules: + if a_module in _coconut.numpy_modules: matconcat = _coconut.numpy.concatenate break - if _coconut.hasattr(a.__class__, "__matconcat__"): - matconcat = a.__class__.__matconcat__ - break if matconcat is not None: - return matconcat(arrs, axis) + return matconcat(arrs, axis=axis) if not axis: return _coconut.list(_coconut.itertools.chain.from_iterable(arrs)) return [_coconut_concatenate(rows, axis - 1) for rows in _coconut.zip(*arrs)] @@ -1810,5 +1858,58 @@ def _coconut_multi_dim_arr(arrs, dim): arr_dims.append(dim) max_arr_dim = _coconut.max(arr_dims) return _coconut_concatenate(arrs, max_arr_dim - dim) +def _coconut_call_or_coefficient(func, *args): + if _coconut.callable(func): + return func(*args) + if not _coconut.isinstance(func, (_coconut.int, _coconut.float, _coconut.complex)) and _coconut_get_base_module(func) not in _coconut.numpy_modules: + raise _coconut.TypeError("implicit function application and coefficient syntax only supported for Callable, int, float, complex, and numpy objects") + func = func + for x in args: + func = func * x{COMMENT.no_times_equals_to_avoid_modification} + return func +class _coconut_SupportsAdd(_coconut.typing.Protocol): + def __add__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((+) in a typing context is a Protocol)") +class _coconut_SupportsMinus(_coconut.typing.Protocol): + def __sub__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)") + def __neg__(self): + raise NotImplementedError("Protocol methods cannot be called at runtime ((-) in a typing context is a Protocol)") +class _coconut_SupportsMul(_coconut.typing.Protocol): + def __mul__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((*) in a typing context is a Protocol)") +class _coconut_SupportsPow(_coconut.typing.Protocol): + def __pow__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((**) in a typing context is a Protocol)") +class _coconut_SupportsTruediv(_coconut.typing.Protocol): + def __truediv__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((/) in a typing context is a Protocol)") +class _coconut_SupportsFloordiv(_coconut.typing.Protocol): + def __floordiv__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((//) in a typing context is a Protocol)") +class _coconut_SupportsMod(_coconut.typing.Protocol): + def __mod__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((%) in a typing context is a Protocol)") +class _coconut_SupportsAnd(_coconut.typing.Protocol): + def __and__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((&) in a typing context is a Protocol)") +class _coconut_SupportsXor(_coconut.typing.Protocol): + def __xor__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((^) in a typing context is a Protocol)") +class _coconut_SupportsOr(_coconut.typing.Protocol): + def __or__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((|) in a typing context is a Protocol)") +class _coconut_SupportsLshift(_coconut.typing.Protocol): + def __lshift__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((<<) in a typing context is a Protocol)") +class _coconut_SupportsRshift(_coconut.typing.Protocol): + def __rshift__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((>>) in a typing context is a Protocol)") +class _coconut_SupportsMatmul(_coconut.typing.Protocol): + def __matmul__(self, other): + raise NotImplementedError("Protocol methods cannot be called at runtime ((@) in a typing context is a Protocol)") +class _coconut_SupportsInv(_coconut.typing.Protocol): + def __invert__(self): + raise NotImplementedError("Protocol methods cannot be called at runtime ((~) in a typing context is a Protocol)") _coconut_self_match_types = {self_match_types} _coconut_Expected, _coconut_MatchError, _coconut_count, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_ident, _coconut_map, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_starmap, _coconut_tee, _coconut_zip, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, count, enumerate, flatten, filter, ident, map, multiset, range, reiterable, reversed, starmap, tee, zip, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index bb2edbbcb..126136543 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -93,6 +93,7 @@ indchars, comment_chars, non_syntactic_newline, + allow_explicit_keyword_vars, ) from coconut.exceptions import ( CoconutException, @@ -306,8 +307,7 @@ def postParse(self, original, loc, tokens): def add_action(item, action, make_copy=None): """Add a parse action to the given item.""" if make_copy is None: - item_ref_count = sys.getrefcount(item) if CPYTHON else float("inf") - # keep this a lambda to prevent CPython refcounting changes from breaking release builds + item_ref_count = sys.getrefcount(item) if CPYTHON and not on_new_python else float("inf") internal_assert(lambda: item_ref_count >= temp_grammar_item_ref_count, "add_action got item with too low ref count", (item, type(item), item_ref_count)) make_copy = item_ref_count > temp_grammar_item_ref_count if make_copy: @@ -389,8 +389,9 @@ def parsing_context(inner_parse=True): finally: if inner_parse and use_packrat_parser: ParserElement.packrat_cache = old_cache - ParserElement.packrat_cache_stats[0] += old_cache_stats[0] - ParserElement.packrat_cache_stats[1] += old_cache_stats[1] + if logger.verbose: + ParserElement.packrat_cache_stats[0] += old_cache_stats[0] + ParserElement.packrat_cache_stats[1] += old_cache_stats[1] def prep_grammar(grammar, streamline=False): @@ -403,37 +404,46 @@ def prep_grammar(grammar, streamline=False): return grammar.parseWithTabs() -def parse(grammar, text, inner=True): +def parse(grammar, text, inner=True, eval_parse_tree=True): """Parse text using grammar.""" with parsing_context(inner): - return unpack(prep_grammar(grammar).parseString(text)) + result = prep_grammar(grammar).parseString(text) + if eval_parse_tree: + result = unpack(result) + return result -def try_parse(grammar, text, inner=True): +def try_parse(grammar, text, inner=True, eval_parse_tree=True): """Attempt to parse text using grammar else None.""" try: - return parse(grammar, text, inner) + return parse(grammar, text, inner, eval_parse_tree) except ParseBaseException: return None -def all_matches(grammar, text, inner=True): +def does_parse(grammar, text, inner=True): + """Determine if text can be parsed using grammar.""" + return try_parse(grammar, text, inner, eval_parse_tree=False) + + +def all_matches(grammar, text, inner=True, eval_parse_tree=True): """Find all matches for grammar in text.""" with parsing_context(inner): for tokens, start, stop in prep_grammar(grammar).scanString(text): - yield unpack(tokens), start, stop + if eval_parse_tree: + tokens = unpack(tokens) + yield tokens, start, stop def parse_where(grammar, text, inner=True): """Determine where the first parse is.""" - with parsing_context(inner): - for tokens, start, stop in prep_grammar(grammar).scanString(text): - return start, stop + for tokens, start, stop in all_matches(grammar, text, inner, eval_parse_tree=False): + return start, stop return None, None def match_in(grammar, text, inner=True): - """Determine if there is a match for grammar in text.""" + """Determine if there is a match for grammar anywhere in text.""" start, stop = parse_where(grammar, text, inner) internal_assert((start is None) == (stop is None), "invalid parse_where results", (start, stop)) return start is not None @@ -452,6 +462,7 @@ def transform(grammar, text, inner=True): # TARGETS: # ----------------------------------------------------------------------------------------------------------------------- +on_new_python = False raw_sys_target = str(sys.version_info[0]) + str(sys.version_info[1]) if raw_sys_target in pseudo_targets: @@ -460,6 +471,7 @@ def transform(grammar, text, inner=True): sys_target = raw_sys_target elif sys.version_info > supported_py3_vers[-1]: sys_target = "".join(str(i) for i in supported_py3_vers[-1]) + on_new_python = True elif sys.version_info < supported_py2_vers[0]: sys_target = "".join(str(i) for i in supported_py2_vers[0]) elif sys.version_info < (3,): @@ -522,11 +534,11 @@ def get_target_info_smart(target, mode="lowest"): class Wrap(ParseElementEnhance): """PyParsing token that wraps the given item in the given context manager.""" - def __init__(self, item, wrapper, greedy=False, can_affect_parse_success=False): + def __init__(self, item, wrapper, greedy=False, include_in_packrat_context=False): super(Wrap, self).__init__(item) self.wrapper = wrapper self.greedy = greedy - self.can_affect_parse_success = can_affect_parse_success + self.include_in_packrat_context = include_in_packrat_context @property def wrapped_name(self): @@ -538,7 +550,7 @@ def wrapped_packrat_context(self): Required to allow the packrat cache to distinguish between wrapped and unwrapped parses. Only supported natively on cPyparsing.""" - if self.can_affect_parse_success and hasattr(self, "packrat_context"): + if self.include_in_packrat_context and hasattr(self, "packrat_context"): self.packrat_context.append(self.wrapper) try: yield @@ -555,12 +567,12 @@ def parseImpl(self, original, loc, *args, **kwargs): with logger.indent_tracing(): with self.wrapper(self, original, loc): with self.wrapped_packrat_context(): - parse_loc, evaluated_toks = super(Wrap, self).parseImpl(original, loc, *args, **kwargs) + parse_loc, tokens = super(Wrap, self).parseImpl(original, loc, *args, **kwargs) if self.greedy: - evaluated_toks = evaluate_tokens(evaluated_toks) + tokens = evaluate_tokens(tokens) if logger.tracing: # avoid the overhead of the call if not tracing - logger.log_trace(self.wrapped_name, original, loc, evaluated_toks) - return parse_loc, evaluated_toks + logger.log_trace(self.wrapped_name, original, loc, tokens) + return parse_loc, tokens def __str__(self): return self.wrapped_name @@ -575,7 +587,7 @@ def disable_inside(item, *elems, **kwargs): Returns (item with elem disabled, *new versions of elems). """ _invert = kwargs.pop("_invert", False) - internal_assert(not kwargs, "excess keyword arguments passed to disable_inside") + internal_assert(not kwargs, "excess keyword arguments passed to disable_inside", kwargs) level = [0] # number of wrapped items deep we are; in a list to allow modification @@ -587,7 +599,7 @@ def manage_item(self, original, loc): finally: level[0] -= 1 - yield Wrap(item, manage_item, can_affect_parse_success=True) + yield Wrap(item, manage_item, include_in_packrat_context=True) @contextmanager def manage_elem(self, original, loc): @@ -597,7 +609,7 @@ def manage_elem(self, original, loc): raise ParseException(original, loc, self.errmsg, self) for elem in elems: - yield Wrap(elem, manage_elem, can_affect_parse_success=True) + yield Wrap(elem, manage_elem, include_in_packrat_context=True) def disable_outside(item, *elems): @@ -777,21 +789,22 @@ def stores_loc_action(loc, tokens): stores_loc_action.ignore_tokens = True -stores_loc_item = attach(Empty(), stores_loc_action, make_copy=False) +always_match = Empty() +stores_loc_item = attach(always_match, stores_loc_action) def disallow_keywords(kwds, with_suffix=None): """Prevent the given kwds from matching.""" item = ~( - keyword(kwds[0], explicit_prefix=False) + base_keyword(kwds[0]) if with_suffix is None else - keyword(kwds[0], explicit_prefix=False) + with_suffix + base_keyword(kwds[0]) + with_suffix ) for k in kwds[1:]: item += ~( - keyword(k, explicit_prefix=False) + base_keyword(k) if with_suffix is None else - keyword(k, explicit_prefix=False) + with_suffix + base_keyword(k) + with_suffix ) return item @@ -801,21 +814,13 @@ def any_keyword_in(kwds): return regex_item(r"|".join(k + r"\b" for k in kwds)) -@memoize() -def keyword(name, explicit_prefix=None, require_whitespace=False): +def base_keyword(name, explicit_prefix=False, require_whitespace=False): """Construct a grammar which matches name as a Python keyword.""" - if explicit_prefix is not False: - internal_assert( - (name in reserved_vars) is (explicit_prefix is not None), - "invalid keyword call for", name, - extra="pass explicit_prefix to keyword for all reserved_vars and only reserved_vars", - ) - base_kwd = regex_item(name + r"\b" + (r"(?=\s)" if require_whitespace else "")) - if explicit_prefix in (None, False): - return base_kwd - else: + if explicit_prefix and name in reserved_vars + allow_explicit_keyword_vars: return combine(Optional(explicit_prefix.suppress()) + base_kwd) + else: + return base_kwd boundary = regex_item(r"\b") @@ -870,6 +875,17 @@ def any_len_perm(*optional, **kwargs): return any_len_perm_with_one_of_each_group(*groups_and_elems) +def any_len_perm_at_least_one(*elems, **kwargs): + """Any length permutation of elems that includes at least one of the elems and all the required.""" + required = kwargs.pop("required", ()) + internal_assert(not kwargs, "invalid any_len_perm kwargs", kwargs) + + groups_and_elems = [] + groups_and_elems.extend((-1, e) for e in elems) + groups_and_elems.extend(enumerate(required)) + return any_len_perm_with_one_of_each_group(*groups_and_elems) + + # ----------------------------------------------------------------------------------------------------------------------- # UTILITIES: # ----------------------------------------------------------------------------------------------------------------------- @@ -1036,15 +1052,22 @@ def should_indent(code): return last_line.endswith((":", "=", "\\")) or paren_change(last_line) < 0 -def split_leading_comment(inputstr): - """Split into leading comment and rest. - Comment must be at very start of string.""" - if inputstr.startswith(comment_chars): - comment_line, rest = inputstr.split("\n", 1) - comment, indent = split_trailing_indent(comment_line) - return comment + "\n", indent + rest - else: - return "", inputstr +def split_leading_comments(inputstr): + """Split into leading comments and rest.""" + comments = "" + indent, base = split_leading_indent(inputstr) + + while base.startswith(comment_chars): + comment_line, rest = base.split("\n", 1) + + got_comment, got_indent = split_trailing_indent(comment_line) + comments += got_comment + "\n" + indent += got_indent + + got_indent, base = split_leading_indent(rest) + indent += got_indent + + return comments, indent + base def split_trailing_comment(inputstr): @@ -1060,7 +1083,7 @@ def split_trailing_comment(inputstr): def split_leading_indent(inputstr, max_indents=None): """Split inputstr into leading indent and main.""" - indent = "" + indents = [] while ( (max_indents is None or max_indents > 0) and inputstr.startswith(indchars) @@ -1069,13 +1092,13 @@ def split_leading_indent(inputstr, max_indents=None): # max_indents only refers to openindents/closeindents, not all indchars if max_indents is not None and got_ind in (openindent, closeindent): max_indents -= 1 - indent += got_ind - return indent, inputstr + indents.append(got_ind) + return "".join(indents), inputstr def split_trailing_indent(inputstr, max_indents=None, handle_comments=True): """Split inputstr into leading indent and main.""" - indent = "" + indents_from_end = [] while ( (max_indents is None or max_indents > 0) and inputstr.endswith(indchars) @@ -1084,13 +1107,13 @@ def split_trailing_indent(inputstr, max_indents=None, handle_comments=True): # max_indents only refers to openindents/closeindents, not all indchars if max_indents is not None and got_ind in (openindent, closeindent): max_indents -= 1 - indent = got_ind + indent + indents_from_end.append(got_ind) if handle_comments: inputstr, comment = split_trailing_comment(inputstr) inputstr, inner_indent = split_trailing_indent(inputstr, max_indents, handle_comments=False) inputstr = inputstr + comment - indent = inner_indent + indent - return inputstr, indent + indents_from_end.append(inner_indent) + return inputstr, "".join(reversed(indents_from_end)) def split_leading_trailing_indent(line, max_indents=None): @@ -1269,6 +1292,8 @@ def should_trim_arity(func): func_args = get_func_args(func) except TypeError: return True + if not func_args: + return True if func_args[0] == "self": func_args.pop(0) if func_args[:3] == ["original", "loc", "tokens"]: diff --git a/coconut/constants.py b/coconut/constants.py index 5093071bb..e42f8a8cb 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -79,7 +79,7 @@ def get_bool_env_var(env_var, default=False): IPY = ( ((PY2 and not PY26) or PY35) and not (PYPY and WINDOWS) - and not (PY311 and not WINDOWS) + and (PY37 or not PYPY) ) MYPY = ( PY37 @@ -113,6 +113,8 @@ def get_bool_env_var(env_var, default=False): varchars = string.ascii_letters + string.digits + "_" +use_computation_graph_env_var = "COCONUT_USE_COMPUTATION_GRAPH" + # ----------------------------------------------------------------------------------------------------------------------- # COMPILER CONSTANTS: # ----------------------------------------------------------------------------------------------------------------------- @@ -125,19 +127,26 @@ def get_bool_env_var(env_var, default=False): temp_grammar_item_ref_count = 3 if PY311 else 5 minimum_recursion_limit = 128 -default_recursion_limit = 2090 +# shouldn't be raised any higher to avoid stack overflows +default_recursion_limit = 1920 if sys.getrecursionlimit() < default_recursion_limit: sys.setrecursionlimit(default_recursion_limit) # modules that numpy-like arrays can live in +pandas_numpy_modules = ( + "pandas", +) jax_numpy_modules = ( - "jaxlib.xla_extension", + "jaxlib", ) numpy_modules = ( "numpy", - "pandas", -) + jax_numpy_modules + "torch", +) + ( + pandas_numpy_modules + + jax_numpy_modules +) legal_indent_chars = " \t" # the only Python-legal indent chars @@ -183,6 +192,7 @@ def get_bool_env_var(env_var, default=False): "26": "2", "32": "3", } +assert all(v in specific_targets or v in pseudo_targets for v in ROOT_HEADER_VERSIONS) targets = ("",) + specific_targets @@ -205,6 +215,7 @@ def get_bool_env_var(env_var, default=False): data_defaults_var = reserved_prefix + "_data_defaults" # prefer Matcher.get_temp_var to proliferating more vars here +match_first_arg_var = reserved_prefix + "_match_first_arg" match_to_args_var = reserved_prefix + "_match_args" match_to_kwargs_var = reserved_prefix + "_match_kwargs" function_match_error_var = reserved_prefix + "_FunctionMatchError" @@ -248,7 +259,7 @@ def get_bool_env_var(env_var, default=False): justify_len = 79 # ideal line length # for pattern-matching -default_matcher_style = "python warn on strict" +default_matcher_style = "python warn" wildcard = "_" in_place_op_funcs = { @@ -273,6 +284,28 @@ def get_bool_env_var(env_var, default=False): "..?**>=": "_coconut_forward_none_dubstar_compose", } +op_func_protocols = { + "add": "_coconut_SupportsAdd", + "minus": "_coconut_SupportsMinus", + "mul": "_coconut_SupportsMul", + "pow": "_coconut_SupportsPow", + "truediv": "_coconut_SupportsTruediv", + "floordiv": "_coconut_SupportsFloordiv", + "mod": "_coconut_SupportsMod", + "and": "_coconut_SupportsAnd", + "xor": "_coconut_SupportsXor", + "or": "_coconut_SupportsOr", + "lshift": "_coconut_SupportsLshift", + "rshift": "_coconut_SupportsRshift", + "matmul": "_coconut_SupportsMatmul", + "inv": "_coconut_SupportsInv", +} + +allow_explicit_keyword_vars = ( + "async", + "await", +) + keyword_vars = ( "and", "as", @@ -304,7 +337,7 @@ def get_bool_env_var(env_var, default=False): "with", "yield", "nonlocal", -) +) + allow_explicit_keyword_vars const_vars = ( "True", @@ -314,8 +347,6 @@ def get_bool_env_var(env_var, default=False): # names that can be backslash-escaped reserved_vars = ( - "async", - "await", "data", "match", "case", @@ -325,6 +356,7 @@ def get_bool_env_var(env_var, default=False): "then", "operator", "type", + "copyclosure", "\u03bb", # lambda ) @@ -517,6 +549,7 @@ def get_bool_env_var(env_var, default=False): more_prompt = " " mypy_path_env_var = "MYPYPATH" + style_env_var = "COCONUT_STYLE" vi_mode_env_var = "COCONUT_VI_MODE" home_env_var = "COCONUT_HOME" @@ -566,16 +599,17 @@ def get_bool_env_var(env_var, default=False): # always use atomic --xxx=yyy rather than --xxx yyy coconut_run_args = ("--run", "--target=sys", "--line-numbers", "--quiet") coconut_run_verbose_args = ("--run", "--target=sys", "--line-numbers") -coconut_import_hook_args = ("--target=sys", "--line-numbers", "--quiet") +coconut_import_hook_args = ("--target=sys", "--line-numbers", "--keep-lines", "--quiet") default_mypy_args = ( "--pretty", ) verbose_mypy_args = ( + "--show-traceback", + "--show-error-context", "--warn-unused-configs", "--warn-redundant-casts", "--warn-return-any", - "--show-error-context", "--warn-incomplete-stub", ) @@ -594,6 +628,11 @@ def get_bool_env_var(env_var, default=False): oserror_retcode = 127 +kilobyte = 1024 +min_stack_size_kbs = 160 + +default_jobs = "sys" if not PY26 else 0 + mypy_install_arg = "install" mypy_builtin_regex = re.compile(r"\b(reveal_type|reveal_locals)\b") @@ -601,10 +640,13 @@ def get_bool_env_var(env_var, default=False): interpreter_uses_auto_compilation = True interpreter_uses_coconut_breakpoint = True -coconut_pth_file = os.path.join(base_dir, "command", "resources", "zcoconut.pth") +command_resources_dir = os.path.join(base_dir, "command", "resources") +coconut_pth_file = os.path.join(command_resources_dir, "zcoconut.pth") interpreter_compiler_var = "__coconut_compiler__" +jupyter_console_commands = ("console", "qtconsole") + # ----------------------------------------------------------------------------------------------------------------------- # HIGHLIGHTER CONSTANTS: # ----------------------------------------------------------------------------------------------------------------------- @@ -656,6 +698,7 @@ def get_bool_env_var(env_var, default=False): "cycle", "windowsof", "py_chr", + "py_dict", "py_hex", "py_input", "py_int", @@ -680,6 +723,11 @@ def get_bool_env_var(env_var, default=False): "reveal_locals", ) +# builtins that must be imported from the exact right target header +must_use_specific_target_builtins = ( + "super", +) + coconut_exceptions = ( "MatchError", ) @@ -703,6 +751,7 @@ def get_bool_env_var(env_var, default=False): r"->", r"\?\??", r"<:", + r"&:", "\u2192", # -> "\\??\\*?\\*?\u21a6", # |> "\u21a4\\*?\\*?\\??", # <| @@ -760,7 +809,7 @@ def get_bool_env_var(env_var, default=False): "pyparsing", ), "non-py26": ( - "pygments", + "psutil", ), "py2": ( "futures", @@ -773,25 +822,31 @@ def get_bool_env_var(env_var, default=False): "py26": ( "argparse", ), - "jobs": ( - "psutil", + "py<39": ( + ("pygments", "mark<39"), + ), + "py39": ( + ("pygments", "mark39"), ), "kernel": ( ("ipython", "py2"), - ("ipython", "py3"), + ("ipython", "py3;py<38"), + ("ipython", "py38"), ("ipykernel", "py2"), - ("ipykernel", "py3"), - ("jupyter-client", "py2"), + ("ipykernel", "py3;py<38"), + ("ipykernel", "py38"), + ("jupyter-client", "py<35"), ("jupyter-client", "py==35"), ("jupyter-client", "py36"), - "jedi", + ("jedi", "py<39"), + ("jedi", "py39"), ("pywinpty", "py2;windows"), ), "jupyter": ( "jupyter", - ("jupyter-console", "py2"), - ("jupyter-console", "py==35"), - ("jupyter-console", "py36"), + ("jupyter-console", "py<35"), + ("jupyter-console", "py>=35;py<37"), + ("jupyter-console", "py37"), ("jupyterlab", "py35"), ("jupytext", "py3"), "papermill", @@ -800,7 +855,8 @@ def get_bool_env_var(env_var, default=False): "mypy[python2]", "types-backports", ("typing_extensions", "py==35"), - ("typing_extensions", "py36"), + ("typing_extensions", "py==36"), + ("typing_extensions", "py37"), ), "watch": ( "watchdog", @@ -814,7 +870,8 @@ def get_bool_env_var(env_var, default=False): ("dataclasses", "py==36"), ("typing", "py<35"), ("typing_extensions", "py==35"), - ("typing_extensions", "py36"), + ("typing_extensions", "py==36"), + ("typing_extensions", "py37"), ), "dev": ( ("pre-commit", "py3"), @@ -823,7 +880,8 @@ def get_bool_env_var(env_var, default=False): ), "docs": ( "sphinx", - "pygments", + ("pygments", "mark<39"), + ("pygments", "mark39"), "myst-parser", "pydata-sphinx-theme", ), @@ -832,13 +890,14 @@ def get_bool_env_var(env_var, default=False): "pexpect", ("numpy", "py34"), ("numpy", "py2;cpy"), + ("pandas", "py36"), ), } # min versions are inclusive min_versions = { - "cPyparsing": (2, 4, 7, 1, 2, 0), - ("pre-commit", "py3"): (2, 21), + "cPyparsing": (2, 4, 7, 1, 2, 1), + ("pre-commit", "py3"): (3,), "psutil": (5,), "jupyter": (1, 0), "types-backports": (0, 1), @@ -847,27 +906,34 @@ def get_bool_env_var(env_var, default=False): "argparse": (1, 4), "pexpect": (4,), ("trollius", "py2;cpy"): (2, 2), - "requests": (2, 28), + "requests": (2, 29), ("numpy", "py34"): (1,), ("numpy", "py2;cpy"): (1,), ("dataclasses", "py==36"): (0, 8), ("aenum", "py<34"): (3,), - "sphinx": (5, 3), - "pydata-sphinx-theme": (0, 12), - "myst-parser": (0, 18), - "mypy[python2]": (0, 991), - ("jupyter-console", "py36"): (6, 4), + "pydata-sphinx-theme": (0, 13), + "myst-parser": (1,), + "mypy[python2]": (1, 2), + ("jupyter-console", "py37"): (6,), ("typing", "py<35"): (3, 10), + ("typing_extensions", "py37"): (4, 5), + ("ipython", "py38"): (8,), + ("ipykernel", "py38"): (6,), + ("jedi", "py39"): (0, 18), + ("pygments", "mark39"): (2, 15), # pinned reqs: (must be added to pinned_reqs below) + # don't upgrade until myst-parser supports the new version + "sphinx": (6,), # don't upgrade this; it breaks on Python 3.6 + ("pandas", "py36"): (1,), ("jupyter-client", "py36"): (7, 1, 2), - ("typing_extensions", "py36"): (4, 1), + ("typing_extensions", "py==36"): (4, 1), # don't upgrade these; they break on Python 3.5 - ("ipykernel", "py3"): (5, 5), - ("ipython", "py3"): (7, 9), - ("jupyter-console", "py==35"): (6, 1), + ("ipykernel", "py3;py<38"): (5, 5), + ("ipython", "py3;py<38"): (7, 9), + ("jupyter-console", "py>=35;py<37"): (6, 1), ("jupyter-client", "py==35"): (6, 1, 12), ("jupytext", "py3"): (1, 8), ("jupyterlab", "py35"): (2, 2), @@ -880,30 +946,32 @@ def get_bool_env_var(env_var, default=False): # don't upgrade this; it breaks on unix "vprof": (0, 36), # don't upgrade this; it breaks on Python 3.4 - "pygments": (2, 3), + ("pygments", "mark<39"): (2, 3), # don't upgrade these; they break on Python 2 - ("jupyter-client", "py2"): (5, 3), + ("jupyter-client", "py<35"): (5, 3), ("pywinpty", "py2;windows"): (0, 5), - ("jupyter-console", "py2"): (5, 2), + ("jupyter-console", "py<35"): (5, 2), ("ipython", "py2"): (5, 4), ("ipykernel", "py2"): (4, 10), ("prompt_toolkit", "mark2"): (1,), "watchdog": (0, 10), "papermill": (1, 2), # don't upgrade this; it breaks with old IPython versions - "jedi": (0, 17), + ("jedi", "py<39"): (0, 17), # Coconut requires pyparsing 2 "pyparsing": (2, 4, 7), } # should match the reqs with comments above pinned_reqs = ( + "sphinx", + ("pandas", "py36"), ("jupyter-client", "py36"), - ("typing_extensions", "py36"), - ("jupyter-client", "py2"), - ("ipykernel", "py3"), - ("ipython", "py3"), - ("jupyter-console", "py==35"), + ("typing_extensions", "py==36"), + ("jupyter-client", "py<35"), + ("ipykernel", "py3;py<38"), + ("ipython", "py3;py<38"), + ("jupyter-console", "py>=35;py<37"), ("jupyter-client", "py==35"), ("jupytext", "py3"), ("jupyterlab", "py35"), @@ -912,15 +980,15 @@ def get_bool_env_var(env_var, default=False): ("prompt_toolkit", "mark3"), "pytest", "vprof", - "pygments", + ("pygments", "mark<39"), ("pywinpty", "py2;windows"), - ("jupyter-console", "py2"), + ("jupyter-console", "py<35"), ("ipython", "py2"), ("ipykernel", "py2"), ("prompt_toolkit", "mark2"), "watchdog", "papermill", - "jedi", + ("jedi", "py<39"), "pyparsing", ) @@ -933,14 +1001,11 @@ def get_bool_env_var(env_var, default=False): "pyparsing": _, "cPyparsing": (_, _, _), ("prompt_toolkit", "mark2"): _, - "jedi": _, + ("jedi", "py<39"): _, ("pywinpty", "py2;windows"): _, + ("ipython", "py3;py<38"): _, } -allowed_constrained_but_unpinned_reqs = ( - "cPyparsing", -) - classifiers = ( "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: Apache Software License", @@ -1107,6 +1172,8 @@ def get_bool_env_var(env_var, default=False): conda_build_env_var = "CONDA_BUILD" +disabled_xonsh_modes = ("exec", "eval") + # ----------------------------------------------------------------------------------------------------------------------- # DOCUMENTATION CONSTANTS: # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/exceptions.py b/coconut/exceptions.py index 4293a8aa9..c49429cf0 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -88,19 +88,28 @@ class CoconutException(BaseCoconutException, Exception): class CoconutSyntaxError(CoconutException): """Coconut SyntaxError.""" point_to_endpoint = False + argnames = ("message", "source", "point", "ln", "extra", "endpoint", "filename") - def __init__(self, message, source=None, point=None, ln=None, extra=None, endpoint=None): + def __init__(self, message, source=None, point=None, ln=None, extra=None, endpoint=None, filename=None): """Creates the Coconut SyntaxError.""" - self.args = (message, source, point, ln, extra, endpoint) + self.args = (message, source, point, ln, extra, endpoint, filename) - def message(self, message, source, point, ln, extra=None, endpoint=None): + @property + def kwargs(self): + """Get the arguments as keyword arguments.""" + return dict(zip(self.args, self.argnames)) + + def message(self, message, source, point, ln, extra=None, endpoint=None, filename=None): """Creates a SyntaxError-like message.""" if message is None: message = "parsing failed" if extra is not None: message += " (" + str(extra) + ")" if ln is not None: - message += " (line " + str(ln) + ")" + message += " (line " + str(ln) + if filename is not None: + message += " in " + repr(filename) + message += ")" if source: if point is None: for line in source.splitlines(): @@ -174,10 +183,20 @@ def message(self, message, source, point, ln, extra=None, endpoint=None): def syntax_err(self): """Creates a SyntaxError.""" - args = self.args[:2] + (None, None) + self.args[4:] - err = SyntaxError(self.message(*args)) - err.offset = args[2] - err.lineno = args[3] + kwargs = self.kwargs + if self.point_to_endpoint and "endpoint" in kwargs: + point = kwargs.pop("endpoint") + else: + point = kwargs.pop("point") + kwargs["point"] = kwargs["endpoint"] = None + ln = kwargs.pop("ln") + filename = kwargs.pop("filename", None) + + err = SyntaxError(self.message(**kwargs)) + err.offset = point + err.lineno = ln + if filename is not None: + err.filename = filename return err def set_point_to_endpoint(self, point_to_endpoint): @@ -189,25 +208,26 @@ def set_point_to_endpoint(self, point_to_endpoint): class CoconutStyleError(CoconutSyntaxError): """Coconut --strict error.""" - def __init__(self, message, source=None, point=None, ln=None, extra="remove --strict to dismiss", endpoint=None): + def __init__(self, message, source=None, point=None, ln=None, extra="remove --strict to dismiss", endpoint=None, filename=None): """Creates the --strict Coconut error.""" - self.args = (message, source, point, ln, extra, endpoint) + self.args = (message, source, point, ln, extra, endpoint, filename) class CoconutTargetError(CoconutSyntaxError): """Coconut --target error.""" + argnames = ("message", "source", "point", "ln", "target", "endpoint", "filename") - def __init__(self, message, source=None, point=None, ln=None, target=None, endpoint=None): + def __init__(self, message, source=None, point=None, ln=None, target=None, endpoint=None, filename=None): """Creates the --target Coconut error.""" - self.args = (message, source, point, ln, target, endpoint) + self.args = (message, source, point, ln, target, endpoint, filename) - def message(self, message, source, point, ln, target, endpoint): + def message(self, message, source, point, ln, target, endpoint, filename): """Creates the --target Coconut error message.""" if target is None: extra = None else: extra = "pass --target " + get_displayable_target(target) + " to fix" - return super(CoconutTargetError, self).message(message, source, point, ln, extra, endpoint) + return super(CoconutTargetError, self).message(message, source, point, ln, extra, endpoint, filename) class CoconutParseError(CoconutSyntaxError): diff --git a/coconut/icoconut/root.py b/coconut/icoconut/root.py index 2067673b2..45fc6f1ad 100644 --- a/coconut/icoconut/root.py +++ b/coconut/icoconut/root.py @@ -214,7 +214,7 @@ def run_cell(self, raw_cell, store_history=False, silent=False, shell_futures=Tr if asyncio is not None: @override - {async_}def run_cell_async(self, raw_cell, store_history=False, silent=False, shell_futures=True, cell_id=None, **kwargs): + {coroutine}def run_cell_async(self, raw_cell, store_history=False, silent=False, shell_futures=True, cell_id=None, **kwargs): """Version of run_cell_async that always uses shell_futures.""" # same as above return super({cls}, self).run_cell_async(raw_cell, store_history, silent, shell_futures=True, **kwargs) @@ -233,8 +233,8 @@ def user_expressions(self, expressions): format_dict = dict( dict="{}", - async_=( - "async " if PY311 else + coroutine=( + "" if PY311 else """@asyncio.coroutine """ ), diff --git a/coconut/integrations.py b/coconut/integrations.py index 77017c84f..bbed00a40 100644 --- a/coconut/integrations.py +++ b/coconut/integrations.py @@ -21,7 +21,10 @@ from types import MethodType -from coconut.constants import coconut_kernel_kwargs +from coconut.constants import ( + coconut_kernel_kwargs, + disabled_xonsh_modes, +) # ----------------------------------------------------------------------------------------------------------------------- # IPYTHON: @@ -86,14 +89,45 @@ def magic(line, cell=None): class CoconutXontribLoader(object): """Implements Coconut's _load_xontrib_.""" - timing_info = [] + loaded = False compiler = None runner = None + timing_info = [] + + def new_parse(self, parser, code, mode="exec", *args, **kwargs): + """Coconut-aware version of xonsh's _parse.""" + if self.loaded and mode not in disabled_xonsh_modes: + # hide imports to avoid circular dependencies + from coconut.exceptions import CoconutException + from coconut.terminal import format_error + from coconut.util import get_clock_time + from coconut.terminal import logger + + parse_start_time = get_clock_time() + quiet, logger.quiet = logger.quiet, True + try: + code = self.compiler.parse_xonsh(code, keep_state=True) + except CoconutException as err: + err_str = format_error(err).splitlines()[0] + code += " #" + err_str + finally: + logger.quiet = quiet + self.timing_info.append(("parse", get_clock_time() - parse_start_time)) + return parser.__class__.parse(parser, code, mode=mode, *args, **kwargs) + + def new_try_subproc_toks(self, ctxtransformer, *args, **kwargs): + """Version of try_subproc_toks that handles the fact that Coconut + code may have different columns than Python code.""" + mode = ctxtransformer.mode + if self.loaded: + ctxtransformer.mode = "eval" + try: + return ctxtransformer.__class__.try_subproc_toks(ctxtransformer, *args, **kwargs) + finally: + ctxtransformer.mode = mode def __call__(self, xsh, **kwargs): # hide imports to avoid circular dependencies - from coconut.exceptions import CoconutException - from coconut.terminal import format_error from coconut.util import get_clock_time start_time = get_clock_time() @@ -109,26 +143,28 @@ def __call__(self, xsh, **kwargs): self.runner.update_vars(xsh.ctx) - def new_parse(execer, s, *args, **kwargs): - """Coconut-aware version of xonsh's _parse.""" - parse_start_time = get_clock_time() - try: - s = self.compiler.parse_xonsh(s, keep_state=True) - except CoconutException as err: - err_str = format_error(err).splitlines()[0] - s += " #" + err_str - self.timing_info.append(("parse", get_clock_time() - parse_start_time)) - return execer.__class__.parse(execer, s, *args, **kwargs) - main_parser = xsh.execer.parser - main_parser.parse = MethodType(new_parse, main_parser) + main_parser.parse = MethodType(self.new_parse, main_parser) + + ctxtransformer = xsh.execer.ctxtransformer + ctx_parser = ctxtransformer.parser + ctx_parser.parse = MethodType(self.new_parse, ctx_parser) - ctx_parser = xsh.execer.ctxtransformer.parser - ctx_parser.parse = MethodType(new_parse, ctx_parser) + ctxtransformer.try_subproc_toks = MethodType(self.new_try_subproc_toks, ctxtransformer) self.timing_info.append(("load", get_clock_time() - start_time)) + self.loaded = True return self.runner.vars + def unload(self, xsh): + if not self.loaded: + # hide imports to avoid circular dependencies + from coconut.exceptions import CoconutException + raise CoconutException("attempting to unload Coconut xontrib but it was never loaded") + self.loaded = False + _load_xontrib_ = CoconutXontribLoader() + +_unload_xontrib_ = _load_xontrib_.unload diff --git a/coconut/requirements.py b/coconut/requirements.py index bbd880084..04be698d7 100644 --- a/coconut/requirements.py +++ b/coconut/requirements.py @@ -23,9 +23,9 @@ from coconut.integrations import embed from coconut.constants import ( - PYPY, CPYTHON, PY34, + PY39, IPY, MYPY, XONSH, @@ -186,7 +186,6 @@ def everything_in(req_dict): extras = { "kernel": get_reqs("kernel"), "watch": get_reqs("watch"), - "jobs": get_reqs("jobs"), "mypy": get_reqs("mypy"), "backports": get_reqs("backports"), "xonsh": get_reqs("xonsh"), @@ -205,7 +204,6 @@ def everything_in(req_dict): "tests": uniqueify_all( get_reqs("tests"), extras["backports"], - extras["jobs"] if not PYPY else [], extras["jupyter"] if IPY else [], extras["mypy"] if MYPY else [], extras["xonsh"] if XONSH else [], @@ -240,6 +238,8 @@ def everything_in(req_dict): extras[":python_version>='2.7'"] = get_reqs("non-py26") extras[":python_version<'3'"] = get_reqs("py2") extras[":python_version>='3'"] = get_reqs("py3") + extras[":python_version<'3.9'"] = get_reqs("py<39") + extras[":python_version>='3.9'"] = get_reqs("py39") else: # old method if PY26: @@ -250,6 +250,10 @@ def everything_in(req_dict): requirements += get_reqs("py2") else: requirements += get_reqs("py3") + if PY39: + requirements += get_reqs("py39") + else: + requirements += get_reqs("py<39") # ----------------------------------------------------------------------------------------------------------------------- # MAIN: diff --git a/coconut/root.py b/coconut/root.py index fd857f54e..48e7e69b1 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -23,12 +23,15 @@ # VERSION: # ----------------------------------------------------------------------------------------------------------------------- -VERSION = "2.2.0" +VERSION = "3.0.0" VERSION_NAME = None # False for release, int >= 1 for develop DEVELOP = False ALPHA = False # for pre releases rather than post releases +assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1" +assert DEVELOP or not ALPHA, "alpha releases are only for develop" + # ----------------------------------------------------------------------------------------------------------------------- # UTILITIES: # ----------------------------------------------------------------------------------------------------------------------- @@ -43,63 +46,23 @@ def _indent(code, by=1, tabsize=4, strip=False, newline=False, initial_newline=F # ----------------------------------------------------------------------------------------------------------------------- -# CONSTANTS: +# HEADER: # ----------------------------------------------------------------------------------------------------------------------- -assert isinstance(DEVELOP, int) or DEVELOP is False, "DEVELOP must be an int or False" -assert DEVELOP or not ALPHA, "alpha releases are only for develop" - -if DEVELOP: - VERSION += "-" + ("a" if ALPHA else "post") + "_dev" + str(int(DEVELOP)) -VERSION_STR = VERSION + (" [" + VERSION_NAME + "]" if VERSION_NAME else "") - -PY2 = _coconut_sys.version_info < (3,) -PY26 = _coconut_sys.version_info < (2, 7) -PY37 = _coconut_sys.version_info >= (3, 7) - -_non_py37_extras = r'''def _coconut_default_breakpointhook(*args, **kwargs): - hookname = _coconut.os.getenv("PYTHONBREAKPOINT") - if hookname != "0": - if not hookname: - hookname = "pdb.set_trace" - modname, dot, funcname = hookname.rpartition(".") - if not dot: - modname = "builtins" if _coconut_sys.version_info >= (3,) else "__builtin__" - if _coconut_sys.version_info >= (2, 7): - import importlib - module = importlib.import_module(modname) - else: - import imp - module = imp.load_module(modname, *imp.find_module(modname)) - hook = _coconut.getattr(module, funcname) - return hook(*args, **kwargs) -if not hasattr(_coconut_sys, "__breakpointhook__"): - _coconut_sys.__breakpointhook__ = _coconut_default_breakpointhook -def breakpoint(*args, **kwargs): - return _coconut.getattr(_coconut_sys, "breakpointhook", _coconut_default_breakpointhook)(*args, **kwargs) -''' - # if a new assignment is added below, a new builtins import should be added alongside it -_base_py3_header = r'''from builtins import chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr -py_chr, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr -_coconut_py_str, _coconut_py_super = str, super +_base_py3_header = r'''from builtins import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr +py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr +_coconut_py_str, _coconut_py_super, _coconut_py_dict = str, super, dict from functools import wraps as _coconut_wraps exec("_coconut_exec = exec") ''' -PY37_HEADER = _base_py3_header + r'''py_breakpoint = breakpoint -''' - -PY3_HEADER = _base_py3_header + r'''if _coconut_sys.version_info < (3, 7): -''' + _indent(_non_py37_extras) + r'''else: - py_breakpoint = breakpoint -''' - # if a new assignment is added below, a new builtins import should be added alongside it -PY27_HEADER = r'''from __builtin__ import chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long -py_chr, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr -_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr = raw_input, xrange, int, long, print, str, super, unicode, repr +_base_py2_header = r'''from __builtin__ import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long +py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr +_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict = raw_input, xrange, int, long, print, str, super, unicode, repr, dict from functools import wraps as _coconut_wraps +from collections import Sequence as _coconut_Sequence from future_builtins import * chr, str = unichr, unicode from io import open @@ -189,7 +152,6 @@ def __copy__(self): return self.__class__(*self._args) def __eq__(self, other): return self.__class__ is other.__class__ and self._args == other._args -from collections import Sequence as _coconut_Sequence _coconut_Sequence.register(range) @_coconut_wraps(_coconut_py_print) def print(*args, **kwargs): @@ -236,9 +198,63 @@ def _coconut_exec(obj, globals=None, locals=None): if globals is None: globals = _coconut_sys._getframe(1).f_globals exec(obj, globals, locals) -''' + _non_py37_extras +''' + +_non_py37_extras = r'''def _coconut_default_breakpointhook(*args, **kwargs): + hookname = _coconut.os.getenv("PYTHONBREAKPOINT") + if hookname != "0": + if not hookname: + hookname = "pdb.set_trace" + modname, dot, funcname = hookname.rpartition(".") + if not dot: + modname = "builtins" if _coconut_sys.version_info >= (3,) else "__builtin__" + if _coconut_sys.version_info >= (2, 7): + import importlib + module = importlib.import_module(modname) + else: + import imp + module = imp.load_module(modname, *imp.find_module(modname)) + hook = _coconut.getattr(module, funcname) + return hook(*args, **kwargs) +if not hasattr(_coconut_sys, "__breakpointhook__"): + _coconut_sys.__breakpointhook__ = _coconut_default_breakpointhook +def breakpoint(*args, **kwargs): + return _coconut.getattr(_coconut_sys, "breakpointhook", _coconut_default_breakpointhook)(*args, **kwargs) +''' + +_finish_dict_def = ''' + def __or__(self, other): + out = self.copy() + out.update(other) + return out + def __ror__(self, other): + out = self.__class__(other) + out.update(self) + return out + def __ior__(self, other): + self.update(other) + return self +class _coconut_dict_meta(type): + def __instancecheck__(cls, inst): + return _coconut.isinstance(inst, _coconut_py_dict) + def __subclasscheck__(cls, subcls): + return _coconut.issubclass(subcls, _coconut_py_dict) +dict = _coconut_dict_meta(py_str("dict"), _coconut_dict_base.__bases__, _coconut_dict_base.__dict__.copy()) +''' -PY2_HEADER = PY27_HEADER + '''if _coconut_sys.version_info < (2, 7): +_below_py37_extras = '''from collections import OrderedDict as _coconut_OrderedDict +class _coconut_dict_base(_coconut_OrderedDict): + __slots__ = () + __doc__ = getattr(_coconut_OrderedDict, "__doc__", "") + __eq__ = _coconut_py_dict.__eq__ + def __repr__(self): + return "{" + ", ".join("{k!r}: {v!r}".format(k=k, v=v) for k, v in self.items()) + "}"''' + _finish_dict_def + +_py37_py38_extras = '''class _coconut_dict_base(_coconut_py_dict): + __slots__ = () + __doc__ = getattr(_coconut_py_dict, "__doc__", "")''' + _finish_dict_def + +_py26_extras = '''if _coconut_sys.version_info < (2, 7): import functools as _coconut_functools, copy_reg as _coconut_copy_reg def _coconut_new_partial(func, args, keywords): return _coconut_functools.partial(func, *(args if args is not None else ()), **(keywords if keywords is not None else {})) @@ -248,9 +264,78 @@ def _coconut_reduce_partial(self): _coconut_copy_reg.pickle(_coconut_functools.partial, _coconut_reduce_partial) ''' -PYCHECK_HEADER = r'''if _coconut_sys.version_info < (3,): -''' + _indent(PY2_HEADER) + '''else: -''' + _indent(PY3_HEADER) + +# whenever new versions are added here, header.py must be updated to use them +ROOT_HEADER_VERSIONS = ( + "universal", + "2", + "3", + "27", + "37", + "39", +) + + +def _get_root_header(version="universal"): + assert version in ROOT_HEADER_VERSIONS, version + + if version == "universal": + return r'''if _coconut_sys.version_info < (3,): +''' + _indent(_get_root_header("2")) + '''else: +''' + _indent(_get_root_header("3")) + + header = "" + + if version.startswith("3"): + header += _base_py3_header + else: + assert version.startswith("2"), version + # if a new assignment is added below, a new builtins import should be added alongside it + header += _base_py2_header + + if version in ("37", "39"): + header += r'''py_breakpoint = breakpoint +''' + elif version == "3": + header += r'''if _coconut_sys.version_info < (3, 7): +''' + _indent(_non_py37_extras) + r'''else: + py_breakpoint = breakpoint +''' + else: + assert version.startswith("2"), version + header += _non_py37_extras + if version == "2": + header += _py26_extras + + if version == "3": + header += r'''if _coconut_sys.version_info < (3, 7): +''' + _indent(_below_py37_extras) + r'''elif _coconut_sys.version_info < (3, 9): +''' + _indent(_py37_py38_extras) + elif version == "37": + header += r'''if _coconut_sys.version_info < (3, 9): +''' + _indent(_py37_py38_extras) + elif version.startswith("2"): + header += _below_py37_extras + '''dict.keys = _coconut_OrderedDict.viewkeys +dict.values = _coconut_OrderedDict.viewvalues +dict.items = _coconut_OrderedDict.viewitems +''' + else: + assert version == "39", version + + return header + + +# ----------------------------------------------------------------------------------------------------------------------- +# CONSTANTS: +# ----------------------------------------------------------------------------------------------------------------------- + +if DEVELOP: + VERSION += "-" + ("a" if ALPHA else "post") + "_dev" + str(int(DEVELOP)) +VERSION_STR = VERSION + (" [" + VERSION_NAME + "]" if VERSION_NAME else "") + +PY2 = _coconut_sys.version_info < (3,) +PY26 = _coconut_sys.version_info < (2, 7) +PY37 = _coconut_sys.version_info >= (3, 7) # ----------------------------------------------------------------------------------------------------------------------- # SETUP: @@ -267,11 +352,4 @@ def _coconut_reduce_partial(self): import os _coconut.os = os -if PY26: - exec(PY2_HEADER) -elif PY2: - exec(PY27_HEADER) -elif PY37: - exec(PY37_HEADER) -else: - exec(PY3_HEADER) +exec(_get_root_header()) diff --git a/coconut/terminal.py b/coconut/terminal.py index 200edbf76..bdb92196e 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -24,12 +24,15 @@ import traceback import logging from contextlib import contextmanager +from collections import defaultdict +from functools import wraps if sys.version_info < (2, 7): from StringIO import StringIO else: from io import StringIO from coconut._pyparsing import ( + MODERN_PYPARSING, lineno, col, ParserElement, @@ -99,7 +102,7 @@ def complain(error): error = error() else: return - if not isinstance(error, CoconutInternalException) and isinstance(error, CoconutException): + if not isinstance(error, BaseException) or (not isinstance(error, CoconutInternalException) and isinstance(error, CoconutException)): error = CoconutInternalException(str(error)) if not DEVELOP: logger.warn_err(error) @@ -196,7 +199,7 @@ def enable_colors(cls): # necessary to resolve https://bugs.python.org/issue40134 try: os.system("") - except Exception: + except BaseException: logger.log_exc() cls.colors_enabled = True @@ -212,7 +215,18 @@ def copy(self): """Make a copy of the logger.""" return Logger(self) - def display(self, messages, sig="", end="\n", file=None, level="normal", color=None, **kwargs): + def display( + self, + messages, + sig="", + end="\n", + file=None, + level="normal", + color=None, + # flush by default to ensure our messages show up when printing from a child process + flush=True, + **kwargs + ): """Prints an iterator of messages.""" if level == "normal": file = file or sys.stdout @@ -248,8 +262,9 @@ def display(self, messages, sig="", end="\n", file=None, level="normal", color=N components.append(end) full_message = "".join(components) - # we use end="" to ensure atomic printing (and so we add the end in earlier) - print(full_message, file=file, end="", **kwargs) + if full_message: + # we use end="" to ensure atomic printing (and so we add the end in earlier) + print(full_message, file=file, end="", flush=flush, **kwargs) def print(self, *messages, **kwargs): """Print messages to stdout.""" @@ -440,8 +455,9 @@ def log_trace(self, expr, original, loc, item=None, extra=None): msg = displayable(str(item)) if "{" in msg: head, middle = msg.split("{", 1) - middle, tail = middle.rsplit("}", 1) - msg = head + "{...}" + tail + if "}" in middle: + middle, tail = middle.rsplit("}", 1) + msg = head + "{...}" + tail out.append(msg) add_line_col = False elif len(item) == 1 and isinstance(item[0], str): @@ -464,7 +480,7 @@ def _trace_exc_action(self, original, loc, expr, exc): def trace(self, item): """Traces a parse element (only enabled in develop).""" - if DEVELOP: + if DEVELOP and not MODERN_PYPARSING: item.debugActions = ( None, # no start action self._trace_success_action, @@ -489,20 +505,47 @@ def gather_parsing_stats(self): else: yield + total_block_time = defaultdict(int) + + @contextmanager + def time_block(self, name): + start_time = get_clock_time() + try: + yield + finally: + elapsed_time = get_clock_time() - start_time + self.total_block_time[name] += elapsed_time + self.printlog("Time while running", name + ":", elapsed_time, "secs (total so far:", self.total_block_time[name], "secs)") + def time_func(self, func): """Decorator to print timing info for a function.""" + @wraps(func) def timed_func(*args, **kwargs): """Function timed by logger.time_func.""" if not DEVELOP or self.quiet: return func(*args, **kwargs) - start_time = get_clock_time() - try: + with self.time_block(func.__name__): return func(*args, **kwargs) - finally: - elapsed_time = get_clock_time() - start_time - self.printlog("Time while running", func.__name__ + ":", elapsed_time, "secs") return timed_func + def debug_func(self, func): + """Decorates a function to print the input/output behavior.""" + @wraps(func) + def printing_func(*args, **kwargs): + """Function decorated by logger.debug_func.""" + if not DEVELOP or self.quiet: + return func(*args, **kwargs) + if not kwargs: + self.printerr(func, "<*|", args) + elif not args: + self.printerr(func, "<**|", kwargs) + else: + self.printerr(func, "<<|", args, kwargs) + out = func(*args, **kwargs) + self.printerr(func, "=>", repr(out)) + return out + return printing_func + def patch_logging(self): """Patches built-in Python logging if necessary.""" if not hasattr(logging, "getLogger"): diff --git a/coconut/tests/constants_test.py b/coconut/tests/constants_test.py index f8efcb163..bb2d561c5 100644 --- a/coconut/tests/constants_test.py +++ b/coconut/tests/constants_test.py @@ -109,7 +109,9 @@ def test_imports(self): def test_reqs(self): assert set(constants.pinned_reqs) <= set(constants.min_versions), "found old pinned requirement" - assert set(constants.max_versions) <= set(constants.pinned_reqs) | set(constants.allowed_constrained_but_unpinned_reqs), "found unlisted constrained but unpinned requirements" + assert set(constants.max_versions) <= set(constants.pinned_reqs) | set(("cPyparsing",)), "found unlisted constrained but unpinned requirements" + for maxed_ver in constants.max_versions: + assert isinstance(maxed_ver, tuple) or maxed_ver in ("pyparsing", "cPyparsing"), "maxed versions must be tagged to a specific Python version" # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 83f0fb4f2..d73a33d0b 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -49,7 +49,8 @@ MYPY, PY35, PY36, - PY39, + PY38, + PY310, icoconut_default_kernel_names, icoconut_custom_kernel_name, mypy_err_infixes, @@ -60,13 +61,27 @@ auto_compilation, setup, ) + + +# ----------------------------------------------------------------------------------------------------------------------- +# SETUP: +# ----------------------------------------------------------------------------------------------------------------------- + + auto_compilation(False) +logger.verbose = property(lambda self: True, lambda self, value: print("WARNING: ignoring attempt to set logger.verbose = {value}".format(value=value))) + +os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + + # ----------------------------------------------------------------------------------------------------------------------- # CONSTANTS: # ----------------------------------------------------------------------------------------------------------------------- -logger.verbose = property(lambda self: True, lambda self, value: print("WARNING: ignoring attempt to set logger.verbose = {value}".format(value=value))) + +default_recursion_limit = "4096" +default_stack_size = "4096" base = os.path.dirname(os.path.relpath(__file__)) src = os.path.join(base, "src") @@ -85,7 +100,8 @@ prelude_git = "https://github.com/evhub/coconut-prelude" bbopt_git = "https://github.com/evhub/bbopt.git" -coconut_snip = r"msg = ''; pmsg = print$(msg); `pmsg`" +coconut_snip = "msg = ''; pmsg = print$(msg); `pmsg`" +target_3_snip = "assert super is py_super; print('')" always_err_strs = ( "CoconutInternalException", @@ -116,6 +132,8 @@ ignore_last_lines_with = ( "DeprecationWarning: The distutils package is deprecated", "from distutils.version import LooseVersion", + ": SyntaxWarning: 'int' object is not ", + " assert_raises(", ) kernel_installation_msg = ( @@ -236,7 +254,15 @@ def call(raw_cmd, assert_output=False, check_mypy=False, check_errors=True, stde line = raw_lines[i] # ignore https://bugs.python.org/issue39098 errors - if sys.version_info < (3, 9) and line == "Error in atexit._run_exitfuncs:": + if sys.version_info < (3, 9) and ( + line == "Error in atexit._run_exitfuncs:" + or ( + line == "Traceback (most recent call last):" + and i + 1 < len(raw_lines) + and "concurrent/futures/process.py" in raw_lines[i + 1] + and "_python_exit" in raw_lines[i + 1] + ) + ): while True: i += 1 if i >= len(raw_lines): @@ -267,7 +293,8 @@ def call(raw_cmd, assert_output=False, check_mypy=False, check_errors=True, stde for line in lines: for errstr in always_err_strs: assert errstr not in line, "{errstr!r} in {line!r}".format(errstr=errstr, line=line) - if check_errors: + # ignore SyntaxWarnings containing assert_raises + if check_errors and "assert_raises(" not in line: assert "Traceback (most recent call last):" not in line, "Traceback in " + repr(line) assert "Exception" not in line, "Exception in " + repr(line) assert "Error" not in line, "Error in " + repr(line) @@ -300,8 +327,10 @@ def call_python(args, **kwargs): def call_coconut(args, **kwargs): """Calls Coconut.""" - if "--jobs" not in args and not PYPY and not PY26: - args = ["--jobs", "sys"] + args + if default_recursion_limit is not None and "--recursion-limit" not in args: + args = ["--recursion-limit", default_recursion_limit] + args + if default_stack_size is not None and "--stack-size" not in args: + args = ["--stack-size", default_stack_size] + args if "--mypy" in args and "check_mypy" not in kwargs: kwargs["check_mypy"] = True if PY26: @@ -642,6 +671,10 @@ class TestShell(unittest.TestCase): def test_code(self): call(["coconut", "-s", "-c", coconut_snip], assert_output=True) + if not PY2: + def test_target_3_snip(self): + call(["coconut", "-t3", "-c", target_3_snip], assert_output=True) + def test_pipe(self): call('echo ' + escape(coconut_snip) + "| coconut -s", shell=True, assert_output=True) @@ -685,6 +718,8 @@ def test_xontrib(self): p.expect("$") p.sendline("!(ls -la) |> bool") p.expect("True") + p.sendline("xontrib unload coconut") + p.expect("$") p.sendeof() if p.isalive(): p.terminate() @@ -703,6 +738,8 @@ def test_kernel_installation(self): call(["coconut", "--jupyter"], assert_output=kernel_installation_msg) stdout, stderr, retcode = call_output(["jupyter", "kernelspec", "list"]) stdout, stderr = "".join(stdout), "".join(stderr) + if not stdout: + stdout, stderr = stderr, "" assert not retcode and not stderr, stderr for kernel in (icoconut_custom_kernel_name,) + icoconut_default_kernel_names: assert kernel in stdout @@ -778,6 +815,14 @@ def test_no_tco(self): def test_no_wrap(self): run(["--no-wrap"]) + if get_bool_env_var("COCONUT_TEST_VERBOSE"): + def test_verbose(self): + run(["--jobs", "0", "--verbose"]) + + if get_bool_env_var("COCONUT_TEST_TRACE"): + def test_trace(self): + run(["--jobs", "0", "--trace"], check_errors=False) + # avoids a strange, unreproducable failure on appveyor if not (WINDOWS and sys.version_info[:2] == (3, 8)): def test_run(self): @@ -808,27 +853,28 @@ class TestExternal(unittest.TestCase): def test_pyprover(self): with using_path(pyprover): comp_pyprover() - run_pyprover() + if PY38: + run_pyprover() if not PYPY or PY2: def test_prelude(self): with using_path(prelude): comp_prelude() - if MYPY: + if MYPY and PY38: run_prelude() + def test_bbopt(self): + with using_path(bbopt): + comp_bbopt() + if not PYPY and PY38 and not PY310: + install_bbopt() + def test_pyston(self): with using_path(pyston): comp_pyston(["--no-tco"]) if PYPY and PY2: run_pyston() - def test_bbopt(self): - with using_path(bbopt): - comp_bbopt() - if not PYPY and (PY2 or PY36) and not PY39: - install_bbopt() - # ----------------------------------------------------------------------------------------------------------------------- # MAIN: diff --git a/coconut/tests/src/cocotest/agnostic/primary.coco b/coconut/tests/src/cocotest/agnostic/primary.coco index a74ac3fe3..8f61821a0 100644 --- a/coconut/tests/src/cocotest/agnostic/primary.coco +++ b/coconut/tests/src/cocotest/agnostic/primary.coco @@ -12,14 +12,8 @@ from math import \log10 as (log10) from importlib import reload # NOQA from enum import Enum # noqa -def assert_raises(c, exc): - """Test whether callable c raises an exception of type exc.""" - try: - c() - except exc: - return True - else: - raise AssertionError("%r failed to raise exception %r" % (c, exc)) +from .util import assert_raises + def primary_test() -> bool: """Basic no-dependency tests.""" @@ -103,7 +97,6 @@ def primary_test() -> bool: assert isinstance(one_line_class(), one_line_class) assert (.join)("")(["1", "2", "3"]) == "123" assert "" |> .join <| ["1","2","3"] == "123" - assert "". <| "join" <| ["1","2","3"] == "123" assert 1 |> [1,2,3][] == 2 == 1 |> [1,2,3]$[] assert 1 |> "123"[] == "2" == 1 |> "123"$[] assert (| -1, 0, |) :: range(1, 5) |> list == [-1, 0, 1, 2, 3, 4] @@ -454,7 +447,6 @@ def primary_test() -> bool: assert None?[herp].derp is None # type: ignore assert None?(derp)[herp] is None # type: ignore assert None?$(herp)(derp) is None # type: ignore - assert "a b c" == (" ". ?? "not gonna happen")("join")("abc") a: int[]? = None # type: ignore assert a is None assert range(5) |> iter |> reiterable |> .[1] == 1 @@ -1270,8 +1262,8 @@ def primary_test() -> bool: assert ys |> list == [] some_err = ValueError() - assert Expected(10) |> fmap$(.+1) == Expected(11) - assert Expected(error=some_err) |> fmap$(.+1) == Expected(error=some_err) + assert Expected(10) |> fmap$(.+1) == Expected(11) == Expected(10) |> .__fmap__(.+1) + assert Expected(error=some_err) |> fmap$(.+1) == Expected(error=some_err) == Expected(error=some_err) |> .__fmap__(.+1) res, err = Expected(10) assert (res, err) == (10, None) assert Expected("abc") @@ -1299,6 +1291,8 @@ def primary_test() -> bool: assert x == 10 Expected(error=err) = Expected(error=some_err) assert err is some_err + assert Expected(error=TypeError()).map_error(const some_err) == Expected(error=some_err) + assert Expected(10).map_error(const some_err) == Expected(10) recit = ([1,2,3] :: recit) |> map$(.+1) assert tee(recit) @@ -1416,7 +1410,7 @@ def primary_test() -> bool: assert weakref.ref(hardref)() |> list == [2, 3, 4] my_match_err = MatchError("my match error", 123) assert parallel_map(ident, [my_match_err]) |> list |> str == str([my_match_err]) - # repeat the same thin again now that my_match_err.str has been called + # repeat the same thing again now that my_match_err.str has been called assert parallel_map(ident, [my_match_err]) |> list |> str == str([my_match_err]) match data tuple(1, 2) in (1, 2, 3): assert False @@ -1501,4 +1495,98 @@ def primary_test() -> bool: optx <*?..= (,) optx <**?..= const None assert optx() is None + + s{} = s{1, 2} + s{*_} = s{1, 2} + s{*()} = s{} + s{*[]} = s{} + s{*s{}} = s{} + s{*f{}} = s{} + s{*m{}} = s{} + match s{*()} in s{1, 2}: + assert False + s{} = f{1, 2} + f{1} = f{1, 2} + f{1, *_} = f{1, 2} + f{1, 2, *()} = f{1, 2} + match f{} in s{}: + assert False + s{} = m{1, 1} + s{1} = m{1} + m{1, 1} = m{1, 1} + m{1} = m{1, 1} + match m{1, 1} in m{1}: + assert False + m{1, *_} = m{1, 1} + match m{1, *()} in m{1, 1}: + assert False + s{*(),} = s{} + s{1, *_,} = s{1, 2} + {**{},} = {} + m{} = collections.Counter() + match m{1, 1} in collections.Counter((1,)): + assert False + + assert_raises(() :: 1 .. 2, TypeError) + two = 2 + three = 3 + five = 5 + assert 1.0 two three ** -4 five == 2*5/3**4 + x = 10 + assert 2 x == 20 + assert 2 x**2 + 3 x == 230 + match 1 in (1,): + case True: + pass + case _: + assert False + assert two**2 three**2 == 2**2 * 3**2 + assert_raises(-> five (two + three), TypeError) + assert_raises(-> 5 (10), TypeError) + assert_raises(-> 5 [0], TypeError) + assert five ** 2 two == 50 + assert 2i x == 20i + some_str = "some" + assert_raises(-> some_str five, TypeError) + assert (not in)("a", "bcd") + assert not (not in)("a", "abc") + assert ("a" not in .)("bcd") + assert (. not in "abc")("d") + assert (is not)(1, True) + assert not (is not)(False, False) + assert (True is not .)(1) + assert (. is not True)(1) + a_dict = {} + a_dict[1] = 1 + a_dict[3] = 2 + a_dict[2] = 3 + assert a_dict |> str == "{1: 1, 3: 2, 2: 3}" == a_dict |> repr, a_dict + assert a_dict.keys() |> tuple == (1, 3, 2) + assert not a_dict.keys() `isinstance` list + assert not a_dict.values() `isinstance` list + assert not a_dict.items() `isinstance` list + assert len(a_dict.keys()) == len(a_dict.values()) == len(a_dict.items()) == 3 + assert {1: 1, 3: 2, 2: 3}.keys() |> tuple == (1, 3, 2) + assert {**{1: 0, 3: 0}, 2: 0}.keys() |> tuple == (1, 3, 2) == {**dict([(1, 1), (3, 2), (2, 3)])}.keys() |> tuple + assert a_dict == {1: 1, 2: 3, 3: 2} + assert {1: 1} |> str == "{1: 1}" == {1: 1} |> repr + assert py_dict `issubclass` dict + assert py_dict() `isinstance` dict + assert {5:0, 3:0, **{2:0, 6:0}, 8:0}.keys() |> tuple == (5, 3, 2, 6, 8) + a_multiset = m{1,1,2} + assert not a_multiset.keys() `isinstance` list + assert not a_multiset.values() `isinstance` list + assert not a_multiset.items() `isinstance` list + assert len(a_multiset.keys()) == len(a_multiset.values()) == len(a_multiset.items()) == 2 + assert (in)(1, [1, 2]) + assert not (1 not in .)([1, 2]) + assert not (in)([[]], []) + assert ("{a}" . .)("format")(a=1) == "1" + a_dict = {"a": 1, "b": 2} + a_dict |= {"a": 10, "c": 20} + assert a_dict == {"a": 10, "b": 2, "c": 20} == {"a": 1, "b": 2} | {"a": 10, "c": 20} + assert ["abc" ; "def"] == ['abc', 'def'] + assert ["abc" ;; "def"] == [['abc'], ['def']] + assert {"a":0, "b":1}$[0] == "a" + assert (|0, NotImplemented, 2|)$[1] is NotImplemented return True diff --git a/coconut/tests/src/cocotest/agnostic/specific.coco b/coconut/tests/src/cocotest/agnostic/specific.coco index e3c74ed82..128f82dcd 100644 --- a/coconut/tests/src/cocotest/agnostic/specific.coco +++ b/coconut/tests/src/cocotest/agnostic/specific.coco @@ -1,4 +1,4 @@ -from io import StringIO # type: ignore +from io import StringIO if TYPE_CHECKING: from typing import Any @@ -106,7 +106,7 @@ def py36_spec_test(tco: bool) -> bool: data D2[T <: int[]](xs: T) # type: ignore assert D2((10, 20)).xs == (10, 20) - def myid[T](x: T) -> T = x + def myid[ T ]( x : T ) -> T = x assert myid(10) == 10 def fst[T](x: T, y: T) -> T = x diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 935a34269..46c2fdd5f 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -92,7 +92,7 @@ def suite_test() -> bool: assert collatz(27) assert preop(1, 2).add() == 3 assert vector(3, 4) |> abs == 5 == vector_with_id(3, 4, 1) |> abs - assert vector(1, 2) |> ((v) -> map(v., ("x", "y"))) |> tuple == (1, 2) # type: ignore + assert vector(1, 2) |> ((v) -> map(getattr$(v), ("x", "y"))) |> tuple == (1, 2) # type: ignore assert vector(3, 1) |> vector(1, 2).transform |> ((v) -> map(v[], (0, 1))) |> tuple == (4, 3) # type: ignore assert vector(1, 2) |> vector(1, 2).__eq__ assert not vector(1, 2) |> vector(3, 4).__eq__ @@ -210,7 +210,7 @@ def suite_test() -> bool: assert not is_one([]) assert is_one([1]) assert trilen(3, 4).h == 5 == datamaker(trilen)(5).h - assert A().true() is True + assert clsA().true() is True inh_a = inh_A() assert inh_a.true() is True assert inh_a.inh_true1() is True @@ -437,7 +437,7 @@ def suite_test() -> bool: assert myreduce((+), (1, 2, 3)) == 6 assert recurse_n_times(10000) assert fake_recurse_n_times(10000) - a = A() + a = clsA() assert ((not)..a.true)() is False assert 10 % 4 % 3 == 2 == 10 `mod` 4 `mod` 3 assert square_times2_plus1(3) == 19 == square_times2_plus1_(3) @@ -706,7 +706,7 @@ def suite_test() -> bool: m = methtest2() assert m.inf_rec(5) == 10 == m.inf_rec_(5) assert reqs(lazy_client)$[:10] |> list == range(10) |> list == reqs(lazy_client_)$[:10] |> list - class T(A, B, *(C, D), metaclass=Meta, e=5) # type: ignore + class T(clsA, clsB, *(clsC, clsD), metaclass=Meta, e=5) # type: ignore assert T.a == 1 assert T.b == 2 assert T.c == 3 @@ -743,8 +743,8 @@ def suite_test() -> bool: assert just_it_of_int(1) |> list == [1] == just_it_of_int_(1) |> list assert must_be_int(4) == 4 == must_be_int_(4) assert typed_plus(1, 2) == 3 - (class inh_A() `isinstance` A) `isinstance` object = inh_A() - class inh_A() `isinstance` A `isinstance` object = inh_A() + (class inh_A() `isinstance` clsA) `isinstance` object = inh_A() + class inh_A() `isinstance` clsA `isinstance` object = inh_A() for maxdiff in (maxdiff1, maxdiff2, maxdiff3, maxdiff_): assert maxdiff([7,1,4,5]) == 4, "failed for " + repr(maxdiff) assert all(r == 4 for r in parallel_map(call$(?, [7,1,4,5]), [maxdiff1, maxdiff2, maxdiff3])) @@ -919,12 +919,12 @@ forward 2""") == 900 ) ), )) - A(.a=1) = A() - match A(.a=2) in A(): + clsA(.a=1) = clsA() + match clsA(.a=2) in clsA(): assert False - assert_raises((def -> A(.b=1) = A()), AttributeError) + assert_raises((def -> clsA(.b=1) = clsA()), AttributeError) assert MySubExc("derp") `isinstance` Exception - assert A().not_super() is True + assert clsA().not_super() is True match class store.A(1) = store.A(1) match data store.A(1) = store.A(1) match store.A(1) = store.A(1) @@ -1014,6 +1014,37 @@ forward 2""") == 900 assert try_divide(1, 2) |> fmap$(.+1) == Expected(1.5) assert sum_evens(0, 5) == 6 == sum_evens(1, 6) assert sum_evens(7, 3) == 0 == sum_evens(4, 4) + assert num_it() |> list == [5] + assert left_right_diff([10,4,8,3]) |> list == [15, 1, 11, 22] + assert num_until_neg_sum([2,-1,0,1,-3,3,-3]) == 6 + assert S((+), (.*10)) <| 2 == 22 + assert K 1 <| 2 == 1 + assert I 1 == 1 + assert KI 1 <| 2 == 2 + assert W(+) <| 3 == 6 + assert C((/), 5) <| 20 == 4 + assert B((.+1), (.*2)) <| 3 == 7 + assert B1((.*2), (+), 3) <| 4 == 14 + assert B2(((0,)+.), (,), 1, 2) <| 3 == (0, 1, 2, 3) + assert B3((.+1), (.*2), (.**2)) <| 3 == 19 + assert D((+), 5, (.+1)) <| 2 == 8 + assert Phi((,), (.+1), (.-1)) <| 5 == (6, 4) + assert Psi((,), (.+1), 3) <| 4 == (4, 5) + assert D1((,), 0, 1, (.+1)) <| 1 == (0, 1, 2) + assert D2((+), (.*2), 3, (.+1)) <| 4 == 11 + assert E((+), 10, (*), 2) <| 3 == 16 + assert Phi1((,), (+), (*), 2) <| 3 == (5, 6) + assert BE((,), (+), 10, 2, (*), 2) <| 3 == (12, 6) + assert (+) `on` (.*2) <*| (3, 5) == 16 + assert test_super_B().method({'somekey': 'string', 'someotherkey': 42}) + assert outer_func_normal() |> map$(call) |> list == [4] * 5 + for outer_func in (outer_func_1, outer_func_2, outer_func_3, outer_func_4, outer_func_5): + assert outer_func() |> map$(call) |> list == range(5) |> list + assert get_glob() == 0 + assert wrong_get_set_glob(10) == 0 + assert get_glob() == 0 + assert wrong_get_set_glob(20) == 10 + assert take_xy(xy("a", "b")) == ("a", "b") # must come at end assert fibs_calls[0] == 1 diff --git a/coconut/tests/src/cocotest/agnostic/util.coco b/coconut/tests/src/cocotest/agnostic/util.coco index bb86d3376..86aa712a8 100644 --- a/coconut/tests/src/cocotest/agnostic/util.coco +++ b/coconut/tests/src/cocotest/agnostic/util.coco @@ -225,7 +225,7 @@ operator ” (“) = (”) = (,) ..> map$(str) ..> "".join operator ! -match def (int(x))! = 0 if x else 1 +addpattern def (int(x))! = 0 if x else 1 # type: ignore addpattern def (float(x))! = 0.0 if x else 1.0 # type: ignore addpattern def x! if x = False # type: ignore addpattern def x! = True # type: ignore @@ -617,7 +617,7 @@ match def fact(n) = fact(n, 1) # type: ignore match addpattern def fact(0, acc) = acc # type: ignore addpattern match def fact(n, acc) = fact(n-1, acc*n) # type: ignore -def factorial(0, acc=1) = acc +addpattern def factorial(0, acc=1) = acc # type: ignore addpattern def factorial(int() as n, acc=1 if n > 0) = # type: ignore """this is a docstring""" factorial(n-1, acc*n) @@ -649,10 +649,10 @@ def classify(value): return "empty dict" else: return "dict" - match _ `isinstance` (set, frozenset) in value: - match s{} in value: + match s{*_} in value: + match s{*()} in value: return "empty set" - match {0} in value: + match {0, *()} in value: return "set of 0" return "set" raise TypeError() @@ -829,7 +829,7 @@ data trilen(h): return (a**2 + b**2)**0.5 |> datamaker(cls) # Inheritance: -class A: +class clsA: a = 1 def true(self): return True @@ -838,7 +838,7 @@ class A: return super().true() @classmethod def cls_true(cls) = True -class inh_A(A): +class inh_A(clsA): def inh_true1(self) = super().true() def inh_true2(self) = @@ -850,11 +850,11 @@ class inh_A(A): inh_true5 = def (self) -> super().true() @classmethod def inh_cls_true(cls) = super().cls_true() -class B: +class clsB: b = 2 -class C: +class clsC: c = 3 -class D: +class clsD: d = 4 class MyExc(Exception): @@ -865,6 +865,17 @@ class MySubExc(MyExc): def __init__(self, m): super().__init__(m) +class test_super_A: + @classmethod + addpattern def method(cls, {'somekey': str()}) = True + + +class test_super_B(test_super_A): + @classmethod + addpattern def method(cls, {'someotherkey': int(), **rest}) = + super().method(rest) + + # Nesting: class Nest: class B: @@ -1033,10 +1044,54 @@ class unrepresentable: # Typing if TYPE_CHECKING or sys.version_info >= (3, 5): - from typing import List, Dict, Any, cast + from typing import ( + List, + Dict, + Any, + cast, + Protocol, + TypeVar, + Generic, + ) + + T = TypeVar("T", covariant=True) + U = TypeVar("U", contravariant=True) + V = TypeVar("V", covariant=True) + + class SupportsAdd(Protocol, Generic[T, U, V]): + def __add__(self: T, other: U) -> V: + raise NotImplementedError + + class SupportsMul(Protocol, Generic[T, U, V]): + def __mul__(self: T, other: U) -> V: + raise NotImplementedError + + class X(Protocol): + x: str + + class Y(Protocol): + y: str + else: def cast(typ, value) = value +obj_with_add_and_mul: SupportsAdd &: SupportsMul = 10 +an_int: ( + (+) + &: (-) + &: (*) + &: (**) + &: (/) + &: (//) + &: (%) + &: (&) + &: (^) + &: (|) + &: (<<) + &: (>>) + &: (~) +) = 10 + def args_kwargs_func(args: List[Any]=[], kwargs: Dict[Any, Any]={}) -> typing.Literal[True] = True @@ -1056,6 +1111,15 @@ def try_divide(x: float, y: float) -> Expected[float]: except Exception as err: return Expected(error=err) +class xy: + def __init__(self, x: str, y: str): + self.x: str = x + self.y: str = y + +def take_xy(xy: X &: Y) -> (str; str) = + xy.x, xy.y + + # Enhanced Pattern-Matching def fact_(0, acc=1) = acc @@ -1291,6 +1355,13 @@ def ret_globals() = abc = 1 locals() +global glob = 0 +copyclosure def wrong_get_set_glob(x): + global glob + old_glob, glob = glob, x + return old_glob +def get_glob() = glob + # Pos/kwd only args match def pos_only(a, b, /) = a, b @@ -1379,6 +1450,34 @@ yield match def just_it_of_int(int() as x): match yield def just_it_of_int_(int() as x): yield x +yield def num_it() -> int$[]: + yield 5 + + +# combinators + +def S(f, g) = lift(f)(ident, g) +K = const +I = ident +KI = const(ident) +def W(f) = lift(f)(ident, ident) +def C(f, x) = flip(f)$(x) +B = (..) +def B1(f, g, x) = f .. g$(x) +def B2(f, g, x, y) = f .. g$(x, y) +def B3(f, g, h) = f .. g .. h +def D(f, x, g) = lift(f)(const x, g) +def Phi(f, g, h) = lift(f)(g, h) +def Psi(f, g, x) = g ..> lift(f)(const(g x), ident) +def D1(f, x, y, g) = lift(f)(const x, const y, g) +def D2(f, g, x, h) = lift(f)(const(g x), h) +def E(f, x, g, y) = lift(f)(const x, g$(y)) +def Phi1(f, g, h, x) = lift(f)(g$(x), h$(x)) +def BE(f, g, x, y, h, z) = lift(f)(const(g x y), h$(z)) + +def on(b, u) = (,) ..> map$(u) ..*> b + + # maximum difference def maxdiff1(ns) = ( ns @@ -1389,7 +1488,6 @@ def maxdiff1(ns) = ( |> reduce$(max, ?, -1) ) -def S(binop, unop) = lift(binop)(ident, unop) def ne_zero(x) = x != 0 maxdiff2 = ( @@ -1540,7 +1638,7 @@ data End(offset `isinstance` int = 0 if offset <= 0): # type: ignore end = End() -# advent of code +# coding challenges proc_moves = ( .strip() ..> .splitlines() @@ -1658,6 +1756,33 @@ def first_disjoint_n_(n, arr) = ( |> .$[0] ) +shifted_left_sum = ( + scan$((+), ?, 0) + ..> .$[:-1] +) + +shifted_right_sum = ( + reversed + ..> shifted_left_sum + ..> list + ..> reversed +) + +left_right_diff = ( + lift(zip)(shifted_left_sum, shifted_right_sum) + ..> starmap$(-) + ..> map$(abs) +) # type: ignore + +num_until_neg_sum = ( + sorted + ..> reversed + ..> scan$(+) + ..> filter$(.>0) + ..> list + ..> len +) + # Search patterns def first_twin(_ + [p, (.-2) -> p] + _) = (p, p+2) @@ -1787,3 +1912,50 @@ data Arr(shape, arr): setind(new_arr, ind, func(getind(self.arr, ind))) return self.__class__(self.shape, new_arr) def __neg__(self) = self |> fmap$(-) + + +# copyclosure + +def outer_func_normal(): + funcs = [] + for x in range(5): + def inner_func() = x + funcs.append(inner_func) + return funcs + +def outer_func_1(): + funcs = [] + for x in range(5): + copyclosure def inner_func() = x + funcs.append(inner_func) + return funcs + +def outer_func_2(): + funcs = [] + for x in range(5): + funcs.append(copyclosure def -> x) + return funcs + +def outer_func_3(): + funcs = [] + for x in range(5): + class inner_cls + @staticmethod + copyclosure def inner_cls.inner_func() = x + funcs.append(inner_cls.inner_func) + return funcs + +def outer_func_4(): + funcs = [] + for x in range(5): + match def inner_func(x) = x + addpattern copyclosure def inner_func() = x + funcs.append(inner_func) + return funcs + +def outer_func_5() -> (() -> int)[]: + funcs = [] + for x in range(5): + copyclosure def inner_func() -> int = x + funcs.append(inner_func) + return funcs diff --git a/coconut/tests/src/cocotest/non_strict/non_strict_test.coco b/coconut/tests/src/cocotest/non_strict/non_strict_test.coco index 17284f1d3..099e0dad2 100644 --- a/coconut/tests/src/cocotest/non_strict/non_strict_test.coco +++ b/coconut/tests/src/cocotest/non_strict/non_strict_test.coco @@ -80,6 +80,8 @@ def non_strict_test() -> bool: assert weird_func()()(5) == 5 a_dict: TextMap[str, int] = {"a": 1} assert a_dict["a"] == 1 + assert "". <| "join" <| ["1","2","3"] == "123" + assert "a b c" == (" ". ?? "not gonna happen")("join")("abc") return True if __name__ == "__main__": diff --git a/coconut/tests/src/cocotest/target_36/py36_test.coco b/coconut/tests/src/cocotest/target_36/py36_test.coco index 422ec2934..43e420fa0 100644 --- a/coconut/tests/src/cocotest/target_36/py36_test.coco +++ b/coconut/tests/src/cocotest/target_36/py36_test.coco @@ -8,9 +8,9 @@ def py36_test() -> bool: loop = asyncio.new_event_loop() async def ayield(x) = x - async def arange(n): + :async def arange(n): for i in range(n): - yield await ayield(i) + yield :await ayield(i) async def afor_test(): # syntax 1 got = [] @@ -44,11 +44,36 @@ def py36_test() -> bool: pass l: typing.List[int] = [] async def aiter_test(): - await (range(10) |> toa |> fmap$(l.append) |> aconsume) - await (arange_(10) |> fmap$(l.append) |> aconsume) + range(10) |> toa |> fmap$(l.append) |> aconsume |> await + arange_(10) |> fmap$(l.append) |> aconsume |> await loop.run_until_complete(aiter_test()) assert l == list(range(10)) + list(range(10)) + async def arec(x) = await arec(x-1) if x else x + async def outer_func(): + funcs = [] + for x in range(5): + funcs.append(async copyclosure def -> x) + return funcs + async def await_all(xs) = [await x for x in xs] + async def atest(): + assert ( + 10 + |> arec + |> await + |> (.+10) + |> arec + |> await + ) == 0 + assert ( + outer_func() + |> await + |> map$(call) + |> await_all + |> await + ) == range(5) |> list + loop.run_until_complete(atest()) + loop.close() return True diff --git a/coconut/tests/src/cocotest/target_sys/target_sys_test.coco b/coconut/tests/src/cocotest/target_sys/target_sys_test.coco index 2bc5b34b3..10cf50399 100644 --- a/coconut/tests/src/cocotest/target_sys/target_sys_test.coco +++ b/coconut/tests/src/cocotest/target_sys/target_sys_test.coco @@ -61,8 +61,8 @@ def asyncio_test() -> bool: aplus1: AsyncNumFunc[int] = async def x -> x + 1 async def main(): assert await async_map_test() - assert `(+)$(1) .. await aplus 1` 1 == 3 - assert `(.+1) .. await aplus_ 1` 1 == 3 + assert `(+)$(1) .. await (aplus 1)` 1 == 3 + assert `(.+1) .. await (aplus_ 1)` 1 == 3 assert await (async def (x, y) -> x + y)(1, 2) == 3 assert await (async def (int(x), int(y)) -> x + y)(1, 2) == 3 assert await (async match def (int(x), int(y)) -> x + y)(1, 2) == 3 diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index f99d7be6e..a94313b5a 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -1,14 +1,15 @@ from collections.abc import Sequence -from coconut.__coconut__ import consume as coc_consume # type: ignore +from coconut.__coconut__ import consume as coc_consume from coconut.constants import ( IPY, PY2, PY34, PY35, - WINDOWS, + PY36, PYPY, ) # type: ignore +from coconut._pyparsing import USE_COMPUTATION_GRAPH # type: ignore from coconut.exceptions import ( CoconutSyntaxError, CoconutStyleError, @@ -24,9 +25,9 @@ from coconut.convenience import ( coconut_eval, ) -if IPY and not WINDOWS: +if IPY: if PY35: - import asyncio # type: ignore + import asyncio from coconut.icoconut import CoconutKernel # type: ignore else: CoconutKernel = None # type: ignore @@ -36,6 +37,10 @@ def assert_raises(c, exc, not_exc=None, err_has=None): """Test whether callable c raises an exception of type exc.""" if not_exc is None and exc is CoconutSyntaxError: not_exc = CoconutParseError + # we don't check err_has without the computation graph since errors can be quite different + if not USE_COMPUTATION_GRAPH: + err_has = None + try: c() except exc as err: @@ -46,6 +51,8 @@ def assert_raises(c, exc, not_exc=None, err_has=None): assert any(has in str(err) for has in err_has), f"{str(err)!r} does not contain any of {err_has!r}" else: assert err_has in str(err), f"{err_has!r} not in {str(err)!r}" + if exc `isinstance` CoconutSyntaxError: + assert "SyntaxError" in str(exc.syntax_err()) except BaseException as err: raise AssertionError(f"got wrong exception {err} (expected {exc})") else: @@ -102,7 +109,7 @@ def test_setup_none() -> bool: assert "Ellipsis" not in parse("x: ... = 1") # things that don't parse correctly without the computation graph - if not PYPY: + if USE_COMPUTATION_GRAPH: exec(parse("assert (1,2,3,4) == ([1, 2], [3, 4]) |*> def (x, y) -> *x, *y"), {}) assert_raises(-> parse("(a := b)"), CoconutTargetError) @@ -121,6 +128,7 @@ def test_setup_none() -> bool: assert_raises(-> parse("f(**x, y)"), CoconutSyntaxError) assert_raises(-> parse("def f(x) = return x"), CoconutSyntaxError) assert_raises(-> parse("def f(x) =\n return x"), CoconutSyntaxError) + assert_raises(-> parse("10 20"), CoconutSyntaxError) assert_raises(-> parse("()[(())"), CoconutSyntaxError, err_has=""" unclosed open '[' (line 1) @@ -164,6 +172,8 @@ mismatched open '[' and close ')' (line 1) assert_raises(-> parse("(.+1) .. x -> x * 2"), CoconutSyntaxError, err_has="<..") assert_raises(-> parse('f"Black holes {*all_black_holes} and revelations"'), CoconutSyntaxError, err_has="format string") assert_raises(-> parse("operator ++\noperator ++"), CoconutSyntaxError, err_has="custom operator already declared") + assert_raises(-> parse("type HasIn = (in)"), CoconutSyntaxError, err_has="not supported") + assert_raises( -> parse("type abc[T,T] = T | T"), CoconutSyntaxError, @@ -184,19 +194,21 @@ def f() = assert 2 """.strip()), CoconutParseError, err_has=( """ - assert 2 - ^ - """.strip(), - """ assert 2 ~~~~~~~~~~~~^ """.strip(), + """ + assert 2 + ^ + """.strip() )) assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has=" ~~~~~~~^") assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has=" ~~~~~~^") + assert_raises(-> parse('"a" 10'), CoconutParseError, err_has=" ~~~~^") + assert_raises(-> parse("A. ."), CoconutParseError, err_has=" ~~~^") - assert_raises(-> parse("return = 1"), CoconutParseError, err_has="invalid use of the keyword") + assert_raises(-> parse("return = 1"), CoconutParseError, err_has='invalid use of the keyword "return"') assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") @@ -219,12 +231,13 @@ def gam_eps_rate(bitarr) = ( except CoconutParseError as err: err_str = str(err) assert "misplaced '?'" in err_str - assert """ + if not PYPY: + assert """ |> map$(int(?, 2)) ~~~~~^""" in err_str or """ |> map$(int(?, 2)) - ~~~~~~~~~~~~~~~~~^""" in err_str + ~~~~~~~~~~~~~~~~~^""" in err_str, err_str else: assert False @@ -286,6 +299,10 @@ else: assert_raises(-> parse("""case x: match x: pass"""), CoconutStyleError, err_has="case x:") + assert_raises(-> parse("obj."), CoconutStyleError, err_has="getattr") + + setup(strict=True, target="sys") + assert_raises(-> parse("await f x"), CoconutParseError, err_has='invalid use of the keyword "await"') setup(target="2.7") assert parse("from io import BytesIO", mode="lenient") == "from io import BytesIO" @@ -371,6 +388,10 @@ def test_kernel() -> bool: def test_numpy() -> bool: import numpy as np + A = np.array([1, 2;; 3, 4]) + B = np.array([5, 6;; 7, 8]) + C = np.array([19, 22;; 43, 50]) + assert isinstance(np.array([1, 2]) |> fmap$(.+1), np.ndarray) assert np.all(fmap(-> _ + 1, np.arange(3)) == np.array([1, 2, 3])) # type: ignore assert np.array([1, 2;; 3, 4]).shape == (2, 2) @@ -391,13 +412,12 @@ def test_numpy() -> bool: assert [1;2 ;;;; 3;4] |> np.array |> .shape == (2, 1, 1, 2) assert [1,2 ;;;; 3,4] |> np.array |> .shape == (2, 1, 1, 2) assert np.array([1,2 ;; 3,4]) `np.array_equal` np.array([[1,2],[3,4]]) - a = np.array([1,2 ;; 3,4]) - assert [a ; a] `np.array_equal` np.array([1,2,1,2 ;; 3,4,3,4]) - assert [a ;; a] `np.array_equal` np.array([1,2;; 3,4;; 1,2;; 3,4]) - assert [a ;;; a].shape == (2, 2, 2) # type: ignore - assert np.array([1, 2;; 3, 4]) @ np.array([5, 6;; 7, 8]) `np.array_equal` np.array([19, 22;; 43, 50]) - assert np.array([1, 2;; 3, 4]) @ np.identity(2) @ np.identity(2) `np.array_equal` np.array([1, 2;; 3, 4]) - assert (@)(np.array([1, 2;; 3, 4]), np.array([5, 6;; 7, 8])) `np.array_equal` np.array([19, 22;; 43, 50]) + assert [A ; A] `np.array_equal` np.array([1,2,1,2 ;; 3,4,3,4]) + assert [A ;; A] `np.array_equal` np.array([1,2;; 3,4;; 1,2;; 3,4]) + assert [A ;;; A].shape == (2, 2, 2) # type: ignore + assert A @ B `np.array_equal` C + assert A @ np.identity(2) @ np.identity(2) `np.array_equal` A + assert (@)(A, B) `np.array_equal` C non_zero_diags = ( np.array ..> lift(,)(ident, reversed ..> np.array) @@ -410,6 +430,9 @@ def test_numpy() -> bool: assert len(enumeration) == 4 # type: ignore assert enumeration[2] == ((1, 0), 3) # type: ignore assert list(enumeration) == [((0, 0), 1), ((0, 1), 2), ((1, 0), 3), ((1, 1), 4)] + for ind, x in multi_enumerate(np.array([1, 2])): + assert ind `isinstance` tuple, (type(ind), ind) + assert x `isinstance` (np.int32, np.int64), (type(x), x) assert all_equal(np.array([])) assert all_equal(np.array([1])) assert all_equal(np.array([1, 1])) @@ -429,12 +452,53 @@ def test_numpy() -> bool: assert (flatten(np.array([1,2;;3,4])) |> list) == [1,2,3,4] assert cycle(np.array([1,2;;3,4]), 2) `isinstance` cycle assert (cycle(np.array([1,2;;3,4]), 2) |> np.asarray) `np.array_equal` np.array([1,2;;3,4;;1,2;;3,4]) + assert 10 A `np.array_equal` A * 10 + assert A 10 `np.array_equal` A * 10 # type: ignore + assert A B `np.array_equal` A * B + obj_arr = np.array([[1, "a"], [2.3, "abc"]], dtype=object) + assert obj_arr |> multi_enumerate |> map$(.[0]) |> list == [(0, 0), (0, 1), (1, 0), (1, 1)] + + # must come at end; checks no modification + assert A `np.array_equal` np.array([1, 2;; 3, 4]) + assert B `np.array_equal` np.array([5, 6;; 7, 8]) + return True + + +def test_pandas() -> bool: + import pandas as pd + import numpy as np + d1 = pd.DataFrame({"nums": [1, 2, 3], "chars": ["a", "b", "c"]}) + assert d1$[0] == "nums" + assert [d1; d1].keys() |> list == ["nums", "chars"] * 2 # type: ignore + assert [d1;; d1].itertuples() |> list == [(0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c'), (0, 1, 'a'), (1, 2, 'b'), (2, 3, 'c')] # type: ignore + d2 = pd.DataFrame({"a": range(3) |> list, "b": range(1, 4) |> list}) + d3 = d2 |> fmap$(fmap$(.+1)) + assert d3["a"] |> list == range(1, 4) |> list + assert d3["b"] |> list == range(2, 5) |> list + assert multi_enumerate(d1) |> list == [((0, 0), 1), ((1, 0), 2), ((2, 0), 3), ((0, 1), 'a'), ((1, 1), 'b'), ((2, 1), 'c')] + assert not all_equal(d1) + assert not all_equal(d2) + assert cartesian_product(d1["nums"], d1["chars"]) `np.array_equal` np.array([ + 1; 'a';; + 1; 'b';; + 1; 'c';; + 2; 'a';; + 2; 'b';; + 2; 'c';; + 3; 'a';; + 3; 'b';; + 3; 'c';; + ], dtype=object) # type: ignore + d4 = d1 |> fmap$(def r -> r["nums2"] = r["nums"]*2; r) + assert (d4["nums"] * 2 == d4["nums2"]).all() return True def test_extras() -> bool: if not PYPY and (PY2 or PY34): assert test_numpy() is True + if not PYPY and PY36: + assert test_pandas() is True if CoconutKernel is not None: assert test_kernel() is True assert test_setup_none() is True diff --git a/coconut/util.py b/coconut/util.py index 2af23327e..216d0e4e3 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -30,6 +30,7 @@ from warnings import warn from types import MethodType from contextlib import contextmanager +from collections import defaultdict if sys.version_info >= (3, 2): from functools import lru_cache @@ -56,11 +57,6 @@ # ----------------------------------------------------------------------------------------------------------------------- -def printerr(*args, **kwargs): - """Prints to standard error.""" - print(*args, file=sys.stderr, **kwargs) - - def univ_open(filename, opentype="r+", encoding=None, **kwargs): """Open a file using default_encoding.""" if encoding is None: @@ -215,6 +211,14 @@ def memoize(maxsize=None, *args, **kwargs): return lru_cache(maxsize, *args, **kwargs) +class keydefaultdict(defaultdict, object): + """Version of defaultdict that calls the factory with the key.""" + + def __missing__(self, key): + self[key] = self.default_factory(key) + return self[key] + + # ----------------------------------------------------------------------------------------------------------------------- # VERSIONING: # ----------------------------------------------------------------------------------------------------------------------- diff --git a/xontrib/coconut.py b/xontrib/coconut.py index ebc278637..b681a1f99 100644 --- a/xontrib/coconut.py +++ b/xontrib/coconut.py @@ -19,7 +19,7 @@ from coconut.root import * # NOQA -from coconut.integrations import _load_xontrib_ +from coconut.integrations import _load_xontrib_, _unload_xontrib_ # NOQA # ----------------------------------------------------------------------------------------------------------------------- # MAIN: