Skip to content

Commit

Permalink
Added AzureBlobFileSystem support for StructuredDatasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick Müller committed Jul 21, 2022
1 parent c6e7237 commit 81b3370
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 4 deletions.
3 changes: 2 additions & 1 deletion flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
ABFS,
GCS,
LOCAL,
PARQUET,
Expand Down Expand Up @@ -106,7 +107,7 @@ def decode(


# Don't override default protocol
for protocol in [LOCAL, S3, GCS]:
for protocol in [LOCAL, S3, GCS, ABFS]:
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False)
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# Protocols
BIGQUERY = "bq"
S3 = "s3"
ABFS = "abfs"
GCS = "gs"
LOCAL = "/"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from flytekit import StructuredDatasetTransformerEngine, logger
from flytekit.configuration import internal
from flytekit.types.structured.structured_dataset import GCS, S3
from flytekit.types.structured.structured_dataset import ABFS, GCS, S3

from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler
from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler
Expand All @@ -41,6 +41,9 @@ def _register(protocol: str):
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), True, True)


if importlib.util.find_spec("adlfs"):
_register(ABFS)

if importlib.util.find_spec("s3fs"):
_register(S3)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
ABFS,
GCS,
LOCAL,
PARQUET,
Expand Down Expand Up @@ -62,7 +63,7 @@ def decode(
return pl.read_parquet(path)


for protocol in [LOCAL, S3, GCS]:
for protocol in [LOCAL, S3, GCS, ABFS]:
StructuredDatasetTransformerEngine.register(
PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=False
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
ABFS,
GCS,
LOCAL,
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand Down Expand Up @@ -48,6 +52,6 @@ def decode(
return user_ctx.spark_session.read.parquet(flyte_value.uri)


for protocol in ["/", "s3"]:
for protocol in [LOCAL, S3, GCS, ABFS]:
StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=False)

0 comments on commit 81b3370

Please sign in to comment.