From c20ba7feab6bdea6c299721310e04dbc10551cc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 14 Jun 2023 17:44:23 +0200 Subject: [PATCH] refactor(ir): glue patterns and rules together Replace the previous rules-based validation system with the new pattern matching system. This enables to use type annotations for operation definitions as well as better error handling. Also lays the groundwork to enable static type checking in the future. BREAKING CHANGE: the `ibis.common.validators` module has been removed and all validation rules from `ibis.expr.rules`, either use typehints or patterns from `ibis.common.patterns` --- docs/how_to/extending/elementwise.ipynb | 7 +- docs/how_to/extending/reduction.ipynb | 10 +- docs/tutorial/ibis-for-sql-users.ipynb | 6 +- .../bigquery/tests/unit/test_compiler.py | 3 +- ibis/backends/bigquery/udf/__init__.py | 8 +- ibis/backends/clickhouse/compiler/values.py | 6 +- .../clickhouse/tests/test_aggregations.py | 7 +- .../dask/tests/execution/test_functions.py | 3 +- ibis/backends/impala/tests/test_udf.py | 4 +- .../impala/tests/test_unary_builtins.py | 7 +- ibis/backends/impala/udf.py | 14 +- .../pandas/tests/execution/test_functions.py | 3 +- .../pandas/tests/execution/test_window.py | 4 +- ibis/backends/postgres/udf.py | 6 +- ibis/backends/tests/test_array.py | 78 +-- ibis/backends/tests/test_generic.py | 11 +- ibis/backends/tests/test_set_ops.py | 4 +- ibis/backends/tests/test_string.py | 3 +- ibis/backends/tests/test_temporal.py | 47 +- ibis/backends/tests/test_udf.py | 5 +- ibis/backends/tests/test_vectorized_udf.py | 2 +- ibis/common/annotations.py | 51 +- ibis/common/grounds.py | 80 ++- ibis/common/patterns.py | 221 ++++-- ibis/common/temporal.py | 15 +- ibis/common/tests/test_annotations.py | 106 +-- ibis/common/tests/test_graph.py | 24 +- ibis/common/tests/test_grounds.py | 236 +++++-- ibis/common/tests/test_grounds_py310.py | 20 +- ibis/common/tests/test_patterns.py | 17 +- ibis/common/tests/test_temporal.py | 23 +- ibis/common/tests/test_typing.py | 54 +- ibis/common/tests/test_validators.py | 300 --------- ibis/common/typing.py | 41 +- ibis/common/validators.py | 632 ------------------ ibis/config.py | 4 +- ibis/expr/analysis.py | 3 +- ibis/expr/builders.py | 77 +-- ibis/expr/datashape.py | 69 ++ ibis/expr/datatypes/core.py | 53 +- ibis/expr/datatypes/tests/test_core.py | 117 +++- ibis/expr/datatypes/tests/test_parse.py | 3 +- ibis/expr/datatypes/tests/test_value.py | 12 + ibis/expr/datatypes/value.py | 28 +- ibis/expr/operations/analytic.py | 71 +- ibis/expr/operations/arrays.py | 65 +- ibis/expr/operations/core.py | 80 ++- ibis/expr/operations/generic.py | 201 +++--- ibis/expr/operations/geospatial.py | 47 +- ibis/expr/operations/histograms.py | 28 +- ibis/expr/operations/json.py | 8 +- ibis/expr/operations/logical.py | 76 +-- ibis/expr/operations/maps.py | 25 +- ibis/expr/operations/numeric.py | 66 +- ibis/expr/operations/reductions.py | 108 +-- ibis/expr/operations/relations.py | 193 +++--- ibis/expr/operations/sortkeys.py | 35 +- ibis/expr/operations/strings.py | 119 ++-- ibis/expr/operations/structs.py | 12 +- ibis/expr/operations/temporal.py | 143 ++-- ibis/expr/operations/tests/__init__.py | 0 ibis/expr/operations/tests/test_generic.py | 105 +++ ibis/expr/operations/udf.py | 11 +- ibis/expr/operations/vectorized.py | 26 +- ibis/expr/operations/window.py | 82 ++- ibis/expr/rules.py | 406 +---------- ibis/expr/schema.py | 4 +- ibis/expr/tests/test_datashape.py | 83 +++ ibis/expr/tests/test_rules.py | 292 -------- ibis/expr/tests/test_schema.py | 4 +- ibis/expr/types/arrays.py | 16 +- ibis/expr/types/core.py | 19 +- ibis/expr/types/generic.py | 15 +- ibis/expr/types/groupby.py | 1 - ibis/expr/types/numeric.py | 16 - ibis/expr/types/relations.py | 9 +- ibis/expr/types/strings.py | 2 +- ibis/expr/types/temporal.py | 95 +-- ibis/formats/tests/test_numpy.py | 2 +- ibis/legacy/udf/validate.py | 8 +- ibis/selectors.py | 2 +- ibis/tests/expr/test_analytics.py | 11 +- ibis/tests/expr/test_decimal.py | 3 +- ibis/tests/expr/test_literal.py | 2 +- ibis/tests/expr/test_operations.py | 200 +++--- ibis/tests/expr/test_set_operations.py | 5 +- ibis/tests/expr/test_table.py | 18 +- ibis/tests/expr/test_udf.py | 10 +- ibis/tests/expr/test_value_exprs.py | 86 +-- ibis/tests/expr/test_visualize.py | 11 +- ibis/tests/expr/test_window_frames.py | 45 +- ibis/tests/test_config.py | 3 +- ibis/util.py | 5 - pyproject.toml | 1 + 94 files changed, 2318 insertions(+), 2981 deletions(-) delete mode 100644 ibis/common/tests/test_validators.py delete mode 100644 ibis/common/validators.py create mode 100644 ibis/expr/datashape.py create mode 100644 ibis/expr/operations/tests/__init__.py create mode 100644 ibis/expr/operations/tests/test_generic.py create mode 100644 ibis/expr/tests/test_datashape.py delete mode 100644 ibis/expr/tests/test_rules.py diff --git a/docs/how_to/extending/elementwise.ipynb b/docs/how_to/extending/elementwise.ipynb index 4289757f104f..469c824a6a95 100644 --- a/docs/how_to/extending/elementwise.ipynb +++ b/docs/how_to/extending/elementwise.ipynb @@ -45,11 +45,12 @@ "source": [ "import ibis.expr.datatypes as dt\n", "import ibis.expr.rules as rlz\n", - "from ibis.expr.operations import ValueOp\n", + "import ibis.expr.datashape as ds\n", + "from ibis.expr.operations import Value\n", "\n", "\n", - "class JulianDay(ValueOp):\n", - " arg = rlz.string\n", + "class JulianDay(Value):\n", + " arg: Value[dt.String, ds.Any]\n", "\n", " output_dtype = dt.float32\n", " output_shape = rlz.shape_like('arg')" diff --git a/docs/how_to/extending/reduction.ipynb b/docs/how_to/extending/reduction.ipynb index 8cafa24d0653..02c9d3b0a9da 100644 --- a/docs/how_to/extending/reduction.ipynb +++ b/docs/how_to/extending/reduction.ipynb @@ -53,17 +53,19 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Optional\n", "import ibis.expr.datatypes as dt\n", + "import ibis.expr.datashape as ds\n", "import ibis.expr.rules as rlz\n", - "from ibis.expr.operations import Reduction\n", + "from ibis.expr.operations import Reduction, Value\n", "\n", "\n", "class LastDate(Reduction):\n", - " arg = rlz.column(rlz.date)\n", - " where = rlz.optional(rlz.boolean)\n", + " arg: Value[dt.Date, ds.Any]\n", + " where: Optional[Value[dt.Boolean, ds.Any]] = None\n", "\n", " output_dtype = rlz.dtype_like('arg')\n", - " output_shape = rlz.Shape.SCALAR" + " output_shape = ds.scalar" ] }, { diff --git a/docs/tutorial/ibis-for-sql-users.ipynb b/docs/tutorial/ibis-for-sql-users.ipynb index 365ef56f5ca4..e1fdf74d2982 100644 --- a/docs/tutorial/ibis-for-sql-users.ipynb +++ b/docs/tutorial/ibis-for-sql-users.ipynb @@ -845,13 +845,13 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "dcb1b530-3d01-4b1a-a9e7-83cabb6df313", "metadata": {}, "source": [ "The default for sorting is in ascending order. To reverse the sort\n", - "direction of any key, either wrap it in `ibis.desc` or pass a tuple with\n", - "`False` as the second value:" + "direction of any key, wrap it in `ibis.desc`:" ] }, { @@ -862,7 +862,7 @@ "outputs": [], "source": [ "sorted = events.order_by(\n", - " [ibis.desc('event_type'), (events.ts.month(), False)]\n", + " [ibis.desc('event_type'), ibis.desc(events.ts.month())]\n", ").limit(100)\n", "\n", "ibis.show_sql(sorted)" diff --git a/ibis/backends/bigquery/tests/unit/test_compiler.py b/ibis/backends/bigquery/tests/unit/test_compiler.py index 32777460187e..ef2e81654a59 100644 --- a/ibis/backends/bigquery/tests/unit/test_compiler.py +++ b/ibis/backends/bigquery/tests/unit/test_compiler.py @@ -13,6 +13,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import _ +from ibis.common.patterns import ValidationError to_sql = ibis.bigquery.compile @@ -557,7 +558,7 @@ def test_cov(alltypes, how, snapshot): def test_cov_invalid_how(alltypes): d = alltypes.double_col - with pytest.raises(ValueError): + with pytest.raises(ValidationError): d.cov(d, how="error") diff --git a/ibis/backends/bigquery/udf/__init__.py b/ibis/backends/bigquery/udf/__init__.py index 883390561793..c9dc31417d4c 100644 --- a/ibis/backends/bigquery/udf/__init__.py +++ b/ibis/backends/bigquery/udf/__init__.py @@ -268,7 +268,10 @@ def js( if libraries is None: libraries = [] - udf_node_fields = {name: rlz.value(type_) for name, type_ in params.items()} + udf_node_fields = { + name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_) + for name, type_ in params.items() + } udf_node_fields["output_dtype"] = output_type udf_node_fields["output_shape"] = rlz.shape_like("args") @@ -362,10 +365,9 @@ def sql( """ validate_output_type(output_type) udf_node_fields = { - name: rlz.any if type_ == "ANY TYPE" else rlz.value(type_) + name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_) for name, type_ in params.items() } - return_type = BigQueryType.from_ibis(dt.dtype(output_type)) udf_node_fields["output_dtype"] = output_type diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 8e4354690881..7eb53a24d6d0 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -16,7 +16,6 @@ import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops -import ibis.expr.rules as rlz from ibis.backends.base.sql.registry import helpers from ibis.backends.clickhouse.datatypes import serialize @@ -850,10 +849,7 @@ def tr(op, *, cache, **kw): left_arg = helpers.parenthesize(left_arg) # special case non-foreign isin/notin expressions - if ( - not isinstance(options, tuple) - and options.output_shape is rlz.Shape.COLUMNAR - ): + if not isinstance(options, tuple) and options.output_shape.is_columnar(): # this will fail to execute if there's a correlation, but it's too # annoying to detect so we let it through to enable the # uncorrelated use case (pandas-style `.isin`) diff --git a/ibis/backends/clickhouse/tests/test_aggregations.py b/ibis/backends/clickhouse/tests/test_aggregations.py index 9b4f2719666e..11d91406880d 100644 --- a/ibis/backends/clickhouse/tests/test_aggregations.py +++ b/ibis/backends/clickhouse/tests/test_aggregations.py @@ -6,7 +6,8 @@ import pandas.testing as tm import pytest -from ibis import literal as L +import ibis +from ibis.common.patterns import ValidationError pytest.importorskip("clickhouse_connect") @@ -33,9 +34,9 @@ def test_std_var_pop(con, alltypes, method, translate, snapshot): @pytest.mark.parametrize('reduction', ['sum', 'count', 'max', 'min']) def test_reduction_invalid_where(alltypes, reduction): - condbad_literal = L('T') + condbad_literal = ibis.literal('T') - with pytest.raises(TypeError): + with pytest.raises(ValidationError): fn = methodcaller(reduction, where=condbad_literal) fn(alltypes.double_col) diff --git a/ibis/backends/dask/tests/execution/test_functions.py b/ibis/backends/dask/tests/execution/test_functions.py index cc33d6f3fbc8..d8f3fff19159 100644 --- a/ibis/backends/dask/tests/execution/test_functions.py +++ b/ibis/backends/dask/tests/execution/test_functions.py @@ -14,6 +14,7 @@ import ibis import ibis.expr.datatypes as dt from ibis.common.exceptions import OperationNotDefinedError +from ibis.common.patterns import ValidationError dd = pytest.importorskip("dask.dataframe") from dask.dataframe.utils import tm # noqa: E402 @@ -196,7 +197,7 @@ def test_quantile_scalar(t, df, ibis_func, dask_func): # out of range on quantile (lambda x: x.quantile(5.0), ValueError), # invalid interpolation arg - (lambda x: x.quantile(0.5, interpolation='foo'), ValueError), + (lambda x: x.quantile(0.5, interpolation='foo'), ValidationError), ], ) def test_arraylike_functions_transform_errors(t, df, ibis_func, exc): diff --git a/ibis/backends/impala/tests/test_udf.py b/ibis/backends/impala/tests/test_udf.py index 4cb8f607141a..223d52e03485 100644 --- a/ibis/backends/impala/tests/test_udf.py +++ b/ibis/backends/impala/tests/test_udf.py @@ -14,7 +14,7 @@ import ibis.expr.types as ir from ibis import util from ibis.backends.impala import ddl -from ibis.common.exceptions import IbisTypeError +from ibis.common.patterns import ValidationError from ibis.expr import rules pytest.importorskip("impala") @@ -223,7 +223,7 @@ def test_udf_invalid_typecasting(ty, valid_cast_indexer, all_cols): func = _register_udf([ty], 'int32', 'typecast') for expr in all_cols[valid_cast_indexer]: - with pytest.raises(IbisTypeError): + with pytest.raises(ValidationError): func(expr) diff --git a/ibis/backends/impala/tests/test_unary_builtins.py b/ibis/backends/impala/tests/test_unary_builtins.py index abe2483121e1..a478ffd3fa2d 100644 --- a/ibis/backends/impala/tests/test_unary_builtins.py +++ b/ibis/backends/impala/tests/test_unary_builtins.py @@ -2,9 +2,10 @@ import pytest +import ibis import ibis.expr.types as ir -from ibis import literal as L from ibis.backends.impala.tests.conftest import translate +from ibis.common.patterns import ValidationError @pytest.fixture(scope="module") @@ -99,8 +100,8 @@ def test_reduction_where(table, expr_fn, snapshot): @pytest.mark.parametrize("method_name", ["sum", "count", "mean", "max", "min"]) def test_reduction_invalid_where(table, method_name): - condbad_literal = L('T') + condbad_literal = ibis.literal('T') reduction = getattr(table.double_col, method_name) - with pytest.raises(TypeError): + with pytest.raises(ValidationError): reduction(where=condbad_literal) diff --git a/ibis/backends/impala/udf.py b/ibis/backends/impala/udf.py index deace95a4ff1..ce1d86a596f4 100644 --- a/ibis/backends/impala/udf.py +++ b/ibis/backends/impala/udf.py @@ -20,10 +20,10 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz -import ibis.legacy.udf.validate as v from ibis import util from ibis.backends.base.sql.registry import fixed_arity, sql_type_names from ibis.backends.impala.compiler import ImpalaExprTranslator +from ibis.legacy.udf.validate import validate_output_type __all__ = [ 'add_operation', @@ -67,7 +67,7 @@ def register(self, name: str, database: str) -> None: class ScalarFunction(Function): def _create_operation_class(self): - fields = {f'_{i}': rlz.value(dtype) for i, dtype in enumerate(self.inputs)} + fields = {f'_{i}': rlz.ValueOf(dtype) for i, dtype in enumerate(self.inputs)} fields['output_dtype'] = self.output fields['output_shape'] = rlz.shape_like('args') return type(f"UDF_{self.name}", (ops.Value,), fields) @@ -75,7 +75,7 @@ def _create_operation_class(self): class AggregateFunction(Function): def _create_operation_class(self): - fields = {f'_{i}': rlz.value(dtype) for i, dtype in enumerate(self.inputs)} + fields = {f'_{i}': rlz.ValueOf(dtype) for i, dtype in enumerate(self.inputs)} fields['output_dtype'] = self.output return type(f"UDA_{self.name}", (ops.Reduction,), fields) @@ -101,7 +101,7 @@ class ImpalaUDF(ScalarFunction, ImpalaFunction): """Feel free to customize my __doc__ or wrap in a nicer user API.""" def __init__(self, inputs, output, so_symbol=None, lib_path=None, name=None): - v.validate_output_type(output) + validate_output_type(output) self.so_symbol = so_symbol ImpalaFunction.__init__(self, name=name, lib_path=lib_path) ScalarFunction.__init__(self, inputs, output, name=self.name) @@ -136,7 +136,7 @@ def __init__( self.finalize_fn = finalize_fn self.serialize_fn = serialize_fn - v.validate_output_type(output) + validate_output_type(output) ImpalaFunction.__init__(self, name=name, lib_path=lib_path) AggregateFunction.__init__(self, inputs, output, name=self.name) @@ -268,10 +268,6 @@ def add_operation(op, func_name, db): database the relevant operator is registered to """ full_name = f'{db}.{func_name}' - # TODO - # if op.input_type is rlz.listof: - # translator = comp.varargs(full_name) - # else: arity = len(op.__signature__.parameters) translator = fixed_arity(full_name, arity) diff --git a/ibis/backends/pandas/tests/execution/test_functions.py b/ibis/backends/pandas/tests/execution/test_functions.py index 11ce66fee2c7..8acef621e289 100644 --- a/ibis/backends/pandas/tests/execution/test_functions.py +++ b/ibis/backends/pandas/tests/execution/test_functions.py @@ -16,6 +16,7 @@ from ibis.backends.pandas.execution import execute from ibis.backends.pandas.tests.conftest import TestConf as tm from ibis.backends.pandas.udf import udf +from ibis.common.patterns import ValidationError @pytest.mark.parametrize( @@ -172,7 +173,7 @@ def test_quantile_multi(t, df, ibis_func, pandas_func, column): # out of range on quantile (lambda x: x.quantile(5.0), ValueError), # invalid interpolation arg - (lambda x: x.quantile(0.5, interpolation='foo'), ValueError), + (lambda x: x.quantile(0.5, interpolation='foo'), ValidationError), ], ) def test_arraylike_functions_transform_errors(t, ibis_func, exc): diff --git a/ibis/backends/pandas/tests/execution/test_window.py b/ibis/backends/pandas/tests/execution/test_window.py index 391d40c0b199..aa88589c5e9b 100644 --- a/ibis/backends/pandas/tests/execution/test_window.py +++ b/ibis/backends/pandas/tests/execution/test_window.py @@ -10,7 +10,6 @@ from packaging.version import parse as vparse import ibis -import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.df.scope import Scope @@ -18,6 +17,7 @@ from ibis.backends.pandas.dispatch import pre_execute from ibis.backends.pandas.execution import execute from ibis.backends.pandas.tests.conftest import TestConf as tm +from ibis.common.patterns import ValidationError from ibis.legacy.udf.vectorized import reduction @@ -502,7 +502,7 @@ def test_window_with_mlb(): tm.assert_frame_equal(result, expected) rows_with_mlb = ibis.rows_with_max_lookback(5, 10) - with pytest.raises(com.IbisTypeError): + with pytest.raises(ValidationError): t.mutate( sum=lambda df: df.a.sum().over( ibis.trailing_window(rows_with_mlb, order_by='time') diff --git a/ibis/backends/postgres/udf.py b/ibis/backends/postgres/udf.py index 071f21481b4a..c48672c36c2b 100644 --- a/ibis/backends/postgres/udf.py +++ b/ibis/backends/postgres/udf.py @@ -12,10 +12,10 @@ import ibis import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -import ibis.legacy.udf.validate as v from ibis import IbisError from ibis.backends.postgres.compiler import PostgreSQLExprTranslator, PostgresUDFNode from ibis.backends.postgres.datatypes import PostgresType +from ibis.legacy.udf.validate import validate_output_type _udf_name_cache: MutableMapping[str, Any] = collections.defaultdict(itertools.count) @@ -70,10 +70,10 @@ def existing_udf(name, input_types, output_type, schema=None, parameters=None): ).format(len(input_types), len(parameters)) ) - v.validate_output_type(output_type) + validate_output_type(output_type) udf_node_fields = { - name: rlz.value(type_) for name, type_ in zip(parameters, input_types) + name: rlz.ValueOf(type_) for name, type_ in zip(parameters, input_types) } udf_node_fields['name'] = name udf_node_fields['output_dtype'] = output_type diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 60c533173af5..c16f69b9cb6d 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -193,10 +193,13 @@ def test_array_index(con, idx): builtin_array = toolz.compose( # these will almost certainly never be supported pytest.mark.never( - ["mysql", "sqlite"], + ["mysql"], reason="array types are unsupported", raises=com.OperationNotDefinedError, ), + pytest.mark.never( + ["sqlite"], reason="array types are unsupported", raises=NotImplementedError + ), # someone just needs to implement these pytest.mark.notimpl(["datafusion"], raises=Exception), duckdb_0_4_0, @@ -485,8 +488,9 @@ def test_array_slice(backend, start, stop): tm.assert_frame_equal(result, expected) +@builtin_array @pytest.mark.notimpl( - ["datafusion", "impala", "mssql", "polars", "snowflake", "sqlite"], + ["datafusion", "impala", "mssql", "polars", "snowflake", "sqlite", "mysql"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -495,9 +499,9 @@ def test_array_slice(backend, start, stop): reason="Operation 'ArrayMap' is not implemented for this backend", ) @pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules: ...", + ["sqlite"], + raises=NotImplementedError, + reason="Unsupported type: Array: ...", ) @pytest.mark.parametrize( ("input", "output"), @@ -525,8 +529,9 @@ def test_array_map(backend, con, input, output): backend.assert_frame_equal(result, expected) +@builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "impala", "mssql", "pandas", "polars", "snowflake"], + ["dask", "datafusion", "impala", "mssql", "pandas", "polars", "snowflake", "mysql"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -535,9 +540,7 @@ def test_array_map(backend, con, input, output): reason="Operation 'ArrayMap' is not implemented for this backend", ) @pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules: ...", + ["sqlite"], raises=NotImplementedError, reason="Unsupported type: Array..." ) @pytest.mark.parametrize( ("input", "output"), @@ -554,6 +557,7 @@ def test_array_filter(backend, con, input, output): backend.assert_frame_equal(result, expected) +@builtin_array @pytest.mark.notimpl( ["datafusion", "mssql", "pandas", "polars", "postgres"], raises=com.OperationNotDefinedError, @@ -561,11 +565,6 @@ def test_array_filter(backend, con, input, output): @pytest.mark.notimpl(["datafusion"], raises=Exception) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) @pytest.mark.never(["impala"], reason="array_types table isn't defined") -@pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules:....", -) def test_array_contains(backend, con): t = backend.array_types expr = t.x.contains(1) @@ -574,15 +573,11 @@ def test_array_contains(backend, con): backend.assert_series_equal(result, expected, check_names=False) +@builtin_array @pytest.mark.notimpl( ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules:....", -) def test_array_position(backend, con): t = ibis.memtable({"a": [[1], [], [42, 42], []]}) expr = t.a.index(42) @@ -591,15 +586,11 @@ def test_array_position(backend, con): backend.assert_series_equal(result, expected, check_names=False, check_dtype=False) +@builtin_array @pytest.mark.notimpl( ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules:....", -) def test_array_remove(backend, con): t = ibis.memtable({"a": [[3, 2], [], [42, 2], [2, 2], []]}) expr = t.a.remove(2) @@ -608,14 +599,18 @@ def test_array_remove(backend, con): backend.assert_series_equal(result, expected, check_names=False) +@builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], + ["dask", "datafusion", "impala", "mssql", "pandas", "polars", "mysql"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules:....", + ["sqlite"], raises=NotImplementedError, reason="Unsupported type: Array..." +) +@pytest.mark.notyet( + ["bigquery"], + raises=BadRequest, + reason="BigQuery doesn't support arrays with null elements", ) @pytest.mark.notyet( ["clickhouse"], @@ -650,15 +645,11 @@ def test_array_unique(backend, con, input, expected): backend.assert_series_equal(result, expected, check_names=False) +@builtin_array @pytest.mark.notimpl( ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules:....", -) def test_array_sort(backend, con): t = ibis.memtable({"a": [[3, 2], [], [42, 42], []]}) expr = t.a.sort() @@ -667,6 +658,7 @@ def test_array_sort(backend, con): backend.assert_series_equal(result, expected, check_names=False) +@builtin_array @pytest.mark.notimpl( ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, @@ -676,11 +668,6 @@ def test_array_sort(backend, con): raises=BadRequest, reason="BigQuery doesn't support arrays with null elements", ) -@pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules:....", -) def test_array_union(con): t = ibis.memtable({"a": [[3, 2], [], []], "b": [[1, 3], [None], [5]]}) expr = t.a.union(t.b) @@ -693,13 +680,11 @@ def test_array_union(con): @pytest.mark.notimpl( - ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], + ["dask", "datafusion", "impala", "mssql", "pandas", "polars", "mysql"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules:....", + ["sqlite"], raises=NotImplementedError, reason="Unsupported type: Array..." ) def test_array_intersect(con): t = ibis.memtable( @@ -715,17 +700,13 @@ def test_array_intersect(con): @unnest +@builtin_array @pytest.mark.notimpl( ["clickhouse"], raises=OperationalError, reason="ClickHouse won't accept dicts for struct type values", ) @pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError) -@pytest.mark.notimpl( - ["sqlite", "mysql"], - raises=com.IbisTypeError, - reason="argument passes none of the following rules: ...", -) def test_unnest_struct(con): data = {"value": [[{'a': 1}, {'a': 2}], [{'a': 3}, {'a': 4}]]} t = ibis.memtable(data, schema=ibis.schema({"value": "!array>"})) @@ -735,8 +716,9 @@ def test_unnest_struct(con): tm.assert_series_equal(result, expected) +@builtin_array @pytest.mark.never( - ["impala", "mssql", "mysql", "sqlite"], + ["impala", "mssql"], raises=com.OperationNotDefinedError, reason="no array support", ) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 1adfdd9c196b..5377765f4062 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -18,7 +18,7 @@ import ibis.expr.datatypes as dt import ibis.selectors as s from ibis import _ -from ibis import literal as L +from ibis.common.patterns import ValidationError try: import duckdb @@ -99,13 +99,13 @@ def test_boolean_literal(con, backend): id="na_fillna", ), param( - L(5).fillna(10), + ibis.literal(5).fillna(10), 5, marks=pytest.mark.notimpl(["mssql", "druid", "oracle"]), id="non_na_fillna", ), - param(L(5).nullif(5), None, id="nullif_null"), - param(L(10).nullif(5), 10, id="nullif_not_null"), + param(ibis.literal(5).nullif(5), None, id="nullif_null"), + param(ibis.literal(10).nullif(5), 10, id="nullif_not_null"), ], ) @pytest.mark.notimpl(["datafusion"]) @@ -471,7 +471,7 @@ def test_dropna_invalid(alltypes): ): alltypes.dropna(subset=['invalid_col']) - with pytest.raises(ValueError, match=r".*is not in.*"): + with pytest.raises(ValidationError, match=r"'invalid' doesn't match"): alltypes.dropna(how='invalid') @@ -491,7 +491,6 @@ def test_dropna_table(backend, alltypes, how, subset): ).select("col_1", "col_2", "col_3") table_pandas = table.execute() - result = table.dropna(subset, how).execute().reset_index(drop=True) expected = table_pandas.dropna(how=how, subset=subset).reset_index(drop=True) diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 861eecc85a9e..67baf219bca1 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -40,10 +40,10 @@ def union_subsets(alltypes, df): def test_union(backend, union_subsets, distinct): (a, b, c), (da, db, dc) = union_subsets - expr = ibis.union(a, b, c, distinct=distinct).order_by("id") + expr = ibis.union(a, b, distinct=distinct).order_by("id") result = expr.execute() - expected = pd.concat([da, db, dc], axis=0).sort_values("id").reset_index(drop=True) + expected = pd.concat([da, db], axis=0).sort_values("id").reset_index(drop=True) if distinct: expected = expected.drop_duplicates("id") diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 1bee66ededb4..04e58f8b65a7 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -10,6 +10,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt from ibis.common.exceptions import OperationNotDefinedError +from ibis.common.patterns import ValidationError try: from google.api_core.exceptions import BadRequest @@ -873,7 +874,7 @@ def test_re_replace_global(con): "a context where a condition is expected, near 'THEN'.DB-Lib error message 20018, severity 15:\n" ), ) -@pytest.mark.notimpl(["druid"], raises=com.IbisTypeError) +@pytest.mark.notimpl(["druid"], raises=ValidationError) @pytest.mark.broken( ["oracle"], raises=sa.exc.DatabaseError, diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 8d9b7f293803..6fca8efb16ac 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -17,6 +17,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt from ibis.backends.pandas.execution.temporal import day_name +from ibis.common.patterns import ValidationError try: from duckdb import InvalidInputException as DuckDBInvalidInputException @@ -711,7 +712,7 @@ def test_date_truncate(backend, alltypes, df, unit): ) @pytest.mark.notimpl( ["druid"], - raises=com.IbisTypeError, + raises=ValidationError, reason="Given argument with datatype interval('h') is not implicitly castable to string", ) def test_integer_to_interval_timestamp( @@ -813,7 +814,7 @@ def convert_to_offset(x): ), pytest.mark.notimpl( ["druid"], - raises=com.IbisTypeError, + raises=ValidationError, reason="Given argument with datatype interval('D') is not implicitly castable to string", ), ], @@ -841,7 +842,7 @@ def convert_to_offset(x): ), pytest.mark.notimpl( ["druid"], - raises=com.IbisTypeError, + raises=ValidationError, reason="Given argument with datatype interval('D') is not implicitly castable to string", ), ], @@ -868,8 +869,8 @@ def convert_to_offset(x): ), pytest.mark.notimpl( ["druid"], - raises=com.IbisTypeError, - reason="Given argument with datatype interval() is not implicitly castable to string", + raises=ValidationError, + reason="alltypes.timestamp_col is represented as string", ), ], ), @@ -931,7 +932,7 @@ def convert_to_offset(x): ), pytest.mark.notimpl( ["druid"], - raises=TypeError, + raises=ValidationError, reason="unsupported operand type(s) for -: 'StringColumn' and 'TimestampScalar'", ), pytest.mark.xfail_version( @@ -987,8 +988,13 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): marks=[ pytest.mark.broken( ['druid'], - raises=com.IbisTypeError, - reason="Given argument with datatype interval('D') is not implicitly castable to string", + raises=AssertionError, + reason="alltypes.timestamp_col is represented as string", + ), + pytest.mark.broken( + ["clickhouse"], + raises=AssertionError, + reason="DateTime column overflows, should use DateTime64", ), pytest.mark.broken( ["clickhouse"], @@ -1003,8 +1009,8 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): marks=[ pytest.mark.broken( ['druid'], - raises=com.IbisTypeError, - reason="Given argument with datatype interval('D') is not implicitly castable to string", + raises=AssertionError, + reason="alltypes.timestamp_col is represented as string", ), ], ), @@ -1014,8 +1020,8 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): marks=[ pytest.mark.broken( ['druid'], - raises=com.IbisTypeError, - reason="Given argument with datatype interval('D') is not implicitly castable to string", + raises=AssertionError, + reason="alltypes.timestamp_col is represented as string", ), ], ), @@ -1025,8 +1031,8 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): marks=[ pytest.mark.broken( ['druid'], - raises=com.IbisTypeError, - reason="Given argument with datatype interval('h') is not implicitly castable to string", + raises=AssertionError, + reason="alltypes.timestamp_col is represented as string", ), ], ), @@ -1036,8 +1042,8 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): marks=[ pytest.mark.broken( ['druid'], - raises=com.IbisTypeError, - reason="Given argument with datatype interval('m') is not implicitly castable to string", + raises=AssertionError, + reason="alltypes.timestamp_col is represented as string", ), ], ), @@ -1047,8 +1053,8 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): marks=[ pytest.mark.broken( ['druid'], - raises=com.IbisTypeError, - reason="Given argument with datatype interval('s') is not implicitly castable to string", + raises=AssertionError, + reason="alltypes.timestamp_col is represented as string", ), ], ), @@ -1066,6 +1072,11 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): raises=AssertionError, reason="DateTime column overflows, should use DateTime64", ), + pytest.mark.broken( + ["clickhouse"], + raises=AssertionError, + reason="DateTime column overflows, should use DateTime64", + ), ], ), param( diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index cc0c58387704..d09a37f3eaa0 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -32,8 +32,11 @@ def num_vowels(s: str, include_y: bool = False) -> int: return sum(map(s.lower().count, "aeiou" + ("y" * include_y))) batting = batting.limit(100) + nvowels = num_vowels(batting.playerID) + assert nvowels.op().__module__ == __name__ + assert type(nvowels.op()).__qualname__ == "num_vowels" - expr = batting.group_by(id_len=num_vowels(batting.playerID)).agg(n=_.count()) + expr = batting.group_by(id_len=nvowels).agg(n=_.count()) result = expr.execute() assert not result.empty diff --git a/ibis/backends/tests/test_vectorized_udf.py b/ibis/backends/tests/test_vectorized_udf.py index 8f9244f61310..0398bbef0fd6 100644 --- a/ibis/backends/tests/test_vectorized_udf.py +++ b/ibis/backends/tests/test_vectorized_udf.py @@ -547,7 +547,7 @@ def test_elementwise_udf_named_destruct(udf_backend, udf_alltypes): add_one_struct_udf = create_add_one_struct_udf( result_formatter=lambda v1, v2: (v1, v2) ) - with pytest.raises(TypeError, match=r"Unable to infer datatype of"): + with pytest.raises(com.IbisTypeError, match=r"Unable to infer"): udf_alltypes.mutate( new_struct=add_one_struct_udf(udf_alltypes['double_col']).destructure() ) diff --git a/ibis/common/annotations.py b/ibis/common/annotations.py index 590ad03c9614..e913a5657c7e 100644 --- a/ibis/common/annotations.py +++ b/ibis/common/annotations.py @@ -2,11 +2,18 @@ import functools import inspect -from typing import Any +from typing import Any as AnyType from ibis.common.collections import DotDict +from ibis.common.patterns import ( + Any, + FrozenDictOf, + Function, + Option, + TupleOf, + Validator, +) from ibis.common.typing import get_type_hints -from ibis.common.validators import Validator, any_, frozendict_of, option, tuple_of EMPTY = inspect.Parameter.empty # marker for missing argument KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY @@ -34,6 +41,12 @@ class Annotation: __slots__ = ('_validator', '_default', '_typehint') def __init__(self, validator=None, default=EMPTY, typehint=EMPTY): + if validator is None or isinstance(validator, Validator): + pass + elif callable(validator): + validator = Function(validator) + else: + raise TypeError(f"Unsupported validator {validator!r}") self._default = default self._typehint = typehint self._validator = validator @@ -52,10 +65,10 @@ def __repr__(self): f"default={self._default!r}, typehint={self._typehint!r})" ) - def validate(self, arg, **kwargs): + def validate(self, arg, context=None): if self._validator is None: return arg - return self._validator(arg, **kwargs) + return self._validator.validate(arg, context) class Attribute(Annotation): @@ -85,7 +98,7 @@ def initialize(self, this): value = self._default(this) else: value = self._default - return self.validate(value, this=this) + return self.validate(value, this) class Argument(Annotation): @@ -109,7 +122,7 @@ class Argument(Annotation): def __init__( self, validator: Validator | None = None, - default: Any = EMPTY, + default: AnyType = EMPTY, typehint: type | None = None, kind: int = POSITIONAL_OR_KEYWORD, ): @@ -130,20 +143,20 @@ def default(cls, default, validator=None, **kwargs): def optional(cls, validator=None, default=None, **kwargs): """Annotation to allow and treat `None` values as missing arguments.""" if validator is None: - validator = option(any_, default=default) + validator = Option(Any(), default=default) else: - validator = option(validator, default=default) + validator = Option(validator, default=default) return cls(validator, default=None, **kwargs) @classmethod def varargs(cls, validator=None, **kwargs): """Annotation to mark a variable length positional argument.""" - validator = None if validator is None else tuple_of(validator) + validator = None if validator is None else TupleOf(validator) return cls(validator, kind=VAR_POSITIONAL, **kwargs) @classmethod def varkwargs(cls, validator=None, **kwargs): - validator = None if validator is None else frozendict_of(any_, validator) + validator = None if validator is None else FrozenDictOf(Any(), validator) return cls(validator, kind=VAR_KEYWORD, **kwargs) @@ -295,7 +308,7 @@ def from_callable(cls, fn, validators=None, return_validator=None): return cls(parameters, return_annotation=return_annotation) - def unbind(self, this: Any): + def unbind(self, this: AnyType): """Reverse bind of the parameters. Attempts to reconstructs the original arguments as keyword only arguments. @@ -353,7 +366,8 @@ def validate(self, *args, **kwargs): for name, value in bound.arguments.items(): param = self.parameters[name] # TODO(kszucs): provide more error context on failure - this[name] = param.annotation.validate(value, this=this) + this[name] = param.annotation.validate(value, this) + return this def validate_nobind(self, **kwargs): @@ -363,16 +377,18 @@ def validate_nobind(self, **kwargs): value = kwargs.get(name, param.default) if value is EMPTY: raise TypeError(f"missing required argument `{name!r}`") - this[name] = param.annotation.validate(value, this=kwargs) + this[name] = param.annotation.validate(value, kwargs) return this - def validate_return(self, value): + def validate_return(self, value, context): """Validate the return value of a function. Parameters ---------- value : Any Return value of the function. + context : dict + Context dictionary. Returns ------- @@ -381,8 +397,7 @@ def validate_return(self, value): """ if self.return_annotation is EMPTY: return value - else: - return self.return_annotation(value) + return self.return_annotation.validate(value, context) # aliases for convenience @@ -412,7 +427,7 @@ def annotated(_1=None, _2=None, _3=None, **kwargs): 2. With argument validators passed as keyword arguments - >>> from ibis.common.validators import instance_of + >>> from ibis.common.patterns import InstanceOf as instance_of >>> @annotated(x=instance_of(int), y=instance_of(str)) ... def foo(x, y): ... return float(x) + float(y) @@ -478,7 +493,7 @@ def wrapped(*args, **kwargs): # 3. Call the function with the validated arguments result = func(*args, **kwargs) # 4. Validate the return value - return sig.validate_return(result) + return sig.validate_return(result, {}) wrapped.__signature__ = sig diff --git a/ibis/common/grounds.py b/ibis/common/grounds.py index 394bc4ccffdd..a1d1709b84d4 100644 --- a/ibis/common/grounds.py +++ b/ibis/common/grounds.py @@ -3,14 +3,30 @@ import contextlib from abc import ABCMeta, abstractmethod from copy import copy -from typing import Any +from typing import ( + Any, + ClassVar, + Mapping, + Tuple, + Union, + get_origin, +) from weakref import WeakValueDictionary -from ibis.common.annotations import EMPTY, Argument, Attribute, Signature, attribute +from typing_extensions import Self, dataclass_transform + +from ibis.common.annotations import ( + EMPTY, + Annotation, + Argument, + Attribute, + Signature, + attribute, +) from ibis.common.caching import WeakCache from ibis.common.collections import FrozenDict +from ibis.common.patterns import Validator from ibis.common.typing import evaluate_annotations -from ibis.common.validators import Validator class BaseMeta(ABCMeta): @@ -21,13 +37,13 @@ def __new__(metacls, clsname, bases, dct, **kwargs): dct.setdefault("__slots__", ()) return super().__new__(metacls, clsname, bases, dct, **kwargs) - def __call__(cls, *args, **kwargs) -> Base: + def __call__(cls, *args, **kwargs): return cls.__create__(*args, **kwargs) class Base(metaclass=BaseMeta): __slots__ = ('__weakref__',) - __create__ = classmethod(type.__call__) + __create__ = classmethod(type.__call__) # type: ignore class AnnotableMeta(BaseMeta): @@ -45,10 +61,15 @@ def __new__(metacls, clsname, bases, dct, **kwargs): signatures.append(parent.__signature__) # collection type annotations and convert them to validators - module_name = dct.get('__module__') + module = dct.get('__module__') + qualname = dct.get('__qualname__') or clsname annotations = dct.get('__annotations__', {}) - typehints = evaluate_annotations(annotations, module_name) + + # TODO(kszucs): pass dct as localns to evaluate_annotations + typehints = evaluate_annotations(annotations, module) for name, typehint in typehints.items(): + if get_origin(typehint) is ClassVar: + continue validator = Validator.from_typehint(typehint) if name in dct: dct[name] = Argument.default(dct[name], validator, typehint=typehint) @@ -77,6 +98,8 @@ def __new__(metacls, clsname, bases, dct, **kwargs): argnames = tuple(signature.parameters.keys()) namespace.update( + __module__=module, + __qualname__=qualname, __argnames__=argnames, __attributes__=attributes, __match_args__=argnames, @@ -85,23 +108,33 @@ def __new__(metacls, clsname, bases, dct, **kwargs): ) return super().__new__(metacls, clsname, bases, namespace, **kwargs) + def __or__(self, other): + # required to support `dt.Numeric | dt.Floating` annotation for python<3.10 + return Union[self, other] + +@dataclass_transform() class Annotable(Base, metaclass=AnnotableMeta): """Base class for objects with custom validation rules.""" + __argnames__: ClassVar[Tuple[str, ...]] + __attributes__: ClassVar[FrozenDict[str, Annotation]] + __match_args__: ClassVar[Tuple[str, ...]] + __signature__: ClassVar[Signature] + @classmethod - def __create__(cls, *args, **kwargs) -> Annotable: + def __create__(cls, *args: Any, **kwargs: Any) -> Self: # construct the instance by passing the validated keyword arguments kwargs = cls.__signature__.validate(*args, **kwargs) return super().__create__(**kwargs) @classmethod - def __recreate__(cls, kwargs) -> Annotable: + def __recreate__(cls, kwargs: Any) -> Self: # bypass signature binding by requiring keyword arguments only kwargs = cls.__signature__.validate_nobind(**kwargs) return super().__create__(**kwargs) - def __init__(self, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: # set the already validated arguments for name, value in kwargs.items(): object.__setattr__(self, name, value) @@ -114,7 +147,7 @@ def __init__(self, **kwargs) -> None: def __setattr__(self, name, value) -> None: if field := self.__attributes__.get(name): - value = field.validate(value, this=self) + value = field.validate(value, self) super().__setattr__(name, value) def __repr__(self) -> str: @@ -132,7 +165,7 @@ def __eq__(self, other) -> bool: ) @property - def __args__(self): + def __args__(self) -> Tuple[Any, ...]: return tuple(getattr(self, name) for name in self.__argnames__) def copy(self, **overrides: Any) -> Annotable: @@ -169,10 +202,10 @@ def __setattr__(self, name: str, _: Any) -> None: class Singleton(Base): - __instances__ = WeakValueDictionary() + __instances__: Mapping[Any, Self] = WeakValueDictionary() @classmethod - def __create__(cls, *args, **kwargs) -> Singleton: + def __create__(cls, *args, **kwargs): key = (cls, args, FrozenDict(kwargs)) try: return cls.__instances__[key] @@ -182,6 +215,15 @@ def __create__(cls, *args, **kwargs) -> Singleton: return instance +class Final(Base): + def __init_subclass__(cls, **kwargs): + cls.__init_subclass__ = cls.__prohibit_inheritance__ + + @classmethod + def __prohibit_inheritance__(cls, **kwargs): + raise TypeError(f"Cannot inherit from final class {cls}") + + class Comparable(Base): __cache__ = WeakCache() @@ -226,7 +268,7 @@ def __args__(self): return tuple(getattr(self, name) for name in self.__argnames__) @attribute.default - def __precomputed_hash__(self): + def __precomputed_hash__(self) -> int: return hash((self.__class__, self.__args__)) def __reduce__(self): @@ -235,10 +277,10 @@ def __reduce__(self): state = dict(zip(self.__argnames__, self.__args__)) return (self.__recreate__, (state,)) - def __hash__(self): + def __hash__(self) -> int: return self.__precomputed_hash__ - def __equals__(self, other): + def __equals__(self, other) -> bool: return self.__args__ == other.__args__ @property @@ -246,10 +288,10 @@ def args(self): return self.__args__ @property - def argnames(self): + def argnames(self) -> Tuple[str, ...]: return self.__argnames__ - def copy(self, **overrides): + def copy(self, **overrides) -> Self: kwargs = dict(zip(self.__argnames__, self.__args__)) kwargs.update(overrides) return self.__recreate__(kwargs) diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 50bba5f91176..954c83dd777a 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -1,7 +1,5 @@ from __future__ import annotations -import enum -import inspect import math import numbers from abc import ABC, abstractmethod @@ -11,6 +9,7 @@ from itertools import chain, zip_longest from typing import Any as AnyType from typing import ( + ForwardRef, Generic, # noqa: F401 Literal, Optional, @@ -19,11 +18,12 @@ Union, ) -from typing_extensions import Annotated, Self, get_args, get_origin +import toolz +from typing_extensions import Annotated, GenericMeta, Self, get_args, get_origin from ibis.common.collections import RewindableIterator, frozendict from ibis.common.dispatch import lazy_singledispatch -from ibis.common.typing import get_bound_typevars, get_type_params +from ibis.common.typing import Sentinel, get_bound_typevars, get_type_params from ibis.util import is_iterable, promote_tuple try: @@ -32,13 +32,21 @@ UnionType = object() -T = TypeVar("T") +T_cov = TypeVar("T_cov", covariant=True) class CoercionError(Exception): ... +class ValidationError(Exception): + ... + + +class MatchError(Exception): + ... + + class Coercible(ABC): """Protocol for defining coercible types. @@ -55,39 +63,44 @@ def __coerce__(cls, value: Any, **kwargs: Any) -> Self: ... -class ValidationError(Exception): - ... - - class Validator(ABC): __slots__ = () - @abstractmethod - def validate(self, value, context): - ... - - -class MatchError(Exception): - ... - - -class NoMatchType(enum.Enum): - NoMatch = "NoMatch" + @classmethod + def from_typevar(cls, var: TypeVar, bound: AnyType = None) -> Pattern: + """Construct a validator from a type variable. + This method is called from two places: + 1. `Validator.from_typehint` without additional bound argument + 2. `GenericInstanceOf` with a substituted type parameter given as bound -NoMatch = NoMatchType.NoMatch # Sentinel value for when a pattern doesn't match. + This method also ensures that the type variable is covariant, + contravariant and invariant type variables are not supported yet. + Parameters + ---------- + var + The type variable to construct the pattern from. + bound + An optional bound to use for the type variable. If not provided, + a no-op validator is returned. -# TODO(kszucs): have an As[int] or Coerced[int] type in ibis.common.typing which -# would be used to annotate an argument as coercible to int or to a certain type -# without needing for the type to inherit from Coercible -# TODO(kszucs): enforce typevars to be invariant otherwise raise an error -class Pattern(Validator, Hashable): - # TODO(kszucs): may need a flag to set preference for coercion over instance check + Returns + ------- + pattern + A pattern that matches the given type variable. + """ + if var.__covariant__: + if bound := bound or var.__bound__: + return cls.from_typehint(bound) + else: + return Any() + else: + raise NotImplementedError("Only covariant typevars are supported for now") @classmethod def from_typehint(cls, annot: type) -> Pattern: - """Construct a pattern from a python type annotation. + """Construct a validator from a python type annotation. Parameters ---------- @@ -101,44 +114,77 @@ def from_typehint(cls, annot: type) -> Pattern: A pattern that matches the given type annotation. """ # TODO(kszucs): cache the result of this function + # TODO(kszucs): explore issubclass(typ, SupportsInt) etc. origin, args = get_origin(annot), get_args(annot) if origin is None: + # the typehint is not generic if annot is Ellipsis or annot is AnyType: + # treat both `Any` and `...` as wildcard return Any() - elif annot is None: - return Is(None) - elif isinstance(annot, TypeVar): - # TODO(kszucs): only use coerced_to if annot.__covariant__ is True - if annot.__bound__ is None: - return Any() + elif isinstance(annot, type): + # the typehint is a concrete type (e.g. int, str, etc.) + if issubclass(annot, Coercible): + # the type implements the Coercible protocol so we try to + # coerce the value to the given type rather than checking + return CoercedTo(annot) else: - return cls.from_typehint(annot.__bound__) + return InstanceOf(annot) + elif isinstance(annot, TypeVar): + # if the typehint is a type variable we try to construct a + # validator from it only if it is covariant and has a bound + return cls.from_typevar(annot) elif isinstance(annot, Enum): + # for enums we check the value against the enum values return EqualTo(annot) - elif issubclass(annot, Coercible): - return CoercedTo(annot) + elif isinstance(annot, (str, ForwardRef)): + # for strings and forward references we check in a lazy way + return LazyInstanceOf(annot) else: - return InstanceOf(annot) + raise TypeError(f"Cannot create validator from annotation {annot!r}") elif origin is Literal: + # for literal types we check the value against the literal values return IsIn(args) elif origin is UnionType or origin is Union: - if len(args) == 2 and args[1] is type(None): - return Option(cls.from_typehint(args[0])) - inners = map(cls.from_typehint, args) - return AnyOf(*inners) + # this is slightly more complicated because we need to handle + # Optional[T] which is Union[T, None] and Union[T1, T2, ...] + *rest, last = args + if last is type(None): + # the typehint is Optional[*rest] which is equivalent to + # Union[*rest, None], so we construct an Option pattern + if len(rest) == 1: + inner = cls.from_typehint(rest[0]) + else: + inner = AnyOf(*map(cls.from_typehint, rest)) + return Option(inner) + else: + # the typehint is Union[*args] so we construct an AnyOf pattern + return AnyOf(*map(cls.from_typehint, args)) elif origin is Annotated: + # the Annotated typehint can be used to add extra validation logic + # to the typehint, e.g. Annotated[int, Positive], the first argument + # is used for isinstance checks, the rest are applied in conjunction annot, *extras = args return AllOf(cls.from_typehint(annot), *extras) elif origin is Callable: + # the Callable typehint is used to annotate functions, e.g. the + # following typehint annotates a function that takes two integers + # and returns a string: Callable[[int, int], str] if args: + # callable with args and return typehints construct a special + # CallableWith validator arg_hints, return_hint = args arg_patterns = tuple(map(cls.from_typehint, arg_hints)) return_pattern = cls.from_typehint(return_hint) return CallableWith(arg_patterns, return_pattern) else: + # in case of Callable without args we check for the Callable + # protocol only return InstanceOf(Callable) elif issubclass(origin, Tuple): + # construct validators for the tuple elements, but need to treat + # variadic tuples differently, e.g. tuple[int, ...] is a variadic + # tuple of integers, while tuple[int] is a tuple with a single int first, *rest = args # TODO(kszucs): consider to support the same SequenceOf path if args # has a single element, e.g. tuple[int] since annotation a single @@ -150,20 +196,36 @@ def from_typehint(cls, annot: type) -> Pattern: inners = tuple(map(cls.from_typehint, args)) return TupleOf(inners) elif issubclass(origin, Sequence): + # construct a validator for the sequence elements where all elements + # must be of the same type, e.g. Sequence[int] is a sequence of ints (value_inner,) = map(cls.from_typehint, args) return SequenceOf(value_inner, type=origin) elif issubclass(origin, Mapping): + # construct a validator for the mapping keys and values, e.g. + # Mapping[str, int] is a mapping with string keys and int values key_inner, value_inner = map(cls.from_typehint, args) return MappingOf(key_inner, value_inner, type=origin) - elif issubclass(origin, Coercible) and args: - return GenericCoercedTo(annot) - elif isinstance(origin, type) and args: - return GenericInstanceOf(annot) + elif isinstance(origin, GenericMeta): + # construct a validator for the generic type, see the specific + # Generic* validators for more details + if issubclass(origin, Coercible) and args: + return GenericCoercedTo(annot) + else: + return GenericInstanceOf(annot) else: - raise NotImplementedError( - f"Cannot create validator from annotation {annot} {origin}" + raise TypeError( + f"Cannot create validator from annotation {annot!r} {origin!r}" ) + +class NoMatch(metaclass=Sentinel): + """Marker to indicate that a pattern didn't match.""" + + +# TODO(kszucs): have an As[int] or Coerced[int] type in ibis.common.typing which +# would be used to annotate an argument as coercible to int or to a certain type +# without needing for the type to inherit from Coercible +class Pattern(Validator, Hashable): @abstractmethod def match(self, value: AnyType, context: dict[str, AnyType]) -> AnyType: """Match a value against the pattern. @@ -477,10 +539,10 @@ class GenericInstanceOf(Matcher): Examples -------- - >>> class MyNumber(Generic[T]): - ... value: T + >>> class MyNumber(Generic[T_cov]): + ... value: T_cov ... - ... def __init__(self, value: T): + ... def __init__(self, value: T_cov): ... self.value = value ... ... def __eq__(self, other): @@ -499,8 +561,11 @@ class GenericInstanceOf(Matcher): def __init__(self, typ): origin = get_origin(typ) - fields = get_bound_typevars(typ) - field_inners = {k: Pattern.from_typehint(v) for k, v in fields.items()} + typevars = get_bound_typevars(typ) + field_inners = { + attr: Pattern.from_typevar(var, type_) + for var, (attr, type_) in typevars.items() + } super().__init__(origin, frozendict(field_inners)) def match(self, value, context): @@ -541,6 +606,7 @@ def match(self, value, *, context): return NoMatch +# TODO(kszucs): to support As[int] or CoercedTo[int] syntax class CoercedTo(Matcher): """Force a value to have a particular Python type. @@ -577,6 +643,9 @@ def __repr__(self): return f"CoercedTo({self.target.__name__!r})" +As = CoercedTo + + class GenericCoercedTo(Matcher): """Force a value to have a particular generic Python type. @@ -589,7 +658,7 @@ class GenericCoercedTo(Matcher): -------- >>> from typing import Generic, TypeVar >>> - >>> T = TypeVar("T") + >>> T = TypeVar("T", covariant=True) >>> >>> class MyNumber(Coercible, Generic[T]): ... def __init__(self, value): @@ -780,6 +849,9 @@ def match(self, value, context): return NoMatch +In = IsIn + + class SequenceOf(Matcher): """Pattern that matches if all of the items in a sequence match a given pattern. @@ -901,6 +973,24 @@ def match(self, value, context): return result +class Attrs(Matcher): + __slots__ = ("patterns",) + + def __init__(self, **patterns): + super().__init__(frozendict(toolz.valmap(pattern, patterns))) + + def match(self, value, context): + for attr, pattern in self.patterns.items(): + if not hasattr(value, attr): + return NoMatch + + v = getattr(value, attr) + if match(pattern, v, context=context) is NoMatch: + return NoMatch + + return value + + class Object(Matcher): """Pattern that matches if the object has the given attributes and they match the given patterns. @@ -917,22 +1007,18 @@ class Object(Matcher): The keyword arguments to match against the attributes of the object. """ - __slots__ = ("type", "field_patterns") + __slots__ = ("type", "attrs_pattern") def __init__(self, type, *args, **kwargs): kwargs.update(dict(zip(type.__match_args__, args))) - super().__init__(type, frozendict(kwargs)) + super().__init__(type, Attrs(**kwargs)) def match(self, value, context): if not isinstance(value, self.type): return NoMatch - for attr, pattern in self.field_patterns.items(): - if not hasattr(value, attr): - return NoMatch - - if match(pattern, getattr(value, attr), context=context) is NoMatch: - return NoMatch + if not self.attrs_pattern.match(value, context=context): + return NoMatch return value @@ -944,19 +1030,16 @@ def __init__(self, args, return_=None): super().__init__(tuple(args), return_ or Any()) def match(self, value, context): + from ibis.common.annotations import annotated + if not callable(value): return NoMatch - fn = value - sig = inspect.signature(fn) - # TODO(kszucs): once the validators get replaced with matchers the - # following should be re-enabled - # from ibis.common.annotations import annotated - # fn = annotated(self.arg_patterns, self.return_pattern, value) + fn = annotated(self.arg_patterns, self.return_pattern, value) has_varargs = False positional, keyword_only = [], [] - for p in sig.parameters.values(): + for p in fn.__signature__.parameters.values(): if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): positional.append(p) elif p.kind is Parameter.KEYWORD_ONLY: diff --git a/ibis/common/temporal.py b/ibis/common/temporal.py index 0f3d23c93d5b..dded3fbcf5ba 100644 --- a/ibis/common/temporal.py +++ b/ibis/common/temporal.py @@ -13,7 +13,7 @@ from ibis import util from ibis.common.dispatch import lazy_singledispatch -from ibis.common.patterns import Coercible +from ibis.common.patterns import Coercible, CoercionError class ABCEnumMeta(EnumMeta, ABCMeta): @@ -30,6 +30,12 @@ def __coerce__(cls, value): @classmethod def from_string(cls, value): + # TODO(kszucs): perhaps this is not needed anymore + if isinstance(value, Unit): + value = value.value + elif not isinstance(value, str): + raise CoercionError(f"Unable to coerce {value} to {cls.__name__}") + # first look for aliases value = cls.aliases().get(value, value) @@ -45,7 +51,7 @@ def from_string(cls, value): try: return cls[value.upper()] except KeyError: - raise ValueError(f"Unable to coerce {value} to {cls.__name__}") + raise CoercionError(f"Unable to coerce {value} to {cls.__name__}") @classmethod def aliases(cls): @@ -219,6 +225,11 @@ def _from_number(value): return datetime.datetime.utcfromtimestamp(value) +@normalize_datetime.register(datetime.time) +def _from_time(value): + return datetime.datetime.combine(datetime.date.today(), value) + + @normalize_datetime.register(datetime.date) def _from_date(value): return datetime.datetime(year=value.year, month=value.month, day=value.day) diff --git a/ibis/common/tests/test_annotations.py b/ibis/common/tests/test_annotations.py index cd367032d446..f3ae813189e5 100644 --- a/ibis/common/tests/test_annotations.py +++ b/ibis/common/tests/test_annotations.py @@ -4,25 +4,33 @@ from typing import Union import pytest -from toolz import identity from typing_extensions import Annotated # noqa: TCH002 from ibis.common.annotations import Argument, Attribute, Parameter, Signature, annotated -from ibis.common.validators import instance_of, option +from ibis.common.patterns import ( + Any, + CoercedTo, + InstanceOf, + NoMatch, + Option, + TupleOf, + ValidationError, + pattern, +) -is_int = instance_of(int) +is_int = InstanceOf(int) def test_argument_repr(): argument = Argument(is_int, typehint=int, default=None) assert repr(argument) == ( - "Argument(validator=instance_of(,), default=None, " + "Argument(validator=InstanceOf(type=), default=None, " "typehint=)" ) def test_default_argument(): - annotation = Argument.default(validator=int, default=3) + annotation = Argument.default(validator=lambda x, context: int(x), default=3) assert annotation.validate(1) == 1 with pytest.raises(TypeError): annotation.validate(None) @@ -30,7 +38,7 @@ def test_default_argument(): @pytest.mark.parametrize( ('default', 'expected'), - [(None, None), (0, 0), ('default', 'default'), (lambda: 3, 3)], + [(None, None), (0, 0), ('default', 'default')], ) def test_optional_argument(default, expected): annotation = Argument.optional(default=default) @@ -40,15 +48,13 @@ def test_optional_argument(default, expected): @pytest.mark.parametrize( ('argument', 'value', 'expected'), [ - (Argument.optional(identity, default=None), None, None), - (Argument.optional(identity, default=None), 'three', 'three'), - (Argument.optional(identity, default=1), None, 1), - (Argument.optional(identity, default=lambda: 8), 'cat', 'cat'), - (Argument.optional(identity, default=lambda: 8), None, 8), - (Argument.optional(int, default=11), None, 11), - (Argument.optional(int, default=None), None, None), - (Argument.optional(int, default=None), 18, 18), - (Argument.optional(str, default=None), 'caracal', 'caracal'), + (Argument.optional(Any(), default=None), None, None), + (Argument.optional(Any(), default=None), 'three', 'three'), + (Argument.optional(Any(), default=1), None, 1), + (Argument.optional(CoercedTo(int), default=11), None, 11), + (Argument.optional(CoercedTo(int), default=None), None, None), + (Argument.optional(CoercedTo(int), default=None), 18, 18), + (Argument.optional(CoercedTo(str), default=None), 'caracal', 'caracal'), ], ) def test_valid_optional(argument, value, expected): @@ -90,19 +96,19 @@ def fn(x, this): assert p.annotation is annot assert p.default is inspect.Parameter.empty - assert p.annotation.validate('2', this={'other': 1}) == 3 + assert p.annotation.validate('2', {'other': 1}) == 3 with pytest.raises(TypeError): p.annotation.validate({}, valid=inspect.Parameter.empty) ofn = Argument.optional(fn) op = Parameter('test', annotation=ofn) - assert op.annotation._validator == option(fn, default=None) + assert op.annotation._validator == Option(fn, default=None) assert op.default is None - assert op.annotation.validate(None, this={'other': 1}) is None + assert op.annotation.validate(None, {'other': 1}) is None with pytest.raises(TypeError, match="annotation must be an instance of Argument"): - Parameter("wrong", annotation=Attribute("a")) + Parameter("wrong", annotation=Attribute(lambda x, context: x)) def test_signature(): @@ -128,7 +134,7 @@ def test(a: int, b: int, c: int = 1): sig = Signature.from_callable(test) assert sig.validate(2, 3) == {'a': 2, 'b': 3, 'c': 1} - with pytest.raises(TypeError): + with pytest.raises(ValidationError): sig.validate(2, 3, "4") args, kwargs = sig.unbind(sig.validate(2, 3)) @@ -148,7 +154,7 @@ def test(a: int, b: int, *args: int): assert sig.parameters['b'].annotation._typehint is int assert sig.parameters['args'].annotation._typehint is int - with pytest.raises(TypeError): + with pytest.raises(ValidationError): sig.validate(2, 3, 4, "5") args, kwargs = sig.unbind(sig.validate(2, 3, 4, 5)) @@ -210,21 +216,14 @@ def add_other(x, this): assert kwargs == {} -def as_float(x, this): - return float(x) - - -def as_tuple_of_floats(x, this): - return tuple(float(i) for i in x) - - -a = Parameter('a', annotation=Argument.required(validator=as_float)) -b = Parameter('b', annotation=Argument.required(validator=as_float)) -c = Parameter('c', annotation=Argument.default(default=0, validator=as_float)) +a = Parameter('a', annotation=Argument.required(CoercedTo(float))) +b = Parameter('b', annotation=Argument.required(CoercedTo(float))) +c = Parameter('c', annotation=Argument.default(default=0, validator=CoercedTo(float))) d = Parameter( - 'd', annotation=Argument.default(default=tuple(), validator=as_tuple_of_floats) + 'd', + annotation=Argument.default(default=tuple(), validator=TupleOf(CoercedTo(float))), ) -e = Parameter('e', annotation=Argument.optional(validator=as_float)) +e = Parameter('e', annotation=Argument.optional(validator=CoercedTo(float))) sig = Signature(parameters=[a, b, c, d, e]) @@ -242,7 +241,7 @@ def test_signature_unbind_with_empty_variadic(d): def test_annotated_function(): - @annotated(a=instance_of(int), b=instance_of(int), c=instance_of(int)) + @annotated(a=InstanceOf(int), b=InstanceOf(int), c=InstanceOf(int)) def test(a, b, c=1): return a + b + c @@ -251,10 +250,10 @@ def test(a, b, c=1): assert test(2, 3, c=4) == 9 assert test(a=2, b=3, c=4) == 9 - with pytest.raises(TypeError): + with pytest.raises(ValidationError): test(2, 3, c='4') - @annotated(a=instance_of(int)) + @annotated(a=InstanceOf(int)) def test(a, b, c=1): return (a, b, c) @@ -291,54 +290,55 @@ def test_wrong(a: int, b: int, c: int = 1) -> int: return "invalid result" assert test_ok(2, 3) == 6 - with pytest.raises(TypeError): + with pytest.raises(ValidationError): test_wrong(2, 3) def test_annotated_function_with_keyword_overrides(): - @annotated(b=instance_of(float)) + @annotated(b=InstanceOf(float)) def test(a: int, b: int, c: int = 1): return a + b + c - with pytest.raises(TypeError): + with pytest.raises(ValidationError): test(2, 3) assert test(2, 3.0) == 6.0 def test_annotated_function_with_list_overrides(): - @annotated([instance_of(int), instance_of(int), instance_of(float)]) + @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(float)]) def test(a: int, b: int, c: int = 1): return a + b + c - with pytest.raises(TypeError): + with pytest.raises(ValidationError): test(2, 3, 4) def test_annotated_function_with_list_overrides_and_return_override(): - @annotated( - [instance_of(int), instance_of(int), instance_of(float)], instance_of(float) - ) + @annotated([InstanceOf(int), InstanceOf(int), InstanceOf(float)], InstanceOf(float)) def test(a: int, b: int, c: int = 1): return a + b + c - with pytest.raises(TypeError): + with pytest.raises(ValidationError): test(2, 3, 4) assert test(2, 3, 4.0) == 9.0 +@pattern def short_str(x, this): if len(x) > 3: return x - raise ValueError("too short") + else: + return NoMatch +@pattern def endswith_d(x, this): if x.endswith('d'): return x else: - raise ValueError("doesn't end with d") + return NoMatch def test_annotated_function_with_complex_type_annotations(): @@ -349,11 +349,11 @@ def test(a: Annotated[str, short_str, endswith_d], b: Union[int, float]): assert test("abcd", 1) == ("abcd", 1) assert test("---d", 1.0) == ("---d", 1.0) - with pytest.raises(ValueError, match="doesn't end with d"): + with pytest.raises(ValidationError, match="doesn't match"): test("---c", 1) - with pytest.raises(ValueError, match="too short"): + with pytest.raises(ValidationError, match="doesn't match"): test("123", 1) - with pytest.raises(TypeError, match="passes none of the following rules"): + with pytest.raises(ValidationError, match="'qweqwe' doesn't match"): test("abcd", "qweqwe") @@ -385,7 +385,7 @@ def test(a: float, b: float, *args: int): assert test(1.0, 2.0, 3, 4) == 10.0 assert test(1.0, 2.0, 3, 4, 5) == 15.0 - with pytest.raises(TypeError): + with pytest.raises(ValidationError): test(1.0, 2.0, 3, 4, 5, 6.0) @@ -397,5 +397,5 @@ def test(a: float, b: float, **kwargs: int): assert test(1.0, 2.0, c=3, d=4) == 10.0 assert test(1.0, 2.0, c=3, d=4, e=5) == 15.0 - with pytest.raises(TypeError): + with pytest.raises(ValidationError): test(1.0, 2.0, c=3, d=4, e=5, f=6.0) diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index 3da4e5f0388a..eb43a4e2ff7e 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -4,7 +4,7 @@ from ibis.common.graph import Graph, Node, bfs, dfs, toposort from ibis.common.grounds import Annotable, Concrete -from ibis.common.validators import any_of, instance_of, tuple_of +from ibis.common.patterns import InstanceOf, TupleOf class MyNode(Node): @@ -111,21 +111,19 @@ def __hash__(self): return hash((self.__class__, self.__args__)) class Literal(Example): - value = instance_of(object) + value = InstanceOf(object) class BoolLiteral(Literal): - value = instance_of(bool) + value = InstanceOf(bool) class And(Example): - operands = tuple_of(instance_of(BoolLiteral)) + operands = TupleOf(InstanceOf(BoolLiteral)) class Or(Example): - operands = tuple_of(instance_of(BoolLiteral)) + operands = TupleOf(InstanceOf(BoolLiteral)) class Collect(Example): - arguments = tuple_of( - any_of([tuple_of(instance_of(Example)), instance_of(Example)]) - ) + arguments = TupleOf(TupleOf(InstanceOf(Example)) | InstanceOf(Example)) a = BoolLiteral(True) b = BoolLiteral(False) @@ -154,15 +152,15 @@ class Bool(Concrete, Node): pass class Value(Bool): - value = instance_of(bool) + value = InstanceOf(bool) class Either(Bool): - left = instance_of(Bool) - right = instance_of(Bool) + left = InstanceOf(Bool) + right = InstanceOf(Bool) class All(Bool): - arguments = tuple_of(instance_of(Bool)) - strict = instance_of(bool) + arguments = TupleOf(InstanceOf(Bool)) + strict = InstanceOf(bool) T, F = Value(True), Value(False) diff --git a/ibis/common/tests/test_grounds.py b/ibis/common/tests/test_grounds.py index d9756947de3f..2c3b352040e0 100644 --- a/ibis/common/tests/test_grounds.py +++ b/ibis/common/tests/test_grounds.py @@ -3,7 +3,7 @@ import copy import pickle import weakref -from typing import Mapping, Sequence, Tuple, TypeVar +from typing import Callable, Generic, Mapping, Optional, Sequence, Tuple, TypeVar, Union import pytest @@ -24,18 +24,28 @@ Base, Comparable, Concrete, + Final, Immutable, Singleton, ) -from ibis.common.validators import Coercible, Validator, instance_of, option, validator +from ibis.common.patterns import ( + Any, + CoercedTo, + Coercible, + InstanceOf, + Option, + Pattern, + TupleOf, + ValidationError, +) from ibis.tests.util import assert_pickle_roundtrip -is_any = instance_of(object) -is_bool = instance_of(bool) -is_float = instance_of(float) -is_int = instance_of(int) -is_str = instance_of(str) -is_list = instance_of(list) +is_any = InstanceOf(object) +is_bool = InstanceOf(bool) +is_float = InstanceOf(float) +is_int = InstanceOf(int) +is_str = InstanceOf(str) +is_list = InstanceOf(list) class Op(Annotable): @@ -43,11 +53,11 @@ class Op(Annotable): class Value(Op): - arg = instance_of(object) + arg = InstanceOf(object) class StringOp(Value): - arg = instance_of(str) + arg = InstanceOf(str) class BetweenSimple(Annotable): @@ -86,14 +96,14 @@ class VariadicArgsAndKeywords(Concrete): kwargs = varkwargs(is_int) -T = TypeVar('T') -K = TypeVar('K') -V = TypeVar('V') +T = TypeVar('T', covariant=True) +K = TypeVar('K', covariant=True) +V = TypeVar('V', covariant=True) class List(Concrete, Sequence[T], Coercible): @classmethod - def __coerce__(self, values): + def __coerce__(self, values, T=None): values = tuple(values) if values: head, *rest = values @@ -126,7 +136,7 @@ def __len__(self): class Map(Concrete, Mapping[K, V], Coercible): @classmethod - def __coerce__(self, pairs): + def __coerce__(self, pairs, K=None, V=None): pairs = dict(pairs) if pairs: head_key = next(iter(pairs)) @@ -184,6 +194,29 @@ class MyExpr(Concrete): c: Map[str, Integer] +class MyInt(int, Coercible): + @classmethod + def __coerce__(cls, value): + return cls(value) + + +class MyFloat(float, Coercible): + @classmethod + def __coerce__(cls, value): + return cls(value) + + +J = TypeVar("J", bound=MyInt, covariant=True) +F = TypeVar("F", bound=MyFloat, covariant=True) +N = TypeVar("N", bound=Union[MyInt, MyFloat], covariant=True) + + +class MyValue(Annotable, Generic[J, F]): + integer: J + floating: F + numeric: N + + def test_immutable(): class Foo(Immutable): __slots__ = ("a", "b") @@ -221,17 +254,31 @@ class Between(BetweenSimple): assert obj.__argnames__ == argnames assert obj.__slots__ == ("value", "lower", "upper") assert not hasattr(obj, "__dict__") + assert obj.__module__ == __name__ + assert type(obj).__qualname__ == "BetweenSimple" # test that a child without additional arguments doesn't have __dict__ obj = Between(10, lower=2) assert obj.__slots__ == tuple() assert not hasattr(obj, "__dict__") + assert obj.__module__ == __name__ + assert type(obj).__qualname__ == "test_annotable..Between" assert obj == obj.copy() assert obj == copy.copy(obj) obj2 = Between(10, lower=8) assert obj.copy(lower=8) == obj2 +def test_annotable_with_bound_typevars_properly_coerce_values(): + v = MyValue(1.1, 2.2, 3.3) + assert isinstance(v.integer, MyInt) + assert v.integer == 1 + assert isinstance(v.floating, MyFloat) + assert v.floating == 2.2 + assert isinstance(v.numeric, MyInt) + assert v.numeric == 3 + + def test_annotable_with_additional_attributes(): a = BetweenWithExtra(10, lower=2) b = BetweenWithExtra(10, lower=2) @@ -265,28 +312,28 @@ class Op(Annotable): assert p.custom == 1 -def test_annotable_with_type_annotations(): +def test_annotable_with_type_annotations() -> None: # TODO(kszucs): bar: str = None # should raise - class Op(Annotable): + class Op1(Annotable): foo: int bar: str = "" - p = Op(1) + p = Op1(1) assert p.foo == 1 assert not p.bar - class Op(Annotable): + class Op2(Annotable): bar: str = None - with pytest.raises(TypeError): - Op() + with pytest.raises(ValidationError): + Op2() def test_annotable_with_recursive_generic_type_annotations(): # testing cons list - validator = Validator.from_typehint(List[Integer]) + pattern = Pattern.from_typehint(List[Integer]) values = ["1", 2.0, 3] - result = validator(values) + result = pattern.validate(values, {}) expected = ConsList(1, ConsList(2, ConsList(3, EmptyList()))) assert result == expected assert result[0] == 1 @@ -297,9 +344,9 @@ def test_annotable_with_recursive_generic_type_annotations(): result[3] # testing cons map - validator = Validator.from_typehint(Map[Integer, Float]) + pattern = Pattern.from_typehint(Map[Integer, Float]) values = {"1": 2, 3: "4.0", 5: 6.0} - result = validator(values) + result = pattern.validate(values, {}) expected = ConsMap((1, 2.0), ConsMap((3, 4.0), ConsMap((5, 6.0), EmptyMap()))) assert result == expected assert result[1] == 2.0 @@ -595,7 +642,7 @@ def test_dont_copy_default_argument(): default = tuple() class Op(Annotable): - arg = optional(instance_of(tuple), default=default) + arg = optional(InstanceOf(tuple), default=default) op = Op() assert op.arg is default @@ -603,8 +650,8 @@ class Op(Annotable): def test_copy_mutable_with_default_attribute(): class Test(Annotable): - a = attribute(instance_of(dict), default={}) - b = argument(instance_of(str)) + a = attribute(InstanceOf(dict), default={}) + b = argument(InstanceOf(str)) @attribute.default def c(self): @@ -615,7 +662,7 @@ def c(self): assert t.b == "t" assert t.c == "T" - with pytest.raises(TypeError): + with pytest.raises(ValidationError): t.a = 1 t.a = {"map": "ping"} assert t.a == {"map": "ping"} @@ -635,17 +682,17 @@ def c(self): def test_slots_are_inherited_and_overridable(): class Op(Annotable): __slots__ = ('_cache',) # first definition - arg = validator(lambda x: x) + arg = Any() class StringOp(Op): - arg = validator(str) # new overridden slot + arg = CoercedTo(str) # new overridden slot class StringSplit(StringOp): - sep = validator(str) # new slot + sep = CoercedTo(str) # new slot class StringJoin(StringOp): __slots__ = ('_memoize',) # new slot - sep = validator(str) # new overridden slot + sep = CoercedTo(str) # new overridden slot assert Op.__slots__ == ('_cache', 'arg') assert StringOp.__slots__ == ('arg',) @@ -661,13 +708,13 @@ class Op(Annotable): __slots__ = ('_hash',) class Value(Annotable): - arg = instance_of(object) + arg = InstanceOf(object) class Reduction(Value): pass class UDF(Value): - func = validator(lambda fn, this: fn) + func = InstanceOf(Callable) class UDAF(UDF, Reduction): arity = is_int @@ -737,40 +784,57 @@ class Between(Value, ConditionalOp): ) -class Value(Annotable): +def test_immutability(): + class Value(Annotable, Immutable): + a = is_int + + op = Value(1) + with pytest.raises(AttributeError): + op.a = 3 + + +class BaseValue(Annotable): i = is_int j = attribute(is_int) -class Value2(Value): +class Value2(BaseValue): @attribute.default def k(self): return 3 -class Value3(Value): +class Value3(BaseValue): k = attribute(is_int, default=3) -class Value4(Value): - k = attribute(option(is_int), default=None) +class Value4(BaseValue): + k = attribute(Option(is_int), default=None) -# TODO(kszucs): add a test case with __dict__ added to __slots__ +def test_annotable_with_dict_slot(): + class Flexible(Annotable): + __slots__ = ('__dict__',) + + v = Flexible() + v.a = 1 + v.b = 2 + assert v.a == 1 + assert v.b == 2 def test_annotable_attribute(): with pytest.raises(TypeError, match="too many positional arguments"): - Value(1, 2) + BaseValue(1, 2) - v = Value(1) + v = BaseValue(1) assert v.__slots__ == ('i', 'j') assert v.i == 1 assert not hasattr(v, 'j') v.j = 2 assert v.j == 2 - with pytest.raises(TypeError): + with pytest.raises(ValidationError): v.j = 'foo' @@ -792,14 +856,14 @@ def test_annotable_attribute_init(): def test_annotable_mutability_and_serialization(): - v_ = Value(1) + v_ = BaseValue(1) v_.j = 2 - v = Value(1) + v = BaseValue(1) v.j = 2 assert v_ == v assert v_.j == v.j == 2 - assert repr(v) == "Value(i=1)" + assert repr(v) == "BaseValue(i=1)" w = pickle.loads(pickle.dumps(v)) assert w.i == 1 assert w.j == 2 @@ -809,7 +873,7 @@ def test_annotable_mutability_and_serialization(): assert v_ != v w = pickle.loads(pickle.dumps(v)) assert w == v - assert repr(w) == "Value(i=1)" + assert repr(w) == "BaseValue(i=1)" def test_initialized_attribute_basics(): @@ -1005,7 +1069,7 @@ def test_singleton_basics(): assert OneAndOnly.__instances__[key] is one -def test_singleton_lifetime(): +def test_singleton_lifetime() -> None: one = OneAndOnly() assert len(OneAndOnly.__instances__) == 1 @@ -1013,7 +1077,7 @@ def test_singleton_lifetime(): assert len(OneAndOnly.__instances__) == 0 -def test_singleton_with_argument(): +def test_singleton_with_argument() -> None: dt1 = DataType(nullable=True) dt2 = DataType(nullable=False) dt3 = DataType(nullable=True) @@ -1030,26 +1094,26 @@ def test_singleton_with_argument(): assert len(DataType.__instances__) == 0 -def test_composition_of_annotable_and_singleton(): +def test_composition_of_annotable_and_singleton() -> None: class AnnSing(Annotable, Singleton): - value = validator(lambda x, this: int(x)) + value = CoercedTo(int) class SingAnn(Singleton, Annotable): # this is the preferable method resolution order - value = validator(lambda x, this: int(x)) + value = CoercedTo(int) # arguments looked up after validation - obj = AnnSing("3") - assert AnnSing("3") is obj - assert AnnSing(3) is obj - assert AnnSing(3.0) is obj + obj1 = AnnSing("3") + assert AnnSing("3") is obj1 + assert AnnSing(3) is obj1 + assert AnnSing(3.0) is obj1 # arguments looked up before validation - obj = SingAnn("3") - assert SingAnn("3") is obj - obj2 = SingAnn(3) - assert obj2 is not obj - assert SingAnn(3) is obj2 + obj2 = SingAnn("3") + assert SingAnn("3") is obj2 + obj3 = SingAnn(3) + assert obj3 is not obj2 + assert SingAnn(3) is obj3 def test_concrete(): @@ -1098,10 +1162,10 @@ def test_concrete(): def test_composition_of_concrete_and_singleton(): class ConcSing(Concrete, Singleton): - value = validator(lambda x, this: int(x)) + value = CoercedTo(int) class SingConc(Singleton, Concrete): - value = validator(lambda x, this: int(x)) + value = CoercedTo(int) # arguments looked up after validation obj = ConcSing("3") @@ -1127,3 +1191,47 @@ class Test2(Test, something="value", value="something"): pass assert Test2.kwargs == {"something": "value", "value": "something"} + + +def test_argument_order_using_optional_annotations(): + class Case1(Annotable): + results: Optional[Tuple[int]] = () + default: Optional[int] = None + + class SimpleCase1(Case1): + base: int + cases: Optional[Tuple[int]] = () + + class Case2(Annotable): + results = optional(TupleOf(is_int), default=()) + default = optional(is_int) + + class SimpleCase2(Case1): + base = is_int + cases = optional(TupleOf(is_int), default=()) + + assert ( + SimpleCase1.__argnames__ + == SimpleCase2.__argnames__ + == ("base", "cases", "results", "default") + ) + + +def test_annotable_with_optional_coercible_typehint(): + class Example(Annotable): + value: Optional[MyInt] = None + + assert Example().value is None + assert Example(None).value is None + assert Example(1).value == 1 + assert isinstance(Example(1).value, MyInt) + + +def test_final(): + class A(Final): + pass + + with pytest.raises(TypeError, match="Cannot inherit from final class .*A.*"): + + class B(A): + pass diff --git a/ibis/common/tests/test_grounds_py310.py b/ibis/common/tests/test_grounds_py310.py index 63623995308e..ed65dada8651 100644 --- a/ibis/common/tests/test_grounds_py310.py +++ b/ibis/common/tests/test_grounds_py310.py @@ -1,11 +1,11 @@ from ibis.common.grounds import Annotable -from ibis.common.validators import instance_of +from ibis.common.patterns import InstanceOf -IsAny = instance_of(object) -IsBool = instance_of(bool) -IsFloat = instance_of(float) -IsInt = instance_of(int) -IsStr = instance_of(str) +IsAny = InstanceOf(object) +IsBool = InstanceOf(bool) +IsFloat = InstanceOf(float) +IsInt = InstanceOf(int) +IsStr = InstanceOf(str) class Node(Annotable): @@ -13,16 +13,16 @@ class Node(Annotable): class Literal(Node): - value = instance_of((int, float, bool, str)) - dtype = instance_of(type) + value = InstanceOf((int, float, bool, str)) + dtype = InstanceOf(type) def __add__(self, other): return Add(self, other) class BinaryOperation(Annotable): - left = instance_of(Node) - right = instance_of(Node) + left = InstanceOf(Node) + right = InstanceOf(Node) class Add(BinaryOperation): diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index fdba96a7b2fd..151d90d7232c 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -173,8 +173,8 @@ def test_lazy_instance_of(): assert p.match("foo", context={}) is NoMatch -T = TypeVar("T") -S = TypeVar("S") +T = TypeVar("T", covariant=True) +S = TypeVar("S", covariant=True) @dataclass @@ -184,7 +184,7 @@ class My(Generic[T, S]): c: str -def test_generic_instance_of(): +def test_generic_instance_of_with_covariant_typevar(): p = Pattern.from_typehint(My[int, AnyType]) assert p.match(My(1, 2, "3"), context={}) == My(1, 2, "3") @@ -426,7 +426,15 @@ def func_with_mandatory_kwargs(*, c): assert p.match(func_with_kwargs, context={}) is NoMatch p = CallableWith([InstanceOf(int)] * 4, InstanceOf(int)) - assert p.match(func_with_args, context={}) == func_with_args + wrapped = p.match(func_with_args, context={}) + assert wrapped(1, 2, 3, 4) == 10 + + p = CallableWith([InstanceOf(int), InstanceOf(str)], InstanceOf(str)) + wrapped = p.match(func, context={}) + assert wrapped(1, "st") == "1st" + + with pytest.raises(ValidationError, match="2 doesn't match InstanceOf"): + wrapped(1, 2) def test_pattern_list(): @@ -662,6 +670,7 @@ def test_pattern_decorator(): (str, InstanceOf(str)), (bool, InstanceOf(bool)), (Optional[int], Option(InstanceOf(int))), + (Optional[Union[str, int]], Option(AnyOf(InstanceOf(str), InstanceOf(int)))), (Union[int, str], AnyOf(InstanceOf(int), InstanceOf(str))), (Annotated[int, Min(3)], AllOf(InstanceOf(int), Min(3))), (List[int], SequenceOf(InstanceOf(int), list)), diff --git a/ibis/common/tests/test_temporal.py b/ibis/common/tests/test_temporal.py index 0e55b2c5b3bf..890b312a5106 100644 --- a/ibis/common/tests/test_temporal.py +++ b/ibis/common/tests/test_temporal.py @@ -1,6 +1,7 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +import itertools +from datetime import date, datetime, time, timedelta, timezone import dateutil import pandas as pd @@ -11,7 +12,9 @@ from ibis.common.patterns import CoercedTo from ibis.common.temporal import ( + DateUnit, IntervalUnit, + TimeUnit, normalize_datetime, normalize_timedelta, normalize_timezone, @@ -48,9 +51,9 @@ def test_interval_units(singular, plural, short): def test_interval_unit_coercions(singular, plural, short): u = IntervalUnit[singular.upper()] v = CoercedTo(IntervalUnit) - assert v.match(singular, {}) == u - assert v.match(plural, {}) == u - assert v.match(short, {}) == u + assert v.validate(singular, {}) == u + assert v.validate(plural, {}) == u + assert v.validate(short, {}) == u @pytest.mark.parametrize( @@ -67,7 +70,7 @@ def test_interval_unit_coercions(singular, plural, short): ) def test_interval_unit_aliases(alias, expected): v = CoercedTo(IntervalUnit) - assert v.match(alias, {}) == IntervalUnit(expected) + assert v.validate(alias, {}) == IntervalUnit(expected) @pytest.mark.parametrize( @@ -113,6 +116,14 @@ def test_normalize_timedelta_invalid(value, unit): normalize_timedelta(value, unit) +def test_interval_unit_compatibility(): + v = CoercedTo(IntervalUnit) + for unit in itertools.chain(DateUnit, TimeUnit): + interval = v.validate(unit, {}) + assert isinstance(interval, IntervalUnit) + assert unit.value == interval.value + + @pytest.mark.parametrize( ("value", "expected"), [ @@ -178,6 +189,8 @@ def test_normalize_timezone(value, expected): (1000, datetime(1970, 1, 1, 0, 16, 40)), # floating point (1000.123, datetime(1970, 1, 1, 0, 16, 40, 123000)), + # time object + (time(0, 0, 0, 1), datetime.combine(date.today(), time(0, 0, 0, 1))), ], ) def test_normalize_datetime(value, expected): diff --git a/ibis/common/tests/test_typing.py b/ibis/common/tests/test_typing.py index 83cccef108b6..246652e65e39 100644 --- a/ibis/common/tests/test_typing.py +++ b/ibis/common/tests/test_typing.py @@ -1,8 +1,12 @@ from __future__ import annotations -from typing import Generic, Optional, TypeVar, Union +from typing import Generic, Optional, Union + +from typing_extensions import TypeVar from ibis.common.typing import ( + DefaultTypeVars, + Sentinel, evaluate_annotations, get_bound_typevars, get_type_hints, @@ -95,5 +99,49 @@ def test_get_type_params() -> None: def test_get_bound_typevars() -> None: - assert get_bound_typevars(A[int, float, str]) == {'t': int, 's': float, 'u': str} - assert get_bound_typevars(B[int, bool]) == {'t': int, 's': bool, 'u': bytes} + expected = { + T: ('t', int), + S: ('s', float), + U: ('u', str), + } + assert get_bound_typevars(A[int, float, str]) == expected + + expected = { + T: ('t', int), + S: ('s', bool), + U: ('u', bytes), + } + assert get_bound_typevars(B[int, bool]) == expected + + +def test_default_type_vars(): + T = TypeVar("T") + U = TypeVar("U", default=float) + + class Test(DefaultTypeVars, Generic[T, U]): + pass + + assert Test[int, float].__parameters__ == () + assert Test[int, float].__args__ == (int, float) + + assert Test[int].__parameters__ == () + assert Test[int].__args__ == (int, float) + + +def test_sentinel(): + class missing(metaclass=Sentinel): + """marker for missing value""" + + class missing1(metaclass=Sentinel): + """marker for missing value""" + + assert type(missing) is Sentinel + expected = ".missing'>" + assert repr(missing) == expected + assert missing.__name__ == "missing" + assert missing.__doc__ == "marker for missing value" + + assert missing is missing + assert missing is not missing1 + assert missing != missing1 + assert missing != "missing" diff --git a/ibis/common/tests/test_validators.py b/ibis/common/tests/test_validators.py deleted file mode 100644 index e203f1cdc272..000000000000 --- a/ibis/common/tests/test_validators.py +++ /dev/null @@ -1,300 +0,0 @@ -from __future__ import annotations - -import sys -from typing import ( - Callable, - Dict, - List, - Literal, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) - -import pytest -from typing_extensions import Annotated - -from ibis.common.collections import frozendict -from ibis.common.validators import ( - Coercible, - Validator, - all_of, - any_of, - bool_, - callable_with, - coerced_to, - dict_of, - equal_to, - frozendict_of, - instance_of, - int_, - isin, - list_of, - mapping_of, - min_, - pair_of, - ref, - sequence_of, - str_, - tuple_of, -) - -T = TypeVar("T") - - -def test_ref(): - assert ref("b", this={"a": 1, "b": 2}) == 2 - - -@pytest.mark.parametrize( - ('validator', 'value', 'expected'), - [ - (bool_, True, True), - (str_, "foo", "foo"), - (int_, 8, 8), - (int_(min=10), 11, 11), - (min_(3), 5, 5), - (instance_of(int), 1, 1), - (instance_of(float), 1.0, 1.0), - (isin({"a", "b"}), "a", "a"), - (isin({"a": 1, "b": 2}), "a", "a"), - (isin(['a', 'b']), 'a', 'a'), - (isin(('a', 'b')), 'b', 'b'), - (isin({'a', 'b', 'c'}), 'c', 'c'), - (tuple_of(instance_of(int)), (1, 2, 3), (1, 2, 3)), - (tuple_of((instance_of(int), instance_of(str))), (1, "a"), (1, "a")), - (list_of(instance_of(str)), ["a", "b"], ["a", "b"]), - (any_of((str_, int_(max=8))), "foo", "foo"), - (any_of((str_, int_(max=8))), 7, 7), - (all_of((int_, min_(3), min_(8))), 10, 10), - (dict_of(str_, int_), {"a": 1, "b": 2}, {"a": 1, "b": 2}), - (pair_of(bool_, str_), (True, "foo"), (True, "foo")), - (equal_to(1), 1, 1), - (equal_to(None), None, None), - (coerced_to(int), "1", 1), - ], -) -def test_validators_passing(validator, value, expected): - assert validator(value) == expected - - -@pytest.mark.parametrize( - ('validator', 'value'), - [ - (bool_, "foo"), - (str_, True), - (int_, 8.1), - (int_(min=10), 9), - (min_(3), 2), - (instance_of(int), None), - (instance_of(float), 1), - (isin(["a", "b"]), "c"), - (isin({"a", "b"}), "c"), - (isin({"a": 1, "b": 2}), "d"), - (tuple_of(instance_of(int)), (1, 2.0, 3)), - (list_of(instance_of(str)), ["a", "b", None]), - (any_of((str_, int_(max=8))), 3.14), - (any_of((str_, int_(max=8))), 9), - (all_of((int_, min_(3), min_(8))), 7), - (dict_of(int_, str_), {"a": 1, "b": 2}), - (pair_of(bool_, str_), (True, True, True)), - (pair_of(bool_, str_), ("str", True)), - (equal_to(1), 2), - ], -) -def test_validators_failing(validator, value): - with pytest.raises((TypeError, ValueError)): - validator(value) - - -def short_str(x, this): - return len(x) > 3 - - -def endswith_d(x, this): - return x.endswith("d") - - -@pytest.mark.parametrize( - ("annot", "expected"), - [ - (int, instance_of(int)), - (str, instance_of(str)), - (bool, instance_of(bool)), - (Optional[int], any_of((instance_of(int), instance_of(type(None))))), - (Union[int, str], any_of((instance_of(int), instance_of(str)))), - (Annotated[int, min_(3)], all_of((instance_of(int), min_(3)))), - ( - Annotated[str, short_str, endswith_d], - all_of((instance_of(str), short_str, endswith_d)), - ), - (List[int], sequence_of(instance_of(int), type=coerced_to(list))), - ( - Tuple[int, float, str], - tuple_of( - (instance_of(int), instance_of(float), instance_of(str)), - type=coerced_to(tuple), - ), - ), - (Tuple[int, ...], tuple_of(instance_of(int), type=coerced_to(tuple))), - ( - Dict[str, float], - dict_of(instance_of(str), instance_of(float), type=coerced_to(dict)), - ), - ( - frozendict[str, int], - frozendict_of( - instance_of(str), instance_of(int), type=coerced_to(frozendict) - ), - ), - (Literal["alpha", "beta", "gamma"], isin(("alpha", "beta", "gamma"))), - ( - Callable[[str, int], str], - callable_with((instance_of(str), instance_of(int)), instance_of(str)), - ), - (Callable, instance_of(Callable)), - ], -) -def test_validator_from_typehint(annot, expected): - assert Validator.from_typehint(annot) == expected - - -@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") -def test_validator_from_typehint_uniontype(): - # uniontype marks `type1 | type2` annotations and it's different from - # Union[type1, type2] - validator = Validator.from_typehint(str | int | float) - assert validator == any_of((instance_of(str), instance_of(int), instance_of(float))) - - -class PlusOne(Coercible): - def __init__(self, value): - self.value = value - - @classmethod - def __coerce__(cls, obj): - return cls(obj + 1) - - def __eq__(self, other): - return type(self) == type(other) and self.value == other.value - - -class PlusOneRaise(PlusOne): - @classmethod - def __coerce__(cls, obj): - raise TypeError("raise on coercion") - - -class PlusOneChild(PlusOne): - pass - - -class PlusTwo(PlusOne): - @classmethod - def __coerce__(cls, obj): - return obj + 2 - - -def test_coercible_protocol(): - s = Validator.from_typehint(PlusOne) - assert s(1) == PlusOne(2) - assert s(10) == PlusOne(11) - - -def test_coercible_bypass_coercion(): - s = Validator.from_typehint(PlusOneRaise) - # bypass coercion since it's already an instance of SomethingRaise - assert s(PlusOneRaise(10)) == PlusOneRaise(10) - # but actually call __coerce__ if it's not an instance - with pytest.raises(TypeError, match="raise on coercion"): - s(10) - - -def test_coercible_checks_type(): - s = Validator.from_typehint(PlusOneChild) - v = Validator.from_typehint(PlusTwo) - - assert s(1) == PlusOneChild(2) - - assert PlusTwo.__coerce__(1) == 3 - with pytest.raises(TypeError, match="not an instance of .*PlusTwo.*"): - v(1) - - -class DoubledList(List[T]): - @classmethod - def __coerce__(cls, obj): - return cls(list(obj) * 2) - - -def test_coercible_sequence_type(): - s = Validator.from_typehint(Sequence[PlusOne]) - with pytest.raises(TypeError, match=r"Sequence\(\) takes no arguments"): - s([1, 2, 3]) - - s = Validator.from_typehint(List[PlusOne]) - assert s == sequence_of(coerced_to(PlusOne), type=coerced_to(list)) - assert s([1, 2, 3]) == [PlusOne(2), PlusOne(3), PlusOne(4)] - - s = Validator.from_typehint(Tuple[PlusOne, ...]) - assert s == tuple_of(coerced_to(PlusOne), type=coerced_to(tuple)) - assert s([1, 2, 3]) == (PlusOne(2), PlusOne(3), PlusOne(4)) - - s = Validator.from_typehint(DoubledList[PlusOne]) - assert s == sequence_of(coerced_to(PlusOne), type=coerced_to(DoubledList)) - assert s([1, 2, 3]) == DoubledList( - [PlusOne(2), PlusOne(3), PlusOne(4), PlusOne(2), PlusOne(3), PlusOne(4)] - ) - - -def test_mapping_of(): - value = {"a": 1, "b": 2} - assert mapping_of(str, int, value, type=dict) == value - assert mapping_of(str, int, value, type=frozendict) == frozendict(value) - - with pytest.raises(TypeError, match="Argument must be a mapping"): - mapping_of(str, float, 10, type=dict) - - -def test_callable_with(): - def func(a, b): - return str(a) + b - - def func_with_args(a, b, *args): - return sum((a, b) + args) - - def func_with_kwargs(a, b, c=1, **kwargs): - return str(a) + b + str(c) - - def func_with_mandatory_kwargs(*, c): - return c - - msg = "Argument must be a callable" - with pytest.raises(TypeError, match=msg): - callable_with([instance_of(int), instance_of(str)], 10, "string") - - msg = "Callable has mandatory keyword-only arguments which cannot be specified" - with pytest.raises(TypeError, match=msg): - callable_with([instance_of(int)], instance_of(str), func_with_mandatory_kwargs) - - msg = "Callable has more positional arguments than expected" - with pytest.raises(TypeError, match=msg): - callable_with([instance_of(int)] * 2, instance_of(str), func_with_kwargs) - - msg = "Callable has less positional arguments than expected" - with pytest.raises(TypeError, match=msg): - callable_with([instance_of(int)] * 4, instance_of(str), func_with_kwargs) - - wrapped = callable_with([instance_of(int)] * 4, instance_of(int), func_with_args) - assert wrapped(1, 2, 3, 4) == 10 - - wrapped = callable_with( - [instance_of(int), instance_of(str)], instance_of(str), func - ) - assert wrapped(1, "st") == "1st" - - msg = "Given argument with type is not an instance of " - with pytest.raises(TypeError, match=msg): - wrapped(1, 2) diff --git a/ibis/common/typing.py b/ibis/common/typing.py index 0ca43c994b95..0e2c47ba81d9 100644 --- a/ibis/common/typing.py +++ b/ibis/common/typing.py @@ -1,11 +1,13 @@ from __future__ import annotations import sys +from itertools import zip_longest from typing import ( Any, Dict, Generic, # noqa: F401 Optional, + Tuple, TypeVar, get_args, get_origin, @@ -15,11 +17,18 @@ from ibis.common.caching import memoize -Namespace = Dict[str, Any] +try: + from types import UnionType +except ImportError: + UnionType = object() + T = TypeVar("T") U = TypeVar("U") +Namespace = Dict[str, Any] +VarTuple = Tuple[T, ...] + @memoize def get_type_hints( @@ -126,7 +135,7 @@ def get_bound_typevars(obj: Any) -> dict[str, Any]: ... b: U ... >>> get_bound_typevars(MyStruct[int, str]) - {'a': , 'b': } + {~T: ('a', ), ~U: ('b', )} >>> >>> class MyStruct(Generic[T, U]): ... a: T @@ -136,7 +145,7 @@ def get_bound_typevars(obj: Any) -> dict[str, Any]: ... ... ... >>> get_bound_typevars(MyStruct[float, bytes]) - {'a': , 'myprop': } + {~T: ('a', ), ~U: ('myprop', )} """ origin = get_origin(obj) or obj hints = get_type_hints(origin, include_properties=True) @@ -145,7 +154,7 @@ def get_bound_typevars(obj: Any) -> dict[str, Any]: result = {} for attr, typ in hints.items(): if isinstance(typ, TypeVar): - result[attr] = params[typ.__name__] + result[typ] = (attr, params[typ.__name__]) return result @@ -181,3 +190,27 @@ def evaluate_annotations( k: eval(v, globalns, localns) if isinstance(v, str) else v # noqa: PGH001 for k, v in annots.items() } + + +class DefaultTypeVars: + """Enable using default type variables in generic classes (PEP-0696).""" + + __slots__ = () + + def __class_getitem__(cls, params): + params = params if isinstance(params, tuple) else (params,) + pairs = zip_longest(params, cls.__parameters__) + params = tuple(p.__default__ if t is None else t for t, p in pairs) + return super().__class_getitem__(params) + + +class Sentinel(type): + """Create type-annotable unique objects.""" + + def __new__(cls, name, bases, namespace, **kwargs): + if bases: + raise TypeError("Sentinels cannot be subclassed") + return super().__new__(cls, name, bases, namespace, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise TypeError("Sentinels are not constructible") diff --git a/ibis/common/validators.py b/ibis/common/validators.py deleted file mode 100644 index 61765bd82498..000000000000 --- a/ibis/common/validators.py +++ /dev/null @@ -1,632 +0,0 @@ -from __future__ import annotations - -import math -from contextlib import suppress -from inspect import Parameter -from typing import ( - Any, - Callable, - Container, - Iterable, - Literal, - Mapping, - Sequence, - Tuple, - TypeVar, - Union, -) - -import toolz -from typing_extensions import Annotated, get_args, get_origin - -from ibis.common.collections import FrozenDict -from ibis.common.dispatch import lazy_singledispatch -from ibis.common.exceptions import IbisTypeError -from ibis.common.patterns import Coercible -from ibis.util import flatten_iterable, is_function, is_iterable - -try: - from types import UnionType -except ImportError: - UnionType = object() - -K = TypeVar('K') -V = TypeVar('V') -T = TypeVar('T') - - -class Validator(Callable): - """Abstract base class for defining argument validators.""" - - __slots__ = () - - @classmethod - def from_typehint(cls, annot: type) -> Validator: - """Construct a validator from a python type annotation. - - Parameters - ---------- - annot - The typehint annotation to construct a validator from. - - Returns - ------- - validator - A validator that can be used to validate objects, typically function - arguments. - """ - # TODO(kszucs): cache the result of this function - origin, args = get_origin(annot), get_args(annot) - - if origin is None: - if annot is Any: - return any_ - elif isinstance(annot, TypeVar): - return any_ - elif issubclass(annot, Coercible): - return coerced_to(annot) - else: - return instance_of(annot) - elif origin is Literal: - return isin(args) - elif origin is UnionType or origin is Union: - inners = map(cls.from_typehint, args) - return any_of(tuple(inners)) - elif origin is Annotated: - annot, *extras = args - return all_of((instance_of(annot), *extras)) - elif issubclass(origin, Tuple): - first, *rest = args - if rest == [Ellipsis]: - inners = cls.from_typehint(first) - else: - inners = tuple(map(cls.from_typehint, args)) - return tuple_of(inners, type=coerced_to(origin)) - elif issubclass(origin, Sequence): - (value_inner,) = map(cls.from_typehint, args) - return sequence_of(value_inner, type=coerced_to(origin)) - elif issubclass(origin, Mapping): - key_inner, value_inner = map(cls.from_typehint, args) - return mapping_of(key_inner, value_inner, type=coerced_to(origin)) - elif issubclass(origin, Callable): - if args: - arg_inners = tuple(map(cls.from_typehint, args[0])) - return_inner = cls.from_typehint(args[1]) - return callable_with(arg_inners, return_inner) - else: - return instance_of(Callable) - else: - raise NotImplementedError( - f"Cannot create validator from annotation {annot} {origin}" - ) - - -# TODO(kszucs): in order to cache valiadator instances we could subclass -# grounds.Singleton, but the imports would need to be reorganized -class Curried(toolz.curry, Validator): - """Enable convenient validator definition by decorating plain functions.""" - - def __repr__(self): - return '{}({}{})'.format( - self.func.__name__, - repr(self.args)[1:-1], - ', '.join(f'{k}={v!r}' for k, v in self.keywords.items()), - ) - - -validator = Curried - - -@validator -def ref(key: str, *, this: Mapping[str, Any]) -> Any: - """Retrieve a value from the already validated state. - - Parameters - ---------- - key - The key to retrieve from the state. - this - The state to retrieve the value from, usually the result of an annotated - function signature validation (including annotable object creation). - - Returns - ------- - value - The value retrieved from the state. - """ - try: - return this[key] - except KeyError: - raise IbisTypeError(f"Could not get `{key}` from {this}") - - -@validator -def any_(arg: Any, **kwargs: Any) -> Any: - """Validator that accepts any value, basically a no-op.""" - return arg - - -@validator -def option(inner: Validator, arg: Any, *, default: Any = None, **kwargs) -> Any: - """Validator that accepts `None` or a value that passes the inner validator. - - Parameters - ---------- - inner - The inner validator to use. - arg - The value to validate. - default - The default value to use if `arg` is `None`. - kwargs - Additional keyword arguments to pass to the inner validator. - - Returns - ------- - validated - The validated value or the default value if `arg` is `None`. - """ - if arg is None: - if default is None: - return None - elif is_function(default): - arg = default() - else: - arg = default - return inner(arg, **kwargs) - - -@validator -def instance_of(klasses: type | tuple[type], arg: Any, **kwargs: Any) -> Any: - """Require that a value has a particular Python type. - - Parameters - ---------- - klasses - The type or tuple of types to validate against. - arg - The value to validate. - kwargs - Omitted keyword arguments. - - Returns - ------- - validated - The input argument if it is an instance of the given type(s). - """ - if not isinstance(arg, klasses): - # TODO(kszucs): unify errors coming from various validators - raise IbisTypeError( - f"Given argument with type {type(arg)} is not an instance of {klasses}" - ) - return arg - - -@validator -def equal_to(value: T, arg: T, **kwargs: Any) -> T: - """Require that a value is equal to a particular value.""" - if arg != value: - raise IbisTypeError(f"Given argument {arg} is not equal to {value}") - return arg - - -@validator -def coerced_to(klass: T, arg: Any, **kwargs: Any) -> T: - """Force a value to have a particular Python type. - - If a Coercible subclass is passed, the `__coerce__` method will be used to - coerce the value. Otherwise, the type will be called with the value as the - only argument. - - Parameters - ---------- - klass - The type to coerce to. - arg - The value to coerce. - kwargs - Additional keyword arguments to pass to the inner validator. - - Returns - ------- - validated - The coerced value which is checked to be an instance of the given type. - """ - if isinstance(arg, klass): - return arg - try: - arg = klass.__coerce__(arg) - except AttributeError: - arg = klass(arg) - return instance_of(klass, arg, **kwargs) - - -class lazy_instance_of(Validator): - """A version of `instance_of` that accepts qualnames instead of imported classes. - - Useful for delaying imports. - """ - - def __init__(self, classes): - classes = (classes,) if isinstance(classes, str) else tuple(classes) - self._classes = classes - self._check = lazy_singledispatch(lambda x: False) - self._check.register(classes, lambda x: True) - - def __repr__(self): - return f"lazy_instance_of(classes={self._classes!r})" - - def __call__(self, arg, **kwargs): - if self._check(arg): - return arg - raise IbisTypeError( - f"Given argument with type {type(arg)} is not an instance of " - f"{self._classes}" - ) - - -@validator -def any_of(inners: Iterable[Validator], arg: Any, **kwargs: Any) -> Any: - """At least one of the inner validators must pass. - - Parameters - ---------- - inners - Iterable of value validators, each of which is applied from left to right and - the first one that passes gets returned. - arg - Value to be validated. - kwargs - Keyword arguments - - Returns - ------- - arg : Any - Value maybe coerced by inner validators to the appropriate types - """ - for inner in inners: - with suppress(IbisTypeError, ValueError): - return inner(arg, **kwargs) - - raise IbisTypeError( - "argument passes none of the following rules: " - f"{', '.join(map(repr, inners))}" - ) - - -one_of = any_of - - -@validator -def all_of(inners: Iterable[Validator], arg: Any, **kwargs: Any) -> Any: - """Construct a validator of other valdiators. - - Parameters - ---------- - inners - Iterable of value validators, each of which is applied from left to - right so `allof([rule1, rule2], arg)` is the same as `rule2(rule1(arg))`. - arg - Value to be validated. - kwargs - Keyword arguments - - Returns - ------- - arg : Any - Value maybe coerced by inner validators to the appropriate types - """ - for inner in inners: - arg = inner(arg, **kwargs) - return arg - - -@validator -def isin(values: Container, arg: T, **kwargs: Any) -> T: - """Check if the value is in the given container. - - Parameters - ---------- - values - Container of values to check against. - arg - Value to be looked for. - kwargs - Omitted keyword arguments. - - Returns - ------- - validated - The input argument if it is in the given container. - """ - if arg not in values: - raise ValueError(f'Value with type {type(arg)} is not in {values!r}') - return arg - - -@validator -def map_to(mapping: Mapping[K, V], variant: K, **kwargs: Any) -> V: - """Check if the value is in the given mapping and return the corresponding value. - - Parameters - ---------- - mapping - Mapping of values to check against. - variant - Value to be looked for. - kwargs - Omitted keyword arguments. - - Returns - ------- - validated - The value corresponding to the input argument if it is in the given mapping. - """ - try: - return mapping[variant] - except KeyError: - raise ValueError(f'Value with type {type(variant)} is not in {mapping!r}') - - -@validator -def pair_of( - inner1: Validator, inner2: Validator, arg: Any, *, type=tuple, **kwargs -) -> tuple[Any, Any]: - """Validate a pair of values (tuple of 2 items). - - Parameters - ---------- - inner1 - Validator to apply to the first element of the pair. - inner2 - Validator to apply to the second element of the pair. - arg - Pair to validate. - type - Type to coerce the pair to, typically a tuple. - kwargs - Additional keyword arguments to pass to the inner validator. - - Returns - ------- - validated - The validated pair with each element coerced according to the inner validators. - """ - try: - first, second = arg - except KeyError: - raise IbisTypeError('Argument must be a pair') - return type((inner1(first, **kwargs), inner2(second, **kwargs))) - - -@validator -def sequence_of( - inner: Validator, - arg: Any, - *, - type: Callable[[Iterable], T], - length: int | None = None, - min_length: int = 0, - max_length: int = math.inf, - flatten: bool = False, - **kwargs: Any, -) -> T: - """Validate a sequence of values. - - Parameters - ---------- - inner - Validator to apply to each element of the sequence. - arg - Sequence to validate. - type - Type to coerce the sequence to, typically a tuple or list. - length - If specified, the sequence must have exactly this length. - min_length - The sequence must have at least this many elements. - max_length - The sequence must have at most this many elements. - flatten - If True, the sequence is flattened before validation. - kwargs - Keyword arguments to pass to the inner validator. - - Returns - ------- - validated - The coerced sequence containing validated elements. - """ - if not is_iterable(arg): - raise IbisTypeError('Argument must be a sequence') - - if length is not None: - min_length = max_length = length - if len(arg) < min_length: - raise IbisTypeError(f'Arg must have at least {min_length} number of elements') - if len(arg) > max_length: - raise IbisTypeError(f'Arg must have at most {max_length} number of elements') - - if flatten: - arg = flatten_iterable(arg) - - return type(inner(item, **kwargs) for item in arg) - - -@validator -def tuple_of(inner: Validator | tuple[Validator], arg: Any, *, type=tuple, **kwargs): - """Validate a tuple of values. - - Parameters - ---------- - inner - Either a balidator to apply to each element of the tuple or a tuple of - validators which are applied to the elements of the tuple in order. - arg - Sequence to validate. - type - Type to coerce the sequence to, a tuple by default. - kwargs - Keyword arguments to pass to the inner validator. - - Returns - ------- - validated - The coerced tuple containing validated elements. - """ - if isinstance(inner, tuple): - if is_iterable(arg): - arg = tuple(arg) - else: - raise IbisTypeError('Argument must be a sequence') - - if len(inner) != len(arg): - raise IbisTypeError(f'Argument must has length {len(inner)}') - - return type(validator(item, **kwargs) for validator, item in zip(inner, arg)) - else: - return sequence_of(inner, arg, type=type, **kwargs) - - -@validator -def mapping_of( - key_inner: Validator, - value_inner: Validator, - arg: Any, - *, - type: T, - **kwargs: Any, -) -> T: - """Validate a mapping of values. - - Parameters - ---------- - key_inner - Validator to apply to each key of the mapping. - value_inner - Validator to apply to each value of the mapping. - arg - Mapping to validate. - type - Type to coerce the mapping to, typically a dict. - kwargs - Keyword arguments to pass to the inner validator. - - Returns - ------- - validated - The coerced mapping containing validated keys and values. - """ - if not isinstance(arg, Mapping): - raise IbisTypeError('Argument must be a mapping') - return type( - {key_inner(k, **kwargs): value_inner(v, **kwargs) for k, v in arg.items()} - ) - - -@validator -def callable_with( - arg_inners: Sequence[Validator], - return_inner: Validator, - value: Any, - **kwargs: Any, -) -> Callable: - """Validate a callable with a given signature and return type. - - The rule's responsility is twofold: - 1. Validate the signature of the callable (keyword only arguments are not supported) - 2. Wrap the callable with validation logic that validates the arguments and the - return value at runtime. - - Parameters - ---------- - arg_inners - Sequence of validators to apply to the arguments of the callable. - return_inner - Validator to apply to the return value of the callable. - value - Callable to validate. - kwargs - Keyword arguments to pass to the inner validators. - - Returns - ------- - validated - The callable wrapped with validation logic. - """ - from ibis.common.annotations import annotated - - if not callable(value): - raise IbisTypeError("Argument must be a callable") - - fn = annotated(arg_inners, return_inner, value) - - has_varargs = False - positional, keyword_only = [], [] - for p in fn.__signature__.parameters.values(): - if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): - positional.append(p) - elif p.kind is Parameter.KEYWORD_ONLY: - keyword_only.append(p) - elif p.kind is Parameter.VAR_POSITIONAL: - has_varargs = True - - if keyword_only: - raise IbisTypeError( - "Callable has mandatory keyword-only arguments which cannot be specified" - ) - elif len(positional) > len(arg_inners): - raise IbisTypeError("Callable has more positional arguments than expected") - elif len(positional) < len(arg_inners) and not has_varargs: - raise IbisTypeError("Callable has less positional arguments than expected") - else: - return fn - - -@validator -def int_(arg: Any, min: int = 0, max: int = math.inf, **kwargs: Any) -> int: - """Validate an integer. - - Parameters - ---------- - arg - Integer to validate. - min - Minimum value of the integer. - max - Maximum value of the integer. - kwargs - Omitted keyword arguments. - - Returns - ------- - validated - The validated integer. - """ - if not isinstance(arg, int): - raise IbisTypeError('Argument must be an integer') - arg = min_(min, arg, **kwargs) - arg = max_(max, arg, **kwargs) - return arg - - -@validator -def min_(min: int, arg: int, **kwargs: Any) -> int: - if arg < min: - raise ValueError(f'Argument must be greater than {min}') - return arg - - -@validator -def max_(max: int, arg: int, **kwargs: Any) -> int: - if arg > max: - raise ValueError(f'Argument must be less than {max}') - return arg - - -str_ = instance_of(str) -bool_ = instance_of(bool) -none_ = instance_of(type(None)) -dict_of = mapping_of(type=dict) -list_of = sequence_of(type=list) -frozendict_of = mapping_of(type=FrozenDict) diff --git a/ibis/config.py b/ibis/config.py index 29d8cf8708ae..c2ab73123b53 100644 --- a/ibis/config.py +++ b/ibis/config.py @@ -8,9 +8,9 @@ import ibis.common.exceptions as com from ibis.common.grounds import Annotable -from ibis.common.validators import min_ +from ibis.common.patterns import Between -PosInt = Annotated[int, min_(0)] +PosInt = Annotated[int, Between(lower=0)] class Config(Annotable): diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 43bc70ce79c7..f70be806ca38 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -13,6 +13,7 @@ import ibis.expr.types as ir from ibis import util from ibis.common.exceptions import IbisTypeError, IntegrityError +from ibis.common.patterns import ValidationError # --------------------------------------------------------------------- # Some expression metaprogramming / graph transformations to support @@ -204,7 +205,7 @@ def substitute(fn, node): try: return node.__class__(*new_args) - except TypeError: + except (TypeError, ValidationError): return node diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index da719879413a..1639af8b3a8a 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -1,8 +1,9 @@ from __future__ import annotations import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, Optional, Union +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz @@ -11,7 +12,10 @@ from ibis.common.annotations import annotated from ibis.common.exceptions import IbisInputError from ibis.common.grounds import Concrete -from ibis.expr.deferred import Deferred +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.expr.deferred import Deferred # noqa: TCH001 +from ibis.expr.operations.core import Value # noqa: TCH001 +from ibis.expr.operations.relations import Relation # noqa: TCH001 from ibis.expr.types.relations import bind_expr if TYPE_CHECKING: @@ -23,8 +27,8 @@ class Builder(Concrete): class CaseBuilder(Builder): - results = rlz.optional(rlz.tuple_of(rlz.any), default=[]) - default = rlz.optional(rlz.any) + results: VarTuple[Value] = () + default: Optional[ops.Value] = None def type(self): return rlz.highest_precedence_dtype(self.results) @@ -61,15 +65,16 @@ def end(self) -> ir.Value: class SearchedCaseBuilder(CaseBuilder): __type__ = ops.SearchedCase - cases = rlz.optional(rlz.tuple_of(rlz.boolean), default=[]) + cases: VarTuple[Value[dt.Boolean, ds.Any]] = () class SimpleCaseBuilder(CaseBuilder): __type__ = ops.SimpleCase - base = rlz.any - cases = rlz.optional(rlz.tuple_of(rlz.any), default=[]) + base: ops.Value + cases: VarTuple[Value] = () - def when(self, case_expr, result_expr) -> Self: + @annotated + def when(self, case_expr: Value, result_expr: Value): """Add a new case-result pair. Parameters @@ -80,7 +85,6 @@ def when(self, case_expr, result_expr) -> Self: result_expr Value when the case predicate evaluates to true. """ - case_expr = rlz.any(case_expr) if not rlz.comparable(self.base, case_expr): raise TypeError( f'Base expression {rlz._arg_type_error_format(self.base)} and ' @@ -89,6 +93,10 @@ def when(self, case_expr, result_expr) -> Self: return super().when(case_expr, result_expr) +RowsWindowBoundary = ops.WindowBoundary[dt.Integer] +RangeWindowBoundary = ops.WindowBoundary[dt.Numeric | dt.Interval] + + class WindowBuilder(Builder): """An unbound window frame specification. @@ -101,24 +109,16 @@ class WindowBuilder(Builder): Use 0 for `CURRENT ROW`. """ - how = rlz.optional(rlz.isin({'rows', 'range'}), default="rows") - start = end = rlz.optional(rlz.option(rlz.range_window_boundary)) - groupings = rlz.optional( - rlz.tuple_of( - rlz.one_of([rlz.instance_of((str, Deferred)), rlz.column(rlz.any)]) - ), - default=(), - ) - orderings = rlz.optional( - rlz.tuple_of(rlz.one_of([rlz.instance_of((str, Deferred)), rlz.any])), - default=(), - ) - max_lookback = rlz.optional(rlz.interval) + how: Literal["rows", "range"] = "rows" + start: Optional[RangeWindowBoundary] = None + end: Optional[RangeWindowBoundary] = None + groupings: VarTuple[Union[str, Deferred, Value]] = () + orderings: VarTuple[Union[str, Deferred, Value]] = () + max_lookback: Optional[Value[dt.Interval]] = None def _maybe_cast_boundary(self, boundary, dtype): if boundary.output_dtype == dtype: return boundary - value = ops.Cast(boundary.value, dtype) return boundary.copy(value=value) @@ -149,29 +149,26 @@ def _validate_boundaries(self, start, end): "Window frame's start point must be greater than its end point" ) - @annotated( - start=rlz.option(rlz.row_window_boundary), - end=rlz.option(rlz.row_window_boundary), - ) - def rows(self, start, end): + @annotated + def rows( + self, start: Optional[RowsWindowBoundary], end: Optional[RowsWindowBoundary] + ): self._validate_boundaries(start, end) start, end = self._maybe_cast_boundaries(start, end) return self.copy(how="rows", start=start, end=end) - @annotated( - start=rlz.option(rlz.range_window_boundary), - end=rlz.option(rlz.range_window_boundary), - ) - def range(self, start, end): + @annotated + def range( + self, start: Optional[RangeWindowBoundary], end: Optional[RangeWindowBoundary] + ): self._validate_boundaries(start, end) start, end = self._maybe_cast_boundaries(start, end) return self.copy(how="range", start=start, end=end) - @annotated( - start=rlz.option(rlz.range_window_boundary), - end=rlz.option(rlz.range_window_boundary), - ) - def between(self, start, end): + @annotated + def between( + self, start: Optional[RangeWindowBoundary], end: Optional[RangeWindowBoundary] + ): self._validate_boundaries(start, end) start, end = self._maybe_cast_boundaries(start, end) method = self._determine_how(start, end) @@ -186,8 +183,8 @@ def order_by(self, expr) -> Self: def lookback(self, value) -> Self: return self.copy(max_lookback=value) - @annotated(table=rlz.table) - def bind(self, table): + @annotated + def bind(self, table: Relation): groupings = bind_expr(table.to_expr(), self.groupings) orderings = bind_expr(table.to_expr(), self.orderings) if self.how == "rows": diff --git a/ibis/expr/datashape.py b/ibis/expr/datashape.py new file mode 100644 index 000000000000..cc173fa67be1 --- /dev/null +++ b/ibis/expr/datashape.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import Any + +from public import public + +from ibis.common.grounds import Singleton + + +@public +class DataShape(Singleton): + ndim: int + SCALAR: Scalar + COLUMNAR: Columnar + + def is_scalar(self) -> bool: + return self.ndim == 0 + + def is_columnar(self) -> bool: + return self.ndim == 1 + + def is_tabular(self) -> bool: + return self.ndim == 2 + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, DataShape): + return NotImplemented + return self.ndim < other.ndim + + def __le__(self, other: Any) -> bool: + if not isinstance(other, DataShape): + return NotImplemented + return self.ndim <= other.ndim + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DataShape): + return NotImplemented + return self.ndim == other.ndim + + def __hash__(self) -> int: + return hash((self.__class__, self.ndim)) + + +@public +class Scalar(DataShape): + ndim = 0 + + +@public +class Columnar(DataShape): + ndim = 1 + + +@public +class Tabular(DataShape): + ndim = 2 + + +# for backward compat +DataShape.SCALAR = Scalar() +DataShape.COLUMNAR = Columnar() +DataShape.TABULAR = Tabular() + +scalar = Scalar() +columnar = Columnar() +tabular = Tabular() + + +public(Any=DataShape) diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index 6f1557464bb1..191370bcb7ae 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -7,7 +7,7 @@ from abc import abstractmethod from collections.abc import Iterator, Mapping, Sequence from numbers import Integral, Real -from typing import Any, Iterable, Literal, NamedTuple, Optional +from typing import Any, Generic, Iterable, Literal, NamedTuple, Optional, TypeVar import toolz from public import public @@ -17,8 +17,8 @@ from ibis.common.collections import FrozenDict, MapSet from ibis.common.dispatch import lazy_singledispatch from ibis.common.grounds import Concrete, Singleton +from ibis.common.patterns import Coercible, CoercionError from ibis.common.temporal import IntervalUnit, TimestampUnit -from ibis.common.validators import Coercible @lazy_singledispatch @@ -94,6 +94,16 @@ class DataType(Concrete, Coercible): nullable: bool = True + @property + @abstractmethod + def scalar(self): + ... + + @property + @abstractmethod + def column(self): + ... + # TODO(kszucs): remove it, prefer to use Annotable.__repr__ instead @property def _pretty_piece(self) -> str: @@ -106,8 +116,13 @@ def name(self) -> str: return self.__class__.__name__ @classmethod - def __coerce__(cls, value): - return dtype(value) + def __coerce__(cls, value, **kwargs): + if isinstance(value, cls): + return value + try: + return dtype(value) + except TypeError as e: + raise CoercionError("Unable to coerce to a DataType") from e def __call__(self, **kwargs): return self.copy(**kwargs) @@ -147,6 +162,7 @@ def from_string(cls, value) -> Self: @classmethod def from_typehint(cls, typ, nullable=True) -> Self: origin_type = get_origin(typ) + if origin_type is None: if isinstance(typ, type): if issubclass(typ, DataType): @@ -446,9 +462,6 @@ class Variadic(DataType): class Parametric(DataType): """Types that can be parameterized.""" - def __class_getitem__(cls, params): - return cls(*params) if isinstance(params, tuple) else cls(params) - @public class Null(Primitive): @@ -481,6 +494,11 @@ def __contains__(self, value: int) -> bool: class Numeric(DataType): """Numeric types.""" + @property + @abstractmethod + def largest(self) -> DataType: + """Return the largest type in this family.""" + @public class Integer(Primitive, Numeric): @@ -819,9 +837,6 @@ class Struct(Parametric, MapSet): scalar = "StructScalar" column = "StructColumn" - def __class_getitem__(cls, fields): - return cls({slice_.start: slice_.stop for slice_ in fields}) - @classmethod def from_tuples( cls, pairs: Iterable[tuple[str, str | DataType]], nullable: bool = True @@ -870,11 +885,14 @@ def _pretty_piece(self) -> str: return f"<{pairs}>" +T = TypeVar("T", bound=DataType, covariant=True) + + @public -class Array(Variadic, Parametric): +class Array(Variadic, Parametric, Generic[T]): """Array values.""" - value_type: DataType + value_type: T scalar = "ArrayScalar" column = "ArrayColumn" @@ -884,12 +902,16 @@ def _pretty_piece(self) -> str: return f"<{self.value_type}>" +K = TypeVar("K", bound=DataType, covariant=True) +V = TypeVar("V", bound=DataType, covariant=True) + + @public -class Map(Variadic, Parametric): +class Map(Variadic, Parametric, Generic[K, V]): """Associative array values.""" - key_type: DataType - value_type: DataType + key_type: K + value_type: V scalar = "MapScalar" column = "MapColumn" @@ -1048,6 +1070,7 @@ class INET(String): public( + Any=DataType, null=null, boolean=boolean, int8=int8, diff --git a/ibis/expr/datatypes/tests/test_core.py b/ibis/expr/datatypes/tests/test_core.py index bdb0103ab937..9b499249f755 100644 --- a/ibis/expr/datatypes/tests/test_core.py +++ b/ibis/expr/datatypes/tests/test_core.py @@ -8,9 +8,11 @@ from typing import Dict, List, NamedTuple, Tuple import pytest +from typing_extensions import Annotated import ibis.expr.datatypes as dt -from ibis.common.temporal import TimestampUnit +from ibis.common.patterns import As, Attrs, NoMatch, Pattern, ValidationError +from ibis.common.temporal import TimestampUnit, TimeUnit def test_validate_type(): @@ -67,8 +69,8 @@ def test_dtype(spec, expected): (dt.Boolean, dt.boolean), (dt.Date, dt.date), (dt.Time, dt.time), - (dt.Timestamp, dt.timestamp), (dt.Decimal, dt.decimal), + (dt.Timestamp, dt.timestamp), ], ) def test_dtype_from_classes(klass, expected): @@ -100,6 +102,34 @@ class FooStruct: s: dt.Map(dt.string, dt.int16) +foo_struct = dt.Struct( + { + 'a': dt.int16, + 'b': dt.int32, + 'c': dt.int64, + 'd': dt.uint8, + 'e': dt.uint16, + 'f': dt.uint32, + 'g': dt.uint64, + 'h': dt.float32, + 'i': dt.float64, + 'j': dt.string, + 'k': dt.binary, + 'l': dt.boolean, + 'm': dt.date, + 'n': dt.time, + 'o': dt.timestamp, + 'oa': dt.Timestamp('UTC'), + 'ob': dt.Timestamp('UTC', 6), + 'pa': dt.Interval('s'), + 'q': dt.decimal, + 'qa': dt.Decimal(12, 2), + 'r': dt.Array(dt.int16), + 's': dt.Map(dt.string, dt.int16), + } +) + + class BarStruct: a: dt.Int16 b: dt.Int32 @@ -116,16 +146,12 @@ class BarStruct: m: dt.Date n: dt.Time o: dt.Timestamp - oa: dt.Timestamp['UTC'] # noqa: F821, UP037 - ob: dt.Timestamp['UTC', 6] # noqa: F821, UP037 - pa: dt.Interval['s'] # noqa: F821, UP037 q: dt.Decimal - qa: dt.Decimal[12, 2] r: dt.Array[dt.Int16] s: dt.Map[dt.String, dt.Int16] -baz_struct = dt.Struct( +bar_struct = dt.Struct( { 'a': dt.int16, 'b': dt.int32, @@ -142,11 +168,7 @@ class BarStruct: 'm': dt.date, 'n': dt.time, 'o': dt.timestamp, - 'oa': dt.Timestamp('UTC'), - 'ob': dt.Timestamp('UTC', 6), - 'pa': dt.Interval('s'), 'q': dt.decimal, - 'qa': dt.Decimal(12, 2), 'r': dt.Array(dt.int16), 's': dt.Map(dt.string, dt.int16), } @@ -275,16 +297,8 @@ class FooDataClass: [ (dt.Array[dt.Null], dt.Array(dt.Null())), (dt.Map[dt.Null, dt.Null], dt.Map(dt.Null(), dt.Null())), - (dt.Timestamp['UTC'], dt.Timestamp(timezone='UTC')), - (dt.Timestamp['UTC', 6], dt.Timestamp(timezone='UTC', scale=6)), - (dt.Interval['s'], dt.Interval('s')), - (dt.Decimal[12, 2], dt.Decimal(12, 2)), - ( - dt.Struct['a' : dt.Int16, 'b' : dt.Int32], - dt.Struct({'a': dt.Int16(), 'b': dt.Int32()}), - ), - (FooStruct, baz_struct), - (BarStruct, baz_struct), + (FooStruct, foo_struct), + (BarStruct, bar_struct), (PyStruct, py_struct), (FooNamedTuple, dt.Struct({'a': dt.string, 'b': dt.int64, 'c': dt.float64})), (FooDataClass, dt.Struct({'a': dt.string, 'b': dt.int64, 'c': dt.float64})), @@ -315,18 +329,6 @@ class Something: dt.dtype(Something) -def test_dtype_from_additional_struct_typehints(): - class A: - nested: dt.Struct({'a': dt.Int16, 'b': dt.Int32}) - - class B: - nested: dt.Struct['a' : dt.Int16, 'b' : dt.Int32] # noqa: F821, UP037 - - expected = dt.Struct({'nested': dt.Struct({'a': dt.Int16(), 'b': dt.Int32()})}) - assert dt.dtype(A) == expected - assert dt.dtype(B) == expected - - def test_struct_subclass_from_tuples(): class MyStruct(dt.Struct): pass @@ -445,13 +447,13 @@ def test_array_type_equals(): def test_interval_invalid_value_type(): - with pytest.raises(TypeError): + with pytest.raises(ValidationError): dt.Interval('m', dt.float32) @pytest.mark.parametrize('unit', ['H', 'unsupported']) def test_interval_invalid_unit(unit): - with pytest.raises(ValueError): + with pytest.raises(ValidationError): dt.Interval(dt.int32, unit) @@ -612,3 +614,48 @@ def test_is_temporal(): def test_set_is_an_alias_of_array(): assert dt.Set is dt.Array + + +def test_type_coercion(): + p = Pattern.from_typehint(dt.DataType) + assert p.match(dt.int8, {}) == dt.int8 + assert p.match('int8', {}) == dt.int8 + assert p.match(dt.string, {}) == dt.string + assert p.match('string', {}) == dt.string + assert p.match(3, {}) is NoMatch + + p = Pattern.from_typehint(dt.Primitive) + assert p.match(dt.int8, {}) == dt.int8 + assert p.match('int8', {}) == dt.int8 + assert p.match(dt.boolean, {}) == dt.boolean + assert p.match('boolean', {}) == dt.boolean + assert p.match(dt.Array(dt.int8), {}) is NoMatch + assert p.match('array', {}) is NoMatch + + p = Pattern.from_typehint(dt.Integer) + assert p.match(dt.int8, {}) == dt.int8 + assert p.match('int8', {}) == dt.int8 + assert p.match(dt.uint8, {}) == dt.uint8 + assert p.match('uint8', {}) == dt.uint8 + assert p.match(dt.boolean, {}) is NoMatch + assert p.match('boolean', {}) is NoMatch + + p = Pattern.from_typehint(dt.Array[dt.Integer]) + assert p.match(dt.Array(dt.int8), {}) == dt.Array(dt.int8) + assert p.match('array', {}) == dt.Array(dt.int8) + assert p.match(dt.Array(dt.uint8), {}) == dt.Array(dt.uint8) + assert p.match('array', {}) == dt.Array(dt.uint8) + assert p.match(dt.Array(dt.boolean), {}) is NoMatch + assert p.match('array', {}) is NoMatch + + p = Pattern.from_typehint(dt.Map[dt.String, dt.Integer]) + assert p.match(dt.Map(dt.string, dt.int8), {}) == dt.Map(dt.string, dt.int8) + assert p.match('map', {}) == dt.Map(dt.string, dt.int8) + assert p.match(dt.Map(dt.string, dt.uint8), {}) == dt.Map(dt.string, dt.uint8) + assert p.match('map', {}) == dt.Map(dt.string, dt.uint8) + assert p.match(dt.Map(dt.string, dt.boolean), {}) is NoMatch + assert p.match('map', {}) is NoMatch + + p = Pattern.from_typehint(Annotated[dt.Interval, Attrs(unit=As(TimeUnit))]) + assert p.match(dt.Interval('s'), {}) == dt.Interval('s') + assert p.match(dt.Interval('ns'), {}) == dt.Interval('ns') diff --git a/ibis/expr/datatypes/tests/test_parse.py b/ibis/expr/datatypes/tests/test_parse.py index 5a2de709a5b1..9eb7d922965a 100644 --- a/ibis/expr/datatypes/tests/test_parse.py +++ b/ibis/expr/datatypes/tests/test_parse.py @@ -4,6 +4,7 @@ import pytest import ibis.expr.datatypes as dt +from ibis.common.patterns import ValidationError @pytest.mark.parametrize( @@ -227,7 +228,7 @@ def test_parse_interval(unit): @pytest.mark.parametrize('unit', ['X', 'unsupported']) def test_parse_interval_with_invalid_unit(unit): definition = f"interval('{unit}')" - with pytest.raises(ValueError): + with pytest.raises(ValidationError): dt.dtype(definition) diff --git a/ibis/expr/datatypes/tests/test_value.py b/ibis/expr/datatypes/tests/test_value.py index f55e00d06974..4b5185678870 100644 --- a/ibis/expr/datatypes/tests/test_value.py +++ b/ibis/expr/datatypes/tests/test_value.py @@ -2,6 +2,7 @@ import decimal import enum +import json from collections import OrderedDict from datetime import date, datetime, timedelta @@ -330,3 +331,14 @@ def test_infer_numpy_array(numpy_array, expected_dtypes): pandas_series = pd.Series(numpy_array) assert dt.infer(numpy_array) in expected_dtypes assert dt.infer(pandas_series) in expected_dtypes + + +def test_normalize_json(): + obj = ['foo', {'bar': ('baz', None, 1.0, 2)}] + expected = json.dumps(obj) + + assert dt.normalize(dt.json, obj) == expected + assert dt.normalize(dt.json, expected) == expected + + with pytest.raises(TypeError): + dt.normalize(dt.json, "invalid") diff --git a/ibis/expr/datatypes/value.py b/ibis/expr/datatypes/value.py index d6865d2ec2ef..dfa8cc805fbc 100644 --- a/ibis/expr/datatypes/value.py +++ b/ibis/expr/datatypes/value.py @@ -5,6 +5,7 @@ import decimal import enum import ipaddress +import json import uuid from typing import Any, Mapping, NamedTuple, Sequence @@ -27,7 +28,9 @@ @lazy_singledispatch def infer(value: Any) -> dt.DataType: """Infer the corresponding ibis dtype for a python object.""" - raise InputTypeError(f"Unable to infer datatype of {value!r}") + raise InputTypeError( + f"Unable to infer datatype of value {value!r} with type {type(value)}" + ) # TODO(kszucs): support NamedTuples and dataclasses instead of OrderedDict @@ -248,7 +251,10 @@ def normalize(typ, value): return None if dtype.is_boolean(): - return bool(value) + try: + return bool(value) + except ValueError: + raise TypeError("Unable to normalize {value!r} to {dtype!r}") elif dtype.is_integer(): try: value = int(value) @@ -266,7 +272,19 @@ def normalize(typ, value): return float(value) except ValueError: raise TypeError("Unable to normalize {value!r} to {dtype!r}") - elif dtype.is_string() and not dtype.is_json(): + elif dtype.is_json(): + if isinstance(value, str): + try: + json.loads(value) + except json.JSONDecodeError: + raise TypeError(f"Invalid JSON string: {value!r}") + else: + return value + else: + return json.dumps(value) + elif dtype.is_binary(): + return bytes(value) + elif dtype.is_string(): return str(value) elif dtype.is_decimal(): out = decimal.Decimal(value) @@ -300,6 +318,8 @@ def normalize(typ, value): return _WellKnownText(value.wkt) elif dtype.is_date(): return normalize_datetime(value).date() + elif dtype.is_time(): + return normalize_datetime(value).time() elif dtype.is_timestamp(): value = normalize_datetime(value) tzinfo = normalize_timezone(dtype.timezone) @@ -312,7 +332,7 @@ def normalize(typ, value): elif dtype.is_interval(): return normalize_timedelta(value, dtype.unit) else: - return value + raise TypeError(f"Unable to normalize {value!r} to {dtype!r}") public(infer=infer, normalize=normalize) diff --git a/ibis/expr/operations/analytic.py b/ibis/expr/operations/analytic.py index 1839dfe25917..fad1f9c29ea2 100644 --- a/ibis/expr/operations/analytic.py +++ b/ibis/expr/operations/analytic.py @@ -1,16 +1,19 @@ from __future__ import annotations +from typing import Optional + from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute -from ibis.expr.operations.core import Value +from ibis.expr.operations.core import Column, Scalar, Value @public class Analytic(Value): - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar @property def __window_op__(self): @@ -19,10 +22,9 @@ def __window_op__(self): @public class ShiftBase(Analytic): - arg = rlz.column(rlz.any) - - offset = rlz.optional(rlz.one_of((rlz.integer, rlz.interval))) - default = rlz.optional(rlz.any) + arg: Column[dt.Any] + offset: Optional[Value[dt.Integer | dt.Interval]] = None + default: Optional[Value] = None output_dtype = rlz.dtype_like("arg") @@ -44,12 +46,12 @@ class RankBase(Analytic): @public class MinRank(RankBase): - arg = rlz.column(rlz.any) + arg: Column[dt.Any] @public class DenseRank(RankBase): - arg = rlz.column(rlz.any) + arg: Column[dt.Any] @public @@ -74,18 +76,18 @@ class RowNumber(RankBase): @public -class CumulativeOp(Analytic): +class Cumulative(Analytic): pass @public -class CumulativeSum(CumulativeOp): +class CumulativeSum(Cumulative): """Cumulative sum. Requires an ordering window. """ - arg = rlz.column(rlz.numeric) + arg: Column[dt.Numeric] @attribute.default def output_dtype(self): @@ -93,13 +95,13 @@ def output_dtype(self): @public -class CumulativeMean(CumulativeOp): +class CumulativeMean(Cumulative): """Cumulative mean. Requires an order window. """ - arg = rlz.column(rlz.numeric) + arg: Column[dt.Numeric] @attribute.default def output_dtype(self): @@ -107,50 +109,56 @@ def output_dtype(self): @public -class CumulativeMax(CumulativeOp): - arg = rlz.column(rlz.any) +class CumulativeMax(Cumulative): + arg: Column[dt.Any] + output_dtype = rlz.dtype_like("arg") @public -class CumulativeMin(CumulativeOp): +class CumulativeMin(Cumulative): """Cumulative min. Requires an order window. """ - arg = rlz.column(rlz.any) + arg: Column[dt.Any] + output_dtype = rlz.dtype_like("arg") @public -class CumulativeAny(CumulativeOp): - arg = rlz.column(rlz.boolean) +class CumulativeAny(Cumulative): + arg: Column[dt.Boolean] output_dtype = rlz.dtype_like("arg") @public -class CumulativeAll(CumulativeOp): - arg = rlz.column(rlz.boolean) +class CumulativeAll(Cumulative): + arg: Column[dt.Boolean] + output_dtype = rlz.dtype_like("arg") @public class PercentRank(Analytic): - arg = rlz.column(rlz.any) + arg: Column[dt.Any] + output_dtype = dt.double @public class CumeDist(Analytic): - arg = rlz.column(rlz.any) + arg: Column[dt.Any] + output_dtype = dt.double @public class NTile(Analytic): - arg = rlz.column(rlz.any) - buckets = rlz.scalar(rlz.integer) + arg: Column[dt.Any] + buckets: Scalar[dt.Integer] + output_dtype = dt.int64 @@ -158,7 +166,8 @@ class NTile(Analytic): class FirstValue(Analytic): """Retrieve the first element.""" - arg = rlz.column(rlz.any) + arg: Column[dt.Any] + output_dtype = rlz.dtype_like("arg") @@ -166,7 +175,8 @@ class FirstValue(Analytic): class LastValue(Analytic): """Retrieve the last element.""" - arg = rlz.column(rlz.any) + arg: Column[dt.Any] + output_dtype = rlz.dtype_like("arg") @@ -174,9 +184,10 @@ class LastValue(Analytic): class NthValue(Analytic): """Retrieve the Nth element.""" - arg = rlz.column(rlz.any) - nth = rlz.integer + arg: Column[dt.Any] + nth: Value[dt.Integer] + output_dtype = rlz.dtype_like("arg") -public(AnalyticOp=Analytic) +public(AnalyticOp=Analytic, CumulativeOp=Cumulative) diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 628fb32a7cd2..3ad30d06a921 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -1,18 +1,22 @@ from __future__ import annotations +from typing import Callable, Optional + from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute +from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Argument, Unary, Value @public class ArrayColumn(Value): - cols = rlz.tuple_of(rlz.any, min_length=1) + cols: VarTuple[Value] - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar @attribute.default def output_dtype(self): @@ -21,7 +25,7 @@ def output_dtype(self): @public class ArrayLength(Unary): - arg = rlz.array + arg: Value[dt.Array] output_dtype = dt.int64 output_shape = rlz.shape_like("args") @@ -29,9 +33,9 @@ class ArrayLength(Unary): @public class ArraySlice(Value): - arg = rlz.array - start = rlz.integer - stop = rlz.optional(rlz.integer) + arg: Value[dt.Array] + start: Value[dt.Integer] + stop: Optional[Value[dt.Integer]] = None output_dtype = rlz.dtype_like("arg") output_shape = rlz.shape_like("arg") @@ -39,8 +43,8 @@ class ArraySlice(Value): @public class ArrayIndex(Value): - arg = rlz.array - index = rlz.integer + arg: Value[dt.Array] + index: Value[dt.Integer] output_shape = rlz.shape_like("args") @@ -51,7 +55,7 @@ def output_dtype(self): @public class ArrayConcat(Value): - arg = rlz.tuple_of(rlz.array, min_length=2) + arg: VarTuple[Value[dt.Array]] @attribute.default def output_dtype(self): @@ -66,15 +70,15 @@ def output_shape(self): @public class ArrayRepeat(Value): - arg = rlz.array - times = rlz.integer + arg: Value[dt.Array] + times: Value[dt.Integer] output_dtype = rlz.dtype_like("arg") output_shape = rlz.shape_like("args") class ArrayApply(Value): - arg = rlz.array + arg: Value[dt.Array] @attribute.default def parameter(self): @@ -97,7 +101,7 @@ def output_shape(self): @public class ArrayMap(ArrayApply): - func = rlz.callable_with([rlz.expr_of(rlz.any)], rlz.any) + func: Callable[[Value], Value] @attribute.default def output_dtype(self) -> dt.DataType: @@ -106,26 +110,26 @@ def output_dtype(self) -> dt.DataType: @public class ArrayFilter(ArrayApply): - func = rlz.callable_with([rlz.expr_of(rlz.any)], rlz.boolean) + func: Callable[[Value], Value[dt.Boolean]] output_dtype = rlz.dtype_like("arg") @public class Unnest(Value): - arg = rlz.array + arg: Value[dt.Array] + + output_shape = ds.columnar @attribute.default def output_dtype(self): return self.arg.output_dtype.value_type - output_shape = rlz.Shape.COLUMNAR - @public class ArrayContains(Value): - arg = rlz.array - other = rlz.any + arg: Value[dt.Array] + other: Value output_dtype = dt.boolean output_shape = rlz.shape_like("args") @@ -133,8 +137,8 @@ class ArrayContains(Value): @public class ArrayPosition(Value): - arg = rlz.array - other = rlz.any + arg: Value[dt.Array] + other: Value output_dtype = dt.int64 output_shape = rlz.shape_like("args") @@ -142,8 +146,8 @@ class ArrayPosition(Value): @public class ArrayRemove(Value): - arg = rlz.array - other = rlz.any + arg: Value[dt.Array] + other: Value output_dtype = rlz.dtype_like("arg") output_shape = rlz.shape_like("args") @@ -151,14 +155,15 @@ class ArrayRemove(Value): @public class ArrayDistinct(Value): - arg = rlz.array + arg: Value[dt.Array] + output_dtype = rlz.dtype_like("arg") output_shape = rlz.shape_like("arg") @public class ArraySort(Value): - arg = rlz.array + arg: Value[dt.Array] output_dtype = rlz.dtype_like("arg") output_shape = rlz.shape_like("arg") @@ -166,8 +171,8 @@ class ArraySort(Value): @public class ArrayUnion(Value): - left = rlz.array - right = rlz.array + left: Value[dt.Array] + right: Value[dt.Array] output_dtype = rlz.dtype_like("args") output_shape = rlz.shape_like("args") @@ -175,8 +180,8 @@ class ArrayUnion(Value): @public class ArrayIntersect(Value): - left = rlz.array - right = rlz.array + left: Value[dt.Array] + right: Value[dt.Array] output_dtype = rlz.dtype_like("args") output_shape = rlz.shape_like("args") @@ -184,7 +189,7 @@ class ArrayIntersect(Value): @public class ArrayZip(Value): - arg = rlz.tuple_of(rlz.array, min_length=2) + arg: VarTuple[Value[dt.Array]] output_shape = rlz.shape_like("arg") diff --git a/ibis/expr/operations/core.py b/ibis/expr/operations/core.py index 94e258ef10b8..010e1c9539c6 100644 --- a/ibis/expr/operations/core.py +++ b/ibis/expr/operations/core.py @@ -1,22 +1,24 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import Generic, Optional from public import public +from typing_extensions import Any, Self, TypeVar +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis import util from ibis.common.graph import Node as Traversable from ibis.common.grounds import Concrete - -if TYPE_CHECKING: - import ibis.expr.datatypes as dt +from ibis.common.patterns import Coercible, CoercionError +from ibis.common.typing import DefaultTypeVars, VarTuple @public class Node(Concrete, Traversable): - def equals(self, other): + def equals(self, other) -> bool: if not isinstance(other, Node): raise TypeError( f"invalid equality comparison between Node and {type(other)}" @@ -24,7 +26,7 @@ def equals(self, other): return self.__cached_equals__(other) @util.deprecated(as_of='4.0', instead='remove intermediate .op() calls') - def op(self): + def op(self) -> Self: """Make `Node` backwards compatible with code that uses `Expr.op()`.""" return self @@ -38,11 +40,11 @@ def to_expr(self): @public class Named(ABC): - __slots__ = tuple() + __slots__: VarTuple[str] = tuple() @property @abstractmethod - def name(self): + def name(self) -> str: """Name of the operation. Returns @@ -51,18 +53,45 @@ def name(self): """ +T = TypeVar("T", bound=dt.DataType, covariant=True) +S = TypeVar("S", bound=ds.DataShape, default=ds.Any, covariant=True) + + @public -class Value(Node, Named): +class Value(Node, Named, Coercible, DefaultTypeVars, Generic[T, S]): + @classmethod + def __coerce__( + cls, value: Any, T: Optional[type] = None, S: Optional[type] = None + ) -> Self: + # note that S=Shape is unused here since the pattern will check the + # shape of the value expression after executing Value.__coerce__() + from ibis.expr.operations import Literal + from ibis.expr.types import Expr + + if isinstance(value, Expr): + value = value.op() + if isinstance(value, Value): + return value + + try: + try: + dtype = dt.dtype(T) + except TypeError: + dtype = dt.infer(value) + return Literal(value, dtype=dtype) + except TypeError: + raise CoercionError(f"Unable to coerce {value!r} to Value[{T!r}]") + # TODO(kszucs): cover it with tests # TODO(kszucs): figure out how to represent not named arguments @property - def name(self): + def name(self) -> str: args = ", ".join(arg.name for arg in self.__args__ if isinstance(arg, Named)) return f"{self.__class__.__name__}({args})" @property @abstractmethod - def output_dtype(self) -> dt.DataType: + def output_dtype(self) -> T: """Ibis datatype of the produced value expression. Returns @@ -72,7 +101,7 @@ def output_dtype(self) -> dt.DataType: @property @abstractmethod - def output_shape(self) -> rlz.Shape: + def output_shape(self) -> S: """Shape of the produced value expression. Possible values are: "scalar" and "columnar" @@ -93,10 +122,15 @@ def to_expr(self): return getattr(ir, typename)(self) +# convenience aliases +Scalar = Value[T, ds.Scalar] +Column = Value[T, ds.Columnar] + + @public class Alias(Value): - arg = rlz.any - name = rlz.instance_of(str) + arg: Value + name: str output_shape = rlz.shape_like("arg") output_dtype = rlz.dtype_like("arg") @@ -106,10 +140,10 @@ class Alias(Value): class Unary(Value): """A unary operation.""" - arg = rlz.any + arg: Value @property - def output_shape(self): + def output_shape(self) -> ds.DataShape: return self.arg.output_shape @@ -117,26 +151,26 @@ def output_shape(self): class Binary(Value): """A binary operation.""" - left = rlz.any - right = rlz.any + left: Value + right: Value @property - def output_shape(self): + def output_shape(self) -> ds.DataShape: return max(self.left.output_shape, self.right.output_shape) @public class Argument(Value): - name = rlz.instance_of(str) - shape = rlz.instance_of(rlz.Shape) - dtype = rlz.datatype + name: str + shape: ds.DataShape + dtype: dt.DataType @property def output_dtype(self) -> dt.DataType: return self.dtype @property - def output_shape(self) -> rlz.Shape: + def output_shape(self) -> ds.DataShape: return self.shape diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index 6ebfd0f35801..0a662f425568 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -1,33 +1,32 @@ from __future__ import annotations import abc -import datetime -import decimal -import enum -import ipaddress import itertools -import uuid -from operator import attrgetter +from typing import Any, Optional, Union +from typing import Literal as LiteralType from public import public +from typing_extensions import TypeVar +import ibis.common.exceptions as com +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -from ibis.common import exceptions as com from ibis.common.annotations import attribute -from ibis.common.collections import frozendict from ibis.common.grounds import Singleton -from ibis.expr.operations.core import Named, Unary, Value +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.expr.operations.core import Named, Scalar, Unary, Value +from ibis.expr.operations.relations import Relation # noqa: TCH001 @public class TableColumn(Value, Named): """Selects a column from a `Table`.""" - table = rlz.table - name = rlz.instance_of((str, int)) + table: Relation + name: Union[str, int] - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar def __init__(self, table, name): if isinstance(name, int): @@ -52,8 +51,9 @@ class RowID(Value, Named): """The row number (an autonumeric) of the returned result.""" name = "rowid" - table = rlz.table - output_shape = rlz.Shape.COLUMNAR + table: Relation + + output_shape = ds.columnar output_dtype = dt.int64 @@ -61,9 +61,9 @@ class RowID(Value, Named): class TableArrayView(Value, Named): """Helper operation class for creating scalar subqueries.""" - table = rlz.table + table: Relation - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar @property def output_dtype(self): @@ -78,26 +78,32 @@ def name(self): class Cast(Value): """Explicitly cast value to a specific data type.""" - arg = rlz.any - to = rlz.datatype + arg: Value + to: dt.DataType output_shape = rlz.shape_like("arg") - output_dtype = property(attrgetter("to")) @property def name(self): return f"{self.__class__.__name__}({self.arg.name}, {self.to})" + @property + def output_dtype(self): + return self.to + @public class TryCast(Value): """Explicitly try cast value to a specific data type.""" - arg = rlz.any - to = rlz.datatype + arg: Value + to: dt.DataType output_shape = rlz.shape_like("arg") - output_dtype = property(attrgetter("to")) + + @property + def output_dtype(self): + return self.to @public @@ -140,8 +146,9 @@ class ZeroIfNull(Unary): class IfNull(Value): """Set values to ifnull_expr if they are equal to NULL.""" - arg = rlz.any - ifnull_expr = rlz.any + arg: Value + ifnull_expr: Value + output_dtype = rlz.dtype_like("args") output_shape = rlz.shape_like("args") @@ -150,102 +157,87 @@ class IfNull(Value): class NullIf(Value): """Set values to NULL if they equal the null_if_expr.""" - arg = rlz.any - null_if_expr = rlz.any + arg: Value + null_if_expr: Value + output_dtype = rlz.dtype_like("args") output_shape = rlz.shape_like("args") @public class Coalesce(Value): - arg = rlz.tuple_of(rlz.any) + arg: VarTuple[Value] + output_shape = rlz.shape_like('arg') output_dtype = rlz.dtype_like('arg') @public class Greatest(Value): - arg = rlz.tuple_of(rlz.any) + arg: VarTuple[Value] + output_shape = rlz.shape_like('arg') output_dtype = rlz.dtype_like('arg') @public class Least(Value): - arg = rlz.tuple_of(rlz.any) + arg: VarTuple[Value] + output_shape = rlz.shape_like('arg') output_dtype = rlz.dtype_like('arg') +T = TypeVar("T", bound=dt.DataType, covariant=True) + + @public -class Literal(Value): - value = rlz.one_of( - ( - rlz.instance_of( - ( - bytes, - datetime.date, - datetime.datetime, - datetime.time, - datetime.timedelta, - enum.Enum, - float, - frozenset, - int, - ipaddress.IPv4Address, - ipaddress.IPv6Address, - frozendict, - str, - tuple, - type(None), - uuid.UUID, - decimal.Decimal, - ) - ), - rlz.lazy_instance_of( - ( - "shapely.geometry.BaseGeometry", - "numpy.generic", - "numpy.ndarray", - ) - ), - ) - ) - dtype = rlz.datatype - - # TODO(kszucs): it should be named actually - - output_shape = rlz.Shape.SCALAR - output_dtype = property(attrgetter("dtype")) +class Literal(Scalar): + value: Any + dtype: dt.DataType + + output_shape = ds.scalar + + def __init__(self, value, dtype): + # normalize ensures that the value is a valid value for the given dtype + value = dt.normalize(dtype, value) + super().__init__(value=value, dtype=dtype) @property def name(self): return repr(self.value) + @property + def output_dtype(self) -> T: + return self.dtype + @public -class ScalarParameter(Value, Named): +class ScalarParameter(Scalar, Named): _counter = itertools.count() - dtype = rlz.datatype - counter = rlz.optional( - rlz.instance_of(int), default=lambda: next(ScalarParameter._counter) - ) + dtype: dt.DataType + counter: Optional[int] = None + + output_shape = ds.scalar - output_shape = rlz.Shape.SCALAR - output_dtype = property(attrgetter("dtype")) + def __init__(self, dtype, counter): + if counter is None: + counter = next(self._counter) + super().__init__(dtype=dtype, counter=counter) @property def name(self): return f'param_{self.counter:d}' - def __hash__(self): - return hash((self.dtype, self.counter)) + @property + def output_dtype(self): + return self.dtype @public -class Constant(Value, Singleton): - output_shape = rlz.Shape.SCALAR +class Constant(Scalar, Singleton): + output_shape = ds.scalar @public @@ -270,7 +262,7 @@ class Pi(Constant): @public class Hash(Value): - arg = rlz.any + arg: Value output_dtype = dt.int64 output_shape = rlz.shape_like("arg") @@ -278,25 +270,22 @@ class Hash(Value): @public class HashBytes(Value): - arg = rlz.one_of({rlz.value(dt.string), rlz.value(dt.binary)}) - # TODO: these don't necessarily all belong here - how = rlz.isin( - { - "md5", - "MD5", - "sha1", - "SHA1", - "SHA224", - "sha256", - "SHA256", - "sha512", - "intHash32", - "intHash64", - "cityHash64", - "sipHash64", - "sipHash128", - } - ) + arg: Value[dt.String | dt.Binary] + how: LiteralType[ + "md5", + "MD5", + "sha1", + "SHA1", + "SHA224", + "sha256", + "SHA256", + "sha512", + "intHash32", + "intHash64", + "cityHash64", + "sipHash64", + "sipHash128", + ] output_dtype = dt.binary output_shape = rlz.shape_like("arg") @@ -307,10 +296,10 @@ class HashBytes(Value): # api.py @public class SimpleCase(Value): - base = rlz.any - cases = rlz.tuple_of(rlz.any) - results = rlz.tuple_of(rlz.any) - default = rlz.any + base: Value + cases: VarTuple[Value] + results: VarTuple[Value] + default: Value output_shape = rlz.shape_like("base") @@ -326,9 +315,9 @@ def output_dtype(self): @public class SearchedCase(Value): - cases = rlz.tuple_of(rlz.boolean) - results = rlz.tuple_of(rlz.any) - default = rlz.any + cases: VarTuple[Value[dt.Boolean]] + results: VarTuple[Value] + default: Value def __init__(self, cases, results, default): assert len(cases) == len(results) diff --git a/ibis/expr/operations/geospatial.py b/ibis/expr/operations/geospatial.py index 1cc4cb824005..30c3f0e0969f 100644 --- a/ibis/expr/operations/geospatial.py +++ b/ibis/expr/operations/geospatial.py @@ -2,9 +2,8 @@ from public import public -from ibis.expr import datatypes as dt -from ibis.expr import rules as rlz -from ibis.expr.operations.core import Binary, Unary +import ibis.expr.datatypes as dt +from ibis.expr.operations.core import Binary, Unary, Value from ibis.expr.operations.reductions import Reduction @@ -12,15 +11,15 @@ class GeoSpatialBinOp(Binary): """Geo Spatial base binary.""" - left = rlz.geospatial - right = rlz.geospatial + left: Value[dt.GeoSpatial] + right: Value[dt.GeoSpatial] @public class GeoSpatialUnOp(Unary): """Geo Spatial base unary.""" - arg = rlz.geospatial + arg: Value[dt.GeoSpatial] @public @@ -83,7 +82,7 @@ class GeoEquals(GeoSpatialBinOp): class GeoGeometryN(GeoSpatialUnOp): """Returns the Nth Geometry of a Multi geometry.""" - n = rlz.integer + n: Value[dt.Integer] output_dtype = dt.geometry @@ -120,8 +119,8 @@ class GeoLineLocatePoint(GeoSpatialBinOp): fraction of the total 2d line length. """ - left = rlz.linestring - right = rlz.point + left: Value[dt.LineString] + right: Value[dt.Point] output_dtype = dt.halffloat @@ -149,9 +148,9 @@ class GeoLineSubstring(GeoSpatialUnOp): This only works with linestrings. """ - arg = rlz.linestring - start = rlz.floating - end = rlz.floating + arg: Value[dt.LineString] + start: Value[dt.Floating] + end: Value[dt.Floating] output_dtype = dt.linestring @@ -303,8 +302,8 @@ class GeoPoint(GeoSpatialBinOp): Constant coordinates result in construction of a POINT literal. """ - left = rlz.numeric - right = rlz.numeric + left: Value[dt.Numeric] + right: Value[dt.Numeric] output_dtype = dt.point @@ -318,7 +317,7 @@ class GeoPointN(GeoSpatialUnOp): no linestring in the geometry """ - n = rlz.integer + n: Value[dt.Integer] output_dtype = dt.point @@ -350,7 +349,7 @@ class GeoSRID(GeoSpatialUnOp): class GeoSetSRID(GeoSpatialUnOp): """Set the spatial reference identifier for the ST_Geometry.""" - srid = rlz.integer + srid: Value[dt.Integer] output_dtype = dt.geometry @@ -362,7 +361,7 @@ class GeoBuffer(GeoSpatialUnOp): Calculations are in the Spatial Reference System of this geometry. """ - radius = rlz.floating + radius: Value[dt.Floating] output_dtype = dt.geometry @@ -377,7 +376,7 @@ class GeoCentroid(GeoSpatialUnOp): class GeoDFullyWithin(GeoSpatialBinOp): """Check if the geometries are fully within `distance` of one another.""" - distance = rlz.floating + distance: Value[dt.Floating] output_dtype = dt.boolean @@ -386,7 +385,7 @@ class GeoDFullyWithin(GeoSpatialBinOp): class GeoDWithin(GeoSpatialBinOp): """Check if the geometries are within `distance` of one another.""" - distance = rlz.floating + distance: Value[dt.Floating] output_dtype = dt.boolean @@ -406,8 +405,8 @@ class GeoAzimuth(GeoSpatialBinOp): 3=PI/2; 6=PI; 9=3PI/2. """ - left = rlz.point - right = rlz.point + left: Value[dt.Point] + right: Value[dt.Point] output_dtype = dt.float64 @@ -437,8 +436,8 @@ class GeoDifference(GeoSpatialBinOp): class GeoSimplify(GeoSpatialUnOp): """Returns a simplified version of the given geometry.""" - tolerance = rlz.floating - preserve_collapsed = rlz.boolean + tolerance: Value[dt.Floating] + preserve_collapsed: Value[dt.Boolean] output_dtype = dt.geometry @@ -447,7 +446,7 @@ class GeoSimplify(GeoSpatialUnOp): class GeoTransform(GeoSpatialUnOp): """Returns a transformed version of the given geometry into a new SRID.""" - srid = rlz.integer + srid: Value[dt.Integer] output_dtype = dt.geometry diff --git a/ibis/expr/operations/histograms.py b/ibis/expr/operations/histograms.py index ae42d3b8d728..10c2c9da7a07 100644 --- a/ibis/expr/operations/histograms.py +++ b/ibis/expr/operations/histograms.py @@ -1,24 +1,28 @@ from __future__ import annotations -import numbers +import numbers # noqa: TCH003 +from typing import Literal from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt from ibis.common.annotations import attribute -from ibis.expr import rules as rlz -from ibis.expr.operations.core import Value +from ibis.common.patterns import ValidationError +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.expr.operations.core import Column, Value @public class Bucket(Value): - arg = rlz.column(rlz.numeric) - buckets = rlz.tuple_of(rlz.instance_of(numbers.Real)) - closed = rlz.optional(rlz.isin({'left', 'right'}), default='left') - close_extreme = rlz.optional(rlz.instance_of(bool), default=True) - include_under = rlz.optional(rlz.instance_of(bool), default=False) - include_over = rlz.optional(rlz.instance_of(bool), default=False) - output_shape = rlz.Shape.COLUMNAR + arg: Column[dt.Numeric | dt.Boolean] + buckets: VarTuple[numbers.Real] + closed: Literal['left', 'right'] = 'left' + close_extreme: bool = True + include_under: bool = False + include_over: bool = False + + output_shape = ds.columnar @attribute.default def output_dtype(self): @@ -26,10 +30,10 @@ def output_dtype(self): def __init__(self, buckets, include_under, include_over, **kwargs): if not buckets: - raise ValueError('Must be at least one bucket edge') + raise ValidationError('Must be at least one bucket edge') elif len(buckets) == 1: if not include_under or not include_over: - raise ValueError( + raise ValidationError( 'If one bucket edge provided, must have ' 'include_under=True and include_over=True' ) diff --git a/ibis/expr/operations/json.py b/ibis/expr/operations/json.py index daaa8083bcc7..c2d33eabcc1a 100644 --- a/ibis/expr/operations/json.py +++ b/ibis/expr/operations/json.py @@ -9,8 +9,8 @@ @public class JSONGetItem(Value): - arg = rlz.json - index = rlz.one_of((rlz.string, rlz.integer)) + arg: Value[dt.JSON] + index: Value[dt.String | dt.Integer] output_dtype = dt.json output_shape = rlz.shape_like("args") @@ -18,7 +18,7 @@ class JSONGetItem(Value): @public class ToJSONArray(Value): - arg = rlz.json + arg: Value[dt.JSON] output_dtype = dt.Array(dt.json) output_shape = rlz.shape_like("arg") @@ -26,7 +26,7 @@ class ToJSONArray(Value): @public class ToJSONMap(Value): - arg = rlz.json + arg: Value[dt.JSON] output_dtype = dt.Map(dt.string, dt.json) output_shape = rlz.shape_like("arg") diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index b8c144b8a81a..65e12d166228 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -1,31 +1,33 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import ibis.expr.types as ir +from typing import Union from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute -from ibis.expr.operations.core import Binary, Unary, Value +from ibis.common.exceptions import IbisTypeError +from ibis.common.patterns import ValidationError +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.expr.operations.core import Binary, Column, Unary, Value from ibis.expr.operations.generic import _Negatable +from ibis.expr.operations.relations import Relation # noqa: TCH001 @public class LogicalBinary(Binary): - left = rlz.boolean - right = rlz.boolean + left: Value[dt.Boolean] + right: Value[dt.Boolean] output_dtype = dt.boolean @public class Not(Unary): - arg = rlz.boolean + arg: Value[dt.Boolean] output_dtype = dt.boolean @@ -47,8 +49,8 @@ class Xor(LogicalBinary): @public class Comparison(Binary): - left = rlz.any - right = rlz.any + left: Value + right: Value output_dtype = dt.boolean @@ -62,7 +64,7 @@ def __init__(self, left, right): Ibis to help the user avoid them? """ if not rlz.comparable(left, right): - raise TypeError( + raise IbisTypeError( f'Arguments {rlz._arg_type_error_format(left)} and ' f'{rlz._arg_type_error_format(right)} are not comparable' ) @@ -106,21 +108,21 @@ class IdenticalTo(Comparison): @public class Between(Value): - arg = rlz.any - lower_bound = rlz.any - upper_bound = rlz.any + arg: Value + lower_bound: Value + upper_bound: Value output_dtype = dt.boolean output_shape = rlz.shape_like("args") def __init__(self, arg, lower_bound, upper_bound): if not rlz.comparable(arg, lower_bound): - raise TypeError( + raise ValidationError( f'Arguments {rlz._arg_type_error_format(arg)} and ' f'{rlz._arg_type_error_format(lower_bound)} are not comparable' ) if not rlz.comparable(arg, upper_bound): - raise TypeError( + raise ValidationError( f'Arguments {rlz._arg_type_error_format(arg)} and ' f'{rlz._arg_type_error_format(upper_bound)} are not comparable' ) @@ -130,14 +132,12 @@ def __init__(self, arg, lower_bound, upper_bound): # TODO(kszucs): decompose it into at least two operations @public class Contains(Value): - value = rlz.any - options = rlz.one_of( - [ - rlz.tuple_of(rlz.any), - rlz.column(rlz.any), - rlz.array, - ] - ) + value: Value + options: Union[ + VarTuple[Value], + Column[dt.Any], + Value[dt.Array], + ] output_dtype = dt.boolean @@ -164,9 +164,9 @@ class Where(Value): Many backends implement this as a built-in function. """ - bool_expr = rlz.boolean - true_expr = rlz.any - false_null_expr = rlz.any + bool_expr: Value[dt.Boolean] + true_expr: Value + false_null_expr: Value output_shape = rlz.shape_like("args") @@ -177,11 +177,11 @@ def output_dtype(self): @public class ExistsSubquery(Value, _Negatable): - foreign_table = rlz.table - predicates = rlz.tuple_of(rlz.boolean) + foreign_table: Relation + predicates: VarTuple[Value[dt.Boolean]] output_dtype = dt.boolean - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar def negate(self) -> NotExistsSubquery: return NotExistsSubquery(*self.args) @@ -189,11 +189,11 @@ def negate(self) -> NotExistsSubquery: @public class NotExistsSubquery(Value, _Negatable): - foreign_table = rlz.table - predicates = rlz.tuple_of(rlz.boolean) + foreign_table: Relation + predicates: VarTuple[Value[dt.Boolean]] output_dtype = dt.boolean - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar def negate(self) -> ExistsSubquery: return ExistsSubquery(*self.args) @@ -237,15 +237,15 @@ class _UnresolvedSubquery(Value, _Negatable): resolved against the outer leaf table when `Selection`s are constructed. """ - tables = rlz.tuple_of(rlz.table) - predicates = rlz.tuple_of(rlz.boolean) + tables: VarTuple[Relation] + predicates: VarTuple[Value[dt.Boolean]] output_dtype = dt.boolean - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar @abc.abstractmethod def _resolve( - self, table: ir.Table + self, table ) -> type[ExistsSubquery] | type[NotExistsSubquery]: # pragma: no cover ... @@ -269,7 +269,7 @@ class UnresolvedNotExistsSubquery(_UnresolvedSubquery): def negate(self) -> UnresolvedExistsSubquery: return UnresolvedExistsSubquery(*self.args) - def _resolve(self, table: ir.Table) -> NotExistsSubquery: + def _resolve(self, table) -> NotExistsSubquery: from ibis.expr.operations.relations import TableNode assert isinstance(table, TableNode) diff --git a/ibis/expr/operations/maps.py b/ibis/expr/operations/maps.py index 024259f68166..3699422a4a0c 100644 --- a/ibis/expr/operations/maps.py +++ b/ibis/expr/operations/maps.py @@ -6,13 +6,12 @@ import ibis.expr.rules as rlz from ibis.common.annotations import attribute from ibis.expr.operations.core import Unary, Value -from ibis.expr.types.generic import null @public class Map(Value): - keys = rlz.array - values = rlz.array + keys: Value[dt.Array] + values: Value[dt.Array] output_shape = rlz.shape_like("args") @@ -26,15 +25,15 @@ def output_dtype(self): @public class MapLength(Unary): - arg = rlz.mapping + arg: Value[dt.Map] output_dtype = dt.int64 @public class MapGet(Value): - arg = rlz.mapping - key = rlz.one_of([rlz.string, rlz.integer]) - default = rlz.optional(rlz.any, default=null()) + arg: Value[dt.Map] + key: Value[dt.String | dt.Integer] + default: Value = None output_shape = rlz.shape_like("args") @@ -47,8 +46,8 @@ def output_dtype(self): @public class MapContains(Value): - arg = rlz.mapping - key = rlz.one_of([rlz.string, rlz.integer]) + arg: Value[dt.Map] + key: Value[dt.String | dt.Integer] output_shape = rlz.shape_like("args") output_dtype = dt.bool @@ -56,7 +55,7 @@ class MapContains(Value): @public class MapKeys(Unary): - arg = rlz.mapping + arg: Value[dt.Map] @attribute.default def output_dtype(self): @@ -65,7 +64,7 @@ def output_dtype(self): @public class MapValues(Unary): - arg = rlz.mapping + arg: Value[dt.Map] @attribute.default def output_dtype(self): @@ -74,8 +73,8 @@ def output_dtype(self): @public class MapMerge(Value): - left = rlz.mapping - right = rlz.mapping + left: Value[dt.Map] + right: Value[dt.Map] output_shape = rlz.shape_like("args") output_dtype = rlz.dtype_like("args") diff --git a/ibis/expr/operations/numeric.py b/ibis/expr/operations/numeric.py index 8b9115dd7a43..f0d077f6b8ed 100644 --- a/ibis/expr/operations/numeric.py +++ b/ibis/expr/operations/numeric.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from typing import Optional from public import public @@ -10,11 +11,15 @@ from ibis.common.annotations import attribute from ibis.expr.operations.core import Binary, Unary, Value +Integer = Value[dt.Integer] +SoftNumeric = Value[dt.Numeric | dt.Boolean] +StrictNumeric = Value[dt.Numeric] + @public class NumericBinary(Binary): - left = rlz.numeric - right = rlz.numeric + left: SoftNumeric + right: SoftNumeric @public @@ -60,7 +65,7 @@ class Modulus(NumericBinary): @public class Negate(Unary): - arg = rlz.one_of((rlz.numeric, rlz.interval)) + arg: Value[dt.Numeric | dt.Interval] output_dtype = rlz.dtype_like("arg") @@ -84,19 +89,22 @@ class NullIfZero(Unary): The input if not zero otherwise `NULL`. """ - arg = rlz.numeric + arg: SoftNumeric + output_dtype = rlz.dtype_like("arg") @public class IsNan(Unary): - arg = rlz.floating + arg: Value[dt.Floating] + output_dtype = dt.boolean @public class IsInf(Unary): - arg = rlz.floating + arg: Value[dt.Floating] + output_dtype = dt.boolean @@ -104,7 +112,8 @@ class IsInf(Unary): class Abs(Unary): """Absolute value.""" - arg = rlz.numeric + arg: SoftNumeric + output_dtype = rlz.dtype_like("arg") @@ -119,7 +128,7 @@ class Ceil(Unary): Other numeric values: yield integer (int32) """ - arg = rlz.numeric + arg: SoftNumeric @property def output_dtype(self): @@ -140,7 +149,7 @@ class Floor(Unary): Other numeric values: yield integer (int32) """ - arg = rlz.numeric + arg: SoftNumeric @property def output_dtype(self): @@ -152,9 +161,9 @@ def output_dtype(self): @public class Round(Value): - arg = rlz.numeric + arg: StrictNumeric # TODO(kszucs): the default should be 0 instead of being None - digits = rlz.optional(rlz.numeric) + digits: Optional[Integer] = None output_shape = rlz.shape_like("arg") @@ -170,9 +179,9 @@ def output_dtype(self): @public class Clip(Value): - arg = rlz.strict_numeric - lower = rlz.optional(rlz.strict_numeric) - upper = rlz.optional(rlz.strict_numeric) + arg: StrictNumeric + lower: Optional[StrictNumeric] = None + upper: Optional[StrictNumeric] = None output_dtype = rlz.dtype_like("arg") output_shape = rlz.shape_like("arg") @@ -180,9 +189,10 @@ class Clip(Value): @public class BaseConvert(Value): - arg = rlz.one_of([rlz.integer, rlz.string]) - from_base = rlz.integer - to_base = rlz.integer + # TODO(kszucs): this should be Integer simply + arg: Value[dt.Integer | dt.String] + from_base: Integer + to_base: Integer output_dtype = dt.string output_shape = rlz.shape_like("args") @@ -190,7 +200,7 @@ class BaseConvert(Value): @public class MathUnary(Unary): - arg = rlz.numeric + arg: SoftNumeric @attribute.default def output_dtype(self): @@ -211,7 +221,8 @@ class Exp(ExpandingMathUnary): @public class Sign(Unary): - arg = rlz.numeric + arg: SoftNumeric + output_dtype = rlz.dtype_like("arg") @@ -222,13 +233,12 @@ class Sqrt(MathUnary): @public class Logarithm(MathUnary): - arg = rlz.strict_numeric + arg: StrictNumeric @public class Log(Logarithm): - arg = rlz.strict_numeric - base = rlz.optional(rlz.strict_numeric) + base: Optional[StrictNumeric] = None @public @@ -268,8 +278,9 @@ class TrigonometricUnary(MathUnary): class TrigonometricBinary(Binary): """Trigonometric base binary.""" - left = rlz.numeric - right = rlz.numeric + left: SoftNumeric + right: SoftNumeric + output_dtype = dt.float64 @@ -315,14 +326,15 @@ class Tan(TrigonometricUnary): @public class BitwiseNot(Unary): - arg = rlz.integer + arg: Integer + output_dtype = rlz.numeric_like("args", operator.invert) @public class BitwiseBinary(Binary): - left = rlz.integer - right = rlz.integer + left: Integer + right: Integer @public diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index 15ba30c99b98..3b25a5dc5b7d 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -1,18 +1,22 @@ from __future__ import annotations +from typing import Literal, Optional + from public import public import ibis.common.exceptions as exc +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute -from ibis.expr.operations.core import Value +from ibis.expr.operations.core import Column, Value from ibis.expr.operations.generic import _Negatable +from ibis.expr.operations.relations import Relation # noqa: TCH001 @public class Reduction(Value): - output_shape = rlz.Shape.SCALAR + output_shape = ds.scalar @property def __window_op__(self): @@ -20,34 +24,34 @@ def __window_op__(self): class Filterable(Value): - where = rlz.optional(rlz.boolean) + where: Optional[Value[dt.Boolean]] = None @public class Count(Filterable, Reduction): - arg = rlz.column(rlz.any) + arg: Column[dt.Any] output_dtype = dt.int64 @public class CountStar(Filterable, Reduction): - arg = rlz.table + arg: Relation output_dtype = dt.int64 @public class CountDistinctStar(Filterable, Reduction): - arg = rlz.table + arg: Relation output_dtype = dt.int64 @public class Arbitrary(Filterable, Reduction): - arg = rlz.column(rlz.any) - how = rlz.isin({'first', 'last', 'heavy'}) + arg: Column[dt.Any] + how: Literal["first", "last", "heavy"] output_dtype = rlz.dtype_like('arg') @@ -56,7 +60,8 @@ class Arbitrary(Filterable, Reduction): class First(Filterable, Reduction): """Retrieve the first element.""" - arg = rlz.column(rlz.any) + arg: Column[dt.Any] + output_dtype = rlz.dtype_like("arg") @property @@ -74,7 +79,8 @@ def __window_op__(self): class Last(Filterable, Reduction): """Retrieve the last element.""" - arg = rlz.column(rlz.any) + arg: Column[dt.Any] + output_dtype = rlz.dtype_like("arg") @property @@ -102,7 +108,7 @@ class BitAnd(Filterable, Reduction): * MySQL [`BIT_AND`](https://dev.mysql.com/doc/refman/5.7/en/aggregate-functions.html#function_bit-and) """ - arg = rlz.column(rlz.integer) + arg: Column[dt.Integer] output_dtype = rlz.dtype_like('arg') @@ -120,7 +126,7 @@ class BitOr(Filterable, Reduction): * MySQL [`BIT_OR`](https://dev.mysql.com/doc/refman/5.7/en/aggregate-functions.html#function_bit-or) """ - arg = rlz.column(rlz.integer) + arg: Column[dt.Integer] output_dtype = rlz.dtype_like('arg') @@ -138,14 +144,14 @@ class BitXor(Filterable, Reduction): * MySQL [`BIT_XOR`](https://dev.mysql.com/doc/refman/5.7/en/aggregate-functions.html#function_bit-xor) """ - arg = rlz.column(rlz.integer) + arg: Column[dt.Integer] output_dtype = rlz.dtype_like('arg') @public class Sum(Filterable, Reduction): - arg = rlz.column(rlz.numeric) + arg: Column[dt.Numeric | dt.Boolean] @attribute.default def output_dtype(self): @@ -157,7 +163,7 @@ def output_dtype(self): @public class Mean(Filterable, Reduction): - arg = rlz.column(rlz.numeric) + arg: Column[dt.Numeric | dt.Boolean] @attribute.default def output_dtype(self): @@ -169,7 +175,7 @@ def output_dtype(self): @public class Median(Filterable, Reduction): - arg = rlz.column(rlz.numeric) + arg: Column[dt.Numeric | dt.Boolean] @attribute.default def output_dtype(self): @@ -178,30 +184,30 @@ def output_dtype(self): @public class Quantile(Filterable, Reduction): - arg = rlz.any - quantile = rlz.strict_numeric - interpolation = rlz.optional( - rlz.isin({'linear', 'lower', 'higher', 'midpoint', 'nearest'}) - ) + arg: Value + quantile: Value[dt.Numeric] + interpolation: Optional[ + Literal['linear', 'lower', 'higher', 'midpoint', 'nearest'] + ] = None output_dtype = dt.float64 @public class MultiQuantile(Filterable, Reduction): - arg = rlz.any - quantile = rlz.value(dt.Array(dt.float64)) - interpolation = rlz.optional( - rlz.isin({'linear', 'lower', 'higher', 'midpoint', 'nearest'}) - ) + arg: Value + quantile: Value[dt.Array[dt.Float64]] + interpolation: Optional[ + Literal['linear', 'lower', 'higher', 'midpoint', 'nearest'] + ] = None output_dtype = dt.Array(dt.float64) @public class VarianceBase(Filterable, Reduction): - arg = rlz.column(rlz.numeric) - how = rlz.isin({'sample', 'pop'}) + arg: Column[dt.Numeric | dt.Boolean] + how: Literal["sample", "pop"] @attribute.default def output_dtype(self): @@ -225,9 +231,9 @@ class Variance(VarianceBase): class Correlation(Filterable, Reduction): """Coefficient of correlation of a set of number pairs.""" - left = rlz.column(rlz.numeric) - right = rlz.column(rlz.numeric) - how = rlz.optional(rlz.isin({'sample', 'pop'}), default='sample') + left: Column[dt.Numeric | dt.Boolean] + right: Column[dt.Numeric | dt.Boolean] + how: Literal['sample', 'pop'] = 'sample' output_dtype = dt.float64 @@ -236,46 +242,46 @@ class Correlation(Filterable, Reduction): class Covariance(Filterable, Reduction): """Covariance of a set of number pairs.""" - left = rlz.column(rlz.numeric) - right = rlz.column(rlz.numeric) - how = rlz.isin({'sample', 'pop'}) + left: Column[dt.Numeric | dt.Boolean] + right: Column[dt.Numeric | dt.Boolean] + how: Literal['sample', 'pop'] output_dtype = dt.float64 @public class Mode(Filterable, Reduction): - arg = rlz.column(rlz.any) + arg: Column output_dtype = rlz.dtype_like('arg') @public class Max(Filterable, Reduction): - arg = rlz.column(rlz.any) + arg: Column output_dtype = rlz.dtype_like('arg') @public class Min(Filterable, Reduction): - arg = rlz.column(rlz.any) + arg: Column output_dtype = rlz.dtype_like('arg') @public class ArgMax(Filterable, Reduction): - arg = rlz.column(rlz.any) - key = rlz.column(rlz.any) + arg: Column + key: Column output_dtype = rlz.dtype_like("arg") @public class ArgMin(Filterable, Reduction): - arg = rlz.column(rlz.any) - key = rlz.column(rlz.any) + arg: Column + key: Column output_dtype = rlz.dtype_like("arg") @@ -287,7 +293,7 @@ class ApproxCountDistinct(Filterable, Reduction): Impala offers the NDV built-in function for this. """ - arg = rlz.column(rlz.any) + arg: Column # Impala 2.0 and higher returns a DOUBLE output_dtype = dt.int64 @@ -297,29 +303,29 @@ class ApproxCountDistinct(Filterable, Reduction): class ApproxMedian(Filterable, Reduction): """Compute the approximate median of a set of comparable values.""" - arg = rlz.column(rlz.any) + arg: Column output_dtype = rlz.dtype_like('arg') @public class GroupConcat(Filterable, Reduction): - arg = rlz.column(rlz.any) - sep = rlz.string + arg: Column + sep: Value[dt.String] output_dtype = dt.string @public class CountDistinct(Filterable, Reduction): - arg = rlz.column(rlz.any) + arg: Column output_dtype = dt.int64 @public class ArrayCollect(Filterable, Reduction): - arg = rlz.column(rlz.any) + arg: Column @attribute.default def output_dtype(self): @@ -328,7 +334,7 @@ def output_dtype(self): @public class All(Filterable, Reduction, _Negatable): - arg = rlz.column(rlz.boolean) + arg: Column[dt.Boolean] output_dtype = dt.boolean @@ -338,7 +344,7 @@ def negate(self): @public class NotAll(Filterable, Reduction, _Negatable): - arg = rlz.column(rlz.boolean) + arg: Column[dt.Boolean] output_dtype = dt.boolean @@ -348,7 +354,7 @@ def negate(self) -> Any: @public class Any(Filterable, Reduction, _Negatable): - arg = rlz.column(rlz.boolean) + arg: Column[dt.Boolean] output_dtype = dt.boolean @@ -358,7 +364,7 @@ def negate(self) -> NotAny: @public class NotAny(Filterable, Reduction, _Negatable): - arg = rlz.column(rlz.boolean) + arg: Column[dt.Boolean] output_dtype = dt.boolean diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index c444a7d88d73..f0363e9be951 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -1,31 +1,33 @@ from __future__ import annotations import abc -import collections import itertools from abc import abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import Union as UnionType from public import public import ibis.common.exceptions as com +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt import ibis.expr.operations as ops -import ibis.expr.rules as rlz -import ibis.expr.types as ir from ibis import util from ibis.common.annotations import attribute -from ibis.common.collections import frozendict +from ibis.common.collections import FrozenDict # noqa: TCH001 from ibis.common.grounds import Immutable +from ibis.common.patterns import Coercible, Pattern +from ibis.common.typing import VarTuple from ibis.expr.deferred import Deferred -from ibis.expr.operations.core import Named, Node, Value -from ibis.expr.operations.generic import TableColumn -from ibis.expr.operations.logical import Equals, ExistsSubquery, NotExistsSubquery +from ibis.expr.operations.core import Column, Named, Node, Scalar, Value +from ibis.expr.operations.sortkeys import SortKey from ibis.expr.schema import Schema if TYPE_CHECKING: import pandas as pd import pyarrow as pa + _table_names = (f'unbound_table_{i:d}' for i in itertools.count()) @@ -35,7 +37,21 @@ def genname(): @public -class TableNode(Node): +class Relation(Node, Coercible): + @classmethod + def __coerce__(cls, value): + import pandas as pd + + import ibis + import ibis.expr.types as ir + + if isinstance(value, pd.DataFrame): + return ibis.memtable(value).op() + elif isinstance(value, ir.Expr): + return value.op() + else: + return value + def order_by(self, sort_exprs): return Selection(self, [], sort_keys=sort_exprs) @@ -50,8 +66,11 @@ def to_expr(self): return ir.Table(self) +TableNode = Relation + + @public -class PhysicalTable(TableNode, Named): +class PhysicalTable(Relation, Named): pass @@ -59,25 +78,30 @@ class PhysicalTable(TableNode, Named): # should just extend TableNode @public class UnboundTable(PhysicalTable): - schema = rlz.coerced_to(Schema) - name = rlz.optional(rlz.instance_of(str), default=genname) + schema: Schema + name: Optional[str] = None + + def __init__(self, schema, name) -> None: + if name is None: + name = genname() + super().__init__(schema=schema, name=name) @public class DatabaseTable(PhysicalTable): - name = rlz.instance_of(str) - schema = rlz.instance_of(Schema) - source = rlz.client - namespace = rlz.optional(rlz.instance_of(str)) + name: str + schema: Schema + source: Any + namespace: Optional[str] = None @public class SQLQueryResult(TableNode): """A table sourced from the result set of a select query.""" - query = rlz.instance_of(str) - schema = rlz.instance_of(Schema) - source = rlz.client + query: str + schema: Schema + source: Any # TODO(kszucs): Add a pseudohashable wrapper and use that from InMemoryTable @@ -145,9 +169,9 @@ def to_pyarrow(self, schema: Schema) -> pa.Table: @public class InMemoryTable(PhysicalTable): - name = rlz.instance_of(str) - schema = rlz.instance_of(Schema) - data = rlz.instance_of(TableProxy) + name: str + schema: Schema + data: TableProxy # TODO(kszucs): desperately need to clean this up, the majority of this @@ -201,10 +225,10 @@ def _clean_join_predicates(left, right, predicates): @public -class Join(TableNode): - left = rlz.table - right = rlz.table - predicates = rlz.optional(lambda x, this: x, default=()) +class Join(Relation): + left: Relation + right: Relation + predicates: Any = () def __init__(self, left, right, predicates, **kwargs): # TODO(kszucs): predicates should be already a list of operations, need @@ -212,6 +236,7 @@ def __init__(self, left, right, predicates, **kwargs): # currently import ibis.expr.analysis as an import ibis.expr.operations as ops + import ibis.expr.types as ir # TODO(kszucs): need to factor this out to appropriate join predicate # rules @@ -303,8 +328,8 @@ class CrossJoin(Join): @public class AsOfJoin(Join): # TODO(kszucs): convert to proper predicate rules - by = rlz.optional(lambda x, this: x, default=()) - tolerance = rlz.optional(rlz.interval) + by: Any = () + tolerance: Optional[Value[dt.Interval]] = None def __init__(self, left, right, by, predicates, **kwargs): by = _clean_join_predicates(left, right, util.promote_list(by)) @@ -312,10 +337,10 @@ def __init__(self, left, right, by, predicates, **kwargs): @public -class SetOp(TableNode): - left = rlz.table - right = rlz.table - distinct = rlz.optional(rlz.instance_of(bool), default=False) +class SetOp(Relation): + left: Relation + right: Relation + distinct: bool = False def __init__(self, left, right, **kwargs): if left.schema != right.schema: @@ -348,10 +373,10 @@ class Difference(SetOp): @public -class Limit(TableNode): - table = rlz.table - n = rlz.instance_of(int) - offset = rlz.instance_of(int) +class Limit(Relation): + table: Relation + n: int + offset: int @property def schema(self): @@ -359,8 +384,8 @@ def schema(self): @public -class SelfReference(TableNode): - table = rlz.table +class SelfReference(Relation): + table: Relation @attribute.default def name(self) -> str: @@ -373,9 +398,9 @@ def schema(self): return self.table.schema -class Projection(TableNode): - table = rlz.table - selections = rlz.tuple_of(rlz.one_of((rlz.table, rlz.any))) +class Projection(Relation): + table: Relation + selections: VarTuple[Relation | Value] @attribute.default def schema(self): @@ -406,10 +431,8 @@ def _add_alias(op: ops.Value | ops.TableNode): @public class Selection(Projection): - predicates = rlz.optional(rlz.tuple_of(rlz.boolean), default=()) - sort_keys = rlz.optional( - rlz.tuple_of(rlz.sort_key_from(rlz.ref("table"))), default=() - ) + predicates: VarTuple[Value[dt.Boolean]] = () + sort_keys: VarTuple[SortKey] = () def __init__(self, table, selections, predicates, sort_keys, **kwargs): from ibis.expr.analysis import shares_all_roots, shares_some_roots @@ -438,7 +461,8 @@ def __init__(self, table, selections, predicates, sort_keys, **kwargs): def order_by(self, sort_exprs): from ibis.expr.analysis import shares_all_roots, sub_immediate_parents - keys = rlz.tuple_of(rlz.sort_key_from(rlz.just(self)), sort_exprs) + p = Pattern.from_typehint(VarTuple[SortKey]) + keys = p.validate(sort_exprs, {}) if not self.selections: if shares_all_roots(keys, table := self.table): @@ -460,8 +484,9 @@ def _projection(self): @public -class DummyTable(TableNode): - values = rlz.tuple_of(rlz.scalar(rlz.any), min_length=1) +class DummyTable(Relation): + # TODO(kszucs): verify that it has at least one element: Length(at_least=1) + values: VarTuple[Value[dt.Any, ds.Scalar]] @property def schema(self): @@ -469,15 +494,13 @@ def schema(self): @public -class Aggregation(TableNode): - table = rlz.table - metrics = rlz.optional(rlz.tuple_of(rlz.scalar(rlz.any)), default=()) - by = rlz.optional(rlz.tuple_of(rlz.column(rlz.any)), default=()) - having = rlz.optional(rlz.tuple_of(rlz.scalar(rlz.boolean)), default=()) - predicates = rlz.optional(rlz.tuple_of(rlz.boolean), default=()) - sort_keys = rlz.optional( - rlz.tuple_of(rlz.sort_key_from(rlz.ref("table"))), default=() - ) +class Aggregation(Relation): + table: Relation + metrics: VarTuple[Scalar] = () + by: VarTuple[Column] = () + having: VarTuple[Scalar[dt.Boolean]] = () + predicates: VarTuple[Value[dt.Boolean]] = () + sort_keys: VarTuple[SortKey] = () def __init__(self, table, metrics, by, having, predicates, sort_keys): from ibis.expr.analysis import shares_all_roots, shares_some_roots @@ -516,7 +539,8 @@ def schema(self): def order_by(self, sort_exprs): from ibis.expr.analysis import shares_all_roots, sub_immediate_parents - keys = rlz.tuple_of(rlz.sort_key_from(rlz.just(self)), sort_exprs) + p = Pattern.from_typehint(VarTuple[SortKey]) + keys = p.validate(sort_exprs, {}) if shares_all_roots(keys, table := self.table): sort_keys = tuple(self.sort_keys) + tuple( @@ -535,7 +559,7 @@ def order_by(self, sort_exprs): @public -class Distinct(TableNode): +class Distinct(Relation): """Distinct is a table-level unique-ing operation. In SQL, you might have: @@ -547,36 +571,22 @@ class Distinct(TableNode): FROM table """ - table = rlz.table + table: Relation @property def schema(self): return self.table.schema +# TODO(kszucs): split it into two operations, one working with a single replacement +# value and the other with a mapping +# TODO(kszucs): the single value case was limited to numeric and string types @public -class FillNa(TableNode): +class FillNa(Relation): """Fill null values in the table.""" - table = rlz.table - replacements = rlz.one_of( - ( - rlz.numeric, - rlz.string, - rlz.instance_of(collections.abc.Mapping), - ) - ) - - def __init__(self, table, replacements, **kwargs): - super().__init__( - table=table, - replacements=( - replacements - if not isinstance(replacements, collections.abc.Mapping) - else frozendict(replacements) - ), - **kwargs, - ) + table: Relation + replacements: UnionType[Value[dt.Numeric | dt.String], FrozenDict[str, Any]] @property def schema(self): @@ -584,12 +594,12 @@ def schema(self): @public -class DropNa(TableNode): +class DropNa(Relation): """Drop null values in the table.""" - table = rlz.table - how = rlz.isin({'any', 'all'}) - subset = rlz.optional(rlz.tuple_of(rlz.column(rlz.any))) + table: Relation + how: Literal["any", "all"] + subset: Optional[VarTuple[Column[dt.Any]]] = None @property def schema(self): @@ -600,8 +610,8 @@ def schema(self): class View(PhysicalTable): """A view created from an expression.""" - child = rlz.table - name = rlz.instance_of(str) + child: Relation + name: str @property def schema(self): @@ -612,9 +622,9 @@ def schema(self): class SQLStringView(PhysicalTable): """A view created from a SQL string.""" - child = rlz.table - name = rlz.instance_of(str) - query = rlz.instance_of(str) + child: Relation + name: str + query: str @attribute.default def schema(self): @@ -624,6 +634,9 @@ def schema(self): def _dedup_join_columns(expr, lname: str, rname: str): + from ibis.expr.operations.generic import TableColumn + from ibis.expr.operations.logical import Equals + op = expr.op() left = op.left.to_expr() right = op.right.to_expr() @@ -690,4 +703,4 @@ def _dedup_join_columns(expr, lname: str, rname: str): return expr.select(projections) -public(ExistsSubquery=ExistsSubquery, NotExistsSubquery=NotExistsSubquery) +public(TableNode=Relation) diff --git a/ibis/expr/operations/sortkeys.py b/ibis/expr/operations/sortkeys.py index 4f936e909632..6072e61b79b5 100644 --- a/ibis/expr/operations/sortkeys.py +++ b/ibis/expr/operations/sortkeys.py @@ -7,17 +7,48 @@ import ibis.expr.rules as rlz from ibis.expr.operations.core import Value +# TODO(kszucs): move the content of this file to generic.py + +_is_ascending = { + "asc": True, + "ascending": True, + "desc": False, + "descending": False, + 0: False, + 1: True, + False: False, + True: True, +} + + +# TODO(kszucs): consider to limit its shape to Columnar, we could treat random() +# as a columnar operation too @public class SortKey(Value): """A sort operation.""" - expr = rlz.any - ascending = rlz.optional(rlz.bool_, default=True) + expr: Value + ascending: bool = True output_dtype = rlz.dtype_like("expr") output_shape = rlz.shape_like("expr") + @classmethod + def __coerce__(cls, value, T=None, S=None): + if isinstance(value, tuple): + key, asc = value + else: + key, asc = value, True + + asc = _is_ascending[asc] + key = super().__coerce__(key, T=T, S=S) + + if isinstance(key, cls): + return key + else: + return cls(key, asc) + @property def name(self) -> str: return self.expr.name diff --git a/ibis/expr/operations/strings.py b/ibis/expr/operations/strings.py index caa6893eed25..e5d2d186ec05 100644 --- a/ibis/expr/operations/strings.py +++ b/ibis/expr/operations/strings.py @@ -1,16 +1,21 @@ from __future__ import annotations +from typing import Optional + from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute +from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Unary, Value @public class StringUnary(Unary): - arg = rlz.string + arg: Value[dt.String] + output_dtype = dt.string @@ -51,9 +56,9 @@ class Capitalize(StringUnary): @public class Substring(Value): - arg = rlz.string - start = rlz.integer - length = rlz.optional(rlz.integer) + arg: Value[dt.String] + start: Value[dt.Integer] + length: Optional[Value[dt.Integer]] = None output_dtype = dt.string output_shape = rlz.shape_like('arg') @@ -61,26 +66,28 @@ class Substring(Value): @public class StrRight(Value): - arg = rlz.string - nchars = rlz.integer + arg: Value[dt.String] + nchars: Value[dt.Integer] + output_shape = rlz.shape_like("arg") output_dtype = dt.string @public class Repeat(Value): - arg = rlz.string - times = rlz.integer + arg: Value[dt.String] + times: Value[dt.Integer] + output_shape = rlz.shape_like("arg") output_dtype = dt.string @public class StringFind(Value): - arg = rlz.string - substr = rlz.string - start = rlz.optional(rlz.integer) - end = rlz.optional(rlz.integer) + arg: Value[dt.String] + substr: Value[dt.String] + start: Optional[Value[dt.Integer]] = None + end: Optional[Value[dt.Integer]] = None output_shape = rlz.shape_like("arg") output_dtype = dt.int64 @@ -88,9 +95,9 @@ class StringFind(Value): @public class Translate(Value): - arg = rlz.string - from_str = rlz.string - to_str = rlz.string + arg: Value[dt.String] + from_str: Value[dt.String] + to_str: Value[dt.String] output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -98,9 +105,9 @@ class Translate(Value): @public class LPad(Value): - arg = rlz.string - length = rlz.integer - pad = rlz.optional(rlz.string) + arg: Value[dt.String] + length: Value[dt.Integer] + pad: Optional[Value[dt.String]] = None output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -108,9 +115,9 @@ class LPad(Value): @public class RPad(Value): - arg = rlz.string - length = rlz.integer - pad = rlz.optional(rlz.string) + arg: Value[dt.String] + length: Value[dt.Integer] + pad: Optional[Value[dt.String]] = None output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -118,8 +125,8 @@ class RPad(Value): @public class FindInSet(Value): - needle = rlz.string - values = rlz.tuple_of(rlz.string, min_length=1) + needle: Value[dt.String] + values: VarTuple[Value[dt.String]] output_shape = rlz.shape_like("needle") output_dtype = dt.int64 @@ -127,8 +134,8 @@ class FindInSet(Value): @public class StringJoin(Value): - sep = rlz.string - arg = rlz.tuple_of(rlz.string, min_length=1) + sep: Value[dt.String] + arg: VarTuple[Value[dt.String]] output_dtype = dt.string @@ -139,8 +146,8 @@ def output_shape(self): @public class ArrayStringJoin(Value): - sep = rlz.string - arg = rlz.value(dt.Array(dt.string)) + sep: Value[dt.String] + arg: Value[dt.Array[dt.String]] output_dtype = dt.string output_shape = rlz.shape_like("args") @@ -148,33 +155,36 @@ class ArrayStringJoin(Value): @public class StartsWith(Value): - arg = rlz.string - start = rlz.scalar(rlz.string) + arg: Value[dt.String] + start: Value[dt.String, ds.Scalar] + output_dtype = dt.boolean output_shape = rlz.shape_like("arg") @public class EndsWith(Value): - arg = rlz.string - end = rlz.scalar(rlz.string) + arg: Value[dt.String] + end: Value[dt.String, ds.Scalar] + output_dtype = dt.boolean output_shape = rlz.shape_like("arg") @public class FuzzySearch(Value): - arg = rlz.string - pattern = rlz.string + arg: Value[dt.String] + pattern: Value[dt.String] + output_dtype = dt.boolean output_shape = rlz.shape_like('arg') @public class StringSQLLike(FuzzySearch): - arg = rlz.string - pattern = rlz.string - escape = rlz.optional(rlz.instance_of(str)) + arg: Value[dt.String] + pattern: Value[dt.String] + escape: Optional[str] = None @public @@ -189,9 +199,9 @@ class RegexSearch(FuzzySearch): @public class RegexExtract(Value): - arg = rlz.string - pattern = rlz.string - index = rlz.integer + arg: Value[dt.String] + pattern: Value[dt.String] + index: Value[dt.Integer] output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -199,9 +209,9 @@ class RegexExtract(Value): @public class RegexReplace(Value): - arg = rlz.string - pattern = rlz.string - replacement = rlz.string + arg: Value[dt.String] + pattern: Value[dt.String] + replacement: Value[dt.String] output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -209,9 +219,9 @@ class RegexReplace(Value): @public class StringReplace(Value): - arg = rlz.string - pattern = rlz.string - replacement = rlz.string + arg: Value[dt.String] + pattern: Value[dt.String] + replacement: Value[dt.String] output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -219,8 +229,8 @@ class StringReplace(Value): @public class StringSplit(Value): - arg = rlz.string - delimiter = rlz.string + arg: Value[dt.String] + delimiter: Value[dt.String] output_shape = rlz.shape_like("arg") output_dtype = dt.Array(dt.string) @@ -228,14 +238,15 @@ class StringSplit(Value): @public class StringConcat(Value): - arg = rlz.tuple_of(rlz.string) + arg: VarTuple[Value[dt.String]] + output_shape = rlz.shape_like('arg') output_dtype = rlz.dtype_like('arg') @public class ExtractURLField(Value): - arg = rlz.string + arg: Value[dt.String] output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -273,7 +284,7 @@ class ExtractPath(ExtractURLField): @public class ExtractQuery(ExtractURLField): - key = rlz.optional(rlz.string) + key: Optional[Value[dt.String]] = None @public @@ -293,8 +304,8 @@ class StringAscii(StringUnary): @public class StringContains(Value): - haystack = rlz.string - needle = rlz.string + haystack: Value[dt.String] + needle: Value[dt.String] output_shape = rlz.shape_like("args") output_dtype = dt.bool @@ -302,8 +313,8 @@ class StringContains(Value): @public class Levenshtein(Value): - left = rlz.string - right = rlz.string + left: Value[dt.String] + right: Value[dt.String] output_dtype = dt.int64 output_shape = rlz.shape_like("args") diff --git a/ibis/expr/operations/structs.py b/ibis/expr/operations/structs.py index 0413de0be139..e1f0e9e5243a 100644 --- a/ibis/expr/operations/structs.py +++ b/ibis/expr/operations/structs.py @@ -2,16 +2,18 @@ from public import public +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute +from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Value @public class StructField(Value): - arg = rlz.struct - field = rlz.instance_of(str) + arg: Value[dt.Struct] + field: str output_shape = rlz.shape_like("arg") @@ -28,10 +30,10 @@ def name(self) -> str: @public class StructColumn(Value): - names = rlz.tuple_of(rlz.instance_of(str), min_length=1) - values = rlz.tuple_of(rlz.any, min_length=1) + names: VarTuple[str] + values: VarTuple[Value] - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar @attribute.default def output_dtype(self) -> dt.DataType: diff --git a/ibis/expr/operations/temporal.py b/ibis/expr/operations/temporal.py index 96d63081c1de..5955ff2d97b3 100644 --- a/ibis/expr/operations/temporal.py +++ b/ibis/expr/operations/temporal.py @@ -3,10 +3,12 @@ import operator from public import public +from typing_extensions import Annotated import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute +from ibis.common.patterns import As, Attrs from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit from ibis.expr.operations.core import Binary, Unary, Value from ibis.expr.operations.logical import Between @@ -14,18 +16,18 @@ @public class TemporalUnary(Unary): - arg = rlz.temporal + arg: Value[dt.Temporal] @public class TimestampUnary(Unary): - arg = rlz.timestamp + arg: Value[dt.Timestamp] @public class TimestampTruncate(Value): - arg = rlz.timestamp - unit = rlz.coerced_to(IntervalUnit) + arg: Value[dt.Timestamp] + unit: IntervalUnit output_shape = rlz.shape_like("arg") output_dtype = dt.timestamp @@ -33,8 +35,8 @@ class TimestampTruncate(Value): @public class DateTruncate(Value): - arg = rlz.date - unit = rlz.coerced_to(DateUnit) + arg: Value[dt.Date] + unit: DateUnit output_shape = rlz.shape_like("arg") output_dtype = dt.date @@ -42,8 +44,8 @@ class DateTruncate(Value): @public class TimeTruncate(Value): - arg = rlz.time - unit = rlz.coerced_to(TimeUnit) + arg: Value[dt.Time] + unit: TimeUnit output_shape = rlz.shape_like("arg") output_dtype = dt.time @@ -51,8 +53,8 @@ class TimeTruncate(Value): @public class Strftime(Value): - arg = rlz.temporal - format_str = rlz.string + arg: Value[dt.Temporal] + format_str: Value[dt.String] output_shape = rlz.shape_like("arg") output_dtype = dt.string @@ -60,8 +62,8 @@ class Strftime(Value): @public class StringToTimestamp(Value): - arg = rlz.string - format_str = rlz.string + arg: Value[dt.String] + format_str: Value[dt.String] output_shape = rlz.shape_like("arg") output_dtype = dt.Timestamp(timezone='UTC') @@ -74,12 +76,12 @@ class ExtractTemporalField(TemporalUnary): @public class ExtractDateField(ExtractTemporalField): - arg = rlz.one_of([rlz.date, rlz.timestamp]) + arg: Value[dt.Date | dt.Timestamp] @public class ExtractTimeField(ExtractTemporalField): - arg = rlz.one_of([rlz.time, rlz.timestamp]) + arg: Value[dt.Time | dt.Timestamp] @public @@ -144,13 +146,15 @@ class ExtractMillisecond(ExtractTimeField): @public class DayOfWeekIndex(Unary): - arg = rlz.one_of([rlz.date, rlz.timestamp]) + arg: Value[dt.Date | dt.Timestamp] + output_dtype = dt.int16 @public class DayOfWeekName(Unary): - arg = rlz.one_of([rlz.date, rlz.timestamp]) + arg: Value[dt.Date | dt.Timestamp] + output_dtype = dt.string @@ -166,9 +170,9 @@ class Date(Unary): @public class DateFromYMD(Value): - year = rlz.integer - month = rlz.integer - day = rlz.integer + year: Value[dt.Integer] + month: Value[dt.Integer] + day: Value[dt.Integer] output_dtype = dt.date output_shape = rlz.shape_like("args") @@ -176,9 +180,9 @@ class DateFromYMD(Value): @public class TimeFromHMS(Value): - hours = rlz.integer - minutes = rlz.integer - seconds = rlz.integer + hours: Value[dt.Integer] + minutes: Value[dt.Integer] + seconds: Value[dt.Integer] output_dtype = dt.time output_shape = rlz.shape_like("args") @@ -186,12 +190,12 @@ class TimeFromHMS(Value): @public class TimestampFromYMDHMS(Value): - year = rlz.integer - month = rlz.integer - day = rlz.integer - hours = rlz.integer - minutes = rlz.integer - seconds = rlz.integer + year: Value[dt.Integer] + month: Value[dt.Integer] + day: Value[dt.Integer] + hours: Value[dt.Integer] + minutes: Value[dt.Integer] + seconds: Value[dt.Integer] output_dtype = dt.timestamp output_shape = rlz.shape_like("args") @@ -199,77 +203,86 @@ class TimestampFromYMDHMS(Value): @public class TimestampFromUNIX(Value): - arg = rlz.any - unit = rlz.coerced_to(TimestampUnit) + arg: Value + unit: TimestampUnit output_dtype = dt.timestamp output_shape = rlz.shape_like('arg') +TimeInterval = Annotated[dt.Interval, Attrs(unit=As(TimeUnit))] +DateInterval = Annotated[dt.Interval, Attrs(unit=As(DateUnit))] + + @public class DateAdd(Binary): - left = rlz.date - right = rlz.interval(units={'Y', 'Q', 'M', 'W', 'D'}) + left: Value[dt.Date] + right: Value[DateInterval] + output_dtype = rlz.dtype_like('left') @public class DateSub(Binary): - left = rlz.date - right = rlz.interval(units={'Y', 'Q', 'M', 'W', 'D'}) + left: Value[dt.Date] + right: Value[DateInterval] + output_dtype = rlz.dtype_like('left') @public class DateDiff(Binary): - left = rlz.date - right = rlz.date + left: Value[dt.Date] + right: Value[dt.Date] + output_dtype = dt.Interval('D') @public class TimeAdd(Binary): - left = rlz.time - right = rlz.interval(units={'h', 'm', 's', 'ms', 'us', 'ns'}) + left: Value[dt.Time] + right: Value[TimeInterval] + output_dtype = rlz.dtype_like('left') @public class TimeSub(Binary): - left = rlz.time - right = rlz.interval(units={'h', 'm', 's', 'ms', 'us', 'ns'}) + left: Value[dt.Time] + right: Value[TimeInterval] + output_dtype = rlz.dtype_like('left') @public class TimeDiff(Binary): - left = rlz.time - right = rlz.time + left: Value[dt.Time] + right: Value[dt.Time] + output_dtype = dt.Interval('s') @public class TimestampAdd(Binary): - left = rlz.timestamp - right = rlz.interval( - units={'Y', 'Q', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns'} - ) + left: Value[dt.Timestamp] + right: Value[dt.Interval] + output_dtype = rlz.dtype_like('left') @public class TimestampSub(Binary): - left = rlz.timestamp - right = rlz.interval( - units={'Y', 'Q', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns'} - ) + left: Value[dt.Timestamp] + right: Value[dt.Interval] + output_dtype = rlz.dtype_like('left') @public class TimestampDiff(Binary): - left = rlz.timestamp - right = rlz.timestamp + left: Value[dt.Timestamp] + right: Value[dt.Timestamp] + output_dtype = dt.Interval('s') @@ -289,36 +302,36 @@ def output_dtype(self): @public class IntervalAdd(IntervalBinary): - left = rlz.interval - right = rlz.interval + left: Value[dt.Interval] + right: Value[dt.Interval] op = operator.add @public class IntervalSubtract(IntervalBinary): - left = rlz.interval - right = rlz.interval + left: Value[dt.Interval] + right: Value[dt.Interval] op = operator.sub @public class IntervalMultiply(IntervalBinary): - left = rlz.interval - right = rlz.numeric + left: Value[dt.Interval] + right: Value[dt.Numeric | dt.Boolean] op = operator.mul @public class IntervalFloorDivide(IntervalBinary): - left = rlz.interval - right = rlz.numeric + left: Value[dt.Interval] + right: Value[dt.Numeric | dt.Boolean] op = operator.floordiv @public class IntervalFromInteger(Value): - arg = rlz.integer - unit = rlz.coerced_to(IntervalUnit) + arg: Value[dt.Integer] + unit: IntervalUnit output_shape = rlz.shape_like("arg") @@ -333,9 +346,9 @@ def resolution(self): @public class BetweenTime(Between): - arg = rlz.one_of([rlz.timestamp, rlz.time]) - lower_bound = rlz.one_of([rlz.time, rlz.string]) - upper_bound = rlz.one_of([rlz.time, rlz.string]) + arg: Value[dt.Time | dt.Timestamp] + lower_bound: Value[dt.Time | dt.String] + upper_bound: Value[dt.Time | dt.String] public(ExtractTimestampField=ExtractTemporalField) diff --git a/ibis/expr/operations/tests/__init__.py b/ibis/expr/operations/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ibis/expr/operations/tests/test_generic.py b/ibis/expr/operations/tests/test_generic.py new file mode 100644 index 000000000000..88c272b3e9a8 --- /dev/null +++ b/ibis/expr/operations/tests/test_generic.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import pytest + +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.common.patterns import ( + CoercedTo, + GenericCoercedTo, + Pattern, + ValidationError, +) + +# TODO(kszucs): actually we should only allow datatype classes not instances + + +@pytest.mark.parametrize( + ("value", "dtype"), + [ + (1, dt.int8), + (1.0, dt.double), + (True, dt.boolean), + ("foo", dt.string), + (b"foo", dt.binary), + ((1, 2), dt.Array(dt.int8)), + ], +) +def test_literal_coercion_type_inference(value, dtype): + assert ops.Literal.__coerce__(value) == ops.Literal(value, dtype) + assert ops.Literal.__coerce__(value, dtype) == ops.Literal(value, dtype) + + +def test_coerced_to_literal(): + p = CoercedTo(ops.Literal) + one = ops.Literal(1, dt.int8) + assert p.validate(ops.Literal(1, dt.int8), {}) == one + assert p.validate(1, {}) == one + assert p.validate(False, {}) == ops.Literal(False, dt.boolean) + + p = GenericCoercedTo(ops.Literal[dt.Int8]) + assert p.validate(ops.Literal(1, dt.int8), {}) == one + + p = Pattern.from_typehint(ops.Literal[dt.Int8]) + assert p == GenericCoercedTo(ops.Literal[dt.Int8]) + + one = ops.Literal(1, dt.int16) + with pytest.raises(ValidationError): + p.validate(ops.Literal(1, dt.int16), {}) + + +def test_coerced_to_value(): + one = ops.Literal(1, dt.int8) + + p = Pattern.from_typehint(ops.Value) + assert p.validate(1, {}) == one + + p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Any]) + assert p.validate(1, {}) == one + + p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Scalar]) + assert p.validate(1, {}) == one + + p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Columnar]) + with pytest.raises(ValidationError): + p.validate(1, {}) + + # dt.Integer is not instantiable so it will be only used for checking + # that the produced literal has any integer datatype + p = Pattern.from_typehint(ops.Value[dt.Integer, ds.Any]) + assert p.validate(1, {}) == one + + # same applies here, the coercion itself will use only the inferred datatype + # but then the result is checked against the given typehint + p = Pattern.from_typehint(ops.Value[dt.Int8 | dt.Int16, ds.Any]) + assert p.validate(1, {}) == one + assert p.validate(128, {}) == ops.Literal(128, dt.int16) + + p1 = Pattern.from_typehint(ops.Value[dt.Int8, ds.Any]) + p2 = Pattern.from_typehint(ops.Value[dt.Int16, ds.Scalar]) + assert p1.validate(1, {}) == one + # this is actually supported by creating an explicit dtype + # in Value.__coerce__ based on the `T` keyword argument + assert p2.validate(1, {}) == ops.Literal(1, dt.int16) + assert p2.validate(128, {}) == ops.Literal(128, dt.int16) + + p = p1 | p2 + assert p.validate(1, {}) == one + + +@pytest.mark.pandas +def test_coerced_to_interval_value(): + import pandas as pd + + p = Pattern.from_typehint(ops.Value[dt.Interval, ds.Any]) + + value = pd.Timedelta("1s") + result = p.match(value, {}) + assert result.value == 1 + assert result.dtype == dt.Interval("s") + + value = pd.Timedelta("1h 1m 1s") + result = p.match(value, {}) + assert result.value == 3661 + assert result.dtype == dt.Interval("s") diff --git a/ibis/expr/operations/udf.py b/ibis/expr/operations/udf.py index d1ee736d6254..19ff7b27fc04 100644 --- a/ibis/expr/operations/udf.py +++ b/ibis/expr/operations/udf.py @@ -13,6 +13,7 @@ import ibis.expr.operations as ops import ibis.expr.rules as rlz from ibis import util +from ibis.common.annotations import Argument from ibis.common.collections import FrozenDict if TYPE_CHECKING: @@ -125,10 +126,11 @@ def make_node( if (raw_dtype := annotations.get(name)) is None: raise exc.MissingParameterAnnotationError(fn, name) - arg = rlz.value(dt.dtype(raw_dtype)) - if (default := param.default) is not EMPTY: - arg = rlz.optional(arg, default=default) - fields[name] = arg + arg = rlz.ValueOf(dt.dtype(raw_dtype)) + if (default := param.default) is EMPTY: + fields[name] = Argument.required(validator=arg) + else: + fields[name] = Argument.default(validator=arg, default=default) fields["output_dtype"] = dt.dtype(return_annotation) @@ -137,6 +139,7 @@ def make_node( fields["__func__"] = property(fget=lambda _, fn=fn: fn) fields["__config__"] = FrozenDict(args=args, kwargs=FrozenDict(**kwargs)) fields["__udf_namespace__"] = kwargs.get("schema") + fields["__module__"] = fn.__module__ return type(fn.__name__, (ScalarUDF,), fields) diff --git a/ibis/expr/operations/vectorized.py b/ibis/expr/operations/vectorized.py index f41a055ff1dd..4dec58e645f9 100644 --- a/ibis/expr/operations/vectorized.py +++ b/ibis/expr/operations/vectorized.py @@ -1,22 +1,24 @@ from __future__ import annotations -from types import FunctionType, LambdaType +from types import FunctionType, LambdaType # noqa: TCH003 +from typing import Union from public import public -from ibis.expr import rules as rlz +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt +from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.analytic import Analytic -from ibis.expr.operations.core import Value +from ibis.expr.operations.core import Column, Value from ibis.expr.operations.reductions import Reduction class VectorizedUDF(Value): - func = rlz.instance_of((FunctionType, LambdaType)) - func_args = rlz.tuple_of(rlz.column(rlz.any)) - # TODO(kszucs): should rename these arguments to - # input_dtypes and return_dtype - input_type = rlz.tuple_of(rlz.datatype) - return_type = rlz.datatype + func: Union[FunctionType, LambdaType] + func_args: VarTuple[Column] + # TODO(kszucs): should rename these arguments to input_dtypes and return_dtype + input_type: VarTuple[dt.DataType] + return_type: dt.DataType @property def output_dtype(self): @@ -27,14 +29,14 @@ def output_dtype(self): class ElementWiseVectorizedUDF(VectorizedUDF): """Node for element wise UDF.""" - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar @public class ReductionVectorizedUDF(VectorizedUDF, Reduction): """Node for reduction UDF.""" - output_shape = rlz.Shape.SCALAR + output_shape = ds.scalar # TODO(kszucs): revisit @@ -42,4 +44,4 @@ class ReductionVectorizedUDF(VectorizedUDF, Reduction): class AnalyticVectorizedUDF(VectorizedUDF, Analytic): """Node for analytics UDF.""" - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar diff --git a/ibis/expr/operations/window.py b/ibis/expr/operations/window.py index b3b86c6e286c..772fb415083c 100644 --- a/ibis/expr/operations/window.py +++ b/ibis/expr/operations/window.py @@ -1,41 +1,72 @@ from __future__ import annotations from abc import abstractmethod +from typing import Optional from public import public +from typing_extensions import TypeVar import ibis.common.exceptions as com +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -from ibis.expr.operations import Value +from ibis.common.patterns import CoercionError +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.expr.operations.core import Column, Value +from ibis.expr.operations.generic import Literal +from ibis.expr.operations.numeric import Negate +from ibis.expr.operations.relations import Relation # noqa: TCH001 +from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001 + +T = TypeVar("T", bound=dt.Numeric | dt.Interval, covariant=True) +S = TypeVar("S", bound=ds.DataShape, default=ds.Any, covariant=True) @public -class WindowBoundary(Value): +class WindowBoundary(Value[T, S]): # TODO(kszucs): consider to prefer Concrete base class here # pretty similar to SortKey and Alias operations which wrap a single value - value = rlz.one_of([rlz.numeric, rlz.interval]) - preceding = rlz.bool_ - - output_shape = rlz.shape_like("value") - output_dtype = rlz.dtype_like("value") + value: Value[T, S] + preceding: bool @property def following(self) -> bool: return not self.preceding + @property + def output_shape(self) -> S: + return self.value.output_shape + + @property + def output_dtype(self) -> T: + return self.value.output_dtype + + @classmethod + def __coerce__(cls, value, **kwargs): + arg = super().__coerce__(value, **kwargs) + + if isinstance(arg, cls): + return arg + elif isinstance(arg, Negate): + return cls(arg.arg, preceding=True) + elif isinstance(arg, Literal): + new = arg.copy(value=abs(arg.value)) + return cls(new, preceding=arg.value < 0) + elif isinstance(arg, Value): + return cls(arg, preceding=False) + else: + raise CoercionError(f'Invalid window boundary type: {type(arg)}') + @public class WindowFrame(Value): """A window frame operation bound to a table.""" - table = rlz.table - group_by = rlz.optional(rlz.tuple_of(rlz.any), default=()) - order_by = rlz.optional( - rlz.tuple_of(rlz.sort_key_from(rlz.ref("table"))), default=() - ) + table: Relation + group_by: VarTuple[Column] = () + order_by: VarTuple[SortKey] = () - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar def __init__(self, start, end, **kwargs): if start and end and start.output_dtype != end.output_dtype: @@ -61,9 +92,9 @@ def end(self): @public class RowsWindowFrame(WindowFrame): how = "rows" - start = rlz.optional(rlz.row_window_boundary) - end = rlz.optional(rlz.row_window_boundary) - max_lookback = rlz.optional(rlz.interval) + start: Optional[WindowBoundary[dt.Integer]] = None + end: Optional[WindowBoundary] = None + max_lookback: Optional[Value[dt.Interval]] = None def __init__(self, max_lookback, order_by, **kwargs): if max_lookback: @@ -83,20 +114,27 @@ def __init__(self, max_lookback, order_by, **kwargs): @public class RangeWindowFrame(WindowFrame): how = "range" - start = rlz.optional(rlz.range_window_boundary) - end = rlz.optional(rlz.range_window_boundary) + start: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None + end: Optional[WindowBoundary[dt.Numeric | dt.Interval]] = None @public class WindowFunction(Value): - func = rlz.analytic - frame = rlz.instance_of(WindowFrame) + func: Value + frame: WindowFrame output_dtype = rlz.dtype_like("func") - output_shape = rlz.Shape.COLUMNAR + output_shape = ds.columnar def __init__(self, func, frame): - from ibis.expr.analysis import propagate_down_window, shares_all_roots + from ibis.expr.analysis import ( + is_analytic, + propagate_down_window, + shares_all_roots, + ) + + if not is_analytic(func): + raise com.IbisTypeError("Window function expression must be analytic") func = propagate_down_window(func, frame) if not shares_all_roots(func, frame): diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index fd695653c194..903568e4b2fe 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -1,50 +1,16 @@ from __future__ import annotations -import enum import operator from itertools import product, starmap from public import public -import ibis.common.exceptions as com import ibis.expr.datatypes as dt -import ibis.expr.schema as sch -import ibis.expr.types as ir +import ibis.expr.operations as ops from ibis import util -from ibis.common.annotations import attribute, optional +from ibis.common.annotations import attribute +from ibis.common.patterns import CoercionError, Matcher, NoMatch from ibis.common.temporal import IntervalUnit -from ibis.common.validators import ( - bool_, - callable_with, # noqa: F401 - coerced_to, # noqa: F401 - equal_to, # noqa: F401 - instance_of, - isin, - lazy_instance_of, - map_to, - one_of, - option, # noqa: F401 - pair_of, # noqa: F401 - ref, - str_, - tuple_of, - validator, -) -from ibis.expr.deferred import Deferred - - -# TODO(kszucs): consider to rename to datashape -@public -class Shape(enum.IntEnum): - SCALAR = 0 - COLUMNAR = 1 - # TABULAR = 2 - - def is_scalar(self): - return self is Shape.SCALAR - - def is_columnar(self): - return self is Shape.COLUMNAR @public @@ -88,215 +54,6 @@ def comparable(left, right): return castable(left, right) or castable(right, left) -class rule(validator): - def _erase_expr(self, value): - return value.op() if isinstance(value, ir.Expr) else value - - def __call__(self, *args, **kwargs): - args = map(self._erase_expr, args) - kwargs = {k: self._erase_expr(v) for k, v in kwargs.items()} - result = super().__call__(*args, **kwargs) - assert not isinstance(result, ir.Expr) - return result - - -# --------------------------------------------------------------------- -# Input type validators / coercion functions - - -@validator -def expr_of(inner, value, **kwargs): - value = inner(value, **kwargs) - return value if isinstance(value, ir.Expr) else value.to_expr() - - -@rule -def just(arg): - return lambda **_: arg - - -@rule -def sort_key_from(table_ref, key, **kwargs): - import ibis.expr.operations as ops - - is_ascending = { - "asc": True, - "ascending": True, - "desc": False, - "descending": False, - 0: False, - 1: True, - False: False, - True: True, - } - - if isinstance(key, ops.SortKey): - return key - elif isinstance(key, tuple): - key, order = key - else: - key, order = key, True - - if isinstance(order, str): - order = order.lower() - order = map_to(is_ascending, order) - - return ops.SortKey(key, ascending=order) - - -@rule -def datatype(arg, **kwargs): - return dt.dtype(arg) - - -# TODO(kszucs): make type argument the first and mandatory, similarly to the -# value rule, move out the type inference to `ir.literal()` method -# TODO(kszucs): may not make sense to support an explicit datatype here, we -# could do the coercion in the API function ibis.literal() -@rule -def literal(dtype, value, **kwargs): - import ibis.expr.operations as ops - - if isinstance(value, ops.Literal): - return value - - dtype = dt.infer(value) if dtype is None else dt.dtype(dtype) - value = dt.normalize(dtype, value) - - return ops.Literal(value, dtype=dtype) - - -@rule -def value(dtype, arg, **kwargs): - """Validates that the given argument is a Value with a particular datatype. - - Parameters - ---------- - dtype - DataType subclass or DataType instance - arg - If a Python literal is given the validator tries to coerce it to an ibis - literal. - kwargs - Keyword arguments - - Returns - ------- - ir.Value - An ibis value expression with the specified datatype - """ - import ibis.expr.operations as ops - - if isinstance(arg, Deferred): - raise com.IbisTypeError( - "Deferred input is not allowed, try passing a lambda function instead. " - "For example, instead of writing `f(_.a)` write `lambda t: f(t.a)`" - ) - - if not isinstance(arg, ops.Value): - # coerce python literal to ibis literal - arg = literal(None, arg) - - if dtype is None: - # no datatype restriction - return arg - elif isinstance(dtype, type): - # dtype class has been specified like dt.Interval or dt.Decimal - if not issubclass(dtype, dt.DataType): - raise com.IbisTypeError( - f"Datatype specification {dtype} is not a subclass dt.DataType" - ) - elif isinstance(arg.output_dtype, dtype): - return arg - else: - raise com.IbisTypeError( - f'Given argument with datatype {arg.output_dtype} is not ' - f'subtype of {dtype}' - ) - elif isinstance(dtype, (dt.DataType, str)): - # dtype instance or string has been specified and arg's dtype is - # implicitly castable to it, like dt.int8 is castable to dt.int64 - dtype = dt.dtype(dtype) - # retrieve literal values for implicit cast check - value = getattr(arg, 'value', None) - if dt.castable(arg.output_dtype, dtype, value=value): - return arg - else: - raise com.IbisTypeError( - f'Given argument with datatype {arg.output_dtype} is not ' - f'implicitly castable to {dtype}' - ) - else: - raise com.IbisTypeError(f'Invalid datatype specification {dtype}') - - -@rule -def scalar(inner, arg, **kwargs): - arg = inner(arg, **kwargs) - if arg.output_shape.is_scalar(): - return arg - else: - raise com.IbisTypeError(f"{arg} is not a scalar") - - -@rule -def column(inner, arg, **kwargs): - arg = inner(arg, **kwargs) - if arg.output_shape.is_columnar(): - return arg - else: - raise com.IbisTypeError(f"{arg} is not a column") - - -any = value(None) -double = value(dt.double) -string = value(dt.string) -boolean = value(dt.boolean) -integer = value(dt.int64) -decimal = value(dt.Decimal) -floating = value(dt.float64) -date = value(dt.date) -time = value(dt.time) -timestamp = value(dt.Timestamp) -temporal = one_of([timestamp, date, time]) -json = value(dt.json) - -strict_numeric = one_of([integer, floating, decimal]) -soft_numeric = one_of([integer, floating, decimal, boolean]) -numeric = soft_numeric - -array = value(dt.Array) -struct = value(dt.Struct) -mapping = value(dt.Map) - -geospatial = value(dt.GeoSpatial) -point = value(dt.Point) -linestring = value(dt.LineString) -polygon = value(dt.Polygon) -multilinestring = value(dt.MultiLineString) -multipoint = value(dt.MultiPoint) -multipolygon = value(dt.MultiPolygon) - - -@public -@rule -def interval(arg, units=None, **kwargs): - arg = value(dt.Interval, arg) - unit = arg.output_dtype.unit.short - if units is not None and unit not in units: - msg = 'Interval unit `{}` is not among the allowed ones {}' - raise com.IbisTypeError(msg.format(unit, units)) - return arg - - -@public -@rule -def client(arg, **kwargs): - from ibis.backends.base import BaseBackend - - return instance_of(BaseBackend, arg) - - # --------------------------------------------------------------------- # Output type functions @@ -408,142 +165,41 @@ def _promote_interval_resolution(units: list[IntervalUnit]) -> IntervalUnit: raise AssertionError('unreachable') -# TODO(kszucs): it could be as simple as rlz.instance_of(ops.TableNode) -# we have a single test case testing the schema superset condition, not -# used anywhere else -@public -@rule -def table(arg, schema=None, **kwargs): - """A table argument. - - Parameters - ---------- - arg - A table node - schema - A validator for the table's columns. Only column subset validators are - currently supported. Accepts any arguments that `sch.schema` accepts. - See the example for usage. - kwargs - Keyword arguments - - The following op will accept an argument named `'table'`. Note that the - `schema` argument specifies rules for columns that are required to be in - the table: `time`, `group` and `value1`. These must match the types - specified in the column rules. Column `value2` is optional, but if present - it must be of the specified type. The table may have extra columns not - specified in the schema. - """ - import pandas as pd - - import ibis - import ibis.expr.operations as ops - - if isinstance(arg, pd.DataFrame): - arg = ibis.memtable(arg).op() - - if not isinstance(arg, ops.TableNode): - raise com.IbisTypeError( - f'Argument is not a table; got type {type(arg).__name__}' - ) - - if schema is not None: - if arg.schema >= sch.schema(schema): - return arg - - raise com.IbisTypeError( - f'Argument is not a table with column subset of {schema}' - ) - return arg - - -@public -@rule -def reduction(arg, **kwargs): - from ibis.expr.analysis import is_reduction - - if not is_reduction(arg): - raise com.IbisTypeError("`argument` must be a reduction") - - return arg - - -@public -@rule -def analytic(arg, **kwargs): - from ibis.expr.analysis import is_analytic +def _arg_type_error_format(op): + from ibis.expr.operations.generic import Literal - if not is_analytic(arg): - raise com.IbisInputError('Expression does not contain a valid window operation') + if isinstance(op, Literal): + return f"Literal({op.value}):{op.output_dtype}" + else: + return f"{op.name}:{op.output_dtype}" - return arg +class ValueOf(Matcher): + """Match a value of a specific type **instance**. -@public -@rule -def window_boundary(inner, arg, **kwargs): - import ibis.expr.operations as ops + This is different from the Value[T] annotations which construct + GenericCoercedTo(Value[T]) validators working with datatype types + rather than instances. - arg = inner(arg, **kwargs) - - if isinstance(arg, ops.WindowBoundary): - return arg - elif isinstance(arg, ops.Negate): - return ops.WindowBoundary(arg.arg, preceding=True) - elif isinstance(arg, ops.Literal): - new = arg.copy(value=abs(arg.value)) - return ops.WindowBoundary(new, preceding=arg.value < 0) - elif isinstance(arg, ops.Value): - return ops.WindowBoundary(arg, preceding=False) - else: - raise TypeError(f'Invalid window boundary type: {type(arg)}') - - -row_window_boundary = window_boundary(integer) -range_window_boundary = window_boundary(one_of([numeric, interval])) + Parameters + ---------- + dtype : DataType | None + The datatype the constructed Value instance should conform to. + """ + __slots__ = ('dtype',) -def _arg_type_error_format(op): - from ibis.expr.operations.generic import Literal + def __init__(self, dtype=None): + dtype = None if dtype is None else dt.dtype(dtype) + super().__init__(dtype) - if isinstance(op, Literal): - return f"Literal({op.value}):{op.output_dtype}" - else: - return f"{op.name}:{op.output_dtype}" + def match(self, value, context): + try: + value = ops.Value.__coerce__(value, self.dtype) + except CoercionError: + return NoMatch + if self.dtype and not value.output_dtype.castable(self.dtype): + return NoMatch -public( - any=any, - array=array, - bool=bool_, - boolean=boolean, - date=date, - decimal=decimal, - double=double, - floating=floating, - geospatial=geospatial, - integer=integer, - isin=isin, - json=json, - lazy_instance_of=lazy_instance_of, - linestring=linestring, - mapping=mapping, - multilinestring=multilinestring, - multipoint=multipoint, - numeric=numeric, - optional=optional, - point=point, - polygon=polygon, - ref=ref, - soft_numeric=soft_numeric, - str_=str_, - strict_numeric=strict_numeric, - string=string, - struct=struct, - temporal=temporal, - time=time, - timestamp=timestamp, - tuple_of=tuple_of, - row_window_boundary=row_window_boundary, - range_window_boundary=range_window_boundary, -) + return value diff --git a/ibis/expr/schema.py b/ibis/expr/schema.py index feda014f38bd..864ba76a398b 100644 --- a/ibis/expr/schema.py +++ b/ibis/expr/schema.py @@ -9,7 +9,7 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.common.exceptions import InputTypeError, IntegrityError from ibis.common.grounds import Concrete -from ibis.common.validators import Coercible +from ibis.common.patterns import Coercible from ibis.util import deprecated, indent if TYPE_CHECKING: @@ -49,6 +49,8 @@ def __getitem__(self, name: str) -> dt.DataType: @classmethod def __coerce__(cls, value) -> Schema: + if isinstance(value, cls): + return value return schema(value) @attribute.default diff --git a/ibis/expr/tests/test_datashape.py b/ibis/expr/tests/test_datashape.py new file mode 100644 index 000000000000..eb5af285921c --- /dev/null +++ b/ibis/expr/tests/test_datashape.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from ibis.expr.datashape import ( + Any, + Columnar, + DataShape, + Scalar, + Tabular, + columnar, + scalar, + tabular, +) + + +def test_scalar_shape(): + s = Scalar() + assert s.ndim == 0 + assert s.is_scalar() + assert not s.is_columnar() + assert not s.is_tabular() + + +def test_columnar_shape(): + c = Columnar() + assert c.ndim == 1 + assert not c.is_scalar() + assert c.is_columnar() + assert not c.is_tabular() + + +def test_tabular_shape(): + t = Tabular() + assert t.ndim == 2 + assert not t.is_scalar() + assert not t.is_columnar() + assert t.is_tabular() + + +def test_shapes_are_singletons(): + assert Scalar() is scalar + assert Scalar() is Scalar() + assert Columnar() is columnar + assert Columnar() is Columnar() + assert Tabular() is tabular + assert Tabular() is Tabular() + + +def test_shape_comparison(): + assert Scalar() < Columnar() + assert Scalar() <= Columnar() + assert Columnar() > Scalar() + assert Columnar() >= Scalar() + assert Scalar() != Columnar() + assert Scalar() == Scalar() + assert Columnar() == Columnar() + assert Tabular() == Tabular() + assert Tabular() != Columnar() + assert Tabular() != Scalar() + assert Tabular() > Columnar() + assert Tabular() > Scalar() + assert Tabular() >= Columnar() + assert Tabular() >= Scalar() + + +def test_shapes_are_hashable(): + assert hash(Scalar()) == hash(Scalar()) + assert hash(Columnar()) == hash(Columnar()) + assert hash(Tabular()) == hash(Tabular()) + assert hash(Scalar()) != hash(Columnar()) + assert hash(Scalar()) != hash(Tabular()) + assert hash(Columnar()) != hash(Tabular()) + assert len({Scalar(), Columnar(), Tabular()}) == 3 + + +def test_backward_compat_aliases(): + assert DataShape.SCALAR is scalar + assert DataShape.COLUMNAR is columnar + assert DataShape.TABULAR is tabular + + +def test_any_alias_for_datashape(): + # useful for typehints like `ds.Any` + assert DataShape is Any diff --git a/ibis/expr/tests/test_rules.py b/ibis/expr/tests/test_rules.py deleted file mode 100644 index 72fb0818f6e3..000000000000 --- a/ibis/expr/tests/test_rules.py +++ /dev/null @@ -1,292 +0,0 @@ -from __future__ import annotations - -import decimal - -import parsy -import pytest -from pytest import param -from toolz import identity - -import ibis -import ibis.expr.datatypes as dt -import ibis.expr.rules as rlz -import ibis.expr.types as ir -from ibis.common.exceptions import IbisTypeError - -table = ibis.table( - [('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')] -) - -similar_table = ibis.table( - [('int_col', 'int64'), ('string_col', 'string'), ('double_col', 'double')] -) - - -@pytest.mark.parametrize( - ('value', 'expected'), - [ - (dt.int32, dt.int32), - ('int64', dt.int64), - ('array', dt.Array(dt.string)), - (int, dt.int64), - (float, dt.float64), - ], -) -def test_valid_datatype(value, expected): - assert rlz.datatype(value) == expected - - -@pytest.mark.parametrize( - ('value', 'expected'), - [ - ('exception', parsy.ParseError), - ('array', parsy.ParseError), - ], -) -def test_invalid_datatype(value, expected): - with pytest.raises(expected): - assert rlz.datatype(value) - - -def test_string_literal_from_integer(): - lit = rlz.literal(dt.string, 1) - assert type(lit.value) is str - assert lit.value == "1" - - -@pytest.mark.parametrize( - ('klass', 'value', 'expected'), - [(int, 32, 32), (str, 'foo', 'foo'), (dt.Integer, dt.int8, dt.int8)], -) -def test_valid_instance_of(klass, value, expected): - assert rlz.instance_of(klass, value) == expected - - -@pytest.mark.parametrize( - ('klass', 'value', 'expected'), - [ - (ir.Table, object, IbisTypeError), - (ir.IntegerValue, 4, IbisTypeError), - ], -) -def test_invalid_instance_of(klass, value, expected): - with pytest.raises(expected): - assert rlz.instance_of(klass, value) - - -def test_lazy_instance_of(): - rule = rlz.lazy_instance_of("decimal.Decimal") - assert "decimal.Decimal" in repr(rule) - d = decimal.Decimal(1) - assert rule(d) == d - with pytest.raises(IbisTypeError, match=r"decimal\.Decimal"): - rule(1) - - -@pytest.mark.parametrize( - ('dtype', 'value', 'expected'), - [ - param(dt.int8, 26, ibis.literal(26)), - param(dt.int16, 26, ibis.literal(26)), - param(dt.int32, 26, ibis.literal(26)), - param(dt.int64, 26, ibis.literal(26)), - param(dt.uint8, 26, ibis.literal(26)), - param(dt.uint16, 26, ibis.literal(26)), - param(dt.uint32, 26, ibis.literal(26)), - param(dt.uint64, 26, ibis.literal(26)), - param(dt.float32, 26, ibis.literal(26)), - param(dt.float64, 26.4, ibis.literal(26.4)), - param(dt.double, 26.3, ibis.literal(26.3)), - param(dt.string, 'bar', ibis.literal('bar')), - param( - dt.Array(dt.float64), - [3.4, 5.6], - ibis.literal([3.4, 5.6]), - ), - param( - dt.Map(dt.string, dt.Array(dt.boolean)), - {'a': [True, False], 'b': [True]}, - ibis.literal({'a': [True, False], 'b': [True]}), - id='map_literal', - ), - ], -) -def test_valid_value(dtype, value, expected): - result = rlz.value(dtype, value) - assert result == expected.op() - - -@pytest.mark.parametrize( - ('dtype', 'value', 'expected'), - [ - (dt.uint8, -3, IbisTypeError), - (dt.int32, {}, IbisTypeError), - (dt.string, 1, IbisTypeError), - (dt.Array(dt.float64), ['s'], IbisTypeError), - ( - dt.Map(dt.string, dt.Array(dt.boolean)), - {'a': [True, False], 'b': ['B']}, - IbisTypeError, - ), - ], -) -def test_invalid_value(dtype, value, expected): - with pytest.raises(expected): - rlz.value(dtype, value) - - -@pytest.mark.parametrize( - ('validator', 'values', 'expected'), - [ - param( - rlz.tuple_of(rlz.integer), - (3, 2), - (ibis.literal(3), ibis.literal(2)), - id="tuple_int", - ), - param( - rlz.tuple_of(rlz.integer), - (3, None), - (ibis.literal(3), ibis.NA), - id="tuple_int_null", - ), - param( - rlz.tuple_of(rlz.string), - ('a',), - (ibis.literal('a'),), - id="tuple_string_one", - ), - param( - rlz.tuple_of(rlz.string), - ['a', 'b'], - (ibis.literal('a'), ibis.literal('b')), - id="tuple_string_two", - ), - param( - rlz.tuple_of(rlz.boolean, min_length=2), - [True, False], - (ibis.literal(True), ibis.literal(False)), - id="tuple_boolean", - ), - param( - rlz.tuple_of(rlz.string), - ["bar", table.string_col, "foo"], - (ibis.literal("bar"), table.string_col, ibis.literal("foo")), - ), - ], -) -def test_valid_tuple_of(validator, values, expected): - result = validator(values) - assert isinstance(result, tuple) - - -def test_valid_tuple_of_extra(): - validator = rlz.tuple_of(identity) - assert validator((3, 2)) == (3, 2) - - validator = rlz.tuple_of(rlz.tuple_of(rlz.string)) - result = validator([[], ['a']]) - assert result[1][0].equals(ibis.literal('a').op()) - - -@pytest.mark.parametrize( - ('validator', 'values'), - [ - (rlz.tuple_of(rlz.double, min_length=2), [1]), - (rlz.tuple_of(rlz.integer), 1.1), - (rlz.tuple_of(rlz.string), 'asd'), - (rlz.tuple_of(identity), 3), - ], -) -def test_invalid_tuple_of(validator, values): - with pytest.raises(IbisTypeError): - validator(values) - - -@pytest.mark.parametrize( - ('units', 'value', 'expected'), - [ - ({'H', 'D'}, ibis.interval(days=3), ibis.interval(days=3)), - (['Y'], ibis.interval(years=3), ibis.interval(years=3)), - ], -) -def test_valid_interval(units, value, expected): - result = rlz.interval(value, units=units) - assert result.equals(expected.op()) - - -@pytest.mark.parametrize( - ('units', 'value', 'expected'), - [ - ({'Y'}, ibis.interval(hours=1), IbisTypeError), - ({'Y', 'M', 'D'}, ibis.interval(hours=1), IbisTypeError), - ({'Q', 'W', 'D'}, ibis.interval(seconds=1), IbisTypeError), - ], -) -def test_invalid_interval(units, value, expected): - with pytest.raises(expected): - rlz.interval(value, units=units) - - -@pytest.mark.parametrize( - ('validator', 'value', 'expected'), - [ - (rlz.column(rlz.any), table.int_col, table.int_col), - (rlz.column(rlz.string), table.string_col, table.string_col), - (rlz.scalar(rlz.integer), ibis.literal(3), ibis.literal(3)), - (rlz.scalar(rlz.any), 'caracal', ibis.literal('caracal')), - ], -) -def test_valid_column_or_scalar(validator, value, expected): - result = validator(value) - assert result.equals(expected.op()) - - -@pytest.mark.parametrize( - ('validator', 'value', 'expected'), - [ - (rlz.column(rlz.integer), table.double_col, IbisTypeError), - (rlz.column(rlz.any), ibis.literal(3), IbisTypeError), - (rlz.column(rlz.integer), ibis.literal(3), IbisTypeError), - ], -) -def test_invalid_column_or_scalar(validator, value, expected): - with pytest.raises(expected): - validator(value) - - -@pytest.mark.parametrize( - 'table', - [ - ibis.table([('group', dt.int64), ('value', dt.double)]), - ibis.table([('group', dt.int64), ('value', dt.double), ('value2', dt.double)]), - ], -) -def test_table_with_schema(table): - validator = rlz.table(schema=[('group', dt.int64), ('value', dt.double)]) - assert validator(table) == table.op() - - -@pytest.mark.parametrize( - 'table', [ibis.table([('group', dt.int64), ('value', dt.timestamp)])] -) -def test_table_with_schema_invalid(table): - validator = rlz.table(schema=[('group', dt.double), ('value', dt.timestamp)]) - with pytest.raises(ValueError): - validator(table) - - -@pytest.mark.parametrize( - ('validator', 'input'), - [ - (rlz.tuple_of(rlz.integer), (3, 2)), - (rlz.instance_of(int), 32), - ], -) -def test_optional(validator, input): - expected = validator(input) - if isinstance(expected, ibis.Expr): - assert rlz.optional(validator).validate(input).equals(expected) - else: - assert rlz.optional(validator).validate(input) == expected - assert rlz.optional(validator).validate(None) is None diff --git a/ibis/expr/tests/test_schema.py b/ibis/expr/tests/test_schema.py index 4183eae019cd..755c28f8717b 100644 --- a/ibis/expr/tests/test_schema.py +++ b/ibis/expr/tests/test_schema.py @@ -10,10 +10,10 @@ import pytest import ibis.expr.datatypes as dt -import ibis.expr.rules as rlz import ibis.expr.schema as sch from ibis.common.exceptions import IntegrityError from ibis.common.grounds import Annotable +from ibis.common.patterns import CoercedTo has_pandas = False with contextlib.suppress(ImportError): @@ -338,7 +338,7 @@ class ObjectWithSchema(Annotable): def test_schema_is_coercible(): s = sch.Schema({'a': dt.int64, 'b': dt.Array(dt.int64)}) - assert rlz.coerced_to(sch.Schema, PreferenceA) == s + assert CoercedTo(sch.Schema).validate(PreferenceA, {}) == s o = ObjectWithSchema(schema=PreferenceA) assert o.schema == s diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index 9354a8301f4d..247ada3bdecf 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable, Iterable - +import functools from public import public import ibis.expr.operations as ops @@ -392,7 +392,12 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue: │ [] │ └───────────────────────┘ """ - return ops.ArrayMap(self, func=func).to_expr() + + @functools.wraps(func) + def wrapped(x): + return func(x.to_expr()) + + return ops.ArrayMap(self, func=wrapped).to_expr() def filter( self, predicate: Callable[[ir.Value], bool | ir.BooleanValue] @@ -435,7 +440,12 @@ def filter( │ [] │ └──────────────────────┘ """ - return ops.ArrayFilter(self, func=predicate).to_expr() + + @functools.wraps(predicate) + def wrapped(x): + return predicate(x.to_expr()) + + return ops.ArrayFilter(self, func=wrapped).to_expr() def contains(self, other: ir.Value) -> ir.BooleanValue: """Return whether the array contains `other`. diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index 917c021e228f..685b7419180c 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -12,7 +12,9 @@ from ibis.common.grounds import Immutable from ibis.config import _default_backend, options from ibis.util import experimental +from ibis.common.patterns import ValidationError, Coercible, CoercionError from rich.jupyter import JupyterMixin +from ibis.common.patterns import Coercible, CoercionError if TYPE_CHECKING: import pandas as pd @@ -35,7 +37,7 @@ def _repr_mimebundle_(self, *args, **kwargs): # TODO(kszucs): consider to subclass from Annotable with a single _arg field @public -class Expr(Immutable): +class Expr(Immutable, Coercible): """Base expression class.""" __slots__ = ("_arg",) @@ -43,6 +45,15 @@ class Expr(Immutable): def __init__(self, arg: ops.Node) -> None: object.__setattr__(self, "_arg", arg) + @classmethod + def __coerce__(cls, value): + if isinstance(value, cls): + return value + elif isinstance(value, ops.Node): + return value.to_expr() + else: + raise CoercionError("Unable to coerce value to an expression") + def __repr__(self) -> str: from ibis.expr.types.pretty import simple_console @@ -556,13 +567,13 @@ def _binop( >>> import ibis.expr.operations as ops >>> expr = _binop(ops.TimeAdd, ibis.time("01:00"), ibis.interval(hours=1)) >>> expr - TimeAdd('01:00', 1): '01:00' + 1 h + TimeAdd(datetime.time(1, 0), 1): datetime.time(1, 0) + 1 h >>> _binop(ops.TimeAdd, 1, ibis.interval(hours=1)) - NotImplemented + TimeAdd(datetime.time(0, 0, 1), 1): datetime.time(0, 0, 1) + 1 h """ try: node = op_class(left, right) - except (IbisTypeError, NotImplementedError): + except (ValidationError, NotImplementedError): return NotImplemented else: return node.to_expr() diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 3b2abccd778f..2e9dce63742b 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1662,12 +1662,19 @@ def literal(value: Any, type: dt.DataType | str | None = None) -> Scalar: ... TypeError: Value 'foobar' cannot be safely coerced to int64 """ - import ibis.expr.rules as rlz - if isinstance(value, Expr): - value = value.op() + node = value.op() + if not isinstance(node, ops.Literal): + raise TypeError(f"Ibis expression {value!r} is not a Literal") + if type is None or node.output_dtype.castable(dt.dtype(type)): + return value + else: + raise TypeError( + f"Ibis literal {value!r} cannot be safely coerced to datatype {type}" + ) - return rlz.literal(type, value).to_expr() + dtype = dt.infer(value) if type is None else dt.dtype(type) + return ops.Literal(value, dtype=dtype).to_expr() public( diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index afb5474912e7..e2ec992c772e 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -205,7 +205,6 @@ def mutate( Table A table expression with window functions applied """ - exprs = self._selectables(*exprs, **kwexprs) return self.table.mutate(exprs) diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index bcad9912ac17..76aa268ad176 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1108,50 +1108,36 @@ def convert_base( def __and__(self, other: IntegerValue) -> IntegerValue | NotImplemented: """Bitwise and `self` with `other`.""" - from ibis.expr import operations as ops - return _binop(ops.BitwiseAnd, self, other) __rand__ = __and__ def __or__(self, other: IntegerValue) -> IntegerValue | NotImplemented: """Bitwise or `self` with `other`.""" - from ibis.expr import operations as ops - return _binop(ops.BitwiseOr, self, other) __ror__ = __or__ def __xor__(self, other: IntegerValue) -> IntegerValue | NotImplemented: """Bitwise xor `self` with `other`.""" - from ibis.expr import operations as ops - return _binop(ops.BitwiseXor, self, other) __rxor__ = __xor__ def __lshift__(self, other: IntegerValue) -> IntegerValue | NotImplemented: """Bitwise left shift `self` with `other`.""" - from ibis.expr import operations as ops - return _binop(ops.BitwiseLeftShift, self, other) def __rlshift__(self, other: IntegerValue) -> IntegerValue | NotImplemented: """Bitwise left shift `self` with `other`.""" - from ibis.expr import operations as ops - return _binop(ops.BitwiseLeftShift, other, self) def __rshift__(self, other: IntegerValue) -> IntegerValue | NotImplemented: """Bitwise right shift `self` with `other`.""" - from ibis.expr import operations as ops - return _binop(ops.BitwiseRightShift, self, other) def __rrshift__(self, other: IntegerValue) -> IntegerValue | NotImplemented: """Bitwise right shift `self` with `other`.""" - from ibis.expr import operations as ops - return _binop(ops.BitwiseRightShift, other, self) def __invert__(self) -> IntegerValue: @@ -1162,8 +1148,6 @@ def __invert__(self) -> IntegerValue: IntegerValue Inverted bits of `self`. """ - from ibis.expr import operations as ops - try: node = ops.BitwiseNot(self) except (IbisTypeError, NotImplementedError): diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 909959b045d1..cc95f37d77b6 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -20,6 +20,7 @@ from ibis import util from ibis.expr.deferred import Deferred from ibis.expr.types.core import Expr, _FixedTextJupyterMixin +from ibis.expr.types.generic import literal if TYPE_CHECKING: import pandas as pd @@ -35,19 +36,20 @@ def _ensure_expr(table, expr): - import ibis.expr.rules as rlz from ibis.selectors import Selector # This is different than self._ensure_expr, since we don't want to # treat `str` or `int` values as column indices - if util.is_function(expr): + if isinstance(expr, Expr): + return expr + elif util.is_function(expr): return expr(table) elif isinstance(expr, Deferred): return expr.resolve(table) elif isinstance(expr, Selector): return expr.expand(table) else: - return rlz.any(expr).to_expr() + return literal(expr) def _regular_join_method( @@ -3649,6 +3651,7 @@ def _resolve_predicates( import ibis.expr.analysis as an import ibis.expr.types as ir + # TODO(kszucs): clean this up, too much flattening and resolving happens here predicates = [ pred.op() for preds in map( diff --git a/ibis/expr/types/strings.py b/ibis/expr/types/strings.py index 333ea198d40a..d1b11cc80d41 100644 --- a/ibis/expr/types/strings.py +++ b/ibis/expr/types/strings.py @@ -12,7 +12,7 @@ from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: - from ibis.expr import types as ir + import ibis.expr.types as ir @public diff --git a/ibis/expr/types/temporal.py b/ibis/expr/types/temporal.py index defa41bcfb03..bb6af3c2baec 100644 --- a/ibis/expr/types/temporal.py +++ b/ibis/expr/types/temporal.py @@ -1,22 +1,25 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Union from public import public -if TYPE_CHECKING: - import pandas as pd - - import ibis.expr.types as ir - +import ibis.expr.datashape as ds import ibis.expr.operations as ops from ibis.expr.types.core import _binop from ibis.expr.types.generic import Column, Scalar, Value -from ibis import util +from ibis.common.annotations import annotated +from ibis.common.patterns import Pattern import ibis.expr.datatypes as dt +from ibis import util from ibis.common.temporal import IntervalUnit +if TYPE_CHECKING: + import pandas as pd + + import ibis.expr.types as ir + @public class TemporalValue(Value): @@ -161,8 +164,6 @@ def between( Whether `self` is between `lower` and `upper`, adjusting `timezone` as needed. """ - import ibis.expr.datatypes as dt - op = self.op() if isinstance(op, ops.Time): # Here we pull out the first argument to the underlying Time @@ -226,14 +227,9 @@ def __add__( Value : TimeValue | NotImplemented """ - def __sub__( - self, - other: TimeValue | IntervalValue, - ) -> IntervalValue | TimeValue | NotImplemented: + @annotated + def __sub__(self, other: ops.Value[dt.Interval | dt.Time, ds.Any]): """Subtract a time or an interval from a time expression.""" - import ibis.expr.rules as rlz - - other = rlz.any(other) if other.output_dtype.is_time(): op = ops.TimeDiff @@ -255,14 +251,9 @@ def __sub__( Value : IntervalValue | TimeValue | NotImplemented """ - def __rsub__( - self, - other: TimeValue | IntervalValue, - ) -> IntervalValue | TimeValue | NotImplemented: + @annotated + def __rsub__(self, other: ops.Value[dt.Interval | dt.Time, ds.Any]): """Subtract a time or an interval from a time expression.""" - import ibis.expr.rules as rlz - - other = rlz.any(other) if other.output_dtype.is_time(): op = ops.TimeDiff @@ -321,18 +312,9 @@ def __add__( Value : DateValue | NotImplemented """ - def __sub__( - self, - other: datetime.date - | DateValue - | datetime.timedelta - | pd.Timedelta - | IntervalValue, - ) -> IntervalValue | DateValue | NotImplemented: + @annotated + def __sub__(self, other: ops.Value[dt.Date | dt.Interval, ds.Any]): """Subtract a date or an interval from a date.""" - import ibis.expr.rules as rlz - - other = rlz.one_of([rlz.date, rlz.interval], other) if other.output_dtype.is_date(): op = ops.DateDiff @@ -354,18 +336,9 @@ def __sub__( Value : DateValue | NotImplemented """ - def __rsub__( - self, - other: datetime.date - | DateValue - | datetime.timedelta - | pd.Timedelta - | IntervalValue, - ) -> IntervalValue | DateValue | NotImplemented: + @annotated + def __rsub__(self, other: ops.Value[dt.Date | dt.Interval, ds.Any]): """Subtract a date or an interval from a date.""" - import ibis.expr.rules as rlz - - other = rlz.one_of([rlz.date, rlz.interval], other) if other.output_dtype.is_date(): op = ops.DateDiff @@ -437,21 +410,11 @@ def __add__( Value : TimestampValue | NotImplemented """ - def __sub__( - self, - other: datetime.datetime - | pd.Timestamp - | TimestampValue - | datetime.timedelta - | pd.Timedelta - | IntervalValue, - ) -> IntervalValue | TimestampValue | NotImplemented: + @annotated + def __sub__(self, other: ops.Value[dt.Timestamp | dt.Interval, ds.Any]): """Subtract a timestamp or an interval from a timestamp.""" - import ibis.expr.rules as rlz - - right = rlz.any(other) - if right.output_dtype.is_timestamp(): + if other.output_dtype.is_timestamp(): op = ops.TimestampDiff else: op = ops.TimestampSub # let the operation validate @@ -471,21 +434,11 @@ def __sub__( Value : IntervalValue | TimestampValue | NotImplemented """ - def __rsub__( - self, - other: datetime.datetime - | pd.Timestamp - | TimestampValue - | datetime.timedelta - | pd.Timedelta - | IntervalValue, - ) -> IntervalValue | TimestampValue | NotImplemented: + @annotated + def __rsub__(self, other: ops.Value[dt.Timestamp | dt.Interval, ds.Any]): """Subtract a timestamp or an interval from a timestamp.""" - import ibis.expr.rules as rlz - - right = rlz.any(other) - if right.output_dtype.is_timestamp(): + if other.output_dtype.is_timestamp(): op = ops.TimestampDiff else: op = ops.TimestampSub # let the operation validate diff --git a/ibis/formats/tests/test_numpy.py b/ibis/formats/tests/test_numpy.py index ad22b2402d1f..67fa79e9629a 100644 --- a/ibis/formats/tests/test_numpy.py +++ b/ibis/formats/tests/test_numpy.py @@ -133,4 +133,4 @@ def test_dtype_from_numpy_dtype_timedelta(): if vparse(pytest.importorskip("pyarrow").__version__) < vparse("9"): pytest.skip("pyarrow < 9 globally mutates the timedelta64 numpy dtype") - assert NumpyType.to_ibis(np.dtype(np.timedelta64)) == dt.Interval("s") + assert NumpyType.to_ibis(np.dtype(np.timedelta64)) == dt.Interval(unit='s') diff --git a/ibis/legacy/udf/validate.py b/ibis/legacy/udf/validate.py index fa46c8fea434..19695579d25c 100644 --- a/ibis/legacy/udf/validate.py +++ b/ibis/legacy/udf/validate.py @@ -8,12 +8,10 @@ from __future__ import annotations from inspect import Parameter, Signature, signature -from typing import TYPE_CHECKING, Any, Callable +from typing import Any, Callable import ibis.common.exceptions as com - -if TYPE_CHECKING: - from ibis.expr.datatypes import DataType +import ibis.expr.datatypes as dt def _parameter_count(funcsig: Signature) -> int: @@ -37,7 +35,7 @@ def _parameter_count(funcsig: Signature) -> int: ) -def validate_input_type(input_type: list[DataType], func: Callable) -> Signature: +def validate_input_type(input_type: list[dt.DataType], func: Callable) -> Signature: """Check that the declared number of inputs and signature of func are compatible. If the signature of `func` uses *args, then no check is done (since no diff --git a/ibis/selectors.py b/ibis/selectors.py index 2c36dd1fc6f1..9ff14eba2a11 100644 --- a/ibis/selectors.py +++ b/ibis/selectors.py @@ -62,7 +62,7 @@ from ibis.common.annotations import attribute from ibis.common.collections import frozendict from ibis.common.grounds import Concrete, Singleton -from ibis.common.validators import Coercible +from ibis.common.patterns import Coercible from ibis.expr.deferred import Deferred diff --git a/ibis/tests/expr/test_analytics.py b/ibis/tests/expr/test_analytics.py index 7105abe732e5..331978c18409 100644 --- a/ibis/tests/expr/test_analytics.py +++ b/ibis/tests/expr/test_analytics.py @@ -17,6 +17,7 @@ import ibis import ibis.expr.types as ir +from ibis.common.patterns import ValidationError from ibis.tests.expr.mocks import MockBackend from ibis.tests.util import assert_equal @@ -66,22 +67,22 @@ def test_bucket(alltypes): def test_bucket_error_cases(alltypes): d = alltypes.double_col - with pytest.raises(ValueError): + with pytest.raises(ValidationError): d.bucket([]) - with pytest.raises(ValueError): + with pytest.raises(ValidationError): d.bucket([1, 2], closed="foo") # it works! d.bucket([10], include_under=True, include_over=True) - with pytest.raises(ValueError): + with pytest.raises(ValidationError): d.bucket([10]) - with pytest.raises(ValueError): + with pytest.raises(ValidationError): d.bucket([10], include_under=True) - with pytest.raises(ValueError): + with pytest.raises(ValidationError): d.bucket([10], include_over=True) diff --git a/ibis/tests/expr/test_decimal.py b/ibis/tests/expr/test_decimal.py index 4d1eaeade1b6..0190ffd59b11 100644 --- a/ibis/tests/expr/test_decimal.py +++ b/ibis/tests/expr/test_decimal.py @@ -7,6 +7,7 @@ import ibis import ibis.expr.datatypes as dt import ibis.expr.types as ir +from ibis.common.patterns import ValidationError from ibis.expr import api @@ -133,7 +134,7 @@ def test_invalid_precision_scale_combo(precision, scale): [(38.1, 3), (38, 3.1)], # non integral precision # non integral scale ) def test_invalid_precision_scale_type(precision, scale): - with pytest.raises(TypeError): + with pytest.raises(ValidationError): dt.Decimal(precision, scale) diff --git a/ibis/tests/expr/test_literal.py b/ibis/tests/expr/test_literal.py index 8fea1974f5be..045ce26f4074 100644 --- a/ibis/tests/expr/test_literal.py +++ b/ibis/tests/expr/test_literal.py @@ -160,7 +160,7 @@ def test_map_literal_non_castable(value): def test_literal_mixed_type_fails(): data = [1, 'a'] - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="Cannot compute precedence"): ibis.literal(data) diff --git a/ibis/tests/expr/test_operations.py b/ibis/tests/expr/test_operations.py index ffeabbbc8e46..7d29cd67f2d1 100644 --- a/ibis/tests/expr/test_operations.py +++ b/ibis/tests/expr/test_operations.py @@ -1,14 +1,17 @@ from __future__ import annotations +from typing import Optional, Tuple + import numpy as np import pytest import ibis +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.rules as rlz import ibis.expr.types as ir -from ibis.common.exceptions import IbisTypeError +from ibis.common.patterns import ValidationError t = ibis.table([('a', 'int64')], name='t') @@ -17,77 +20,74 @@ two = ir.literal(2) three = ir.literal(3) -operations = [ - ops.Cast(three, to='int64'), - ops.TypeOf(arg=2), - ops.Negate(4), - ops.Negate(4.0), - ops.NullIfZero(0), - ops.NullIfZero(1), - ops.IsNull(ir.null()), - ops.NotNull(ir.null()), - ops.ZeroIfNull(ir.null()), - ops.IfNull(1, ops.NullIfZero(0).to_expr()), - ops.NullIf(ir.null(), ops.NullIfZero(0).to_expr()), - ops.IsNan(np.nan), - ops.IsInf(np.inf), - ops.Ceil(4.5), - ops.Floor(4.5), - ops.Round(3.43456), - ops.Round(3.43456, 2), - ops.Round(3.43456, digits=1), - ops.Clip(123, lower=30), - ops.Clip(123, lower=30, upper=100), - ops.BaseConvert('EEE', from_base=16, to_base=10), - ops.Logarithm(100), - ops.Log(100), - ops.Log(100, base=2), - ops.Ln(100), - ops.Log2(100), - ops.Log10(100), - ops.Uppercase('asd'), - ops.Lowercase('asd'), - ops.Reverse('asd'), - ops.Strip('asd'), - ops.LStrip('asd'), - ops.RStrip('asd'), - ops.Capitalize('asd'), - ops.Substring('asd', start=1), - ops.Substring('asd', 1), - ops.Substring('asd', 1, length=2), - ops.StrRight('asd', nchars=2), - ops.Repeat('asd', times=4), - ops.StringFind('asd', 'sd', start=1), - ops.Translate('asd', from_str='bd', to_str='ce'), - ops.LPad('asd', length=2, pad='ss'), - ops.RPad('asd', length=2, pad='ss'), - ops.StringJoin(',', ['asd', 'bsdf']), - ops.FuzzySearch('asd', pattern='n'), - ops.StringSQLLike('asd', pattern='as', escape='asd'), - ops.RegexExtract('asd', pattern='as', index=1), - ops.RegexReplace('asd', 'as', 'a'), - ops.StringReplace('asd', 'as', 'a'), - ops.StringSplit('asd', 's'), - ops.StringConcat(('s', 'e')), - ops.StartsWith('asd', 'as'), - ops.EndsWith('asd', 'xyz'), - ops.Not(false), - ops.And(false, true), - ops.Or(false, true), - ops.GreaterEqual(three, two), - ops.Sum(t.a), - t.a.op(), -] - - -@pytest.fixture(scope='module', params=operations) -def op(request): - with pytest.warns(FutureWarning): - # .op().op() deprecated, do not use for new code - r = request.param.op() - assert r is request.param - - return request.param + +@pytest.fixture(scope='module') +def operations(request): + true = ir.literal(True) + false = ir.literal(False) + two = ir.literal(2) + three = ir.literal(3) + return [ + ops.Cast(three, to='int64'), + ops.TypeOf(arg=2), + ops.Negate(4), + ops.Negate(4.0), + ops.NullIfZero(0), + ops.NullIfZero(1), + ops.IsNull(ir.null()), + ops.NotNull(ir.null()), + ops.ZeroIfNull(ir.null()), + ops.IfNull(1, ops.NullIfZero(0).to_expr()), + ops.NullIf(ir.null(), ops.NullIfZero(0).to_expr()), + ops.IsNan(np.nan), + ops.IsInf(np.inf), + ops.Ceil(4.5), + ops.Floor(4.5), + ops.Round(3.43456), + ops.Round(3.43456, 2), + ops.Round(3.43456, digits=1), + ops.Clip(123, lower=30), + ops.Clip(123, lower=30, upper=100), + ops.BaseConvert('EEE', from_base=16, to_base=10), + ops.Logarithm(100), + ops.Log(100), + ops.Log(100, base=2), + ops.Ln(100), + ops.Log2(100), + ops.Log10(100), + ops.Uppercase('asd'), + ops.Lowercase('asd'), + ops.Reverse('asd'), + ops.Strip('asd'), + ops.LStrip('asd'), + ops.RStrip('asd'), + ops.Capitalize('asd'), + ops.Substring('asd', start=1), + ops.Substring('asd', 1), + ops.Substring('asd', 1, length=2), + ops.StrRight('asd', nchars=2), + ops.Repeat('asd', times=4), + ops.StringFind('asd', 'sd', start=1), + ops.Translate('asd', from_str='bd', to_str='ce'), + ops.LPad('asd', length=2, pad='ss'), + ops.RPad('asd', length=2, pad='ss'), + ops.StringJoin(',', ['asd', 'bsdf']), + ops.FuzzySearch('asd', pattern='n'), + ops.StringSQLLike('asd', pattern='as', escape='asd'), + ops.RegexExtract('asd', pattern='as', index=1), + ops.RegexReplace('asd', 'as', 'a'), + ops.StringReplace('asd', 'as', 'a'), + ops.StringSplit('asd', 's'), + ops.StringConcat(('s', 'e')), + ops.StartsWith('asd', 'as'), + ops.EndsWith('asd', 'xyz'), + ops.Not(false), + ops.And(false, true), + ops.Or(false, true), + ops.GreaterEqual(three, two), + ops.Sum(t.a), + t.a.op(), + ] class Expr: @@ -110,7 +110,7 @@ class NamedValue(Base): class Values(Base): - lst = rlz.tuple_of(rlz.instance_of(ops.Node)) + lst: Tuple[ops.Node, ...] one = NamedValue(value=1, name=Name("one")) @@ -170,13 +170,44 @@ class Aliased(Base): assert expected == new_values +def test_value_annotations(): + class Op1(ops.Value): + arg: ops.Value + + output_dtype = dt.int64 + output_shape = ds.scalar + + class Op2(ops.Value): + arg: ops.Value[dt.Any, ds.Any] + + output_dtype = dt.int64 + output_shape = ds.scalar + + class Op3(ops.Value): + arg: ops.Value[dt.Integer, ds.Any] + + output_dtype = dt.int64 + output_shape = ds.scalar + + class Op4(ops.Value): + arg: ops.Value[dt.Integer, ds.Scalar] + + output_dtype = dt.int64 + output_shape = ds.scalar + + assert Op1(1).arg.dtype == dt.int8 + assert Op2(1).arg.dtype == dt.int8 + assert Op3(1).arg.dtype == dt.int8 + assert Op4(1).arg.dtype == dt.int8 + + def test_operation(): class Logarithm(ir.Expr): pass class Log(ops.Node): - arg = rlz.double() - base = rlz.optional(rlz.double()) + arg: ops.Value[dt.Float64, ds.Any] + base: Optional[ops.Value[dt.Float64, ds.Any]] = None def to_expr(self): return Logarithm(self) @@ -188,27 +219,28 @@ def to_expr(self): assert isinstance(Log(arg=100).to_expr(), Logarithm) -def test_operation_nodes_are_slotted(op): - assert hasattr(op, "__slots__") - assert not hasattr(op, "__dict__") +def test_operation_nodes_are_slotted(operations): + for op in operations: + assert hasattr(op, "__slots__") + assert not hasattr(op, "__dict__") def test_instance_of_operation(): class MyOperation(ops.Node): - arg = rlz.instance_of(ir.IntegerValue) + arg: ir.IntegerValue def to_expr(self): return ir.IntegerScalar(self) MyOperation(ir.literal(5)) - with pytest.raises(IbisTypeError): + with pytest.raises(ValidationError): MyOperation(ir.literal('string')) def test_array_input(): class MyOp(ops.Value): - value = rlz.value(dt.Array(dt.double)) + value: ops.Value[dt.Array[dt.Float64], ds.Any] output_dtype = rlz.dtype_like('value') output_shape = rlz.shape_like('value') @@ -236,7 +268,7 @@ def to_expr(self): @pytest.fixture(scope='session') def dummy_op(): class DummyOp(ops.Value): - arg = rlz.any + arg: ops.Value return DummyOp @@ -272,12 +304,12 @@ def test_expression_class_aliases(): def test_sortkey_propagates_dtype_and_shape(): k = ops.SortKey(ibis.literal(1), ascending=True) assert k.output_dtype == dt.int8 - assert k.output_shape == rlz.Shape.SCALAR + assert k.output_shape.is_scalar() t = ibis.table([('a', 'int16')], name='t') k = ops.SortKey(t.a, ascending=True) assert k.output_dtype == dt.int16 - assert k.output_shape == rlz.Shape.COLUMNAR + assert k.output_shape.is_columnar() def test_getitem_on_column_is_error(): diff --git a/ibis/tests/expr/test_set_operations.py b/ibis/tests/expr/test_set_operations.py index f9e88ebc5d09..ae4cd35bf800 100644 --- a/ibis/tests/expr/test_set_operations.py +++ b/ibis/tests/expr/test_set_operations.py @@ -3,7 +3,6 @@ import pytest import ibis -import ibis.expr.operations as ops from ibis.common.exceptions import RelationError @@ -62,5 +61,5 @@ def test_operation_supports_schemas_with_different_field_order(method): assert u2.schema == a.schema() assert u2.left == a.op() - columns = [ops.TableColumn(c, name) for name in 'abc'] - assert u2.right == ops.Selection(c.op(), columns) + reprojected = c.select(['a', 'b', 'c']) + assert u2.right == reprojected.op() diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index bd44049bb89e..58150ed5c065 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -21,6 +21,7 @@ from ibis import _ from ibis import literal as L from ibis.common.exceptions import RelationError +from ibis.common.patterns import ValidationError from ibis.expr import api from ibis.expr.types import Column, Table from ibis.tests.expr.mocks import MockAlchemyBackend, MockBackend @@ -468,10 +469,11 @@ def test_order_by(table): key3 = result3.op().sort_keys[0] key4 = result4.op().sort_keys[0] - assert key2.descending - assert key3.descending - assert key4.descending - assert_equal(result2, result3) + assert key2.descending is True + assert key3.descending is True + assert key4.descending is True + assert key2.expr.equals(key3.expr) + assert key2.expr.equals(key4.expr) def test_order_by_desc_deferred_sort_key(table): @@ -678,7 +680,7 @@ def test_aggregate_post_predicate(table, case_fn): by = ['g'] having = [case_fn(table)] - with pytest.raises(com.IbisTypeError): + with pytest.raises(ValidationError): table.aggregate(metrics, by=by, having=having) @@ -1105,7 +1107,7 @@ def test_join_invalid_expr_type(con): invalid_right = left.foo_id join_key = ['bar_id'] - with pytest.raises(com.IbisTypeError, match="Argument is not a table"): + with pytest.raises(ValidationError): left.inner_join(invalid_right, join_key) @@ -1699,7 +1701,7 @@ def test_filter_with_literal(value, api): # ints are invalid predicates int_val = ibis.literal(int(value)) - with pytest.raises((NotImplementedError, com.IbisTypeError)): + with pytest.raises((NotImplementedError, ValidationError, com.IbisTypeError)): api(t, int_val) @@ -1819,7 +1821,7 @@ def test_pivot_wider(): def test_invalid_deferred(): t = ibis.table(dict(value="int", lagged_value="int"), name="t") - with pytest.raises(com.IbisTypeError, match="Deferred input is not allowed"): + with pytest.raises(ValidationError, match="doesn't match"): ibis.greatest(t.value, ibis._.lagged_value) diff --git a/ibis/tests/expr/test_udf.py b/ibis/tests/expr/test_udf.py index cdb556adfe5f..a13d4c6a89e7 100644 --- a/ibis/tests/expr/test_udf.py +++ b/ibis/tests/expr/test_udf.py @@ -3,10 +3,10 @@ import pytest import ibis -import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir +from ibis.common.patterns import ValidationError @pytest.fixture @@ -44,7 +44,7 @@ def test_vectorized_udf_operations(table, klass, output_type): assert isinstance(udf.to_expr(), output_type) - with pytest.raises(com.IbisTypeError): + with pytest.raises(ValidationError): # wrong function type klass( func=1, @@ -53,7 +53,7 @@ def test_vectorized_udf_operations(table, klass, output_type): return_type=dt.int8(), ) - with pytest.raises(com.IbisTypeError): + with pytest.raises(ValidationError): # scalar type instead of column type klass( func=lambda a, b, c: a, @@ -62,7 +62,7 @@ def test_vectorized_udf_operations(table, klass, output_type): return_type=dt.int8(), ) - with pytest.raises(com.IbisTypeError): + with pytest.raises(ValidationError): # wrong input type klass( func=lambda a, b, c: a, @@ -71,7 +71,7 @@ def test_vectorized_udf_operations(table, klass, output_type): return_type=dt.int8(), ) - with pytest.raises(com.IbisTypeError): + with pytest.raises(ValidationError): # wrong return type klass( func=lambda a, b, c: a, diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 86f1e8b4c17f..676b445398a5 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -17,13 +17,14 @@ import ibis import ibis.common.exceptions as com import ibis.expr.analysis as an +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops -import ibis.expr.rules as rlz import ibis.expr.types as ir from ibis import _, literal from ibis.common.collections import frozendict from ibis.common.exceptions import IbisTypeError +from ibis.common.patterns import ValidationError from ibis.expr import api from ibis.tests.util import assert_equal @@ -33,12 +34,6 @@ def test_null(): assert ibis.null().op() == ops.Literal(None, dtype=dt.null) -def test_literal_mixed_type_fails(): - data = [1, 'a'] - with pytest.raises(TypeError): - ibis.literal(data) - - @pytest.mark.parametrize( ['value', 'expected_type'], [ @@ -395,7 +390,7 @@ def test_log(table, log, column): def test_log_string(table): g = table.g - with pytest.raises(IbisTypeError): + with pytest.raises(ValidationError): ops.Log(g, None).to_expr() @@ -403,14 +398,14 @@ def test_log_string(table): def test_log_variants_string(table, klass): g = table.g - with pytest.raises(IbisTypeError): + with pytest.raises(ValidationError): klass(g).to_expr() def test_log_boolean(table, log): # boolean not implemented for these h = table['h'] - with pytest.raises(IbisTypeError): + with pytest.raises(ValidationError): log(h) @@ -623,18 +618,18 @@ def test_null_column_union(): def test_string_compare_numeric_array(table): - with pytest.raises(TypeError): + with pytest.raises(com.IbisTypeError): table.g == table.f # noqa: B015 - with pytest.raises(TypeError): + with pytest.raises(com.IbisTypeError): table.g == table.c # noqa: B015 def test_string_compare_numeric_literal(table): - with pytest.raises(TypeError): + with pytest.raises(com.IbisTypeError): table.g == ibis.literal(1.5) # noqa: B015 - with pytest.raises(TypeError): + with pytest.raises(com.IbisTypeError): table.g == ibis.literal(5) # noqa: B015 @@ -655,13 +650,13 @@ def test_between(table): assert isinstance(result, ir.BooleanScalar) # Cases where between should immediately fail, e.g. incomparables - with pytest.raises(TypeError): + with pytest.raises(ValidationError): table.f.between('0', '1') - with pytest.raises(TypeError): + with pytest.raises(ValidationError): table.f.between(0, '1') - with pytest.raises(TypeError): + with pytest.raises(ValidationError): table.f.between('0', 1) @@ -679,7 +674,7 @@ def test_binop_string_type_error(table, operation, left, right): a = table[left] b = table[right] - with pytest.raises(TypeError): + with pytest.raises((TypeError, ValidationError)): operation(a, b) @@ -1031,53 +1026,23 @@ def test_scalar_parameter_invalid_compare(left, right): ) def test_between_time_failure_time(case, creator, left, right): value = creator(case) - with pytest.raises(TypeError): + with pytest.raises(ValidationError): value.between(left, right) -def test_custom_type_binary_operations(): - class Foo(ir.Expr): - def __add__(self, other): - op = self.op() - return type(op)(op.value + other).to_expr() - - __radd__ = __add__ - - class FooNode(ops.Node): - value = rlz.integer - - def to_expr(self): - return Foo(self) - - left = ibis.literal(2) - right = FooNode(3).to_expr() - result = left + right - assert isinstance(result, Foo) - assert isinstance(result.op(), FooNode) - - left = FooNode(3).to_expr() - right = ibis.literal(2) - result = left + right - assert isinstance(result, Foo) - assert isinstance(result.op(), FooNode) - - def test_empty_array_as_argument(): class Foo(ir.Expr): pass class FooNode(ops.Node): - value = rlz.value(dt.Array(dt.int64)) + value: ops.Value[dt.Array[dt.Int64], ds.Any] def to_expr(self): return Foo(self) node = FooNode([]) - value = node.value - expected = literal([]).cast(dt.Array(dt.int64)).op() - - assert value.output_dtype.equals(dt.Array(dt.null)) - assert ops.Cast(value, dt.Array(dt.int64)).equals(expected) + assert node.value.value == () + assert node.value.dtype == dt.Array(dt.Int64) def test_nullable_column_propagated(): @@ -1150,6 +1115,7 @@ def test_timestamp_timezone_type(tz): def test_map_get_broadcast(): t = ibis.table([('a', 'string')], name='t') lookup_table = ibis.literal({'a': 1, 'b': 2}) + expr = lookup_table.get(t.a) assert isinstance(expr, ir.IntegerColumn) @@ -1234,7 +1200,7 @@ def test_map_get_with_null_on_null_type_with_non_null(): def test_map_get_with_incompatible_value(): value = ibis.literal({'A': 1000, 'B': 2000}) - with pytest.raises(IbisTypeError): + with pytest.raises(TypeError): value.get('C', ['A']) @@ -1538,6 +1504,10 @@ def double_float(x): return x * 2.0 +def is_negative(x): + return x < 0 + + def test_array_map(): arr = ibis.array([1, 2, 3]) @@ -1548,6 +1518,12 @@ def test_array_map(): assert result_float.type() == dt.Array(dt.float64) +def test_array_filter(): + arr = ibis.array([1, 2, 3]) + result = arr.filter(is_negative) + assert result.type() == arr.type() + + @pytest.mark.parametrize( ("func", "expected_type"), [ @@ -1601,12 +1577,12 @@ def test_numpy_ufuncs_dont_cast_columns(): [operator.lt, operator.gt, operator.ge, operator.le, operator.eq, operator.ne], ) def test_logical_comparison_rlz_incompatible_error(table, operation): - with pytest.raises(TypeError, match=r"b:int16 and Literal\(foo\):string"): + with pytest.raises(com.IbisTypeError, match=r"b:int16 and Literal\(foo\):string"): operation(table.b, "foo") def test_case_rlz_incompatible_error(table): - with pytest.raises(TypeError, match=r"a:int8 and Literal\(foo\):string"): + with pytest.raises(com.IbisTypeError, match=r"a:int8 and Literal\(foo\):string"): table.a == 'foo' # noqa: B015 diff --git a/ibis/tests/expr/test_visualize.py b/ibis/tests/expr/test_visualize.py index f7e59560de7a..1478ec9df6c4 100644 --- a/ibis/tests/expr/test_visualize.py +++ b/ibis/tests/expr/test_visualize.py @@ -5,8 +5,9 @@ import pytest import ibis +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt import ibis.expr.operations as ops -import ibis.expr.rules as rlz import ibis.expr.types as ir pytest.importorskip('graphviz') @@ -47,8 +48,8 @@ class MyExpr(ir.Expr): pass class MyExprNode(ops.Node): - foo = rlz.string - bar = rlz.numeric + foo: ops.Value[dt.String, ds.Any] + bar: ops.Value[dt.Numeric, ds.Any] def to_expr(self): return MyExpr(self) @@ -68,8 +69,8 @@ def schema(self): raise NotImplementedError class MyExprNode(ops.Node): - foo = rlz.string - bar = rlz.numeric + foo: ops.Value[dt.String, ds.Any] + bar: ops.Value[dt.Numeric, ds.Any] def to_expr(self): return MyExpr(self) diff --git a/ibis/tests/expr/test_window_frames.py b/ibis/tests/expr/test_window_frames.py index 5095a3f453e8..b3979f128cef 100644 --- a/ibis/tests/expr/test_window_frames.py +++ b/ibis/tests/expr/test_window_frames.py @@ -7,8 +7,49 @@ import ibis import ibis.expr.builders as bl +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.exceptions import IbisInputError, IbisTypeError +from ibis.common.patterns import Pattern, ValidationError + + +def test_window_boundary(): + # the boundary value must be either numeric or interval + b = ops.WindowBoundary(5, preceding=False) + assert b.value == ops.Literal(5, dtype=dt.int8) + + b = ops.WindowBoundary(3.12, preceding=True) + assert b.value == ops.Literal(3.12, dtype=dt.double) + + oneday = ops.Literal(1, dtype=dt.Interval('D')) + b = ops.WindowBoundary(oneday, preceding=False) + assert b.value == oneday + + with pytest.raises(ValidationError): + ops.WindowBoundary('foo', preceding=True) + + +def test_window_boundary_typevars(): + lit = ops.Literal(1, dtype=dt.Interval('D')) + + p = Pattern.from_typehint(ops.WindowBoundary[dt.Integer, ds.Any]) + b = ops.WindowBoundary(5, preceding=False) + assert p.validate(b, {}) == b + with pytest.raises(ValidationError): + p.validate(ops.WindowBoundary(5.0, preceding=False), {}) + with pytest.raises(ValidationError): + p.validate(ops.WindowBoundary(lit, preceding=True), {}) + + p = Pattern.from_typehint(ops.WindowBoundary[dt.Interval, ds.Any]) + b = ops.WindowBoundary(lit, preceding=True) + assert p.validate(b, {}) == b + + +def test_window_boundary_coercions(): + RowsWindowBoundary = ops.WindowBoundary[dt.Integer, ds.Any] + p = Pattern.from_typehint(RowsWindowBoundary) + assert p.validate(1, {}) == RowsWindowBoundary(ops.Literal(1, dtype=dt.int8), False) def test_window_builder_rows(): @@ -53,9 +94,9 @@ def test_window_builder_rows(): assert w6.end is None assert w6.how == 'rows' - with pytest.raises(TypeError): + with pytest.raises(ValidationError): w0.rows(5, ibis.interval(days=1)) - with pytest.raises(TypeError): + with pytest.raises(ValidationError): w0.rows(ibis.interval(days=1), 10) diff --git a/ibis/tests/test_config.py b/ibis/tests/test_config.py index 0fd674047d22..c13a5bd4db47 100644 --- a/ibis/tests/test_config.py +++ b/ibis/tests/test_config.py @@ -2,13 +2,14 @@ import pytest +from ibis.common.patterns import ValidationError from ibis.config import options def test_sql_config(monkeypatch): assert options.sql.default_limit is None - with pytest.raises(TypeError): + with pytest.raises(ValidationError): options.sql.default_limit = -1 monkeypatch.setattr(options.sql, 'default_limit', 100) diff --git a/ibis/util.py b/ibis/util.py index 8f1e1e03ec4c..ec16c2f37053 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -19,7 +19,6 @@ Any, Callable, Iterator, - Mapping, Sequence, TypeVar, ) @@ -31,10 +30,6 @@ from numbers import Real from pathlib import Path - import ibis.expr.operations as ops - - Graph = Mapping[ops.Node, Sequence[ops.Node]] - T = TypeVar("T", covariant=True) U = TypeVar("U", covariant=True) K = TypeVar("K") diff --git a/pyproject.toml b/pyproject.toml index b33afe47c37d..706fd449dcd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -448,6 +448,7 @@ ignore = [ "SIM117", # nested with statements "SIM118", # remove .keys() calls from dictionaries "SIM300", # yoda conditions + "UP007", # Optional[str] -> str | None ] exclude = ["*_py310.py", "ibis/tests/*/snapshots/*"] target-version = "py38"