Skip to content

Commit

Permalink
fix(pandas): don't silently ignore result column name mismatches
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 6, 2024
1 parent 85e1dcc commit 48be246
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
7 changes: 4 additions & 3 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def raw_sql(self, query: str):
def _fetch_from_cursor(self, cursor, schema):
from ibis.formats.pandas import PandasData

results = fetchall(cursor)
results = fetchall(cursor, schema.names)
return PandasData.convert_table(results, schema)

@contextlib.contextmanager
Expand Down Expand Up @@ -1260,9 +1260,10 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
cur.execute(insert_stmt, row)


def fetchall(cur):
def fetchall(cur, names=None):
batches = cur.fetchcolumnar()
names = list(map(operator.itemgetter(0), cur.description))
if names is None:
names = list(map(operator.itemgetter(0), cur.description))
df = _column_batches_to_dataframe(names, batches)
return df

Expand Down
11 changes: 9 additions & 2 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,12 +466,19 @@ def _to_dataframe(
streaming: bool = False,
**kwargs: Any,
) -> pl.DataFrame:
lf = self.compile(expr, params=params, **kwargs)
table_expr = expr.as_table()
lf = self.compile(table_expr, params=params, **kwargs)
if limit == "default":
limit = ibis.options.sql.default_limit
if limit is not None:
lf = lf.limit(limit)
return lf.collect(streaming=streaming)
df = lf.collect(streaming=streaming)
# XXX: Polars sometimes returns data with the incorrect column names.
# For now we catch this case and rename them here if needed.
expected_cols = tuple(table_expr.columns)
if tuple(df.columns) != expected_cols:
df = df.rename(dict(zip(df.columns, expected_cols)))
return df

def execute(
self,
Expand Down
9 changes: 3 additions & 6 deletions ibis/formats/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,11 @@ def infer_table(cls, df):

@classmethod
def convert_table(cls, df, schema):
if len(schema) != len(df.columns):
raise ValueError(
"schema column count does not match input data column count"
)
if schema.names != tuple(df.columns):
raise ValueError("schema names don't match input data columns")

columns = {
name: cls.convert_column(series, dtype)
for (name, dtype), (_, series) in zip(schema.items(), df.items())
name: cls.convert_column(df[name], dtype) for name, dtype in schema.items()
}
df = pd.DataFrame(columns)

Expand Down
7 changes: 7 additions & 0 deletions ibis/formats/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,10 @@ def test_convert_dataframe_with_timezone():
desired_schema = ibis.schema(dict(time='timestamp("EST")'))
result = PandasData.convert_table(df.copy(), desired_schema)
tm.assert_frame_equal(expected, result)


def test_schema_doesnt_match_input_columns():
df = pd.DataFrame({"x": [1], "y": [2]})
schema = sch.Schema({"a": "int64", "b": "int64"})
with pytest.raises(ValueError, match="schema names don't match"):
PandasData.convert_table(df, schema)

0 comments on commit 48be246

Please sign in to comment.