From 86f96e8b79fe85ea03481f322fbc4287fb408739 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 4 Feb 2023 00:13:22 +0400 Subject: [PATCH] fix(python): support dict-expansion (scalars) when mixed with numpy arrays --- py-polars/polars/internals/construction.py | 95 +++++++++++----------- py-polars/tests/unit/test_df.py | 30 ++++++- 2 files changed, 74 insertions(+), 51 deletions(-) diff --git a/py-polars/polars/internals/construction.py b/py-polars/polars/internals/construction.py index 3b8f10d137d8..8c095c9cb778 100644 --- a/py-polars/polars/internals/construction.py +++ b/py-polars/polars/internals/construction.py @@ -685,68 +685,67 @@ def dict_to_pydf( nan_to_null: bool = False, ) -> PyDataFrame: """Construct a PyDataFrame from a dictionary of sequences.""" - if not schema: - schema = list(data) - if schema: - # the columns arg may also set the dtype/column order of the series - if isinstance(schema, dict) and data: - if not all((col in schema) for col in data): - raise ValueError( - "The given column-schema names do not match the data dictionary" - ) - data = {col: data[col] for col in schema} - - columns, schema_overrides = _unpack_schema( - schema, lookup_names=data.keys(), schema_overrides=schema_overrides - ) - if not data and schema_overrides: - data_series = [ - pli.Series( - name, [], dtype=schema_overrides.get(name), nan_to_null=nan_to_null - )._s - for name in columns - ] - else: - data_series = [ - s._s - for s in _expand_dict_scalars( - data, schema_overrides, nan_to_null=nan_to_null - ).values() - ] + if isinstance(schema, dict) and data: + if not all((col in schema) for col in data): + raise ValueError( + "The given column-schema names do not match the data dictionary" + ) + data = {col: data[col] for col in schema} - data_series = _handle_columns_arg(data_series, columns=columns, from_dict=True) - return PyDataFrame(data_series) + column_names, schema_overrides = _unpack_schema( + schema, lookup_names=data.keys(), schema_overrides=schema_overrides + ) + if not column_names: + column_names = list(data) if _NUMPY_AVAILABLE: - count_numpy = 0 - for val in data.values(): - # only start a thread pool from a reasonable size. - count_numpy += int( + # if there are 3 or more numpy arrays of sufficient size, we multi-thread: + count_numpy = sum( + int( _check_for_numpy(val) and isinstance(val, np.ndarray) and len(val) > 1000 ) + for val in data.values() + ) + if count_numpy >= 3: + # yes, multi-threading was easier in python here; we cannot have multiple + # threads running python and release the gil in pyo3 (it will deadlock). - # if we have more than 3 numpy arrays we multi-thread - if count_numpy > 2: - # yes, multi-threading was easier in python here - # we cannot run multiple threads that run python code - # and release the gil in pyo3 - # it will deadlock. - - # dummy is threaded + # (note: 'dummy' is threaded) import multiprocessing.dummy pool_size = threadpool_size() with multiprocessing.dummy.Pool(pool_size) as pool: - data_series = pool.map( - lambda t: pli.Series(t[0], t[1])._s, - [(k, v) for k, v in data.items()], + data = dict( + zip( + column_names, + pool.map( + lambda t: pli.Series(t[0], t[1]) + if isinstance(t[1], np.ndarray) + else t[1], + [(k, v) for k, v in data.items()], + ), + ) ) - return PyDataFrame(data_series) - data = _expand_dict_scalars(data) - return PyDataFrame.read_dict(data) + if not data and schema_overrides: + data_series = [ + pli.Series( + name, [], dtype=schema_overrides.get(name), nan_to_null=nan_to_null + )._s + for name in column_names + ] + else: + data_series = [ + s._s + for s in _expand_dict_scalars( + data, schema_overrides, nan_to_null=nan_to_null + ).values() + ] + + data_series = _handle_columns_arg(data_series, columns=column_names, from_dict=True) + return PyDataFrame(data_series) def sequence_to_pydf( diff --git a/py-polars/tests/unit/test_df.py b/py-polars/tests/unit/test_df.py index 5f111935e0dd..ea6e254dcc0e 100644 --- a/py-polars/tests/unit/test_df.py +++ b/py-polars/tests/unit/test_df.py @@ -276,7 +276,7 @@ def test_from_dict_with_column_order() -> None: def test_from_dict_with_scalars() -> None: import polars as pl - # one or more valid arrays, with some scalars + # one or more valid arrays, with some scalars (inc. None) df1 = pl.DataFrame( {"key": ["aa", "bb", "cc"], "misc": "xyz", "other": None, "value": 0} ) @@ -300,8 +300,8 @@ def test_from_dict_with_scalars() -> None: df3 = pl.DataFrame({"vals": map(float, [1, 2, 3])}) assert df3.to_dict(False) == {"vals": [1.0, 2.0, 3.0]} - # ensure we don't accidentally consume or expand map/range/generator cols - # and can properly apply schema dtype/ordering directives (via 'columns') + # ensure we don't accidentally consume or expand map/range/generator + # cols, and can properly apply schema dtype/ordering directives df4 = pl.DataFrame( { "key": range(1, 4), @@ -348,6 +348,30 @@ def test_from_dict_with_scalars() -> None: "z": pl.Utf8, } + # mixed with numpy cols... + df6 = pl.DataFrame( + {"x": np.ones(3), "y": np.zeros(3), "z": 1.0}, + ) + assert df6.rows() == [(1.0, 0.0, 1.0), (1.0, 0.0, 1.0), (1.0, 0.0, 1.0)] + + # ...and trigger multithreaded load codepath + df7 = pl.DataFrame( + { + "w": np.zeros(1001, dtype=np.uint8), + "x": np.ones(1001, dtype=np.uint8), + "y": np.zeros(1001, dtype=np.uint8), + "z": 1, + }, + schema_overrides={"z": pl.UInt8}, + ) + assert df7[999:].rows() == [(0, 1, 0, 1), (0, 1, 0, 1)] + assert df7.schema == { + "w": pl.UInt8, + "x": pl.UInt8, + "y": pl.UInt8, + "z": pl.UInt8, + } + def test_dataframe_membership_operator() -> None: # cf. issue #4032