Skip to content

Commit

Permalink
Remove arg swap hack from Series.add (#731)
Browse files Browse the repository at this point in the history
* remove arg swap hack

* tweak wording
  • Loading branch information
billylanchantin authored Nov 8, 2023
1 parent f0d981d commit 1d5ad56
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 28 deletions.
14 changes: 1 addition & 13 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ defmodule Explorer.PolarsBackend.Expression do
@type t :: %__MODULE__{resource: reference()}

@all_expressions [
add: 2,
all_equal: 2,
argmax: 1,
argmin: 1,
Expand Down Expand Up @@ -130,7 +131,6 @@ defmodule Explorer.PolarsBackend.Expression do
]

@custom_expressions [
add: 2,
divide: 2,
multiply: 2,
cast: 2,
Expand Down Expand Up @@ -244,18 +244,6 @@ defmodule Explorer.PolarsBackend.Expression do
end
end

def to_expr(%LazySeries{op: :add, args: [left, right]}) do
# `duration + date` is not supported by polars for some reason.
# `date + duration` is, so we're swapping arguments as a work around.
[left, right] =
case [dtype(left), dtype(right)] do
[{:duration, _}, :date] -> [right, left]
_ -> [left, right]
end

Native.expr_add(to_expr(left), to_expr(right))
end

def to_expr(%LazySeries{op: :multiply, args: [left, right] = args}) do
expr = Native.expr_multiply(to_expr(left), to_expr(right))

Expand Down
15 changes: 2 additions & 13 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -284,19 +284,8 @@ defmodule Explorer.PolarsBackend.Series do
# Arithmetic

@impl true
def add(out_dtype, left, right) do
left = matching_size!(left, right)

# `duration + date` is not supported by polars for some reason.
# `date + duration` is, so we're swapping arguments as a work around.
[left, right] =
case {out_dtype, dtype(left), dtype(right)} do
{:date, {:duration, _}, :date} -> [right, left]
_ -> [left, right]
end

Shared.apply_series(left, :s_add, [right.data])
end
def add(_out_dtype, left, right),
do: Shared.apply_series(matching_size!(left, right), :s_add, [right.data])

@impl true
def subtract(_out_dtype, left, right),
Expand Down
4 changes: 2 additions & 2 deletions test/explorer/series/duration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ defmodule Explorer.Series.DurationTest do
assert Series.to_list(df["ns"]) == [ns]
end

# There is currently an issue with Polars where `duration + date` is not supported but
# `date + duration` is. There is a workaround in where we swap the args.
# There used to be an issue with Polars where `duration + date` was not supported but
# `date + duration` was. This test was for a workaround (longer present) to that issue.
test "mutate/2 with duration + date" do
require Explorer.DataFrame
alias Explorer.DataFrame, as: DF
Expand Down

0 comments on commit 1d5ad56

Please sign in to comment.