diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index f0b82b43680f..c508b696666f 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -947,6 +947,39 @@ def test_memtable_construct(backend, con, monkeypatch): ) +@pytest.mark.parametrize( + "df, columns, expected", + [ + (pd.DataFrame([("a", 1.0)], columns=["d", "f"]), ["a", "b"], ["a", "b"]), + (pd.DataFrame([("a", 1.0)]), ["A", "B"], ["A", "B"]), + (pd.DataFrame([("a", 1.0)], columns=["c", "d"]), None, ["c", "d"]), + ([("a", "1.0")], None, ["col0", "col1"]), + ([("a", "1.0")], ["d", "e"], ["d", "e"]), + ], +) +def test_memtable_column_naming(backend, con, monkeypatch, df, columns, expected): + monkeypatch.setattr(ibis.options, "default_backend", con) + + t = ibis.memtable(df, columns=columns) + assert all(t.to_pandas().columns == expected) + + +@pytest.mark.parametrize( + "df, columns", + [ + (pd.DataFrame([("a", 1.0)], columns=["d", "f"]), ["a"]), + (pd.DataFrame([("a", 1.0)]), ["A", "B", "C"]), + ([("a", "1.0")], ["col0", "col1", "col2"]), + ([("a", "1.0")], ["d"]), + ], +) +def test_memtable_column_naming_mismatch(backend, con, monkeypatch, df, columns): + monkeypatch.setattr(ibis.options, "default_backend", con) + + with pytest.raises(ValueError): + ibis.memtable(df, columns=columns) + + @pytest.mark.notimpl( ["dask", "datafusion", "pandas", "polars"], raises=NotImplementedError, diff --git a/ibis/expr/api.py b/ibis/expr/api.py index f993667ab76b..5bd3e04e997c 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -371,7 +371,8 @@ def memtable( Do not depend on the underlying storage type (e.g., pyarrow.Table), it's subject to change across non-major releases. columns - Optional [](`typing.Iterable`) of [](`str`) column names. + Optional [](`typing.Iterable`) of [](`str`) column names. If provided, + must match the number of columns in `data`. schema Optional [`Schema`](./schemas.qmd#ibis.expr.schema.Schema). The functions use `data` to infer a schema if not passed. @@ -468,7 +469,11 @@ def _memtable_from_dataframe( from ibis.expr.operations.relations import PandasDataFrameProxy - df = pd.DataFrame(data, columns=columns) + if not isinstance(data, pd.DataFrame): + df = pd.DataFrame(data, columns=columns) + else: + df = data + if df.columns.inferred_type != "string": cols = df.columns newcols = getattr( @@ -478,6 +483,15 @@ def _memtable_from_dataframe( ) df = df.rename(columns=dict(zip(cols, newcols))) + if columns is not None: + if (provided_col := len(columns)) != (exist_col := len(df.columns)): + raise ValueError( + "Provided `columns` must have an entry for each column in `data`.\n" + f"`columns` has {provided_col} elements but `data` has {exist_col} columns." + ) + + df = df.rename(columns=dict(zip(df.columns, columns))) + # verify that the DataFrame has no duplicate column names because ibis # doesn't allow that cols = df.columns