Skip to content

Commit

Permalink
feat(cast): add table-level try_cast
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth authored and cpcloud committed Jun 29, 2023
1 parent 15d9e50 commit 5e4d16b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
27 changes: 27 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,33 @@ def test_try_cast_expected(con, from_val, to_type, expected):
assert con.execute(ibis.literal(from_val).try_cast(to_type)) == expected


@pytest.mark.notimpl(
[
"pandas",
"dask",
"bigquery",
"datafusion",
"druid",
"impala",
"mssql",
"mysql",
"oracle",
"postgres",
"pyspark",
"snowflake",
"sqlite",
]
)
def test_try_cast_table(con):
df = pd.DataFrame({"a": ["1", "2", None], "b": ["1.0", "2.2", "goodbye"]})

expected = pd.DataFrame({"a": [1.0, 2.0, None], "b": [1.0, 2.2, None]})

t = ibis.memtable(df)

tm.assert_frame_equal(con.execute(t.try_cast({"a": "int", "b": "float"})), expected)


@pytest.mark.notimpl(
[
"pandas",
Expand Down
39 changes: 38 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,43 @@ def cast(self, schema: SupportsSchema) -> Table:
...
ibis.common.exceptions.IbisError: Cast schema has fields that are not in the table: ['foo']
"""
return self._cast(schema, cast_method="cast")

def try_cast(self, schema: SupportsSchema) -> Table:
"""Cast the columns of a table.
If the cast fails for a row, the value is returned
as `NULL` or `NaN` depending on backend behavior.
Parameters
----------
schema
Mapping, schema or iterable of pairs to use for casting
Returns
-------
Table
Casted table
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"a": ["1", "2", "3"], "b": ["2.2", "3.3", "book"]})
>>> t.try_cast({"a": "int", "b": "float"})
┏━━━━━━━┳━━━━━━━━━┓
┃ a ┃ b ┃
┡━━━━━━━╇━━━━━━━━━┩
│ int64 │ float64 │
├───────┼─────────┤
│ 1 │ 2.2 │
│ 2 │ 3.3 │
│ 3 │ nan │
└───────┴─────────┘
"""
return self._cast(schema, cast_method="try_cast")

def _cast(self, schema: SupportsSchema, cast_method: str = "cast") -> Table:
schema = sch.schema(schema)

cols = []
Expand All @@ -237,7 +274,7 @@ def cast(self, schema: SupportsSchema) -> Table:

for col in columns:
if (new_type := schema.get(col)) is not None:
new_col = self[col].cast(new_type).name(col)
new_col = getattr(self[col], cast_method)(new_type).name(col)
else:
new_col = col
cols.append(new_col)
Expand Down

0 comments on commit 5e4d16b

Please sign in to comment.