Skip to content

Commit

Permalink
add lazyframe support
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 20, 2024
1 parent 4a7b917 commit 9eb0a94
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union

import polars as pl
import pyarrow as pa
import pyarrow.dataset as ds
from dagster._core.storage.db_io_manager import (
DbTypeHandler,
)
Expand All @@ -11,19 +12,35 @@
)
from dagster_deltalake.io_manager import DeltaLakeIOManager

PolarsTypes = Union[pl.DataFrame, pl.LazyFrame]

class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[pl.DataFrame]):
def from_arrow(
self, obj: pa.RecordBatchReader, target_type: Type[pl.DataFrame]
) -> pl.DataFrame:
return pl.from_arrow(obj) # type: ignore

def to_arrow(self, obj: pl.DataFrame) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[PolarsTypes]):
def from_arrow(
self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: Type[PolarsTypes]
) -> PolarsTypes:
if isinstance(obj, pa.RecordBatchReader):
df = pl.DataFrame(obj.read_all())
if target_type == pl.LazyFrame:
## Maybe allow this but raise a warning that the data has been sliced earlier otherwise it
## would have received a ds.dataset
return df.lazy()
else:
return df
df = pl.scan_pyarrow_dataset(obj)
if target_type == pl.DataFrame:
return df.collect()
else:
return df

def to_arrow(self, obj: PolarsTypes) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
return obj.to_arrow().to_reader(), {"large_dtypes": True}

@property
def supported_types(self) -> Sequence[Type[object]]:
return [pl.DataFrame]
return [pl.DataFrame, pl.LazyFrame]


class DeltaLakePolarsIOManager(DeltaLakeIOManager):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,33 @@ def b_plus_one(b_df: pl.DataFrame) -> pl.DataFrame:
return b_df + 1


def test_deltalake_io_manager_with_assets(tmp_path, io_manager):
@asset(key_prefix=["my_schema"])
def b_df_lazy() -> pl.LazyFrame:
return pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})


@asset(key_prefix=["my_schema"])
def b_plus_one_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame:
return b_df_lazy.select(pl.all()+1)


@pytest.mark.parametrize("asset1,asset2,asset1_path,asset2_path", [
(b_df, b_plus_one, "b_df", 'b_plus_one'),
(b_df_lazy, b_plus_one_lazy, "b_df_lazy", "b_plus_one_lazy")
])
def test_deltalake_io_manager_with_assets(tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path):
resource_defs = {"io_manager": io_manager}

# materialize asset twice to ensure that tables get properly deleted
for _ in range(2):
res = materialize([b_df, b_plus_one], resources=resource_defs)
res = materialize([asset1, asset2], resources=resource_defs)
assert res.success

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_df"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/"+ asset1_path ))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "my_schema/b_plus_one"))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/"+ asset2_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

Expand Down Expand Up @@ -142,7 +156,6 @@ def test_loading_columns(tmp_path, io_manager):

assert out_df.shape[1] == 1


@op
def non_supported_type() -> int:
return 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

class DeltalakeBaseArrowTypeHandler(DbTypeHandler[T], Generic[T]):
@abstractmethod
def from_arrow(self, obj: pa.RecordBatchReader, target_type: type) -> T:
def from_arrow(self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: type) -> T:
pass

@abstractmethod
Expand Down Expand Up @@ -132,12 +132,17 @@ def load_input(
raise ValueError("Cannot select columns when loading as Dataset.")
return dataset

scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)
if table_slice.columns is None:
return self.from_arrow(dataset, context.dagster_type.typing_type)
else:
scanner = dataset.scanner(columns=table_slice.columns)
return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type)


class DeltaLakePyArrowTypeHandler(DeltalakeBaseArrowTypeHandler[ArrowTypes]):
def from_arrow(self, obj: pa.RecordBatchReader, target_type: Type[ArrowTypes]) -> ArrowTypes:
def from_arrow(self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: Type[ArrowTypes]) -> ArrowTypes:
if isinstance(obj, ds.Dataset):
obj = obj.scanner().to_reader()
if target_type == pa.Table:
return obj.read_all()
return obj
Expand Down

0 comments on commit 9eb0a94

Please sign in to comment.