Skip to content

Commit

Permalink
fix(memtable): ensure column names match provided data
Browse files Browse the repository at this point in the history
Previously, the `pandas` memtable constructor was performing two tasks
inconsistently, one was to subselect columns out of the provided
dataframe, the other was to rename those columns.
This led to some weird behavior where a mismatch in provided names could
lead to a dataframe consisting of NaNs.

Now, `columns` can only be provided to rename existing columns and there
is no subselection behavior.  If the length of the `columns` iterable
does not match the number of columns in the provided data, we error.
  • Loading branch information
gforsyth authored and cpcloud committed Dec 8, 2023
1 parent 241c8be commit faf99df
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
33 changes: 33 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,39 @@ def test_memtable_construct(backend, con, monkeypatch):
)


@pytest.mark.parametrize(
"df, columns, expected",
[
(pd.DataFrame([("a", 1.0)], columns=["d", "f"]), ["a", "b"], ["a", "b"]),
(pd.DataFrame([("a", 1.0)]), ["A", "B"], ["A", "B"]),
(pd.DataFrame([("a", 1.0)], columns=["c", "d"]), None, ["c", "d"]),
([("a", "1.0")], None, ["col0", "col1"]),
([("a", "1.0")], ["d", "e"], ["d", "e"]),
],
)
def test_memtable_column_naming(backend, con, monkeypatch, df, columns, expected):
monkeypatch.setattr(ibis.options, "default_backend", con)

t = ibis.memtable(df, columns=columns)
assert all(t.to_pandas().columns == expected)


@pytest.mark.parametrize(
"df, columns",
[
(pd.DataFrame([("a", 1.0)], columns=["d", "f"]), ["a"]),
(pd.DataFrame([("a", 1.0)]), ["A", "B", "C"]),
([("a", "1.0")], ["col0", "col1", "col2"]),
([("a", "1.0")], ["d"]),
],
)
def test_memtable_column_naming_mismatch(backend, con, monkeypatch, df, columns):
monkeypatch.setattr(ibis.options, "default_backend", con)

with pytest.raises(ValueError):
ibis.memtable(df, columns=columns)


@pytest.mark.notimpl(
["dask", "datafusion", "pandas", "polars"],
raises=NotImplementedError,
Expand Down
18 changes: 16 additions & 2 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ def memtable(
Do not depend on the underlying storage type (e.g., pyarrow.Table), it's subject
to change across non-major releases.
columns
Optional [](`typing.Iterable`) of [](`str`) column names.
Optional [](`typing.Iterable`) of [](`str`) column names. If provided,
must match the number of columns in `data`.
schema
Optional [`Schema`](./schemas.qmd#ibis.expr.schema.Schema).
The functions use `data` to infer a schema if not passed.
Expand Down Expand Up @@ -468,7 +469,11 @@ def _memtable_from_dataframe(

from ibis.expr.operations.relations import PandasDataFrameProxy

df = pd.DataFrame(data, columns=columns)
if not isinstance(data, pd.DataFrame):
df = pd.DataFrame(data, columns=columns)
else:
df = data

if df.columns.inferred_type != "string":
cols = df.columns
newcols = getattr(
Expand All @@ -478,6 +483,15 @@ def _memtable_from_dataframe(
)
df = df.rename(columns=dict(zip(cols, newcols)))

if columns is not None:
if (provided_col := len(columns)) != (exist_col := len(df.columns)):
raise ValueError(
"Provided `columns` must have an entry for each column in `data`.\n"
f"`columns` has {provided_col} elements but `data` has {exist_col} columns."
)

df = df.rename(columns=dict(zip(df.columns, columns)))

# verify that the DataFrame has no duplicate column names because ibis
# doesn't allow that
cols = df.columns
Expand Down

0 comments on commit faf99df

Please sign in to comment.