diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2df5155a4..c7868a2d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: args: - --autofix - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 args: diff --git a/DOCS.md b/DOCS.md index 1355ca8fb..69cd43466 100644 --- a/DOCS.md +++ b/DOCS.md @@ -225,22 +225,22 @@ as an alias for ``` coconut --quiet --target sys --keep-lines --run --argv ``` -which will quietly compile and run ``, passing any additional arguments to the script, mimicking how the `python` command works. +which will quietly compile and run ``, passing any additional arguments to the script, mimicking how the `python` command works. To instead pass additional compilation arguments to Coconut itself (e.g. `--no-tco`), put them before the `` file. + +`coconut-run` can be used to compile and run directories rather than files, again mimicking how the `python` command works. Specifically, Coconut will compile the directory and then run the `__main__.coco` in that directory, which must exist. `coconut-run` can be used in a Unix shebang line to create a Coconut script by adding the following line to the start of your script: ```bash #!/usr/bin/env coconut-run ``` -To pass additional compilation arguments to `coconut-run` (e.g. `--no-tco`), put them before the `` file. - `coconut-run` will always enable [automatic compilation](#automatic-compilation), such that Coconut source files can be directly imported from any Coconut files run via `coconut-run`. Additionally, compilation parameters (e.g. `--no-tco`) used in `coconut-run` will be passed along and used for any auto compilation. On Python 3.4+, `coconut-run` will use a `__coconut_cache__` directory to cache the compiled Python. Note that `__coconut_cache__` will always be removed from `__file__`. #### Naming Source Files -Coconut source files should, so the compiler can recognize them, use the extension `.coco` (preferred), `.coc`, or `.coconut`. +Coconut source files should, so the compiler can recognize them, use the extension `.coco`. When Coconut compiles a `.coco` file, it will compile to another file with the same name, except with `.py` instead of `.coco`, which will hold the compiled code. @@ -248,7 +248,7 @@ If an extension other than `.py` is desired for the compiled files, then that ex #### Compilation Modes -Files compiled by the `coconut` command-line utility will vary based on compilation parameters. If an entire directory of files is compiled (which the compiler will search recursively for any folders containing `.coco`, `.coc`, or `.coconut` files), a `__coconut__.py` file will be created to house necessary functions (package mode), whereas if only a single file is compiled, that information will be stored within a header inside the file (standalone mode). Standalone mode is better for single files because it gets rid of the overhead involved in importing `__coconut__.py`, but package mode is better for large packages because it gets rid of the need to run the same Coconut header code again in every file, since it can just be imported from `__coconut__.py`. +Files compiled by the `coconut` command-line utility will vary based on compilation parameters. If an entire directory of files is compiled (which the compiler will search recursively for any folders containing `.coco` files), a `__coconut__.py` file will be created to house necessary functions (package mode), whereas if only a single file is compiled, that information will be stored within a header inside the file (standalone mode). Standalone mode is better for single files because it gets rid of the overhead involved in importing `__coconut__.py`, but package mode is better for large packages because it gets rid of the need to run the same Coconut header code again in every file, since it can just be imported from `__coconut__.py`. By default, if the `source` argument to the command-line utility is a file, it will perform standalone compilation on it, whereas if it is a directory, it will recursively search for all `.coco` files and perform package compilation on them. Thus, in most cases, the mode chosen by Coconut automatically will be the right one. But if it is very important that no additional files like `__coconut__.py` be created, for example, then the command-line utility can also be forced to use a specific mode with the `--package` (`-p`) and `--standalone` (`-a`) flags. @@ -258,6 +258,7 @@ While Coconut syntax is based off of the latest Python 3, Coconut code compiled To make Coconut built-ins universal across Python versions, Coconut makes available on any Python version built-ins that only exist in later versions, including **automatically overwriting Python 2 built-ins with their Python 3 counterparts.** Additionally, Coconut also [overwrites some Python 3 built-ins for optimization and enhancement purposes](#enhanced-built-ins). If access to the original Python versions of any overwritten built-ins is desired, the old built-ins can be retrieved by prefixing them with `py_`. Specifically, the overwritten built-ins are: +- `py_bytes` - `py_chr` - `py_dict` - `py_hex` @@ -480,14 +481,14 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all - Coconut's [multidimensional array literal and array concatenation syntax](#multidimensional-array-literalconcatenation-syntax) supports `numpy` objects, including using fast `numpy` concatenation methods if given `numpy` arrays rather than Coconut's default much slower implementation built for Python lists of lists. - Many of Coconut's built-ins include special `numpy` support, specifically: * [`fmap`](#fmap) will use [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) to map over `numpy` arrays. - * [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multi-dimensional indices in a `numpy` array. + * [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multidimensional indices in a `numpy` array. * [`cartesian_product`](#cartesian_product) can compute the Cartesian product of given `numpy` arrays as a `numpy` array. * [`all_equal`](#all_equal) allows for easily checking if all the elements in a `numpy` array are the same. - [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification). - `numpy` objects are allowed seamlessly in Coconut's [implicit coefficient syntax](#implicit-function-application-and-coefficients), allowing the use of e.g. `A B**2` shorthand for `A * B**2` when `A` and `B` are `numpy` arrays (note: **not** `A @ B**2`). - Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions). -Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `pandas`/`jax`-specific methods over `numpy` methods when given `pandas`/`jax` objects. +Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`xarray`](https://docs.xarray.dev/en/stable/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects. #### `xonsh` Support @@ -688,9 +689,10 @@ Coconut uses pipe operators for pipeline-style function application. All the ope The None-aware pipe operators here are equivalent to a [monadic bind](https://en.wikipedia.org/wiki/Monad_(functional_programming)) treating the object as a `Maybe` monad composed of either `None` or the given object. Thus, `x |?> f` is equivalent to `None if x is None else f(x)`. Note that only the object being piped, not the function being piped into, may be `None` for `None`-aware pipes. -For working with `async` functions in pipes, all non-starred pipes support piping into `await` to await the awaitable piped into them, such that `x |> await` is equivalent to `await x`. - -Additionally, all pipe operators support a lambda as the last argument, despite lambdas having a lower precedence. Thus, `a |> x => b |> c` is equivalent to `a |> (x => b |> c)`, not `a |> (x => b) |> c`. +Additionally, some special syntax constructs are only available in pipes to enable doing as many operations as possible via pipes if so desired: +* For working with `async` functions in pipes, all non-starred pipes support piping into `await` to await the awaitable piped into them, such that `x |> await` is equivalent to `await x`. +* All non-starred pipes support piping into `( := .)` (mirroring the syntax for [operator implicit partials](#implicit-partial-application)) to assign the piped in item to ``. +* All pipe operators support a lambda as the last argument, despite lambdas having a lower precedence. Thus, `a |> x => b |> c` is equivalent to `a |> (x => b |> c)`, not `a |> (x => b) |> c`. _Note: To visually spread operations across several lines, just use [parenthetical continuation](#enhanced-parenthetical-continuation)._ @@ -1147,11 +1149,17 @@ depth: 1 ### `match` -Coconut provides fully-featured, functional pattern-matching through its `match` statements. +Coconut provides fully-featured, functional pattern-matching through its `match` statements. Coconut `match` syntax is a strict superset of [Python's `match` syntax](https://peps.python.org/pep-0636/). + +_Note: In describing Coconut's pattern-matching syntax, this section focuses on `match` statements, but Coconut's pattern-matching can also be used in many other places, such as [pattern-matching function definition](#pattern-matching-functions), [`case` statements](#case), [destructuring assignment](#destructuring-assignment), [`match data`](#match-data), and [`match for`](#match-for)._ ##### Overview -Match statements follow the basic syntax `match in `. The match statement will attempt to match the value against the pattern, and if successful, bind any variables in the pattern to whatever is in the same position in the value, and execute the code below the match statement. Match statements also support, in their basic syntax, an `if ` that will check the condition after executing the match before executing the code below, and an `else` statement afterwards that will only be executed if the `match` statement is not. What is allowed in the match statement's pattern has no equivalent in Python, and thus the specifications below are provided to explain it. +Match statements follow the basic syntax `match in `. The match statement will attempt to match the value against the pattern, and if successful, bind any variables in the pattern to whatever is in the same position in the value, and execute the code below the match statement. + +Match statements also support, in their basic syntax, an `if ` that will check the condition after executing the match before executing the code below, and an `else` statement afterwards that will only be executed if the `match` statement is not. + +All pattern-matching in Coconut is atomic, such that no assignments will be executed unless the whole match succeeds. ##### Syntax Specification @@ -1726,6 +1734,8 @@ If the last `statement` (not followed by a semicolon) in a statement lambda is a Statement lambdas also support implicit lambda syntax such that `def => _` is equivalent to `def (_=None) => _` as well as explicitly marking them as pattern-matching such that `match def (x) => x` will be a pattern-matching function. +Additionally, statement lambdas have slightly different scoping rules than normal lambdas. When a statement lambda is inside of an expression with an expression-local variable, such as a normal lambda or comprehension, the statement lambda will capture the value of the variable at the time that the statement lambda is defined (rather than a reference to the overall namespace as with normal lambdas). As a result, while `[=> y for y in range(2)] |> map$(call) |> list` is `[1, 1]`, `[def => y for y in range(2)] |> map$(call) |> list` is `[0, 1]`. Note that this only works for expression-local variables: to copy the entire namespace at the time of function definition, use [`copyclosure`](#copyclosure-functions) (which can be used with statement lambdas). + Note that statement lambdas have a lower precedence than normal lambdas and thus capture things like trailing commas. To avoid confusion, statement lambdas should always be wrapped in their own set of parentheses. _Deprecated: Statement lambdas also support `->` instead of `=>`. Note that when using `->`, any lambdas in the body of the statement lambda must also use `->` rather than `=>`._ @@ -1764,6 +1774,8 @@ _Deprecated: if the deprecated `->` is used in place of `=>`, then return type a Coconut uses a simple operator function short-hand: surround an operator with parentheses to retrieve its function. Similarly to iterator comprehensions, if the operator function is the only argument to a function, the parentheses of the function call can also serve as the parentheses for the operator function. +All operator functions also support [implicit partial application](#implicit-partial-application), e.g. `(. + 1)` is equivalent to `(=> _ + 1)`. + ##### Rationale A very common thing to do in functional programming is to make use of function versions of built-in operators: currying them, composing them, and piping them. To make this easy, Coconut provides a short-hand syntax to access operator functions. @@ -1822,6 +1834,10 @@ A very common thing to do in functional programming is to make use of function v (not in) => # negative containment (assert) => def (cond, msg=None) => assert cond, msg # (but a better msg if msg is None) (raise) => def (exc=None, from_exc=None) => raise exc from from_exc # or just raise if exc is None +# operator functions for multidimensional array concatenation use brackets: +[;] => def (x, y) => [x; y] +[;;] => def (x, y) => [x;; y] +... # and so on for any number of semicolons # there are two operator functions that don't require parentheses: .[] => (operator.getitem) .$[] => # iterator slicing operator @@ -1849,25 +1865,34 @@ print(list(map(operator.add, range(0, 5), range(5, 10)))) Coconut supports a number of different syntactical aliases for common partial application use cases. These are: ```coconut -.attr => operator.attrgetter("attr") -.method(args) => operator.methodcaller("method", args) -func$ => ($)$(func) -seq[] => operator.getitem$(seq) -iter$[] => # the equivalent of seq[] for iterators -.[a:b:c] => operator.itemgetter(slice(a, b, c)) -.$[a:b:c] => # the equivalent of .[a:b:c] for iterators -``` +# attribute access and method calling +.attr1.attr2 => operator.attrgetter("attr1.attr2") +.method(args) => operator.methodcaller("method", args) +.attr.method(args) => .attr ..> .method(args) -Additionally, `.attr.method(args)`, `.[x][y]`, `.$[x]$[y]`, and `.method[x]` are also supported. +# indexing +.[a:b:c] => operator.itemgetter(slice(a, b, c)) +.[x][y] => .[x] ..> .[y] +.method[x] => .method ..> .[x] +seq[] => operator.getitem$(seq) + +# iterator indexing +.$[a:b:c] => # the equivalent of .[a:b:c] for iterators +.$[x]$[y] => .$[x] ..> .$[y] +iter$[] => # the equivalent of seq[] for iterators + +# currying +func$ => ($)$(func) +``` In addition, for every Coconut [operator function](#operator-functions), Coconut supports syntax for implicitly partially applying that operator function as ``` (. ) ( .) ``` -where `` is the operator function and `` is any expression. Note that, as with operator functions themselves, the parentheses are necessary for this type of implicit partial application. +where `` is the operator function and `` is any expression. Note that, as with operator functions themselves, the parentheses are necessary for this type of implicit partial application. This syntax is slightly different for multidimensional array concatenation operator functions, which use brackets instead of parentheses. -Additionally, Coconut also supports implicit operator function partials for arbitrary functions as +Furthermore, Coconut also supports implicit operator function partials for arbitrary functions as ``` (. `` ) ( `` .) @@ -2067,6 +2092,8 @@ If multiple different concatenation operators are used, the operators with the l [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] ``` +_Note: the [operator functions](#operator-functions) for multidimensional array concatenation are spelled `[;]`, `[;;]`, etc. (with any number of parentheses). The [implicit partials](#implicit-partial-application) are similarly spelled `[. ; x]`, `[x ; .]`, etc._ + ##### Comparison to Julia Coconut's multidimensional array syntax is based on that of [Julia](https://docs.julialang.org/en/v1/manual/arrays/#man-array-literals). The primary difference between Coconut's syntax and Julia's syntax is that multidimensional arrays are row-first in Coconut (following `numpy`), but column-first in Julia. Thus, `;` is vertical concatenation in Julia but **horizontal concatenation** in Coconut and `;;` is horizontal concatenation in Julia but **vertical concatenation** in Coconut. @@ -2469,11 +2496,11 @@ where `` is defined as ``` where `` is the name of the function, `` is an optional additional check, `` is the body of the function, `` is defined by Coconut's [`match` statement](#match), `` is the optional default if no argument is passed, and `` is the optional return type annotation (note that argument type annotations are not supported for pattern-matching functions). The `match` keyword at the beginning is optional, but is sometimes necessary to disambiguate pattern-matching function definition from normal function definition, since Python function definition will always take precedence. Note that the `async` and `match` keywords can be in any order. -If `` has a variable name (either directly or with `as`), the resulting pattern-matching function will support keyword arguments using that variable name. +If `` has a variable name (via any variable binding that binds the entire pattern, e.g. `x` in `int(x)` or `[a, b] as x`), the resulting pattern-matching function will support keyword arguments using that variable name. In addition to supporting pattern-matching in their arguments, pattern-matching function definitions also have a couple of notable differences compared to Python functions. Specifically: - If pattern-matching function definition fails, it will raise a [`MatchError`](#matcherror) (just like [destructuring assignment](#destructuring-assignment)) instead of a `TypeError`. -- All defaults in pattern-matching function definition are late-bound rather than early-bound. Thus, `match def f(xs=[]) = xs` will instantiate a new list for each call where `xs` is not given, unlike `def f(xs=[]) = xs`, which will use the same list for all calls where `xs` is unspecified. +- All defaults in pattern-matching function definition are late-bound rather than early-bound. Thus, `match def f(xs=[]) = xs` will instantiate a new list for each call where `xs` is not given, unlike `def f(xs=[]) = xs`, which will use the same list for all calls where `xs` is unspecified. This also allows defaults for later arguments to be specified in terms of matched values from earlier arguments, as in `match def f(x, y=x) = (x, y)`. Pattern-matching function definition can also be combined with `async` functions, [`copyclosure` functions](#copyclosure-functions), [`yield` functions](#explicit-generators), [infix function definition](#infix-functions), and [assignment function syntax](#assignment-functions). The various keywords in front of the `def` can be put in any order. @@ -3364,16 +3391,10 @@ _Can't be done without a series of method definitions for each data type. See th In Haskell, `fmap(func, obj)` takes a data type `obj` and returns a new data type with `func` mapped over the contents. Coconut's `fmap` function does the exact same thing for Coconut's [data types](#data). -`fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, and `dict` as a variant of `map` that returns back an object of the same type. - -The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor). +`fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray`, and `dict` as a variant of `map` that returns back an object of the same type. For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. _Deprecated: `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them._ -For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. - -For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s). - For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to ```coconut_python async def fmap_over_async_iters(func, async_iter): @@ -3382,6 +3403,13 @@ async def fmap_over_async_iters(func, async_iter): ``` such that `fmap` can effectively be used as an async map. +Some objects from external libraries are also given special support: +* For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result. +* For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s). +* For [`xarray`](https://docs.xarray.dev/en/stable/) objects, `fmap` will first convert them into `pandas` objects, apply `fmap`, then convert them back. + +The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor). + _Deprecated: `fmap(func, obj, fallback_to_init=True)` will fall back to `obj.__class__(map(func, obj))` if no `fmap` implementation is available rather than raise `TypeError`._ ##### Example @@ -3490,11 +3518,11 @@ def flip(f, nargs=None) = ) ``` -#### `lift` +#### `lift` and `lift_apart` -**lift**(_func_) +##### **lift**(_func_) -**lift**(_func_, *_func\_args_, **_func\_kwargs_) +##### **lift**(_func_, *_func\_args_, **_func\_kwargs_) Coconut's `lift` built-in is a higher-order function that takes in a function and “lifts” it up so that all of its arguments are functions. @@ -3518,7 +3546,33 @@ def lift(f) = ( `lift` also supports a shortcut form such that `lift(f, *func_args, **func_kwargs)` is equivalent to `lift(f)(*func_args, **func_kwargs)`. -##### Example +##### **lift\_apart**(_func_) + +##### **lift\_apart**(_func_, *_func\_args_, **_func\_kwargs_) + +Coconut's `lift_apart` built-in is very similar to `lift`, except instead of duplicating the final arguments to each function, it separates them out. + +For a binary function `f(x, y)` and two unary functions `g(z)` and `h(z)`, `lift_apart` works as +```coconut +lift_apart(f)(g, h)(z, w) == f(g(z), h(w)) +``` +such that in this case `lift_apart` implements the `D2` combinator. + +In the general case, `lift_apart` is equivalent to a pickleable version of +```coconut +def lift_apart(f) = ( + (*func_args, **func_kwargs) => + (*args, **kwargs) => + f( + *(f(x) for f, x in zip(func_args, args, strict=True)), + **{k: func_kwargs[k](kwargs[k]) for k in func_kwargs.keys() | kwargs.keys()}, + ) +) +``` + +`lift_apart` supports the same shortcut form as `lift`. + +##### Examples **Coconut:** ```coconut @@ -3537,8 +3591,33 @@ def plus_and_times(x, y): return x + y, x * y ``` +**Coconut:** +```coconut +first_false_and_last_true = ( + lift(,)(ident, reversed) + ..*> lift_apart(,)(dropwhile$(bool), dropwhile$(not)) + ..*> lift_apart(,)(.$[0], .$[0]) +) +``` + +**Python:** +```coconut_python +from itertools import dropwhile + +def first_false_and_last_true(xs): + rev_xs = reversed(xs) + return ( + next(dropwhile(bool, xs)), + next(dropwhile(lambda x: not x, rev_xs)), + ) +``` + #### `and_then` and `and_then_await` +**and\_then**(_first\_async\_func_, _second\_func_) + +**and\_then\_await**(_first\_async\_func_, _second\_async\_func_) + Coconut provides the `and_then` and `and_then_await` built-ins for composing `async` functions. Specifically: * To forwards compose an async function `async_f` with a normal function `g` (such that `g` is called on the result of `await`ing `async_f`), write ``async_f `and_then` g``. * To forwards compose an async function `async_f` with another async function `async_g` (such that `async_g` is called on the result of `await`ing `async_f`, and then `async_g` is itself awaited), write ``async_f `and_then_await` async_g``. @@ -3891,7 +3970,7 @@ flat_it = iter_of_iters |> flatten |> list ```coconut_python from itertools import chain iter_of_iters = [[1, 2], [3, 4]] -flat_it = iter_of_iters |> chain.from_iterable |> list +flat_it = list(chain.from_iterable(iter_of_iters)) ``` #### `scan` @@ -4141,9 +4220,15 @@ _Can't be done without the definition of `windowsof`; see the compiled header fo #### `all_equal` -**all\_equal**(_iterable_) +**all\_equal**(_iterable_, _to_=`...`) + +Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. + +If _to_ is passed, `all_equal` will check that all the elements are specifically equal to that value, rather than just equal to each other. + +Note that `all_equal` assumes transitivity of equality, that `!=` is the negation of `==`, and that empty arrays always have all their elements equal. -Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. `all_equal` assumes transitivity of equality and that `!=` is the negation of `==`. Special support is provided for [`numpy`](#numpy-integration) objects. +Special support is provided for [`numpy`](#numpy-integration) objects. ##### Example @@ -4716,7 +4801,7 @@ Switches the [`breakpoint` built-in](https://www.python.org/dev/peps/pep-0553/) Both functions behave identically to [`setuptools.find_packages`](https://setuptools.pypa.io/en/latest/userguide/quickstart.html#package-discovery), except that they find Coconut packages rather than Python packages. `find_and_compile_packages` additionally compiles any Coconut packages that it finds in-place. -Note that if you want to use either of these functions in your `setup.py`, you'll need to include `coconut` as a [build-time dependency in your `pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/#build-time-dependencies). If you want `setuptools` to package your Coconut files, you'll also need to add `global-include *.coco` to your [`MANIFEST.in`](https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html). +Note that if you want to use either of these functions in your `setup.py`, you'll need to include `coconut` as a [build-time dependency in your `pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/#build-time-dependencies). If you want `setuptools` to package your Coconut files, you'll also need to add `global-include *.coco` to your [`MANIFEST.in`](https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html) and [pass `include_package_data=True` to `setuptools.setup`](https://setuptools.pypa.io/en/latest/userguide/datafiles.html). ##### Example diff --git a/Makefile b/Makefile index 93742e5e7..eb2094c8f 100644 --- a/Makefile +++ b/Makefile @@ -141,7 +141,7 @@ test-any-of: test-univ .PHONY: test-mypy-univ test-mypy-univ: export COCONUT_USE_COLOR=TRUE test-mypy-univ: clean - python ./coconut/tests --strict --keep-lines --force --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition + python ./coconut/tests --strict --keep-lines --force --no-cache --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py @@ -149,7 +149,7 @@ test-mypy-univ: clean .PHONY: test-mypy test-mypy: export COCONUT_USE_COLOR=TRUE test-mypy: clean - python ./coconut/tests --strict --keep-lines --force --target sys --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition + python ./coconut/tests --strict --keep-lines --force --target sys --no-cache --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py @@ -198,7 +198,7 @@ test-mypy-verbose: clean .PHONY: test-mypy-all test-mypy-all: export COCONUT_USE_COLOR=TRUE test-mypy-all: clean - python ./coconut/tests --strict --keep-lines --force --target sys --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition --check-untyped-defs + python ./coconut/tests --strict --keep-lines --force --target sys --no-cache --mypy --follow-imports silent --ignore-missing-imports --allow-redefinition --check-untyped-defs python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 70a0646f5..09313eb57 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -155,6 +155,7 @@ if sys.version_info < (3, 7): ... +py_bytes = bytes py_chr = chr py_dict = dict py_hex = hex @@ -1216,6 +1217,11 @@ def _coconut_comma_op(*args: _t.Any) -> _Tuple: ... +def _coconut_if_op(cond: _t.Any, if_true: _T, if_false: _U) -> _t.Union[_T, _U]: + """If operator (if). Equivalent to (cond, if_true, if_false) => if_true if cond else if_false.""" + ... + + if sys.version_info < (3, 5): @_t.overload def _coconut_matmul(a: _T, b: _T) -> _T: ... @@ -1451,7 +1457,7 @@ def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U], Supports: * Coconut data types - * `str`, `dict`, `list`, `tuple`, `set`, `frozenset` + * `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray` * `dict` (maps over .items()) * asynchronous iterables * numpy arrays (uses np.vectorize) @@ -1461,6 +1467,8 @@ def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U], """ ... +_coconut_fmap = fmap + def _coconut_handle_cls_kwargs(**kwargs: _t.Dict[_t.Text, _t.Any]) -> _t.Callable[[_T], _T]: ... @@ -1636,22 +1644,44 @@ def lift(func: _t.Callable[[_T, _U], _W]) -> _coconut_lifted_2[_T, _U, _W]: ... def lift(func: _t.Callable[[_T, _U, _V], _W]) -> _coconut_lifted_3[_T, _U, _V, _W]: ... @_t.overload def lift(func: _t.Callable[..., _W]) -> _t.Callable[..., _t.Callable[..., _W]]: - """Lift a function up so that all of its arguments are functions. + """Lift a function up so that all of its arguments are functions that all take the same arguments. For a binary function f(x, y) and two unary functions g(z) and h(z), lift works as the S' combinator: lift(f)(g, h)(z) == f(g(z), h(z)) In general, lift is equivalent to: - def lift(f) = ((*func_args, **func_kwargs) -> (*args, **kwargs) -> - f(*(g(*args, **kwargs) for g in func_args), **{lbrace}k: h(*args, **kwargs) for k, h in func_kwargs.items(){rbrace})) + def lift(f) = ((*func_args, **func_kwargs) => (*args, **kwargs) => ( + f(*(g(*args, **kwargs) for g in func_args), **{k: h(*args, **kwargs) for k, h in func_kwargs.items()})) + ) lift also supports a shortcut form such that lift(f, *func_args, **func_kwargs) is equivalent to lift(f)(*func_args, **func_kwargs). """ ... _coconut_lift = lift +@_t.overload +def lift_apart(func: _t.Callable[[_T], _W]) -> _t.Callable[[_t.Callable[[_U], _T]], _t.Callable[[_U], _W]]: ... +@_t.overload +def lift_apart(func: _t.Callable[[_T, _X], _W]) -> _t.Callable[[_t.Callable[[_U], _T], _t.Callable[[_Y], _X]], _t.Callable[[_U, _Y], _W]]: ... +@_t.overload +def lift_apart(func: _t.Callable[..., _W]) -> _t.Callable[..., _t.Callable[..., _W]]: + """Lift a function up so that all of its arguments are functions that each take separate arguments. + + For a binary function f(x, y) and two unary functions g(z) and h(z), lift_apart works as the D2 combinator: + lift_apart(f)(g, h)(z, w) == f(g(z), h(w)) + + In general, lift_apart is equivalent to: + def lift_apart(func) = (*func_args, **func_kwargs) => (*args, **kwargs) => func( + *map(call, func_args, args, strict=True), + **{k: func_kwargs[k](kwargs[k]) for k in func_kwargs.keys() | kwargs.keys()}, + ) -def all_equal(iterable: _Iterable) -> bool: + lift_apart also supports a shortcut form such that lift_apart(f, *func_args, **func_kwargs) is equivalent to lift_apart(f)(*func_args, **func_kwargs). + """ + ... + + +def all_equal(iterable: _t.Iterable[_T], to: _T = ...) -> bool: """For a given iterable, check whether all elements in that iterable are equal to each other. Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'. @@ -1828,45 +1858,44 @@ def _coconut_mk_anon_namedtuple( # @_t.overload -# def _coconut_multi_dim_arr( -# arrs: _t.Tuple[_coconut.npt.NDArray[_DType], ...], +# def _coconut_arr_concat_op( # dim: int, +# *arrs: _coconut.npt.NDArray[_DType], # ) -> _coconut.npt.NDArray[_DType]: ... # @_t.overload -# def _coconut_multi_dim_arr( -# arrs: _t.Tuple[_DType, ...], +# def _coconut_arr_concat_op( # dim: int, +# *arrs: _DType, # ) -> _coconut.npt.NDArray[_DType]: ... - @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_t.Sequence[_T], ...], +def _coconut_arr_concat_op( dim: _t.Literal[1], + *arrs: _t.Sequence[_T], ) -> _t.Sequence[_T]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_T, ...], +def _coconut_arr_concat_op( dim: _t.Literal[1], + *arrs: _T, ) -> _t.Sequence[_T]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_t.Sequence[_t.Sequence[_T]], ...], +def _coconut_arr_concat_op( dim: _t.Literal[2], + *arrs: _t.Sequence[_t.Sequence[_T]], ) -> _t.Sequence[_t.Sequence[_T]]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_t.Sequence[_T], ...], +def _coconut_arr_concat_op( dim: _t.Literal[2], + *arrs: _t.Sequence[_T], ) -> _t.Sequence[_t.Sequence[_T]]: ... @_t.overload -def _coconut_multi_dim_arr( - arrs: _t.Tuple[_T, ...], +def _coconut_arr_concat_op( dim: _t.Literal[2], + *arrs: _T, ) -> _t.Sequence[_t.Sequence[_T]]: ... @_t.overload -def _coconut_multi_dim_arr(arrs: _Tuple, dim: int) -> _Sequence: ... +def _coconut_arr_concat_op(dim: int, *arrs: _t.Any) -> _Sequence: ... class _coconut_SupportsAdd(_t.Protocol, _t.Generic[_Tco, _Ucontra, _Vco]): diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 31d9fd411..17c0e3418 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -109,8 +109,10 @@ npt = _npt # Fake, like typing zip_longest = _zip_longest numpy_modules: _t.Any = ... -pandas_numpy_modules: _t.Any = ... +xarray_modules: _t.Any = ... +pandas_modules: _t.Any = ... jax_numpy_modules: _t.Any = ... + tee_type: _t.Any = ... reiterables: _t.Any = ... fmappables: _t.Any = ... @@ -129,6 +131,7 @@ ValueError = _builtins.ValueError StopIteration = _builtins.StopIteration RuntimeError = _builtins.RuntimeError callable = _builtins.callable +chr = _builtins.chr classmethod = _builtins.classmethod complex = _builtins.complex all = _builtins.all @@ -157,6 +160,7 @@ min = _builtins.min max = _builtins.max next = _builtins.next object = _builtins.object +ord = _builtins.ord print = _builtins.print property = _builtins.property range = _builtins.range diff --git a/coconut/__coconut__.pyi b/coconut/__coconut__.pyi index e56d0e55e..520b56973 100644 --- a/coconut/__coconut__.pyi +++ b/coconut/__coconut__.pyi @@ -1,2 +1,2 @@ from __coconut__ import * -from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter +from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op diff --git a/coconut/_pyparsing.py b/coconut/_pyparsing.py index c973208b5..6d08487a6 100644 --- a/coconut/_pyparsing.py +++ b/coconut/_pyparsing.py @@ -49,6 +49,7 @@ warn_on_multiline_regex, num_displayed_timing_items, use_cache_file, + use_line_by_line_parser, ) from coconut.util import get_clock_time # NOQA from coconut.util import ( @@ -183,7 +184,6 @@ def _parseCache(self, instring, loc, doActions=True, callPreParse=True): if isinstance(value, Exception): raise value return value[0], value[1].copy() - ParserElement._parseCache = _parseCache # [CPYPARSING] fix append @@ -249,11 +249,12 @@ def enableIncremental(*args, **kwargs): ) SUPPORTS_ADAPTIVE = ( - hasattr(MatchFirst, "setAdaptiveMode") - and USE_COMPUTATION_GRAPH + USE_COMPUTATION_GRAPH + and hasattr(MatchFirst, "setAdaptiveMode") ) USE_CACHE = SUPPORTS_INCREMENTAL and use_cache_file +USE_LINE_BY_LINE = USE_COMPUTATION_GRAPH and use_line_by_line_parser if MODERN_PYPARSING: _trim_arity = _pyparsing.core._trim_arity diff --git a/coconut/command/command.py b/coconut/command/command.py index 95e21d0da..fc6fe2d3e 100644 --- a/coconut/command/command.py +++ b/coconut/command/command.py @@ -246,22 +246,33 @@ def execute_args(self, args, interact=True, original_args=None): unset_fast_pyparsing_reprs() if args.profile: start_profiling() - logger.enable_colors() logger.log(cli_version) if original_args is not None: logger.log("Directly passed args:", original_args) logger.log("Parsed args:", args) - # validate general command args + # validate args and show warnings if args.stack_size and args.stack_size % 4 != 0: logger.warn("--stack-size should generally be a multiple of 4, not {stack_size} (to support 4 KB pages)".format(stack_size=args.stack_size)) if args.mypy is not None and args.no_line_numbers: logger.warn("using --mypy running with --no-line-numbers is not recommended; mypy error messages won't include Coconut line numbers") + if args.interact and args.run: + logger.warn("extraneous --run argument passed; --interact implies --run") + if args.package and self.mypy: + logger.warn("extraneous --package argument passed; --mypy implies --package") + + # validate args and raise errors if args.line_numbers and args.no_line_numbers: raise CoconutException("cannot compile with both --line-numbers and --no-line-numbers") if args.site_install and args.site_uninstall: raise CoconutException("cannot --site-install and --site-uninstall simultaneously") + if args.standalone and args.package: + raise CoconutException("cannot compile as both --package and --standalone") + if args.standalone and self.mypy: + raise CoconutException("cannot compile as both --package (implied by --mypy) and --standalone") + if args.no_write and self.mypy: + raise CoconutException("cannot compile with --no-write when using --mypy") for and_args in getattr(args, "and") or []: if len(and_args) > 2: raise CoconutException( @@ -271,6 +282,9 @@ def execute_args(self, args, interact=True, original_args=None): ), ) + # modify args + args.run = args.run or args.interact + # process general command args self.set_jobs(args.jobs, args.profile) if args.recursion_limit is not None: @@ -338,44 +352,45 @@ def execute_args(self, args, interact=True, original_args=None): # do compilation, keeping track of compiled filepaths filepaths = [] if args.source is not None: - # warnings if source is given - if args.interact and args.run: - logger.warn("extraneous --run argument passed; --interact implies --run") - if args.package and self.mypy: - logger.warn("extraneous --package argument passed; --mypy implies --package") - - # errors if source is given - if args.standalone and args.package: - raise CoconutException("cannot compile as both --package and --standalone") - if args.standalone and self.mypy: - raise CoconutException("cannot compile as both --package (implied by --mypy) and --standalone") - if args.no_write and self.mypy: - raise CoconutException("cannot compile with --no-write when using --mypy") - # process all source, dest pairs - src_dest_package_triples = [] + all_compile_path_kwargs = [] + extra_compile_path_kwargs = [] for and_args in [(args.source, args.dest)] + (getattr(args, "and") or []): if len(and_args) == 1: src, = and_args dest = None else: src, dest = and_args - src_dest_package_triples.append(self.process_source_dest(src, dest, args)) + all_new_main_kwargs, all_new_extra_kwargs = self.process_source_dest(src, dest, args) + all_compile_path_kwargs += all_new_main_kwargs + extra_compile_path_kwargs += all_new_extra_kwargs # disable jobs if we know we're only compiling one file - if len(src_dest_package_triples) <= 1 and not any(os.path.isdir(source) for source, dest, package in src_dest_package_triples): + if len(all_compile_path_kwargs) <= 1 and not any(os.path.isdir(kwargs["source"]) for kwargs in all_compile_path_kwargs): self.disable_jobs() - # do compilation - with self.running_jobs(exit_on_error=not ( + # do main compilation + exit_on_error = extra_compile_path_kwargs or not ( args.watch or args.profile - )): - for source, dest, package in src_dest_package_triples: - filepaths += self.compile_path(source, dest, package, run=args.run or args.interact, force=args.force) + ) + with self.running_jobs(exit_on_error=exit_on_error): + for kwargs in all_compile_path_kwargs: + filepaths += self.compile_path(**kwargs) + + # run mypy on compiled files self.run_mypy(filepaths) + # do extra compilation if there is any + if extra_compile_path_kwargs: + with self.running_jobs(exit_on_error=exit_on_error): + for kwargs in extra_compile_path_kwargs: + extra_filepaths = self.compile_path(**kwargs) + internal_assert(lambda: set(extra_filepaths) <= set(filepaths), "new file paths from extra compilation", (extra_filepaths, filepaths)) + # validate args if no source is given + elif getattr(args, "and"): + raise CoconutException("--and should only be used for extra source/dest pairs, not the first source/dest pair") elif ( args.run or args.no_write @@ -386,8 +401,6 @@ def execute_args(self, args, interact=True, original_args=None): or args.jobs ): raise CoconutException("a source file/folder must be specified when options that depend on the source are enabled") - elif getattr(args, "and"): - raise CoconutException("--and should only be used for extra source/dest pairs, not the first source/dest pair") # handle extra cli tasks if args.code is not None: @@ -417,8 +430,8 @@ def execute_args(self, args, interact=True, original_args=None): ): self.start_prompt() if args.watch: - # src_dest_package_triples is always available here - self.watch(src_dest_package_triples, args.run, args.force) + # all_compile_path_kwargs is always available here + self.watch(all_compile_path_kwargs) if args.profile: print_profiling_results() @@ -426,16 +439,11 @@ def execute_args(self, args, interact=True, original_args=None): return filepaths def process_source_dest(self, source, dest, args): - """Determine the correct source, dest, package mode to use for the given source, dest, and args.""" + """Get all the compile_path kwargs to use for the given source, dest, and args.""" # determine source processed_source = fixpath(source) # validate args - if (args.run or args.interact) and os.path.isdir(processed_source): - if args.run: - raise CoconutException("source path %r must point to file not directory when --run is enabled" % (source,)) - if args.interact: - raise CoconutException("source path %r must point to file not directory when --run (implied by --interact) is enabled" % (source,)) if args.watch and os.path.isfile(processed_source): raise CoconutException("source path %r must point to directory not file when --watch is enabled" % (source,)) @@ -464,67 +472,51 @@ def process_source_dest(self, source, dest, args): else: raise CoconutException("could not find source path", source) - return processed_source, processed_dest, package - - def register_exit_code(self, code=1, errmsg=None, err=None): - """Update the exit code and errmsg.""" - if err is not None: - internal_assert(errmsg is None, "register_exit_code accepts only one of errmsg or err") - if logger.verbose: - errmsg = format_error(err) - else: - errmsg = err.__class__.__name__ - if errmsg is not None: - if self.errmsg is None: - self.errmsg = errmsg - elif errmsg not in self.errmsg: - if logger.verbose: - self.errmsg += "\nAnd error: " + errmsg - else: - self.errmsg += "; " + errmsg - if code is not None: - self.exit_code = code or self.exit_code - - @contextmanager - def handling_exceptions(self, exit_on_error=None, on_keyboard_interrupt=None): - """Perform proper exception handling.""" - if exit_on_error is None: - exit_on_error = self.fail_fast - try: - if self.using_jobs: - with handling_broken_process_pool(): - yield - else: - yield - except SystemExit as err: - self.register_exit_code(err.code) - # make sure we don't catch GeneratorExit below - except GeneratorExit: - raise - except BaseException as err: - if isinstance(err, CoconutException): - logger.print_exc() - elif isinstance(err, KeyboardInterrupt): - if on_keyboard_interrupt is not None: - on_keyboard_interrupt() - else: - logger.print_exc() - logger.printerr(report_this_text) - self.register_exit_code(err=err) - if exit_on_error: - self.exit_on_error() + # handle running directories + run = args.run + extra_compilation_tasks = [] + if run and os.path.isdir(processed_source): + main_source = os.path.join(processed_source, "__main__" + code_exts[0]) + if not os.path.isfile(main_source): + raise CoconutException("source directory {source} must contain a __main__{ext} when --run{implied} is enabled".format( + source=source, + ext=code_exts[0], + implied=" (implied by --interact)" if args.interact else "", + )) + # first compile the directory without --run + run = False + # then compile just __main__ with --run + extra_compilation_tasks.append(dict( + source=main_source, + dest=processed_dest, + package=package, + run=True, + force=args.force, + )) + + # compile_path kwargs + main_compilation_tasks = [ + dict( + source=processed_source, + dest=processed_dest, + package=package, + run=run, + force=args.force, + ), + ] + return main_compilation_tasks, extra_compilation_tasks - def compile_path(self, path, write=True, package=True, handling_exceptions_kwargs={}, **kwargs): + def compile_path(self, source, dest=True, package=True, handling_exceptions_kwargs={}, **kwargs): """Compile a path and return paths to compiled files.""" - if not isinstance(write, bool): - write = fixpath(write) - if os.path.isfile(path): - destpath = self.compile_file(path, write, package, **kwargs) + if not isinstance(dest, bool): + dest = fixpath(dest) + if os.path.isfile(source): + destpath = self.compile_file(source, dest, package, **kwargs) return [destpath] if destpath is not None else [] - elif os.path.isdir(path): - return self.compile_folder(path, write, package, handling_exceptions_kwargs=handling_exceptions_kwargs, **kwargs) + elif os.path.isdir(source): + return self.compile_folder(source, dest, package, handling_exceptions_kwargs=handling_exceptions_kwargs, **kwargs) else: - raise CoconutException("could not find source path", path) + raise CoconutException("could not find source path", source) def compile_folder(self, directory, write=True, package=True, handling_exceptions_kwargs={}, **kwargs): """Compile a directory and return paths to compiled files.""" @@ -693,6 +685,54 @@ def callback_wrapper(completed_future): callback(result) future.add_done_callback(callback_wrapper) + def register_exit_code(self, code=1, errmsg=None, err=None): + """Update the exit code and errmsg.""" + if err is not None: + internal_assert(errmsg is None, "register_exit_code accepts only one of errmsg or err") + if logger.verbose: + errmsg = format_error(err) + else: + errmsg = err.__class__.__name__ + if errmsg is not None: + if self.errmsg is None: + self.errmsg = errmsg + elif errmsg not in self.errmsg: + if logger.verbose: + self.errmsg += "\nAnd error: " + errmsg + else: + self.errmsg += "; " + errmsg + if code is not None: + self.exit_code = code or self.exit_code + + @contextmanager + def handling_exceptions(self, exit_on_error=None, on_keyboard_interrupt=None): + """Perform proper exception handling.""" + if exit_on_error is None: + exit_on_error = self.fail_fast + try: + if self.using_jobs: + with handling_broken_process_pool(): + yield + else: + yield + except SystemExit as err: + self.register_exit_code(err.code) + # make sure we don't catch GeneratorExit below + except GeneratorExit: + raise + except BaseException as err: + if isinstance(err, CoconutException): + logger.print_exc() + elif isinstance(err, KeyboardInterrupt): + if on_keyboard_interrupt is not None: + on_keyboard_interrupt() + else: + logger.print_exc() + logger.printerr(report_this_text) + self.register_exit_code(err=err) + if exit_on_error: + self.exit_on_error() + def set_jobs(self, jobs, profile=False): """Set --jobs.""" if jobs in (None, "sys"): @@ -1085,21 +1125,23 @@ def start_jupyter(self, args): if run_args is not None: self.register_exit_code(run_cmd(run_args, raise_errs=False), errmsg="Jupyter error") - def watch(self, src_dest_package_triples, run=False, force=False): + def watch(self, all_compile_path_kwargs): """Watch a source and recompile on change.""" from coconut.command.watch import Observer, RecompilationWatcher - for src, _, _ in src_dest_package_triples: + for kwargs in all_compile_path_kwargs: logger.show() - logger.show_tabulated("Watching", showpath(src), "(press Ctrl-C to end)...") + logger.show_tabulated("Watching", showpath(kwargs["source"]), "(press Ctrl-C to end)...") interrupted = [False] # in list to allow modification def interrupt(): interrupted[0] = True - def recompile(path, src, dest, package): + def recompile(path, **kwargs): path = fixpath(path) + src = kwargs.pop("source") + dest = kwargs.pop("dest") if os.path.isfile(path) and os.path.splitext(path)[1] in code_exts: with self.handling_exceptions(on_keyboard_interrupt=interrupt): if dest is True or dest is None: @@ -1111,19 +1153,17 @@ def recompile(path, src, dest, package): filepaths = self.compile_path( path, writedir, - package, - run=run, - force=force, show_unchanged=False, handling_exceptions_kwargs=dict(on_keyboard_interrupt=interrupt), + **kwargs # no comma for py2 ) self.run_mypy(filepaths) observer = Observer() watchers = [] - for src, dest, package in src_dest_package_triples: - watcher = RecompilationWatcher(recompile, src, dest, package) - observer.schedule(watcher, src, recursive=True) + for kwargs in all_compile_path_kwargs: + watcher = RecompilationWatcher(recompile, **kwargs) + observer.schedule(watcher, kwargs["source"], recursive=True) watchers.append(watcher) with self.running_jobs(): diff --git a/coconut/command/util.py b/coconut/command/util.py index 53cb00bfb..c4e0b1e7d 100644 --- a/coconut/command/util.py +++ b/coconut/command/util.py @@ -421,7 +421,12 @@ def unlink(link_path): def rm_dir_or_link(dir_to_rm): """Safely delete a directory without deleting the contents of symlinks.""" if not unlink(dir_to_rm) and os.path.exists(dir_to_rm): - if WINDOWS: + if PY2: # shutil.rmtree doesn't seem to be fully safe on Python 2 + try: + os.rmdir(dir_to_rm) + except OSError: + logger.warn_exc() + elif WINDOWS: try: os.rmdir(dir_to_rm) except OSError: diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 24307a965..b567759dc 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -17,6 +17,7 @@ # - Compiler # - Processors # - Handlers +# - Managers # - Checking Handlers # - Endpoints # - Binding @@ -40,6 +41,7 @@ from coconut._pyparsing import ( USE_COMPUTATION_GRAPH, USE_CACHE, + USE_LINE_BY_LINE, ParseBaseException, ParseResults, col as getcol, @@ -49,6 +51,7 @@ ) from coconut.constants import ( + PY35, specific_targets, targets, pseudo_targets, @@ -95,7 +98,7 @@ pickleable_obj, checksum, clip, - logical_lines, + literal_lines, clean, get_target_info, get_clock_time, @@ -130,6 +133,7 @@ attrgetter_atom_handle, itemgetter_handle, partial_op_item_handle, + partial_arr_concat_handle, ) from coconut.compiler.util import ( ExceptionNode, @@ -148,7 +152,6 @@ match_in, transform, parse, - all_matches, get_target_info_smart, split_leading_comments, compile_regex, @@ -178,7 +181,9 @@ load_cache_for, pickle_cache, handle_and_manage, + manage, sub_all, + ComputationNode, ) from coconut.compiler.header import ( minify_header, @@ -485,6 +490,7 @@ class Compiler(Grammar, pickleable_obj): def __init__(self, *args, **kwargs): """Creates a new compiler with the given parsing parameters.""" self.setup(*args, **kwargs) + self.reset() # changes here should be reflected in __reduce__, get_cli_args, and in the stub for coconut.api.setup def setup(self, target=None, strict=False, minify=False, line_numbers=True, keep_lines=False, no_tco=False, no_wrap=False): @@ -599,6 +605,7 @@ def reset(self, keep_state=False, filename=None): self.add_code_before_regexes = {} self.add_code_before_replacements = {} self.add_code_before_ignore_names = {} + self.remaining_original = None @contextmanager def inner_environment(self, ln=None): @@ -615,8 +622,10 @@ def inner_environment(self, ln=None): parsing_context, self.parsing_context = self.parsing_context, defaultdict(list) kept_lines, self.kept_lines = self.kept_lines, [] num_lines, self.num_lines = self.num_lines, 0 + remaining_original, self.remaining_original = self.remaining_original, None try: - yield + with ComputationNode.using_overrides(): + yield finally: self.outer_ln = outer_ln self.line_numbers = line_numbers @@ -628,23 +637,7 @@ def inner_environment(self, ln=None): self.parsing_context = parsing_context self.kept_lines = kept_lines self.num_lines = num_lines - - def current_parsing_context(self, name, default=None): - """Get the current parsing context for the given name.""" - stack = self.parsing_context[name] - if stack: - return stack[-1] - else: - return default - - @contextmanager - def add_to_parsing_context(self, name, obj): - """Add the given object to the parsing context for the given name.""" - self.parsing_context[name].append(obj) - try: - yield - finally: - self.parsing_context[name].pop() + self.remaining_original = remaining_original @contextmanager def disable_checks(self): @@ -693,15 +686,15 @@ def method(cls, method_name, is_action=None, **kwargs): trim_arity = should_trim_arity(cls_method) if is_action else False @wraps(cls_method) - def method(original, loc, tokens): + def method(original, loc, tokens_or_item): self_method = getattr(cls.current_compiler, method_name) if kwargs: self_method = partial(self_method, **kwargs) if trim_arity: self_method = _trim_arity(self_method) - return self_method(original, loc, tokens) + return self_method(original, loc, tokens_or_item) internal_assert( - hasattr(cls_method, "ignore_tokens") is hasattr(method, "ignore_tokens") + hasattr(cls_method, "ignore_arguments") is hasattr(method, "ignore_arguments") and hasattr(cls_method, "ignore_no_tokens") is hasattr(method, "ignore_no_tokens") and hasattr(cls_method, "ignore_one_token") is hasattr(method, "ignore_one_token"), "failed to properly wrap method", @@ -718,16 +711,19 @@ def bind(cls): cls.classdef_ref, cls.method("classdef_handle"), cls.method("class_manage"), + include_in_packrat_context=False, ) cls.datadef <<= handle_and_manage( cls.datadef_ref, cls.method("datadef_handle"), cls.method("class_manage"), + include_in_packrat_context=False, ) cls.match_datadef <<= handle_and_manage( cls.match_datadef_ref, cls.method("match_datadef_handle"), cls.method("class_manage"), + include_in_packrat_context=False, ) # handle parsing_context for function definitions @@ -735,16 +731,19 @@ def bind(cls): cls.stmt_lambdef_ref, cls.method("stmt_lambdef_handle"), cls.method("func_manage"), + include_in_packrat_context=False, ) cls.decoratable_normal_funcdef_stmt <<= handle_and_manage( cls.decoratable_normal_funcdef_stmt_ref, cls.method("decoratable_funcdef_stmt_handle"), cls.method("func_manage"), + include_in_packrat_context=False, ) cls.decoratable_async_funcdef_stmt <<= handle_and_manage( cls.decoratable_async_funcdef_stmt_ref, cls.method("decoratable_funcdef_stmt_handle", is_async=True), cls.method("func_manage"), + include_in_packrat_context=False, ) # handle parsing_context for type aliases @@ -752,6 +751,7 @@ def bind(cls): cls.type_alias_stmt_ref, cls.method("type_alias_stmt_handle"), cls.method("type_alias_stmt_manage"), + include_in_packrat_context=False, ) # handle parsing_context for where statements @@ -759,11 +759,37 @@ def bind(cls): cls.where_stmt_ref, cls.method("where_stmt_handle"), cls.method("where_stmt_manage"), + include_in_packrat_context=False, ) cls.implicit_return_where <<= handle_and_manage( cls.implicit_return_where_ref, cls.method("where_stmt_handle"), cls.method("where_stmt_manage"), + include_in_packrat_context=False, + ) + + # handle parsing_context for expr_setnames + # (we need include_in_packrat_context here because some parses will be in an expr_setname context and some won't) + cls.expr_lambdef <<= manage( + cls.expr_lambdef_ref, + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, + ) + cls.lambdef_no_cond <<= manage( + cls.lambdef_no_cond_ref, + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, + ) + cls.comprehension_expr <<= manage( + cls.comprehension_expr_ref, + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, + ) + cls.dict_comp <<= handle_and_manage( + cls.dict_comp_ref, + cls.method("dict_comp_handle"), + cls.method("has_expr_setname_manage"), + include_in_packrat_context=True, ) # greedy handlers (we need to know about them even if suppressed and/or they use the parsing_context) @@ -775,7 +801,16 @@ def bind(cls): # name handlers cls.refname <<= attach(cls.name_ref, cls.method("name_handle")) cls.setname <<= attach(cls.name_ref, cls.method("name_handle", assign=True)) - cls.classname <<= attach(cls.name_ref, cls.method("name_handle", assign=True, classname=True), greedy=True) + cls.classname <<= attach( + cls.name_ref, + cls.method("name_handle", assign=True, classname=True), + greedy=True, + ) + cls.expr_setname <<= attach( + cls.name_ref, + cls.method("name_handle", assign=True, expr_setname=True), + greedy=True, + ) # abnormally named handlers cls.moduledoc_item <<= attach(cls.moduledoc, cls.method("set_moduledoc")) @@ -788,6 +823,11 @@ def bind(cls): cls.trailer_atom <<= attach(cls.trailer_atom_ref, cls.method("item_handle")) cls.no_partial_trailer_atom <<= attach(cls.no_partial_trailer_atom_ref, cls.method("item_handle")) cls.simple_assign <<= attach(cls.simple_assign_ref, cls.method("item_handle")) + cls.expr_simple_assign <<= attach(cls.expr_simple_assign_ref, cls.method("item_handle")) + + # handle all star assignments with star_assign_item_check + cls.star_assign_item <<= attach(cls.star_assign_item_ref, cls.method("star_assign_item_check")) + cls.expr_star_assign_item <<= attach(cls.expr_star_assign_item_ref, cls.method("star_assign_item_check")) # handle all string atoms with string_atom_handle cls.string_atom <<= attach(cls.string_atom_ref, cls.method("string_atom_handle")) @@ -811,7 +851,6 @@ def bind(cls): cls.complex_raise_stmt <<= attach(cls.complex_raise_stmt_ref, cls.method("complex_raise_stmt_handle")) cls.augassign_stmt <<= attach(cls.augassign_stmt_ref, cls.method("augassign_stmt_handle")) cls.kwd_augassign <<= attach(cls.kwd_augassign_ref, cls.method("kwd_augassign_handle")) - cls.dict_comp <<= attach(cls.dict_comp_ref, cls.method("dict_comp_handle")) cls.destructuring_stmt <<= attach(cls.destructuring_stmt_ref, cls.method("destructuring_stmt_handle")) cls.full_match <<= attach(cls.full_match_ref, cls.method("full_match_handle")) cls.name_match_funcdef <<= attach(cls.name_match_funcdef_ref, cls.method("name_match_funcdef_handle")) @@ -841,7 +880,6 @@ def bind(cls): # these handlers just do strict/target checking cls.u_string <<= attach(cls.u_string_ref, cls.method("u_string_check")) cls.nonlocal_stmt <<= attach(cls.nonlocal_stmt_ref, cls.method("nonlocal_check")) - cls.star_assign_item <<= attach(cls.star_assign_item_ref, cls.method("star_assign_item_check")) cls.keyword_lambdef <<= attach(cls.keyword_lambdef_ref, cls.method("lambdef_check")) cls.star_sep_arg <<= attach(cls.star_sep_arg_ref, cls.method("star_sep_check")) cls.star_sep_setarg <<= attach(cls.star_sep_setarg_ref, cls.method("star_sep_check")) @@ -998,7 +1036,7 @@ def remove_strs(self, inputstring, inner_environment=True, **kwargs): try: with (self.inner_environment() if inner_environment else noop_ctx()): return self.str_proc(inputstring, **kwargs) - except Exception: + except CoconutSyntaxError: logger.log_exc() return None @@ -1160,7 +1198,7 @@ def target_info(self): """Return information on the current target as a version tuple.""" return get_target_info(self.target) - def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, reformat=True, endpoint=None, include_causes=False, **kwargs): + def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, reformat=True, endpoint=None, include_causes=False, use_startpoint=False, **kwargs): """Generate an error of the specified type.""" logger.log_loc("raw_loc", original, loc) logger.log_loc("raw_endpoint", original, endpoint) @@ -1170,13 +1208,19 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor logger.log_loc("loc", original, loc) # get endpoint + startpoint = None if endpoint is None: endpoint = reformat if endpoint is False: endpoint = loc else: if endpoint is True: - endpoint = get_highest_parse_loc(original) + if self.remaining_original is None: + endpoint = get_highest_parse_loc(original) + else: + startpoint = ComputationNode.add_to_loc + raw_endpoint = get_highest_parse_loc(self.remaining_original) + endpoint = startpoint + raw_endpoint logger.log_loc("highest_parse_loc", original, endpoint) endpoint = clip( move_endpt_to_non_whitespace(original, endpoint, backwards=True), @@ -1184,6 +1228,40 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor ) logger.log_loc("endpoint", original, endpoint) + # process startpoint + if startpoint is not None: + startpoint = move_loc_to_non_whitespace(original, startpoint) + logger.log_loc("startpoint", original, startpoint) + + # determine possible causes + if include_causes: + self.internal_assert(extra is None, original, loc, "make_err cannot include causes with extra") + causes = dictset() + for check_loc in dictset((loc, endpoint, startpoint)): + if check_loc is not None: + cause = try_parse(self.parse_err_msg, original[check_loc:], inner=True) + if cause: + causes.add(cause) + if causes: + extra = "possible cause{s}: {causes}".format( + s="s" if len(causes) > 1 else "", + causes=", ".join(ordered(causes)), + ) + else: + extra = None + + # use startpoint if appropriate + if startpoint is None: + use_startpoint = False + else: + if use_startpoint is None: + use_startpoint = ( + "\n" not in original[loc:endpoint] + and "\n" in original[startpoint:loc] + ) + if use_startpoint: + loc = startpoint + # get line number if ln is None: if self.outer_ln is None: @@ -1192,7 +1270,7 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor ln = self.outer_ln # get line indices for the error locs - original_lines = tuple(logical_lines(original, True)) + original_lines = tuple(literal_lines(original, True)) loc_line_ind = clip(lineno(loc, original) - 1, max=len(original_lines) - 1) # build the source snippet that the error is referring to @@ -1205,33 +1283,27 @@ def make_err(self, errtype, message, original, loc=0, ln=None, extra=None, refor logger.log_loc("loc_in_snip", snippet, loc_in_snip) logger.log_loc("endpt_in_snip", snippet, endpt_in_snip) - # determine possible causes - if include_causes: - self.internal_assert(extra is None, original, loc, "make_err cannot include causes with extra") - causes = dictset() - for cause, _, _ in all_matches(self.parse_err_msg, snippet[loc_in_snip:]): - if cause: - causes.add(cause) - for cause, _, _ in all_matches(self.parse_err_msg, snippet[endpt_in_snip:]): - if cause: - causes.add(cause) - if causes: - extra = "possible cause{s}: {causes}".format( - s="s" if len(causes) > 1 else "", - causes=", ".join(ordered(causes)), - ) - else: - extra = None - # reformat the snippet and fix error locations to match if reformat: snippet, loc_in_snip, endpt_in_snip = self.reformat_locs(snippet, loc_in_snip, endpt_in_snip) logger.log_loc("reformatted_loc", snippet, loc_in_snip) logger.log_loc("reformatted_endpt", snippet, endpt_in_snip) + # build the error if extra is not None: kwargs["extra"] = extra - return errtype(message, snippet, loc_in_snip, ln, endpoint=endpt_in_snip, filename=self.filename, **kwargs) + return errtype( + message, + snippet, + loc_in_snip, + ln, + endpoint=endpt_in_snip, + filename=self.filename, + **kwargs # no comma + ).set_formatting( + point_to_endpoint=True if use_startpoint else None, + max_err_msg_lines=2 if use_startpoint else None, + ) def make_syntax_err(self, err, original, after_parsing=False): """Make a CoconutSyntaxError from a CoconutDeferredSyntaxError.""" @@ -1244,7 +1316,7 @@ def make_parse_err(self, err, msg=None, include_ln=True, **kwargs): loc = err.loc ln = self.adjust(err.lineno) if include_ln else None - return self.make_err(CoconutParseError, msg, original, loc, ln, include_causes=True, **kwargs) + return self.make_err(CoconutParseError, msg, original, loc, ln, include_causes=True, use_startpoint=None, **kwargs) def make_internal_syntax_err(self, original, loc, msg, item, extra): """Make a CoconutInternalSyntaxError.""" @@ -1286,23 +1358,24 @@ def parsing(self, keep_state=False, codepath=None): Compiler.current_compiler = self yield - def streamline(self, grammar, inputstring=None, force=False, inner=False): - """Streamline the given grammar for the given inputstring.""" - input_len = 0 if inputstring is None else len(inputstring) - if force or (streamline_grammar_for_len is not None and input_len > streamline_grammar_for_len): - start_time = get_clock_time() - prep_grammar(grammar, streamline=True) - logger.log_lambda( - lambda: "Streamlined {grammar} in {time} seconds{info}.".format( - grammar=get_name(grammar), - time=get_clock_time() - start_time, - info="" if inputstring is None else " (streamlined due to receiving input of length {length})".format( - length=input_len, + def streamline(self, grammars, inputstring=None, force=False, inner=False): + """Streamline the given grammar(s) for the given inputstring.""" + for grammar in grammars if isinstance(grammars, tuple) else (grammars,): + input_len = 0 if inputstring is None else len(inputstring) + if force or (streamline_grammar_for_len is not None and input_len > streamline_grammar_for_len): + start_time = get_clock_time() + prep_grammar(grammar, streamline=True) + logger.log_lambda( + lambda: "Streamlined {grammar} in {time} seconds{info}.".format( + grammar=get_name(grammar), + time=get_clock_time() - start_time, + info="" if inputstring is None else " (streamlined due to receiving input of length {length})".format( + length=input_len, + ), ), - ), - ) - elif inputstring is not None and not inner: - logger.log("No streamlining done for input of length {length}.".format(length=input_len)) + ) + elif inputstring is not None and not inner: + logger.log("No streamlining done for input of length {length}.".format(length=input_len)) def run_final_checks(self, original, keep_state=False): """Run post-parsing checks to raise any necessary errors/warnings.""" @@ -1320,6 +1393,30 @@ def run_final_checks(self, original, keep_state=False): endpoint=False, ) + def parse_line_by_line(self, init_parser, line_parser, original): + """Apply init_parser then line_parser repeatedly.""" + if not USE_LINE_BY_LINE: + raise CoconutException("line-by-line parsing not supported", extra="run 'pip install --upgrade cPyparsing' to fix") + with ComputationNode.using_overrides(): + ComputationNode.override_original = original + out_parts = [] + init = True + cur_loc = 0 + while cur_loc < len(original): + self.remaining_original = original[cur_loc:] + ComputationNode.add_to_loc = cur_loc + results = parse(init_parser if init else line_parser, self.remaining_original, inner=False) + if len(results) == 1: + got_loc, = results + else: + got, got_loc = results + out_parts.append(got) + got_loc = int(got_loc) + internal_assert(got_loc >= cur_loc and (init or got_loc > cur_loc), "invalid line by line parse", (cur_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0])) + cur_loc = got_loc + init = False + return "".join(out_parts) + def parse( self, inputstring, @@ -1349,7 +1446,11 @@ def parse( with logger.gather_parsing_stats(): try: pre_procd = self.pre(inputstring, keep_state=keep_state, **preargs) - parsed = parse(parser, pre_procd, inner=False) + if isinstance(parser, tuple): + init_parser, line_parser = parser + parsed = self.parse_line_by_line(init_parser, line_parser, pre_procd) + else: + parsed = parse(parser, pre_procd, inner=False) out = self.post(parsed, keep_state=keep_state, **postargs) except ParseBaseException as err: raise self.make_parse_err(err) @@ -1357,7 +1458,7 @@ def parse( internal_assert(pre_procd is not None, "invalid deferred syntax error in pre-processing", err) raise self.make_syntax_err(err, pre_procd, after_parsing=parsed is not None) # RuntimeError, not RecursionError, for Python < 3.5 - except RuntimeError as err: + except (RecursionError if PY35 else RuntimeError) as err: raise CoconutException( str(err), extra="try again with --recursion-limit greater than the current " + str(sys.getrecursionlimit()) + " (you may also need to increase --stack-size)", @@ -1378,7 +1479,7 @@ def prepare(self, inputstring, strip=False, nl_at_eof_check=False, **kwargs): if self.strict and nl_at_eof_check and inputstring and not inputstring.endswith("\n"): end_index = len(inputstring) - 1 if inputstring else 0 raise self.make_err(CoconutStyleError, "missing new line at end of file", inputstring, end_index) - kept_lines = inputstring.splitlines() + kept_lines = tuple(literal_lines(inputstring)) self.num_lines = len(kept_lines) if self.keep_lines: self.kept_lines = kept_lines @@ -1559,16 +1660,8 @@ def str_proc(self, inputstring, **kwargs): # start the string hold if we're at the start of a string if hold is not None: - is_f = False - j = i - len(hold["start"]) - while j >= 0: - prev_c = inputstring[j] - if prev_c == "f": - is_f = True - break - elif prev_c != "r": - break - j -= 1 + is_f_check_str = inputstring[clip(i - len(hold["start"]) + 1 - self.start_f_str_regex_len, min=0): i - len(hold["start"]) + 1] + is_f = self.start_f_str_regex.search(is_f_check_str) if is_f: hold.update({ "type": "f string", @@ -1656,7 +1749,7 @@ def operator_proc(self, inputstring, keep_state=False, **kwargs): """Process custom operator definitions.""" out = [] skips = self.copy_skips() - for i, raw_line in enumerate(logical_lines(inputstring, keep_newlines=True)): + for i, raw_line in enumerate(literal_lines(inputstring, keep_newlines=True)): ln = i + 1 base_line = rem_comment(raw_line) stripped_line = base_line.lstrip() @@ -1743,7 +1836,7 @@ def leading_whitespace(self, inputstring): def ind_proc(self, inputstring, **kwargs): """Process indentation and ensure balanced parentheses.""" - lines = tuple(logical_lines(inputstring)) + lines = tuple(literal_lines(inputstring)) new = [] # new lines current = None # indentation level of previous line levels = [] # indentation levels of all previous blocks, newest at end @@ -1814,7 +1907,7 @@ def ind_proc(self, inputstring, **kwargs): original=line, ln=self.adjust(len(new)), **err_kwargs - ).set_point_to_endpoint(True) + ).set_formatting(point_to_endpoint=True) self.set_skips(skips) if new: @@ -1836,11 +1929,8 @@ def reind_proc(self, inputstring, ignore_errors=False, **kwargs): out_lines = [] level = 0 - next_line_is_fake = False - for line in inputstring.splitlines(True): - is_fake = next_line_is_fake - next_line_is_fake = line.endswith("\f") and line.rstrip("\f") == line.rstrip() - + is_fake = False + for next_line_is_real, line in literal_lines(inputstring, True, yield_next_line_is_real=True): line, comment = split_comment(line.strip()) indent, line = split_leading_indent(line) @@ -1869,6 +1959,8 @@ def reind_proc(self, inputstring, ignore_errors=False, **kwargs): line = (line + comment).rstrip() out_lines.append(line) + is_fake = not next_line_is_real + if not ignore_errors and level != 0: logger.log_lambda(lambda: "failed to reindent:\n" + inputstring) complain("non-zero final indentation level: " + repr(level)) @@ -1915,7 +2007,7 @@ def endline_repl(self, inputstring, reformatting=False, ignore_errors=False, **k """Add end of line comments.""" out_lines = [] ln = 1 # line number in pre-processed original - for line in logical_lines(inputstring): + for line in literal_lines(inputstring): add_one_to_ln = False try: @@ -2050,7 +2142,7 @@ def split_docstring(self, block): pass else: raw_first_line = split_leading_trailing_indent(rem_comment(first_line))[1] - if match_in(self.just_a_string, raw_first_line): + if match_in(self.just_a_string, raw_first_line, inner=True): return first_line, rest_of_lines return None, block @@ -2268,7 +2360,7 @@ def transform_returns(self, original, loc, raw_lines, tre_return_grammar=None, i def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, is_stmt_lambda): """Determines if TCO or TRE can be done and if so does it, handles dotted function names, and universalizes async functions.""" - raw_lines = list(logical_lines(funcdef, True)) + raw_lines = list(literal_lines(funcdef, True)) def_stmt = raw_lines.pop(0) out = [] @@ -2383,7 +2475,8 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, raise self.make_err( CoconutTargetError, "async function definition requires a specific target", - original, loc, + original, + loc, target="sys", ) elif self.target_info >= (3, 5): @@ -2394,7 +2487,8 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, raise self.make_err( CoconutTargetError, "found Python 3.6 async generator (Coconut can only backport async generators as far back as 3.5)", - original, loc, + original, + loc, target="35", ) else: @@ -2568,7 +2662,7 @@ def {mock_var}({mock_paramdef}): {vars_var} = {{"{def_name}": {def_name}}} else: {vars_var} = _coconut.globals().copy() - {vars_var}.update(_coconut.locals().copy()) + {vars_var}.update(_coconut.locals()) _coconut_exec({code_str}, {vars_var}) {func_name} = {func_from_vars} ''', @@ -2621,7 +2715,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names= self.compile_add_code_before_regexes() out = [] - for raw_line in inputstring.splitlines(True): + for raw_line in literal_lines(inputstring, True): bef_ind, line, aft_ind = split_leading_trailing_indent(raw_line) # look for deferred errors @@ -2644,7 +2738,7 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names= # handle any non-function code that was added before the funcdef pre_def_lines = [] post_def_lines = [] - funcdef_lines = list(logical_lines(funcdef, True)) + funcdef_lines = list(literal_lines(funcdef, True)) for i, line in enumerate(funcdef_lines): if self.def_regex.match(line): pre_def_lines = funcdef_lines[:i] @@ -2753,15 +2847,18 @@ def function_call_handle(self, loc, tokens): """Enforce properly ordered function parameters.""" return "(" + join_args(*self.split_function_call(tokens, loc)) + ")" - def pipe_item_split(self, tokens, loc): + def pipe_item_split(self, original, loc, tokens): """Process a pipe item, which could be a partial, an attribute access, a method call, or an expression. - Return (type, split) where split is: - - (expr,) for expression - - (func, pos_args, kwd_args) for partial - - (name, args) for attr/method - - (attr, [(op, args)]) for itemgetter - - (op, arg) for right op partial + Return (type, split) where split is, for each type: + - expr: (expr,) + - partial: (func, pos_args, kwd_args) + - attrgetter: (name, args) + - itemgetter: (attr, [(op, args)]) for itemgetter + - right op partial: (op, arg) + - right arr concat partial: (op, arg) + - await: () + - namedexpr: (varname,) """ # list implies artificial tokens, which must be expr if isinstance(tokens, list) or "expr" in tokens: @@ -2791,9 +2888,32 @@ def pipe_item_split(self, tokens, loc): return "right op partial", (op, arg) else: raise CoconutInternalException("invalid op partial tokens in pipe_item", inner_toks) + elif "arr concat partial" in tokens: + inner_toks, = tokens + if "left arr concat partial" in inner_toks: + arg, op = inner_toks + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "partial", ("_coconut_arr_concat_op", str(len(op)) + ", " + arg, "") + elif "right arr concat partial" in inner_toks: + op, arg = inner_toks + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "right arr concat partial", (op, arg) + else: + raise CoconutInternalException("invalid arr concat partial tokens in pipe_item", inner_toks) elif "await" in tokens: internal_assert(len(tokens) == 1 and tokens[0] == "await", "invalid await pipe item tokens", tokens) - return "await", [] + return "await", () + elif "namedexpr" in tokens: + if self.target_info < (3, 8): + raise self.make_err( + CoconutTargetError, + "named expression partial in pipe only supported for targets 3.8+", + original, + loc, + target="38", + ) + varname, = tokens + return "namedexpr", (varname,) else: raise CoconutInternalException("invalid pipe item tokens", tokens) @@ -2807,7 +2927,7 @@ def pipe_handle(self, original, loc, tokens, **kwargs): return item # we've only been given one operand, so we can't do any optimization, so just produce the standard object - name, split_item = self.pipe_item_split(item, loc) + name, split_item = self.pipe_item_split(original, loc, item) if name == "expr": expr, = split_item return expr @@ -2820,8 +2940,12 @@ def pipe_handle(self, original, loc, tokens, **kwargs): return itemgetter_handle(item) elif name == "right op partial": return partial_op_item_handle(item) + elif name == "right arr concat partial": + return partial_arr_concat_handle(item) elif name == "await": raise CoconutDeferredSyntaxError("await in pipe must have something piped into it", loc) + elif name == "namedexpr": + raise CoconutDeferredSyntaxError("named expression partial in pipe must have something piped into it", loc) else: raise CoconutInternalException("invalid split pipe item", split_item) @@ -2852,7 +2976,7 @@ def pipe_handle(self, original, loc, tokens, **kwargs): elif direction == "forwards": # if this is an implicit partial, we have something to apply it to, so optimize it - name, split_item = self.pipe_item_split(item, loc) + name, split_item = self.pipe_item_split(original, loc, item) subexpr = self.pipe_handle(original, loc, tokens) if name == "expr": @@ -2888,11 +3012,22 @@ def pipe_handle(self, original, loc, tokens, **kwargs): raise CoconutDeferredSyntaxError("cannot star pipe into operator partial", loc) op, arg = split_item return "({op})({x}, {arg})".format(op=op, x=subexpr, arg=arg) + elif name == "right arr concat partial": + if stars: + raise CoconutDeferredSyntaxError("cannot star pipe into array concatenation operator partial", loc) + op, arg = split_item + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "_coconut_arr_concat_op({dim}, {x}, {arg})".format(dim=len(op), x=subexpr, arg=arg) elif name == "await": internal_assert(not split_item, "invalid split await pipe item tokens", split_item) if stars: raise CoconutDeferredSyntaxError("cannot star pipe into await", loc) return self.await_expr_handle(original, loc, [subexpr]) + elif name == "namedexpr": + if stars: + raise CoconutDeferredSyntaxError("cannot star pipe into named expression partial", loc) + varname, = split_item + return "({varname} := {item})".format(varname=varname, item=subexpr) else: raise CoconutInternalException("invalid split pipe item", split_item) @@ -3044,7 +3179,7 @@ def yield_from_handle(self, loc, tokens): def endline_handle(self, original, loc, tokens): """Add line number information to end of line.""" endline, = tokens - lines = endline.splitlines(True) + lines = tuple(literal_lines(endline, True)) if self.minify: lines = lines[0] out = [] @@ -3790,69 +3925,105 @@ def set_letter_literal_handle(self, tokens): def stmt_lambdef_handle(self, original, loc, tokens): """Process multi-line lambdef statements.""" - if len(tokens) == 4: - got_kwds, params, stmts_toks, followed_by = tokens - typedef = None - else: - got_kwds, params, typedef, stmts_toks, followed_by = tokens + name = self.get_temp_var("lambda", loc) - if followed_by == ",": - self.strict_err_or_warn("found statement lambda followed by comma; this isn't recommended as it can be unclear whether the comma is inside or outside the lambda (just wrap the lambda in parentheses)", original, loc) - else: - internal_assert(followed_by == "", "invalid stmt_lambdef followed_by", followed_by) - - is_async = False - add_kwds = [] - for kwd in got_kwds: - if kwd == "async": - self.internal_assert(not is_async, original, loc, "duplicate stmt_lambdef async keyword", kwd) - is_async = True - elif kwd == "copyclosure": - add_kwds.append(kwd) + # avoid regenerating the code if we already built it on a previous call + if name not in self.add_code_before: + if len(tokens) == 4: + got_kwds, params, stmts_toks, followed_by = tokens + typedef = None + else: + got_kwds, params, typedef, stmts_toks, followed_by = tokens + + if followed_by == ",": + self.strict_err_or_warn("found statement lambda followed by comma; this isn't recommended as it can be unclear whether the comma is inside or outside the lambda (just wrap the lambda in parentheses)", original, loc) else: - raise CoconutInternalException("invalid stmt_lambdef keyword", kwd) - - if len(stmts_toks) == 1: - stmts, = stmts_toks - elif len(stmts_toks) == 2: - stmts, last = stmts_toks - if "tests" in stmts_toks: - stmts = stmts.asList() + ["return " + last] + internal_assert(followed_by == "", "invalid stmt_lambdef followed_by", followed_by) + + is_async = False + add_kwds = [] + for kwd in got_kwds: + if kwd == "async": + self.internal_assert(not is_async, original, loc, "duplicate stmt_lambdef async keyword", kwd) + is_async = True + elif kwd == "copyclosure": + add_kwds.append(kwd) + else: + raise CoconutInternalException("invalid stmt_lambdef keyword", kwd) + + if len(stmts_toks) == 1: + stmts, = stmts_toks + elif len(stmts_toks) == 2: + stmts, last = stmts_toks + if "tests" in stmts_toks: + stmts = stmts.asList() + ["return " + last] + else: + stmts = stmts.asList() + [last] else: - stmts = stmts.asList() + [last] - else: - raise CoconutInternalException("invalid statement lambda body tokens", stmts_toks) + raise CoconutInternalException("invalid statement lambda body tokens", stmts_toks) - name = self.get_temp_var("lambda", loc) - body = openindent + "\n".join(stmts) + closeindent + body = openindent + "\n".join(stmts) + closeindent - if typedef is None: - colon = ":" - else: - colon = self.typedef_handle([typedef]) - if isinstance(params, str): - decorators = "" - funcdef = "def " + name + params + colon + "\n" + body - else: - match_tokens = [name] + list(params) - before_colon, after_docstring = self.name_match_funcdef_handle(original, loc, match_tokens) - decorators = "@_coconut_mark_as_match\n" - funcdef = ( - before_colon - + colon - + "\n" - + after_docstring - + body - ) + if typedef is None: + colon = ":" + else: + colon = self.typedef_handle([typedef]) + if isinstance(params, str): + decorators = "" + funcdef = "def " + name + params + colon + "\n" + body + else: + match_tokens = [name] + list(params) + before_colon, after_docstring = self.name_match_funcdef_handle(original, loc, match_tokens) + decorators = "@_coconut_mark_as_match\n" + funcdef = ( + before_colon + + colon + + "\n" + + after_docstring + + body + ) + + funcdef = " ".join(add_kwds + [funcdef]) - funcdef = " ".join(add_kwds + [funcdef]) + self.add_code_before[name] = self.decoratable_funcdef_stmt_handle(original, loc, [decorators, funcdef], is_async, is_stmt_lambda=True) - self.add_code_before[name] = self.decoratable_funcdef_stmt_handle(original, loc, [decorators, funcdef], is_async, is_stmt_lambda=True) + expr_setname_context = self.current_parsing_context("expr_setnames") + if expr_setname_context is None: + return name + else: + builder_name = self.get_temp_var("lambda_builder", loc) + + parent_context = expr_setname_context["parent"] + parent_setnames = set() + while parent_context: + parent_setnames |= parent_context["new_names"] + parent_context = parent_context["parent"] + + def stmt_lambdef_callback(): + expr_setnames = parent_setnames | expr_setname_context["new_names"] + expr_setnames_str = ", ".join(sorted(expr_setnames) + ["**_coconut_other_locals"]) + # the actual code for the function will automatically be added by add_code_before for name + builder_code = handle_indentation(""" +def {builder_name}({expr_setnames_str}): + del _coconut_other_locals + return {name} + """).format( + builder_name=builder_name, + expr_setnames_str=expr_setnames_str, + name=name, + ) + self.add_code_before[builder_name] = builder_code - return name + expr_setname_context["callbacks"].append(stmt_lambdef_callback) + if parent_setnames: + # use _coconut.dict to ensure it supports | + builder_args = "**(_coconut.dict(" + ", ".join(name + '=' + name for name in sorted(parent_setnames)) + ") | _coconut.locals())" + else: + builder_args = "**_coconut.locals()" + return builder_name + "(" + builder_args + ")" def decoratable_funcdef_stmt_handle(self, original, loc, tokens, is_async=False, is_stmt_lambda=False): - """Wraps the given function for later processing""" + """Wrap the given function for later processing.""" if len(tokens) == 1: funcdef, = tokens decorators = "" @@ -3869,7 +4040,8 @@ def await_expr_handle(self, original, loc, tokens): raise self.make_err( CoconutTargetError, "await requires a specific target", - original, loc, + original, + loc, target="sys", ) elif self.target_info >= (3, 5): @@ -3973,247 +4145,75 @@ def typed_assign_stmt_handle(self, tokens): type_ignore=self.type_ignore_comment(), ) - def funcname_typeparams_handle(self, tokens): - """Handle function names with type parameters.""" - if len(tokens) == 1: - name, = tokens - return name + def with_stmt_handle(self, tokens): + """Process with statements.""" + withs, body = tokens + if len(withs) == 1 or self.target_info >= (2, 7): + return "with " + ", ".join(withs) + body else: - name, paramdefs = tokens - return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) + return ( + "".join("with " + expr + ":\n" + openindent for expr in withs[:-1]) + + "with " + withs[-1] + body + + closeindent * (len(withs) - 1) + ) - funcname_typeparams_handle.ignore_one_token = True + def ellipsis_handle(self, tokens=None): + if self.target.startswith("3"): + return "..." + else: + return "_coconut.Ellipsis" - def type_param_handle(self, original, loc, tokens): - """Compile a type param into an assignment.""" - args = "" - bound_op = None - bound_op_type = "" - if "TypeVar" in tokens: - TypeVarFunc = "TypeVar" - bound_op_type = "bound" - if len(tokens) == 2: - name_loc, name = tokens - else: - name_loc, name, bound_op, bound = tokens - args = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False) - elif "TypeVar constraint" in tokens: - TypeVarFunc = "TypeVar" - bound_op_type = "constraint" - name_loc, name, bound_op, constraints = tokens - args = ", " + ", ".join(self.wrap_typedef(c, for_py_typedef=False) for c in constraints) - elif "TypeVarTuple" in tokens: - TypeVarFunc = "TypeVarTuple" - name_loc, name = tokens - elif "ParamSpec" in tokens: - TypeVarFunc = "ParamSpec" - name_loc, name = tokens + ellipsis_handle.ignore_arguments = True + + def match_case_tokens(self, match_var, check_var, original, tokens, top): + """Build code for matching the given case.""" + if len(tokens) == 3: + loc, matches, stmts = tokens + cond = None + elif len(tokens) == 4: + loc, matches, cond, stmts = tokens else: - raise CoconutInternalException("invalid type_param tokens", tokens) + raise CoconutInternalException("invalid case match tokens", tokens) + loc = int(loc) + matching = self.get_matcher(original, loc, check_var) + matching.match(matches, match_var) + if cond: + matching.add_guard(cond) + return matching.build(stmts, set_check_var=top) - kwargs = "" - if bound_op is not None: - self.internal_assert(bound_op_type in ("bound", "constraint"), original, loc, "invalid type_param bound_op", bound_op) - # uncomment this line whenever mypy adds support for infer_variance in TypeVar - # (and remove the warning about it in the DOCS) - # kwargs = ", infer_variance=True" - if bound_op == "<=": - self.strict_err_or_warn( - "use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator is deprecated (Coconut style is to use '<:' for bounds and ':' for constaints)", - original, - loc, - ) - else: - self.internal_assert(bound_op in (":", "<:"), original, loc, "invalid type_param bound_op", bound_op) - if bound_op_type == "bound" and bound_op != "<:" or bound_op_type == "constraint" and bound_op != ":": - self.strict_err( - "found use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator (Coconut style is to use '<:' for bounds and ':' for constaints)", - original, - loc, - ) + def cases_stmt_handle(self, original, loc, tokens): + """Process case blocks.""" + if len(tokens) == 3: + block_kwd, item, cases = tokens + default = None + elif len(tokens) == 4: + block_kwd, item, cases, default = tokens + else: + raise CoconutInternalException("invalid case tokens", tokens) - name_loc = int(name_loc) - internal_assert(name_loc == loc if TypeVarFunc == "TypeVar" else name_loc >= loc, "invalid name location for " + TypeVarFunc, (name_loc, loc, tokens)) + self.internal_assert(block_kwd in ("cases", "case", "match"), original, loc, "invalid case statement keyword", block_kwd) + if block_kwd == "case": + self.strict_err_or_warn("deprecated case keyword at top level in case ...: match ...: block (use Python 3.10 match ...: case ...: syntax instead)", original, loc) - typevar_info = self.current_parsing_context("typevars") - if typevar_info is not None: - # check to see if we already parsed this exact typevar, in which case just reuse the existing temp_name - if typevar_info["typevar_locs"].get(name, None) == name_loc: - name = typevar_info["all_typevars"][name] - else: - if name in typevar_info["all_typevars"]: - raise CoconutDeferredSyntaxError("type variable {name!r} already defined".format(name=name), loc) - temp_name = self.get_temp_var(("typevar", name), name_loc) - typevar_info["all_typevars"][name] = temp_name - typevar_info["new_typevars"].append((TypeVarFunc, temp_name)) - typevar_info["typevar_locs"][name] = name_loc - name = temp_name + check_var = self.get_temp_var("case_match_check", loc) + match_var = self.get_temp_var("case_match_to", loc) - return '{name} = _coconut.typing.{TypeVarFunc}("{name}"{args}{kwargs})\n'.format( - name=name, - TypeVarFunc=TypeVarFunc, - args=args, - kwargs=kwargs, + out = ( + match_var + " = " + item + "\n" + + self.match_case_tokens(match_var, check_var, original, cases[0], True) ) + for case in cases[1:]: + out += ( + "if not " + check_var + ":\n" + openindent + + self.match_case_tokens(match_var, check_var, original, case, False) + closeindent + ) + if default is not None: + out += "if not " + check_var + default + return out - def get_generic_for_typevars(self): - """Get the Generic instances for the current typevars.""" - typevar_info = self.current_parsing_context("typevars") - internal_assert(typevar_info is not None, "get_generic_for_typevars called with no typevars") - generics = [] - for TypeVarFunc, name in typevar_info["new_typevars"]: - if TypeVarFunc in ("TypeVar", "ParamSpec"): - generics.append(name) - elif TypeVarFunc == "TypeVarTuple": - if self.target_info >= (3, 11): - generics.append("*" + name) - else: - generics.append("_coconut.typing.Unpack[" + name + "]") - else: - raise CoconutInternalException("invalid TypeVarFunc", TypeVarFunc, "(", name, ")") - return "_coconut.typing.Generic[" + ", ".join(generics) + "]" - - @contextmanager - def type_alias_stmt_manage(self, item=None, original=None, loc=None): - """Manage the typevars parsing context.""" - prev_typevar_info = self.current_parsing_context("typevars") - with self.add_to_parsing_context("typevars", { - "all_typevars": {} if prev_typevar_info is None else prev_typevar_info["all_typevars"].copy(), - "new_typevars": [], - "typevar_locs": {}, - }): - yield - - def type_alias_stmt_handle(self, tokens): - """Handle type alias statements.""" - if len(tokens) == 2: - name, typedef = tokens - paramdefs = () - else: - name, paramdefs, typedef = tokens - if self.target_info >= (3, 12): - return "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) - else: - return "".join(paramdefs) + self.typed_assign_stmt_handle([ - name, - "_coconut.typing.TypeAlias", - self.wrap_typedef(typedef, for_py_typedef=False), - ]) - - def where_item_handle(self, tokens): - """Manage where items.""" - where_context = self.current_parsing_context("where") - internal_assert(not where_context["assigns"], "invalid where_context", where_context) - where_context["assigns"] = set() - return tokens - - @contextmanager - def where_stmt_manage(self, item, original, loc): - """Manage where statements.""" - with self.add_to_parsing_context("where", { - "assigns": None, - }): - yield - - def where_stmt_handle(self, loc, tokens): - """Process where statements.""" - main_stmt, body_stmts = tokens - - where_assigns = self.current_parsing_context("where")["assigns"] - internal_assert(lambda: where_assigns is not None, "missing where_assigns") - - where_init = "".join(body_stmts) - where_final = main_stmt + "\n" - out = where_init + where_final - if not where_assigns: - return out - - name_regexes = { - name: compile_regex(r"\b" + name + r"\b") - for name in where_assigns - } - name_replacements = { - name: self.get_temp_var(("where", name), loc) - for name in where_assigns - } - - where_init = self.deferred_code_proc(where_init) - where_final = self.deferred_code_proc(where_final) - out = where_init + where_final - - out = sub_all(out, name_regexes, name_replacements) - - return self.wrap_passthrough(out, early=True) - - def with_stmt_handle(self, tokens): - """Process with statements.""" - withs, body = tokens - if len(withs) == 1 or self.target_info >= (2, 7): - return "with " + ", ".join(withs) + body - else: - return ( - "".join("with " + expr + ":\n" + openindent for expr in withs[:-1]) - + "with " + withs[-1] + body - + closeindent * (len(withs) - 1) - ) - - def ellipsis_handle(self, tokens=None): - if self.target.startswith("3"): - return "..." - else: - return "_coconut.Ellipsis" - - ellipsis_handle.ignore_tokens = True - - def match_case_tokens(self, match_var, check_var, original, tokens, top): - """Build code for matching the given case.""" - if len(tokens) == 3: - loc, matches, stmts = tokens - cond = None - elif len(tokens) == 4: - loc, matches, cond, stmts = tokens - else: - raise CoconutInternalException("invalid case match tokens", tokens) - loc = int(loc) - matching = self.get_matcher(original, loc, check_var) - matching.match(matches, match_var) - if cond: - matching.add_guard(cond) - return matching.build(stmts, set_check_var=top) - - def cases_stmt_handle(self, original, loc, tokens): - """Process case blocks.""" - if len(tokens) == 3: - block_kwd, item, cases = tokens - default = None - elif len(tokens) == 4: - block_kwd, item, cases, default = tokens - else: - raise CoconutInternalException("invalid case tokens", tokens) - - self.internal_assert(block_kwd in ("cases", "case", "match"), original, loc, "invalid case statement keyword", block_kwd) - if block_kwd == "case": - self.strict_err_or_warn("deprecated case keyword at top level in case ...: match ...: block (use Python 3.10 match ...: case ...: syntax instead)", original, loc) - - check_var = self.get_temp_var("case_match_check", loc) - match_var = self.get_temp_var("case_match_to", loc) - - out = ( - match_var + " = " + item + "\n" - + self.match_case_tokens(match_var, check_var, original, cases[0], True) - ) - for case in cases[1:]: - out += ( - "if not " + check_var + ":\n" + openindent - + self.match_case_tokens(match_var, check_var, original, case, False) + closeindent - ) - if default is not None: - out += "if not " + check_var + default - return out - - def f_string_handle(self, original, loc, tokens): - """Process Python 3.6 format strings.""" - string, = tokens + def f_string_handle(self, original, loc, tokens): + """Process Python 3.6 format strings.""" + string, = tokens # strip raw r raw = string.startswith("r") @@ -4542,75 +4542,207 @@ class {protocol_var}({tokens}, _coconut.typing.Protocol): pass # end: HANDLERS # ----------------------------------------------------------------------------------------------------------------------- -# CHECKING HANDLERS: +# MANAGERS: # ----------------------------------------------------------------------------------------------------------------------- - def check_strict(self, name, original, loc, tokens=(None,), only_warn=False, always_warn=False): - """Check that syntax meets --strict requirements.""" - self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) - message = "found " + name - if self.strict: - kwargs = {} - if only_warn: - if not always_warn: - kwargs["extra"] = "remove --strict to dismiss" - self.syntax_warning(message, original, loc, **kwargs) - else: - if always_warn: - kwargs["extra"] = "remove --strict to downgrade to a warning" - return self.raise_or_wrap_error(self.make_err(CoconutStyleError, message, original, loc, **kwargs)) - elif always_warn: - self.syntax_warning(message, original, loc) - return tokens[0] + def current_parsing_context(self, name, default=None): + """Get the current parsing context for the given name.""" + stack = self.parsing_context[name] + if stack: + return stack[-1] + else: + return default - def lambdef_check(self, original, loc, tokens): - """Check for Python-style lambdas.""" - return self.check_strict("Python-style lambda", original, loc, tokens) + @contextmanager + def add_to_parsing_context(self, name, obj, callbacks_key=None): + """Pur the given object on the parsing context stack for the given name.""" + self.parsing_context[name].append(obj) + try: + yield + finally: + popped_ctx = self.parsing_context[name].pop() + if callbacks_key is not None: + for callback in popped_ctx[callbacks_key]: + callback() - def endline_semicolon_check(self, original, loc, tokens): - """Check for semicolons at the end of lines.""" - return self.check_strict("semicolon at end of line", original, loc, tokens, always_warn=True) + def funcname_typeparams_handle(self, tokens): + """Handle function names with type parameters.""" + if len(tokens) == 1: + name, = tokens + return name + else: + name, paramdefs = tokens + return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) - def u_string_check(self, original, loc, tokens): - """Check for Python-2-style unicode strings.""" - return self.check_strict("Python-2-style unicode string (all Coconut strings are unicode strings)", original, loc, tokens, always_warn=True) + funcname_typeparams_handle.ignore_one_token = True - def match_dotted_name_const_check(self, original, loc, tokens): - """Check for Python-3.10-style implicit dotted name match check.""" - return self.check_strict("Python-3.10-style dotted name in pattern-matching (Coconut style is to use '=={name}' not '{name}')".format(name=tokens[0]), original, loc, tokens) + def type_param_handle(self, original, loc, tokens): + """Compile a type param into an assignment.""" + args = "" + bound_op = None + bound_op_type = "" + if "TypeVar" in tokens: + TypeVarFunc = "TypeVar" + bound_op_type = "bound" + if len(tokens) == 2: + name_loc, name = tokens + else: + name_loc, name, bound_op, bound = tokens + args = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False) + elif "TypeVar constraint" in tokens: + TypeVarFunc = "TypeVar" + bound_op_type = "constraint" + name_loc, name, bound_op, constraints = tokens + args = ", " + ", ".join(self.wrap_typedef(c, for_py_typedef=False) for c in constraints) + elif "TypeVarTuple" in tokens: + TypeVarFunc = "TypeVarTuple" + name_loc, name = tokens + elif "ParamSpec" in tokens: + TypeVarFunc = "ParamSpec" + name_loc, name = tokens + else: + raise CoconutInternalException("invalid type_param tokens", tokens) - def match_check_equals_check(self, original, loc, tokens): - """Check for old-style =item in pattern-matching.""" - return self.check_strict("deprecated equality-checking '=...' pattern; use '==...' instead", original, loc, tokens, always_warn=True) + if bound_op is not None: + self.internal_assert(bound_op_type in ("bound", "constraint"), original, loc, "invalid type_param bound_op", bound_op) + if bound_op == "<=": + self.strict_err_or_warn( + "use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator is deprecated (Coconut style is to use '<:' for bounds and ':' for constaints)", + original, + loc, + ) + else: + self.internal_assert(bound_op in (":", "<:"), original, loc, "invalid type_param bound_op", bound_op) + if bound_op_type == "bound" and bound_op != "<:" or bound_op_type == "constraint" and bound_op != ":": + self.strict_err( + "found use of " + repr(bound_op) + " as a type parameter " + bound_op_type + " declaration operator (Coconut style is to use '<:' for bounds and ':' for constaints)", + original, + loc, + ) - def power_in_impl_call_check(self, original, loc, tokens): - """Check for exponentation in implicit function application / coefficient syntax.""" - return self.check_strict( - "syntax with new behavior in Coconut v3; 'f x ** y' is now equivalent to 'f(x**y)' not 'f(x)**y'", - original, - loc, - tokens, - only_warn=True, - always_warn=True, + kwargs = "" + # uncomment these lines whenever mypy adds support for infer_variance in TypeVar + # (and remove the warning about it in the DOCS) + # if TypeVarFunc == "TypeVar": + # kwargs += ", infer_variance=True" + + name_loc = int(name_loc) + internal_assert(name_loc == loc if TypeVarFunc == "TypeVar" else name_loc >= loc, "invalid name location for " + TypeVarFunc, (name_loc, loc, tokens)) + + typevar_info = self.current_parsing_context("typevars") + if typevar_info is not None: + # check to see if we already parsed this exact typevar, in which case just reuse the existing temp_name + if typevar_info["typevar_locs"].get(name, None) == name_loc: + name = typevar_info["all_typevars"][name] + else: + if name in typevar_info["all_typevars"]: + raise CoconutDeferredSyntaxError("type variable {name!r} already defined".format(name=name), loc) + temp_name = self.get_temp_var(("typevar", name), name_loc) + typevar_info["all_typevars"][name] = temp_name + typevar_info["new_typevars"].append((TypeVarFunc, temp_name)) + typevar_info["typevar_locs"][name] = name_loc + name = temp_name + + return '{name} = _coconut.typing.{TypeVarFunc}("{name}"{args}{kwargs})\n'.format( + name=name, + TypeVarFunc=TypeVarFunc, + args=args, + kwargs=kwargs, ) - def check_py(self, version, name, original, loc, tokens): - """Check for Python-version-specific syntax.""" - self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) - version_info = get_target_info(version) - if self.target_info < version_info: - return self.raise_or_wrap_error(self.make_err( - CoconutTargetError, - "found Python " + ".".join(str(v) for v in version_info) + " " + name, - original, - loc, - target=version, - )) + def get_generic_for_typevars(self): + """Get the Generic instances for the current typevars.""" + typevar_info = self.current_parsing_context("typevars") + internal_assert(typevar_info is not None, "get_generic_for_typevars called with no typevars") + generics = [] + for TypeVarFunc, name in typevar_info["new_typevars"]: + if TypeVarFunc in ("TypeVar", "ParamSpec"): + generics.append(name) + elif TypeVarFunc == "TypeVarTuple": + if self.target_info >= (3, 11): + generics.append("*" + name) + else: + generics.append("_coconut.typing.Unpack[" + name + "]") + else: + raise CoconutInternalException("invalid TypeVarFunc", TypeVarFunc, "(", name, ")") + return "_coconut.typing.Generic[" + ", ".join(generics) + "]" + + @contextmanager + def type_alias_stmt_manage(self, original=None, loc=None, item=None): + """Manage the typevars parsing context.""" + prev_typevar_info = self.current_parsing_context("typevars") + with self.add_to_parsing_context("typevars", { + "all_typevars": {} if prev_typevar_info is None else prev_typevar_info["all_typevars"].copy(), + "new_typevars": [], + "typevar_locs": {}, + }): + yield + + def type_alias_stmt_handle(self, tokens): + """Handle type alias statements.""" + if len(tokens) == 2: + name, typedef = tokens + paramdefs = () else: - return tokens[0] + name, paramdefs, typedef = tokens + out = "".join(paramdefs) + if self.target_info >= (3, 12): + out += "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) + else: + out += self.typed_assign_stmt_handle([ + name, + "_coconut.typing.TypeAlias", + self.wrap_typedef(typedef, for_py_typedef=False), + ]) + return out + + def where_item_handle(self, tokens): + """Manage where items.""" + where_context = self.current_parsing_context("where") + internal_assert(not where_context["assigns"], "invalid where_context", where_context) + where_context["assigns"] = set() + return tokens + + @contextmanager + def where_stmt_manage(self, original, loc, item): + """Manage where statements.""" + with self.add_to_parsing_context("where", { + "assigns": None, + }): + yield + + def where_stmt_handle(self, loc, tokens): + """Process where statements.""" + main_stmt, body_stmts = tokens + + where_assigns = self.current_parsing_context("where")["assigns"] + internal_assert(lambda: where_assigns is not None, "missing where_assigns") + + where_init = "".join(body_stmts) + where_final = main_stmt + "\n" + out = where_init + where_final + if not where_assigns: + return out + + name_regexes = { + name: compile_regex(r"\b" + name + r"\b") + for name in where_assigns + } + name_replacements = { + name: self.get_temp_var(("where", name), loc) + for name in where_assigns + } + + where_init = self.deferred_code_proc(where_init) + where_final = self.deferred_code_proc(where_final) + out = where_init + where_final + + out = sub_all(out, name_regexes, name_replacements) + + return self.wrap_passthrough(out, early=True) @contextmanager - def class_manage(self, item, original, loc): + def class_manage(self, original, loc, item): """Manage the class parsing context.""" cls_stack = self.parsing_context["class"] if cls_stack: @@ -4636,7 +4768,7 @@ def class_manage(self, item, original, loc): cls_stack.pop() @contextmanager - def func_manage(self, item, original, loc): + def func_manage(self, original, loc, item): """Manage the function parsing context.""" cls_context = self.current_parsing_context("class") if cls_context is not None: @@ -4655,8 +4787,25 @@ def in_method(self): cls_context = self.current_parsing_context("class") return cls_context is not None and cls_context["name"] is not None and cls_context["in_method"] - def name_handle(self, original, loc, tokens, assign=False, classname=False): + @contextmanager + def has_expr_setname_manage(self, original, loc, item): + """Handle parses that can assign expr_setname.""" + with self.add_to_parsing_context( + "expr_setnames", + { + "parent": self.current_parsing_context("expr_setnames"), + "new_names": set(), + "callbacks": [], + "loc": loc, + }, + callbacks_key="callbacks", + ): + yield + + def name_handle(self, original, loc, tokens, assign=False, classname=False, expr_setname=False): """Handle the given base name.""" + internal_assert(assign if expr_setname else True, "expr_setname should always imply assign", (expr_setname, assign)) + name, = tokens if name.startswith("\\"): name = name[1:] @@ -4679,6 +4828,11 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): self.internal_assert(cls_context is not None, original, loc, "found classname outside of class", tokens) cls_context["name"] = name + if expr_setname: + expr_setnames_context = self.current_parsing_context("expr_setnames") + self.internal_assert(expr_setnames_context is not None, original, loc, "found expr_setname outside of has_expr_setname_manage", tokens) + expr_setnames_context["new_names"].add(name) + # raise_or_wrap_error for all errors here to make sure we don't # raise spurious errors if not using the computation graph @@ -4713,8 +4867,8 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): # greedily, which means this might be an invalid parse, in which # case we can't be sure this is actually shadowing a builtin and USE_COMPUTATION_GRAPH - # classnames are handled greedily, so ditto the above - and not classname + # classnames and expr_setnames are handled greedily, so ditto the above + and not (classname or expr_setname) and name in all_builtins ): self.strict_err_or_warn( @@ -4757,6 +4911,75 @@ def name_handle(self, original, loc, tokens, assign=False, classname=False): else: return name +# end: MANAGERS +# ----------------------------------------------------------------------------------------------------------------------- +# CHECKING HANDLERS: +# ----------------------------------------------------------------------------------------------------------------------- + + def check_strict(self, name, original, loc, tokens=(None,), only_warn=False, always_warn=False): + """Check that syntax meets --strict requirements.""" + self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) + message = "found " + name + if self.strict: + kwargs = {} + if only_warn: + if not always_warn: + kwargs["extra"] = "remove --strict to dismiss" + self.syntax_warning(message, original, loc, **kwargs) + else: + if always_warn: + kwargs["extra"] = "remove --strict to downgrade to a warning" + return self.raise_or_wrap_error(self.make_err(CoconutStyleError, message, original, loc, **kwargs)) + elif always_warn: + self.syntax_warning(message, original, loc) + return tokens[0] + + def lambdef_check(self, original, loc, tokens): + """Check for Python-style lambdas.""" + return self.check_strict("Python-style lambda", original, loc, tokens) + + def endline_semicolon_check(self, original, loc, tokens): + """Check for semicolons at the end of lines.""" + return self.check_strict("semicolon at end of line", original, loc, tokens, always_warn=True) + + def u_string_check(self, original, loc, tokens): + """Check for Python-2-style unicode strings.""" + return self.check_strict("Python-2-style unicode string (all Coconut strings are unicode strings)", original, loc, tokens, always_warn=True) + + def match_dotted_name_const_check(self, original, loc, tokens): + """Check for Python-3.10-style implicit dotted name match check.""" + return self.check_strict("Python-3.10-style dotted name in pattern-matching (Coconut style is to use '=={name}' not '{name}')".format(name=tokens[0]), original, loc, tokens) + + def match_check_equals_check(self, original, loc, tokens): + """Check for old-style =item in pattern-matching.""" + return self.check_strict("deprecated equality-checking '=...' pattern; use '==...' instead", original, loc, tokens, always_warn=True) + + def power_in_impl_call_check(self, original, loc, tokens): + """Check for exponentation in implicit function application / coefficient syntax.""" + return self.check_strict( + "syntax with new behavior in Coconut v3; 'f x ** y' is now equivalent to 'f(x**y)' not 'f(x)**y'", + original, + loc, + tokens, + only_warn=True, + always_warn=True, + ) + + def check_py(self, version, name, original, loc, tokens): + """Check for Python-version-specific syntax.""" + self.internal_assert(len(tokens) == 1, original, loc, "invalid " + name + " tokens", tokens) + version_info = get_target_info(version) + if self.target_info < version_info: + return self.raise_or_wrap_error(self.make_err( + CoconutTargetError, + "found Python " + ".".join(str(v) for v in version_info) + " " + name, + original, + loc, + target=version, + )) + else: + return tokens[0] + def nonlocal_check(self, original, loc, tokens): """Check for Python 3 nonlocal statement.""" return self.check_py("3", "nonlocal statement", original, loc, tokens) @@ -4787,7 +5010,7 @@ def namedexpr_check(self, original, loc, tokens): def new_namedexpr_check(self, original, loc, tokens): """Check for Python 3.10 assignment expressions.""" - return self.check_py("310", "assignment expression in set literal or indexing", original, loc, tokens) + return self.check_py("310", "assignment expression in syntactic location only supported for 3.10+", original, loc, tokens) def except_star_clause_check(self, original, loc, tokens): """Check for Python 3.11 except* statements.""" diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 7eaed6226..967930699 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -32,6 +32,7 @@ from functools import partial from coconut._pyparsing import ( + USE_LINE_BY_LINE, Forward, Group, Literal, @@ -109,7 +110,6 @@ labeled_group, any_keyword_in, any_char, - tuple_str_of, any_len_perm, any_len_perm_at_least_one, boundary, @@ -551,6 +551,21 @@ def partial_op_item_handle(tokens): raise CoconutInternalException("invalid operator function implicit partial token group", tok_grp) +def partial_arr_concat_handle(tokens): + """Handle array concatenation operator function implicit partials.""" + tok_grp, = tokens + if "left arr concat partial" in tok_grp: + arg, op = tok_grp + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "_coconut_partial(_coconut_arr_concat_op, " + str(len(op)) + ", " + arg + ")" + elif "right arr concat partial" in tok_grp: + op, arg = tok_grp + internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) + return "_coconut_complex_partial(_coconut_arr_concat_op, {{0: {dim}, 2: {arg}}}, 3, ())".format(dim=len(op), arg=arg) + else: + raise CoconutInternalException("invalid array concatenation operator function implicit partial token group", tok_grp) + + def array_literal_handle(loc, tokens): """Handle multidimensional array literals.""" internal_assert(len(tokens) >= 1, "invalid array literal tokens", tokens) @@ -576,20 +591,27 @@ def array_literal_handle(loc, tokens): array_elems = [] for p in pieces: if p: - if len(p) > 1: + if p[0].lstrip(";") == "": + raise CoconutDeferredSyntaxError("invalid initial multidimensional array separator or broken-up multidimensional array concatenation operator function", loc) + elif len(p) > 1: internal_assert(sep_level > 1, "failed to handle array literal tokens", tokens) subarr_item = array_literal_handle(loc, p) - elif p[0].lstrip(";") == "": - raise CoconutDeferredSyntaxError("naked multidimensional array separators are not allowed", loc) else: subarr_item = p[0] array_elems.append(subarr_item) + # if multidimensional array literal is only separators, compile to implicit partial if not array_elems: - raise CoconutDeferredSyntaxError("multidimensional array literal cannot be only separators", loc) + if len(pieces) > 2: + raise CoconutDeferredSyntaxError("invalid empty multidimensional array literal or broken-up multidimensional array concatenation operator function", loc) + return "_coconut_partial(_coconut_arr_concat_op, " + str(sep_level) + ")" + + # check for initial top-level separators + if not pieces[0]: + raise CoconutDeferredSyntaxError("invalid initial multidimensional array separator", loc) # build multidimensional array - return "_coconut_multi_dim_arr(" + tuple_str_of(array_elems) + ", " + str(sep_level) + ")" + return "_coconut_arr_concat_op(" + str(sep_level) + ", " + ", ".join(array_elems) + ")" def typedef_op_item_handle(loc, tokens): @@ -627,8 +649,8 @@ class Grammar(object): unsafe_fat_arrow = Literal("=>") | fixto(Literal("\u21d2"), "=>") colon_eq = Literal(":=") unsafe_dubcolon = Literal("::") - unsafe_colon = Literal(":") colon = disambiguate_literal(":", ["::", ":="]) + indexing_colon = disambiguate_literal(":", [":="]) # same as : but :: is allowed lt_colon = Literal("<:") semicolon = Literal(";") | invalid_syntax("\u037e", "invalid Greek question mark instead of semicolon", greedy=True) multisemicolon = combine(OneOrMore(semicolon)) @@ -651,9 +673,9 @@ class Grammar(object): pipe = Literal("|>") | fixto(Literal("\u21a6"), "|>") star_pipe = Literal("|*>") | fixto(Literal("*\u21a6"), "|*>") dubstar_pipe = Literal("|**>") | fixto(Literal("**\u21a6"), "|**>") - back_pipe = Literal("<|") | fixto(Literal("\u21a4"), "<|") - back_star_pipe = Literal("<*|") | ~Literal("\u21a4**") + fixto(Literal("\u21a4*"), "<*|") - back_dubstar_pipe = Literal("<**|") | fixto(Literal("\u21a4**"), "<**|") + back_pipe = Literal("<|") | disambiguate_literal("\u21a4", ["\u21a4*", "\u21a4?"], fixesto="<|") + back_star_pipe = Literal("<*|") | disambiguate_literal("\u21a4*", ["\u21a4**", "\u21a4*?"], fixesto="<*|") + back_dubstar_pipe = Literal("<**|") | disambiguate_literal("\u21a4**", ["\u21a4**?"], fixesto="<**|") none_pipe = Literal("|?>") | fixto(Literal("?\u21a6"), "|?>") none_star_pipe = ( Literal("|?*>") @@ -780,6 +802,7 @@ class Grammar(object): refname = Forward() setname = Forward() + expr_setname = Forward() classname = Forward() name_ref = combine(Optional(backslash) + base_name) unsafe_name = combine(Optional(backslash.suppress()) + base_name) @@ -795,7 +818,7 @@ class Grammar(object): octint = combine(Word("01234567") + ZeroOrMore(underscore.suppress() + Word("01234567"))) hexint = combine(Word(hexnums) + ZeroOrMore(underscore.suppress() + Word(hexnums))) - imag_j = caseless_literal("j") | fixto(caseless_literal("i", suppress=True), "j") + imag_j = caseless_literal("j") | fixto(caseless_literal("i", suppress=True, disambiguate=True), "j") basenum = combine( Optional(integer) + dot + integer | integer + Optional(dot + Optional(integer)) @@ -930,16 +953,17 @@ class Grammar(object): ) atom_item = Forward() + const_atom = Forward() expr = Forward() star_expr = Forward() dubstar_expr = Forward() - comp_for = Forward() test_no_cond = Forward() infix_op = Forward() namedexpr_test = Forward() # for namedexpr locations only supported in Python 3.10 new_namedexpr_test = Forward() - lambdef = Forward() + comp_for = Forward() + comprehension_expr = Forward() typedef = Forward() typedef_default = Forward() @@ -949,6 +973,10 @@ class Grammar(object): typedef_ellipsis = Forward() typedef_op_item = Forward() + expr_lambdef = Forward() + stmt_lambdef = Forward() + lambdef = expr_lambdef | stmt_lambdef + negable_atom_item = condense(Optional(neg_minus) + atom_item) testlist = itemlist(test, comma, suppress_trailing=False) @@ -1046,6 +1074,7 @@ class Grammar(object): | fixto(dollar, "_coconut_partial") | fixto(keyword("assert"), "_coconut_assert") | fixto(keyword("raise"), "_coconut_raise") + | fixto(keyword("if"), "_coconut_if_op") | fixto(keyword("is") + keyword("not"), "_coconut.operator.is_not") | fixto(keyword("not") + keyword("in"), "_coconut_not_in") @@ -1058,14 +1087,14 @@ class Grammar(object): | fixto(keyword("is"), "_coconut.operator.is_") | fixto(keyword("in"), "_coconut_in") ) - partialable_op = base_op_item | infix_op + partialable_op = ~keyword("if") + (base_op_item | infix_op) partial_op_item_tokens = ( labeled_group(dot.suppress() + partialable_op + test_no_infix, "right partial") | labeled_group(test_no_infix + partialable_op + dot.suppress(), "left partial") ) partial_op_item = attach(partial_op_item_tokens, partial_op_item_handle) op_item = ( - # partial_op_item must come first, then typedef_op_item must come after base_op_item + # must stay in exactly this order partial_op_item | typedef_op_item | base_op_item @@ -1073,6 +1102,12 @@ class Grammar(object): partial_op_atom_tokens = lparen.suppress() + partial_op_item_tokens + rparen.suppress() + partial_arr_concat_tokens = lbrack.suppress() + ( + labeled_group(dot.suppress() + multisemicolon + test_no_infix + rbrack.suppress(), "right arr concat partial") + | labeled_group(test_no_infix + multisemicolon + dot.suppress() + rbrack.suppress(), "left arr concat partial") + ) + partial_arr_concat = attach(partial_arr_concat_tokens, partial_arr_concat_handle) + # we include (var)arg_comma to ensure the pattern matches the whole arg arg_comma = comma | fixto(FollowedBy(rparen), "") setarg_comma = arg_comma | fixto(FollowedBy(colon), "") @@ -1119,21 +1154,25 @@ class Grammar(object): ZeroOrMore( condense( # everything here must end with setarg_comma - setname + Optional(default) + setarg_comma - | (star | dubstar) + setname + setarg_comma + expr_setname + Optional(default) + setarg_comma + | (star | dubstar) + expr_setname + setarg_comma | star_sep_setarg | slash_sep_setarg ) ) ) ) + match_arg_default = Group( + const_atom("const") + | test("expr") + ) match_args_list = Group(Optional( tokenlist( Group( (star | dubstar) + match | star # not star_sep because pattern-matching can handle star separators on any Python version | slash # not slash_sep as above - | match + Optional(equals.suppress() + test) + | match + Optional(equals.suppress() + match_arg_default) ), comma, ) @@ -1151,7 +1190,7 @@ class Grammar(object): # everything here must end with rparen rparen.suppress() | tokenlist(Group(call_item), comma) + rparen.suppress() - | Group(attach(addspace(test + comp_for), add_parens_handle)) + rparen.suppress() + | Group(attach(comprehension_expr, add_parens_handle)) + rparen.suppress() | Group(op_item) + rparen.suppress() ) function_call = Forward() @@ -1170,21 +1209,23 @@ class Grammar(object): | op_item ) + # for .[] subscript_star = Forward() subscript_star_ref = star slicetest = Optional(test_no_chain) - sliceop = condense(unsafe_colon + slicetest) + sliceop = condense(indexing_colon + slicetest) subscript = condense( slicetest + sliceop + Optional(sliceop) - | Optional(subscript_star) + test + | Optional(subscript_star) + new_namedexpr_test ) - subscriptlist = itemlist(subscript, comma, suppress_trailing=False) | new_namedexpr_test + subscriptlist = itemlist(subscript, comma, suppress_trailing=False) + # for .$[] slicetestgroup = Optional(test_no_chain, default="") - sliceopgroup = unsafe_colon.suppress() + slicetestgroup + sliceopgroup = indexing_colon.suppress() + slicetestgroup subscriptgroup = attach( slicetestgroup + sliceopgroup + Optional(sliceopgroup) - | test, + | new_namedexpr_test, subscriptgroup_handle, ) subscriptgrouplist = itemlist(subscriptgroup, comma) @@ -1199,10 +1240,6 @@ class Grammar(object): comma, ) - comprehension_expr = ( - addspace(namedexpr_test + comp_for) - | invalid_syntax(star_expr + comp_for, "iterable unpacking cannot be used in comprehension") - ) paren_atom = condense(lparen + any_of( # everything here must end with rparen rparen, @@ -1228,7 +1265,8 @@ class Grammar(object): list_item = ( lbrack.suppress() + list_expr + rbrack.suppress() | condense(lbrack + Optional(comprehension_expr) + rbrack) - # array_literal must come last + # partial_arr_concat and array_literal must come last + | partial_arr_concat | array_literal ) @@ -1250,7 +1288,7 @@ class Grammar(object): setmaker = Group( (new_namedexpr_test + FollowedBy(rbrace))("test") | (new_namedexpr_testlist_has_comma + FollowedBy(rbrace))("list") - | addspace(new_namedexpr_test + comp_for + FollowedBy(rbrace))("comp") + | (comprehension_expr + FollowedBy(rbrace))("comp") | (testlist_star_namedexpr + FollowedBy(rbrace))("testlist_star_expr") ) set_literal_ref = lbrace.suppress() + setmaker + rbrace.suppress() @@ -1259,19 +1297,24 @@ class Grammar(object): lazy_items = Optional(tokenlist(test, comma)) lazy_list = attach(lbanana.suppress() + lazy_items + rbanana.suppress(), lazy_list_handle) - known_atom = ( + # for const_atom, value should be known at compile time + const_atom <<= ( keyword_atom - | string_atom | num_atom + # typedef ellipsis must come before ellipsis + | typedef_ellipsis + | ellipsis + ) + # for known_atom, type should be known at compile time + known_atom = ( + const_atom + | string_atom | list_item | dict_literal | dict_comp | set_literal | set_letter_literal | lazy_list - # typedef ellipsis must come before ellipsis - | typedef_ellipsis - | ellipsis ) atom = ( # known_atom must come before name to properly parse string prefixes @@ -1350,17 +1393,20 @@ class Grammar(object): no_partial_trailer_atom_ref = atom + ZeroOrMore(no_partial_trailer) partial_atom_tokens = no_partial_trailer_atom + partial_trailer_tokens + # must be kept in sync with expr_assignlist block below + assignlist = Forward() + star_assign_item = Forward() simple_assign = Forward() simple_assign_ref = maybeparens( lparen, - (setname | passthrough_atom) - + ZeroOrMore(ZeroOrMore(complex_trailer) + OneOrMore(simple_trailer)), + ( + # refname if there's a trailer, setname if not + (refname | passthrough_atom) + OneOrMore(ZeroOrMore(complex_trailer) + OneOrMore(simple_trailer)) + | setname + | passthrough_atom + ), rparen, ) - simple_assignlist = maybeparens(lparen, itemlist(simple_assign, comma, suppress_trailing=False), rparen) - - assignlist = Forward() - star_assign_item = Forward() base_assign_item = condense( simple_assign | lparen + assignlist + rparen @@ -1370,6 +1416,30 @@ class Grammar(object): assign_item = base_assign_item | star_assign_item assignlist <<= itemlist(assign_item, comma, suppress_trailing=False) + # must be kept in sync with assignlist block above (but with expr_setname) + expr_assignlist = Forward() + expr_star_assign_item = Forward() + expr_simple_assign = Forward() + expr_simple_assign_ref = maybeparens( + lparen, + ( + # refname if there's a trailer, expr_setname if not + (refname | passthrough_atom) + OneOrMore(ZeroOrMore(complex_trailer) + OneOrMore(simple_trailer)) + | expr_setname + | passthrough_atom + ), + rparen, + ) + expr_base_assign_item = condense( + expr_simple_assign + | lparen + expr_assignlist + rparen + | lbrack + expr_assignlist + rbrack + ) + expr_star_assign_item_ref = condense(star + expr_base_assign_item) + expr_assign_item = expr_base_assign_item | expr_star_assign_item + expr_assignlist <<= itemlist(expr_assign_item, comma, suppress_trailing=False) + + simple_assignlist = maybeparens(lparen, itemlist(simple_assign, comma, suppress_trailing=False), rparen) typed_assign_stmt = Forward() typed_assign_stmt_ref = simple_assign + colon.suppress() + typedef_test + Optional(equals.suppress() + test_expr) basic_stmt = addspace(ZeroOrMore(assignlist + equals) + test_expr) @@ -1407,7 +1477,6 @@ class Grammar(object): ) + Optional(power_in_impl_call)) impl_call_item = condense( disallow_keywords(reserved_vars) - + ~any_string + ~non_decimal_num + atom_item + Optional(power_in_impl_call) @@ -1530,6 +1599,9 @@ class Grammar(object): back_none_dubstar_pipe, use_adaptive=False, ) + pipe_namedexpr_partial = lparen.suppress() + setname + (colon_eq + dot + rparen).suppress() + + # make sure to keep these three definitions in sync pipe_item = ( # we need the pipe_op since any of the atoms could otherwise be the start of an expression labeled_group(keyword("await"), "await") + pipe_op @@ -1538,6 +1610,8 @@ class Grammar(object): | labeled_group(itemgetter_atom_tokens, "itemgetter") + pipe_op | labeled_group(attrgetter_atom_tokens, "attrgetter") + pipe_op | labeled_group(partial_op_atom_tokens, "op partial") + pipe_op + | labeled_group(partial_arr_concat_tokens, "arr concat partial") + pipe_op + | labeled_group(pipe_namedexpr_partial, "namedexpr") + pipe_op # expr must come at end | labeled_group(comp_pipe_expr, "expr") + pipe_op ) @@ -1548,22 +1622,26 @@ class Grammar(object): | labeled_group(itemgetter_atom_tokens, "itemgetter") + end_simple_stmt_item | labeled_group(attrgetter_atom_tokens, "attrgetter") + end_simple_stmt_item | labeled_group(partial_op_atom_tokens, "op partial") + end_simple_stmt_item + | labeled_group(partial_arr_concat_tokens, "arr concat partial") + end_simple_stmt_item + | labeled_group(pipe_namedexpr_partial, "namedexpr") + end_simple_stmt_item ) last_pipe_item = Group( lambdef("expr") # we need longest here because there's no following pipe_op we can use as above | longest( keyword("await")("await"), + partial_atom_tokens("partial"), itemgetter_atom_tokens("itemgetter"), attrgetter_atom_tokens("attrgetter"), - partial_atom_tokens("partial"), partial_op_atom_tokens("op partial"), + partial_arr_concat_tokens("arr concat partial"), + pipe_namedexpr_partial("namedexpr"), comp_pipe_expr("expr"), ) ) + normal_pipe_expr = Forward() normal_pipe_expr_tokens = OneOrMore(pipe_item) + last_pipe_item - pipe_expr = ( comp_pipe_expr + ~pipe_op | normal_pipe_expr @@ -1595,7 +1673,10 @@ class Grammar(object): unsafe_lambda_arrow = any_of(fat_arrow, arrow) keyword_lambdef_params = maybeparens(lparen, set_args_list, rparen) - arrow_lambdef_params = lparen.suppress() + set_args_list + rparen.suppress() | setname + arrow_lambdef_params = ( + lparen.suppress() + set_args_list + rparen.suppress() + | expr_setname + ) keyword_lambdef = Forward() keyword_lambdef_ref = addspace(keyword("lambda") + condense(keyword_lambdef_params + colon)) @@ -1607,7 +1688,6 @@ class Grammar(object): keyword_lambdef, ) - stmt_lambdef = Forward() match_guard = Optional(keyword("if").suppress() + namedexpr_test) closing_stmt = longest(new_testlist_star_expr("tests"), unsafe_simple_stmt_item) stmt_lambdef_match_params = Group(lparen.suppress() + match_args_list + match_guard + rparen.suppress()) @@ -1654,8 +1734,9 @@ class Grammar(object): | fixto(always_match, "") ) - lambdef <<= addspace(lambdef_base + test) | stmt_lambdef - lambdef_no_cond = addspace(lambdef_base + test_no_cond) + expr_lambdef_ref = addspace(lambdef_base + test) + lambdef_no_cond = Forward() + lambdef_no_cond_ref = addspace(lambdef_base + test_no_cond) typedef_callable_arg = Group( test("arg") @@ -1764,11 +1845,15 @@ class Grammar(object): invalid_syntax(maybeparens(lparen, namedexpr, rparen), "PEP 572 disallows assignment expressions in comprehension iterable expressions") | test_item ) - base_comp_for = addspace(keyword("for") + assignlist + keyword("in") + comp_it_item + Optional(comp_iter)) + base_comp_for = addspace(keyword("for") + expr_assignlist + keyword("in") + comp_it_item + Optional(comp_iter)) async_comp_for_ref = addspace(keyword("async") + base_comp_for) comp_for <<= base_comp_for | async_comp_for comp_if = addspace(keyword("if") + test_no_cond + Optional(comp_iter)) comp_iter <<= any_of(comp_for, comp_if) + comprehension_expr_ref = ( + addspace(namedexpr_test + comp_for) + | invalid_syntax(star_expr + comp_for, "iterable unpacking cannot be used in comprehension") + ) return_stmt = addspace(keyword("return") - Optional(new_testlist_star_expr)) @@ -2122,7 +2207,7 @@ class Grammar(object): ( lparen.suppress() + match - + Optional(equals.suppress() + test) + + Optional(equals.suppress() + match_arg_default) + rparen.suppress() ) | interior_name_match ) @@ -2161,7 +2246,7 @@ class Grammar(object): where_stmt_ref = where_item + where_suite implicit_return = ( - invalid_syntax(return_stmt, "expected expression but got return statement") + invalid_syntax(return_stmt, "assignment function expected expression as last statement but got return instead") | attach(new_testlist_star_expr, implicit_return_handle) ) implicit_return_where = Forward() @@ -2436,12 +2521,18 @@ class Grammar(object): line = newline | stmt - single_input = condense(Optional(line) - ZeroOrMore(newline)) file_input = condense(moduledoc_marker - ZeroOrMore(line)) + raw_file_parser = start_marker - file_input - end_marker + line_by_line_file_parser = ( + start_marker - moduledoc_marker - stores_loc_item, + start_marker - line - stores_loc_item, + ) + file_parser = line_by_line_file_parser if USE_LINE_BY_LINE else raw_file_parser + + single_input = condense(Optional(line) - ZeroOrMore(newline)) eval_input = condense(testlist - ZeroOrMore(newline)) single_parser = start_marker - single_input - end_marker - file_parser = start_marker - file_input - end_marker eval_parser = start_marker - eval_input - end_marker some_eval_parser = start_marker + eval_input @@ -2497,7 +2588,7 @@ class Grammar(object): original_function_call_tokens = ( lparen.suppress() + rparen.suppress() # we need to keep the parens here, since f(x for x in y) is fine but tail_call(f, x for x in y) is not - | condense(lparen + originalTextFor(test + comp_for) + rparen) + | condense(lparen + originalTextFor(comprehension_expr) + rparen) | attach(parens, strip_parens_handle) ) @@ -2601,23 +2692,25 @@ class Grammar(object): unsafe_equals = Literal("=") - kwd_err_msg = attach(any_keyword_in(keyword_vars + reserved_vars), kwd_err_msg_handle) - parse_err_msg = ( - start_marker + ( - fixto(end_of_line, "misplaced newline (maybe missing ':')") - | fixto(Optional(keyword("if") + skip_to_in_line(unsafe_equals)) + equals, "misplaced assignment (maybe should be '==')") - | kwd_err_msg - ) - | fixto( - questionmark + parse_err_msg = start_marker + ( + # should be in order of most likely to actually be the source of the error first + fixto( + ZeroOrMore(~questionmark + ~Literal("\n") + any_char) + + questionmark + ~dollar + ~lparen + ~lbrack + ~dot, "misplaced '?' (naked '?' is only supported inside partial application arguments)", ) + | fixto(Optional(keyword("if") + skip_to_in_line(unsafe_equals)) + equals, "misplaced assignment (maybe should be '==')") + | attach(any_keyword_in(keyword_vars + reserved_vars), kwd_err_msg_handle) + | fixto(end_of_line, "misplaced newline (maybe missing ':')") ) + start_f_str_regex = compile_regex(r"\br?fr?$") + start_f_str_regex_len = 4 + end_f_str_expr = combine(start_marker + (rbrace | colon | bang)) string_start = start_marker + python_quoted_string diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 8a60ff8cc..39b2d2664 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -33,8 +33,9 @@ justify_len, report_this_text, numpy_modules, - pandas_numpy_modules, + pandas_modules, jax_numpy_modules, + xarray_modules, self_match_types, is_data_var, data_defaults_var, @@ -290,11 +291,16 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): report_this_text=report_this_text, from_None=" from None" if target.startswith("3") else "", process_="process_" if target_info >= (3, 13) else "", - numpy_modules=tuple_str_of(numpy_modules, add_quotes=True), - pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True), + xarray_modules=tuple_str_of(xarray_modules, add_quotes=True), + pandas_modules=tuple_str_of(pandas_modules, add_quotes=True), jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True), self_match_types=tuple_str_of(self_match_types), + comma_bytearray=", bytearray" if not target.startswith("3") else "", + lstatic="staticmethod(" if not target.startswith("3") else "", + rstatic=")" if not target.startswith("3") else "", + all_keys="self.func_kwargs.keys() | kwargs.keys()" if target_info >= (3,) else "_coconut.set(self.func_kwargs.keys()) | _coconut.set(kwargs.keys())", + set_super=( # we have to use _coconut_super even on the universal target, since once we set __class__ it becomes a local variable "super = py_super" if target.startswith("3") else "super = _coconut_super" @@ -335,9 +341,6 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap): else "zip_longest = itertools.izip_longest", indent=1, ), - comma_bytearray=", bytearray" if not target.startswith("3") else "", - lstatic="staticmethod(" if not target.startswith("3") else "", - rstatic=")" if not target.startswith("3") else "", zip_iter=prepare( r''' for items in _coconut.iter(_coconut.zip(*self.iters, strict=self.strict) if _coconut_sys.version_info >= (3, 10) else _coconut.zip_longest(*self.iters, fillvalue=_coconut_sentinel) if self.strict else _coconut.zip(*self.iters)): @@ -638,7 +641,7 @@ def __anext__(self): # (extra_format_dict is to keep indentation levels matching) extra_format_dict = dict( # when anything is added to this list it must also be added to *both* __coconut__ stub files - underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_multi_dim_arr, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter".format(**format_dict), + underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op".format(**format_dict), import_typing=pycondition( (3, 5), if_ge=''' @@ -775,6 +778,15 @@ def __aiter__(self): {async_def_anext} '''.format(**format_dict), ), + handle_bytes=pycondition( + (3,), + if_lt=''' +if _coconut.isinstance(obj, _coconut.bytes): + return _coconut_base_makedata(_coconut.bytes, [func(_coconut.ord(x)) for x in obj], from_fmap=True, fallback_to_init=fallback_to_init) + ''', + indent=1, + newline=True, + ), maybe_bind_lru_cache=pycondition( (3, 2), if_lt=''' diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index 96765f91a..99e5457f5 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -307,6 +307,50 @@ def get_set_name_var(self, name): """Gets the var for checking whether a name should be set.""" return match_set_name_var + "_" + name + def add_default_expr(self, assign_to, default): + """Add code that evaluates expr in the context of any names that have been matched so far + and assigns the result to assign_to if assign_to is currently _coconut_sentinel.""" + default_expr, = default + if "const" in default: + self.add_def(handle_indentation(""" +if {assign_to} is _coconut_sentinel: + {assign_to} = {default_expr} + """.format( + assign_to=assign_to, + default_expr=default_expr, + ))) + else: + internal_assert("expr" in default, "invalid match default tokens", default) + vars_var = self.get_temp_var() + add_names_code = [] + for name in self.names: + add_names_code.append( + handle_indentation( + """ +if {set_name_var} is not _coconut_sentinel: + {vars_var}["{name}"] = {set_name_var} + """, + add_newline=True, + ).format( + set_name_var=self.get_set_name_var(name), + vars_var=vars_var, + name=name, + ) + ) + code = self.comp.reformat_post_deferred_code_proc(assign_to + " = " + default_expr) + self.add_def(handle_indentation(""" +if {assign_to} is _coconut_sentinel: + {vars_var} = _coconut.globals().copy() + {vars_var}.update(_coconut.locals()) + {add_names_code}_coconut_exec({code_str}, {vars_var}) + {assign_to} = {vars_var}["{assign_to}"] + """).format( + vars_var=vars_var, + add_names_code="".join(add_names_code), + assign_to=assign_to, + code_str=self.comp.wrap_str_of(code), + )) + def register_name(self, name): """Register a new name at the current position.""" internal_assert(lambda: name not in self.parent_names and name not in self.names, "attempt to register duplicate name", name) @@ -373,7 +417,7 @@ def match_function( ).format( first_arg=first_arg, args=args, - ), + ) ) with self.down_a_level(): @@ -418,7 +462,7 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al # if i >= req_len "_coconut.sum((_coconut.len(" + args + ") > " + str(i) + ", " + ", ".join('"' + name + '" in ' + kwargs for name in names) - + ")) == 1", + + ")) == 1" ) tempvar = self.get_temp_var() self.add_def( @@ -428,16 +472,19 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else " for name in names[:-1] ) - + kwargs + '.pop("' + names[-1] + '")', + + kwargs + '.pop("' + names[-1] + '")' ) with self.down_a_level(): self.match(match, tempvar) else: if not names: tempvar = self.get_temp_var() - self.add_def(tempvar + " = " + args + "[" + str(i) + "] if _coconut.len(" + args + ") > " + str(i) + " else " + default) - with self.down_a_level(): - self.match(match, tempvar) + self.add_def(tempvar + " = " + args + "[" + str(i) + "] if _coconut.len(" + args + ") > " + str(i) + " else _coconut_sentinel") + # go down to end to ensure we've matched as much as possible before evaluating the default + with self.down_to_end(): + self.add_default_expr(tempvar, default) + with self.down_a_level(): + self.match(match, tempvar) else: arg_checks[i] = ( # if i < req_len @@ -445,7 +492,7 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al # if i >= req_len "_coconut.sum((_coconut.len(" + args + ") > " + str(i) + ", " + ", ".join('"' + name + '" in ' + kwargs for name in names) - + ")) <= 1", + + ")) <= 1" ) tempvar = self.get_temp_var() self.add_def( @@ -455,10 +502,13 @@ def match_in_args_kwargs(self, pos_only_match_args, match_args, args, kwargs, al kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else " for name in names ) - + default, + + "_coconut_sentinel" ) - with self.down_a_level(): - self.match(match, tempvar) + # go down to end to ensure we've matched as much as possible before evaluating the default + with self.down_to_end(): + self.add_default_expr(tempvar, default) + with self.down_a_level(): + self.match(match, tempvar) # length checking max_len = None if allow_star_args else len(pos_only_match_args) + len(match_args) @@ -484,12 +534,18 @@ def match_in_kwargs(self, match_args, kwargs): kwargs + '.pop("' + name + '") if "' + name + '" in ' + kwargs + " else " for name in names ) - + (default if default is not None else "_coconut_sentinel"), + + "_coconut_sentinel" ) - with self.down_a_level(): - if default is None: + if default is None: + with self.down_a_level(): self.add_check(tempvar + " is not _coconut_sentinel") - self.match(match, tempvar) + self.match(match, tempvar) + else: + # go down to end to ensure we've matched as much as possible before evaluating the default + with self.down_to_end(): + self.add_default_expr(tempvar, default) + with self.down_a_level(): + self.match(match, tempvar) def match_dict(self, tokens, item): """Matches a dictionary.""" @@ -1050,11 +1106,11 @@ def match_class(self, tokens, item): handle_indentation( """ raise _coconut.TypeError("too many positional args in class match (pattern requires {num_pos_matches}; '{cls_name}' only supports 1)") - """, + """, ).format( num_pos_matches=len(pos_matches), cls_name=cls_name, - ), + ) ) else: self_match_matcher.match(pos_matches[0], item) @@ -1063,19 +1119,21 @@ def match_class(self, tokens, item): other_cls_matcher.add_check("not _coconut.type(" + item + ") in _coconut_self_match_types") match_args_var = other_cls_matcher.get_temp_var() other_cls_matcher.add_def( - handle_indentation(""" + handle_indentation( + """ {match_args_var} = _coconut.getattr({cls_name}, '__match_args__', ()) {type_any} {type_ignore} if not _coconut.isinstance({match_args_var}, _coconut.tuple): raise _coconut.TypeError("{cls_name}.__match_args__ must be a tuple") if _coconut.len({match_args_var}) < {num_pos_matches}: raise _coconut.TypeError("too many positional args in class match (pattern requires {num_pos_matches}; '{cls_name}' only supports %s)" % (_coconut.len({match_args_var}),)) - """).format( + """, + ).format( cls_name=cls_name, match_args_var=match_args_var, num_pos_matches=len(pos_matches), type_any=self.comp.wrap_comment(" type: _coconut.typing.Any"), type_ignore=self.comp.type_ignore_comment(), - ), + ) ) with other_cls_matcher.down_a_level(): for i, match in enumerate(pos_matches): @@ -1089,14 +1147,14 @@ def match_class(self, tokens, item): """ {match_args_var} = _coconut.getattr({cls_name}, '__match_args__', ()) {star_match_var} = _coconut.tuple(_coconut.getattr({item}, {match_args_var}[i]) for i in _coconut.range({num_pos_matches}, _coconut.len({match_args_var}))) - """, + """, ).format( match_args_var=self.get_temp_var(), cls_name=cls_name, star_match_var=star_match_var, item=item, num_pos_matches=len(pos_matches), - ), + ) ) with self.down_a_level(): self.match(star_match, star_match_var) @@ -1116,7 +1174,7 @@ def match_data(self, tokens, item): "_coconut.len({item}) >= {min_len}".format( item=item, min_len=len(pos_matches), - ), + ) ) self.match_all_in(pos_matches, item) @@ -1150,7 +1208,7 @@ def match_data(self, tokens, item): min_len=len(pos_matches), name_matches=tuple_str_of(name_matches, add_quotes=True), type_ignore=self.comp.type_ignore_comment(), - ), + ) ) with self.down_a_level(): self.add_check(temp_var) @@ -1164,13 +1222,13 @@ def match_data_or_class(self, tokens, item): handle_indentation( """ {is_data_result_var} = _coconut.getattr({cls_name}, "{is_data_var}", False) or _coconut.isinstance({cls_name}, _coconut.tuple) and _coconut.all(_coconut.getattr(_coconut_x, "{is_data_var}", False) for _coconut_x in {cls_name}) {type_ignore} - """, + """, ).format( is_data_result_var=is_data_result_var, is_data_var=is_data_var, cls_name=cls_name, type_ignore=self.comp.type_ignore_comment(), - ), + ) ) if_data, if_class = self.branches(2) @@ -1241,12 +1299,12 @@ def match_view(self, tokens, item): {func_result_var} = _coconut_sentinel else: raise - """, + """, ).format( func_result_var=func_result_var, view_func=view_func, item=item, - ), + ) ) with self.down_a_level(): @@ -1323,7 +1381,7 @@ def out(self): check_var=self.check_var, parameterization=parameterization, child_checks=child.out().rstrip(), - ), + ) ) # handle normal child groups @@ -1351,7 +1409,7 @@ def out(self): ).format( check_var=self.check_var, children_checks=children_checks, - ), + ) ) # commit variable definitions @@ -1367,7 +1425,7 @@ def out(self): ).format( set_name_var=self.get_set_name_var(name), name=name, - ), + ) ) if name_set_code: out.append( @@ -1379,7 +1437,7 @@ def out(self): ).format( check_var=self.check_var, name_set_code="".join(name_set_code), - ), + ) ) # handle guards @@ -1394,7 +1452,7 @@ def out(self): ).format( check_var=self.check_var, guards=paren_join(self.guards, "and"), - ), + ) ) return "".join(out) diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index e7ec5f6f1..a332de645 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -54,13 +54,14 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} else: abc.Sequence.register(numpy.ndarray) numpy_modules = {numpy_modules} - pandas_numpy_modules = {pandas_numpy_modules} + xarray_modules = {xarray_modules} + pandas_modules = {pandas_modules} jax_numpy_modules = {jax_numpy_modules} tee_type = type(itertools.tee((), 1)[0]) reiterables = abc.Sequence, abc.Mapping, abc.Set - fmappables = list, tuple, dict, set, frozenset + fmappables = list, tuple, dict, set, frozenset, bytes, bytearray abc.Sequence.register(collections.deque) - Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} + Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} @_coconut.functools.wraps(_coconut.functools.partial) def _coconut_partial(_coconut_func, *args, **kwargs): partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs) @@ -121,6 +122,20 @@ class _coconut_Sentinel(_coconut_baseclass): _coconut_sentinel = _coconut_Sentinel() def _coconut_get_base_module(obj): return obj.__class__.__module__.split(".", 1)[0] +def _coconut_xarray_to_pandas(obj): + import xarray + if isinstance(obj, xarray.Dataset): + return obj.to_dataframe() + elif isinstance(obj, xarray.DataArray): + return obj.to_series() + else: + return obj.to_pandas() +def _coconut_xarray_to_numpy(obj): + import xarray + if isinstance(obj, xarray.Dataset): + return obj.to_dataframe().to_numpy() + else: + return obj.to_numpy() class MatchError(_coconut_baseclass, Exception): """Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below} max_val_repr_len = 500 @@ -432,7 +447,7 @@ def and_then(first_async_func, second_func): first_async_func: async (**T) -> U, second_func: U -> V, ) -> async (**T) -> V = - async def (*args, **kwargs) -> ( + async def (*args, **kwargs) => ( first_async_func(*args, **kwargs) |> await |> second_func @@ -447,7 +462,7 @@ def and_then_await(first_async_func, second_async_func): first_async_func: async (**T) -> U, second_async_func: async U -> V, ) -> async (**T) -> V = - async def (*args, **kwargs) -> ( + async def (*args, **kwargs) => ( first_async_func(*args, **kwargs) |> await |> second_async_func @@ -458,98 +473,98 @@ def and_then_await(first_async_func, second_async_func): def _coconut_forward_compose(func, *funcs): """Forward composition operator (..>). - (..>)(f, g) is effectively equivalent to (*args, **kwargs) -> g(f(*args, **kwargs)).""" + (..>)(f, g) is effectively equivalent to (*args, **kwargs) => g(f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 0, False) for f in funcs)) def _coconut_back_compose(*funcs): """Backward composition operator (<..). - (<..)(f, g) is effectively equivalent to (*args, **kwargs) -> f(g(*args, **kwargs)).""" + (<..)(f, g) is effectively equivalent to (*args, **kwargs) => f(g(*args, **kwargs)).""" return _coconut_forward_compose(*_coconut.reversed(funcs)) def _coconut_forward_none_compose(func, *funcs): """Forward none-aware composition operator (..?>). - (..?>)(f, g) is effectively equivalent to (*args, **kwargs) -> g?(f(*args, **kwargs)).""" + (..?>)(f, g) is effectively equivalent to (*args, **kwargs) => g?(f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 0, True) for f in funcs)) def _coconut_back_none_compose(*funcs): """Backward none-aware composition operator (<..?). - (<..?)(f, g) is effectively equivalent to (*args, **kwargs) -> f?(g(*args, **kwargs)).""" + (<..?)(f, g) is effectively equivalent to (*args, **kwargs) => f?(g(*args, **kwargs)).""" return _coconut_forward_none_compose(*_coconut.reversed(funcs)) def _coconut_forward_star_compose(func, *funcs): """Forward star composition operator (..*>). - (..*>)(f, g) is effectively equivalent to (*args, **kwargs) -> g(*f(*args, **kwargs)).""" + (..*>)(f, g) is effectively equivalent to (*args, **kwargs) => g(*f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 1, False) for f in funcs)) def _coconut_back_star_compose(*funcs): """Backward star composition operator (<*..). - (<*..)(f, g) is effectively equivalent to (*args, **kwargs) -> f(*g(*args, **kwargs)).""" + (<*..)(f, g) is effectively equivalent to (*args, **kwargs) => f(*g(*args, **kwargs)).""" return _coconut_forward_star_compose(*_coconut.reversed(funcs)) def _coconut_forward_none_star_compose(func, *funcs): """Forward none-aware star composition operator (..?*>). - (..?*>)(f, g) is effectively equivalent to (*args, **kwargs) -> g?(*f(*args, **kwargs)).""" + (..?*>)(f, g) is effectively equivalent to (*args, **kwargs) => g?(*f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 1, True) for f in funcs)) def _coconut_back_none_star_compose(*funcs): """Backward none-aware star composition operator (<*?..). - (<*?..)(f, g) is effectively equivalent to (*args, **kwargs) -> f?(*g(*args, **kwargs)).""" + (<*?..)(f, g) is effectively equivalent to (*args, **kwargs) => f?(*g(*args, **kwargs)).""" return _coconut_forward_none_star_compose(*_coconut.reversed(funcs)) def _coconut_forward_dubstar_compose(func, *funcs): """Forward double star composition operator (..**>). - (..**>)(f, g) is effectively equivalent to (*args, **kwargs) -> g(**f(*args, **kwargs)).""" + (..**>)(f, g) is effectively equivalent to (*args, **kwargs) => g(**f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 2, False) for f in funcs)) def _coconut_back_dubstar_compose(*funcs): """Backward double star composition operator (<**..). - (<**..)(f, g) is effectively equivalent to (*args, **kwargs) -> f(**g(*args, **kwargs)).""" + (<**..)(f, g) is effectively equivalent to (*args, **kwargs) => f(**g(*args, **kwargs)).""" return _coconut_forward_dubstar_compose(*_coconut.reversed(funcs)) def _coconut_forward_none_dubstar_compose(func, *funcs): """Forward none-aware double star composition operator (..?**>). - (..?**>)(f, g) is effectively equivalent to (*args, **kwargs) -> g?(**f(*args, **kwargs)).""" + (..?**>)(f, g) is effectively equivalent to (*args, **kwargs) => g?(**f(*args, **kwargs)).""" return _coconut_base_compose(func, *((f, 2, True) for f in funcs)) def _coconut_back_none_dubstar_compose(*funcs): """Backward none-aware double star composition operator (<**?..). - (<**?..)(f, g) is effectively equivalent to (*args, **kwargs) -> f?(**g(*args, **kwargs)).""" + (<**?..)(f, g) is effectively equivalent to (*args, **kwargs) => f?(**g(*args, **kwargs)).""" return _coconut_forward_none_dubstar_compose(*_coconut.reversed(funcs)) def _coconut_pipe(x, f): - """Pipe operator (|>). Equivalent to (x, f) -> f(x).""" + """Pipe operator (|>). Equivalent to (x, f) => f(x).""" return f(x) def _coconut_star_pipe(xs, f): - """Star pipe operator (*|>). Equivalent to (xs, f) -> f(*xs).""" + """Star pipe operator (*|>). Equivalent to (xs, f) => f(*xs).""" return f(*xs) def _coconut_dubstar_pipe(kws, f): - """Double star pipe operator (**|>). Equivalent to (kws, f) -> f(**kws).""" + """Double star pipe operator (**|>). Equivalent to (kws, f) => f(**kws).""" return f(**kws) def _coconut_back_pipe(f, x): - """Backward pipe operator (<|). Equivalent to (f, x) -> f(x).""" + """Backward pipe operator (<|). Equivalent to (f, x) => f(x).""" return f(x) def _coconut_back_star_pipe(f, xs): - """Backward star pipe operator (<*|). Equivalent to (f, xs) -> f(*xs).""" + """Backward star pipe operator (<*|). Equivalent to (f, xs) => f(*xs).""" return f(*xs) def _coconut_back_dubstar_pipe(f, kws): - """Backward double star pipe operator (<**|). Equivalent to (f, kws) -> f(**kws).""" + """Backward double star pipe operator (<**|). Equivalent to (f, kws) => f(**kws).""" return f(**kws) def _coconut_none_pipe(x, f): - """Nullable pipe operator (|?>). Equivalent to (x, f) -> f(x) if x is not None else None.""" + """Nullable pipe operator (|?>). Equivalent to (x, f) => f(x) if x is not None else None.""" return None if x is None else f(x) def _coconut_none_star_pipe(xs, f): - """Nullable star pipe operator (|?*>). Equivalent to (xs, f) -> f(*xs) if xs is not None else None.""" + """Nullable star pipe operator (|?*>). Equivalent to (xs, f) => f(*xs) if xs is not None else None.""" return None if xs is None else f(*xs) def _coconut_none_dubstar_pipe(kws, f): - """Nullable double star pipe operator (|?**>). Equivalent to (kws, f) -> f(**kws) if kws is not None else None.""" + """Nullable double star pipe operator (|?**>). Equivalent to (kws, f) => f(**kws) if kws is not None else None.""" return None if kws is None else f(**kws) def _coconut_back_none_pipe(f, x): - """Nullable backward pipe operator ( f(x) if x is not None else None.""" + """Nullable backward pipe operator ( f(x) if x is not None else None.""" return None if x is None else f(x) def _coconut_back_none_star_pipe(f, xs): - """Nullable backward star pipe operator (<*?|). Equivalent to (f, xs) -> f(*xs) if xs is not None else None.""" + """Nullable backward star pipe operator (<*?|). Equivalent to (f, xs) => f(*xs) if xs is not None else None.""" return None if xs is None else f(*xs) def _coconut_back_none_dubstar_pipe(f, kws): - """Nullable backward double star pipe operator (<**?|). Equivalent to (kws, f) -> f(**kws) if kws is not None else None.""" + """Nullable backward double star pipe operator (<**?|). Equivalent to (kws, f) => f(**kws) if kws is not None else None.""" return None if kws is None else f(**kws) def _coconut_assert(cond, msg=None): """Assert operator (assert). Asserts condition with optional message.""" @@ -563,28 +578,31 @@ def _coconut_raise(exc=None, from_exc=None): exc.__cause__ = from_exc raise exc def _coconut_bool_and(a, b): - """Boolean and operator (and). Equivalent to (a, b) -> a and b.""" + """Boolean and operator (and). Equivalent to (a, b) => a and b.""" return a and b def _coconut_bool_or(a, b): - """Boolean or operator (or). Equivalent to (a, b) -> a or b.""" + """Boolean or operator (or). Equivalent to (a, b) => a or b.""" return a or b def _coconut_in(a, b): - """Containment operator (in). Equivalent to (a, b) -> a in b.""" + """Containment operator (in). Equivalent to (a, b) => a in b.""" return a in b def _coconut_not_in(a, b): - """Negative containment operator (not in). Equivalent to (a, b) -> a not in b.""" + """Negative containment operator (not in). Equivalent to (a, b) => a not in b.""" return a not in b def _coconut_none_coalesce(a, b): - """None coalescing operator (??). Equivalent to (a, b) -> a if a is not None else b.""" + """None coalescing operator (??). Equivalent to (a, b) => a if a is not None else b.""" return b if a is None else a def _coconut_minus(a, b=_coconut_sentinel): - """Minus operator (-). Effectively equivalent to (a, b=None) -> a - b if b is not None else -a.""" + """Minus operator (-). Effectively equivalent to (a, b=None) => a - b if b is not None else -a.""" if b is _coconut_sentinel: return -a return a - b def _coconut_comma_op(*args): - """Comma operator (,). Equivalent to (*args) -> args.""" + """Comma operator (,). Equivalent to (*args) => args.""" return args +def _coconut_if_op(cond, if_true, if_false): + """If operator (if). Equivalent to (cond, if_true, if_false) => if_true if cond else if_false.""" + return if_true if cond else if_false {def_coconut_matmul} class scan(_coconut_has_iter): """Reduce func over iterable, yielding intermediate results, @@ -749,8 +767,10 @@ Additionally supports Cartesian products of numpy arrays.""" if iterables: it_modules = [_coconut_get_base_module(it) for it in iterables] if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules): - if _coconut.any(mod in _coconut.pandas_numpy_modules for mod in it_modules): - iterables = tuple((it.to_numpy() if _coconut_get_base_module(it) in _coconut.pandas_numpy_modules else it) for it in iterables) + if _coconut.any(mod in _coconut.xarray_modules for mod in it_modules): + iterables = tuple((_coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) + if _coconut.any(mod in _coconut.pandas_modules for mod in it_modules): + iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules): from jax import numpy else: @@ -1580,7 +1600,7 @@ def fmap(func, obj, **kwargs): Supports: * Coconut data types - * `str`, `dict`, `list`, `tuple`, `set`, `frozenset` + * `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, `bytes`, `bytearray` * `dict` (maps over .items()) * asynchronous iterables * numpy arrays (uses np.vectorize) @@ -1602,7 +1622,9 @@ def fmap(func, obj, **kwargs): if result is not _coconut.NotImplemented: return result obj_module = _coconut_get_base_module(obj) - if obj_module in _coconut.pandas_numpy_modules: + if obj_module in _coconut.xarray_modules: + return {_coconut_}fmap(func, _coconut_xarray_to_pandas(obj)).to_xarray() + if obj_module in _coconut.pandas_modules: if obj.ndim <= 1: return obj.apply(func) return obj.apply(func, axis=obj.ndim-1) @@ -1620,10 +1642,11 @@ def fmap(func, obj, **kwargs): else: if aiter is not _coconut.NotImplemented: return _coconut_amap(func, aiter) - if starmap_over_mappings: - return _coconut_base_makedata(obj.__class__, {_coconut_}starmap(func, obj.items()) if _coconut.isinstance(obj, _coconut.abc.Mapping) else {_coconut_}map(func, obj), from_fmap=True, fallback_to_init=fallback_to_init) +{handle_bytes} if _coconut.isinstance(obj, _coconut.abc.Mapping): + mapped_obj = ({_coconut_}starmap if starmap_over_mappings else {_coconut_}map)(func, obj.items()) else: - return _coconut_base_makedata(obj.__class__, {_coconut_}map(func, obj.items() if _coconut.isinstance(obj, _coconut.abc.Mapping) else obj), from_fmap=True, fallback_to_init=fallback_to_init) + mapped_obj = _coconut_map(func, obj) + return _coconut_base_makedata(obj.__class__, mapped_obj, from_fmap=True, fallback_to_init=fallback_to_init) def _coconut_memoize_helper(maxsize=None, typed=False): return maxsize, typed def memoize(*args, **kwargs): @@ -1678,7 +1701,7 @@ def _coconut_dict_merge(*dicts, **kwargs): prevlen = _coconut.len(newdict) return newdict def ident(x, **kwargs): - """The identity function. Generally equivalent to x -> x. Useful in point-free programming. + """The identity function. Generally equivalent to x => x. Useful in point-free programming. Accepts one keyword-only argument, side_effect, which specifies a function to call on the argument before it is returned.""" side_effect = kwargs.pop("side_effect", None) if kwargs: @@ -1874,30 +1897,36 @@ class const(_coconut_base_callable): def __repr__(self): return "const(%s)" % (_coconut.repr(self.value),) class _coconut_lifted(_coconut_base_callable): - __slots__ = ("func", "func_args", "func_kwargs") - def __init__(self, _coconut_func, *func_args, **func_kwargs): - self.func = _coconut_func + __slots__ = ("apart", "func", "func_args", "func_kwargs") + def __init__(self, apart, func, func_args, func_kwargs): + self.apart = apart + self.func = func self.func_args = func_args self.func_kwargs = func_kwargs def __reduce__(self): - return (self.__class__, (self.func,) + self.func_args, {lbrace}"func_kwargs": self.func_kwargs{rbrace}) + return (self.__class__, (self.apart, self.func, self.func_args, self.func_kwargs)) def __call__(self, *args, **kwargs): - return self.func(*(g(*args, **kwargs) for g in self.func_args), **_coconut_py_dict((k, h(*args, **kwargs)) for k, h in self.func_kwargs.items())) + if self.apart: + return self.func(*(f(x) for f, x in {_coconut_}zip(self.func_args, args, strict=True)), **_coconut_py_dict((k, self.func_kwargs[k](kwargs[k])) for k in {all_keys})) + else: + return self.func(*(g(*args, **kwargs) for g in self.func_args), **_coconut_py_dict((k, h(*args, **kwargs)) for k, h in self.func_kwargs.items())) def __repr__(self): - return "lift(%r)(%s%s)" % (self.func, ", ".join(_coconut.repr(g) for g in self.func_args), ", ".join(k + "=" + _coconut.repr(h) for k, h in self.func_kwargs.items())) + return "lift%s(%r)(%s%s)" % (self.func, ("_apart" if self.apart else ""), ", ".join(_coconut.repr(g) for g in self.func_args), ", ".join(k + "=" + _coconut.repr(h) for k, h in self.func_kwargs.items())) class lift(_coconut_base_callable): - """Lift a function up so that all of its arguments are functions. + """Lift a function up so that all of its arguments are functions that all take the same arguments. For a binary function f(x, y) and two unary functions g(z) and h(z), lift works as the S' combinator: lift(f)(g, h)(z) == f(g(z), h(z)) In general, lift is equivalent to: - def lift(f) = ((*func_args, **func_kwargs) -> (*args, **kwargs) -> + def lift(f) = ((*func_args, **func_kwargs) => (*args, **kwargs) => ( f(*(g(*args, **kwargs) for g in func_args), **{lbrace}k: h(*args, **kwargs) for k, h in func_kwargs.items(){rbrace})) + ) lift also supports a shortcut form such that lift(f, *func_args, **func_kwargs) is equivalent to lift(f)(*func_args, **func_kwargs). """ __slots__ = ("func",) + _apart = False def __new__(cls, func, *func_args, **func_kwargs): self = _coconut.super({_coconut_}lift, cls).__new__(cls) self.func = func @@ -1907,20 +1936,38 @@ class lift(_coconut_base_callable): def __reduce__(self): return (self.__class__, (self.func,)) def __repr__(self): - return "lift(%r)" % (self.func,) + return "lift%s(%r)" % (("_apart" if self._apart else ""), self.func) def __call__(self, *func_args, **func_kwargs): - return _coconut_lifted(self.func, *func_args, **func_kwargs) -def all_equal(iterable): + return _coconut_lifted(self._apart, self.func, func_args, func_kwargs) +class lift_apart(lift): + """Lift a function up so that all of its arguments are functions that each take separate arguments. + + For a binary function f(x, y) and two unary functions g(z) and h(z), lift_apart works as the D2 combinator: + lift_apart(f)(g, h)(z, w) == f(g(z), h(w)) + + In general, lift_apart is equivalent to: + def lift_apart(func) = (*func_args, **func_kwargs) => (*args, **kwargs) => func( + *(f(x) for f, x in zip(func_args, args, strict=True)), + **{lbrace}k: func_kwargs[k](kwargs[k]) for k in func_kwargs.keys() | kwargs.keys(){rbrace}, + ) + + lift_apart also supports a shortcut form such that lift_apart(f, *func_args, **func_kwargs) is equivalent to lift_apart(f)(*func_args, **func_kwargs). + """ + _apart = True +def all_equal(iterable, to=_coconut_sentinel): """For a given iterable, check whether all elements in that iterable are equal to each other. + If 'to' is passed, check that all the elements are equal to that value. Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'. """ iterable_module = _coconut_get_base_module(iterable) if iterable_module in _coconut.numpy_modules: - if iterable_module in _coconut.pandas_numpy_modules: + if iterable_module in _coconut.xarray_modules: + iterable = _coconut_xarray_to_numpy(iterable) + elif iterable_module in _coconut.pandas_modules: iterable = iterable.to_numpy() - return not _coconut.len(iterable) or (iterable == iterable[0]).all() - first_item = _coconut_sentinel + return not _coconut.len(iterable) or (iterable == (iterable[0] if to is _coconut_sentinel else to)).all() + first_item = to for item in iterable: if first_item is _coconut_sentinel: first_item = item @@ -1974,7 +2021,7 @@ def collectby(key_func, iterable, value_func=None, **kwargs): If map_using is passed, calculate key_func and value_func by mapping them over the iterable using map_using as map. Useful with process_map/thread_map. """ - return {_coconut_}mapreduce(_coconut_lifted(_coconut_comma_op, key_func, {_coconut_}ident if value_func is None else value_func), iterable, **kwargs) + return {_coconut_}mapreduce(_coconut_lifted(False, _coconut_comma_op, (key_func, {_coconut_}ident if value_func is None else value_func), {empty_dict}), iterable, **kwargs) collectby.using_processes = _coconut_partial(_coconut_parallel_mapreduce, collectby, process_map) collectby.using_threads = _coconut_partial(_coconut_parallel_mapreduce, collectby, thread_map) def _namedtuple_of(**kwargs): @@ -1990,8 +2037,11 @@ def _coconut_mk_anon_namedtuple(fields, types=None, of_kwargs=None): return NT return NT(**of_kwargs) def _coconut_ndim(arr): - if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): + arr_mod = _coconut_get_base_module(arr) + if (arr_mod in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"): return arr.ndim + if arr_mod in _coconut.xarray_modules:{COMMENT.if_we_got_here_its_a_Dataset_not_a_DataArray} + return 2 if not _coconut.isinstance(arr, _coconut.abc.Sequence) or _coconut.isinstance(arr, (_coconut.str, _coconut.bytes)): return 0 if _coconut.len(arr) == 0: @@ -2016,27 +2066,25 @@ def _coconut_expand_arr(arr, new_dims): arr = [arr] return arr def _coconut_concatenate(arrs, axis): - matconcat = None for a in arrs: if _coconut.hasattr(a.__class__, "__matconcat__"): - matconcat = a.__class__.__matconcat__ - break - a_module = _coconut_get_base_module(a) - if a_module in _coconut.pandas_numpy_modules: - from pandas import concat as matconcat - break - if a_module in _coconut.jax_numpy_modules: - from jax.numpy import concatenate as matconcat - break - if a_module in _coconut.numpy_modules: - matconcat = _coconut.numpy.concatenate - break - if matconcat is not None: - return matconcat(arrs, axis=axis) + return a.__class__.__matconcat__(arrs, axis=axis) + arr_modules = [_coconut_get_base_module(a) for a in arrs] + if any(mod in _coconut.xarray_modules for mod in arr_modules): + return _coconut_concatenate([(_coconut_xarray_to_pandas(a) if mod in _coconut.xarray_modules else a) for a, mod in _coconut.zip(arrs, arr_modules)], axis).to_xarray() + if any(mod in _coconut.pandas_modules for mod in arr_modules): + import pandas + return pandas.concat(arrs, axis=axis) + if any(mod in _coconut.jax_numpy_modules for mod in arr_modules): + import jax.numpy + return jax.numpy.concatenate(arrs, axis=axis) + if any(mod in _coconut.numpy_modules for mod in arr_modules): + return _coconut.numpy.concatenate(arrs, axis=axis) if not axis: return _coconut.list(_coconut.itertools.chain.from_iterable(arrs)) return [_coconut_concatenate(rows, axis - 1) for rows in _coconut.zip(*arrs)] -def _coconut_multi_dim_arr(arrs, dim): +def _coconut_arr_concat_op(dim, *arrs): + """Coconut multi-dimensional array concatenation operator.""" arr_dims = [_coconut_ndim(a) for a in arrs] arrs = [_coconut_expand_arr(a, dim - d) if d < dim else a for a, d in _coconut.zip(arrs, arr_dims)] arr_dims.append(dim) @@ -2046,7 +2094,7 @@ def _coconut_call_or_coefficient(func, *args): if _coconut.callable(func): return func(*args) if not _coconut.isinstance(func, (_coconut.int, _coconut.float, _coconut.complex)) and _coconut_get_base_module(func) not in _coconut.numpy_modules: - raise _coconut.TypeError("implicit function application and coefficient syntax only supported for Callable, int, float, complex, and numpy objects") + raise _coconut.TypeError("first object in implicit function application and coefficient syntax must be Callable, int, float, complex, or numpy") func = func for x in args: func = func * x{COMMENT.no_times_equals_to_avoid_modification} @@ -2184,4 +2232,4 @@ class _coconut_SupportsInv(_coconut.typing.Protocol): {def_async_map} {def_aliases} _coconut_self_match_types = {self_match_types} -_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file} +_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_fmap, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, fmap, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file} diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index 64a0ff84f..ffbbf6151 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -88,6 +88,7 @@ univ_open, ensure_dir, get_clock_time, + literal_lines, ) from coconut.terminal import ( logger, @@ -123,16 +124,15 @@ unwrapper, incremental_cache_limit, incremental_mode_cache_successes, - adaptive_reparse_usage_weight, use_adaptive_any_of, disable_incremental_for_len, coconut_cache_dir, - use_adaptive_if_available, use_fast_pyparsing_reprs, save_new_cache_items, cache_validation_info, require_cache_clear_frac, reverse_any_of, + all_keywords, ) from coconut.exceptions import ( CoconutException, @@ -264,6 +264,19 @@ class ComputationNode(object): """A single node in the computation graph.""" __slots__ = ("action", "original", "loc", "tokens") pprinting = False + override_original = None + add_to_loc = 0 + + @classmethod + @contextmanager + def using_overrides(cls): + override_original, cls.override_original = cls.override_original, None + add_to_loc, cls.add_to_loc = cls.add_to_loc, 0 + try: + yield + finally: + cls.override_original = override_original + cls.add_to_loc = add_to_loc def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_one_token=False, greedy=False, trim_arity=True): """Create a ComputionNode to return from a parse action. @@ -281,8 +294,8 @@ def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_o self.action = _trim_arity(action) else: self.action = action - self.original = original - self.loc = loc + self.original = original if self.override_original is None else self.override_original + self.loc = self.add_to_loc + loc self.tokens = tokens if greedy: return self.evaluate() @@ -391,12 +404,38 @@ def add_action(item, action, make_copy=None): return item.addParseAction(action) -def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_tokens=None, trim_arity=None, make_copy=None, **kwargs): +def get_func_args(func): + """Inspect a function to determine its argument names.""" + if PY2: + return inspect.getargspec(func)[0] + else: + return inspect.getfullargspec(func)[0] + + +def should_trim_arity(func): + """Determine if we need to call _trim_arity on func.""" + annotation = getattr(func, "trim_arity", None) + if annotation is not None: + return annotation + try: + func_args = get_func_args(func) + except TypeError: + return True + if not func_args: + return True + if func_args[0] == "self": + func_args.pop(0) + if func_args[:3] == ["original", "loc", "tokens"]: + return False + return True + + +def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_arguments=None, trim_arity=None, make_copy=None, **kwargs): """Set the parse action for the given item to create a node in the computation graph.""" - if ignore_tokens is None: - ignore_tokens = getattr(action, "ignore_tokens", False) - # if ignore_tokens, then we can just pass in the computation graph and have it be ignored - if not ignore_tokens and USE_COMPUTATION_GRAPH: + if ignore_arguments is None: + ignore_arguments = getattr(action, "ignore_arguments", False) + # if ignore_arguments, then we can just pass in the computation graph and have it be ignored + if not ignore_arguments and USE_COMPUTATION_GRAPH: # use the action's annotations to generate the defaults if ignore_no_tokens is None: ignore_no_tokens = getattr(action, "ignore_no_tokens", False) @@ -417,43 +456,14 @@ def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_to def final_evaluate_tokens(tokens): """Same as evaluate_tokens but should only be used once a parse is assured.""" + result = evaluate_tokens(tokens, is_final=True) + # clear packrat cache after evaluating tokens so error creation gets to see the cache clear_packrat_cache() - return evaluate_tokens(tokens, is_final=True) - - -@contextmanager -def adaptive_manager(item, original, loc, reparse=False): - """Manage the use of MatchFirst.setAdaptiveMode.""" - if reparse: - cleared_cache = clear_packrat_cache() - if cleared_cache is not True: - item.include_in_packrat_context = True - MatchFirst.setAdaptiveMode(False, usage_weight=adaptive_reparse_usage_weight) - try: - yield - finally: - MatchFirst.setAdaptiveMode(False, usage_weight=1) - if cleared_cache is not True: - item.include_in_packrat_context = False - else: - MatchFirst.setAdaptiveMode(True) - try: - yield - except Exception as exc: - if DEVELOP: - logger.log("reparsing due to:", exc) - logger.record_stat("adaptive", False) - else: - if DEVELOP: - logger.record_stat("adaptive", True) - finally: - MatchFirst.setAdaptiveMode(False) + return result def final(item): """Collapse the computation graph upon parsing the given item.""" - if SUPPORTS_ADAPTIVE and use_adaptive_if_available: - item = Wrap(item, adaptive_manager, greedy=True) # evaluate_tokens expects a computation graph, so we just call add_action directly return add_action(trace(item), final_evaluate_tokens) @@ -489,11 +499,35 @@ def force_reset_packrat_cache(): @contextmanager -def parsing_context(inner_parse=True): +def parsing_context(inner_parse=None): """Context to manage the packrat cache across parse calls.""" - if not inner_parse: - yield - elif should_clear_cache(): + current_cache_matters = ( + inner_parse is not False + and ParserElement._packratEnabled + ) + new_cache_matters = ( + inner_parse is not True + and ParserElement._incrementalEnabled + and not ParserElement._incrementalWithResets + ) + will_clear_cache = ( + not ParserElement._incrementalEnabled + or ParserElement._incrementalWithResets + ) + if ( + current_cache_matters + and new_cache_matters + and ParserElement._incrementalWithResets + ): + incrementalWithResets, ParserElement._incrementalWithResets = ParserElement._incrementalWithResets, False + try: + yield + finally: + ParserElement._incrementalWithResets = incrementalWithResets + elif ( + current_cache_matters + and will_clear_cache + ): # store old packrat cache old_cache = ParserElement.packrat_cache old_cache_stats = ParserElement.packrat_cache_stats[:] @@ -507,13 +541,6 @@ def parsing_context(inner_parse=True): if logger.verbose: ParserElement.packrat_cache_stats[0] += old_cache_stats[0] ParserElement.packrat_cache_stats[1] += old_cache_stats[1] - # if we shouldn't clear the cache, but we're using incrementalWithResets, then do this to avoid clearing it - elif ParserElement._incrementalWithResets: - incrementalWithResets, ParserElement._incrementalWithResets = ParserElement._incrementalWithResets, False - try: - yield - finally: - ParserElement._incrementalWithResets = incrementalWithResets else: yield @@ -529,7 +556,7 @@ def prep_grammar(grammar, streamline=False): return grammar.parseWithTabs() -def parse(grammar, text, inner=True, eval_parse_tree=True): +def parse(grammar, text, inner=None, eval_parse_tree=True): """Parse text using grammar.""" with parsing_context(inner): result = prep_grammar(grammar).parseString(text) @@ -538,7 +565,7 @@ def parse(grammar, text, inner=True, eval_parse_tree=True): return result -def try_parse(grammar, text, inner=True, eval_parse_tree=True): +def try_parse(grammar, text, inner=None, eval_parse_tree=True): """Attempt to parse text using grammar else None.""" try: return parse(grammar, text, inner, eval_parse_tree) @@ -546,12 +573,12 @@ def try_parse(grammar, text, inner=True, eval_parse_tree=True): return None -def does_parse(grammar, text, inner=True): +def does_parse(grammar, text, inner=None): """Determine if text can be parsed using grammar.""" return try_parse(grammar, text, inner, eval_parse_tree=False) -def all_matches(grammar, text, inner=True, eval_parse_tree=True): +def all_matches(grammar, text, inner=None, eval_parse_tree=True): """Find all matches for grammar in text.""" with parsing_context(inner): for tokens, start, stop in prep_grammar(grammar).scanString(text): @@ -560,21 +587,21 @@ def all_matches(grammar, text, inner=True, eval_parse_tree=True): yield tokens, start, stop -def parse_where(grammar, text, inner=True): +def parse_where(grammar, text, inner=None): """Determine where the first parse is.""" for tokens, start, stop in all_matches(grammar, text, inner, eval_parse_tree=False): return start, stop return None, None -def match_in(grammar, text, inner=True): +def match_in(grammar, text, inner=None): """Determine if there is a match for grammar anywhere in text.""" start, stop = parse_where(grammar, text, inner) internal_assert((start is None) == (stop is None), "invalid parse_where results", (start, stop)) return start is not None -def transform(grammar, text, inner=True): +def transform(grammar, text, inner=None): """Transform text by replacing matches to grammar.""" with parsing_context(inner): result = prep_grammar(add_action(grammar, unpack)).transformString(text) @@ -753,7 +780,7 @@ def should_clear_cache(force=False): return True elif not ParserElement._packratEnabled: return False - elif SUPPORTS_INCREMENTAL and ParserElement._incrementalEnabled: + elif ParserElement._incrementalEnabled: if not in_incremental_mode(): return repeatedly_clear_incremental_cache if ( @@ -845,8 +872,8 @@ def get_cache_items_for(original, only_useful=False, exclude_stale=True): def get_highest_parse_loc(original): - """Get the highest observed parse location.""" - # find the highest observed parse location + """Get the highest observed parse location. + Note that there's no point in filtering for successes/failures, since we always see both at the same locations.""" highest_loc = 0 for lookup, _ in get_cache_items_for(original): loc = lookup[2] @@ -1139,7 +1166,7 @@ class Wrap(ParseElementEnhance): global_instance_counter = 0 inside = False - def __init__(self, item, wrapper, greedy=False, include_in_packrat_context=False): + def __init__(self, item, wrapper, greedy=False, include_in_packrat_context=True): super(Wrap, self).__init__(item) self.wrapper = wrapper self.greedy = greedy @@ -1179,7 +1206,7 @@ def parseImpl(self, original, loc, *args, **kwargs): reparse = False parse_loc = None while parse_loc is None: # lets wrapper catch errors to trigger a reparse - with self.wrapper(self, original, loc, **(dict(reparse=True) if reparse else {})): + with self.wrapper(original, loc, self, **(dict(reparse=True) if reparse else {})): with self.wrapped_context(): parse_loc, tokens = super(Wrap, self).parseImpl(original, loc, *args, **kwargs) if self.greedy: @@ -1198,10 +1225,14 @@ def __repr__(self): return self.wrapped_name -def handle_and_manage(item, handler, manager): +def manage(item, manager, include_in_packrat_context, greedy=True): + """Attach a manager to the given parse item.""" + return Wrap(item, manager, include_in_packrat_context=include_in_packrat_context, greedy=greedy) + + +def handle_and_manage(item, handler, manager, **kwargs): """Attach a handler and a manager to the given parse item.""" - new_item = attach(item, handler) - return Wrap(new_item, manager, greedy=True) + return manage(attach(item, handler), manager, **kwargs) def disable_inside(item, *elems, **kwargs): @@ -1215,7 +1246,7 @@ def disable_inside(item, *elems, **kwargs): level = [0] # number of wrapped items deep we are; in a list to allow modification @contextmanager - def manage_item(self, original, loc): + def manage_item(original, loc, self): level[0] += 1 try: yield @@ -1225,7 +1256,7 @@ def manage_item(self, original, loc): yield Wrap(item, manage_item, include_in_packrat_context=True) @contextmanager - def manage_elem(self, original, loc): + def manage_elem(original, loc, self): if level[0] == 0 if not _invert else level[0] > 0: yield else: @@ -1259,7 +1290,7 @@ def invalid_syntax(item, msg, **kwargs): def invalid_syntax_handle(loc, tokens): raise CoconutDeferredSyntaxError(msg, loc) - return attach(item, invalid_syntax_handle, ignore_tokens=True, **kwargs) + return attach(item, invalid_syntax_handle, ignore_arguments=True, **kwargs) def skip_to_in_line(item): @@ -1303,7 +1334,7 @@ def regex_item(regex, options=None): def fixto(item, output): """Force an item to result in a specific output.""" - return attach(item, replaceWith(output), ignore_tokens=True) + return attach(item, replaceWith(output), ignore_arguments=True) def addspace(item): @@ -1414,9 +1445,6 @@ def stores_loc_action(loc, tokens): return str(loc) -stores_loc_action.ignore_tokens = True - - always_match = Empty() stores_loc_item = attach(always_match, stores_loc_action) @@ -1430,12 +1458,15 @@ def disallow_keywords(kwds, with_suffix=""): return regex_item(r"(?!" + "|".join(to_disallow) + r")").suppress() -def disambiguate_literal(literal, not_literals): +def disambiguate_literal(literal, not_literals, fixesto=None): """Get an item that matchesl literal and not any of not_literals.""" - return regex_item( + item = regex_item( r"(?!" + "|".join(re.escape(s) for s in not_literals) + ")" + re.escape(literal) ) + if fixesto is not None: + item = fixto(item, fixesto) + return item def any_keyword_in(kwds): @@ -1515,12 +1546,16 @@ def any_len_perm_at_least_one(*elems, **kwargs): return any_len_perm_with_one_of_each_group(*groups_and_elems) -def caseless_literal(literalstr, suppress=False): +def caseless_literal(literalstr, suppress=False, disambiguate=False): """Version of CaselessLiteral that always parses to the given literalstr.""" + out = CaselessLiteral(literalstr) if suppress: - return CaselessLiteral(literalstr).suppress() + out = out.suppress() else: - return fixto(CaselessLiteral(literalstr), literalstr) + out = fixto(out, literalstr) + if disambiguate: + out = disallow_keywords(k for k in all_keywords if k.startswith((literalstr[0].lower(), literalstr[0].upper()))) + out + return out # ----------------------------------------------------------------------------------------------------------------------- @@ -1806,13 +1841,13 @@ def collapse_indents(indentation): def is_blank(line): """Determine whether a line is blank.""" line, _ = rem_and_count_indents(rem_comment(line)) - return line.strip() == "" + return not line or line.isspace() def final_indentation_level(code): """Determine the final indentation level of the given code.""" level = 0 - for line in code.splitlines(): + for line in literal_lines(code): leading_indent, _, trailing_indent = split_leading_trailing_indent(line) level += ind_change(leading_indent) + ind_change(trailing_indent) return level @@ -1883,32 +1918,6 @@ def literal_eval(py_code): raise CoconutInternalException("failed to literal eval", py_code) -def get_func_args(func): - """Inspect a function to determine its argument names.""" - if PY2: - return inspect.getargspec(func)[0] - else: - return inspect.getfullargspec(func)[0] - - -def should_trim_arity(func): - """Determine if we need to call _trim_arity on func.""" - annotation = getattr(func, "trim_arity", None) - if annotation is not None: - return annotation - try: - func_args = get_func_args(func) - except TypeError: - return True - if not func_args: - return True - if func_args[0] == "self": - func_args.pop(0) - if func_args[:3] == ["original", "loc", "tokens"]: - return False - return True - - def sequential_split(inputstr, splits): """Slice off parts of inputstr by sequential splits.""" out = [inputstr] @@ -2011,7 +2020,7 @@ def sub_all(inputstr, regexes, replacements): # ----------------------------------------------------------------------------------------------------------------------- -# PYTEST: +# EXTRAS: # ----------------------------------------------------------------------------------------------------------------------- @@ -2042,3 +2051,32 @@ def pytest_rewrite_asserts(code, module_name=reserved_prefix + "_pytest_module") rewrite_asserts(tree, module_name) fixed_tree = ast.fix_missing_locations(FixPytestNames().visit(tree)) return ast.unparse(fixed_tree) + + +@contextmanager +def adaptive_manager(original, loc, item, reparse=False): + """Manage the use of MatchFirst.setAdaptiveMode.""" + if reparse: + cleared_cache = clear_packrat_cache() + if cleared_cache is not True: + item.include_in_packrat_context = True + MatchFirst.setAdaptiveMode(False, usage_weight=10) + try: + yield + finally: + MatchFirst.setAdaptiveMode(False, usage_weight=1) + if cleared_cache is not True: + item.include_in_packrat_context = False + else: + MatchFirst.setAdaptiveMode(True) + try: + yield + except Exception as exc: + if DEVELOP: + logger.log("reparsing due to:", exc) + logger.record_stat("adaptive", False) + else: + if DEVELOP: + logger.record_stat("adaptive", True) + finally: + MatchFirst.setAdaptiveMode(False) diff --git a/coconut/constants.py b/coconut/constants.py index a6c276a8e..c9d7d095a 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -37,7 +37,7 @@ def fixpath(path): return os.path.normpath(os.path.realpath(os.path.expanduser(path))) -def get_bool_env_var(env_var, default=False): +def get_bool_env_var(env_var, default=None): """Get a boolean from an environment variable.""" boolstr = os.getenv(env_var, "").lower() if boolstr in ("true", "yes", "on", "1", "t"): @@ -84,7 +84,7 @@ def get_path_env_var(env_var, default): PY311 = sys.version_info >= (3, 11) PY312 = sys.version_info >= (3, 12) IPY = ( - PY35 + PY36 and (PY37 or not PYPY) and not (PYPY and WINDOWS) and sys.version_info[:2] != (3, 7) @@ -146,8 +146,7 @@ def get_path_env_var(env_var, default): # note that _parseIncremental produces much smaller caches use_incremental_if_available = False -use_adaptive_if_available = False # currently broken -adaptive_reparse_usage_weight = 10 +use_line_by_line_parser = False # these only apply to use_incremental_if_available, not compiler.util.enable_incremental_parsing() default_incremental_cache_size = None @@ -180,7 +179,10 @@ def get_path_env_var(env_var, default): sys.setrecursionlimit(default_recursion_limit) # modules that numpy-like arrays can live in -pandas_numpy_modules = ( +xarray_modules = ( + "xarray", +) +pandas_modules = ( "pandas", ) jax_numpy_modules = ( @@ -190,7 +192,8 @@ def get_path_env_var(env_var, default): "numpy", "torch", ) + ( - pandas_numpy_modules + xarray_modules + + pandas_modules + jax_numpy_modules ) @@ -315,12 +318,11 @@ def get_path_env_var(env_var, default): ) tabideal = 4 # spaces to indent code for displaying - -taberrfmt = 2 # spaces to indent exceptions - justify_len = 79 # ideal line length +taberrfmt = 2 # spaces to indent exceptions min_squiggles_in_err_msg = 1 +default_max_err_msg_lines = 10 # for pattern-matching default_matcher_style = "python warn" @@ -437,11 +439,11 @@ def get_path_env_var(env_var, default): r"locals", r"globals", r"(py_)?super", - r"(typing\.)?cast", - r"(sys\.)?exc_info", - r"(sys\.)?_getframe", - r"(sys\.)?_current_frames", - r"(sys\.)?_current_exceptions", + r"cast", + r"exc_info", + r"sys\.[a-zA-Z0-9_.]+", + r"traceback\.[a-zA-Z0-9_.]+", + r"typing\.[a-zA-Z0-9_.]+", ) py3_to_py2_stdlib = { @@ -574,45 +576,34 @@ def get_path_env_var(env_var, default): ) python_builtins = ( - '__import__', 'abs', 'all', 'any', 'bin', 'bool', 'bytearray', - 'breakpoint', 'bytes', 'chr', 'classmethod', 'compile', 'complex', - 'delattr', 'dict', 'dir', 'divmod', 'enumerate', 'eval', 'filter', - 'float', 'format', 'frozenset', 'getattr', 'globals', 'hasattr', - 'hash', 'hex', 'id', 'input', 'int', 'isinstance', 'issubclass', - 'iter', 'len', 'list', 'locals', 'map', 'max', 'memoryview', - 'min', 'next', 'object', 'oct', 'open', 'ord', 'pow', 'print', - 'property', 'range', 'repr', 'reversed', 'round', 'set', 'setattr', - 'slice', 'sorted', 'staticmethod', 'str', 'sum', 'super', 'tuple', - 'type', 'vars', 'zip', - 'Ellipsis', 'NotImplemented', - 'ArithmeticError', 'AssertionError', 'AttributeError', - 'BaseException', 'BufferError', 'BytesWarning', 'DeprecationWarning', - 'EOFError', 'EnvironmentError', 'Exception', 'FloatingPointError', - 'FutureWarning', 'GeneratorExit', 'IOError', 'ImportError', - 'ImportWarning', 'IndentationError', 'IndexError', 'KeyError', - 'KeyboardInterrupt', 'LookupError', 'MemoryError', 'NameError', - 'NotImplementedError', 'OSError', 'OverflowError', - 'PendingDeprecationWarning', 'ReferenceError', 'ResourceWarning', - 'RuntimeError', 'RuntimeWarning', 'StopIteration', - 'SyntaxError', 'SyntaxWarning', 'SystemError', 'SystemExit', - 'TabError', 'TypeError', 'UnboundLocalError', 'UnicodeDecodeError', - 'UnicodeEncodeError', 'UnicodeError', 'UnicodeTranslateError', - 'UnicodeWarning', 'UserWarning', 'ValueError', 'VMSError', - 'Warning', 'WindowsError', 'ZeroDivisionError', + "abs", "aiter", "all", "anext", "any", "ascii", + "bin", "bool", "breakpoint", "bytearray", "bytes", + "callable", "chr", "classmethod", "compile", "complex", + "delattr", "dict", "dir", "divmod", + "enumerate", "eval", "exec", + "filter", "float", "format", "frozenset", + "getattr", "globals", + "hasattr", "hash", "help", "hex", + "id", "input", "int", "isinstance", "issubclass", "iter", + "len", "list", "locals", + "map", "max", "memoryview", "min", + "next", + "object", "oct", "open", "ord", + "pow", "print", "property", + "range", "repr", "reversed", "round", + "set", "setattr", "slice", "sorted", "staticmethod", "str", "sum", "super", + "tuple", "type", + "vars", + "zip", + "__import__", '__name__', '__file__', '__annotations__', '__debug__', - # we treat these as coconut_exceptions so the highlighter will always know about them: - # 'ExceptionGroup', 'BaseExceptionGroup', - # don't include builtins that aren't always made available by Coconut: - # 'BlockingIOError', 'ChildProcessError', 'ConnectionError', - # 'BrokenPipeError', 'ConnectionAbortedError', 'ConnectionRefusedError', - # 'ConnectionResetError', 'FileExistsError', 'FileNotFoundError', - # 'InterruptedError', 'IsADirectoryError', 'NotADirectoryError', - # 'PermissionError', 'ProcessLookupError', 'TimeoutError', - # 'StopAsyncIteration', 'ModuleNotFoundError', 'RecursionError', - # 'EncodingWarning', +) + +python_exceptions = ( + "BaseException", "BaseExceptionGroup", "GeneratorExit", "KeyboardInterrupt", "SystemExit", "Exception", "ArithmeticError", "FloatingPointError", "OverflowError", "ZeroDivisionError", "AssertionError", "AttributeError", "BufferError", "EOFError", "ExceptionGroup", "BaseExceptionGroup", "ImportError", "ModuleNotFoundError", "LookupError", "IndexError", "KeyError", "MemoryError", "NameError", "UnboundLocalError", "OSError", "BlockingIOError", "ChildProcessError", "ConnectionError", "BrokenPipeError", "ConnectionAbortedError", "ConnectionRefusedError", "ConnectionResetError", "FileExistsError", "FileNotFoundError", "InterruptedError", "IsADirectoryError", "NotADirectoryError", "PermissionError", "ProcessLookupError", "TimeoutError", "ReferenceError", "RuntimeError", "NotImplementedError", "RecursionError", "StopAsyncIteration", "StopIteration", "SyntaxError", "IndentationError", "TabError", "SystemError", "TypeError", "ValueError", "UnicodeError", "UnicodeDecodeError", "UnicodeEncodeError", "UnicodeTranslateError", "Warning", "BytesWarning", "DeprecationWarning", "EncodingWarning", "FutureWarning", "ImportWarning", "PendingDeprecationWarning", "ResourceWarning", "RuntimeWarning", "SyntaxWarning", "UnicodeWarning", "UserWarning", ) # ----------------------------------------------------------------------------------------------------------------------- @@ -640,11 +631,13 @@ def get_path_env_var(env_var, default): coconut_home = get_path_env_var(home_env_var, "~") -use_color = get_bool_env_var("COCONUT_USE_COLOR", None) +use_color_env_var = "COCONUT_USE_COLOR" error_color_code = "31" log_color_code = "93" default_style = "default" +fake_styles = ("none", "list") + prompt_histfile = get_path_env_var( "COCONUT_HISTORY_FILE", os.path.join(coconut_home, ".coconut_history"), @@ -793,6 +786,7 @@ def get_path_env_var(env_var, default): "flip", "const", "lift", + "lift_apart", "all_equal", "collectby", "mapreduce", @@ -837,12 +831,16 @@ def get_path_env_var(env_var, default): coconut_exceptions = ( "MatchError", - "ExceptionGroup", - "BaseExceptionGroup", ) -highlight_builtins = coconut_specific_builtins + interp_only_builtins -all_builtins = frozenset(python_builtins + coconut_specific_builtins + coconut_exceptions) +highlight_builtins = coconut_specific_builtins + interp_only_builtins + python_builtins +highlight_exceptions = coconut_exceptions + python_exceptions +all_builtins = frozenset( + python_builtins + + python_exceptions + + coconut_specific_builtins + + coconut_exceptions +) magic_methods = ( "__fmap__", @@ -942,7 +940,8 @@ def get_path_env_var(env_var, default): ("ipython", "py3;py<37"), ("ipython", "py==37"), ("ipython", "py==38"), - ("ipython", "py>=39"), + ("ipython", "py==39"), + ("ipython", "py>=310"), ("ipykernel", "py<3"), ("ipykernel", "py3;py<38"), ("ipykernel", "py38"), @@ -976,8 +975,8 @@ def get_path_env_var(env_var, default): ), "xonsh": ( ("xonsh", "py<36"), - ("xonsh", "py>=36;py<38"), - ("xonsh", "py38"), + ("xonsh", "py>=36;py<39"), + ("xonsh", "py39"), ), "dev": ( ("pre-commit", "py3"), @@ -997,10 +996,12 @@ def get_path_env_var(env_var, default): ("numpy", "py34;py<39"), ("numpy", "py39"), ("pandas", "py36"), + ("xarray", "py39"), ), "tests": ( ("pytest", "py<36"), - ("pytest", "py36"), + ("pytest", "py>=36;py<38"), + ("pytest", "py38"), "pexpect", ), } @@ -1013,35 +1014,36 @@ def get_path_env_var(env_var, default): "jupyter": (1, 0), "types-backports": (0, 1), ("futures", "py<3"): (3, 4), - ("backports.functools-lru-cache", "py<3"): (1, 6), ("argparse", "py<27"): (1, 4), "pexpect": (4,), ("trollius", "py<3;cpy"): (2, 2), "requests": (2, 31), ("numpy", "py39"): (1, 26), + ("xarray", "py39"): (2024,), ("dataclasses", "py==36"): (0, 8), ("aenum", "py<34"): (3, 1, 15), - "pydata-sphinx-theme": (0, 14), + "pydata-sphinx-theme": (0, 15), "myst-parser": (2,), "sphinx": (7,), - "mypy[python2]": (1, 7), + "mypy[python2]": (1, 8), ("jupyter-console", "py37"): (6, 6), ("typing", "py<35"): (3, 10), - ("typing_extensions", "py>=38"): (4, 8), + ("typing_extensions", "py>=38"): (4, 9), ("ipykernel", "py38"): (6,), ("jedi", "py39"): (0, 19), ("pygments", "py>=39"): (2, 17), - ("xonsh", "py38"): (0, 14), - ("pytest", "py36"): (7,), + ("xonsh", "py39"): (0, 15), + ("pytest", "py38"): (8,), ("async_generator", "py35"): (1, 10), ("exceptiongroup", "py37;py<311"): (1,), - ("ipython", "py>=39"): (8, 18), + ("ipython", "py>=310"): (8, 22), "py-spy": (0, 3), } pinned_min_versions = { # don't upgrade these; they break on Python 3.9 ("numpy", "py34;py<39"): (1, 18), + ("ipython", "py==39"): (8, 18), # don't upgrade these; they break on Python 3.8 ("ipython", "py==38"): (8, 12), # don't upgrade these; they break on Python 3.7 @@ -1049,10 +1051,11 @@ def get_path_env_var(env_var, default): ("typing_extensions", "py==37"): (4, 7), # don't upgrade these; they break on Python 3.6 ("anyio", "py36"): (3,), - ("xonsh", "py>=36;py<38"): (0, 11), + ("xonsh", "py>=36;py<39"): (0, 11), ("pandas", "py36"): (1,), ("jupyter-client", "py36"): (7, 1, 2), ("typing_extensions", "py==36"): (4, 1), + ("pytest", "py>=36;py<38"): (7,), # don't upgrade these; they break on Python 3.5 ("ipykernel", "py3;py<38"): (5, 5), ("ipython", "py3;py<37"): (7, 9), @@ -1081,6 +1084,7 @@ def get_path_env_var(env_var, default): "watchdog": (0, 10), "papermill": (1, 2), ("numpy", "py<3;cpy"): (1, 16), + ("backports.functools-lru-cache", "py<3"): (1, 6), # don't upgrade this; it breaks with old IPython versions ("jedi", "py<39"): (0, 17), # Coconut requires pyparsing 2 @@ -1235,6 +1239,8 @@ def get_path_env_var(env_var, default): "coconut_pycon = coconut.highlighter:CoconutPythonConsoleLexer", ) +setuptools_distribution_names = ("bdist", "sdist") + requests_sleep_times = (0, 0.1, 0.2, 0.3, 0.4, 1) # ----------------------------------------------------------------------------------------------------------------------- @@ -1267,8 +1273,11 @@ def get_path_env_var(env_var, default): "coconut3", ) -py_syntax_version = 3 mimetype = "text/x-python3" +codemirror_mode = { + "name": "ipython", + "version": 3, +} all_keywords = keyword_vars + const_vars + reserved_vars @@ -1276,6 +1285,9 @@ def get_path_env_var(env_var, default): enabled_xonsh_modes = ("single",) +# 1 is safe, 2 seems to work okay, and 3 breaks stuff like '"""\n(\n)\n"""' +num_assemble_logical_lines_tries = 1 + # ----------------------------------------------------------------------------------------------------------------------- # DOCUMENTATION CONSTANTS: # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/exceptions.py b/coconut/exceptions.py index 341ef3831..89843a428 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -30,14 +30,16 @@ taberrfmt, report_this_text, min_squiggles_in_err_msg, + default_max_err_msg_lines, ) from coconut.util import ( pickleable_obj, clip, - logical_lines, + literal_lines, clean, get_displayable_target, normalize_newlines, + highlight, ) # ----------------------------------------------------------------------------------------------------------------------- @@ -88,7 +90,6 @@ class CoconutException(BaseCoconutException, Exception): class CoconutSyntaxError(CoconutException): """Coconut SyntaxError.""" - point_to_endpoint = False argnames = ("message", "source", "point", "ln", "extra", "endpoint", "filename") def __init__(self, message, source=None, point=None, ln=None, extra=None, endpoint=None, filename=None): @@ -100,21 +101,31 @@ def kwargs(self): """Get the arguments as keyword arguments.""" return dict(zip(self.argnames, self.args)) + point_to_endpoint = False + max_err_msg_lines = default_max_err_msg_lines + + def set_formatting(self, point_to_endpoint=None, max_err_msg_lines=None): + """Sets formatting values.""" + if point_to_endpoint is not None: + self.point_to_endpoint = point_to_endpoint + if max_err_msg_lines is not None: + self.max_err_msg_lines = max_err_msg_lines + return self + def message(self, message, source, point, ln, extra=None, endpoint=None, filename=None): """Creates a SyntaxError-like message.""" - if message is None: - message = "parsing failed" + message_parts = ["parsing failed" if message is None else message] if extra is not None: - message += " (" + str(extra) + ")" + message_parts += [" (", str(extra), ")"] if ln is not None: - message += " (line " + str(ln) + message_parts += [" (line ", str(ln)] if filename is not None: - message += " in " + repr(filename) - message += ")" + message_parts += [" in ", repr(filename)] + message_parts += [")"] if source: if point is None: for line in source.splitlines(): - message += "\n" + " " * taberrfmt + clean(line) + message_parts += ["\n", " " * taberrfmt, clean(line)] else: source = normalize_newlines(source) point = clip(point, 0, len(source)) @@ -129,7 +140,7 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam point_ind = getcol(point, source) - 1 endpoint_ind = getcol(endpoint, source) - 1 - source_lines = tuple(logical_lines(source, keep_newlines=True)) + source_lines = tuple(literal_lines(source, keep_newlines=True)) # walk the endpoint line back until it points to real text while endpoint_ln > point_ln and not "".join(source_lines[endpoint_ln - 1:endpoint_ln]).strip(): @@ -153,23 +164,25 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam point_ind = clip(point_ind, 0, len(part)) endpoint_ind = clip(endpoint_ind, point_ind, len(part)) - message += "\n" + " " * taberrfmt + part + # add code to message, highlighting part only at end so as not to change len(part) + message_parts += ["\n", " " * taberrfmt, highlight(part)] - if point_ind > 0 or endpoint_ind > 0: - err_len = endpoint_ind - point_ind - message += "\n" + " " * (taberrfmt + point_ind) + # add squiggles to message + err_len = endpoint_ind - point_ind + if (point_ind > 0 or endpoint_ind > 0) and err_len < len(part): + message_parts += ["\n", " " * (taberrfmt + point_ind)] if err_len <= min_squiggles_in_err_msg: if not self.point_to_endpoint: - message += "^" - message += "~" * err_len # err_len ~'s when there's only an extra char in one spot + message_parts += ["^"] + message_parts += ["~" * err_len] # err_len ~'s when there's only an extra char in one spot if self.point_to_endpoint: - message += "^" + message_parts += ["^"] else: - message += ( - ("^" if not self.point_to_endpoint else "\\") - + "~" * (err_len - 1) # err_len-1 ~'s when there's an extra char at the start and end - + ("^" if self.point_to_endpoint else "/" if endpoint_ind < len(part) else "|") - ) + message_parts += [ + ("^" if not self.point_to_endpoint else "\\"), + "~" * (err_len - 1), # err_len-1 ~'s when there's an extra char at the start and end + ("^" if self.point_to_endpoint else "/" if endpoint_ind < len(part) else "|"), + ] # multi-line error message else: @@ -182,20 +195,35 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam max_line_len = max(len(line) for line in lines) - message += "\n" + " " * (taberrfmt + point_ind) + # add top squiggles + message_parts += ["\n", " " * (taberrfmt + point_ind)] if point_ind >= len(lines[0]): - message += "|" + message_parts += ["|"] else: - message += "/" + "~" * (len(lines[0]) - point_ind - 1) - message += "~" * (max_line_len - len(lines[0])) + "\n" - for line in lines: - message += "\n" + " " * taberrfmt + line - message += ( - "\n\n" + " " * taberrfmt + "~" * endpoint_ind - + ("^" if self.point_to_endpoint else "/" if 0 < endpoint_ind < len(lines[-1]) else "|") - ) + message_parts += ["/", "~" * (len(lines[0]) - point_ind - 1)] + message_parts += ["~" * (max_line_len - len(lines[0])), "\n"] + + # add code, highlighting all of it together + code_parts = [] + if len(lines) > self.max_err_msg_lines: + for i in range(self.max_err_msg_lines // 2): + code_parts += ["\n", " " * taberrfmt, lines[i]] + code_parts += ["\n", " " * (taberrfmt // 2), "..."] + for i in range(len(lines) - self.max_err_msg_lines // 2, len(lines)): + code_parts += ["\n", " " * taberrfmt, lines[i]] + else: + for line in lines: + code_parts += ["\n", " " * taberrfmt, line] + message_parts += highlight("".join(code_parts)) - return message + # add bottom squiggles + message_parts += [ + "\n\n", + " " * taberrfmt + "~" * endpoint_ind, + ("^" if self.point_to_endpoint else "/" if 0 < endpoint_ind < len(lines[-1]) else "|"), + ] + + return "".join(message_parts) def syntax_err(self): """Creates a SyntaxError.""" @@ -217,11 +245,6 @@ def syntax_err(self): err.filename = filename return err - def set_point_to_endpoint(self, point_to_endpoint): - """Sets whether to point to the endpoint.""" - self.point_to_endpoint = point_to_endpoint - return self - class CoconutStyleError(CoconutSyntaxError): """Coconut --strict error.""" diff --git a/coconut/highlighter.py b/coconut/highlighter.py index a12686a06..9bf2b1c71 100644 --- a/coconut/highlighter.py +++ b/coconut/highlighter.py @@ -19,10 +19,12 @@ from coconut.root import * # NOQA +from pygments import highlight from pygments.lexers import Python3Lexer, PythonConsoleLexer from pygments.token import Text, Operator, Keyword, Name, Number from pygments.lexer import words, bygroups from pygments.util import shebang_matches +from pygments.formatters import Terminal256Formatter from coconut.constants import ( highlight_builtins, @@ -34,9 +36,13 @@ shebang_regex, magic_methods, template_ext, - coconut_exceptions, + highlight_exceptions, main_prompt, + style_env_var, + default_style, + fake_styles, ) +from coconut.terminal import logger # ----------------------------------------------------------------------------------------------------------------------- # LEXERS: @@ -94,7 +100,7 @@ class CoconutLexer(Python3Lexer): ] tokens["builtins"] += [ (words(highlight_builtins, suffix=r"\b"), Name.Builtin), - (words(coconut_exceptions, suffix=r"\b"), Name.Exception), + (words(highlight_exceptions, suffix=r"\b"), Name.Exception), ] tokens["numbers"] = [ (r"0b[01_]+", Number.Integer), @@ -113,3 +119,14 @@ def __init__(self, stripnl=False, stripall=False, ensurenl=True, tabsize=tabidea def analyse_text(text): return shebang_matches(text, shebang_regex) + + +def highlight_coconut_for_terminal(code): + """Highlight Coconut code for the terminal.""" + style = os.getenv(style_env_var, default_style) + if style not in fake_styles: + try: + return highlight(code, CoconutLexer(), Terminal256Formatter(style=style)) + except Exception: + logger.log_exc() + return code diff --git a/coconut/icoconut/root.py b/coconut/icoconut/root.py index 0b0cb77f9..7fd1d968c 100644 --- a/coconut/icoconut/root.py +++ b/coconut/icoconut/root.py @@ -34,7 +34,7 @@ ) from coconut.constants import ( PY311, - py_syntax_version, + codemirror_mode, mimetype, version_banner, tutorial_url, @@ -43,15 +43,17 @@ conda_build_env_var, coconut_kernel_kwargs, default_whitespace_chars, + num_assemble_logical_lines_tries, ) from coconut.terminal import logger from coconut.util import override, memoize_with_exceptions, replace_all from coconut.compiler import Compiler -from coconut.compiler.util import should_indent +from coconut.compiler.util import should_indent, paren_change from coconut.command.util import Runner try: from IPython.core.inputsplitter import IPythonInputSplitter + from IPython.core.inputtransformer import CoroutineInputTransformer from IPython.core.interactiveshell import InteractiveShellABC from IPython.core.compilerop import CachingCompiler from IPython.terminal.embed import InteractiveShellEmbed @@ -108,6 +110,10 @@ def syntaxerr_memoized_parse_block(code): # KERNEL: # ----------------------------------------------------------------------------------------------------------------------- +if papermill_translators is not None: + papermill_translators.register("coconut", PythonTranslator) + + if LOAD_MODULE: COMPILER.warm_up(enable_incremental_mode=True) @@ -154,8 +160,8 @@ class CoconutSplitter(IPythonInputSplitter, object): def __init__(self, *args, **kwargs): """Version of __init__ that sets up Coconut code compilation.""" super(CoconutSplitter, self).__init__(*args, **kwargs) - self._original_compile = self._compile - self._compile = self._coconut_compile + self._original_compile, self._compile = self._compile, self._coconut_compile + self.assemble_logical_lines = self._coconut_assemble_logical_lines() def _coconut_compile(self, source, *args, **kwargs): """Version of _compile that checks Coconut code. @@ -170,6 +176,60 @@ def _coconut_compile(self, source, *args, **kwargs): else: return True + @staticmethod + @CoroutineInputTransformer.wrap + def _coconut_assemble_logical_lines(): + """Version of assemble_logical_lines() that respects strings/parentheses/brackets/braces.""" + line = "" + while True: + line = (yield line) + if not line or line.isspace(): + continue + + parts = [] + level = 0 + while line is not None: + + # get no_strs_line + no_strs_line = None + while no_strs_line is None: + no_strs_line = line.strip() + if no_strs_line: + no_strs_line = COMPILER.remove_strs(no_strs_line) + if no_strs_line is None: + # if we're in the middle of a string, fetch a new line + for _ in range(num_assemble_logical_lines_tries): + new_line = (yield None) + if new_line is not None: + break + if new_line is None: + # if we're not able to build a no_strs_line, we should stop doing line joining + level = 0 + no_strs_line = "" + break + else: + line += new_line + + # update paren level + level += paren_change(no_strs_line) + + # put line in parts and break if done + if level < 0: + parts.append(line) + elif no_strs_line.endswith("\\"): + parts.append(line[:-1]) + else: + parts.append(line) + break + + # if we're not done, fetch a new line + for _ in range(num_assemble_logical_lines_tries): + line = (yield None) + if line is not None: + break + + line = ''.join(parts) + INTERACTIVE_SHELL_CODE = ''' input_splitter = CoconutSplitter(line_input_checker=True) input_transformer_manager = CoconutSplitter(line_input_checker=False) @@ -247,10 +307,7 @@ class CoconutKernel(IPythonKernel, object): "name": "coconut", "version": VERSION, "mimetype": mimetype, - "codemirror_mode": { - "name": "ipython", - "version": py_syntax_version, - }, + "codemirror_mode": codemirror_mode, "pygments_lexer": "coconut", "file_extension": code_exts[0], } @@ -293,6 +350,3 @@ class CoconutKernelApp(IPKernelApp, object): classes = IPKernelApp.classes + [CoconutKernel, CoconutShell] kernel_class = CoconutKernel subcommands = {} - - if papermill_translators is not None: - papermill_translators.register("coconut", PythonTranslator) diff --git a/coconut/root.py b/coconut/root.py index 2d622b4d8..44fe2b5c8 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -23,7 +23,7 @@ # VERSION: # ----------------------------------------------------------------------------------------------------------------------- -VERSION = "3.0.4" +VERSION = "3.1.0" VERSION_NAME = None # False for release, int >= 1 for develop DEVELOP = False @@ -61,7 +61,7 @@ def _get_target_info(target): # if a new assignment is added below, a new builtins import should be added alongside it _base_py3_header = r'''from builtins import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr -py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr +py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr _coconut_py_str, _coconut_py_super, _coconut_py_dict = str, super, dict from functools import wraps as _coconut_wraps exec("_coconut_exec = exec") @@ -69,8 +69,8 @@ def _get_target_info(target): # if a new assignment is added below, a new builtins import should be added alongside it _base_py2_header = r'''from __builtin__ import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long -py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr -_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict = raw_input, xrange, int, long, print, str, super, unicode, repr, dict +py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr +_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict, _coconut_py_bytes = raw_input, xrange, int, long, print, str, super, unicode, repr, dict, bytes from functools import wraps as _coconut_wraps from collections import Sequence as _coconut_Sequence from future_builtins import * @@ -96,6 +96,26 @@ def __instancecheck__(cls, inst): return _coconut.isinstance(inst, (_coconut_py_int, _coconut_py_long)) def __subclasscheck__(cls, subcls): return _coconut.issubclass(subcls, (_coconut_py_int, _coconut_py_long)) +class bytes(_coconut_py_bytes): + __slots__ = () + __doc__ = getattr(_coconut_py_bytes, "__doc__", "") + class __metaclass__(type): + def __instancecheck__(cls, inst): + return _coconut.isinstance(inst, _coconut_py_bytes) + def __subclasscheck__(cls, subcls): + return _coconut.issubclass(subcls, _coconut_py_bytes) + def __new__(self, *args): + if not args: + return b"" + elif _coconut.len(args) == 1: + if _coconut.isinstance(args[0], _coconut.int): + return b"\x00" * args[0] + elif _coconut.isinstance(args[0], _coconut.bytes): + return _coconut_py_bytes(args[0]) + else: + return b"".join(_coconut.chr(x) for x in args[0]) + else: + return args[0].encode(*args[1:]) class range(object): __slots__ = ("_xrange",) __doc__ = getattr(_coconut_py_xrange, "__doc__", "") diff --git a/coconut/terminal.py b/coconut/terminal.py index 11fb41cf7..3fe3cad9d 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -47,7 +47,8 @@ taberrfmt, use_packrat_parser, embed_on_internal_exc, - use_color, + use_color_env_var, + get_bool_env_var, error_color_code, log_color_code, ansii_escape, @@ -182,6 +183,16 @@ def logging(self): sys.stdout = old_stdout +def should_use_color(file=None): + """Determine if colors should be used for the given file object.""" + use_color = get_bool_env_var(use_color_env_var, default=None) + if use_color is not None: + return use_color + if get_bool_env_var("CLICOLOR_FORCE") or get_bool_env_var("FORCE_COLOR"): + return True + return file is not None and not isatty(file) + + # ----------------------------------------------------------------------------------------------------------------------- # LOGGER: # ----------------------------------------------------------------------------------------------------------------------- @@ -207,8 +218,10 @@ def __init__(self, other=None): self.patch_logging() @classmethod - def enable_colors(cls): + def enable_colors(cls, file=None): """Attempt to enable CLI colors.""" + if not should_use_color(file): + return False if not cls.colors_enabled: # necessary to resolve https://bugs.python.org/issue40134 try: @@ -216,6 +229,7 @@ def enable_colors(cls): except BaseException: logger.log_exc() cls.colors_enabled = True + return True def copy_from(self, other): """Copy other onto self.""" @@ -265,11 +279,8 @@ def display( else: raise CoconutInternalException("invalid logging level", level) - if use_color is False or (use_color is None and not isatty(file)): - color = None - if color: - self.enable_colors() + color = self.enable_colors(file) and color raw_message = " ".join(str(msg) for msg in messages) # if there's nothing to display but there is a sig, display the sig diff --git a/coconut/tests/constants_test.py b/coconut/tests/constants_test.py index eb3250b29..d60976c19 100644 --- a/coconut/tests/constants_test.py +++ b/coconut/tests/constants_test.py @@ -81,6 +81,7 @@ class TestConstants(unittest.TestCase): def test_defaults(self): assert constants.use_fast_pyparsing_reprs assert not constants.embed_on_internal_exc + assert constants.num_assemble_logical_lines_tries >= 1 def test_fixpath(self): assert os.path.basename(fixpath("CamelCase.py")) == "CamelCase.py" @@ -133,6 +134,7 @@ def test_targets(self): def test_tuples(self): assert isinstance(constants.indchars, tuple) assert isinstance(constants.comment_chars, tuple) + assert isinstance(constants.setuptools_distribution_names, tuple) # ----------------------------------------------------------------------------------------------------------------------- diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 2d7bf296e..155f8e17b 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -57,6 +57,7 @@ PY38, PY39, PY310, + PY312, CPYTHON, adaptive_any_of_env_var, reverse_any_of_env_var, @@ -89,8 +90,10 @@ os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" -# run fewer tests on Windows so appveyor doesn't time out -TEST_ALL = get_bool_env_var("COCONUT_TEST_ALL", not WINDOWS) +TEST_ALL = get_bool_env_var("COCONUT_TEST_ALL", ( + # run fewer tests on Windows so appveyor doesn't time out + not WINDOWS +)) # ----------------------------------------------------------------------------------------------------------------------- @@ -145,14 +148,12 @@ "INTERNAL ERROR", ) ignore_error_lines_with = ( - # ignore SyntaxWarnings containing assert_raises - "assert_raises(", - " raise ", + # ignore SyntaxWarnings containing assert_raises or raise + "raise", ) mypy_snip = "a: str = count()[0]" -mypy_snip_err_2 = '''error: Incompatible types in assignment (expression has type\n"int", variable has type "unicode")''' -mypy_snip_err_3 = '''error: Incompatible types in assignment (expression has type\n"int", variable has type "str")''' +mypy_snip_err = '''error: Incompatible types in assignment (expression has type''' mypy_args = ["--follow-imports", "silent", "--ignore-missing-imports", "--allow-redefinition"] @@ -580,6 +581,15 @@ def using_env_vars(env_vars): os.environ.update(old_env) +def list_kernel_names(): + """Get a list of installed jupyter kernels.""" + stdout, stderr, retcode = call_output(["jupyter", "kernelspec", "list"]) + if not stdout: + stdout, stderr = stderr, "" + assert not retcode and not stderr, stderr + return stdout + + # ----------------------------------------------------------------------------------------------------------------------- # RUNNERS: # ----------------------------------------------------------------------------------------------------------------------- @@ -657,8 +667,19 @@ def run_extras(**kwargs): call_python([os.path.join(dest, "extras.py")], assert_output=True, check_errors=False, stderr_first=True, **kwargs) -def run(args=[], agnostic_target=None, use_run_arg=False, convert_to_import=False, always_sys=False, manage_cache=True, **kwargs): +def run( + args=[], + agnostic_target=None, + use_run_arg=False, + run_directory=False, + convert_to_import=False, + always_sys=False, + manage_cache=True, + **kwargs # no comma for compat +): """Compiles and runs tests.""" + assert use_run_arg + run_directory < 2 + if agnostic_target is None: agnostic_args = args else: @@ -683,12 +704,22 @@ def run(args=[], agnostic_target=None, use_run_arg=False, convert_to_import=Fals if sys.version_info >= (3, 11): comp_311(args, **spec_kwargs) - comp_agnostic(agnostic_args, **kwargs) + if not run_directory: + comp_agnostic(agnostic_args, **kwargs) comp_sys(args, **kwargs) # do non-strict at the end so we get the non-strict header comp_non_strict(args, **kwargs) - if use_run_arg: + if run_directory: + _kwargs = kwargs.copy() + _kwargs["assert_output"] = True + _kwargs["stderr_first"] = True + comp_agnostic( + # remove --strict so that we run with the non-strict header + ["--run"] + [arg for arg in agnostic_args if arg != "--strict"], + **_kwargs + ) + elif use_run_arg: _kwargs = kwargs.copy() _kwargs["assert_output"] = True comp_runner(["--run"] + agnostic_args, **_kwargs) @@ -824,7 +855,7 @@ def test_target_3_snip(self): def test_universal_mypy_snip(self): call( ["coconut", "-c", mypy_snip, "--mypy"], - assert_output=mypy_snip_err_3, + assert_output=mypy_snip_err, check_errors=False, check_mypy=False, ) @@ -832,7 +863,7 @@ def test_universal_mypy_snip(self): def test_sys_mypy_snip(self): call( ["coconut", "--target", "sys", "-c", mypy_snip, "--mypy"], - assert_output=mypy_snip_err_3, + assert_output=mypy_snip_err, check_errors=False, check_mypy=False, ) @@ -840,7 +871,7 @@ def test_sys_mypy_snip(self): def test_no_wrap_mypy_snip(self): call( ["coconut", "--target", "sys", "--no-wrap", "-c", mypy_snip, "--mypy"], - assert_output=mypy_snip_err_3, + assert_output=mypy_snip_err, check_errors=False, check_mypy=False, ) @@ -857,7 +888,8 @@ def test_import_hook(self): with using_coconut(): auto_compilation(True) import runnable - reload(runnable) + if not PY2: # triggers a weird metaclass conflict + reload(runnable) assert runnable.success == "" def test_find_packages(self): @@ -933,13 +965,11 @@ def test_ipython_extension(self): ) def test_kernel_installation(self): + assert icoconut_custom_kernel_name in list_kernel_names() call(["coconut", "--jupyter"], assert_output=kernel_installation_msg) - stdout, stderr, retcode = call_output(["jupyter", "kernelspec", "list"]) - if not stdout: - stdout, stderr = stderr, "" - assert not retcode and not stderr, stderr + kernels = list_kernel_names() for kernel in (icoconut_custom_kernel_name,) + icoconut_default_kernel_names: - assert kernel in stdout + assert kernel in kernels if not WINDOWS and not PYPY: def test_jupyter_console(self): @@ -987,18 +1017,20 @@ def test_always_sys(self): def test_target(self): run(agnostic_target=(2 if PY2 else 3)) - def test_standalone(self): - run(["--standalone"]) + def test_no_tco(self): + run(["--no-tco"]) def test_package(self): run(["--package"]) - def test_no_tco(self): - run(["--no-tco"]) + # TODO: re-allow these once we figure out what's causing the strange unreproducible errors with them on py3.12 + if not PY312: + def test_standalone(self): + run(["--standalone"]) - if PY35: - def test_no_wrap(self): - run(["--no-wrap"]) + if PY35: + def test_no_wrap(self): + run(["--no-wrap"]) if TEST_ALL: if CPYTHON: @@ -1021,6 +1053,9 @@ def test_and(self): def test_run_arg(self): run(use_run_arg=True) + def test_run_dir(self): + run(run_directory=True) + if not PYPY and not PY26: def test_jobs_zero(self): run(["--jobs", "0"]) diff --git a/coconut/tests/src/cocotest/agnostic/__main__.coco b/coconut/tests/src/cocotest/agnostic/__main__.coco new file mode 100644 index 000000000..4df76fafc --- /dev/null +++ b/coconut/tests/src/cocotest/agnostic/__main__.coco @@ -0,0 +1,20 @@ +import sys +import os.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import cocotest +from cocotest.main import run_main + + +def main() -> bool: + print(".", end="", flush=True) # . + assert cocotest.__doc__ + assert run_main( + outer_MatchError=MatchError, + test_easter_eggs="--test-easter-eggs" in sys.argv, + ) is True + return True + + +if __name__ == "__main__": + assert main() is True diff --git a/coconut/tests/src/cocotest/agnostic/main.coco b/coconut/tests/src/cocotest/agnostic/main.coco index 97c9d3df7..56bfad400 100644 --- a/coconut/tests/src/cocotest/agnostic/main.coco +++ b/coconut/tests/src/cocotest/agnostic/main.coco @@ -90,7 +90,7 @@ def run_main(outer_MatchError, test_easter_eggs=False) -> bool: if using_tco: assert hasattr(tco_func, "_coconut_tco_func") assert tco_test() is True - if outer_MatchError.__module__ != "__main__": + if not outer_MatchError.__module__.endswith("__main__"): assert package_test(outer_MatchError) is True print_dot() # ....... diff --git a/coconut/tests/src/cocotest/agnostic/primary_1.coco b/coconut/tests/src/cocotest/agnostic/primary_1.coco index bc85179c7..b8e9a44d5 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_1.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_1.coco @@ -973,8 +973,8 @@ def primary_test_1() -> bool: def ret1() = 1 assert ret1() == 1 assert (.,2)(1) == (1, 2) == (1,.)(2) - assert [[];] == [] - assert [[];;] == [[]] + assert [[];] == [] == [;]([]) + assert [[];;] == [[]] == [;;]([]) assert [1;] == [1] == [[1];] assert [1;;] == [[1]] == [[1];;] assert [[[1]];;] == [[1]] == [[1;];;] @@ -1009,7 +1009,7 @@ def primary_test_1() -> bool: 5, 6 ;; 7, 8] == [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] a = [1,2 ;; 3,4] - assert [a; a] == [[1,2,1,2], [3,4,3,4]] + assert [a; a] == [[1,2,1,2], [3,4,3,4]] == [;](a, a) assert [a;; a] == [[1,2],[3,4],[1,2],[3,4]] == [*a, *a] assert [a ;;; a] == [[[1,2],[3,4]], [[1,2],[3,4]]] == [a, a] assert [a ;;;; a] == [[a], [a]] diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index b4e55fb2e..ee8ca556b 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -303,8 +303,8 @@ def primary_test_2() -> bool: a_dict = {"a": 1, "b": 2} a_dict |= {"a": 10, "c": 20} assert a_dict == {"a": 10, "b": 2, "c": 20} == {"a": 1, "b": 2} | {"a": 10, "c": 20} - assert ["abc" ; "def"] == ['abc', 'def'] - assert ["abc" ;; "def"] == [['abc'], ['def']] + assert ["abc" ; "def"] == ['abc', 'def'] == [;] <*| ("abc", "def") + assert ["abc" ;; "def"] == [['abc'], ['def']] == [;;] <*| ("abc", "def") assert {"a":0, "b":1}$[0] == "a" assert (|0, NotImplemented, 2|)$[1] is NotImplemented assert m{1, 1, 2} |> fmap$(.+1) == m{2, 2, 3} @@ -410,6 +410,51 @@ def primary_test_2() -> bool: assert 0x == 0 == 0 x assert 0xff == 255 == 0x100-1 assert 11259375 == 0xabcdef + assert [[] ;; [] ;;;] == [[[], []]] + assert ( + 1 + |> [. ; 2] + |> [[3; 4] ;; .] + ) == [3; 4;; 1; 2] == [[3; 4] ;; .]([. ; 2](1)) + arr: Any = 1 + arr |>= [. ; 2] + arr |>= [[3; 4] ;; .] + assert arr == [3; 4;; 1; 2] == [[3; 4] ;; .] |> call$(?, [. ; 2] |> call$(?, 1)) + assert (if)(10, 20, 30) == 20 == (if)(0, 10, 20) + assert all_equal([], to=10) + assert all_equal([10; 10; 10; 10], to=10) + assert not all_equal([1, 1], to=10) + assert not 0in[1,2,3] + if"0":assert True + if"0": + assert True + b = "b" + assert "abc".find b == 1 + assert_raises(-> "a" 10, TypeError) + assert (,) ↤* (1, 2, 3) == (1, 2, 3) + assert (,) ↤? None is None + assert (,) ↤*? None is None # type: ignore + assert '''\u2029'''!='''\n''' + assert b"a" `isinstance` bytes + assert b"a" `isinstance` py_bytes + assert bytes() == b"" + assert bytes(10) == b"\x00" * 10 + assert bytes([35, 40]) == b'#(' + assert bytes(b"abc") == b"abc" == bytes("abc", "utf-8") + assert b"Abc" |> fmap$(.|32) == b"abc" + assert bytearray(b"Abc") |> fmap$(.|32) == bytearray(b"abc") + assert (bytearray(b"Abc") |> fmap$(.|32)) `isinstance` bytearray + assert 10 |> lift(+)((x -> x), (def y -> y)) == 20 + assert (x -> def y -> (x, y))(1)(2) == (1, 2) == (x -> copyclosure def y -> (x, y))(1)(2) # type: ignore + assert ((x, y) -> def z -> (x, y, z))(1, 2)(3) == (1, 2, 3) == (x -> y -> def z -> (x, y, z))(1)(2)(3) # type: ignore + assert [def x -> (x, y) for y in range(10)] |> map$(call$(?, 10)) |> list == [(10, y) for y in range(10)] + assert [x -> (x, y) for y in range(10)] |> map$(call$(?, 10)) |> list == [(10, 9) for y in range(10)] + assert [=> y for y in range(2)] |> map$(call) |> list == [1, 1] + assert [def => y for y in range(2)] |> map$(call) |> list == [0, 1] + assert (x -> x -> def y -> (x, y))(1)(2)(3) == (2, 3) + match def maybe_dup(x, y=x) = (x, y) + assert maybe_dup(1) == (1, 1) == maybe_dup(x=1) + assert maybe_dup(1, 2) == (1, 2) == maybe_dup(x=1, y=2) with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 813fe05b0..45d96810a 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -1029,7 +1029,7 @@ forward 2""") == 900 assert Phi((,), (.+1), (.-1)) <| 5 == (6, 4) assert Psi((,), (.+1), 3) <| 4 == (4, 5) assert D1((,), 0, 1, (.+1)) <| 1 == (0, 1, 2) - assert D2((+), (.*2), 3, (.+1)) <| 4 == 11 + assert D2((+), (.*2), 3, (.+1)) <| 4 == 11 == D2_((+), (.*2), (.+1))(3, 4) assert E((+), 10, (*), 2) <| 3 == 16 assert Phi1((,), (+), (*), 2) <| 3 == (5, 6) assert BE((,), (+), 10, 2, (*), 2) <| 3 == (12, 6) @@ -1075,6 +1075,10 @@ forward 2""") == 900 assert pickle_round_trip(.loc[0]) <| (loc=[10]) == 10 assert pickle_round_trip(.method(0)) <| (method=const 10) == 10 assert pickle_round_trip(.method(x=10)) <| (method=x -> x) == 10 + assert sq_and_t2p1(10) == (100, 21) + assert first_false_and_last_true([3, 2, 1, 0, "11", "1", ""]) == (0, "1") + assert ret_args_kwargs ↤** dict(a=1) == ((), dict(a=1)) + assert ret_args_kwargs ↤**? None is None with process_map.multiple_sequential_calls(): # type: ignore assert process_map(tuple <.. (|>)$(to_sort), qsorts) |> list == [to_sort |> sorted |> tuple] * len(qsorts) diff --git a/coconut/tests/src/cocotest/agnostic/util.coco b/coconut/tests/src/cocotest/agnostic/util.coco index c06598be7..f58003eec 100644 --- a/coconut/tests/src/cocotest/agnostic/util.coco +++ b/coconut/tests/src/cocotest/agnostic/util.coco @@ -34,6 +34,14 @@ def assert_raises(c, exc): else: raise AssertionError(f"{c} failed to raise exception {exc}") +def x `typed_eq` y = (type(x), x) == (type(y), y) + +def pickle_round_trip(obj) = ( + obj + |> pickle.dumps + |> pickle.loads +) + try: prepattern() # type: ignore except NameError, TypeError: @@ -44,14 +52,6 @@ except NameError, TypeError: return addpattern(func, base_func, **kwargs) return pattern_prepender -def x `typed_eq` y = (type(x), x) == (type(y), y) - -def pickle_round_trip(obj) = ( - obj - |> pickle.dumps - |> pickle.loads -) - # Old functions: old_fmap = fmap$(starmap_over_mappings=True) @@ -1544,6 +1544,24 @@ def BE(f, g, x, y, h, z) = lift(f)(const(g x y), h$(z)) def on(b, u) = (,) ..> map$(u) ..*> b +def D2_(f, g, h) = lift_apart(f)(g, h) + + +# branching +branch = lift(,) +branched = lift_apart(,) + +sq_and_t2p1 = ( + branch(ident, (.*2)) + ..*> branched((.**2), (.+1)) # type: ignore +) + +first_false_and_last_true = ( + lift(,)(ident, reversed) + ..*> lift_apart(,)(dropwhile$(bool), dropwhile$(not)) # type: ignore + ..*> lift_apart(,)(.$[0], .$[0]) # type: ignore +) + # maximum difference def maxdiff1(ns) = ( diff --git a/coconut/tests/src/cocotest/target_3/py3_test.coco b/coconut/tests/src/cocotest/target_3/py3_test.coco index acdef4f73..8ace419a2 100644 --- a/coconut/tests/src/cocotest/target_3/py3_test.coco +++ b/coconut/tests/src/cocotest/target_3/py3_test.coco @@ -27,14 +27,14 @@ def py3_test() -> bool: čeština = "czech" assert čeština == "czech" class HasExecMethod: - def exec(self, x) = x() + def \exec(self, x) = x() has_exec = HasExecMethod() assert hasattr(has_exec, "exec") assert has_exec.exec(-> 1) == 1 def exec_rebind_test(): - exec = 1 + \exec = 1 assert exec + 1 == 2 - def exec(x) = x + def \exec(x) = x assert exec(1) == 1 return True assert exec_rebind_test() is True diff --git a/coconut/tests/src/cocotest/target_311/py311_test.coco b/coconut/tests/src/cocotest/target_311/py311_test.coco index a2c655815..c527cf3a4 100644 --- a/coconut/tests/src/cocotest/target_311/py311_test.coco +++ b/coconut/tests/src/cocotest/target_311/py311_test.coco @@ -7,4 +7,6 @@ def py311_test() -> bool: except* ValueError as err: got_err = err assert repr(got_err) == repr(multi_err), (got_err, multi_err) + assert [1, 2, 3][x := 1] == 2 + assert x == 1 return True diff --git a/coconut/tests/src/cocotest/target_38/py38_test.coco b/coconut/tests/src/cocotest/target_38/py38_test.coco index 5df470874..13ed72b9c 100644 --- a/coconut/tests/src/cocotest/target_38/py38_test.coco +++ b/coconut/tests/src/cocotest/target_38/py38_test.coco @@ -7,4 +7,9 @@ def py38_test() -> bool: assert a == 3 == b def f(x: int, /, y: int) -> int = x + y assert f(1, y=2) == 3 + assert 10 |> (x := .) == 10 == x + assert 10 |> (x := .) |> (. + 1) == 11 + assert x == 10 + assert not consume(y := i for i in range(10)) + assert y == 9 return True diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 0d13f39d3..0bb22fbde 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -1,5 +1,8 @@ +import os from collections.abc import Sequence +os.environ["COCONUT_USE_COLOR"] = "False" + from coconut.__coconut__ import consume as coc_consume from coconut.constants import ( IPY, @@ -7,6 +10,7 @@ from coconut.constants import ( PY34, PY35, PY36, + PY39, PYPY, ) # type: ignore from coconut._pyparsing import USE_COMPUTATION_GRAPH # type: ignore @@ -112,6 +116,7 @@ def test_setup_none() -> bool: assert "==" not in parse("None = None") assert parse("(1\f+\f2)", "lenient") == "(1 + 2)" == parse("(1\f+\f2)", "eval") assert "Ellipsis" not in parse("x: ... = 1") + assert parse("linebreaks = '\x0b\x0c\x1c\x1d\x1e'") # things that don't parse correctly without the computation graph if USE_COMPUTATION_GRAPH: @@ -127,8 +132,11 @@ def test_setup_none() -> bool: assert_raises(-> parse("\\("), CoconutSyntaxError) assert_raises(-> parse("if a:\n b\n c"), CoconutSyntaxError) assert_raises(-> parse("_coconut"), CoconutSyntaxError) - assert_raises(-> parse("[;]"), CoconutSyntaxError) + assert_raises(-> parse("[; ;]"), CoconutSyntaxError) assert_raises(-> parse("[; ;; ;]"), CoconutSyntaxError) + assert_raises(-> parse("[; ; ;;]"), CoconutSyntaxError) + assert_raises(-> parse("[[] ;;; ;; [] ;]"), CoconutSyntaxError) + assert_raises(-> parse("[; []]"), CoconutSyntaxError) assert_raises(-> parse("f$()"), CoconutSyntaxError) assert_raises(-> parse("f(**x, y)"), CoconutSyntaxError) assert_raises(-> parse("def f(x) = return x"), CoconutSyntaxError) @@ -201,13 +209,16 @@ cannot reassign type variable 'T' (use explicit '\T' syntax if intended) (line 1 assert_raises(-> parse("$"), CoconutParseError) assert_raises(-> parse("@"), CoconutParseError) assert_raises(-> parse("range(1,10) |> reduce$(*, initializer = 1000) |> print"), CoconutParseError, err_has=( - " \\~~~~~~~~~~~~~~~~~~~~~~~^", - " \\~~~~~~~~~~~~^", + "\n \\~~^", + "\n \\~~~~~~~~~~~~~~~~~~~~~~~^", + )) + assert_raises(-> parse("a := b"), CoconutParseError, err_has=( + "\n ^", + "\n \\~^", )) - assert_raises(-> parse("a := b"), CoconutParseError, err_has=" \\~^") assert_raises(-> parse("1 + return"), CoconutParseError, err_has=( - " \\~~~^", - " \\~~~~^", + "\n \\~~^", + "\n \\~~~~^", )) assert_raises(-> parse(""" def f() = @@ -222,20 +233,39 @@ def f() = """ assert 2 ~^ - """.strip() + """.strip(), + )) + assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has=( + "\n ^", + "\n \\~~~~~~^", + )) + assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has=( + "\n ^", + "\n \\~~~~~^", + )) + assert_raises(-> parse("A. ."), CoconutParseError, err_has=( + "\n \\~^", + "\n \\~~^", )) - assert_raises(-> parse('b"abc" "def"'), CoconutParseError, err_has=" \\~~~~~~^") - assert_raises(-> parse('"abc" b"def"'), CoconutParseError, err_has=" \\~~~~~^") - assert_raises(-> parse('"a" 10'), CoconutParseError, err_has=" \\~~~^") - assert_raises(-> parse("A. ."), CoconutParseError, err_has=" \\~~^") + assert_raises(-> parse("f([] {})"), CoconutParseError, err_has=( + "\n \\~~~^", + "\n \\~~~~^", + )) + assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=( + "\n ^", + "\n \\~^", + "\n \\~~^", + )) + assert_raises(-> parse("(. if 1)"), CoconutParseError, err_has=( + "\n ^", + "\n \\~~^", + )) + assert_raises(-> parse('''f"""{ }"""'''), CoconutSyntaxError, err_has="parsing failed for format string expression") - assert_raises(-> parse("f([] {})"), CoconutParseError, err_has=" \\~~~~^") - assert_raises(-> parse("return = 1"), CoconutParseError, err_has='invalid use of the keyword "return"') assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") - assert_raises(-> parse("0xfgf"), CoconutParseError, err_has=" \~~^") try: parse(""" @@ -257,7 +287,7 @@ def gam_eps_rate(bitarr) = ( err_str = str(err) assert "misplaced '?'" in err_str if not PYPY: - assert """ + assert r""" |> map$(int(?, 2)) \~~~~^""" in err_str or """ |> map$(int(?, 2)) @@ -283,8 +313,14 @@ def g(x) = x assert parse("def f(x):\n ${var}", "xonsh") == "def f(x):\n ${var}\n" assert "data ABC" not in parse("data ABC:\n ${var}", "xonsh") - assert parse('"abc" "xyz"', "lenient") == "'abcxyz'" + assert "builder" not in parse("def x -> x", "lenient") + assert parse("def x -> x", "lenient").count("def") == 1 + assert "builder" in parse("x -> def y -> (x, y)", "lenient") + assert parse("x -> def y -> (x, y)", "lenient").count("def") == 2 + assert "builder" in parse("[def x -> (x, y) for y in range(10)]", "lenient") + assert parse("[def x -> (x, y) for y in range(10)]", "lenient").count("def") == 2 + assert parse("123 # derp", "lenient") == "123 # derp" return True @@ -362,6 +398,22 @@ import abc except CoconutStyleError as err: assert str(err) == """found unused import 'abc' (add '# NOQA' to suppress) (remove --strict to downgrade to a warning) (line 1) import abc""" + assert_raises(-> parse("""class A(object): + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15"""), CoconutStyleError, err_has="\n ...\n") setup(line_numbers=False, strict=True, target="sys") assert_raises(-> parse("await f x"), CoconutParseError, err_has='invalid use of the keyword "await"') @@ -419,6 +471,11 @@ async def async_map_test() = # Compiled Coconut: ----------------------------------------------------------- type Num = int | float""".strip()) + assert parse("type L[T] = list[T]").strip().endswith(""" +# Compiled Coconut: ----------------------------------------------------------- + +_coconut_typevar_T_0 = _coconut.typing.TypeVar("_coconut_typevar_T_0") +type L = list[_coconut_typevar_T_0]""".strip()) setup(line_numbers=False, minify=True) assert parse("123 # derp", "lenient") == "123# derp" @@ -462,7 +519,7 @@ def test_kernel() -> bool: captured_messages: list[tuple] = [] else: captured_messages: list = [] - def send(self, stream, msg_or_type, content, *args, **kwargs): + def send(self, stream, msg_or_type, content=None, *args, **kwargs): self.captured_messages.append((msg_or_type, content)) if PY35: @@ -515,6 +572,16 @@ def test_kernel() -> bool: assert keyword_complete_result["cursor_start"] == 0 assert keyword_complete_result["cursor_end"] == 1 + assert k.do_execute("ident$(\n?,\n)(99)", False, True, {}, True) |> unwrap_future$(loop) |> .["status"] == "ok" + captured_msg_type, captured_msg_content = fake_session.captured_messages[-1] + assert captured_msg_content is None + assert captured_msg_type["content"]["data"]["text/plain"] == "99" + + assert k.do_execute('"""\n(\n)\n"""', False, True, {}, True) |> unwrap_future$(loop) |> .["status"] == "ok" + captured_msg_type, captured_msg_content = fake_session.captured_messages[-1] + assert captured_msg_content is None + assert captured_msg_type["content"]["data"]["text/plain"] == "'()'" + return True @@ -570,6 +637,9 @@ def test_numpy() -> bool: assert all_equal(np.array([1, 1])) assert all_equal(np.array([1, 1;; 1, 1])) assert not all_equal(np.array([1, 1;; 1, 2])) + assert all_equal(np.array([]), to=10) + assert all_equal(np.array([10; 10;; 10; 10]), to=10) + assert not all_equal(np.array([1, 1]), to=10) assert ( cartesian_product(np.array([1, 2]), np.array([3, 4])) `np.array_equal` @@ -631,22 +701,46 @@ def test_pandas() -> bool: return True +def test_xarray() -> bool: + import xarray as xr + import numpy as np + def ds1 `dataset_equal` ds2 = (ds1 == ds2).all().values() |> all + da = xr.DataArray([10, 11;; 12, 13], dims=["x", "y"]) + ds = xr.Dataset({"a": da, "b": da + 10}) + assert ds$[0] == "a" + ds_ = [da; da + 10] + assert ds `dataset_equal` ds_ # type: ignore + ds__ = [da; da |> fmap$(.+10)] + assert ds `dataset_equal` ds__ # type: ignore + assert ds `dataset_equal` (ds |> fmap$(ident)) + assert da.to_numpy() `np.array_equal` (da |> fmap$(ident) |> .to_numpy()) + assert (ds |> fmap$(r -> r["a"] + r["b"]) |> .to_numpy()) `np.array_equal` np.array([30; 32;; 34; 36]) + assert not all_equal(da) + assert not all_equal(ds) + assert multi_enumerate(da) |> list == [((0, 0), 10), ((0, 1), 11), ((1, 0), 12), ((1, 1), 13)] + assert cartesian_product(da.sel(x=0), da.sel(x=1)) `np.array_equal` np.array([10; 12;; 10; 13;; 11; 12;; 11; 13]) # type: ignore + return True + + def test_extras() -> bool: if not PYPY and (PY2 or PY34): assert test_numpy() is True print(".", end="") if not PYPY and PY36: assert test_pandas() is True # . + print(".", end="") + if not PYPY and PY39: + assert test_xarray() is True # .. print(".") # newline bc we print stuff after this - assert test_setup_none() is True # .. + assert test_setup_none() is True # ... print(".") # ditto - assert test_convenience() is True # ... + assert test_convenience() is True # .... # everything after here uses incremental parsing, so it must come last print(".", end="") - assert test_incremental() is True # .... + assert test_incremental() is True # ..... if IPY: print(".", end="") - assert test_kernel() is True # ..... + assert test_kernel() is True # ...... return True diff --git a/coconut/tests/src/runner.coco b/coconut/tests/src/runner.coco index 3265cf493..62a090d92 100644 --- a/coconut/tests/src/runner.coco +++ b/coconut/tests/src/runner.coco @@ -5,18 +5,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) import pytest pytest.register_assert_rewrite(py_str("cocotest")) -import cocotest -from cocotest.main import run_main - - -def main() -> bool: - print(".", end="", flush=True) # . - assert cocotest.__doc__ - assert run_main( - outer_MatchError=MatchError, - test_easter_eggs="--test-easter-eggs" in sys.argv, - ) is True - return True +from cocotest.__main__ import main if __name__ == "__main__": diff --git a/coconut/util.py b/coconut/util.py index b0e04be68..f9f4905d0 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -8,7 +8,7 @@ """ Author: Evan Hubinger License: Apache 2.0 -Description: Installer for the Coconut Jupyter kernel. +Description: Base Coconut utilities. """ # ----------------------------------------------------------------------------------------------------------------------- @@ -49,6 +49,7 @@ icoconut_custom_kernel_file_loc, WINDOWS, non_syntactic_newline, + setuptools_distribution_names, ) @@ -150,8 +151,8 @@ def clip(num, min=None, max=None): ) -def logical_lines(text, keep_newlines=False): - """Iterate over the logical code lines in text.""" +def literal_lines(text, keep_newlines=False, yield_next_line_is_real=False): + """Iterate over the literal code lines in text.""" prev_content = None for line in text.splitlines(True): real_line = True @@ -162,11 +163,14 @@ def logical_lines(text, keep_newlines=False): if not keep_newlines: line = line[:-1] else: - if prev_content is None: - prev_content = "" - prev_content += line + if not yield_next_line_is_real: + if prev_content is None: + prev_content = "" + prev_content += line real_line = False - if real_line: + if yield_next_line_is_real: + yield real_line, line + elif real_line: if prev_content is not None: line = prev_content + line prev_content = None @@ -265,6 +269,9 @@ def __missing__(self, key): class dictset(dict, object): """A set implemented using a dictionary to get ordering benefits.""" + def __init__(self, items=()): + super(dictset, self).__init__((x, True) for x in items) + def __bool__(self): return len(self) > 0 # fixes py2 issue @@ -324,6 +331,20 @@ def replace_all(inputstr, all_to_replace, replace_to): return inputstr +def highlight(code, force=False): + """Attempt to highlight Coconut code for the terminal.""" + from coconut.terminal import logger # hide to remove circular deps + if force or logger.enable_colors(sys.stdout) and logger.enable_colors(sys.stderr): + try: + from coconut.highlighter import highlight_coconut_for_terminal + except ImportError: + logger.log_exc() + else: + code_base, code_white = split_trailing_whitespace(code) + return highlight_coconut_for_terminal(code_base).rstrip() + code_white + return code + + # ----------------------------------------------------------------------------------------------------------------------- # VERSIONING: # ----------------------------------------------------------------------------------------------------------------------- @@ -369,12 +390,10 @@ def get_displayable_target(target): def get_kernel_data_files(argv): """Given sys.argv, write the custom kernel file and return data_files.""" - if any(arg.startswith("bdist") for arg in argv): + if any(arg.startswith(setuptools_distribution_names) for arg in argv): executable = "python" - elif any(arg.startswith("install") for arg in argv): - executable = sys.executable else: - return [] + executable = sys.executable install_custom_kernel(executable) return [ (