Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Jan 20, 2024
1 parent 9eb0a94 commit 4302375
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
from dagster._core.storage.db_io_manager import (
DbTypeHandler,
)
Expand All @@ -14,8 +15,10 @@

class DeltaLakePandasTypeHandler(DeltalakeBaseArrowTypeHandler[pd.DataFrame]):
def from_arrow(
self, obj: pa.RecordBatchReader, target_type: Type[pd.DataFrame]
self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: Type[pd.DataFrame]
) -> pd.DataFrame:
if isinstance(obj, ds.Dataset):
obj = obj.scanner().to_reader()
return obj.read_pandas()

def to_arrow(self, obj: pd.DataFrame) -> Tuple[pa.RecordBatchReader, Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import polars as pl
import pyarrow as pa
Expand All @@ -17,16 +17,16 @@

class DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[PolarsTypes]):
def from_arrow(
self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: Type[PolarsTypes]
self, obj: Union[ds.Dataset, pa.RecordBatchReader], target_type: Type[PolarsTypes]
) -> PolarsTypes:
if isinstance(obj, pa.RecordBatchReader):
df = pl.DataFrame(obj.read_all())
df = pl.DataFrame(obj.read_all())
if target_type == pl.LazyFrame:
## Maybe allow this but raise a warning that the data has been sliced earlier otherwise it
## would have received a ds.dataset
return df.lazy()
else:
return df
return df
df = pl.scan_pyarrow_dataset(obj)
if target_type == pl.DataFrame:
return df.collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,31 @@ def b_df_lazy() -> pl.LazyFrame:

@asset(key_prefix=["my_schema"])
def b_plus_one_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame:
return b_df_lazy.select(pl.all()+1)
return b_df_lazy.select(pl.all() + 1)


@pytest.mark.parametrize("asset1,asset2,asset1_path,asset2_path", [
(b_df, b_plus_one, "b_df", 'b_plus_one'),
(b_df_lazy, b_plus_one_lazy, "b_df_lazy", "b_plus_one_lazy")
])
def test_deltalake_io_manager_with_assets(tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path):
@pytest.mark.parametrize(
"asset1,asset2,asset1_path,asset2_path",
[
(b_df, b_plus_one, "b_df", "b_plus_one"),
(b_df_lazy, b_plus_one_lazy, "b_df_lazy", "b_plus_one_lazy"),
],
)
def test_deltalake_io_manager_with_assets(
tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path
):
resource_defs = {"io_manager": io_manager}

# materialize asset twice to ensure that tables get properly deleted
for _ in range(2):
res = materialize([asset1, asset2], resources=resource_defs)
assert res.success

dt = DeltaTable(os.path.join(tmp_path, "my_schema/"+ asset1_path ))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "my_schema/"+ asset2_path))
dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

Expand Down Expand Up @@ -156,6 +161,7 @@ def test_loading_columns(tmp_path, io_manager):

assert out_df.shape[1] == 1


@op
def non_supported_type() -> int:
return 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

class DeltalakeBaseArrowTypeHandler(DbTypeHandler[T], Generic[T]):
@abstractmethod
def from_arrow(self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: type) -> T:
def from_arrow(self, obj: Union[ds.Dataset, pa.RecordBatchReader], target_type: type) -> T:
pass

@abstractmethod
Expand Down Expand Up @@ -140,7 +140,9 @@ def load_input(


class DeltaLakePyArrowTypeHandler(DeltalakeBaseArrowTypeHandler[ArrowTypes]):
def from_arrow(self, obj: Union[ds.dataset, pa.RecordBatchReader], target_type: Type[ArrowTypes]) -> ArrowTypes:
def from_arrow(
self, obj: Union[ds.Dataset, pa.RecordBatchReader], target_type: Type[ArrowTypes]
) -> ArrowTypes:
if isinstance(obj, ds.Dataset):
obj = obj.scanner().to_reader()
if target_type == pa.Table:
Expand Down

0 comments on commit 4302375

Please sign in to comment.