diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index c0a5cba19c8f..399ab41b59f9 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -187,9 +187,8 @@ def register( parquet/csv files, an iterable of parquet or CSV files, a pandas dataframe, a pyarrow table or dataset, or a postgres URI. table_name - An optional name to use for the created table. This defaults to the - filename if a path (with hyphens replaced with underscores), or - sequentially generated name otherwise. + An optional name to use for the created table. This defaults to a + sequentially generated name. **kwargs Additional keyword arguments passed to DuckDB loading functions for CSV or parquet. See https://duckdb.org/docs/data/csv and diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index a15b5bbdb891..4036cc91a759 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from pathlib import Path from typing import TYPE_CHECKING, Any @@ -13,6 +14,7 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir +from ibis import util from ibis.backends.base.df.scope import Scope from ibis.backends.base.df.timecontext import canonicalize_context, localize_context from ibis.backends.base.sql import BaseSQLBackend @@ -37,6 +39,16 @@ 'escape': '"', } +pa_n = itertools.count(0) +csv_n = itertools.count(0) + + +def normalize_filenames(source_list): + # Promote to list + source_list = util.promote_list(source_list) + + return list(map(util.normalize_filename, source_list)) + class _PySparkCursor: """Spark cursor. @@ -574,3 +586,121 @@ def _clean_up_cached_table(self, op): assert t.is_cached t.unpersist() assert not t.is_cached + + def read_parquet( + self, + source: str | Path, + table_name: str | None = None, + **kwargs: Any, + ) -> ir.Table: + """Register a parquet file as a table in the current database. + + Parameters + ---------- + source + The data source. May be a path to a file or directory of parquet files. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + kwargs + Additional keyword arguments passed to PySpark. + https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.parquet.html + + Returns + ------- + ir.Table + The just-registered table + """ + source = util.normalize_filename(source) + spark_df = self._session.read.parquet(source, **kwargs) + table_name = table_name or f"ibis_read_parquet_{next(pa_n)}" + + spark_df.createOrReplaceTempView(table_name) + return self.table(table_name) + + def read_csv( + self, + source_list: str | list[str] | tuple[str], + table_name: str | None = None, + **kwargs: Any, + ) -> ir.Table: + """Register a CSV file as a table in the current database. + + Parameters + ---------- + source_list + The data source(s). May be a path to a file or directory of CSV files, or an + iterable of CSV files. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + kwargs + Additional keyword arguments passed to PySpark loading function. + https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.csv.html + + Returns + ------- + ir.Table + The just-registered table + """ + source_list = normalize_filenames(source_list) + spark_df = self._session.read.csv(source_list, **kwargs) + table_name = table_name or f"ibis_read_csv_{next(csv_n)}" + + spark_df.createOrReplaceTempView(table_name) + return self.table(table_name) + + def register( + self, + source: str | Path | Any, + table_name: str | None = None, + **kwargs: Any, + ) -> ir.Table: + """Register a data source as a table in the current database. + + Parameters + ---------- + source + The data source(s). May be a path to a file or directory of + parquet/csv files, or an iterable of CSV files. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + **kwargs + Additional keyword arguments passed to PySpark loading functions for + CSV or parquet. + + Returns + ------- + ir.Table + The just-registered table + """ + + if isinstance(source, (str, Path)): + first = str(source) + elif isinstance(source, (list, tuple)): + first = source[0] + else: + self._register_failure() + + if first.startswith(("parquet://", "parq://")) or first.endswith( + ("parq", "parquet") + ): + return self.read_parquet(source, table_name=table_name, **kwargs) + elif first.startswith( + ("csv://", "csv.gz://", "txt://", "txt.gz://") + ) or first.endswith(("csv", "csv.gz", "tsv", "tsv.gz", "txt", "txt.gz")): + return self.read_csv(source, table_name=table_name, **kwargs) + else: + self._register_failure() # noqa: RET503 + + def _register_failure(self): + import inspect + + msg = ", ".join( + name for name, _ in inspect.getmembers(self) if name.startswith("read_") + ) + raise ValueError( + f"Cannot infer appropriate read function for input, " + f"please call one of {msg} directly" + ) diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 9e2b0c2f0814..2b05ac01554e 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -41,7 +41,15 @@ def gzip_csv(data_directory, tmp_path): ("fname", "in_table_name", "out_table_name"), [ param("diamonds.csv", None, "ibis_read_csv_", id="default"), - param("csv://diamonds.csv", "Diamonds2", "Diamonds2", id="csv_name"), + param( + "csv://diamonds.csv", + "Diamonds2", + "Diamonds2", + id="csv_name", + marks=pytest.mark.notyet( + ["pyspark"], reason="pyspark lowercases view names" + ), + ), param( "file://diamonds.csv", "fancy_stones", @@ -53,11 +61,14 @@ def gzip_csv(data_directory, tmp_path): "fancy stones", "fancy stones", id="file_atypical_name", + marks=pytest.mark.notyet( + ["pyspark"], reason="no spaces allowed in view names" + ), ), param( ["file://diamonds.csv", "diamonds.csv"], - "fancy stones", - "fancy stones", + "fancy_stones2", + "fancy_stones2", id="multi_csv", marks=pytest.mark.notyet( ["polars", "datafusion"], @@ -76,7 +87,6 @@ def gzip_csv(data_directory, tmp_path): "mysql", "pandas", "postgres", - "pyspark", "snowflake", "sqlite", "trino", @@ -102,7 +112,6 @@ def test_register_csv(con, data_directory, fname, in_table_name, out_table_name) "mysql", "pandas", "postgres", - "pyspark", "snowflake", "sqlite", "trino", @@ -125,7 +134,6 @@ def test_register_csv_gz(con, data_directory, gzip_csv): "mysql", "pandas", "postgres", - "pyspark", "snowflake", "sqlite", "trino", @@ -179,7 +187,6 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "mysql", "pandas", "postgres", - "pyspark", "snowflake", "sqlite", "trino", @@ -381,3 +388,90 @@ def test_register_garbage(con, monkeypatch): with pytest.raises(FileNotFoundError): con.read_parquet("garbage_notafile") + + +@pytest.mark.parametrize( + ("fname", "in_table_name", "out_table_name"), + [ + ( + "functional_alltypes.parquet", + None, + "ibis_read_parquet", + ), + ("functional_alltypes.parquet", "funk_all", "funk_all"), + ], +) +@pytest.mark.notyet( + [ + "bigquery", + "clickhouse", + "dask", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "snowflake", + "sqlite", + "trino", + ] +) +def test_read_parquet( + con, tmp_path, data_directory, fname, in_table_name, out_table_name +): + pq = pytest.importorskip("pyarrow.parquet") + + fname = Path(fname) + table = read_table(data_directory / fname.name) + + pq.write_table(table, tmp_path / fname.name) + + with pushd(data_directory): + if con.name == "pyspark": + # pyspark doesn't respect CWD + fname = str(Path(fname).absolute()) + table = con.read_parquet(fname, table_name=in_table_name) + + assert any(t.startswith(out_table_name) for t in con.list_tables()) + + if con.name != "datafusion": + table.count().execute() + + +@pytest.mark.parametrize( + ("fname", "in_table_name", "out_table_name"), + [ + param("diamonds.csv", None, "ibis_read_csv_", id="default"), + param( + "diamonds.csv", + "fancy_stones", + "fancy_stones", + id="file_name", + ), + ], +) +@pytest.mark.notyet( + [ + "bigquery", + "clickhouse", + "dask", + "impala", + "mssql", + "mysql", + "pandas", + "postgres", + "snowflake", + "sqlite", + "trino", + ] +) +def test_read_csv(con, data_directory, fname, in_table_name, out_table_name): + with pushd(data_directory): + if con.name == "pyspark": + # pyspark doesn't respect CWD + fname = str(Path(fname).absolute()) + table = con.read_csv(fname, table_name=in_table_name) + + assert any(t.startswith(out_table_name) for t in con.list_tables()) + if con.name != "datafusion": + table.count().execute()