Skip to content

Commit

Permalink
Fix from_list/2 of list of structs when first is empty (#849)
Browse files Browse the repository at this point in the history
* Fix `from_list/2` of list of structs when fist is empty

Closes #847

* Update test/explorer/series/list_test.exs
  • Loading branch information
philss authored Feb 7, 2024
1 parent 5dbdbbd commit 24fa10d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
10 changes: 9 additions & 1 deletion lib/explorer/polars_backend/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ defmodule Explorer.PolarsBackend.Shared do
Native.s_from_list_of_series(name, series)
end

def from_list(list, {:struct, fields} = _dtype, name) when is_list(list) do
def from_list(list, {:struct, fields} = dtype, name) when is_list(list) do
series =
for {column, values} <- Table.to_columns(list) do
column = to_string(column)
Expand All @@ -135,6 +135,14 @@ defmodule Explorer.PolarsBackend.Shared do
end

Native.s_from_list_of_series_as_structs(name, series)
|> then(fn polars_series ->
if Native.s_dtype(polars_series) != dtype do
{:ok, casted} = Native.s_cast(polars_series, dtype)
casted
else
polars_series
end
end)
end

def from_list(list, dtype, name) when is_list(list) do
Expand Down
16 changes: 16 additions & 0 deletions test/explorer/series/list_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,22 @@ defmodule Explorer.Series.ListTest do
"the value \"z\" does not match the inferred dtype {:s, 64}",
fn -> Series.from_list([[[[[1, 2], ["z", "b"]]]]]) end
end

test "list of structs" do
series =
Series.from_list([[%{"a" => 42}], []], dtype: {:list, {:struct, %{"a" => :integer}}})

assert Series.dtype(series) == {:list, {:struct, %{"a" => {:s, 64}}}}
assert Series.to_list(series) == [[%{"a" => 42}], []]
end

test "list of structs with first empty" do
series =
Series.from_list([[], [%{"a" => 42}], []], dtype: {:list, {:struct, %{"a" => :integer}}})

assert Series.dtype(series) == {:list, {:struct, %{"a" => {:s, 64}}}}
assert Series.to_list(series) == [[], [%{"a" => 42}], []]
end
end

describe "cast/2" do
Expand Down

0 comments on commit 24fa10d

Please sign in to comment.