From 87d13903993d5339723ecd859bd362a64e780aec Mon Sep 17 00:00:00 2001 From: Robin Kahlow Date: Wed, 15 Jun 2022 21:30:32 +0100 Subject: [PATCH] Add polars plugin (#1061) * add polars plugin Signed-off-by: Robin Kahlow * support for older polars versions, add info about what polars is Signed-off-by: Robin Kahlow * run make fmt Signed-off-by: Robin Kahlow * structured dataset instead of schema transformer Signed-off-by: Robin Kahlow * polars html describe only Signed-off-by: Robin Kahlow * set polars min to 0.7.13 (.describe() added) Signed-off-by: Robin Kahlow * set polars min to 0.8.27 (.transpose() added) Signed-off-by: Robin Kahlow * add gcs, fix encode local/remote dir Signed-off-by: Robin Kahlow * black and isort Signed-off-by: Robin Kahlow * add polars plugin to pythonbuild.yml Signed-off-by: Robin Kahlow --- .github/workflows/pythonbuild.yml | 1 + .../types/structured/structured_dataset.py | 5 + plugins/flytekit-polars/README.md | 10 + .../flytekitplugins/polars/__init__.py | 14 ++ .../flytekitplugins/polars/sd_transformers.py | 73 +++++++ plugins/flytekit-polars/requirements.in | 2 + plugins/flytekit-polars/requirements.txt | 186 ++++++++++++++++++ plugins/flytekit-polars/setup.py | 38 ++++ plugins/flytekit-polars/tests/__init__.py | 0 .../tests/test_polars_plugin_sd.py | 64 ++++++ 10 files changed, 393 insertions(+) create mode 100644 plugins/flytekit-polars/README.md create mode 100644 plugins/flytekit-polars/flytekitplugins/polars/__init__.py create mode 100644 plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py create mode 100644 plugins/flytekit-polars/requirements.in create mode 100644 plugins/flytekit-polars/requirements.txt create mode 100644 plugins/flytekit-polars/setup.py create mode 100644 plugins/flytekit-polars/tests/__init__.py create mode 100644 plugins/flytekit-polars/tests/test_polars_plugin_sd.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index ebf241cf10..ca6c85bb3e 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -79,6 +79,7 @@ jobs: - flytekit-modin - flytekit-pandera - flytekit-papermill + - flytekit-polars - flytekit-snowflake - flytekit-spark - flytekit-sqlalchemy diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index a4488fbc1e..af60599f0c 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -19,6 +19,8 @@ if importlib.util.find_spec("pyspark") is not None: import pyspark +if importlib.util.find_spec("polars") is not None: + import polars as pl from dataclasses_json import config, dataclass_json from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin @@ -647,6 +649,9 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ return pd.DataFrame(df).describe().to_html() elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame): return pd.DataFrame(df.schema, columns=["StructField"]).to_html() + elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame): + describe_df = df.describe() + return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) else: raise NotImplementedError("Conversion to html string should be implemented") diff --git a/plugins/flytekit-polars/README.md b/plugins/flytekit-polars/README.md new file mode 100644 index 0000000000..011a447582 --- /dev/null +++ b/plugins/flytekit-polars/README.md @@ -0,0 +1,10 @@ +# Flytekit Polars Plugin +[Polars](https://github.com/pola-rs/polars) is a blazingly fast DataFrames library implemented in Rust using Apache Arrow Columnar Format as memory model. + +This plugin supports `polars.DataFrame` as a data type with [StructuredDataset](https://docs.flyte.org/projects/cookbook/en/latest/auto/core/type_system/structured_dataset.html). + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-polars +``` diff --git a/plugins/flytekit-polars/flytekitplugins/polars/__init__.py b/plugins/flytekit-polars/flytekitplugins/polars/__init__.py new file mode 100644 index 0000000000..85948bed73 --- /dev/null +++ b/plugins/flytekit-polars/flytekitplugins/polars/__init__.py @@ -0,0 +1,14 @@ +""" +.. currentmodule:: flytekitplugins.polars + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + PolarsDataFrameToParquetEncodingHandler + ParquetToPolarsDataFrameDecodingHandler +""" + +from .sd_transformers import ParquetToPolarsDataFrameDecodingHandler, PolarsDataFrameToParquetEncodingHandler diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py new file mode 100644 index 0000000000..1a667fe699 --- /dev/null +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -0,0 +1,73 @@ +import typing + +import polars as pl + +from flytekit import FlyteContext +from flytekit.models import literals +from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + GCS, + LOCAL, + PARQUET, + S3, + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, +) + + +class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self, protocol: str): + super().__init__(pl.DataFrame, protocol, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + df = typing.cast(pl.DataFrame, structured_dataset.dataframe) + + local_dir = ctx.file_access.get_random_local_directory() + local_path = f"{local_dir}/00000" + + # Polars 0.13.12 deprecated to_parquet in favor of write_parquet + if hasattr(df, "write_parquet"): + df.write_parquet(local_path) + else: + df.to_parquet(local_path) + remote_dir = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + ctx.file_access.upload_directory(local_dir, remote_dir) + return literals.StructuredDataset(uri=remote_dir, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + +class ParquetToPolarsDataFrameDecodingHandler(StructuredDatasetDecoder): + def __init__(self, protocol: str): + super().__init__(pl.DataFrame, protocol, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> pl.DataFrame: + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + path = f"{local_dir}/00000" + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return pl.read_parquet(path, columns=columns) + return pl.read_parquet(path) + + +for protocol in [LOCAL, S3]: + StructuredDatasetTransformerEngine.register( + PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=True + ) + StructuredDatasetTransformerEngine.register( + ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=True + ) +StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler(GCS), default_for_type=False) +StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler(GCS), default_for_type=False) diff --git a/plugins/flytekit-polars/requirements.in b/plugins/flytekit-polars/requirements.in new file mode 100644 index 0000000000..8425c5645f --- /dev/null +++ b/plugins/flytekit-polars/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-polars diff --git a/plugins/flytekit-polars/requirements.txt b/plugins/flytekit-polars/requirements.txt new file mode 100644 index 0000000000..9bd7819a1e --- /dev/null +++ b/plugins/flytekit-polars/requirements.txt @@ -0,0 +1,186 @@ +# +# This file is autogenerated by pip-compile with python 3.8 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-polars + # via -r requirements.in +arrow==1.2.2 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2021.10.8 + # via requests +cffi==1.15.0 + # via cryptography +chardet==4.0.0 + # via binaryornot +charset-normalizer==2.0.12 + # via requests +click==8.1.2 + # via + # cookiecutter + # flytekit +cloudpickle==2.0.0 + # via flytekit +cookiecutter==1.7.3 + # via flytekit +croniter==1.3.4 + # via flytekit +cryptography==36.0.2 + # via secretstorage +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.13 + # via flytekit +flyteidl==1.0.1 + # via flytekit +flytekit==1.1.0b2 + # via flytekitplugins-polars +googleapis-common-protos==1.56.0 + # via + # flyteidl + # grpcio-status +grpcio==1.43.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.43.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.11.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.1 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keyring==23.5.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.15.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.22.3 + # via + # pandas + # polars + # pyarrow +packaging==21.3 + # via marshmallow +pandas==1.4.1 + # via flytekit +polars==0.13.44 + # via flytekitplugins-polars +poyo==0.5.0 + # via cookiecutter +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyparsing==3.0.8 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==6.1.1 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via flytekit +regex==2022.3.15 + # via docker-image-py +requests==2.27.1 + # via + # cookiecutter + # docker + # flytekit + # responses +responses==0.20.0 + # via flytekit +retry==0.9.2 + # via flytekit +secretstorage==3.3.2 + # via keyring +six==1.16.0 + # via + # cookiecutter + # grpcio + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +typing-extensions==4.2.0 + # via + # flytekit + # polars + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.9 + # via + # flytekit + # requests + # responses +websocket-client==1.3.2 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.0 + # via + # deprecated + # flytekit +zipp==3.8.0 + # via importlib-metadata diff --git a/plugins/flytekit-polars/setup.py b/plugins/flytekit-polars/setup.py new file mode 100644 index 0000000000..ea3feb8582 --- /dev/null +++ b/plugins/flytekit-polars/setup.py @@ -0,0 +1,38 @@ +from setuptools import setup + +PLUGIN_NAME = "polars" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flytekit>=1.1.0b0,<1.2.0", + "polars>=0.8.27", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="Robin Kahlow", + description="Polars plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-polars/tests/__init__.py b/plugins/flytekit-polars/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py new file mode 100644 index 0000000000..3c9c2613ae --- /dev/null +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -0,0 +1,64 @@ +import flytekitplugins.polars # noqa F401 +import polars as pl + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + +from flytekit import kwtypes, task, workflow +from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset + +subset_schema = Annotated[StructuredDataset, kwtypes(col2=str), PARQUET] +full_schema = Annotated[StructuredDataset, PARQUET] + + +def test_polars_workflow_subset(): + @task + def generate() -> subset_schema: + df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + return StructuredDataset(dataframe=df) + + @task + def consume(df: subset_schema) -> subset_schema: + df = df.open(pl.DataFrame).all() + + assert df["col2"][0] == "a" + assert df["col2"][1] == "b" + assert df["col2"][2] == "c" + + return StructuredDataset(dataframe=df) + + @workflow + def wf() -> subset_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_polars_workflow_full(): + @task + def generate() -> full_schema: + df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + return StructuredDataset(dataframe=df) + + @task + def consume(df: full_schema) -> full_schema: + df = df.open(pl.DataFrame).all() + + assert df["col1"][0] == 1 + assert df["col1"][1] == 3 + assert df["col1"][2] == 2 + assert df["col2"][0] == "a" + assert df["col2"][1] == "b" + assert df["col2"][2] == "c" + + return StructuredDataset(dataframe=df.sort("col1")) + + @workflow + def wf() -> full_schema: + return consume(df=generate()) + + result = wf() + assert result is not None