From 98e11739fcf23665759adfdc07b3de6cc21ffe0d Mon Sep 17 00:00:00 2001 From: Kexiang Wang Date: Sun, 28 Jan 2024 14:24:02 -0500 Subject: [PATCH] feat(risingwave): init impl for Risingwave (#7954) This is the initial PR for ibis to support risingwave. [RisingWave](https://github.com/risingwavelabs/risingwave) is a distributed SQL streaming database engineered to provide the simplest and most cost-efficient approach for processing and managing streaming data with utmost reliability. After a few weeks of investigation, here's my phasic results. 1. As Risingwave is largely compatible with Postgres, Ibis can be easily extended to support Risingwave. The main work is to add a new dialect for Risingwave, which is similar to the `postgres` dialect. I've almost finished this part in this PR. With this PR, Ibis can be used to connect to Risingwave and run some basic queries. I've manually tested some queries which works well and add some tests imitating Postgres Backend. I would appreciate if you can help test more queries. 2. As ibis relies on SQLalchemy to support Postgres, we follow its implementation to support Risingwave. However, there are also some differences in semantics between Risingwave and Postgres, which require some modification in either ibis or SQLalchemy. `sqlalchemy-risingwave` is designed to reduce this mismatch. So in this PR, I introduce the new dependencies `sqlalchemy-risingwave` to ibis. 3. Ibis has no support for Materialized View natively. However Materialized View is a core concept in RW, people use RW because of its convenient auto-updated Materialized View. Now, if a user wants to create a new MV, he needs to use a raw SQL. Adding DDLs like `CreateMaterializedView`, `CreateSource` and `CreateSink` in the [base ddl file](https://github.com/ibis-project/ibis/blob/main/ibis/backends/base/sql/ddl.py) may help. We would appreciate it if you can help offer some suggestions. Besides this, I also met some obstacle that may need your help. 1. Risingwave hasn't supported `TEMPORARY VIEW` yet, so I changed some implementations relying on `TEMPORARY VIEW` to a normal view. For example, for the `_metadata()` functions, RW backend's implementation is `con.exec_driver_sql(f"CREATE VIEW IF NOT EXISTS {name} AS {query}")`. While in PG it's `con.exec_driver_sql(f"CREATE TEMPORARY VIEW {name} AS {query}")`. Do you have any suggestions for other ways to work around this? Besides, I didn't quite get what `_metadata()` is doing here. I would appreciate it if you could explain it a bit. 2. There's some mismatch between the `postgres` dialect and `risingwave` dialect, which are still not fully tested in this PR. We'll continue to work on it. 3. ~~This PR requires some new features of Risingwave v1.6.0 and sqlalchemy-risingwave v1.0.0 which are not released yet. They'll be released soon.~~ Done. BTW, how should I indicate this backend is only for risingwave > 1.6? 4. I don't quite understand the test pipeline of Ibis. I copied the test cases from the `postgres` dialect and modified them to fit the `risingwave` dialect, and some of them are commented temporarily due to the lack of support. I also added an SQL script to help set up the test environment, which creates tables and loads data. But I don't know how to run it in the test pipeline. Any suggestions or guidance are welcomed. I suppose the test pipeline would require a docker image of Risingwave. We can provide one if needed. 5. I'm a newbie in the ibis community, this PR may not be perfect considering others. Any comments are welcomed and I sincerely appreciate your time and patience. closes #8038 --- .github/workflows/ibis-backends.yml | 14 + ci/schema/risingwave.sql | 177 +++ compose.yaml | 85 ++ docker/risingwave/risingwave.toml | 2 + ibis/backends/risingwave/__init__.py | 282 +++++ ibis/backends/risingwave/compiler.py | 34 + ibis/backends/risingwave/datatypes.py | 83 ++ ibis/backends/risingwave/registry.py | 848 ++++++++++++++ ibis/backends/risingwave/tests/__init__.py | 0 ibis/backends/risingwave/tests/conftest.py | 124 ++ .../test_client/test_compile_toplevel/out.sql | 2 + .../test_analytic_functions/out.sql | 7 + .../test_union_cte/False/out.sql | 1 + .../test_union_cte/True/out.sql | 1 + ibis/backends/risingwave/tests/test_client.py | 158 +++ .../risingwave/tests/test_functions.py | 1032 +++++++++++++++++ ibis/backends/risingwave/tests/test_json.py | 17 + ibis/backends/tests/test_aggregation.py | 155 ++- ibis/backends/tests/test_array.py | 145 ++- ibis/backends/tests/test_benchmarks.py | 900 -------------- ibis/backends/tests/test_binary.py | 1 + ibis/backends/tests/test_client.py | 67 +- ibis/backends/tests/test_column.py | 1 + ibis/backends/tests/test_dot_sql.py | 17 +- ibis/backends/tests/test_examples.py | 2 +- ibis/backends/tests/test_export.py | 49 +- ibis/backends/tests/test_generic.py | 104 +- ibis/backends/tests/test_json.py | 8 +- ibis/backends/tests/test_map.py | 96 +- ibis/backends/tests/test_network.py | 5 + ibis/backends/tests/test_numeric.py | 98 +- ibis/backends/tests/test_param.py | 18 + ibis/backends/tests/test_register.py | 38 +- ibis/backends/tests/test_set_ops.py | 98 +- ibis/backends/tests/test_sql.py | 2 +- ibis/backends/tests/test_string.py | 89 +- ibis/backends/tests/test_struct.py | 8 +- ibis/backends/tests/test_temporal.py | 136 ++- ibis/backends/tests/test_timecontext.py | 1 + ibis/backends/tests/test_udf.py | 1 + ibis/backends/tests/test_uuid.py | 6 + ibis/backends/tests/test_vectorized_udf.py | 2 +- ibis/backends/tests/test_window.py | 107 +- pyproject.toml | 4 + 44 files changed, 3993 insertions(+), 1032 deletions(-) create mode 100644 ci/schema/risingwave.sql create mode 100644 docker/risingwave/risingwave.toml create mode 100644 ibis/backends/risingwave/__init__.py create mode 100644 ibis/backends/risingwave/compiler.py create mode 100644 ibis/backends/risingwave/datatypes.py create mode 100644 ibis/backends/risingwave/registry.py create mode 100644 ibis/backends/risingwave/tests/__init__.py create mode 100644 ibis/backends/risingwave/tests/conftest.py create mode 100644 ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql create mode 100644 ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql create mode 100644 ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql create mode 100644 ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql create mode 100644 ibis/backends/risingwave/tests/test_client.py create mode 100644 ibis/backends/risingwave/tests/test_functions.py create mode 100644 ibis/backends/risingwave/tests/test_json.py delete mode 100644 ibis/backends/tests/test_benchmarks.py diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index dc8fd09287ff..6cf7ed40f699 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -122,6 +122,12 @@ jobs: - postgres sys-deps: - libgeos-dev + - name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - name: impala title: Impala serial: true @@ -211,6 +217,14 @@ jobs: - postgres sys-deps: - libgeos-dev + - os: windows-latest + backend: + name: risingwave + title: Risingwave + services: + - risingwave + extras: + - risingwave - os: windows-latest backend: name: postgres diff --git a/ci/schema/risingwave.sql b/ci/schema/risingwave.sql new file mode 100644 index 000000000000..cedfa8449d60 --- /dev/null +++ b/ci/schema/risingwave.sql @@ -0,0 +1,177 @@ +SET RW_IMPLICIT_FLUSH=true; + +DROP TABLE IF EXISTS diamonds CASCADE; + +CREATE TABLE diamonds ( + carat FLOAT, + cut TEXT, + color TEXT, + clarity TEXT, + depth FLOAT, + "table" FLOAT, + price BIGINT, + x FLOAT, + y FLOAT, + z FLOAT +) WITH ( + connector = 'posix_fs', + match_pattern = 'diamonds.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS astronauts CASCADE; + +CREATE TABLE astronauts ( + "id" BIGINT, + "number" BIGINT, + "nationwide_number" BIGINT, + "name" VARCHAR, + "original_name" VARCHAR, + "sex" VARCHAR, + "year_of_birth" BIGINT, + "nationality" VARCHAR, + "military_civilian" VARCHAR, + "selection" VARCHAR, + "year_of_selection" BIGINT, + "mission_number" BIGINT, + "total_number_of_missions" BIGINT, + "occupation" VARCHAR, + "year_of_mission" BIGINT, + "mission_title" VARCHAR, + "ascend_shuttle" VARCHAR, + "in_orbit" VARCHAR, + "descend_shuttle" VARCHAR, + "hours_mission" DOUBLE PRECISION, + "total_hrs_sum" DOUBLE PRECISION, + "field21" BIGINT, + "eva_hrs_mission" DOUBLE PRECISION, + "total_eva_hrs" DOUBLE PRECISION +) WITH ( + connector = 'posix_fs', + match_pattern = 'astronauts.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS batting CASCADE; + +CREATE TABLE batting ( + "playerID" TEXT, + "yearID" BIGINT, + stint BIGINT, + "teamID" TEXT, + "lgID" TEXT, + "G" BIGINT, + "AB" BIGINT, + "R" BIGINT, + "H" BIGINT, + "X2B" BIGINT, + "X3B" BIGINT, + "HR" BIGINT, + "RBI" BIGINT, + "SB" BIGINT, + "CS" BIGINT, + "BB" BIGINT, + "SO" BIGINT, + "IBB" BIGINT, + "HBP" BIGINT, + "SH" BIGINT, + "SF" BIGINT, + "GIDP" BIGINT +) WITH ( + connector = 'posix_fs', + match_pattern = 'batting.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS awards_players CASCADE; + +CREATE TABLE awards_players ( + "playerID" TEXT, + "awardID" TEXT, + "yearID" BIGINT, + "lgID" TEXT, + tie TEXT, + notes TEXT +) WITH ( + connector = 'posix_fs', + match_pattern = 'awards_players.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS functional_alltypes CASCADE; + +CREATE TABLE functional_alltypes ( + id INTEGER, + bool_col BOOLEAN, + tinyint_col SMALLINT, + smallint_col SMALLINT, + int_col INTEGER, + bigint_col BIGINT, + float_col REAL, + double_col DOUBLE PRECISION, + date_string_col TEXT, + string_col TEXT, + timestamp_col TIMESTAMP WITHOUT TIME ZONE, + year INTEGER, + month INTEGER +) WITH ( + connector = 'posix_fs', + match_pattern = 'functional_alltypes.csv', + posix_fs.root = '/data', +) FORMAT PLAIN ENCODE CSV ( without_header = 'false', delimiter = ',' ); + +DROP TABLE IF EXISTS tzone CASCADE; + +CREATE TABLE tzone ( + ts TIMESTAMP WITH TIME ZONE, + key TEXT, + value DOUBLE PRECISION +); + +INSERT INTO tzone + SELECT + CAST('2017-05-28 11:01:31.000400' AS TIMESTAMP WITH TIME ZONE) + + t * INTERVAL '1 day 1 second' AS ts, + CHR(97 + t) AS key, + t + t / 10.0 AS value + FROM generate_series(0, 9) AS t; + +DROP TABLE IF EXISTS array_types CASCADE; + +CREATE TABLE IF NOT EXISTS array_types ( + x BIGINT[], + y TEXT[], + z DOUBLE PRECISION[], + grouper TEXT, + scalar_column DOUBLE PRECISION, + multi_dim BIGINT[][] +); + +INSERT INTO array_types VALUES + (ARRAY[1, 2, 3], ARRAY['a', 'b', 'c'], ARRAY[1.0, 2.0, 3.0], 'a', 1.0, ARRAY[ARRAY[NULL::BIGINT, NULL, NULL], ARRAY[1, 2, 3]]), + (ARRAY[4, 5], ARRAY['d', 'e'], ARRAY[4.0, 5.0], 'a', 2.0, ARRAY[]::BIGINT[][]), + (ARRAY[6, NULL], ARRAY['f', NULL], ARRAY[6.0, NULL], 'a', 3.0, ARRAY[NULL, ARRAY[]::BIGINT[], NULL]), + (ARRAY[NULL, 1, NULL], ARRAY[NULL, 'a', NULL], ARRAY[]::DOUBLE PRECISION[], 'b', 4.0, ARRAY[ARRAY[1], ARRAY[2], ARRAY[NULL::BIGINT], ARRAY[3]]), + (ARRAY[2, NULL, 3], ARRAY['b', NULL, 'c'], NULL, 'b', 5.0, NULL), + (ARRAY[4, NULL, NULL, 5], ARRAY['d', NULL, NULL, 'e'], ARRAY[4.0, NULL, NULL, 5.0], 'c', 6.0, ARRAY[ARRAY[1, 2, 3]]); + +DROP TABLE IF EXISTS json_t CASCADE; + +CREATE TABLE IF NOT EXISTS json_t (js JSONB); + +INSERT INTO json_t VALUES + ('{"a": [1,2,3,4], "b": 1}'), + ('{"a":null,"b":2}'), + ('{"a":"foo", "c":null}'), + ('null'), + ('[42,47,55]'), + ('[]'); + +DROP TABLE IF EXISTS win CASCADE; +CREATE TABLE win (g TEXT, x BIGINT, y BIGINT); +INSERT INTO win VALUES + ('a', 0, 3), + ('a', 1, 2), + ('a', 2, 0), + ('a', 3, 1), + ('a', 4, 1); diff --git a/compose.yaml b/compose.yaml index 9d110ccbc7b6..388ee3e9f115 100644 --- a/compose.yaml +++ b/compose.yaml @@ -538,6 +538,88 @@ services: networks: - impala + risingwave-minio: + image: "quay.io/minio/minio:latest" + command: + - server + - "--address" + - "0.0.0.0:9301" + - "--console-address" + - "0.0.0.0:9400" + - /data + expose: + - "9301" + - "9400" + ports: + - "9301:9301" + - "9400:9400" + depends_on: [] + volumes: + - "risingwave-minio:/data" + entrypoint: /bin/sh -c "set -e; mkdir -p \"/data/hummock001\"; /usr/bin/docker-entrypoint.sh \"$$0\" \"$$@\" " + environment: + MINIO_CI_CD: "1" + MINIO_ROOT_PASSWORD: hummockadmin + MINIO_ROOT_USER: hummockadmin + MINIO_DOMAIN: "risingwave-minio" + container_name: risingwave-minio + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/9301; exit $$?;' + interval: 5s + timeout: 5s + retries: 20 + restart: always + networks: + - risingwave + + risingwave: + image: ghcr.io/risingwavelabs/risingwave:nightly-20240122 + command: "standalone --meta-opts=\" \ + --advertise-addr 0.0.0.0:5690 \ + --backend mem \ + --state-store hummock+minio://hummockadmin:hummockadmin@risingwave-minio:9301/hummock001 \ + --data-directory hummock_001 \ + --config-path /risingwave.toml\" \ + --compute-opts=\" \ + --config-path /risingwave.toml \ + --advertise-addr 0.0.0.0:5688 \ + --role both \" \ + --frontend-opts=\" \ + --config-path /risingwave.toml \ + --listen-addr 0.0.0.0:4566 \ + --advertise-addr 0.0.0.0:4566 \" \ + --compactor-opts=\" \ + --advertise-addr 0.0.0.0:6660 \"" + expose: + - "4566" + ports: + - "4566:4566" + depends_on: + - risingwave-minio + volumes: + - "./docker/risingwave/risingwave.toml:/risingwave.toml" + - risingwave:/data + environment: + RUST_BACKTRACE: "1" + # If ENABLE_TELEMETRY is not set, telemetry will start by default + ENABLE_TELEMETRY: ${ENABLE_TELEMETRY:-true} + container_name: risingwave + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/6660; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/5688; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/4566; exit $$?;' + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/5690; exit $$?;' + interval: 5s + timeout: 5s + retries: 20 + restart: always + networks: + - risingwave + networks: impala: # docker defaults to naming networks "$PROJECT_$NETWORK" but the Java Hive @@ -554,6 +636,7 @@ networks: oracle: exasol: flink: + risingwave: volumes: broker_var: @@ -572,3 +655,5 @@ volumes: minio: exasol: impala: + risingwave-minio: + risingwave: diff --git a/docker/risingwave/risingwave.toml b/docker/risingwave/risingwave.toml new file mode 100644 index 000000000000..43d57926ed16 --- /dev/null +++ b/docker/risingwave/risingwave.toml @@ -0,0 +1,2 @@ +# RisingWave config file to be mounted into the Docker containers. +# See https://github.com/risingwavelabs/risingwave/blob/main/src/config/example.toml for example diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py new file mode 100644 index 000000000000..04de491f6dfe --- /dev/null +++ b/ibis/backends/risingwave/__init__.py @@ -0,0 +1,282 @@ +"""Risingwave backend.""" + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Callable, Literal + +import sqlalchemy as sa + +import ibis.common.exceptions as exc +import ibis.expr.operations as ops +from ibis import util +from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend +from ibis.backends.risingwave.compiler import RisingwaveCompiler +from ibis.backends.risingwave.datatypes import RisingwaveType +from ibis.common.exceptions import InvalidDecoratorError + +if TYPE_CHECKING: + from collections.abc import Iterable + + import ibis.expr.datatypes as dt + + +def _verify_source_line(func_name: str, line: str): + if line.startswith("@"): + raise InvalidDecoratorError(func_name, line) + return line + + +class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema): + name = "risingwave" + compiler = RisingwaveCompiler + supports_temporary_tables = False + supports_create_or_replace = False + supports_python_udfs = False + + def do_connect( + self, + host: str | None = None, + user: str | None = None, + password: str | None = None, + port: int = 5432, + database: str | None = None, + schema: str | None = None, + url: str | None = None, + driver: Literal["psycopg2"] = "psycopg2", + ) -> None: + """Create an Ibis client connected to Risingwave database. + + Parameters + ---------- + host + Hostname + user + Username + password + Password + port + Port number + database + Database to connect to + schema + Risingwave schema to use. If `None`, use the default `search_path`. + url + SQLAlchemy connection string. + + If passed, the other connection arguments are ignored. + driver + Database driver + + Examples + -------- + >>> import os + >>> import getpass + >>> import ibis + >>> host = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost") + >>> user = os.environ.get("IBIS_TEST_RISINGWAVE_USER", getpass.getuser()) + >>> password = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD") + >>> database = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev") + >>> con = connect(database=database, host=host, user=user, password=password) + >>> con.list_tables() # doctest: +ELLIPSIS + [...] + >>> t = con.table("functional_alltypes") + >>> t + RisingwaveTable[table] + name: functional_alltypes + schema: + id : int32 + bool_col : boolean + tinyint_col : int16 + smallint_col : int16 + int_col : int32 + bigint_col : int64 + float_col : float32 + double_col : float64 + date_string_col : string + string_col : string + timestamp_col : timestamp + year : int32 + month : int32 + """ + if driver != "psycopg2": + raise NotImplementedError("psycopg2 is currently the only supported driver") + + alchemy_url = self._build_alchemy_url( + url=url, + host=host, + port=port, + user=user, + password=password, + database=database, + driver=f"risingwave+{driver}", + ) + + connect_args = {} + if schema is not None: + connect_args["options"] = f"-csearch_path={schema}" + + engine = sa.create_engine( + alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool + ) + + @sa.event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + with dbapi_connection.cursor() as cur: + cur.execute("SET TIMEZONE = UTC") + + super().do_connect(engine) + + def list_tables(self, like=None, schema=None): + """List the tables in the database. + + Parameters + ---------- + like + A pattern to use for listing tables. + schema + The schema to perform the list against. + + ::: {.callout-warning} + ## `schema` refers to database hierarchy + + The `schema` parameter does **not** refer to the column names and + types of `table`. + ::: + """ + tables = self.inspector.get_table_names(schema=schema) + views = self.inspector.get_view_names(schema=schema) + return self._filter_with_like(tables + views, like) + + def list_databases(self, like=None) -> list[str]: + # http://dba.stackexchange.com/a/1304/58517 + dbs = sa.table( + "pg_database", + sa.column("datname", sa.TEXT()), + sa.column("datistemplate", sa.BOOLEAN()), + schema="pg_catalog", + ) + query = sa.select(dbs.c.datname).where(sa.not_(dbs.c.datistemplate)) + with self.begin() as con: + databases = list(con.execute(query).scalars()) + + return self._filter_with_like(databases, like) + + @property + def current_database(self) -> str: + return self._scalar_query(sa.select(sa.func.current_database())) + + @property + def current_schema(self) -> str: + return self._scalar_query(sa.select(sa.func.current_schema())) + + def function(self, name: str, *, schema: str | None = None) -> Callable: + query = sa.text( + """ +SELECT + n.nspname as schema, + pg_catalog.pg_get_function_result(p.oid) as return_type, + string_to_array(pg_catalog.pg_get_function_arguments(p.oid), ', ') as signature, + CASE p.prokind + WHEN 'a' THEN 'agg' + WHEN 'w' THEN 'window' + WHEN 'p' THEN 'proc' + ELSE 'func' + END as "Type" +FROM pg_catalog.pg_proc p +LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace +WHERE p.proname = :name +""" + + "AND n.nspname OPERATOR(pg_catalog.~) :schema COLLATE pg_catalog.default" + * (schema is not None) + ).bindparams(name=name, schema=f"^({schema})$") + + def split_name_type(arg: str) -> tuple[str, dt.DataType]: + name, typ = arg.split(" ", 1) + return name, RisingwaveType.from_string(typ) + + with self.begin() as con: + rows = con.execute(query).mappings().fetchall() + + if not rows: + name = f"{schema}.{name}" if schema else name + raise exc.MissingUDFError(name) + elif len(rows) > 1: + raise exc.AmbiguousUDFError(name) + + [row] = rows + return_type = RisingwaveType.from_string(row["return_type"]) + signature = list(map(split_name_type, row["signature"])) + + # dummy callable + def fake_func(*args, **kwargs): + ... + + fake_func.__name__ = name + fake_func.__signature__ = inspect.Signature( + [ + inspect.Parameter( + name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typ + ) + for name, typ in signature + ], + return_annotation=return_type, + ) + fake_func.__annotations__ = {"return": return_type, **dict(signature)} + op = ops.udf.scalar.builtin(fake_func, schema=schema) + return op + + def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: + name = util.gen_name("risingwave_metadata") + type_info_sql = """\ + SELECT + attname, + format_type(atttypid, atttypmod) AS type + FROM pg_attribute + WHERE attrelid = CAST(:name AS regclass) + AND attnum > 0 + AND NOT attisdropped + ORDER BY attnum""" + if self.inspector.has_table(query): + query = f"TABLE {query}" + + text = sa.text(type_info_sql).bindparams(name=name) + with self.begin() as con: + con.exec_driver_sql(f"CREATE VIEW IF NOT EXISTS {name} AS {query}") + try: + yield from ( + (col, RisingwaveType.from_string(typestr)) + for col, typestr in con.execute(text) + ) + finally: + con.exec_driver_sql(f"DROP VIEW IF EXISTS {name}") + + def _get_temp_view_definition( + self, name: str, definition: sa.sql.compiler.Compiled + ) -> str: + yield f"DROP VIEW IF EXISTS {name}" + yield f"CREATE TEMPORARY VIEW {name} AS {definition}" + + def create_schema( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + if database is not None and database != self.current_database: + raise exc.UnsupportedOperationError( + "Risingwave does not support creating a schema in a different database" + ) + if_not_exists = "IF NOT EXISTS " * force + name = self._quote(name) + with self.begin() as con: + con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") + + def drop_schema( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + if database is not None and database != self.current_database: + raise exc.UnsupportedOperationError( + "Risingwave does not support dropping a schema in a different database" + ) + name = self._quote(name) + if_exists = "IF EXISTS " * force + with self.begin() as con: + con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}") diff --git a/ibis/backends/risingwave/compiler.py b/ibis/backends/risingwave/compiler.py new file mode 100644 index 000000000000..b4bcd9c0b9d5 --- /dev/null +++ b/ibis/backends/risingwave/compiler.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import ibis.expr.operations as ops +from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator +from ibis.backends.risingwave.datatypes import RisingwaveType +from ibis.backends.risingwave.registry import operation_registry +from ibis.expr.rewrites import rewrite_sample + + +class RisingwaveExprTranslator(AlchemyExprTranslator): + _registry = operation_registry.copy() + _rewrites = AlchemyExprTranslator._rewrites.copy() + _has_reduction_filter_syntax = True + _supports_tuple_syntax = True + _dialect_name = "risingwave" + + # it does support it, but we can't use it because of support for pivot + supports_unnest_in_select = False + + type_mapper = RisingwaveType + + +rewrites = RisingwaveExprTranslator.rewrites + + +@rewrites(ops.Any) +@rewrites(ops.All) +def _any_all_no_op(expr): + return expr + + +class RisingwaveCompiler(AlchemyCompiler): + translator_class = RisingwaveExprTranslator + rewrites = AlchemyCompiler.rewrites | rewrite_sample diff --git a/ibis/backends/risingwave/datatypes.py b/ibis/backends/risingwave/datatypes.py new file mode 100644 index 000000000000..389210486a6f --- /dev/null +++ b/ibis/backends/risingwave/datatypes.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as psql +import sqlalchemy.types as sat + +import ibis.expr.datatypes as dt +from ibis.backends.base.sql.alchemy.datatypes import AlchemyType +from ibis.backends.base.sqlglot.datatypes import PostgresType as SqlglotPostgresType + +_from_postgres_types = { + psql.DOUBLE_PRECISION: dt.Float64, + psql.JSONB: dt.JSON, + psql.JSON: dt.JSON, + psql.BYTEA: dt.Binary, +} + + +_postgres_interval_fields = { + "YEAR": "Y", + "MONTH": "M", + "DAY": "D", + "HOUR": "h", + "MINUTE": "m", + "SECOND": "s", + "YEAR TO MONTH": "M", + "DAY TO HOUR": "h", + "DAY TO MINUTE": "m", + "DAY TO SECOND": "s", + "HOUR TO MINUTE": "m", + "HOUR TO SECOND": "s", + "MINUTE TO SECOND": "s", +} + + +class RisingwaveType(AlchemyType): + dialect = "risingwave" + + @classmethod + def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine: + if dtype.is_floating(): + if isinstance(dtype, dt.Float64): + return psql.DOUBLE_PRECISION + else: + return psql.REAL + elif dtype.is_array(): + # Unwrap the array element type because sqlalchemy doesn't allow arrays of + # arrays. This doesn't affect the underlying data. + while dtype.is_array(): + dtype = dtype.value_type + return sa.ARRAY(cls.from_ibis(dtype)) + elif dtype.is_map(): + if not (dtype.key_type.is_string() and dtype.value_type.is_string()): + raise TypeError( + f"Risingwave only supports map, got: {dtype}" + ) + return psql.HSTORE() + elif dtype.is_uuid(): + return psql.UUID() + else: + return super().from_ibis(dtype) + + @classmethod + def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType: + if dtype := _from_postgres_types.get(type(typ)): + return dtype(nullable=nullable) + elif isinstance(typ, psql.HSTORE): + return dt.Map(dt.string, dt.string, nullable=nullable) + elif isinstance(typ, psql.INTERVAL): + field = typ.fields.upper() + if (unit := _postgres_interval_fields.get(field, None)) is None: + raise ValueError(f"Unknown Risingwave interval field {field!r}") + elif unit in {"Y", "M"}: + raise ValueError( + "Variable length intervals are not yet supported with Risingwave" + ) + return dt.Interval(unit=unit, nullable=nullable) + else: + return super().to_ibis(typ, nullable=nullable) + + @classmethod + def from_string(cls, type_string: str) -> RisingwaveType: + return SqlglotPostgresType.from_string(type_string) diff --git a/ibis/backends/risingwave/registry.py b/ibis/backends/risingwave/registry.py new file mode 100644 index 000000000000..a6cb67ca83a8 --- /dev/null +++ b/ibis/backends/risingwave/registry.py @@ -0,0 +1,848 @@ +from __future__ import annotations + +import functools +import itertools +import locale +import operator +import platform +import re +import string + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pg +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import GenericFunction + +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops + +# used for literal translate +from ibis.backends.base.sql.alchemy import ( + fixed_arity, + get_sqla_table, + reduction, + sqlalchemy_operation_registry, + sqlalchemy_window_functions_registry, + unary, + varargs, +) +from ibis.backends.base.sql.alchemy.registry import ( + _bitwise_op, + _extract, + get_col, +) + +operation_registry = sqlalchemy_operation_registry.copy() +operation_registry.update(sqlalchemy_window_functions_registry) + +_truncate_precisions = { + "us": "microseconds", + "ms": "milliseconds", + "s": "second", + "m": "minute", + "h": "hour", + "D": "day", + "W": "week", + "M": "month", + "Q": "quarter", + "Y": "year", +} + + +def _timestamp_truncate(t, op): + sa_arg = t.translate(op.arg) + try: + precision = _truncate_precisions[op.unit.short] + except KeyError: + raise com.UnsupportedOperationError(f"Unsupported truncate unit {op.unit!r}") + return sa.func.date_trunc(precision, sa_arg) + + +def _timestamp_bucket(t, op): + arg = t.translate(op.arg) + interval = t.translate(op.interval) + + origin = sa.literal_column("timestamp '1970-01-01 00:00:00'") + + if op.offset is not None: + origin = origin + t.translate(op.offset) + return sa.func.date_bin(interval, arg, origin) + + +def _typeof(t, op): + sa_arg = t.translate(op.arg) + typ = sa.cast(sa.func.pg_typeof(sa_arg), sa.TEXT) + + # select pg_typeof('thing') returns unknown so we have to check the child's + # type for nullness + return sa.case( + ((typ == "unknown") & (op.arg.dtype != dt.null), "text"), + ((typ == "unknown") & (op.arg.dtype == dt.null), "null"), + else_=typ, + ) + + +_strftime_to_postgresql_rules = { + "%a": "TMDy", # TM does it in a locale dependent way + "%A": "TMDay", + "%w": "D", # 1-based day of week, see below for how we make this 0-based + "%d": "DD", # day of month + "%-d": "FMDD", # - is no leading zero for Python same for FM in postgres + "%b": "TMMon", # Sep + "%B": "TMMonth", # September + "%m": "MM", # 01 + "%-m": "FMMM", # 1 + "%y": "YY", # 15 + "%Y": "YYYY", # 2015 + "%H": "HH24", # 09 + "%-H": "FMHH24", # 9 + "%I": "HH12", # 09 + "%-I": "FMHH12", # 9 + "%p": "AM", # AM or PM + "%M": "MI", # zero padded minute + "%-M": "FMMI", # Minute + "%S": "SS", # zero padded second + "%-S": "FMSS", # Second + "%f": "US", # zero padded microsecond + "%z": "OF", # utf offset + "%Z": "TZ", # uppercase timezone name + "%j": "DDD", # zero padded day of year + "%-j": "FMDDD", # day of year + "%U": "WW", # 1-based week of year + # 'W': ?, # meh +} + +try: + _strftime_to_postgresql_rules.update( + { + "%c": locale.nl_langinfo(locale.D_T_FMT), # locale date and time + "%x": locale.nl_langinfo(locale.D_FMT), # locale date + "%X": locale.nl_langinfo(locale.T_FMT), # locale time + } + ) +except AttributeError: + HAS_LANGINFO = False +else: + HAS_LANGINFO = True + + +# translate strftime spec into mostly equivalent Risingwave spec +_scanner = re.Scanner( # type: ignore # re does have a Scanner attribute + # double quotes need to be escaped + [('"', lambda *_: r"\"")] + + [ + ( + "|".join( + map( + "(?:{})".format, + itertools.chain( + _strftime_to_postgresql_rules.keys(), + [ + # "%e" is in the C standard and Python actually + # generates this if your spec contains "%c" but we + # don't officially support it as a specifier so we + # need to special case it in the scanner + "%e", + r"\s+", + rf"[{re.escape(string.punctuation)}]", + rf"[^{re.escape(string.punctuation)}\s]+", + ], + ), + ) + ), + lambda _, token: token, + ) + ] +) + + +_lexicon_values = frozenset(_strftime_to_postgresql_rules.values()) + +_locale_specific_formats = frozenset(["%c", "%x", "%X"]) +_strftime_blacklist = frozenset(["%w", "%U", "%e"]) | _locale_specific_formats + + +def _reduce_tokens(tokens, arg): + # current list of tokens + curtokens = [] + + # reduced list of tokens that accounts for blacklisted values + reduced = [] + + non_special_tokens = frozenset(_strftime_to_postgresql_rules) - _strftime_blacklist + + # TODO: how much of a hack is this? + for token in tokens: + if token in _locale_specific_formats and not HAS_LANGINFO: + raise com.UnsupportedOperationError( + f"Format string component {token!r} is not supported on {platform.system()}" + ) + # we are a non-special token %A, %d, etc. + if token in non_special_tokens: + curtokens.append(_strftime_to_postgresql_rules[token]) + + # we have a string like DD, to escape this we + # surround it with double quotes + elif token in _lexicon_values: + curtokens.append(f'"{token}"') + + # we have a token that needs special treatment + elif token in _strftime_blacklist: + if token == "%w": + value = sa.extract("dow", arg) # 0 based day of week + elif token == "%U": + value = sa.cast(sa.func.to_char(arg, "WW"), sa.SMALLINT) - 1 + elif token in ("%c", "%x", "%X"): + # re scan and tokenize this pattern + try: + new_pattern = _strftime_to_postgresql_rules[token] + except KeyError: + raise ValueError( + "locale specific date formats (%%c, %%x, %%X) are " + "not yet implemented for %s" % platform.system() + ) + + new_tokens, _ = _scanner.scan(new_pattern) + value = functools.reduce( + sa.sql.ColumnElement.concat, + _reduce_tokens(new_tokens, arg), + ) + elif token == "%e": + # pad with spaces instead of zeros + value = sa.func.replace(sa.func.to_char(arg, "DD"), "0", " ") + + reduced += [ + sa.func.to_char(arg, "".join(curtokens)), + sa.cast(value, sa.TEXT), + ] + + # empty current token list in case there are more tokens + del curtokens[:] + + # uninteresting text + else: + curtokens.append(token) + # append result to r if we had more tokens or if we have no + # blacklisted tokens + if curtokens: + reduced.append(sa.func.to_char(arg, "".join(curtokens))) + return reduced + + +def _strftime(arg, pattern): + tokens, _ = _scanner.scan(pattern.value) + reduced = _reduce_tokens(tokens, arg) + return functools.reduce(sa.sql.ColumnElement.concat, reduced) + + +def _find_in_set(t, op): + # TODO + # this operation works with any type, not just strings. should the + # operation itself also have this property? + return ( + sa.func.coalesce( + sa.func.array_position( + pg.array(list(map(t.translate, op.values))), + t.translate(op.needle), + ), + 0, + ) + - 1 + ) + + +def _log(t, op): + arg, base = op.args + sa_arg = t.translate(arg) + if base is not None: + sa_base = t.translate(base) + return sa.cast( + sa.func.log(sa.cast(sa_base, sa.NUMERIC), sa.cast(sa_arg, sa.NUMERIC)), + t.get_sqla_type(op.dtype), + ) + return sa.func.ln(sa_arg) + + +def _regex_extract(arg, pattern, index): + # wrap in parens to support 0th group being the whole string + pattern = "(" + pattern + ")" + # arrays are 1-based in postgres + index = index + 1 + does_match = sa.func.textregexeq(arg, pattern) + matches = sa.func.regexp_match(arg, pattern, type_=pg.ARRAY(sa.TEXT)) + return sa.case((does_match, matches[index]), else_=None) + + +def _array_repeat(t, op): + """Repeat an array.""" + arg = t.translate(op.arg) + times = t.translate(op.times) + + array_length = sa.func.cardinality(arg) + array = sa.sql.elements.Grouping(arg) if isinstance(op.arg, ops.Literal) else arg + + # sequence from 1 to the total number of elements desired in steps of 1. + series = sa.func.generate_series(1, times * array_length).table_valued() + + # if our current index modulo the array's length is a multiple of the + # array's length, then the index is the array's length + index = sa.func.coalesce( + sa.func.nullif(series.column % array_length, 0), array_length + ) + + # tie it all together in a scalar subquery and collapse that into an ARRAY + return sa.func.array(sa.select(array[index]).scalar_subquery()) + + +def _table_column(t, op): + ctx = t.context + table = op.table + + sa_table = get_sqla_table(ctx, table) + out_expr = get_col(sa_table, op) + + if op.dtype.is_timestamp(): + timezone = op.dtype.timezone + if timezone is not None: + out_expr = out_expr.op("AT TIME ZONE")(timezone).label(op.name) + + # If the column does not originate from the table set in the current SELECT + # context, we should format as a subquery + if t.permit_subquery and ctx.is_foreign_expr(table): + return sa.select(out_expr) + + return out_expr + + +def _round(t, op): + arg, digits = op.args + sa_arg = t.translate(arg) + + if digits is None: + return sa.func.round(sa_arg) + + # postgres doesn't allow rounding of double precision values to a specific + # number of digits (though simple truncation on doubles is allowed) so + # we cast to numeric and then cast back if necessary + result = sa.func.round(sa.cast(sa_arg, sa.NUMERIC), t.translate(digits)) + if digits is not None and arg.dtype.is_decimal(): + return result + result = sa.cast(result, pg.DOUBLE_PRECISION()) + return result + + +def _mod(t, op): + left, right = map(t.translate, op.args) + + # postgres doesn't allow modulus of double precision values, so upcast and + # then downcast later if necessary + if not op.dtype.is_integer(): + left = sa.cast(left, sa.NUMERIC) + right = sa.cast(right, sa.NUMERIC) + + result = left % right + if op.dtype.is_float64(): + return sa.cast(result, pg.DOUBLE_PRECISION()) + else: + return result + + +def _neg_idx_to_pos(array, idx): + return sa.case((idx < 0, sa.func.cardinality(array) + idx), else_=idx) + + +def _array_slice(*, index_converter, array_length, func): + def translate(t, op): + arg = t.translate(op.arg) + + arg_length = array_length(arg) + + if (start := op.start) is None: + start = 0 + else: + start = t.translate(start) + start = sa.func.least(arg_length, index_converter(arg, start)) + + if (stop := op.stop) is None: + stop = arg_length + else: + stop = index_converter(arg, t.translate(stop)) + + return func(arg, start + 1, stop) + + return translate + + +def _array_index(*, index_converter, func): + def translate(t, op): + sa_array = t.translate(op.arg) + sa_index = t.translate(op.index) + if isinstance(op.arg, ops.Literal): + sa_array = sa.sql.elements.Grouping(sa_array) + return func(sa_array, index_converter(sa_array, sa_index) + 1) + + return translate + + +def _literal(t, op): + dtype = op.dtype + value = op.value + + if value is None: + return ( + sa.null() if dtype.is_null() else sa.cast(sa.null(), t.get_sqla_type(dtype)) + ) + if dtype.is_interval(): + return sa.literal_column(f"INTERVAL '{value} {dtype.resolution}'") + elif dtype.is_array(): + return pg.array(value) + elif dtype.is_map(): + return pg.hstore(list(value.keys()), list(value.values())) + elif dtype.is_time(): + return sa.func.make_time( + value.hour, value.minute, value.second + value.microsecond / 1e6 + ) + elif dtype.is_date(): + return sa.func.make_date(value.year, value.month, value.day) + elif dtype.is_timestamp(): + if (tz := dtype.timezone) is not None: + return sa.func.to_timestamp(value.timestamp()).op("AT TIME ZONE")(tz) + return sa.cast(sa.literal(value.isoformat()), sa.TIMESTAMP()) + else: + return sa.literal(value) + + +def _string_agg(t, op): + agg = sa.func.string_agg(t.translate(op.arg), t.translate(op.sep)) + if (where := op.where) is not None: + return agg.filter(t.translate(where)) + return agg + + +def _corr(t, op): + if op.how == "sample": + raise ValueError( + f"{t.__class__.__name__} only implements population correlation " + "coefficient" + ) + return _binary_variance_reduction(sa.func.corr)(t, op) + + +def _covar(t, op): + suffix = {"sample": "samp", "pop": "pop"} + how = suffix.get(op.how, "samp") + func = getattr(sa.func, f"covar_{how}") + return _binary_variance_reduction(func)(t, op) + + +def _mode(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + return sa.func.mode().within_group(t.translate(arg)) + + +def _quantile(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(t.translate(op.quantile)).within_group(t.translate(arg)) + + +def _median(t, op): + arg = op.arg + if (where := op.where) is not None: + arg = ops.IfElse(where, arg, None) + + if arg.dtype.is_numeric(): + func = sa.func.percentile_cont + else: + func = sa.func.percentile_disc + return func(0.5).within_group(t.translate(arg)) + + +def _binary_variance_reduction(func): + def variance_compiler(t, op): + x = op.left + if (x_type := x.dtype).is_boolean(): + x = ops.Cast(x, dt.Int32(nullable=x_type.nullable)) + + y = op.right + if (y_type := y.dtype).is_boolean(): + y = ops.Cast(y, dt.Int32(nullable=y_type.nullable)) + + if t._has_reduction_filter_syntax: + result = func(t.translate(x), t.translate(y)) + + if (where := op.where) is not None: + return result.filter(t.translate(where)) + return result + else: + if (where := op.where) is not None: + x = ops.IfElse(where, x, None) + y = ops.IfElse(where, y, None) + return func(t.translate(x), t.translate(y)) + + return variance_compiler + + +def _arg_min_max(sort_func): + def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: + arg = t.translate(op.arg) + key = t.translate(op.key) + + conditions = [arg != sa.null(), key != sa.null()] + + agg = sa.func.array_agg(pg.aggregate_order_by(arg, sort_func(key))) + + if (where := op.where) is not None: + conditions.append(t.translate(where)) + return agg.filter(sa.and_(*conditions))[1] + + return translate + + +def _arbitrary(t, op): + if (how := op.how) == "heavy": + raise com.UnsupportedOperationError( + f"risingwave backend doesn't support how={how!r} for the arbitrary() aggregate" + ) + func = getattr(sa.func, op.how) + return t._reduction(func, op) + + +class rw_struct_field(GenericFunction): + inherit_cache = True + + +@compiles(rw_struct_field) +def compile_struct_field_postgresql(element, compiler, **kw): + arg, field = element.clauses + return f"({compiler.process(arg, **kw)}).{field.name}" + + +def _struct_field(t, op): + arg = op.arg + idx = arg.dtype.names.index(op.field) + 1 + field_name = sa.literal_column(f"f{idx:d}") + return rw_struct_field( + t.translate(arg), field_name, type_=t.get_sqla_type(op.dtype) + ) + + +def _struct_column(t, op): + types = op.dtype.types + return sa.func.row( + # we have to cast here, otherwise risingwave refuses to allow the statement + *map(t.translate, map(ops.Cast, op.values, types)), + type_=t.get_sqla_type( + dt.Struct({f"f{i:d}": typ for i, typ in enumerate(types, start=1)}) + ), + ) + + +def _unnest(t, op): + arg = op.arg + row_type = arg.dtype.value_type + + types = getattr(row_type, "types", (row_type,)) + + is_struct = row_type.is_struct() + derived = ( + sa.func.unnest(t.translate(arg)) + .table_valued( + *( + sa.column(f"f{i:d}", stype) + for i, stype in enumerate(map(t.get_sqla_type, types), start=1) + ) + ) + .render_derived(with_types=is_struct) + ) + + # wrap in a row column so that we can return a single column from this rule + if not is_struct: + return derived.c[0] + return sa.func.row(*derived.c) + + +def _array_sort(arg): + flat = sa.func.unnest(arg).column_valued() + return sa.func.array(sa.select(flat).order_by(flat).scalar_subquery()) + + +def _array_position(haystack, needle): + t = ( + sa.func.unnest(haystack) + .table_valued("value", with_ordinality="idx", name="haystack") + .render_derived() + ) + idx = t.c.idx - 1 + return sa.func.coalesce( + sa.select(idx).where(t.c.value == needle).limit(1).scalar_subquery(), -1 + ) + + +def _array_map(t, op): + return sa.func.array( + # this translates to the function call, with column names the same as + # the parameter names in the lambda + sa.select(t.translate(op.body)) + .select_from( + # unnest the input array + sa.func.unnest(t.translate(op.arg)) + # name the columns of the result the same as the lambda parameter + # so that we can reference them as such in the outer query + .table_valued(op.param) + .render_derived() + ) + .scalar_subquery() + ) + + +def _array_filter(t, op): + param = op.param + return sa.func.array( + sa.select(sa.column(param, type_=t.get_sqla_type(op.arg.dtype.value_type))) + .select_from( + sa.func.unnest(t.translate(op.arg)).table_valued(param).render_derived() + ) + .where(t.translate(op.body)) + .scalar_subquery() + ) + + +def zero_value(dtype): + if dtype.is_interval(): + return sa.func.make_interval() + return 0 + + +def interval_sign(v): + zero = sa.func.make_interval() + return sa.case((v == zero, 0), (v < zero, -1), (v > zero, 1)) + + +def _sign(value, dtype): + if dtype.is_interval(): + return interval_sign(value) + return sa.func.sign(value) + + +def _range(t, op): + start = t.translate(op.start) + stop = t.translate(op.stop) + step = t.translate(op.step) + satype = t.get_sqla_type(op.dtype) + seq = sa.func.generate_series(start, stop, step, type_=satype) + zero = zero_value(op.step.dtype) + return sa.case( + ( + sa.and_( + sa.func.nullif(step, zero).is_not(None), + _sign(step, op.step.dtype) == _sign(stop - start, op.step.dtype), + ), + sa.func.array_remove( + sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype + ), + ), + else_=sa.cast(pg.array([]), satype), + ) + + +operation_registry.update( + { + ops.Literal: _literal, + # We override this here to support time zones + ops.TableColumn: _table_column, + ops.Argument: lambda t, op: sa.column( + op.param, type_=t.get_sqla_type(op.dtype) + ), + # types + ops.TypeOf: _typeof, + # Floating + ops.IsNan: fixed_arity(lambda arg: arg == float("nan"), 1), + ops.IsInf: fixed_arity( + lambda arg: sa.or_(arg == float("inf"), arg == float("-inf")), 1 + ), + # boolean reductions + ops.Any: reduction(sa.func.bool_or), + ops.All: reduction(sa.func.bool_and), + # strings + ops.GroupConcat: _string_agg, + ops.Capitalize: unary(sa.func.initcap), + ops.RegexSearch: fixed_arity(lambda x, y: x.op("~")(y), 2), + # postgres defaults to replacing only the first occurrence + ops.RegexReplace: fixed_arity( + lambda string, pattern, replacement: sa.func.regexp_replace( + string, pattern, replacement, "g" + ), + 3, + ), + ops.Translate: fixed_arity(sa.func.translate, 3), + ops.RegexExtract: fixed_arity(_regex_extract, 3), + ops.StringSplit: fixed_arity( + lambda col, sep: sa.func.string_to_array( + col, sep, type_=sa.ARRAY(col.type) + ), + 2, + ), + ops.FindInSet: _find_in_set, + # math + ops.Log: _log, + ops.Log2: unary(lambda x: sa.func.log(2, x)), + ops.Log10: unary(sa.func.log), + ops.Round: _round, + ops.Modulus: _mod, + # dates and times + ops.DateFromYMD: fixed_arity(sa.func.make_date, 3), + ops.DateTruncate: _timestamp_truncate, + ops.TimestampTruncate: _timestamp_truncate, + ops.TimestampBucket: _timestamp_bucket, + ops.IntervalFromInteger: ( + lambda t, op: t.translate(op.arg) + * sa.text(f"INTERVAL '1 {op.dtype.resolution}'") + ), + ops.DateAdd: fixed_arity(operator.add, 2), + ops.DateSub: fixed_arity(operator.sub, 2), + ops.DateDiff: fixed_arity(operator.sub, 2), + ops.TimestampAdd: fixed_arity(operator.add, 2), + ops.TimestampSub: fixed_arity(operator.sub, 2), + ops.TimestampDiff: fixed_arity(operator.sub, 2), + ops.Strftime: fixed_arity(_strftime, 2), + ops.ExtractEpochSeconds: fixed_arity( + lambda arg: sa.cast(sa.extract("epoch", arg), sa.INTEGER), 1 + ), + ops.ExtractDayOfYear: _extract("doy"), + ops.ExtractWeekOfYear: _extract("week"), + # extracting the second gives us the fractional part as well, so smash that + # with a cast to SMALLINT + ops.ExtractSecond: fixed_arity( + lambda arg: sa.cast(sa.func.floor(sa.extract("second", arg)), sa.SMALLINT), + 1, + ), + # we get total number of milliseconds including seconds with extract so we + # mod 1000 + ops.ExtractMillisecond: fixed_arity( + lambda arg: sa.cast( + sa.func.floor(sa.extract("millisecond", arg)) % 1000, + sa.SMALLINT, + ), + 1, + ), + ops.DayOfWeekIndex: fixed_arity( + lambda arg: sa.cast( + sa.cast(sa.extract("dow", arg) + 6, sa.SMALLINT) % 7, sa.SMALLINT + ), + 1, + ), + ops.DayOfWeekName: fixed_arity( + lambda arg: sa.func.trim(sa.func.to_char(arg, "Day")), 1 + ), + ops.TimeFromHMS: fixed_arity(sa.func.make_time, 3), + # array operations + ops.ArrayLength: unary(sa.func.cardinality), + ops.ArrayCollect: reduction(sa.func.array_agg), + ops.Array: (lambda t, op: pg.array(list(map(t.translate, op.exprs)))), + ops.ArraySlice: _array_slice( + index_converter=_neg_idx_to_pos, + array_length=sa.func.cardinality, + func=lambda arg, start, stop: arg[start:stop], + ), + ops.ArrayIndex: _array_index( + index_converter=_neg_idx_to_pos, func=lambda arg, index: arg[index] + ), + ops.ArrayConcat: varargs(lambda *args: functools.reduce(operator.add, args)), + ops.ArrayRepeat: _array_repeat, + ops.Unnest: _unnest, + ops.Covariance: _covar, + ops.Correlation: _corr, + ops.BitwiseXor: _bitwise_op("#"), + ops.Mode: _mode, + ops.ApproxMedian: _median, + ops.Median: _median, + ops.Quantile: _quantile, + ops.MultiQuantile: _quantile, + ops.TimestampNow: lambda t, op: sa.literal_column( + "CURRENT_TIMESTAMP", type_=t.get_sqla_type(op.dtype) + ), + ops.MapGet: fixed_arity( + lambda arg, key, default: sa.case( + (arg.has_key(key), arg[key]), else_=default + ), + 3, + ), + ops.MapContains: fixed_arity(pg.HSTORE.Comparator.has_key, 2), + ops.MapKeys: unary(pg.HSTORE.Comparator.keys), + ops.MapValues: unary(pg.HSTORE.Comparator.vals), + ops.MapMerge: fixed_arity(operator.add, 2), + ops.MapLength: unary(lambda arg: sa.func.cardinality(arg.keys())), + ops.Map: fixed_arity(pg.hstore, 2), + ops.ArgMin: _arg_min_max(sa.asc), + ops.ArgMax: _arg_min_max(sa.desc), + ops.ArrayStringJoin: fixed_arity( + lambda sep, arr: sa.func.array_to_string(arr, sep), 2 + ), + ops.Strip: unary(lambda arg: sa.func.trim(arg, string.whitespace)), + ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)), + ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)), + ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2), + ops.Arbitrary: _arbitrary, + ops.StructColumn: _struct_column, + ops.StructField: _struct_field, + ops.First: reduction(sa.func.first), + ops.Last: reduction(sa.func.last), + ops.ExtractMicrosecond: fixed_arity( + lambda arg: sa.extract("microsecond", arg) % 1_000_000, 1 + ), + ops.Levenshtein: fixed_arity(sa.func.levenshtein, 2), + ops.ArraySort: fixed_arity(_array_sort, 1), + ops.ArrayIntersect: fixed_arity( + lambda left, right: sa.func.array( + sa.intersect( + sa.select(sa.func.unnest(left).column_valued()), + sa.select(sa.func.unnest(right).column_valued()), + ).scalar_subquery() + ), + 2, + ), + ops.ArrayRemove: fixed_arity( + lambda left, right: sa.func.array( + sa.except_( + sa.select(sa.func.unnest(left).column_valued()), sa.select(right) + ).scalar_subquery() + ), + 2, + ), + ops.ArrayUnion: fixed_arity( + lambda left, right: sa.func.array( + sa.union( + sa.select(sa.func.unnest(left).column_valued()), + sa.select(sa.func.unnest(right).column_valued()), + ).scalar_subquery() + ), + 2, + ), + ops.ArrayDistinct: fixed_arity( + lambda arg: sa.case( + (arg.is_(sa.null()), sa.null()), + else_=sa.func.array( + sa.select( + sa.distinct(sa.func.unnest(arg).column_valued()) + ).scalar_subquery() + ), + ), + 1, + ), + ops.ArrayPosition: fixed_arity(_array_position, 2), + ops.ArrayMap: _array_map, + ops.ArrayFilter: _array_filter, + ops.IntegerRange: _range, + ops.TimestampRange: _range, + ops.RegexSplit: fixed_arity(sa.func.regexp_split_to_array, 2), + } +) diff --git a/ibis/backends/risingwave/tests/__init__.py b/ibis/backends/risingwave/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ibis/backends/risingwave/tests/conftest.py b/ibis/backends/risingwave/tests/conftest.py new file mode 100644 index 000000000000..35cfe6b8e1db --- /dev/null +++ b/ibis/backends/risingwave/tests/conftest.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any + +import pytest +import sqlalchemy as sa + +import ibis +from ibis.backends.conftest import init_database +from ibis.backends.tests.base import ServiceBackendTest + +if TYPE_CHECKING: + from collections.abc import Iterable + from pathlib import Path + +PG_USER = os.environ.get("IBIS_TEST_RISINGWAVE_USER", os.environ.get("PGUSER", "root")) +PG_PASS = os.environ.get( + "IBIS_TEST_RISINGWAVE_PASSWORD", os.environ.get("PGPASSWORD", "") +) +PG_HOST = os.environ.get( + "IBIS_TEST_RISINGWAVE_HOST", os.environ.get("PGHOST", "localhost") +) +PG_PORT = os.environ.get("IBIS_TEST_RISINGWAVE_PORT", os.environ.get("PGPORT", 4566)) +IBIS_TEST_RISINGWAVE_DB = os.environ.get( + "IBIS_TEST_RISINGWAVE_DATABASE", os.environ.get("PGDATABASE", "dev") +) + + +class TestConf(ServiceBackendTest): + # postgres rounds half to even for double precision and half away from zero + # for numeric and decimal + + returned_timestamp_unit = "s" + supports_structs = False + rounding_method = "half_to_even" + service_name = "risingwave" + deps = "psycopg2", "sqlalchemy" + + @property + def test_files(self) -> Iterable[Path]: + return self.data_dir.joinpath("csv").glob("*.csv") + + def _load_data( + self, + *, + user: str = PG_USER, + password: str = PG_PASS, + host: str = PG_HOST, + port: int = PG_PORT, + database: str = IBIS_TEST_RISINGWAVE_DB, + **_: Any, + ) -> None: + """Load test data into a Risingwave backend instance. + + Parameters + ---------- + data_dir + Location of test data + script_dir + Location of scripts defining schemas + """ + init_database( + url=sa.engine.make_url( + f"risingwave://{user}:{password}@{host}:{port:d}/{database}" + ), + database=database, + schema=self.ddl_script, + isolation_level="AUTOCOMMIT", + recreate=False, + ) + + @staticmethod + def connect(*, tmpdir, worker_id, port: int | None = None, **kw): + con = ibis.risingwave.connect( + host=PG_HOST, + port=port or PG_PORT, + user=PG_USER, + password=PG_PASS, + database=IBIS_TEST_RISINGWAVE_DB, + **kw, + ) + cursor = con.raw_sql("SET RW_IMPLICIT_FLUSH TO true;") + cursor.close() + return con + + +@pytest.fixture(scope="session") +def con(tmp_path_factory, data_dir, worker_id): + return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection + + +@pytest.fixture(scope="module") +def db(con): + return con.database() + + +@pytest.fixture(scope="module") +def alltypes(db): + return db.functional_alltypes + + +@pytest.fixture(scope="module") +def df(alltypes): + return alltypes.execute() + + +@pytest.fixture(scope="module") +def alltypes_sqla(con, alltypes): + name = alltypes.op().name + return con._get_sqla_table(name) + + +@pytest.fixture(scope="module") +def intervals(con): + return con.table("intervals") + + +@pytest.fixture +def translate(): + from ibis.backends.risingwave import Backend + + context = Backend.compiler.make_context() + return lambda expr: Backend.compiler.translator_class(expr, context).get_result() diff --git a/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql new file mode 100644 index 000000000000..cfbcf133a863 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_client/test_compile_toplevel/out.sql @@ -0,0 +1,2 @@ +SELECT sum(t0.foo) AS "Sum(foo)" +FROM t0 AS t0 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql new file mode 100644 index 000000000000..c00dec1bed25 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_analytic_functions/out.sql @@ -0,0 +1,7 @@ +SELECT + RANK() OVER (ORDER BY t0.double_col ASC) - 1 AS rank, + DENSE_RANK() OVER (ORDER BY t0.double_col ASC) - 1 AS dense_rank, + CUME_DIST() OVER (ORDER BY t0.double_col ASC) AS cume_dist, + NTILE(7) OVER (ORDER BY t0.double_col ASC) - 1 AS ntile, + PERCENT_RANK() OVER (ORDER BY t0.double_col ASC) AS percent_rank +FROM functional_alltypes AS t0 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql new file mode 100644 index 000000000000..34761d9a76e0 --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/False/out.sql @@ -0,0 +1 @@ +WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION ALL SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION ALL SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql new file mode 100644 index 000000000000..6ce31e7468bb --- /dev/null +++ b/ibis/backends/risingwave/tests/snapshots/test_functions/test_union_cte/True/out.sql @@ -0,0 +1 @@ +WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1 \ No newline at end of file diff --git a/ibis/backends/risingwave/tests/test_client.py b/ibis/backends/risingwave/tests/test_client.py new file mode 100644 index 000000000000..b5c7cfa98560 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_client.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import os + +import pandas as pd +import pytest +from pytest import param + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.types as ir +from ibis.tests.util import assert_equal + +pytest.importorskip("psycopg2") +sa = pytest.importorskip("sqlalchemy") + +from sqlalchemy.dialects import postgresql # noqa: E402 + +RISINGWAVE_TEST_DB = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev") +IBIS_RISINGWAVE_HOST = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost") +IBIS_RISINGWAVE_PORT = os.environ.get("IBIS_TEST_RISINGWAVE_PORT", "4566") +IBIS_RISINGWAVE_USER = os.environ.get("IBIS_TEST_RISINGWAVE_USER", "root") +IBIS_RISINGWAVE_PASS = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD", "") + + +def test_table(alltypes): + assert isinstance(alltypes, ir.Table) + + +def test_array_execute(alltypes): + d = alltypes.limit(10).double_col + s = d.execute() + assert isinstance(s, pd.Series) + assert len(s) == 10 + + +def test_literal_execute(con): + expr = ibis.literal("1234") + result = con.execute(expr) + assert result == "1234" + + +def test_simple_aggregate_execute(alltypes): + d = alltypes.double_col.sum() + v = d.execute() + assert isinstance(v, float) + + +def test_list_tables(con): + assert con.list_tables() + assert len(con.list_tables(like="functional")) == 1 + + +def test_compile_toplevel(snapshot): + t = ibis.table([("foo", "double")], name="t0") + + expr = t.foo.sum() + result = ibis.postgres.compile(expr) + snapshot.assert_match(str(result), "out.sql") + + +def test_list_databases(con): + assert RISINGWAVE_TEST_DB is not None + assert RISINGWAVE_TEST_DB in con.list_databases() + + +def test_schema_type_conversion(con): + typespec = [ + # name, type, nullable + ("jsonb", postgresql.JSONB, True, dt.JSON), + ] + + sqla_types = [] + ibis_types = [] + for name, t, nullable, ibis_type in typespec: + sqla_types.append(sa.Column(name, t, nullable=nullable)) + ibis_types.append((name, ibis_type(nullable=nullable))) + + # Create a table with placeholder stubs for JSON, JSONB, and UUID. + table = sa.Table("tname", sa.MetaData(), *sqla_types) + + # Check that we can correctly create a schema with dt.any for the + # missing types. + schema = con._schema_from_sqla_table(table) + expected = ibis.schema(ibis_types) + + assert_equal(schema, expected) + + +@pytest.mark.parametrize("params", [{}, {"database": RISINGWAVE_TEST_DB}]) +def test_create_and_drop_table(con, temp_table, params): + sch = ibis.schema( + [ + ("first_name", "string"), + ("last_name", "string"), + ("department_name", "string"), + ("salary", "float64"), + ] + ) + + con.create_table(temp_table, schema=sch, **params) + assert con.table(temp_table, **params) is not None + + con.drop_table(temp_table, **params) + + with pytest.raises(sa.exc.NoSuchTableError): + con.table(temp_table, **params) + + +@pytest.mark.parametrize( + ("pg_type", "expected_type"), + [ + param(pg_type, ibis_type, id=pg_type.lower()) + for (pg_type, ibis_type) in [ + ("boolean", dt.boolean), + ("bytea", dt.binary), + ("bigint", dt.int64), + ("smallint", dt.int16), + ("integer", dt.int32), + ("text", dt.string), + ("real", dt.float32), + ("double precision", dt.float64), + ("character varying", dt.string), + ("date", dt.date), + ("time", dt.time), + ("time without time zone", dt.time), + ("timestamp without time zone", dt.timestamp), + ("timestamp with time zone", dt.Timestamp("UTC")), + ("interval", dt.Interval("s")), + ("numeric", dt.decimal), + ("jsonb", dt.json), + ] + ], +) +def test_get_schema_from_query(con, pg_type, expected_type): + name = con._quote(ibis.util.guid()) + with con.begin() as c: + c.exec_driver_sql(f"CREATE TABLE {name} (x {pg_type}, y {pg_type}[])") + expected_schema = ibis.schema(dict(x=expected_type, y=dt.Array(expected_type))) + result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}") + assert result_schema == expected_schema + with con.begin() as c: + c.exec_driver_sql(f"DROP TABLE {name}") + + +@pytest.mark.xfail(reason="unsupported insert with CTEs") +def test_insert_with_cte(con): + X = con.create_table("X", schema=ibis.schema(dict(id="int")), temp=False) + expr = X.join(X.mutate(a=X["id"] + 1), ["id"]) + Y = con.create_table("Y", expr, temp=False) + assert Y.execute().empty + con.drop_table("Y") + con.drop_table("X") + + +def test_connect_url_with_empty_host(): + con = ibis.connect("risingwave:///dev") + assert con.con.url.host is None diff --git a/ibis/backends/risingwave/tests/test_functions.py b/ibis/backends/risingwave/tests/test_functions.py new file mode 100644 index 000000000000..c8874e390c60 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_functions.py @@ -0,0 +1,1032 @@ +from __future__ import annotations + +import operator +import string +import warnings +from datetime import datetime + +import numpy as np +import pandas as pd +import pandas.testing as tm +import pytest +from pytest import param + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.types as ir +from ibis import config +from ibis import literal as L + +pytest.importorskip("psycopg2") +sa = pytest.importorskip("sqlalchemy") + +from sqlalchemy.dialects import postgresql # noqa: E402 + + +@pytest.mark.parametrize( + ("left_func", "right_func"), + [ + param( + lambda t: t.double_col.cast("int8"), + lambda at: sa.cast(at.c.double_col, sa.SMALLINT), + id="double_to_int8", + ), + param( + lambda t: t.double_col.cast("int16"), + lambda at: sa.cast(at.c.double_col, sa.SMALLINT), + id="double_to_int16", + ), + param( + lambda t: t.string_col.cast("double"), + lambda at: sa.cast(at.c.string_col, postgresql.DOUBLE_PRECISION), + id="string_to_double", + ), + param( + lambda t: t.string_col.cast("float32"), + lambda at: sa.cast(at.c.string_col, postgresql.REAL), + id="string_to_float", + ), + param( + lambda t: t.string_col.cast("decimal"), + lambda at: sa.cast(at.c.string_col, sa.NUMERIC()), + id="string_to_decimal_no_params", + ), + param( + lambda t: t.string_col.cast("decimal(9, 3)"), + lambda at: sa.cast(at.c.string_col, sa.NUMERIC(9, 3)), + id="string_to_decimal_params", + ), + ], +) +def test_cast(alltypes, alltypes_sqla, translate, left_func, right_func): + left = left_func(alltypes) + right = right_func(alltypes_sqla.alias("t0")) + assert str(translate(left.op()).compile()) == str(right.compile()) + + +def test_date_cast(alltypes, alltypes_sqla, translate): + result = alltypes.date_string_col.cast("date") + expected = sa.cast(alltypes_sqla.alias("t0").c.date_string_col, sa.DATE) + assert str(translate(result.op())) == str(expected) + + +@pytest.mark.parametrize( + "column", + [ + "id", + "bool_col", + "tinyint_col", + "smallint_col", + "int_col", + "bigint_col", + "float_col", + "double_col", + "date_string_col", + "string_col", + "timestamp_col", + "year", + "month", + ], +) +def test_noop_cast(alltypes, alltypes_sqla, translate, column): + col = alltypes[column] + result = col.cast(col.type()) + expected = alltypes_sqla.alias("t0").c[column] + assert result.equals(col) + assert str(translate(result.op())) == str(expected) + + +def test_timestamp_cast_noop(alltypes, alltypes_sqla, translate): + # See GH #592 + result1 = alltypes.timestamp_col.cast("timestamp") + result2 = alltypes.int_col.cast("timestamp") + + assert isinstance(result1, ir.TimestampColumn) + assert isinstance(result2, ir.TimestampColumn) + + expected1 = alltypes_sqla.alias("t0").c.timestamp_col + expected2 = sa.cast( + sa.func.to_timestamp(alltypes_sqla.alias("t0").c.int_col), sa.TIMESTAMP() + ) + + assert str(translate(result1.op())) == str(expected1) + assert str(translate(result2.op())) == str(expected2) + + +@pytest.mark.parametrize(("value", "expected"), [(0, None), (5.5, 5.5)]) +def test_nullif_zero(con, value, expected): + assert con.execute(L(value).nullif(0)) == expected + + +@pytest.mark.parametrize(("value", "expected"), [("foo_bar", 7), ("", 0)]) +def test_string_length(con, value, expected): + assert con.execute(L(value).length()) == expected + + +@pytest.mark.parametrize( + ("op", "expected"), + [ + param(operator.methodcaller("left", 3), "foo", id="left"), + param(operator.methodcaller("right", 3), "bar", id="right"), + param(operator.methodcaller("substr", 0, 3), "foo", id="substr_0_3"), + param(operator.methodcaller("substr", 4, 3), "bar", id="substr_4, 3"), + param(operator.methodcaller("substr", 1), "oo_bar", id="substr_1"), + ], +) +def test_string_substring(con, op, expected): + value = L("foo_bar") + assert con.execute(op(value)) == expected + + +@pytest.mark.parametrize( + ("opname", "expected"), + [("lstrip", "foo "), ("rstrip", " foo"), ("strip", "foo")], +) +def test_string_strip(con, opname, expected): + op = operator.methodcaller(opname) + value = L(" foo ") + assert con.execute(op(value)) == expected + + +@pytest.mark.parametrize( + ("opname", "count", "char", "expected"), + [("lpad", 6, " ", " foo"), ("rpad", 6, " ", "foo ")], +) +def test_string_pad(con, opname, count, char, expected): + op = operator.methodcaller(opname, count, char) + value = L("foo") + assert con.execute(op(value)) == expected + + +def test_string_reverse(con): + assert con.execute(L("foo").reverse()) == "oof" + + +def test_string_upper(con): + assert con.execute(L("foo").upper()) == "FOO" + + +def test_string_lower(con): + assert con.execute(L("FOO").lower()) == "foo" + + +@pytest.mark.parametrize( + ("haystack", "needle", "expected"), + [ + ("foobar", "bar", True), + ("foobar", "foo", True), + ("foobar", "baz", False), + ("100%", "%", True), + ("a_b_c", "_", True), + ], +) +def test_string_contains(con, haystack, needle, expected): + value = L(haystack) + expr = value.contains(needle) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("value", "expected"), + [("foo bar foo", "Foo Bar Foo"), ("foobar Foo", "Foobar Foo")], +) +def test_capitalize(con, value, expected): + assert con.execute(L(value).capitalize()) == expected + + +def test_repeat(con): + expr = L("bar ").repeat(3) + assert con.execute(expr) == "bar bar bar " + + +def test_re_replace(con): + expr = L("fudge|||chocolate||candy").re_replace("\\|{2,3}", ", ") + assert con.execute(expr) == "fudge, chocolate, candy" + + +def test_translate(con): + expr = L("faab").translate("a", "b") + assert con.execute(expr) == "fbbb" + + +@pytest.mark.parametrize( + ("raw_value", "expected"), [("a", 0), ("b", 1), ("d", -1), (None, 3)] +) +def test_find_in_set(con, raw_value, expected): + value = L(raw_value, dt.string) + haystack = ["a", "b", "c", None] + expr = value.find_in_set(haystack) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("raw_value", "opname", "expected"), + [ + (None, "isnull", True), + (1, "isnull", False), + (None, "notnull", False), + (1, "notnull", True), + ], +) +def test_isnull_notnull(con, raw_value, opname, expected): + lit = L(raw_value) + op = operator.methodcaller(opname) + expr = op(lit) + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(L("foobar").find("bar"), 3, id="find_pos"), + param(L("foobar").find("baz"), -1, id="find_neg"), + param(L("foobar").like("%bar"), True, id="like_left_pattern"), + param(L("foobar").like("foo%"), True, id="like_right_pattern"), + param(L("foobar").like("%baz%"), False, id="like_both_sides_pattern"), + param(L("foobar").like(["%bar"]), True, id="like_list_left_side"), + param(L("foobar").like(["foo%"]), True, id="like_list_right_side"), + param(L("foobar").like(["%baz%"]), False, id="like_list_both_sides"), + param(L("foobar").like(["%bar", "foo%"]), True, id="like_list_multiple"), + param(L("foobarfoo").replace("foo", "H"), "HbarH", id="replace"), + param(L("a").ascii_str(), ord("a"), id="ascii_str"), + ], +) +def test_string_functions(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(L("abcd").re_search("[a-z]"), True, id="re_search_match"), + param(L("abcd").re_search(r"[\d]+"), False, id="re_search_no_match"), + param(L("1222").re_search(r"[\d]+"), True, id="re_search_match_number"), + ], +) +def test_regexp(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.NA.fillna(5), 5, id="filled"), + param(L(5).fillna(10), 5, id="not_filled"), + param(L(5).nullif(5), None, id="nullif_null"), + param(L(10).nullif(5), 10, id="nullif_not_null"), + ], +) +def test_fillna_nullif(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.coalesce(5, None, 4), 5, id="first"), + param(ibis.coalesce(ibis.NA, 4, ibis.NA), 4, id="second"), + param(ibis.coalesce(ibis.NA, ibis.NA, 3.14), 3.14, id="third"), + ], +) +def test_coalesce(con, expr, expected): + assert con.execute(expr) == expected + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + param(ibis.coalesce(ibis.NA, ibis.NA), None, id="all_null"), + param( + ibis.coalesce( + ibis.NA.cast("int8"), + ibis.NA.cast("int8"), + ibis.NA.cast("int8"), + ), + None, + id="all_nulls_with_all_cast", + ), + ], +) +def test_coalesce_all_na(con, expr, expected): + assert con.execute(expr) is None + + +def test_coalesce_all_na_double(con): + expr = ibis.coalesce(ibis.NA, ibis.NA, ibis.NA.cast("double")) + assert np.isnan(con.execute(expr)) + + +def test_numeric_builtins_work(alltypes, df): + expr = alltypes.double_col.fillna(0) + result = expr.execute() + expected = df.double_col.fillna(0) + expected.name = "Coalesce()" + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("op", "pandas_op"), + [ + param( + lambda t: (t.double_col > 20).ifelse(10, -20), + lambda df: pd.Series(np.where(df.double_col > 20, 10, -20), dtype="int8"), + id="simple", + ), + param( + lambda t: (t.double_col > 20).ifelse(10, -20).abs(), + lambda df: pd.Series( + np.where(df.double_col > 20, 10, -20), dtype="int8" + ).abs(), + id="abs", + ), + ], +) +def test_ifelse(alltypes, df, op, pandas_op): + expr = op(alltypes) + result = expr.execute() + result.name = None + expected = pandas_op(df) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "pandas_func"), + [ + # tier and histogram + param( + lambda d: d.bucket([0, 10, 25, 50, 100]), + lambda s: pd.cut(s, [0, 10, 25, 50, 100], right=False, labels=False).astype( + "int8" + ), + id="include_over_false", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], include_over=True), + lambda s: pd.cut( + s, [0, 10, 25, 50, np.inf], right=False, labels=False + ).astype("int8"), + id="include_over_true", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], close_extreme=False), + lambda s: pd.cut(s, [0, 10, 25, 50], right=False, labels=False), + id="close_extreme_false", + ), + param( + lambda d: d.bucket([0, 10, 25, 50], closed="right", close_extreme=False), + lambda s: pd.cut( + s, + [0, 10, 25, 50], + include_lowest=False, + right=True, + labels=False, + ), + id="closed_right", + ), + param( + lambda d: d.bucket([10, 25, 50, 100], include_under=True), + lambda s: pd.cut(s, [0, 10, 25, 50, 100], right=False, labels=False).astype( + "int8" + ), + id="include_under_true", + ), + ], +) +def test_bucket(alltypes, df, func, pandas_func): + expr = func(alltypes.double_col) + result = expr.execute() + expected = pandas_func(df.double_col) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_category_label(alltypes, df): + t = alltypes + d = t.double_col + + bins = [0, 10, 25, 50, 100] + labels = ["a", "b", "c", "d"] + bucket = d.bucket(bins) + expr = bucket.label(labels) + result = expr.execute() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result = pd.Series(pd.Categorical(result, ordered=True)) + + result.name = "double_col" + + expected = pd.cut(df.double_col, bins, labels=labels, right=False) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("distinct", [True, False]) +def test_union_cte(alltypes, distinct, snapshot): + t = alltypes + expr1 = t.group_by(t.string_col).aggregate(metric=t.double_col.sum()) + expr2 = expr1.view() + expr3 = expr1.view() + expr = expr1.union(expr2, distinct=distinct).union(expr3, distinct=distinct) + result = " ".join( + line.strip() + for line in str( + expr.compile().compile(compile_kwargs={"literal_binds": True}) + ).splitlines() + ) + snapshot.assert_match(result, "out.sql") + + +@pytest.mark.parametrize( + ("func", "pandas_func"), + [ + param( + lambda t, cond: t.bool_col.count(), + lambda df, cond: df.bool_col.count(), + id="count", + ), + param( + lambda t, cond: t.double_col.mean(), + lambda df, cond: df.double_col.mean(), + id="mean", + ), + param( + lambda t, cond: t.double_col.min(), + lambda df, cond: df.double_col.min(), + id="min", + ), + param( + lambda t, cond: t.double_col.max(), + lambda df, cond: df.double_col.max(), + id="max", + ), + param( + lambda t, cond: t.double_col.var(), + lambda df, cond: df.double_col.var(), + id="var", + ), + param( + lambda t, cond: t.double_col.std(), + lambda df, cond: df.double_col.std(), + id="std", + ), + param( + lambda t, cond: t.double_col.var(how="sample"), + lambda df, cond: df.double_col.var(ddof=1), + id="samp_var", + ), + param( + lambda t, cond: t.double_col.std(how="pop"), + lambda df, cond: df.double_col.std(ddof=0), + id="pop_std", + ), + param( + lambda t, cond: t.bool_col.count(where=cond), + lambda df, cond: df.bool_col[cond].count(), + id="count_where", + ), + param( + lambda t, cond: t.double_col.mean(where=cond), + lambda df, cond: df.double_col[cond].mean(), + id="mean_where", + ), + param( + lambda t, cond: t.double_col.min(where=cond), + lambda df, cond: df.double_col[cond].min(), + id="min_where", + ), + param( + lambda t, cond: t.double_col.max(where=cond), + lambda df, cond: df.double_col[cond].max(), + id="max_where", + ), + param( + lambda t, cond: t.double_col.var(where=cond), + lambda df, cond: df.double_col[cond].var(), + id="var_where", + ), + param( + lambda t, cond: t.double_col.std(where=cond), + lambda df, cond: df.double_col[cond].std(), + id="std_where", + ), + param( + lambda t, cond: t.double_col.var(where=cond, how="sample"), + lambda df, cond: df.double_col[cond].var(), + id="samp_var_where", + ), + param( + lambda t, cond: t.double_col.std(where=cond, how="pop"), + lambda df, cond: df.double_col[cond].std(ddof=0), + id="pop_std_where", + ), + ], +) +def test_aggregations(alltypes, df, func, pandas_func): + table = alltypes.limit(100) + df = df.head(table.count().execute()) + + cond = table.string_col.isin(["1", "7"]) + expr = func(table, cond) + result = expr.execute() + expected = pandas_func(df, cond.execute()) + + np.testing.assert_allclose(result, expected) + + +def test_not_contains(alltypes, df): + n = 100 + table = alltypes.limit(n) + expr = table.string_col.notin(["1", "7"]) + result = expr.execute() + expected = ~df.head(n).string_col.isin(["1", "7"]) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_group_concat(alltypes, df): + expr = alltypes.string_col.group_concat() + result = expr.execute() + expected = ",".join(df.string_col.dropna()) + assert result == expected + + +def test_distinct_aggregates(alltypes, df): + expr = alltypes.limit(100).double_col.nunique() + result = expr.execute() + assert result == df.head(100).double_col.nunique() + + +def test_not_exists(alltypes, df): + t = alltypes + t2 = t.view() + + expr = t[~((t.string_col == t2.string_col).any())] + result = expr.execute() + + left, right = df, t2.execute() + expected = left[left.string_col != right.string_col] + + tm.assert_frame_equal(result, expected, check_index_type=False, check_dtype=False) + + +def test_interactive_repr_shows_error(alltypes): + # #591. Doing this in Postgres because so many built-in functions are + # not available + + expr = alltypes.int_col.convert_base(10, 2) + + with config.option_context("interactive", True): + result = repr(expr) + + assert "no translation rule" in result.lower() + + +def test_subquery(alltypes, df): + t = alltypes + + expr = t.mutate(d=t.double_col.fillna(0)).limit(1000).group_by("string_col").size() + result = expr.execute().sort_values("string_col").reset_index(drop=True) + expected = ( + df.assign(d=df.double_col.fillna(0)) + .head(1000) + .groupby("string_col") + .string_col.count() + .rename("CountStar()") + .reset_index() + .sort_values("string_col") + .reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +def test_simple_window(alltypes, func, df): + t = alltypes + f = getattr(t.double_col, func) + df_f = getattr(df.double_col, func) + result = t.select((t.double_col - f()).name("double_col")).execute().double_col + expected = df.double_col - df_f() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_rolling_window(alltypes, func, df): + t = alltypes + df = ( + df[["double_col", "timestamp_col"]] + .sort_values("timestamp_col") + .reset_index(drop=True) + ) + window = ibis.window(order_by=t.timestamp_col, preceding=6, following=0) + f = getattr(t.double_col, func) + df_f = getattr(df.double_col.rolling(7, min_periods=0), func) + result = t.select(f().over(window).name("double_col")).execute().double_col + expected = df_f() + tm.assert_series_equal(result, expected) + + +def test_rolling_window_with_mlb(alltypes): + t = alltypes + window = ibis.trailing_window( + preceding=ibis.rows_with_max_lookback(3, ibis.interval(days=5)), + order_by=t.timestamp_col, + ) + expr = t["double_col"].sum().over(window) + with pytest.raises(NotImplementedError): + expr.execute() + + +@pytest.mark.parametrize("func", ["mean", "sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_partitioned_window(alltypes, func, df): + t = alltypes + window = ibis.window( + group_by=t.string_col, + order_by=t.timestamp_col, + preceding=6, + following=0, + ) + + def roller(func): + def rolled(df): + torder = df.sort_values("timestamp_col") + rolling = torder.double_col.rolling(7, min_periods=0) + return getattr(rolling, func)() + + return rolled + + f = getattr(t.double_col, func) + expr = f().over(window).name("double_col") + result = t.select(expr).execute().double_col + expected = df.groupby("string_col").apply(roller(func)).reset_index(drop=True) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_cumulative_simple_window(alltypes, func, df): + t = alltypes + f = getattr(t.double_col, func) + col = t.double_col - f().over(ibis.cumulative_window()) + expr = t.select(col.name("double_col")) + result = expr.execute().double_col + expected = df.double_col - getattr(df.double_col, "cum%s" % func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["sum", "min", "max"]) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_cumulative_ordered_window(alltypes, func, df): + t = alltypes + df = df.sort_values("timestamp_col").reset_index(drop=True) + window = ibis.cumulative_window(order_by=t.timestamp_col) + f = getattr(t.double_col, func) + expr = t.select((t.double_col - f().over(window)).name("double_col")) + result = expr.execute().double_col + expected = df.double_col - getattr(df.double_col, "cum%s" % func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "shift_amount"), [("lead", -1), ("lag", 1)], ids=["lead", "lag"] +) +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_analytic_shift_functions(alltypes, df, func, shift_amount): + method = getattr(alltypes.double_col, func) + expr = method(1) + result = expr.execute().rename("double_col") + expected = df.double_col.shift(shift_amount) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "expected_index"), [("first", -1), ("last", 0)], ids=["first", "last"] +) +@pytest.mark.xfail(reason="Unsupported expr: (first(t0.double_col) + 1) - 1") +def test_first_last_value(alltypes, df, func, expected_index): + col = alltypes.order_by(ibis.desc(alltypes.string_col)).double_col + method = getattr(col, func) + # test that we traverse into expression trees + expr = (1 + method()) - 1 + result = expr.execute() + expected = df.double_col.iloc[expected_index] + assert result == expected + + +def test_null_column(alltypes): + t = alltypes + nrows = t.count().execute() + expr = t.mutate(na_column=ibis.NA).na_column + result = expr.execute() + tm.assert_series_equal(result, pd.Series([None] * nrows, name="na_column")) + + +@pytest.mark.xfail( + reason="Window function with empty PARTITION BY is not supported yet" +) +def test_window_with_arithmetic(alltypes, df): + t = alltypes + w = ibis.window(order_by=t.timestamp_col) + expr = t.mutate(new_col=ibis.row_number().over(w) / 2) + + df = df[["timestamp_col"]].sort_values("timestamp_col").reset_index(drop=True) + expected = df.assign(new_col=[x / 2.0 for x in range(len(df))]) + result = expr["timestamp_col", "new_col"].execute() + tm.assert_frame_equal(result, expected) + + +def test_anonymous_aggregate(alltypes, df): + t = alltypes + expr = t[t.double_col > t.double_col.mean()] + result = expr.execute() + expected = df[df.double_col > df.double_col.mean()].reset_index(drop=True) + tm.assert_frame_equal(result, expected) + + +@pytest.fixture +def array_types(con): + return con.table("array_types") + + +@pytest.mark.xfail( + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype" +) +def test_array_length(array_types): + expr = array_types.select( + array_types.x.length().name("x_length"), + array_types.y.length().name("y_length"), + array_types.z.length().name("z_length"), + ) + result = expr.execute() + expected = pd.DataFrame( + { + "x_length": [3, 2, 2, 3, 3, 4], + "y_length": [3, 2, 2, 3, 3, 4], + "z_length": [3, 2, 2, 0, None, 4], + } + ) + result_sorted = result.sort_values( + by=["x_length", "y_length", "z_length"], na_position="first" + ).reset_index(drop=True) + expected_sorted = expected.sort_values( + by=["x_length", "y_length", "z_length"], na_position="first" + ).reset_index(drop=True) + tm.assert_frame_equal(result_sorted, expected_sorted) + + +def custom_sort_none_first(arr): + return sorted(arr, key=lambda x: (x is not None, x)) + + +def test_head(con): + t = con.table("functional_alltypes") + result = t.head().execute() + expected = t.limit(5).execute() + tm.assert_frame_equal(result, expected) + + +def test_identical_to(con, df): + # TODO: abstract this testing logic out into parameterized fixtures + t = con.table("functional_alltypes") + dt = df[["tinyint_col", "double_col"]] + expr = t.tinyint_col.identical_to(t.double_col) + result = expr.execute() + expected = (dt.tinyint_col.isnull() & dt.double_col.isnull()) | ( + dt.tinyint_col == dt.double_col + ) + expected.name = result.name + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("opname", ["invert", "neg"]) +def test_not_and_negate_bool(con, opname, df): + op = getattr(operator, opname) + t = con.table("functional_alltypes").limit(10) + expr = t.select(op(t.bool_col).name("bool_col")) + result = expr.execute().bool_col + expected = op(df.head(10).bool_col) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "field", + [ + "tinyint_col", + "smallint_col", + "int_col", + "bigint_col", + "float_col", + "double_col", + "year", + "month", + ], +) +def test_negate_non_boolean(con, field, df): + t = con.table("functional_alltypes").limit(10) + expr = t.select((-t[field]).name(field)) + result = expr.execute()[field] + expected = -df.head(10)[field] + tm.assert_series_equal(result, expected) + + +def test_negate_boolean(con, df): + t = con.table("functional_alltypes").limit(10) + expr = t.select((-t.bool_col).name("bool_col")) + result = expr.execute().bool_col + expected = -df.head(10).bool_col + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("opname", ["sum", "mean", "min", "max", "std", "var"]) +def test_boolean_reduction(alltypes, opname, df): + op = operator.methodcaller(opname) + expr = op(alltypes.bool_col) + result = expr.execute() + assert result == op(df.bool_col) + + +def test_timestamp_with_timezone(con): + t = con.table("tzone") + result = t.ts.execute() + assert str(result.dtype.tz) + + +@pytest.fixture( + params=[ + None, + "UTC", + "America/New_York", + "America/Los_Angeles", + "Europe/Paris", + "Chile/Continental", + "Asia/Tel_Aviv", + "Asia/Tokyo", + "Africa/Nairobi", + "Australia/Sydney", + ] +) +def tz(request): + return request.param + + +@pytest.fixture +def tzone_compute(con, temp_table, tz): + schema = ibis.schema([("ts", dt.Timestamp(tz)), ("b", "double"), ("c", "string")]) + con.create_table(temp_table, schema=schema, temp=False) + t = con.table(temp_table) + + n = 10 + df = pd.DataFrame( + { + "ts": pd.date_range("2017-04-01", periods=n, tz=tz).values, + "b": np.arange(n).astype("float64"), + "c": list(string.ascii_lowercase[:n]), + } + ) + + df.to_sql( + temp_table, + con.con, + index=False, + if_exists="append", + dtype={"ts": sa.TIMESTAMP(timezone=True), "b": sa.FLOAT, "c": sa.TEXT}, + ) + + yield t + con.drop_table(temp_table) + + +def test_ts_timezone_is_preserved(tzone_compute, tz): + assert dt.Timestamp(tz).equals(tzone_compute.ts.type()) + + +def test_timestamp_with_timezone_select(tzone_compute, tz): + ts = tzone_compute.ts.execute() + assert str(getattr(ts.dtype, "tz", None)) == str(tz) + + +@pytest.mark.parametrize( + ("left", "right", "type"), + [ + param( + L("2017-04-01 01:02:33"), + datetime(2017, 4, 1, 1, 3, 34), + dt.timestamp, + id="ibis_timestamp", + ), + param( + datetime(2017, 4, 1, 1, 3, 34), + L("2017-04-01 01:02:33"), + dt.timestamp, + id="python_datetime", + ), + ], +) +@pytest.mark.parametrize("opname", ["eq", "ne", "lt", "le", "gt", "ge"]) +def test_string_temporal_compare(con, opname, left, right, type): + op = getattr(operator, opname) + expr = op(left, right) + result = con.execute(expr) + left_raw = con.execute(L(left).cast(type)) + right_raw = con.execute(L(right).cast(type)) + expected = op(left_raw, right_raw) + assert result == expected + + +@pytest.mark.parametrize( + ("left", "right"), + [ + param( + L("2017-03-31 00:02:33").cast(dt.timestamp), + datetime(2017, 4, 1, 1, 3, 34), + id="ibis_timestamp", + ), + param( + datetime(2017, 3, 31, 0, 2, 33), + L("2017-04-01 01:03:34").cast(dt.timestamp), + id="python_datetime", + ), + ], +) +@pytest.mark.parametrize( + "op", + [ + param( + lambda left, right: ibis.timestamp("2017-04-01 00:02:34").between( + left, right + ), + id="timestamp", + ), + param( + lambda left, right: ( + ibis.timestamp("2017-04-01").cast(dt.date).between(left, right) + ), + id="date", + ), + ], +) +def test_string_temporal_compare_between(con, op, left, right): + expr = op(left, right) + result = con.execute(expr) + assert isinstance(result, (bool, np.bool_)) + assert result + + +@pytest.mark.xfail( + reason="function make_date(integer, integer, integer) does not exist" +) +def test_scalar_parameter(con): + start_string, end_string = "2009-03-01", "2010-07-03" + + start = ibis.param(dt.date) + end = ibis.param(dt.date) + t = con.table("functional_alltypes") + col = t.date_string_col.cast("date") + expr = col.between(start, end).name("res") + expected_expr = col.between(start_string, end_string).name("res") + + result = expr.execute(params={start: start_string, end: end_string}) + expected = expected_expr.execute() + tm.assert_series_equal(result, expected) + + +def test_string_to_binary_cast(con): + t = con.table("functional_alltypes").limit(10) + expr = t.string_col.cast("binary") + result = expr.execute() + name = expr.get_name() + sql_string = ( + f"SELECT decode(string_col, 'escape') AS \"{name}\" " + "FROM functional_alltypes LIMIT 10" + ) + with con.begin() as c: + cur = c.exec_driver_sql(sql_string) + raw_data = [row[0][0] for row in cur] + expected = pd.Series(raw_data, name=name) + tm.assert_series_equal(result, expected) + + +def test_string_to_binary_round_trip(con): + t = con.table("functional_alltypes").limit(10) + expr = t.string_col.cast("binary").cast("string") + result = expr.execute() + name = expr.get_name() + sql_string = ( + "SELECT encode(decode(string_col, 'escape'), 'escape') AS " + f'"{name}"' + "FROM functional_alltypes LIMIT 10" + ) + with con.begin() as c: + cur = c.exec_driver_sql(sql_string) + expected = pd.Series([row[0][0] for row in cur], name=name) + tm.assert_series_equal(result, expected) diff --git a/ibis/backends/risingwave/tests/test_json.py b/ibis/backends/risingwave/tests/test_json.py new file mode 100644 index 000000000000..18edda8e3741 --- /dev/null +++ b/ibis/backends/risingwave/tests/test_json.py @@ -0,0 +1,17 @@ +"""Tests for json data types.""" +from __future__ import annotations + +import json + +import pytest +from pytest import param + +import ibis + + +@pytest.mark.parametrize("data", [param({"status": True}, id="status")]) +def test_json(data, alltypes): + lit = ibis.literal(json.dumps(data), type="json").name("tmp") + expr = alltypes[[alltypes.id, lit]].head(1) + df = expr.execute() + assert df["tmp"].iloc[0] == data diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 90a3850b26b9..c3c98af4657f 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -49,6 +49,7 @@ def mean_udf(s): "bigquery", "datafusion", "postgres", + "risingwave", "clickhouse", "impala", "duckdb", @@ -205,6 +206,7 @@ def test_aggregate_grouped(backend, alltypes, df, result_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -518,39 +520,51 @@ def mean_and_std(v): lambda t, where: t.double_col.arbitrary(where=where), lambda t, where: t.double_col[where].iloc[0], id="arbitrary_default", - marks=pytest.mark.notimpl( - [ - "impala", - "mysql", - "polars", - "datafusion", - "mssql", - "druid", - "oracle", - "exasol", - "flink", - ], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "mysql", + "polars", + "datafusion", + "mssql", + "druid", + "oracle", + "exasol", + "flink", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.arbitrary(how="first", where=where), lambda t, where: t.double_col[where].iloc[0], id="arbitrary_first", - marks=pytest.mark.notimpl( - [ - "impala", - "mysql", - "polars", - "datafusion", - "mssql", - "druid", - "oracle", - "exasol", - "flink", - ], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "mysql", + "polars", + "datafusion", + "mssql", + "druid", + "oracle", + "exasol", + "flink", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.arbitrary(how="last", where=where), @@ -576,6 +590,10 @@ def mean_and_std(v): raises=com.UnsupportedOperationError, reason="backend only supports the `first` option for `.arbitrary()", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), ], ), param( @@ -602,7 +620,14 @@ def mean_and_std(v): raises=com.OperationNotDefinedError, ), pytest.mark.notimpl( - ["bigquery", "duckdb", "postgres", "pyspark", "trino"], + [ + "bigquery", + "duckdb", + "postgres", + "risingwave", + "pyspark", + "trino", + ], raises=com.UnsupportedOperationError, reason="how='heavy' not supported in the backend", ), @@ -617,19 +642,31 @@ def mean_and_std(v): lambda t, where: t.double_col.first(where=where), lambda t, where: t.double_col[where].iloc[0], id="first", - marks=pytest.mark.notimpl( - ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.double_col.last(where=where), lambda t, where: t.double_col[where].iloc[-1], id="last", - marks=pytest.mark.notimpl( - ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], - raises=com.OperationNotDefinedError, - ), + marks=[ + pytest.mark.notimpl( + ["dask", "druid", "impala", "mssql", "mysql", "oracle", "flink"], + raises=com.OperationNotDefinedError, + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + ), + ], ), param( lambda t, where: t.bigint_col.bit_and(where=where), @@ -899,6 +936,11 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond): reason="backend doesn't implement approximate quantiles yet", raises=com.OperationNotDefinedError, ), + pytest.mark.broken( + ["risingwave"], + reason="Invalid input syntax: direct arg in `percentile_cont` must be castable to float64", + raises=sa.exc.InternalError, + ), ], ), ], @@ -947,6 +989,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -962,6 +1009,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -982,6 +1034,16 @@ def test_quantile( raises=(ValueError, AttributeError), reason="ClickHouse only implements `sample` correlation coefficient", ), + pytest.mark.notyet( + ["pyspark"], + raises=ValueError, + reason="PySpark only implements sample correlation", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1008,7 +1070,7 @@ def test_quantile( reason="Correlation with how='sample' is not supported.", ), pytest.mark.notyet( - ["oracle"], + ["oracle", "risingwave"], raises=ValueError, reason="XXXXSQLExprTranslator only implements population correlation coefficient", ), @@ -1032,6 +1094,11 @@ def test_quantile( ["mysql", "impala", "sqlite", "flink"], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), param( @@ -1056,6 +1123,16 @@ def test_quantile( raises=ValueError, reason="ClickHouse only implements `sample` correlation coefficient", ), + pytest.mark.notyet( + ["pyspark"], + raises=ValueError, + reason="PySpark only implements sample correlation", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function covar_pop(integer, integer) does not exist", + ), ], ), ], @@ -1374,6 +1451,7 @@ def test_topk_filter_op(con, alltypes, df, result_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -1414,6 +1492,7 @@ def test_aggregate_list_like(backend, alltypes, df, agg_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index a784eddf3c44..5a49efb57ba5 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -9,6 +9,7 @@ import pandas.testing as tm import pytest import pytz +import sqlalchemy as sa import toolz from pytest import param @@ -131,6 +132,11 @@ def test_array_concat_variadic(con): # Issues #2370 @pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError) @pytest.mark.notyet(["trino"], raises=TrinoUserError) +@pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: cannot determine type of empty array", +) def test_array_concat_some_empty(con): left = ibis.literal([]) right = ibis.literal([2, 1]) @@ -209,6 +215,11 @@ def test_array_index(con, idx): reason="backend does not support nullable nested types", raises=AssertionError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) @pytest.mark.never( ["bigquery"], reason="doesn't support arrays of arrays", raises=AssertionError ) @@ -240,6 +251,11 @@ def test_array_discovery(backend): ) @pytest.mark.notimpl(["dask"], raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_simple(backend): array_types = backend.array_types expected = ( @@ -257,6 +273,11 @@ def test_unnest_simple(backend): @builtin_array @pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_complex(backend): array_types = backend.array_types df = array_types.execute() @@ -295,6 +316,11 @@ def test_unnest_complex(backend): ) @pytest.mark.notimpl(["dask"], raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_idempotent(backend): array_types = backend.array_types df = array_types.execute() @@ -316,6 +342,11 @@ def test_unnest_idempotent(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_no_nulls(backend): array_types = backend.array_types df = array_types.execute() @@ -343,6 +374,11 @@ def test_unnest_no_nulls(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_unnest_default_name(backend): array_types = backend.array_types df = array_types.execute() @@ -378,6 +414,11 @@ def test_unnest_default_name(backend): ["datafusion", "flink"], raises=Exception, reason="array_types table isn't defined" ) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_array_slice(backend, start, stop): array_types = backend.array_types expr = array_types.select(sliced=array_types.y[start:stop]) @@ -392,6 +433,11 @@ def test_array_slice(backend, start, stop): @pytest.mark.notimpl( ["datafusion", "polars", "snowflake", "sqlite"], raises=com.OperationNotDefinedError ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.notimpl( ["dask", "pandas"], raises=com.OperationNotDefinedError, @@ -419,6 +465,7 @@ def test_array_slice(backend, start, stop): ], ) def test_array_map(con, input, output): + t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) expected = pd.DataFrame(output) @@ -478,6 +525,11 @@ def test_array_filter(con, input, output): @builtin_array @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="ValueError: Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_array_contains(backend, con): t = backend.array_types expr = t.x.contains(1) @@ -501,6 +553,11 @@ def test_array_position(backend, con): @builtin_array @pytest.mark.notimpl(["dask", "polars"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) def test_array_remove(con): t = ibis.memtable({"a": [[3, 2], [], [42, 2], [2, 2], []]}) expr = t.a.remove(2) @@ -531,6 +588,11 @@ def test_array_remove(con): raises=(AssertionError, GoogleBadRequest), reason="bigquery doesn't support null elements in arrays", ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.parametrize( ("input", "expected"), [ @@ -558,6 +620,11 @@ def test_array_unique(con, input, expected): @pytest.mark.notimpl( ["dask", "datafusion", "polars"], raises=com.OperationNotDefinedError ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14735", +) def test_array_sort(backend, con): t = ibis.memtable({"a": [[3, 2], [], [42, 42], []], "id": range(4)}) expr = t.mutate(a=t.a.sort()).order_by("id") @@ -593,6 +660,11 @@ def test_array_union(con): @pytest.mark.notimpl( ["sqlite"], raises=com.UnsupportedBackendType, reason="Unsupported type: Array..." ) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="TODO(Kexiang): seems a bug", +) @pytest.mark.parametrize( "data", [ @@ -631,6 +703,7 @@ def test_array_intersect(con, data): reason="ClickHouse won't accept dicts for struct type values", ) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) +@pytest.mark.notimpl(["risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["trino"], reason="inserting maps into structs doesn't work", raises=TrinoUserError @@ -648,9 +721,23 @@ def test_unnest_struct(con): @builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "druid", "oracle", "pandas", "polars", "postgres"], + [ + "dask", + "datafusion", + "druid", + "oracle", + "pandas", + "polars", + "postgres", + "risingwave", + ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="Do not nest ARRAY types; ARRAY(basetype) handles multi-dimensional arrays of basetype", +) def test_zip(backend): t = backend.array_types @@ -676,6 +763,7 @@ def test_zip(backend): reason="https://github.com/ClickHouse/ClickHouse/issues/41112", ) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) +@pytest.mark.notimpl(["risingwave"], raises=sa.exc.ProgrammingError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["polars"], @@ -733,7 +821,7 @@ def flatten_data(): ["bigquery"], reason="BigQuery doesn't support arrays of arrays", raises=TypeError ) @pytest.mark.notyet( - ["postgres"], + ["postgres", "risingwave"], reason="Postgres doesn't truly support arrays of arrays", raises=(com.OperationNotDefinedError, PsycoPg2IndeterminateDatatype), ) @@ -804,6 +892,7 @@ def test_range_single_argument(con, n): ) @pytest.mark.parametrize("n", [-2, 0, 2]) @pytest.mark.notimpl(["polars", "flink", "dask"], raises=com.OperationNotDefinedError) +@pytest.mark.skip("risingwave") def test_range_single_argument_unnest(backend, con, n): expr = ibis.range(n).unnest() result = con.execute(expr) @@ -850,6 +939,11 @@ def test_range_start_stop_step(con, start, stop, step): ["datafusion"], raises=com.OperationNotDefinedError, reason="not supported upstream" ) @pytest.mark.notimpl(["flink", "dask"], raises=com.OperationNotDefinedError) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Invalid parameter step: step size cannot equal zero", +) def test_range_start_stop_step_zero(con, start, stop): expr = ibis.range(start, stop, 0) result = con.execute(expr) @@ -957,6 +1051,11 @@ def swap(token): ibis.interval(hours=1), "1H", id="pos", + marks=pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_interval() does not exist", + ), ), param( datetime(2017, 1, 2), @@ -967,7 +1066,12 @@ def swap(token): marks=[ pytest.mark.broken( ["polars"], raises=AssertionError, reason="returns an empty array" - ) + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), param( @@ -983,6 +1087,11 @@ def swap(token): ["clickhouse", "snowflake"], raises=com.UnsupportedOperationError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), ], @@ -1008,7 +1117,14 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): datetime(2017, 1, 2, tzinfo=pytz.UTC), ibis.interval(hours=0), id="pos", - marks=[pytest.mark.notyet(["polars"], raises=PolarsComputeError)], + marks=[ + pytest.mark.notyet(["polars"], raises=PolarsComputeError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_interval() does not exist", + ), + ], ), param( datetime(2017, 1, 1, tzinfo=pytz.UTC), @@ -1022,6 +1138,11 @@ def test_timestamp_range(con, start, stop, step, freq, tzinfo): ["clickhouse", "snowflake"], raises=com.UnsupportedOperationError, ), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function neg(interval) does not exist", + ), ], ), ], @@ -1051,3 +1172,19 @@ def test_repr_timestamp_array(con, monkeypatch): expr = ibis.array(pd.date_range("2010-01-01", "2010-01-03", freq="D").tolist()) assert "Translation to backend failed" not in repr(expr) + + +@pytest.mark.notyet( + ["dask", "datafusion", "flink", "pandas", "polars"], + raises=com.OperationNotDefinedError, +) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.OperationalError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14734", +) +def test_unnest_range(con): + expr = ibis.range(2).unnest().name("x").as_table().mutate({"y": 1.0}) + result = con.execute(expr) + expected = pd.DataFrame({"x": np.array([0, 1], dtype="int8"), "y": [1.0, 1.0]}) + tm.assert_frame_equal(result, expected) diff --git a/ibis/backends/tests/test_benchmarks.py b/ibis/backends/tests/test_benchmarks.py deleted file mode 100644 index 4805a5ccf5bc..000000000000 --- a/ibis/backends/tests/test_benchmarks.py +++ /dev/null @@ -1,900 +0,0 @@ -from __future__ import annotations - -import copy -import functools -import inspect -import itertools -import os -import string - -import numpy as np -import pandas as pd -import pytest -import sqlalchemy as sa -from packaging.version import parse as vparse - -import ibis -import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -import ibis.expr.types as ir -from ibis.backends.base import _get_backend_names - -# from ibis.backends.pandas.udf import udf - -# FIXME(kszucs): pytestmark = pytest.mark.benchmark -pytestmark = pytest.mark.skip(reason="the backends must be rewritten first") - - -def make_t(): - return ibis.table( - [ - ("_timestamp", "int32"), - ("dim1", "int32"), - ("dim2", "int32"), - ("valid_seconds", "int32"), - ("meas1", "int32"), - ("meas2", "int32"), - ("year", "int32"), - ("month", "int32"), - ("day", "int32"), - ("hour", "int32"), - ("minute", "int32"), - ], - name="t", - ) - - -@pytest.fixture(scope="module") -def t(): - return make_t() - - -def make_base(t): - return t[ - ( - (t.year > 2016) - | ((t.year == 2016) & (t.month > 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day > 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day == 6) & (t.hour > 6)) - | ( - (t.year == 2016) - & (t.month == 6) - & (t.day == 6) - & (t.hour == 6) - & (t.minute >= 5) - ) - ) - & ( - (t.year < 2016) - | ((t.year == 2016) & (t.month < 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day < 6)) - | ((t.year == 2016) & (t.month == 6) & (t.day == 6) & (t.hour < 6)) - | ( - (t.year == 2016) - & (t.month == 6) - & (t.day == 6) - & (t.hour == 6) - & (t.minute <= 5) - ) - ) - ] - - -@pytest.fixture(scope="module") -def base(t): - return make_base(t) - - -def make_large_expr(base): - src_table = base - src_table = src_table.mutate( - _timestamp=(src_table["_timestamp"] - src_table["_timestamp"] % 3600) - .cast("int32") - .name("_timestamp"), - valid_seconds=300, - ) - - aggs = [] - for meas in ["meas1", "meas2"]: - aggs.append(src_table[meas].sum().cast("float").name(meas)) - src_table = src_table.aggregate( - aggs, by=["_timestamp", "dim1", "dim2", "valid_seconds"] - ) - - part_keys = ["year", "month", "day", "hour", "minute"] - ts_col = src_table["_timestamp"].cast("timestamp") - new_cols = {} - for part_key in part_keys: - part_col = getattr(ts_col, part_key)() - new_cols[part_key] = part_col - src_table = src_table.mutate(**new_cols) - return src_table[ - [ - "_timestamp", - "dim1", - "dim2", - "meas1", - "meas2", - "year", - "month", - "day", - "hour", - "minute", - ] - ] - - -@pytest.fixture(scope="module") -def large_expr(base): - return make_large_expr(base) - - -@pytest.mark.benchmark(group="construction") -@pytest.mark.parametrize( - "construction_fn", - [ - pytest.param(lambda *_: make_t(), id="small"), - pytest.param(lambda t, *_: make_base(t), id="medium"), - pytest.param(lambda _, base: make_large_expr(base), id="large"), - ], -) -def test_construction(benchmark, construction_fn, t, base): - benchmark(construction_fn, t, base) - - -@pytest.mark.benchmark(group="builtins") -@pytest.mark.parametrize( - "expr_fn", - [ - pytest.param(lambda t, _base, _large_expr: t, id="small"), - pytest.param(lambda _t, base, _large_expr: base, id="medium"), - pytest.param(lambda _t, _base, large_expr: large_expr, id="large"), - ], -) -@pytest.mark.parametrize("builtin", [hash, str]) -def test_builtins(benchmark, expr_fn, builtin, t, base, large_expr): - expr = expr_fn(t, base, large_expr) - benchmark(builtin, expr) - - -_backends = set(_get_backend_names()) -# compile is a no-op -_backends.remove("pandas") - -_XFAIL_COMPILE_BACKENDS = {"dask", "pyspark", "polars"} - - -@pytest.mark.benchmark(group="compilation") -@pytest.mark.parametrize( - "module", - [ - pytest.param( - mod, - marks=pytest.mark.xfail( - condition=mod in _XFAIL_COMPILE_BACKENDS, - reason=f"{mod} backend doesn't support compiling UnboundTable", - ), - ) - for mod in _backends - ], -) -@pytest.mark.parametrize( - "expr_fn", - [ - pytest.param(lambda t, _base, _large_expr: t, id="small"), - pytest.param(lambda _t, base, _large_expr: base, id="medium"), - pytest.param(lambda _t, _base, large_expr: large_expr, id="large"), - ], -) -def test_compile(benchmark, module, expr_fn, t, base, large_expr): - try: - mod = getattr(ibis, module) - except (AttributeError, ImportError) as e: - pytest.skip(str(e)) - else: - expr = expr_fn(t, base, large_expr) - try: - benchmark(mod.compile, expr) - except (sa.exc.NoSuchModuleError, ImportError) as e: # delayed imports - pytest.skip(str(e)) - - -@pytest.fixture(scope="module") -def pt(): - n = 60_000 - data = pd.DataFrame( - { - "key": np.random.choice(16000, size=n), - "low_card_key": np.random.choice(30, size=n), - "value": np.random.rand(n), - "timestamps": pd.date_range( - start="2023-05-05 16:37:57", periods=n, freq="s" - ).values, - "timestamp_strings": pd.date_range( - start="2023-05-05 16:37:39", periods=n, freq="s" - ).values.astype(str), - "repeated_timestamps": pd.date_range(start="2018-09-01", periods=30).repeat( - int(n / 30) - ), - } - ) - - return ibis.pandas.connect(dict(df=data)).table("df") - - -def high_card_group_by(t): - return t.group_by(t.key).aggregate(avg_value=t.value.mean()) - - -def cast_to_dates(t): - return t.timestamps.cast(dt.date) - - -def cast_to_dates_from_strings(t): - return t.timestamp_strings.cast(dt.date) - - -def multikey_group_by_with_mutate(t): - return ( - t.mutate(dates=t.timestamps.cast("date")) - .group_by(["low_card_key", "dates"]) - .aggregate(avg_value=lambda t: t.value.mean()) - ) - - -def simple_sort(t): - return t.order_by([t.key]) - - -def simple_sort_projection(t): - return t[["key", "value"]].order_by(["key"]) - - -def multikey_sort(t): - return t.order_by(["low_card_key", "key"]) - - -def multikey_sort_projection(t): - return t[["low_card_key", "key", "value"]].order_by(["low_card_key", "key"]) - - -def low_card_rolling_window(t): - return ibis.trailing_range_window( - ibis.interval(days=2), - order_by=t.repeated_timestamps, - group_by=t.low_card_key, - ) - - -def low_card_grouped_rolling(t): - return t.value.mean().over(low_card_rolling_window(t)) - - -def high_card_rolling_window(t): - return ibis.trailing_range_window( - ibis.interval(days=2), - order_by=t.repeated_timestamps, - group_by=t.key, - ) - - -def high_card_grouped_rolling(t): - return t.value.mean().over(high_card_rolling_window(t)) - - -# @udf.reduction(["double"], "double") -# def my_mean(series): -# return series.mean() - - -def low_card_grouped_rolling_udf_mean(t): - return my_mean(t.value).over(low_card_rolling_window(t)) - - -def high_card_grouped_rolling_udf_mean(t): - return my_mean(t.value).over(high_card_rolling_window(t)) - - -# @udf.analytic(["double"], "double") -# def my_zscore(series): -# return (series - series.mean()) / series.std() - - -def low_card_window(t): - return ibis.window(group_by=t.low_card_key) - - -def high_card_window(t): - return ibis.window(group_by=t.key) - - -def low_card_window_analytics_udf(t): - return my_zscore(t.value).over(low_card_window(t)) - - -def high_card_window_analytics_udf(t): - return my_zscore(t.value).over(high_card_window(t)) - - -# @udf.reduction(["double", "double"], "double") -# def my_wm(v, w): -# return np.average(v, weights=w) - - -def low_card_grouped_rolling_udf_wm(t): - return my_wm(t.value, t.value).over(low_card_rolling_window(t)) - - -def high_card_grouped_rolling_udf_wm(t): - return my_wm(t.value, t.value).over(low_card_rolling_window(t)) - - -broken_pandas_grouped_rolling = pytest.mark.xfail( - condition=vparse("1.4") <= vparse(pd.__version__) < vparse("1.4.2"), - raises=ValueError, - reason="https://github.com/pandas-dev/pandas/pull/44068", -) - - -@pytest.mark.benchmark(group="execution") -@pytest.mark.parametrize( - "expression_fn", - [ - pytest.param(high_card_group_by, id="high_card_group_by"), - pytest.param(cast_to_dates, id="cast_to_dates"), - pytest.param(cast_to_dates_from_strings, id="cast_to_dates_from_strings"), - pytest.param(multikey_group_by_with_mutate, id="multikey_group_by_with_mutate"), - pytest.param(simple_sort, id="simple_sort"), - pytest.param(simple_sort_projection, id="simple_sort_projection"), - pytest.param(multikey_sort, id="multikey_sort"), - pytest.param(multikey_sort_projection, id="multikey_sort_projection"), - pytest.param( - low_card_grouped_rolling, - id="low_card_grouped_rolling", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling, - id="high_card_grouped_rolling", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - low_card_grouped_rolling_udf_mean, - id="low_card_grouped_rolling_udf_mean", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling_udf_mean, - id="high_card_grouped_rolling_udf_mean", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param(low_card_window_analytics_udf, id="low_card_window_analytics_udf"), - pytest.param( - high_card_window_analytics_udf, id="high_card_window_analytics_udf" - ), - pytest.param( - low_card_grouped_rolling_udf_wm, - id="low_card_grouped_rolling_udf_wm", - marks=[broken_pandas_grouped_rolling], - ), - pytest.param( - high_card_grouped_rolling_udf_wm, - id="high_card_grouped_rolling_udf_wm", - marks=[broken_pandas_grouped_rolling], - ), - ], -) -def test_execute(benchmark, expression_fn, pt): - expr = expression_fn(pt) - benchmark(expr.execute) - - -@pytest.fixture(scope="module") -def part(): - return ibis.table( - dict( - p_partkey="int64", - p_size="int64", - p_type="string", - p_mfgr="string", - ), - name="part", - ) - - -@pytest.fixture(scope="module") -def supplier(): - return ibis.table( - dict( - s_suppkey="int64", - s_nationkey="int64", - s_name="string", - s_acctbal="decimal(15, 3)", - s_address="string", - s_phone="string", - s_comment="string", - ), - name="supplier", - ) - - -@pytest.fixture(scope="module") -def partsupp(): - return ibis.table( - dict( - ps_partkey="int64", - ps_suppkey="int64", - ps_supplycost="decimal(15, 3)", - ), - name="partsupp", - ) - - -@pytest.fixture(scope="module") -def nation(): - return ibis.table( - dict(n_nationkey="int64", n_regionkey="int64", n_name="string"), - name="nation", - ) - - -@pytest.fixture(scope="module") -def region(): - return ibis.table(dict(r_regionkey="int64", r_name="string"), name="region") - - -@pytest.fixture(scope="module") -def tpc_h02(part, supplier, partsupp, nation, region): - REGION = "EUROPE" - SIZE = 25 - TYPE = "BRASS" - - expr = ( - part.join(partsupp, part.p_partkey == partsupp.ps_partkey) - .join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) - .join(nation, supplier.s_nationkey == nation.n_nationkey) - .join(region, nation.n_regionkey == region.r_regionkey) - ) - - subexpr = ( - partsupp.join(supplier, supplier.s_suppkey == partsupp.ps_suppkey) - .join(nation, supplier.s_nationkey == nation.n_nationkey) - .join(region, nation.n_regionkey == region.r_regionkey) - ) - - subexpr = subexpr[ - (subexpr.r_name == REGION) & (expr.p_partkey == subexpr.ps_partkey) - ] - - filters = [ - expr.p_size == SIZE, - expr.p_type.like(f"%{TYPE}"), - expr.r_name == REGION, - expr.ps_supplycost == subexpr.ps_supplycost.min(), - ] - q = expr.filter(filters) - - q = q.select( - [ - q.s_acctbal, - q.s_name, - q.n_name, - q.p_partkey, - q.p_mfgr, - q.s_address, - q.s_phone, - q.s_comment, - ] - ) - - return q.order_by( - [ - ibis.desc(q.s_acctbal), - q.n_name, - q.s_name, - q.p_partkey, - ] - ).limit(100) - - -@pytest.mark.benchmark(group="repr") -def test_repr_tpc_h02(benchmark, tpc_h02): - benchmark(repr, tpc_h02) - - -@pytest.mark.benchmark(group="repr") -def test_repr_huge_union(benchmark): - n = 10 - raw_types = [ - "int64", - "float64", - "string", - "array, b: map>>>", - ] - tables = [ - ibis.table( - list(zip(string.ascii_letters, itertools.cycle(raw_types))), - name=f"t{i:d}", - ) - for i in range(n) - ] - expr = functools.reduce(ir.Table.union, tables) - benchmark(repr, expr) - - -@pytest.mark.benchmark(group="node_args") -def test_op_argnames(benchmark): - t = ibis.table([("a", "int64")]) - expr = t[["a"]] - benchmark(lambda op: op.argnames, expr.op()) - - -@pytest.mark.benchmark(group="node_args") -def test_op_args(benchmark): - t = ibis.table([("a", "int64")]) - expr = t[["a"]] - benchmark(lambda op: op.args, expr.op()) - - -@pytest.mark.benchmark(group="datatype") -def test_complex_datatype_parse(benchmark): - type_str = "array, b: map>>>" - expected = dt.Array( - dt.Struct(dict(a=dt.Array(dt.string), b=dt.Map(dt.string, dt.Array(dt.int64)))) - ) - assert dt.parse(type_str) == expected - benchmark(dt.parse, type_str) - - -@pytest.mark.benchmark(group="datatype") -@pytest.mark.parametrize("func", [str, hash]) -def test_complex_datatype_builtins(benchmark, func): - datatype = dt.Array( - dt.Struct(dict(a=dt.Array(dt.string), b=dt.Map(dt.string, dt.Array(dt.int64)))) - ) - benchmark(func, datatype) - - -@pytest.mark.benchmark(group="equality") -def test_large_expr_equals(benchmark, tpc_h02): - benchmark(ir.Expr.equals, tpc_h02, copy.deepcopy(tpc_h02)) - - -@pytest.mark.benchmark(group="datatype") -@pytest.mark.parametrize( - "dtypes", - [ - pytest.param( - [ - obj - for _, obj in inspect.getmembers( - dt, - lambda obj: isinstance(obj, dt.DataType), - ) - ], - id="singletons", - ), - pytest.param( - dt.Array( - dt.Struct( - dict( - a=dt.Array(dt.string), - b=dt.Map(dt.string, dt.Array(dt.int64)), - ) - ) - ), - id="complex", - ), - ], -) -def test_eq_datatypes(benchmark, dtypes): - def eq(a, b): - assert a == b - - benchmark(eq, dtypes, copy.deepcopy(dtypes)) - - -def multiple_joins(table, num_joins): - for _ in range(num_joins): - table = table.mutate(dummy=ibis.literal("")) - table = table.left_join(table, ["dummy"])[[table]] - - -@pytest.mark.parametrize("num_joins", [1, 10]) -@pytest.mark.parametrize("num_columns", [1, 10, 100]) -def test_multiple_joins(benchmark, num_joins, num_columns): - table = ibis.table( - {f"col_{i:d}": "string" for i in range(num_columns)}, - name="t", - ) - benchmark(multiple_joins, table, num_joins) - - -@pytest.fixture -def customers(): - return ibis.table( - dict( - customerid="int32", - name="string", - address="string", - citystatezip="string", - birthdate="date", - phone="string", - timezone="string", - lat="float64", - long="float64", - ), - name="customers", - ) - - -@pytest.fixture -def orders(): - return ibis.table( - dict( - orderid="int32", - customerid="int32", - ordered="timestamp", - shipped="timestamp", - items="string", - total="float64", - ), - name="orders", - ) - - -@pytest.fixture -def orders_items(): - return ibis.table( - dict(orderid="int32", sku="string", qty="int32", unit_price="float64"), - name="orders_items", - ) - - -@pytest.fixture -def products(): - return ibis.table( - dict( - sku="string", - desc="string", - weight_kg="float64", - cost="float64", - dims_cm="string", - ), - name="products", - ) - - -@pytest.mark.benchmark(group="compilation") -@pytest.mark.parametrize( - "module", - [ - pytest.param( - mod, - marks=pytest.mark.xfail( - condition=mod in _XFAIL_COMPILE_BACKENDS, - reason=f"{mod} backend doesn't support compiling UnboundTable", - ), - ) - for mod in _backends - ], -) -def test_compile_with_drops( - benchmark, module, customers, orders, orders_items, products -): - expr = ( - customers.join(orders, "customerid") - .join(orders_items, "orderid") - .join(products, "sku") - .drop("customerid", "qty", "total", "items") - .drop("dims_cm", "cost") - .mutate(o_date=lambda t: t.shipped.date()) - .filter(lambda t: t.ordered == t.shipped) - ) - - try: - mod = getattr(ibis, module) - except (AttributeError, ImportError) as e: - pytest.skip(str(e)) - else: - try: - benchmark(mod.compile, expr) - except sa.exc.NoSuchModuleError as e: - pytest.skip(str(e)) - - -def test_repr_join(benchmark, customers, orders, orders_items, products): - expr = ( - customers.join(orders, "customerid") - .join(orders_items, "orderid") - .join(products, "sku") - .drop("customerid", "qty", "total", "items") - ) - op = expr.op() - benchmark(repr, op) - - -@pytest.mark.parametrize("overwrite", [True, False], ids=["overwrite", "no_overwrite"]) -def test_insert_duckdb(benchmark, overwrite, tmp_path): - pytest.importorskip("duckdb") - - n_rows = int(1e4) - table_name = "t" - schema = ibis.schema(dict(a="int64", b="int64", c="int64")) - t = ibis.memtable(dict.fromkeys(list("abc"), range(n_rows)), schema=schema) - - con = ibis.duckdb.connect(tmp_path / "test_insert.ddb") - con.create_table(table_name, schema=schema) - benchmark(con.insert, table_name, t, overwrite=overwrite) - - -def test_snowflake_medium_sized_to_pandas(benchmark): - pytest.importorskip("snowflake.connector") - - if (url := os.environ.get("SNOWFLAKE_URL")) is None: - pytest.skip("SNOWFLAKE_URL environment variable not set") - - con = ibis.connect(url) - - # LINEITEM at scale factor 1 is around 6MM rows, but we limit to 1,000,000 - # to make the benchmark fast enough for development, yet large enough to show a - # difference if there's a performance hit - lineitem = con.table("LINEITEM", schema="SNOWFLAKE_SAMPLE_DATA.TPCH_SF1").limit( - 1_000_000 - ) - - benchmark.pedantic(lineitem.to_pandas, rounds=5, iterations=1, warmup_rounds=1) - - -def test_parse_many_duckdb_types(benchmark): - parse = pytest.importorskip("ibis.backends.duckdb.datatypes").DuckDBType.from_string - - def parse_many(types): - list(map(parse, types)) - - types = ["VARCHAR", "INTEGER", "DOUBLE", "BIGINT"] * 1000 - benchmark(parse_many, types) - - -@pytest.fixture(scope="session") -def sql() -> str: - return """ - SELECT t1.id as t1_id, x, t2.id as t2_id, y - FROM t1 INNER JOIN t2 - ON t1.id = t2.id - """ - - -@pytest.fixture(scope="session") -def ddb(tmp_path_factory): - duckdb = pytest.importorskip("duckdb") - - N = 20_000_000 - - con = duckdb.connect() - - path = str(tmp_path_factory.mktemp("duckdb") / "data.ddb") - sql = ( - lambda var, table, n=N: f""" - CREATE TABLE {table} AS - SELECT ROW_NUMBER() OVER () AS id, {var} - FROM ( - SELECT {var} - FROM RANGE({n}) _ ({var}) - ORDER BY RANDOM() - ) - """ - ) - - with duckdb.connect(path) as con: - con.execute(sql("x", table="t1")) - con.execute(sql("y", table="t2")) - return path - - -def test_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: - # yes, we're benchmarking duckdb here, not ibis - # - # we do this to get a baseline for comparison - duckdb = pytest.importorskip("duckdb") - con = duckdb.connect(ddb, read_only=True) - - benchmark(lambda sql: con.sql(sql).to_arrow_table(), sql) - - -def test_ibis_duckdb_to_pyarrow(benchmark, sql, ddb) -> None: - pytest.importorskip("duckdb") - - con = ibis.duckdb.connect(ddb, read_only=True) - - expr = con.sql(sql) - benchmark(expr.to_pyarrow) - - -@pytest.fixture -def diffs(): - return ibis.table( - { - "id": "int64", - "validation_name": "string", - "difference": "float64", - "pct_difference": "float64", - "pct_threshold": "float64", - "validation_status": "string", - }, - name="diffs", - ) - - -@pytest.fixture -def srcs(): - return ibis.table( - { - "id": "int64", - "validation_name": "string", - "validation_type": "string", - "aggregation_type": "string", - "table_name": "string", - "column_name": "string", - "primary_keys": "string", - "num_random_rows": "string", - "agg_value": "float64", - }, - name="srcs", - ) - - -@pytest.fixture -def nrels(): - return 300 - - -def make_big_union(t, nrels): - return ibis.union(*[t] * nrels) - - -@pytest.fixture -def src(srcs, nrels): - return make_big_union(srcs, nrels) - - -@pytest.fixture -def diff(diffs, nrels): - return make_big_union(diffs, nrels) - - -def test_big_eq_expr(benchmark, src, diff): - benchmark(ops.core.Node.equals, src.op(), diff.op()) - - -def test_big_join_expr(benchmark, src, diff): - benchmark(ir.Table.join, src, diff, ["validation_name"], how="outer") - - -def test_big_join_execute(benchmark, nrels): - pytest.importorskip("duckdb") - - con = ibis.duckdb.connect() - - # cache to avoid a request-per-union operand - src = make_big_union( - con.read_csv( - "https://github.com/ibis-project/ibis/files/12580336/source_pivot.csv" - ) - .rename(id="column0") - .cache(), - nrels, - ) - - diff = make_big_union( - con.read_csv( - "https://github.com/ibis-project/ibis/files/12580340/differences_pivot.csv" - ) - .rename(id="column0") - .cache(), - nrels, - ) - - expr = src.join(diff, ["validation_name"], how="outer") - t = benchmark.pedantic(expr.to_pyarrow, rounds=1, iterations=1, warmup_rounds=1) - assert len(t) diff --git a/ibis/backends/tests/test_binary.py b/ibis/backends/tests/test_binary.py index 0a5790c64631..1d9f7cfa0516 100644 --- a/ibis/backends/tests/test_binary.py +++ b/ibis/backends/tests/test_binary.py @@ -15,6 +15,7 @@ "sqlite": "blob", "trino": "varbinary", "postgres": "bytea", + "risingwave": "bytea", "flink": "BINARY(1) NOT NULL", } diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 3f2336cafe9b..2d8674eb8384 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -306,6 +306,11 @@ def tmpcon(alchemy_con): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_create_temporary_table_from_schema(tmpcon, new_schema): temp_table = f"_{guid()}" table = tmpcon.create_table(temp_table, schema=new_schema, temp=True) @@ -338,6 +343,7 @@ def test_create_temporary_table_from_schema(tmpcon, new_schema): "pandas", "polars", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -367,6 +373,11 @@ def test_rename_table(con, temp_table, temp_table_orig): raises=com.IbisError, reason="`tbl_properties` is required when creating table with schema", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason='Feature is not yet implemented: column constraints "NOT NULL"', +) def test_nullable_input_output(con, temp_table): sch = ibis.schema( [("foo", "int64"), ("bar", dt.int64(nullable=False)), ("baz", "boolean")] @@ -412,7 +423,7 @@ def test_create_drop_view(ddl_con, temp_view): assert set(t_expr.schema().names) == set(v_expr.schema().names) -@mark.notimpl(["postgres", "polars"]) +@mark.notimpl(["postgres", "risingwave", "polars"]) @mark.notimpl( ["datafusion"], raises=NotImplementedError, @@ -622,6 +633,7 @@ def test_list_databases(alchemy_con): test_databases = { "sqlite": {"main"}, "postgres": {"postgres", "ibis_testing"}, + "risingwave": {"dev"}, "mssql": {"ibis_testing"}, "mysql": {"ibis_testing", "information_schema"}, "duckdb": {"memory"}, @@ -634,7 +646,7 @@ def test_list_databases(alchemy_con): @pytest.mark.never( - ["bigquery", "postgres", "mssql", "mysql", "oracle"], + ["bigquery", "postgres", "risingwave", "mssql", "mysql", "oracle"], reason="backend does not support client-side in-memory tables", raises=(sa.exc.OperationalError, TypeError, sa.exc.InterfaceError), ) @@ -707,6 +719,11 @@ def test_unsigned_integer_type(alchemy_con, alchemy_temp_table): marks=mark.postgres, id="postgresql", ), + param( + "postgresql://root:@localhost:4566/dev", + marks=mark.risingwave, + id="risingwave", + ), param( "pyspark://?spark.app.name=test-pyspark", marks=[ @@ -1123,6 +1140,11 @@ def test_set_backend_name(name, monkeypatch): marks=mark.postgres, id="postgres", ), + param( + "postgres://root:@localhost:4566/dev", + marks=mark.risingwave, + id="risingwave", + ), ], ) def test_set_backend_url(url, monkeypatch): @@ -1147,6 +1169,7 @@ def test_set_backend_url(url, monkeypatch): "pandas", "polars", "postgres", + "risingwave", "pyspark", "sqlite", ], @@ -1183,6 +1206,11 @@ def test_create_table_timestamp(con, temp_table): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_ref_count(backend, con, alltypes): non_persisted_table = alltypes.mutate(test_column="calculation") persisted_table = non_persisted_table.cache() @@ -1203,6 +1231,11 @@ def test_persist_expression_ref_count(backend, con, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression(backend, alltypes): non_persisted_table = alltypes.mutate(test_column="calculation", other_calc="xyz") persisted_table = non_persisted_table.cache() @@ -1217,6 +1250,11 @@ def test_persist_expression(backend, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_contextmanager(backend, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc" @@ -1233,6 +1271,11 @@ def test_persist_expression_contextmanager(backend, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 2" @@ -1251,6 +1294,11 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): ["mssql"], reason="mssql supports support temporary tables through naming conventions", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") def test_persist_expression_multiple_refs(backend, con, alltypes): non_cached_table = alltypes.mutate( @@ -1288,6 +1336,11 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): reason="mssql supports support temporary tables through naming conventions", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_repeated_cache(alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 2" @@ -1307,6 +1360,11 @@ def test_persist_expression_repeated_cache(alltypes): ["oracle"], reason="Oracle error message for a missing table/view doesn't include the name of the table", ) +@pytest.mark.never( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", +) def test_persist_expression_release(con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 3" @@ -1391,6 +1449,11 @@ def test_create_schema(con_create_schema): assert schema not in con_create_schema.list_schemas() +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: information_schema.schemata is not supported,", +) def test_list_schemas(con_create_schema): schemas = con_create_schema.list_schemas() assert len(schemas) == len(set(schemas)) diff --git a/ibis/backends/tests/test_column.py b/ibis/backends/tests/test_column.py index f26b2a876ded..f6b4bd8ee0f4 100644 --- a/ibis/backends/tests/test_column.py +++ b/ibis/backends/tests/test_column.py @@ -19,6 +19,7 @@ "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "trino", diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index 4690702f0c18..a189e17d8760 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -238,7 +238,7 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): backend.assert_series_equal(foo2.x.execute(), expected2) -_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink"} +_NO_SQLGLOT_DIALECT = {"pandas", "dask", "druid", "flink", "risingwave"} no_sqlglot_dialect = sorted( # TODO(cpcloud): remove the strict=False hack once backends are ported to # sqlglot @@ -255,6 +255,11 @@ def test_dot_sql_reuse_alias_with_different_types(backend, alltypes, df): ], ) @pytest.mark.notyet(["polars"], raises=PolarsComputeError) +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) @dot_sql_notimpl @dot_sql_never @pytest.mark.notyet(["druid"], reason="druid doesn't respect column name case") @@ -282,6 +287,11 @@ def test_table_dot_sql_transpile(backend, alltypes, dialect, df): ["druid"], raises=AttributeError, reason="druid doesn't respect column names" ) @pytest.mark.notyet(["snowflake", "bigquery"]) +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) @dot_sql_notimpl @dot_sql_never def test_con_dot_sql_transpile(backend, con, dialect, df): @@ -301,6 +311,11 @@ def test_con_dot_sql_transpile(backend, con, dialect, df): @dot_sql_never @pytest.mark.notimpl(["druid", "flink", "polars", "exasol"]) @pytest.mark.notyet(["snowflake"], reason="snowflake column names are case insensitive") +@pytest.mark.notyet( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_order_by_no_projection(backend): con = backend.connection expr = ( diff --git a/ibis/backends/tests/test_examples.py b/ibis/backends/tests/test_examples.py index b73ea8e2f3da..113fe70102a4 100644 --- a/ibis/backends/tests/test_examples.py +++ b/ibis/backends/tests/test_examples.py @@ -16,7 +16,7 @@ reason="nix on linux cannot download duckdb extensions or data due to sandboxing", ) @pytest.mark.notimpl(["dask", "exasol", "pyspark"]) -@pytest.mark.notyet(["clickhouse", "druid", "impala", "mssql", "trino"]) +@pytest.mark.notyet(["clickhouse", "druid", "impala", "mssql", "trino", "risingwave"]) @pytest.mark.parametrize( ("example", "columns"), [ diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 658850c6096e..c6346b645a97 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -4,6 +4,7 @@ import pyarrow as pa import pyarrow.csv as pcsv import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -254,6 +255,7 @@ def test_table_to_parquet_writer_kwargs(version, tmp_path, backend, awards_playe "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -348,6 +350,11 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): marks=[ pytest.mark.notyet(["flink"], raises=NotImplementedError), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.DBAPIError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(38,9)", + ), ], ), param( @@ -369,6 +376,11 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): ), pytest.mark.notyet(["flink"], raises=NotImplementedError), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.DBAPIError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(76,38)", + ), ], ), ], @@ -390,6 +402,7 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "bigquery", @@ -488,7 +501,22 @@ def test_to_pandas_batches_empty_table(backend, con): @pytest.mark.notimpl(["flink"]) -@pytest.mark.parametrize("n", [None, 1]) +@pytest.mark.parametrize( + "n", + [ + param( + None, + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit null", + ), + ], + ), + 1, + ], +) def test_to_pandas_batches_nonempty_table(backend, con, n): t = backend.functional_alltypes.limit(n) n = t.count().execute() @@ -498,7 +526,24 @@ def test_to_pandas_batches_nonempty_table(backend, con, n): @pytest.mark.notimpl(["flink"]) -@pytest.mark.parametrize("n", [None, 0, 1, 2]) +@pytest.mark.parametrize( + "n", + [ + param( + None, + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit null", + ), + ], + ), + 0, + 1, + 2, + ], +) def test_to_pandas_batches_column(backend, con, n): t = backend.functional_alltypes.limit(n).timestamp_col n = t.count().execute() diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 241c57ce8870..1ff8122967f1 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa import toolz from pytest import param @@ -43,6 +44,7 @@ "sqlite": "null", "trino": "unknown", "postgres": "null", + "risingwave": "null", } @@ -66,6 +68,7 @@ def test_null_literal(con, backend): "trino": "boolean", "duckdb": "BOOLEAN", "postgres": "boolean", + "risingwave": "boolean", "flink": "BOOLEAN NOT NULL", } @@ -150,6 +153,7 @@ def test_isna(backend, alltypes, col, value, filt): "duckdb", "impala", "postgres", + "risingwave", "mysql", "snowflake", "polars", @@ -307,6 +311,7 @@ def test_filter(backend, alltypes, sorted_df, predicate_fn, expected_fn): "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", @@ -540,6 +545,11 @@ def test_order_by(backend, alltypes, df, key, df_kwargs): @pytest.mark.notimpl(["dask", "pandas", "polars", "mssql", "druid"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_order_by_random(alltypes): expr = alltypes.filter(_.id < 100).order_by(ibis.random()).limit(5) r1 = expr.execute() @@ -761,6 +771,11 @@ def test_correlated_subquery(alltypes): @pytest.mark.notimpl(["polars", "pyspark"]) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason='DataFrame.iloc[:, 0] (column name="playerID") are different', +) def test_uncorrelated_subquery(backend, batting, batting_df): subset_batting = batting[batting.yearID <= 2000] expr = batting[_.yearID == subset_batting.yearID.max()]["playerID", "yearID"] @@ -833,6 +848,11 @@ def test_typeof(con): @pytest.mark.notimpl(["pyspark"], condition=is_older_than("pyspark", "3.5.0")) @pytest.mark.notyet(["dask"], reason="not supported by the backend") @pytest.mark.notyet(["exasol"], raises=ExaQueryError, reason="not supported by exasol") +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="https://github.com/risingwavelabs/risingwave/issues/1343", +) def test_isin_uncorrelated( backend, batting, awards_players, batting_df, awards_players_df ): @@ -985,6 +1005,11 @@ def test_memtable_column_naming_mismatch(backend, con, monkeypatch, df, columns) ) @pytest.mark.notimpl(["druid", "flink"], reason="no sqlglot dialect", raises=ValueError) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_many_subqueries(con, snapshot): def query(t, group_cols): t2 = t.mutate(key=ibis.row_number().over(ibis.window(order_by=group_cols))) @@ -1012,6 +1037,11 @@ def query(t, group_cols): reason="invalid code generated for unnesting a struct", raises=TrinoUserError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected ), found: TEXT at line:3, column:219 Near "))]) AS anon_1(f1"', +) def test_pivot_longer(backend): diamonds = backend.diamonds df = diamonds.execute() @@ -1126,6 +1156,11 @@ def test_pivot_wider(backend): ["exasol"], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function last(double precision) does not exist, do you mean left or least", +) def test_distinct_on_keep(backend, on, keep): from ibis import _ @@ -1191,6 +1226,11 @@ def test_distinct_on_keep(backend, on, keep): raises=com.OperationNotDefinedError, reason="backend doesn't implement deduplication", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function first(double precision) does not exist", +) def test_distinct_on_keep_is_none(backend, on): from ibis import _ @@ -1209,7 +1249,7 @@ def test_distinct_on_keep_is_none(backend, on): assert len(result) == len(expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "flink", "exasol"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "flink", "exasol"]) @pytest.mark.notyet( [ "sqlite", @@ -1226,7 +1266,7 @@ def test_hash_consistent(backend, alltypes): assert h1.dtype in ("i8", "uint64") # polars likes returning uint64 for this -@pytest.mark.notimpl(["pandas", "dask", "oracle", "snowflake", "sqlite"]) +@pytest.mark.notimpl(["pandas", "dask", "oracle", "risingwave", "snowflake", "sqlite"]) @pytest.mark.parametrize( ("from_val", "to_type", "expected"), [ @@ -1275,6 +1315,7 @@ def test_try_cast(con, from_val, to_type, expected): "oracle", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", ] @@ -1312,6 +1353,7 @@ def test_try_cast_null(con, from_val, to_type): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "exasol", @@ -1337,6 +1379,7 @@ def test_try_cast_table(backend, con): "mysql", "oracle", "postgres", + "risingwave", "snowflake", "sqlite", "exasol", @@ -1377,9 +1420,31 @@ def test_try_cast_func(con, from_val, to_type, func): ### NONE/ZERO start # no stop param(slice(None, 0), lambda _: 0, id="[:0]"), - param(slice(None, None), lambda t: t.count().to_pandas(), id="[:]"), + param( + slice(None, None), + lambda t: t.count().to_pandas(), + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), + ], + id="[:]", + ), param(slice(0, 0), lambda _: 0, id="[0:0]"), - param(slice(0, None), lambda t: t.count().to_pandas(), id="[0:]"), + param( + slice(0, None), + lambda t: t.count().to_pandas(), + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), + ], + id="[0:]", + ), # positive stop param(slice(None, 2), lambda _: 2, id="[:2]"), param(slice(0, 2), lambda _: 2, id="[0:2]"), @@ -1434,6 +1499,11 @@ def test_try_cast_func(con, from_val, to_type, func): reason="impala doesn't support OFFSET without ORDER BY", ), pytest.mark.notyet(["oracle"], raises=com.UnsupportedArgumentError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", + ), ], ), # positive stop @@ -1520,6 +1590,11 @@ def test_static_table_slice(backend, slc, expected_count_fn): raises=com.UnsupportedArgumentError, reason="Removed half-baked dynamic offset functionality for now", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", +) @pytest.mark.notyet( ["trino"], raises=TrinoUserError, @@ -1610,6 +1685,11 @@ def test_dynamic_table_slice(backend, slc, expected_count_fn): reason="doesn't support dynamic limit/offset; compiles incorrectly in sqlglot", raises=AssertionError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="risingwave doesn't support limit/offset", +) def test_dynamic_table_slice_with_computed_offset(backend): t = backend.functional_alltypes @@ -1629,6 +1709,11 @@ def test_dynamic_table_slice_with_computed_offset(backend): @pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_sample(backend): t = backend.functional_alltypes.filter(_.int_col >= 2) @@ -1645,6 +1730,11 @@ def test_sample(backend): @pytest.mark.notimpl(["druid", "flink", "polars", "snowflake"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_sample_memtable(con, backend): df = pd.DataFrame({"x": [1, 2, 3, 4]}) res = con.execute(ibis.memtable(df).sample(0.5)) @@ -1665,6 +1755,7 @@ def test_sample_memtable(con, backend): "oracle", "polars", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -1700,6 +1791,11 @@ def test_substitute(backend): ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" ) @pytest.mark.notimpl(["flink"], reason="no sqlglot dialect", raises=ValueError) +@pytest.mark.notimpl( + ["risingwave"], + raises=ValueError, + reason="risingwave doesn't support sqlglot.dialects.dialect.Dialect", +) def test_simple_memtable_construct(con): t = ibis.memtable({"a": [1, 2]}) expr = t.a diff --git a/ibis/backends/tests/test_json.py b/ibis/backends/tests/test_json.py index 78d379ae0bde..98c72e6934d5 100644 --- a/ibis/backends/tests/test_json.py +++ b/ibis/backends/tests/test_json.py @@ -40,13 +40,17 @@ ["flink"], reason="https://github.com/ibis-project/ibis/pull/6920#discussion_r1373212503", ) +@pytest.mark.broken( + ["risingwave"], + reason="TODO(Kexiang): order mismatch in array", +) def test_json_getitem(json_t, expr_fn, expected): expr = expr_fn(json_t) result = expr.execute() tm.assert_series_equal(result.fillna(pd.NA), expected.fillna(pd.NA)) -@pytest.mark.notimpl(["dask", "mysql", "pandas"]) +@pytest.mark.notimpl(["dask", "mysql", "pandas", "risingwave"]) @pytest.mark.notyet(["bigquery", "sqlite"], reason="doesn't support maps") @pytest.mark.notyet(["postgres"], reason="only supports map") @pytest.mark.notyet( @@ -70,7 +74,7 @@ def test_json_map(backend, json_t): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "mysql", "pandas"]) +@pytest.mark.notimpl(["dask", "mysql", "pandas", "risingwave"]) @pytest.mark.notyet(["sqlite"], reason="doesn't support arrays") @pytest.mark.notyet( ["pyspark", "trino", "flink"], reason="should work but doesn't deserialize JSON" diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index cb0386dc0ff8..6c7a5e717f12 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -34,6 +35,11 @@ def test_map_table(backend): @pytest.mark.xfail_version( duckdb=["duckdb<0.8.0"], raises=exc.UnsupportedOperationError ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_column_map_values(backend): table = backend.map expr = table.select("idx", vals=table.kv.values()).order_by("idx") @@ -64,6 +70,11 @@ def test_column_map_merge(backend): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_keys(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.keys().name("tmp") @@ -79,6 +90,11 @@ def test_literal_map_keys(con): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_values(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.values().name("tmp") @@ -87,7 +103,7 @@ def test_literal_map_values(con): assert np.array_equal(result, ["a", "b"]) -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -103,7 +119,9 @@ def test_scalar_isin_literal_map_keys(con): assert con.execute(false) == False # noqa: E712 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -124,6 +142,11 @@ def test_map_scalar_contains_key_scalar(con): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_scalar_contains_key_column(backend, alltypes, df): value = {"1": "a", "3": "c"} mapping = ibis.literal(value) @@ -133,7 +156,9 @@ def test_map_scalar_contains_key_column(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -149,7 +174,9 @@ def test_map_column_contains_key_scalar(backend, alltypes, df): backend.assert_series_equal(result, series) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -164,7 +191,9 @@ def test_map_column_contains_key_column(alltypes): assert result.all() -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -183,6 +212,11 @@ def test_literal_map_merge(con): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_getitem_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -200,6 +234,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_literal_map_get_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -220,19 +259,27 @@ def test_literal_map_get_broadcast(backend, alltypes, df): [1, 2], id="string", marks=pytest.mark.notyet( - ["postgres"], reason="only support maps of string -> string" + ["postgres", "risingwave"], + reason="only support maps of string -> string", ), ), param(["a", "b"], ["1", "2"], id="int"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_construct_dict(con, keys, values): expr = ibis.map(keys, values) result = con.execute(expr.name("tmp")) assert result == dict(zip(keys, values)) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=exc.OperationNotDefinedError, @@ -246,7 +293,9 @@ def test_map_construct_array_column(con, alltypes, df): assert result.to_list() == expected.to_list() -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -258,7 +307,9 @@ def test_map_get_with_compatible_value_smaller(con): assert con.execute(expr) == 3 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -270,7 +321,9 @@ def test_map_get_with_compatible_value_bigger(con): assert con.execute(expr) == 3000 -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -283,7 +336,9 @@ def test_map_get_with_incompatible_value_different_kind(con): @pytest.mark.parametrize("null_value", [None, ibis.NA]) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -303,6 +358,11 @@ def test_map_get_with_null_on_not_nullable(con, null_value): raises=NotImplementedError, reason="No translation rule for map", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_get_with_null_on_null_type_with_null(con, null_value): value = ibis.literal({"A": None, "B": None}) expr = value.get("C", null_value) @@ -310,7 +370,9 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value): assert pd.isna(result) -@pytest.mark.notyet(["postgres"], reason="only support maps of string -> string") +@pytest.mark.notyet( + ["postgres", "risingwave"], reason="only support maps of string -> string" +) @pytest.mark.notimpl( ["flink"], raises=NotImplementedError, @@ -327,6 +389,11 @@ def test_map_get_with_null_on_null_type_with_non_null(con): raises=exc.IbisError, reason="`tbl_properties` is required when creating table with schema", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_create_table(con, temp_table): t = con.create_table( temp_table, @@ -340,6 +407,11 @@ def test_map_create_table(con, temp_table): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function hstore(character varying[], character varying[]) does not exist", +) def test_map_length(con): expr = ibis.literal(dict(a="A", b="B")).length() assert con.execute(expr) == 2 diff --git a/ibis/backends/tests/test_network.py b/ibis/backends/tests/test_network.py index e6048ee907b2..dca5815c6855 100644 --- a/ibis/backends/tests/test_network.py +++ b/ibis/backends/tests/test_network.py @@ -20,6 +20,7 @@ "trino": "varchar(17)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(17) NOT NULL", } @@ -50,6 +51,7 @@ def test_macaddr_literal(con, backend): "trino": "127.0.0.1", "impala": "127.0.0.1", "postgres": "127.0.0.1", + "risingwave": "127.0.0.1", "pandas": "127.0.0.1", "pyspark": "127.0.0.1", "mysql": "127.0.0.1", @@ -67,6 +69,7 @@ def test_macaddr_literal(con, backend): "trino": "varchar(9)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(9) NOT NULL", }, id="ipv4", @@ -82,6 +85,7 @@ def test_macaddr_literal(con, backend): "trino": "2001:db8::1", "impala": "2001:db8::1", "postgres": "2001:db8::1", + "risingwave": "2001:db8::1", "pandas": "2001:db8::1", "pyspark": "2001:db8::1", "mysql": "2001:db8::1", @@ -99,6 +103,7 @@ def test_macaddr_literal(con, backend): "trino": "varchar(11)", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(11) NOT NULL", }, id="ipv6", diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 049aa20f0a8d..5537155bc64b 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -52,6 +52,7 @@ "trino": "integer", "duckdb": "TINYINT", "postgres": "integer", + "risingwave": "integer", "flink": "TINYINT NOT NULL", }, id="int8", @@ -67,6 +68,7 @@ "trino": "integer", "duckdb": "SMALLINT", "postgres": "integer", + "risingwave": "integer", "flink": "SMALLINT NOT NULL", }, id="int16", @@ -82,6 +84,7 @@ "trino": "integer", "duckdb": "INTEGER", "postgres": "integer", + "risingwave": "integer", "flink": "INT NOT NULL", }, id="int32", @@ -97,6 +100,7 @@ "trino": "integer", "duckdb": "BIGINT", "postgres": "integer", + "risingwave": "integer", "flink": "BIGINT NOT NULL", }, id="int64", @@ -112,6 +116,7 @@ "trino": "integer", "duckdb": "UTINYINT", "postgres": "integer", + "risingwave": "integer", "flink": "TINYINT NOT NULL", }, id="uint8", @@ -127,6 +132,7 @@ "trino": "integer", "duckdb": "USMALLINT", "postgres": "integer", + "risingwave": "integer", "flink": "SMALLINT NOT NULL", }, id="uint16", @@ -142,6 +148,7 @@ "trino": "integer", "duckdb": "UINTEGER", "postgres": "integer", + "risingwave": "integer", "flink": "INT NOT NULL", }, id="uint32", @@ -157,6 +164,7 @@ "trino": "integer", "duckdb": "UBIGINT", "postgres": "integer", + "risingwave": "integer", "flink": "BIGINT NOT NULL", }, id="uint64", @@ -172,6 +180,7 @@ "trino": "real", "duckdb": "FLOAT", "postgres": "numeric", + "risingwave": "numeric", "flink": "FLOAT NOT NULL", }, marks=[ @@ -199,6 +208,7 @@ "trino": "real", "duckdb": "FLOAT", "postgres": "numeric", + "risingwave": "numeric", "flink": "FLOAT NOT NULL", }, id="float32", @@ -214,6 +224,7 @@ "trino": "double", "duckdb": "DOUBLE", "postgres": "numeric", + "risingwave": "numeric", "flink": "DOUBLE NOT NULL", }, id="float64", @@ -245,6 +256,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "duckdb": decimal.Decimal("1.1"), "impala": decimal.Decimal("1"), "postgres": decimal.Decimal("1.1"), + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": decimal.Decimal("1"), @@ -263,6 +275,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(18,3)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", "flink": "DECIMAL(38, 18) NOT NULL", }, marks=[ @@ -285,6 +298,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "duckdb": decimal.Decimal("1.100000000"), "impala": decimal.Decimal("1.1"), "postgres": decimal.Decimal("1.1"), + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "mysql": decimal.Decimal("1.1"), @@ -305,6 +319,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(38,9)", "duckdb": "DECIMAL(38,9)", "postgres": "numeric", + "risingwave": "numeric", "flink": "DECIMAL(38, 9) NOT NULL", }, marks=[pytest.mark.notimpl(["exasol"], raises=ExaQueryError)], @@ -318,6 +333,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": decimal.Decimal("1.1"), "dask": decimal.Decimal("1.1"), "postgres": decimal.Decimal("1.1"), + "risingwave": 1.1, "pandas": decimal.Decimal("1.1"), "pyspark": decimal.Decimal("1.1"), "clickhouse": decimal.Decimal( @@ -333,6 +349,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "trino": "decimal(2,1)", "duckdb": "DECIMAL(18,3)", "postgres": "numeric", + "risingwave": "numeric", }, marks=[ pytest.mark.notimpl(["exasol"], raises=ExaQueryError), @@ -370,6 +387,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "bigquery": float("inf"), "sqlite": decimal.Decimal("Infinity"), "postgres": decimal.Decimal("Infinity"), + "risingwave": float("nan"), "pandas": decimal.Decimal("Infinity"), "dask": decimal.Decimal("Infinity"), "pyspark": decimal.Decimal("Infinity"), @@ -380,6 +398,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": "real", "postgres": "numeric", "duckdb": "FLOAT", + "risingwave": "numeric", }, marks=[ pytest.mark.broken( @@ -439,6 +458,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "bigquery": float("-inf"), "sqlite": decimal.Decimal("-Infinity"), "postgres": decimal.Decimal("-Infinity"), + "risingwave": float("nan"), "pandas": decimal.Decimal("-Infinity"), "dask": decimal.Decimal("-Infinity"), "pyspark": decimal.Decimal("-Infinity"), @@ -449,6 +469,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": "real", "postgres": "numeric", "duckdb": "FLOAT", + "risingwave": "numeric", }, marks=[ pytest.mark.broken( @@ -509,6 +530,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "snowflake": float("nan"), "sqlite": None, "postgres": float("nan"), + "risingwave": float("nan"), "pandas": decimal.Decimal("NaN"), "dask": decimal.Decimal("NaN"), "pyspark": decimal.Decimal("NaN"), @@ -521,6 +543,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "sqlite": "null", "postgres": "numeric", "duckdb": "FLOAT", + "risingwave": "numeric", }, marks=[ pytest.mark.broken( @@ -730,14 +753,55 @@ def test_isnan_isinf( L(5.556).log(2), math.log(5.556, 2), id="log-base", - marks=[pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)], + marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), + ], + ), + param( + L(5.556).ln(), + math.log(5.556), + id="ln", ), param(L(5.556).ln(), math.log(5.556), id="ln"), param( L(5.556).log2(), math.log(5.556, 2), id="log2", - marks=[pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)], + marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), + ], + ), + param( + L(5.556).log10(), + math.log10(5.556), + marks=pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError), + id="log10", + ), + param( + L(5.556).radians(), + math.radians(5.556), + id="radians", + ), + param( + L(5.556).degrees(), + math.degrees(5.556), + id="degrees", + ), + param( + L(11) % 3, + 11 % 3, + marks=pytest.mark.notimpl(["exasol"], raises=ExaQueryError), + id="mod", ), param(L(5.556).log10(), math.log10(5.556), id="log10"), param(L(5.556).radians(), math.radians(5.556), id="radians"), @@ -874,7 +938,14 @@ def test_simple_math_functions_columns( param( lambda t: t.double_col.add(1).log(2), lambda t: np.log2(t.double_col + 1), - marks=[pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)], + marks=[ + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), + ], id="log2", ), param( @@ -908,6 +979,11 @@ def test_simple_math_functions_columns( reason="Base greatest(9000, t0.bigint_col) for logarithm not supported!", ), pytest.mark.notimpl(["polars"], raises=com.UnsupportedArgumentError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function log10(numeric, numeric) does not exist", + ), ], ), ], @@ -1131,6 +1207,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1143,6 +1220,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1155,6 +1233,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1167,6 +1246,7 @@ def test_floating_mod(backend, alltypes, df): reason="Oracle doesn't do integer division by zero", ), pytest.mark.never(["impala"], reason="doesn't allow divide by zero"), + pytest.mark.notyet(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -1211,6 +1291,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): ( { "postgres": None, + "risingwave": None, "mysql": 10, "snowflake": 38, "trino": 18, @@ -1220,6 +1301,7 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): }, { "postgres": None, + "risingwave": None, "mysql": 0, "snowflake": 0, "trino": 3, @@ -1254,6 +1336,11 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator): ], reason="Not SQLAlchemy backends", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: unsupported data type: NUMERIC(5)", +) def test_sa_default_numeric_precision_and_scale( con, backend, default_precisions, default_scales, temp_table ): @@ -1289,6 +1376,11 @@ def test_sa_default_numeric_precision_and_scale( @pytest.mark.notimpl(["dask", "pandas", "polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function random() does not exist", +) def test_random(con): expr = ibis.random() result = con.execute(expr) diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 3fea928999e2..c0f3a98b3a04 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -38,6 +39,11 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value): ) @pytest.mark.notimpl(["trino", "druid"]) @pytest.mark.broken(["oracle"], raises=OracleDatabaseError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_scalar_parameter(backend, alltypes, start_string, end_string): start, end = ibis.param(dt.date), ibis.param(dt.date) @@ -76,6 +82,7 @@ def test_scalar_param_array(con): "impala", "flink", "postgres", + "risingwave", "druid", "oracle", "exasol", @@ -108,6 +115,11 @@ def test_scalar_param_struct(con): "sql= SELECT MAP_FROM_ARRAYS(ARRAY['a', 'b', 'c'], ARRAY['ghi', 'def', 'abc']) '[' 'b' ']' AS `MapGet(param_0, 'b', None)`" ), ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_scalar_param_map(con): value = {"a": "ghi", "b": "def", "c": "abc"} param = ibis.param(dt.Map(dt.string, dt.string)) @@ -168,6 +180,11 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col): ids=["string", "date", "datetime"], ) @pytest.mark.notimpl(["druid", "oracle"]) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_scalar_param_date(backend, alltypes, value): param = ibis.param("date") ds_col = alltypes.date_string_col @@ -203,6 +220,7 @@ def test_scalar_param_date(backend, alltypes, value): @pytest.mark.notimpl( [ "postgres", + "risingwave", "datafusion", "clickhouse", "polars", diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 4e1739e30cf0..d967e45d80c7 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -93,6 +93,7 @@ def gzip_csv(data_dir, tmp_path): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -119,6 +120,7 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -142,6 +144,7 @@ def test_register_csv_gz(con, data_dir, gzip_csv): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -198,6 +201,7 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -234,6 +238,7 @@ def test_register_parquet( "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -273,6 +278,7 @@ def test_register_iterator_parquet( "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -303,6 +309,7 @@ def test_register_pandas(con): "mysql", "pandas", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -328,6 +335,7 @@ def test_register_pyarrow_tables(con): "mysql", "pandas", "postgres", + "risingwave", "snowflake", "sqlite", "trino", @@ -370,6 +378,7 @@ def test_csv_reregister_schema(con, tmp_path): "pandas", "polars", "postgres", + "risingwave", "pyspark", "snowflake", "sqlite", @@ -400,7 +409,7 @@ def test_register_garbage(con, monkeypatch): ], ) @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] + ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") @@ -431,7 +440,17 @@ def ft_data(data_dir): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + [ + "flink", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "risingwave", + "sqlite", + "trino", + ] ) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet") @@ -450,7 +469,17 @@ def test_read_parquet_glob(con, tmp_path, ft_data): @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "pandas", "postgres", "sqlite", "trino"] + [ + "flink", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "risingwave", + "sqlite", + "trino", + ] ) def test_read_csv_glob(con, tmp_path, ft_data): pc = pytest.importorskip("pyarrow.csv") @@ -479,6 +508,7 @@ def test_read_csv_glob(con, tmp_path, ft_data): "mysql", "pandas", "postgres", + "risingwave", "sqlite", "trino", ] @@ -527,7 +557,7 @@ def num_diamonds(data_dir): [param(None, id="default"), param("fancy_stones", id="file_name")], ) @pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "sqlite", "trino"] + ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 0fb52fa10f0f..41102559ad9c 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -67,19 +68,26 @@ def test_union_mixed_distinct(backend, union_subsets): [ param( False, - marks=pytest.mark.notyet( - [ - "impala", - "bigquery", - "dask", - "pandas", - "sqlite", - "snowflake", - "mssql", - "exasol", - ], - reason="backend doesn't support INTERSECT ALL", - ), + marks=[ + pytest.mark.notyet( + [ + "impala", + "bigquery", + "dask", + "pandas", + "sqlite", + "snowflake", + "mssql", + "exasol", + ], + reason="backend doesn't support INTERSECT ALL", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: INTERSECT all", + ), + ], id="all", ), param(True, id="distinct"), @@ -114,19 +122,26 @@ def test_intersect(backend, alltypes, df, distinct): [ param( False, - marks=pytest.mark.notyet( - [ - "impala", - "bigquery", - "dask", - "pandas", - "sqlite", - "snowflake", - "mssql", - "exasol", - ], - reason="backend doesn't support EXCEPT ALL", - ), + marks=[ + pytest.mark.notyet( + [ + "impala", + "bigquery", + "dask", + "pandas", + "sqlite", + "snowflake", + "mssql", + "exasol", + ], + reason="backend doesn't support EXCEPT ALL", + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: EXCEPT all", + ), + ], id="all", ), param(True, id="distinct"), @@ -193,18 +208,25 @@ def test_top_level_union(backend, con, alltypes, distinct): True, param( False, - marks=pytest.mark.notimpl( - [ - "impala", - "bigquery", - "dask", - "mssql", - "pandas", - "snowflake", - "sqlite", - "exasol", - ] - ), + marks=[ + pytest.mark.notimpl( + [ + "impala", + "bigquery", + "dask", + "mssql", + "pandas", + "snowflake", + "sqlite", + "exasol", + ] + ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: INTERSECT all", + ), + ], ), ], ) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 9fdecc6100c1..edf4617ce76f 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -11,7 +11,7 @@ sa = pytest.importorskip("sqlalchemy") sg = pytest.importorskip("sqlglot") -pytestmark = pytest.mark.notimpl(["flink"]) +pytestmark = pytest.mark.notimpl(["flink", "risingwave"]) simple_literal = param(ibis.literal(1), id="simple_literal") array_literal = param( diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 3ff40d13ce67..7b9333345775 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -33,6 +34,7 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(6) NOT NULL", }, id="string", @@ -48,14 +50,22 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(7) NOT NULL", }, id="string-quote1", - marks=pytest.mark.broken( - ["oracle"], - raises=OracleDatabaseError, - reason="ORA-01741: illegal zero length identifier", - ), + marks=[ + pytest.mark.broken( + ["oracle"], + raises=OracleDatabaseError, + reason="ORA-01741: illegal zero length identifier", + ), + pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', + ), + ], ), param( 'STRI"NG', @@ -68,14 +78,22 @@ "duckdb": "VARCHAR", "impala": "STRING", "postgres": "text", + "risingwave": "text", "flink": "CHAR(7) NOT NULL", }, id="string-quote2", - marks=pytest.mark.broken( - ["oracle"], - raises=OracleDatabaseError, - reason="ORA-25716", - ), + marks=[ + pytest.mark.broken( + ["oracle"], + raises=OracleDatabaseError, + reason="ORA-25716", + ), + pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason='sql parser error: Expected end of statement, found: "NG\'" at line:1, column:31 Near "SELECT \'STRI"NG\' AS "\'STRI""', + ), + ], ), ], ) @@ -215,6 +233,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -225,6 +248,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -240,6 +268,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["druid"], reason="No posix support", raises=AssertionError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -250,6 +283,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -262,6 +300,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -274,6 +317,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -286,6 +334,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -296,6 +349,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -306,6 +364,11 @@ def uses_java_re(t): pytest.mark.notimpl( ["mssql", "exasol"], raises=com.OperationNotDefinedError ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function textregexeq(character varying, character varying) does not exist", + ), ], ), param( @@ -846,6 +909,7 @@ def test_substr_with_null_values(backend, alltypes, df): "mysql", "polars", "postgres", + "risingwave", "pyspark", "druid", "oracle", @@ -917,6 +981,11 @@ def test_multiple_subs(con): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function levenshtein(character varying, character varying) does not exist", +) @pytest.mark.parametrize( "right", ["sitting", ibis.literal("sitting")], ids=["python", "ibis"] ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index d368c6b16305..95f7df9f4ea5 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -65,7 +65,7 @@ def test_all_fields(struct, struct_df): _NULL_STRUCT_LITERAL = ibis.NA.cast("struct") -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" @@ -79,7 +79,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" @@ -95,7 +95,7 @@ def test_null_literal(backend, con, field): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from literals" ) @@ -111,7 +111,7 @@ def test_struct_column(backend, alltypes, df): tm.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "polars"]) +@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from collect" ) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 8b5f7e897815..c8e2210b4129 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa import sqlglot as sg from pytest import param @@ -155,6 +156,11 @@ def test_timestamp_extract(backend, alltypes, df, attr): "Ref: https://nightlies.apache.org/flink/flink-docs-release-1.13/docs/dev/table/functions/systemfunctions/#temporal-functions" ), ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -434,6 +440,7 @@ def test_timestamp_extract_week_of_year(backend, alltypes, df): "impala", "mysql", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -634,6 +641,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'WEEK'. Was expecting one of: DAY, DAYS, HOUR", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: week", + ), ], ), param("D", pd.offsets.DateOffset), @@ -652,6 +664,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'MILLISECOND'. Was expecting one of: DAY, DAYS, HOUR, ...", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: millisecond", + ), ], ), param( @@ -671,6 +688,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=Py4JJavaError, reason="ParseException: Encountered 'MICROSECOND'. Was expecting one of: DAY, DAYS, HOUR, ...", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: microsecond", + ), ], ), ], @@ -721,7 +743,14 @@ def convert_to_offset(offset, displacement_type=displacement_type): ), param( "W", - marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + marks=[ + pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Bind error: Invalid unit: week", + ), + ], ), "D", ], @@ -819,7 +848,7 @@ def convert_to_offset(x): id="timestamp-add-interval-binop", marks=[ pytest.mark.notimpl( - ["dask", "snowflake", "sqlite", "bigquery", "exasol"], + ["dask", "snowflake", "sqlite", "bigquery", "exasol", "risingwave"], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -839,7 +868,14 @@ def convert_to_offset(x): id="timestamp-add-interval-binop-different-units", marks=[ pytest.mark.notimpl( - ["sqlite", "polars", "snowflake", "bigquery", "exasol"], + [ + "sqlite", + "polars", + "snowflake", + "bigquery", + "exasol", + "risingwave", + ], raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["impala"], raises=com.UnsupportedOperationError), @@ -963,6 +999,11 @@ def convert_to_offset(x): raises=AttributeError, reason="'StringColumn' object has no attribute 'date'", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", + ), pytest.mark.broken( ["flink"], raises=com.UnsupportedOperationError, @@ -1409,6 +1450,13 @@ def test_interval_add_cast_column(backend, alltypes, df): raises=com.UnsupportedArgumentError, reason="Polars does not support columnar argument StringConcat()", ), + pytest.mark.notimpl( + [ + "risingwave", + ], + raises=AttributeError, + reason="Neither 'concat' object nor 'Comparator' object has an attribute 'value'", + ), pytest.mark.notyet(["dask"], raises=com.OperationNotDefinedError), pytest.mark.notyet(["impala"], raises=com.UnsupportedOperationError), pytest.mark.notimpl(["druid", "flink"], raises=AttributeError), @@ -1511,7 +1559,7 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern): ], ) @pytest.mark.notimpl( - ["mysql", "postgres", "sqlite", "druid", "oracle"], + ["mysql", "postgres", "risingwave", "sqlite", "druid", "oracle"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) @@ -1593,6 +1641,7 @@ def test_integer_to_timestamp(backend, con, unit): [ "dask", "pandas", + "risingwave", "clickhouse", "sqlite", "datafusion", @@ -1631,6 +1680,11 @@ def test_string_to_timestamp(alltypes, fmt): reason="DayOfWeekName is not supported in Flink", ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", +) def test_day_of_week_scalar(con, date, expected_index, expected_day): expr = ibis.literal(date).cast(dt.date) result_index = con.execute(expr.day_of_week.index().name("tmp")) @@ -1656,6 +1710,11 @@ def test_day_of_week_scalar(con, date, expected_index, expected_day): ), ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", +) def test_day_of_week_column(backend, alltypes, df): expr = alltypes.timestamp_col.day_of_week @@ -1692,6 +1751,11 @@ def test_day_of_week_column(backend, alltypes, df): "Ref: https://nightlies.apache.org/flink/flink-docs-release-1.13/docs/dev/table/functions/systemfunctions/#temporal-functions" ), ), + pytest.mark.broken( + ["risingwave"], + raises=AssertionError, + reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14670", + ), ], ), ], @@ -1759,6 +1823,7 @@ def test_now_from_projection(alltypes): "snowflake": "DATE", "sqlite": "text", "trino": "date", + "risingwave": "date", } @@ -1769,6 +1834,11 @@ def test_now_from_projection(alltypes): @pytest.mark.notimpl( ["oracle"], raises=OracleDatabaseError, reason="ORA-00936 missing expression" ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_literal(con, backend): expr = ibis.date(2022, 2, 4) result = con.execute(expr) @@ -1788,6 +1858,7 @@ def test_date_literal(con, backend): "trino": "timestamp(3)", "duckdb": "TIMESTAMP", "postgres": "timestamp without time zone", + "risingwave": "timestamp without time zone", "flink": "TIMESTAMP(6) NOT NULL", } @@ -1797,6 +1868,11 @@ def test_date_literal(con, backend): raises=com.OperationNotDefinedError, ) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", +) def test_timestamp_literal(con, backend): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0) result = con.execute(expr) @@ -1850,6 +1926,11 @@ def test_timestamp_literal(con, backend): ", , , )" ), ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", +) def test_timestamp_with_timezone_literal(con, timezone, expected): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0).cast(dt.Timestamp(timezone=timezone)) result = con.execute(expr) @@ -1866,6 +1947,7 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): "trino": "time(3)", "duckdb": "TIME", "postgres": "time without time zone", + "risingwave": "time without time zone", } @@ -1877,6 +1959,11 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): ["clickhouse", "impala", "exasol"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_time(integer, integer, integer) does not exist", +) def test_time_literal(con, backend): expr = ibis.time(16, 20, 0) result = con.execute(expr) @@ -1953,6 +2040,7 @@ def test_extract_time_from_timestamp(con, microsecond): "trino": "interval day to second", "duckdb": "INTERVAL", "postgres": "interval", + "risingwave": "interval", } @@ -2021,6 +2109,11 @@ def test_interval_literal(con, backend): @pytest.mark.broken( ["oracle"], raises=OracleDatabaseError, reason="ORA-00936: missing expression" ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_column_from_ymd(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.date(c.year(), c.month(), c.day()) @@ -2041,6 +2134,11 @@ def test_date_column_from_ymd(backend, con, alltypes, df): reason="StringColumn' object has no attribute 'year'", ) @pytest.mark.notyet(["impala", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_timestamp(smallint, smallint, smallint, smallint, smallint, smallint) does not exist", +) def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.timestamp( @@ -2169,6 +2267,11 @@ def build_date_col(t): param(lambda _: DATE, build_date_col, id="date_column"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn): left = left_fn(alltypes) right = right_fn(alltypes) @@ -2291,6 +2394,11 @@ def test_large_timestamp(con): raises=AssertionError, ), pytest.mark.notimpl(["exasol"], raises=AssertionError), + pytest.mark.notyet( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Parse error: timestamp without time zone Can't cast string to timestamp (expected format is YYYY-MM-DD HH:MM:SS[.D+{up to 6 digits}] or YYYY-MM-DD HH:MM or YYYY-MM-DD or ISO 8601 format)", + ), ], ), ], @@ -2321,6 +2429,11 @@ def test_timestamp_precision_output(con, ts, scale, unit): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notyet( + ["risingwave"], + reason="postgres doesn't have any easy way to accurately compute the delta in specific units", + raises=com.OperationNotDefinedError, +) @pytest.mark.parametrize( ("start", "end", "unit", "expected"), [ @@ -2467,6 +2580,11 @@ def test_delta(con, start, end, unit, expected): ], ) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", +) def test_timestamp_bucket(backend, kws, pd_freq): ts = backend.functional_alltypes.timestamp_col.execute().rename("ts") res = backend.functional_alltypes.timestamp_col.bucket(**kws).execute().rename("ts") @@ -2502,6 +2620,11 @@ def test_timestamp_bucket(backend, kws, pd_freq): ) @pytest.mark.parametrize("offset_mins", [2, -2], ids=["pos", "neg"]) @pytest.mark.notimpl(["exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function date_bin(interval, timestamp without time zone, timestamp without time zone) does not exist", +) def test_timestamp_bucket_offset(backend, offset_mins): ts = backend.functional_alltypes.timestamp_col expr = ts.bucket(minutes=5, offset=ibis.interval(minutes=offset_mins)) @@ -2612,6 +2735,11 @@ def test_time_literal_sql(dialect, snapshot, micros): param(datetime.date.fromisoformat, id="fromstring"), ], ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="function make_date(integer, integer, integer) does not exist", +) def test_date_scalar(con, value, func): expr = ibis.date(func(value)).name("tmp") diff --git a/ibis/backends/tests/test_timecontext.py b/ibis/backends/tests/test_timecontext.py index 2488ee5b3659..88376a4f961b 100644 --- a/ibis/backends/tests/test_timecontext.py +++ b/ibis/backends/tests/test_timecontext.py @@ -19,6 +19,7 @@ "impala", "mysql", "postgres", + "risingwave", "sqlite", "snowflake", "polars", diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index f5a3a5d01e61..75021f577133 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -19,6 +19,7 @@ "oracle", "pandas", "trino", + "risingwave", ] ) diff --git a/ibis/backends/tests/test_uuid.py b/ibis/backends/tests/test_uuid.py index a01a1c124ad7..3f2d49cd6d59 100644 --- a/ibis/backends/tests/test_uuid.py +++ b/ibis/backends/tests/test_uuid.py @@ -4,6 +4,7 @@ import uuid import pytest +import sqlalchemy import ibis import ibis.common.exceptions as com @@ -28,6 +29,11 @@ @pytest.mark.notimpl(["polars"], raises=NotImplementedError) @pytest.mark.notimpl(["datafusion"], raises=Exception) +@pytest.mark.notimpl( + ["risingwave"], + raises=sqlalchemy.exc.InternalError, + reason="Feature is not yet implemented: unsupported data type: UUID", +) def test_uuid_literal(con, backend): backend_name = backend.name() diff --git a/ibis/backends/tests/test_vectorized_udf.py b/ibis/backends/tests/test_vectorized_udf.py index c2414db0b58d..fa6728acb7f2 100644 --- a/ibis/backends/tests/test_vectorized_udf.py +++ b/ibis/backends/tests/test_vectorized_udf.py @@ -10,7 +10,7 @@ import ibis.expr.datatypes as dt from ibis.legacy.udf.vectorized import analytic, elementwise, reduction -pytestmark = pytest.mark.notimpl(["druid", "oracle"]) +pytestmark = pytest.mark.notimpl(["druid", "oracle", "risingwave"]) def _format_udf_return_type(func, result_formatter): diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index 1e3627afab8f..f5532f7b7e8b 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytest +import sqlalchemy as sa from pytest import param import ibis @@ -137,6 +138,11 @@ def calc_zscore(s): raises=com.OperationNotDefinedError, ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", + ), ], ), param( @@ -148,6 +154,12 @@ def calc_zscore(s): ["clickhouse", "exasol"], raises=com.OperationNotDefinedError ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: cume_dist", + ), ], ), param( @@ -174,6 +186,11 @@ def calc_zscore(s): raises=com.UnsupportedOperationError, reason="Windows in Flink can only be ordered by a single time column", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: ntile", + ), ], ), param( @@ -212,6 +229,7 @@ def calc_zscore(s): ), pytest.mark.notimpl(["dask"], raises=NotImplementedError), pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError), + pytest.mark.notimpl(["risingwave"], raises=sa.exc.InternalError), ], ), param( @@ -373,7 +391,14 @@ def test_grouped_bounded_expanding_window( lambda t, win: t.double_col.mean().over(win), lambda df: (df.double_col.expanding().mean()), id="mean", - marks=[pytest.mark.notimpl(["dask"], raises=NotImplementedError)], + marks=[ + pytest.mark.notimpl(["dask"], raises=NotImplementedError), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), + ], ), param( # Disabled on PySpark and Spark backends because in pyspark<3.0.0, @@ -393,6 +418,7 @@ def test_grouped_bounded_expanding_window( "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "datafusion", @@ -548,6 +574,7 @@ def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -625,6 +652,11 @@ def test_grouped_unbounded_window( @pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["flink"], raises=com.UnsupportedOperationError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_simple_ungrouped_unbound_following_window( backend, alltypes, ibis_method, pandas_fn ): @@ -652,6 +684,11 @@ def test_simple_ungrouped_unbound_following_window( ["mssql"], raises=Exception, reason="order by constant is not supported" ) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_simple_ungrouped_window_with_scalar_order_by(alltypes): t = alltypes[alltypes.double_col < 50].order_by("id") w = ibis.window(rows=(0, None), order_by=ibis.NA) @@ -680,6 +717,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="default window semantics are different", raises=AssertionError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -713,6 +755,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): raises=Py4JJavaError, reason="CalciteContextException: Argument to function 'NTILE' must be a literal", ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: ntile", + ), ], ), param( @@ -732,6 +779,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -763,6 +811,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "sqlite", "snowflake", "trino", @@ -783,6 +832,13 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): lambda df: df.float_col.shift(1), True, id="ordered-lag", + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), + ], ), param( lambda t, win: t.float_col.lag().over(win), @@ -812,6 +868,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="backend requires ordering", raises=SnowflakeProgrammingError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -819,6 +880,13 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): lambda df: df.float_col.shift(-1), True, id="ordered-lead", + marks=[ + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), + ], ), param( lambda t, win: t.float_col.lead().over(win), @@ -851,6 +919,11 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): reason="backend requires ordering", raises=SnowflakeProgrammingError, ), + pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", + ), ], ), param( @@ -870,6 +943,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -902,6 +976,7 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): "mysql", "oracle", "postgres", + "risingwave", "pyspark", "sqlite", "snowflake", @@ -968,6 +1043,11 @@ def test_ungrouped_unbounded_window( raises=MySQLOperationalError, reason="https://github.com/tobymao/sqlglot/issues/2779", ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: window frame in `RANGE` mode is not supported yet", +) def test_grouped_bounded_range_window(backend, alltypes, df): # Explanation of the range window spec below: # @@ -1023,6 +1103,11 @@ def gb_fn(df): reason="clickhouse doesn't implement percent_rank", raises=com.OperationNotDefinedError, ) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", +) def test_percent_rank_whole_table_no_order_by(backend, alltypes, df): expr = alltypes.mutate(val=lambda t: t.id.percent_rank()) @@ -1069,6 +1154,11 @@ def agg(df): @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_mutate_window_filter(backend, alltypes): t = alltypes win = ibis.window(order_by=[t.id]) @@ -1143,6 +1233,11 @@ def test_first_last(backend): raises=ExaQueryError, reason="database can't handle UTC timestamps in DataFrames", ) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="sql parser error: Expected literal int, found: INTERVAL at line:1, column:99", +) def test_range_expression_bounds(backend): t = ibis.memtable( { @@ -1187,6 +1282,11 @@ def test_range_expression_bounds(backend): @pytest.mark.broken( ["mssql"], reason="lack of support for booleans", raises=PyODBCProgrammingError ) +@pytest.mark.broken( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Unrecognized window function: percent_rank", +) def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): # GH #7631 t = alltypes @@ -1217,6 +1317,11 @@ def test_rank_followed_by_over_call_merge_frames(backend, alltypes, df): ) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notyet(["flink"], raises=com.UnsupportedOperationError) +@pytest.mark.notimpl( + ["risingwave"], + raises=sa.exc.InternalError, + reason="Feature is not yet implemented: Window function with empty PARTITION BY is not supported yet", +) def test_windowed_order_by_sequence_is_preserved(con): table = ibis.memtable({"bool_col": [True, False, False, None, True]}) window = ibis.window( diff --git a/pyproject.toml b/pyproject.toml index 6b5ef6dffa75..c457d899e75e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ snowflake-connector-python = { version = ">=3.0.2,<4,!=3.3.0b1", optional = true sqlalchemy = { version = ">=1.4,<3", optional = true } sqlalchemy-views = { version = ">=0.3.1,<1", optional = true } trino = { version = ">=0.321,<1", optional = true } +sqlalchemy-risingwave = { version = ">=1.0.0,<2", optional = true } [tool.poetry.group.dev.dependencies] codespell = { version = ">=2.2.6,<3", extras = [ @@ -193,6 +194,7 @@ postgres = ["psycopg2"] pyspark = ["pyspark", "packaging"] snowflake = ["snowflake-connector-python", "packaging"] sqlite = ["regex"] +risingwave = ["psycopg2"] trino = ["trino"] # non-backend extras visualization = ["graphviz"] @@ -216,6 +218,7 @@ oracle = "ibis.backends.oracle" pandas = "ibis.backends.pandas" polars = "ibis.backends.polars" postgres = "ibis.backends.postgres" +risingwave = "ibis.backends.risingwave" pyspark = "ibis.backends.pyspark" snowflake = "ibis.backends.snowflake" sqlite = "ibis.backends.sqlite" @@ -352,6 +355,7 @@ markers = [ "pandas: Pandas tests", "polars: Polars tests", "postgres: PostgreSQL tests", + "risingwave: Risingwave tests", "pyspark: PySpark tests", "snowflake: Snowflake tests", "sqlite: SQLite tests",