Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add empty source handling for delta table format on filesystem destination #1617

Merged
merged 14 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions .github/workflows/test_pyarrow17.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@

name: tests marked as needspyarrow17

on:
pull_request:
branches:
- master
- devel
workflow_dispatch:
schedule:
- cron: '0 2 * * *'

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

env:

DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }}

RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752
RUNTIME__LOG_LEVEL: ERROR
RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }}

ACTIVE_DESTINATIONS: "[\"filesystem\"]"

jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}}

run_pyarrow17:
name: needspyarrow17 tests
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
defaults:
run:
shell: bash
runs-on: "ubuntu-latest"

steps:

- name: Check out
uses: actions/checkout@master

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10.x"

- name: Install Poetry
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-pyarrow17

- name: Install dependencies
run: poetry install --no-interaction --with sentry-sdk --with pipeline -E deltalake -E gs -E s3 -E az

- name: Upgrade pyarrow
run: poetry run pip install pyarrow==17.0.0

- name: create secrets.toml
run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml

- run: |
poetry run pytest tests/libs tests/load -m needspyarrow17
name: Run needspyarrow17 tests Linux
19 changes: 19 additions & 0 deletions dlt/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ def _to_pip_install(self) -> str:
return "\n".join([f'pip install "{d}"' for d in self.dependencies])


class DependencyVersionException(DltException):
def __init__(
self, pkg_name: str, version_found: str, version_required: str, appendix: str = ""
) -> None:
self.pkg_name = pkg_name
self.version_found = version_found
self.version_required = version_required
super().__init__(self._get_msg(appendix))

def _get_msg(self, appendix: str) -> str:
msg = (
f"Found `{self.pkg_name}=={self.version_found}`, while"
f" `{self.pkg_name}{self.version_required}` is required."
)
if appendix:
msg = msg + "\n" + appendix
return msg


class SystemConfigurationException(DltException):
pass

Expand Down
37 changes: 22 additions & 15 deletions dlt/common/libs/deltalake.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Optional, Dict, Union
from pathlib import Path

from dlt import version
from dlt.common import logger
from dlt.common.libs.pyarrow import pyarrow as pa
from dlt.common.libs.pyarrow import dataset_to_table, cast_arrow_schema_types
from dlt.common.libs.pyarrow import cast_arrow_schema_types
from dlt.common.schema.typing import TWriteDisposition
from dlt.common.exceptions import MissingDependencyException
from dlt.common.storages import FilesystemConfiguration

try:
from deltalake import write_deltalake
from deltalake import write_deltalake, DeltaTable
from deltalake.writer import try_get_deltatable
except ModuleNotFoundError:
raise MissingDependencyException(
Expand All @@ -19,21 +20,29 @@
)


def ensure_delta_compatible_arrow_table(table: pa.table) -> pa.Table:
"""Returns Arrow table compatible with Delta table format.
def ensure_delta_compatible_arrow_schema(schema: pa.Schema) -> pa.Schema:
"""Returns Arrow schema compatible with Delta table format.

Casts table schema to replace data types not supported by Delta.
Casts schema to replace data types not supported by Delta.
"""
ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP = {
# maps type check function to type factory function
pa.types.is_null: pa.string(),
pa.types.is_time: pa.string(),
pa.types.is_decimal256: pa.string(), # pyarrow does not allow downcasting to decimal128
}
adjusted_schema = cast_arrow_schema_types(
table.schema, ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP
)
return table.cast(adjusted_schema)
return cast_arrow_schema_types(schema, ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP)


def ensure_delta_compatible_arrow_data(
data: Union[pa.Table, pa.RecordBatchReader]
) -> Union[pa.Table, pa.RecordBatchReader]:
"""Returns Arrow data compatible with Delta table format.

Casts `data` schema to replace data types not supported by Delta.
"""
schema = ensure_delta_compatible_arrow_schema(data.schema)
return data.cast(schema)


def get_delta_write_mode(write_disposition: TWriteDisposition) -> str:
Expand All @@ -50,21 +59,19 @@ def get_delta_write_mode(write_disposition: TWriteDisposition) -> str:


def write_delta_table(
path: str,
data: Union[pa.Table, pa.dataset.Dataset],
table_or_uri: Union[str, Path, DeltaTable],
data: Union[pa.Table, pa.RecordBatchReader],
write_disposition: TWriteDisposition,
storage_options: Optional[Dict[str, str]] = None,
) -> None:
"""Writes in-memory Arrow table to on-disk Delta table."""

table = dataset_to_table(data)

# throws warning for `s3` protocol: https://github.com/delta-io/delta-rs/issues/2460
# TODO: upgrade `deltalake` lib after https://github.com/delta-io/delta-rs/pull/2500
# is released
write_deltalake( # type: ignore[call-overload]
table_or_uri=path,
data=ensure_delta_compatible_arrow_table(table),
table_or_uri=table_or_uri,
data=ensure_delta_compatible_arrow_data(data),
mode=get_delta_write_mode(write_disposition),
schema_mode="merge", # enable schema evolution (adding new columns)
storage_options=storage_options,
Expand Down
4 changes: 0 additions & 4 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,6 @@ def pq_stream_with_new_columns(
yield tbl


def dataset_to_table(data: Union[pyarrow.Table, pyarrow.dataset.Dataset]) -> pyarrow.Table:
return data.to_table() if isinstance(data, pyarrow.dataset.Dataset) else data


def cast_arrow_schema_types(
schema: pyarrow.Schema,
type_map: Dict[Callable[[pyarrow.DataType], bool], Callable[..., pyarrow.DataType]],
Expand Down
20 changes: 19 additions & 1 deletion dlt/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from types import ModuleType
import traceback
import zlib
from importlib.metadata import version as pkg_version
from packaging.version import Version

from typing import (
Any,
Expand All @@ -29,7 +31,12 @@
Iterable,
)

from dlt.common.exceptions import DltException, ExceptionTrace, TerminalException
from dlt.common.exceptions import (
DltException,
ExceptionTrace,
TerminalException,
DependencyVersionException,
)
from dlt.common.typing import AnyFun, StrAny, DictStrAny, StrStr, TAny, TFun


Expand Down Expand Up @@ -565,3 +572,14 @@ def order_deduped(lst: List[Any]) -> List[Any]:
Only works for lists with hashable elements.
"""
return list(dict.fromkeys(lst))


def assert_min_pkg_version(pkg_name: str, version: str, msg: str = "") -> None:
version_found = pkg_version(pkg_name)
if Version(version_found) < Version(version):
raise DependencyVersionException(
pkg_name=pkg_name,
version_found=version_found,
version_required=">=" + version,
appendix=msg,
)
49 changes: 35 additions & 14 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import dlt
from dlt.common import logger, time, json, pendulum
from dlt.common.utils import assert_min_pkg_version
from dlt.common.storages.fsspec_filesystem import glob_files
from dlt.common.typing import DictStrAny
from dlt.common.schema import Schema, TSchemaTables, TTableSchema
Expand Down Expand Up @@ -122,23 +123,43 @@ def __init__(
def write(self) -> None:
from dlt.common.libs.pyarrow import pyarrow as pa
from dlt.common.libs.deltalake import (
DeltaTable,
write_delta_table,
ensure_delta_compatible_arrow_schema,
_deltalake_storage_options,
try_get_deltatable,
)

file_paths = [job.file_path for job in self.table_jobs]
assert_min_pkg_version(
pkg_name="pyarrow",
version="17.0.0",
msg="`pyarrow>=17.0.0` is needed for `delta` table format on `filesystem` destination.",
)

if (
self.table["write_disposition"] == "merge"
and (
dt := try_get_deltatable(
self.client.make_remote_uri(self.make_remote_path()),
storage_options=_deltalake_storage_options(self.client.config),
# create Arrow dataset from Parquet files
file_paths = [job.file_path for job in self.table_jobs]
arrow_ds = pa.dataset.dataset(file_paths)

# create Delta table object
dt_path = self.client.make_remote_uri(self.make_remote_path())
storage_options = _deltalake_storage_options(self.client.config)
dt = try_get_deltatable(dt_path, storage_options=storage_options)

# explicitly check if there is data
# (https://github.com/delta-io/delta-rs/issues/2686)
if arrow_ds.head(1).num_rows == 0:
if dt is None:
# create new empty Delta table with schema from Arrow table
DeltaTable.create(
table_uri=dt_path,
schema=ensure_delta_compatible_arrow_schema(arrow_ds.schema),
mode="overwrite",
)
)
is not None
):
return

arrow_rbr = arrow_ds.scanner().to_reader() # RecordBatchReader

if self.table["write_disposition"] == "merge" and dt is not None:
assert self.table["x-merge-strategy"] in self.client.capabilities.supported_merge_strategies # type: ignore[typeddict-item]

if self.table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item]
Expand All @@ -151,7 +172,7 @@ def write(self) -> None:

qry = (
dt.merge(
source=pa.dataset.dataset(file_paths),
source=arrow_rbr,
predicate=predicate,
source_alias="source",
target_alias="target",
Expand All @@ -164,10 +185,10 @@ def write(self) -> None:

else:
write_delta_table(
path=self.client.make_remote_uri(self.make_remote_path()),
data=pa.dataset.dataset(file_paths),
table_or_uri=dt_path if dt is None else dt,
data=arrow_rbr,
write_disposition=self.table["write_disposition"],
storage_options=_deltalake_storage_options(self.client.config),
storage_options=storage_options,
)

def make_remote_path(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ python_functions = *_test test_* *_snippet
filterwarnings= ignore::DeprecationWarning
markers =
essential: marks all essential tests
no_load: marks tests that do not load anything
no_load: marks tests that do not load anything
needspyarrow17: marks tests that need pyarrow>=17.0.0 (deselected by default)
7 changes: 5 additions & 2 deletions tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def arrow_table_all_data_types(
include_date: bool = True,
include_not_normalized_name: bool = True,
include_name_clash: bool = False,
include_null: bool = True,
num_rows: int = 3,
tz="UTC",
) -> Tuple[Any, List[Dict[str, Any]], Dict[str, List[Any]]]:
Expand All @@ -323,9 +324,11 @@ def arrow_table_all_data_types(
"float_null": [round(random.uniform(0, 100), 4) for _ in range(num_rows - 1)] + [
None
], # decrease precision
"null": pd.Series([None for _ in range(num_rows)]),
}

if include_null:
data["null"] = pd.Series([None for _ in range(num_rows)])

if include_name_clash:
data["pre Normalized Column"] = [random.choice(ascii_lowercase) for _ in range(num_rows)]
include_not_normalized_name = True
Expand Down Expand Up @@ -373,7 +376,7 @@ def arrow_table_all_data_types(
"Pre Normalized Column": "pre_normalized_column",
}
)
.drop(columns=["null"])
.drop(columns=(["null"] if include_null else []))
.to_dict("records")
)
if object_format == "object":
Expand Down
14 changes: 13 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import dataclasses
import logging
from typing import List
import sys
import pytest
from typing import List, Iterator
from importlib.metadata import version as pkg_version
from packaging.version import Version

# patch which providers to enable
from dlt.common.configuration.providers import (
Expand Down Expand Up @@ -142,3 +146,11 @@ def _create_pipeline_instance_id(self) -> str:

except Exception:
pass


@pytest.fixture(autouse=True)
def pyarrow17_check(request) -> Iterator[None]:
if "needspyarrow17" in request.keywords:
if "pyarrow" not in sys.modules or Version(pkg_version("pyarrow")) < Version("17.0.0"):
pytest.skip("test needs `pyarrow>=17.0.0`")
yield
Loading
Loading