Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add SCD2 delta writer #4

Merged
merged 18 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* @Nike-Inc/koheesio-dev
73 changes: 73 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: test

on:
push:
branches:
- main
pull_request:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true

env:
STABLE_PYTHON_VERSION: '3.11'
PYTHONUNBUFFERED: "1"
FORCE_COLOR: "1"

jobs:
tests:
name: Python ${{ matrix.python-version }} with PySpark ${{ matrix.pyspark-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
# os: [ubuntu-latest, windows-latest, macos-latest]
mikita-sakalouski marked this conversation as resolved.
Show resolved Hide resolved
python-version: ['3.9', '3.10', '3.11', '3.12']
pyspark-version: ['33', '34', '35']
exclude:
- python-version: '3.9'
pyspark-version: '35'
- python-version: '3.11'
pyspark-version: '33'
- python-version: '3.11'
pyspark-version: '34'
- python-version: '3.12'
pyspark-version: '33'
- python-version: '3.12'
pyspark-version: '34'

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Ensure latest pip
run: python -m pip install --upgrade pip

- name: Install hatch
run: pip install hatch

- name: Run tests
run: hatch run test.py${{ matrix.python-version }}-pyspark${{ matrix.pyspark-version }}:all-tests

# https://github.com/marketplace/actions/alls-green#why
final_check: # This job does nothing and is only used for the branch protection
if: always()

needs:
- tests

runs-on: ubuntu-latest

steps:
- name: Decide whether the needed jobs succeeded or failed
uses: re-actors/alls-green@release/v1
with:
jobs: ${{ toJSON(needs) }}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ out/**

/test/integration/**/task_definition/*.yaml
/.vscode/settings.json
/.vscode/launch.json

.databricks

Expand Down
3 changes: 1 addition & 2 deletions koheesio/steps/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from py4j.protocol import Py4JJavaError # type: ignore
from pyspark.sql import DataFrame
from pyspark.sql.types import DataType
from pyspark.sql.utils import AnalysisException

from koheesio.models import Field, field_validator, model_validator
from koheesio.steps.spark import SparkStep
from koheesio.steps.spark import AnalysisException, SparkStep
from koheesio.utils import on_databricks


Expand Down
22 changes: 21 additions & 1 deletion koheesio/steps/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,24 @@
from typing import Optional

from pydantic import Field
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import Column
from pyspark.sql import DataFrame as PySparkSQLDataFrame
from pyspark.sql import SparkSession as OriginalSparkSession
from pyspark.sql import functions as F

try:
from pyspark.sql.utils import AnalysisException as SparkAnalysisException
except ImportError:
from pyspark.errors.exceptions.base import AnalysisException as SparkAnalysisException

from koheesio.steps.step import Step, StepOutput

# TODO: Move to spark/__init__.py after reorganizing the code
# Will be used for typing checks and consistency, specifically for PySpark >=3.5
DataFrame = PySparkSQLDataFrame
SparkSession = OriginalSparkSession
AnalysisException = SparkAnalysisException


class SparkStep(Step, ABC):
"""Base class for a Spark step
Expand All @@ -30,3 +44,9 @@ class Output(StepOutput):
def spark(self) -> Optional[SparkSession]:
"""Get active SparkSession instance"""
return SparkSession.getActiveSession()


# TODO: Move to spark/functions/__init__.py after reorganizing the code
def current_timestamp_utc(spark: SparkSession) -> Column:
"""Get the current timestamp in UTC"""
return F.to_utc_timestamp(F.current_timestamp(), spark.conf.get("spark.sql.session.timeZone"))
13 changes: 12 additions & 1 deletion koheesio/steps/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import inspect
import json
import sys
import warnings
from abc import abstractmethod
from functools import partialmethod, wraps
Expand Down Expand Up @@ -50,6 +51,14 @@ class StepMetaClass(ModelMetaclass):
allowing for the execute method to be auto-decorated with do_execute
"""

# Solution to overcome issue with python>=3.11,
# When partialmethod is forgetting that _execute_wrapper
# is a method of wrapper, and it needs to pass that in as the first arg.
# https://github.com/python/cpython/issues/99152
class _partialmethod_with_self(partialmethod):
def __get__(self, obj, cls=None):
return self._make_unbound_method().__get__(obj, cls)

# Unique object to mark a function as wrapped
_step_execute_wrapper_sentinel = object()

Expand Down Expand Up @@ -123,7 +132,9 @@ def __new__(
if not is_already_wrapped:
# Create a partial method with the execute_method as one of the arguments.
# This is the new function that will be called instead of the original execute_method.
wrapper = partialmethod(cls._execute_wrapper, execute_method=execute_method)

partial_impl = partialmethod if sys.version_info < (3, 11) else mcs._partialmethod_with_self
mikita-sakalouski marked this conversation as resolved.
Show resolved Hide resolved
wrapper = partial_impl(cls._execute_wrapper, execute_method=execute_method)

# Updating the attributes of the wrapping function to those of the original function.
wraps(execute_method)(wrapper) # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion koheesio/steps/writers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Literal, Optional

from pandas._typing import CompressionOptions as PandasCompressionOptions
from pydantic import InstanceOf
from pyspark import pandas

from koheesio.models import ExtraParamsMixin, Field, constr
Expand All @@ -46,7 +47,9 @@ class BufferWriter(Writer, ABC):
class Output(Writer.Output, ABC):
"""Output class for BufferWriter"""

buffer: SpooledTemporaryFile = Field(default_factory=partial(SpooledTemporaryFile, mode="w+b", max_size=0))
buffer: InstanceOf[SpooledTemporaryFile] = Field(
default_factory=partial(SpooledTemporaryFile, mode="w+b", max_size=0), exclude=True
)

def read(self):
"""Read the buffer"""
Expand Down
14 changes: 14 additions & 0 deletions koheesio/steps/writers/delta/__init__.py
mikita-sakalouski marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
This module is the entry point for the koheesio.steps.writers.delta package.

It imports and exposes the DeltaTableWriter and DeltaTableStreamWriter classes for external use.

Classes:
DeltaTableWriter: Class to write data in batch mode to a Delta table.
DeltaTableStreamWriter: Class to write data in streaming mode to a Delta table.
"""

from koheesio.steps.writers.delta.batch import DeltaTableWriter
from koheesio.steps.writers.delta.stream import DeltaTableStreamWriter

__all__ = ["DeltaTableWriter", "DeltaTableStreamWriter"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This module defines the DeltaTableWriter and DeltaTableStreamWriter class, which is used to write both batch and
streaming dataframes to Delta tables.
This module defines the DeltaTableWriter class, which is used to write both batch and streaming dataframes
to Delta tables.

DeltaTableWriter supports two output modes: `MERGEALL` and `MERGE`.

Expand Down Expand Up @@ -38,14 +38,13 @@
from typing import List, Optional, Set, Type, Union

from delta.tables import DeltaMergeBuilder, DeltaTable
from py4j.java_gateway import JavaObject
from py4j.protocol import Py4JError
from pyspark.sql import DataFrameWriter

from koheesio.models import ExtraParamsMixin, Field, field_validator
from koheesio.steps.delta import DeltaTableStep
from koheesio.steps.writers import BatchOutputMode, StreamingOutputMode, Writer
from koheesio.steps.writers.stream import StreamWriter
from koheesio.steps.writers.delta.utils import log_clauses
from koheesio.utils import on_databricks


Expand Down Expand Up @@ -166,7 +165,7 @@ def __merge(self, merge_builder: Optional[DeltaMergeBuilder] = None) -> Union[De
self.log.debug(
f"The following aliases are used during Merge operation: source={source_alias}, target={target_alias}"
)
patched__log_clauses = partial(_log_clauses, source_alias=source_alias, target_alias=target_alias)
patched__log_clauses = partial(log_clauses, source_alias=source_alias, target_alias=target_alias)
self.log.debug(
patched__log_clauses(clauses=merge_builder._jbuilder.getMergePlan().matchedClauses())
)
Expand Down Expand Up @@ -252,7 +251,7 @@ def _get_merge_builder(self, provided_merge_builder=None):
if isinstance(merge_builder, DeltaMergeBuilder):
return merge_builder

if isinstance(merge_builder, list) and "merge_cond" in self.params:
if isinstance(merge_builder, list) and "merge_cond" in self.params: # type: ignore
return self._merge_builder_from_args()

raise ValueError(
Expand Down Expand Up @@ -364,81 +363,3 @@ def execute(self):
# should we add options only if mode is not merge?
_writer = _writer.options(**options)
_writer.saveAsTable(self.table.table_name)


class DeltaTableStreamWriter(StreamWriter, DeltaTableWriter):
"""Delta table stream writer"""

class Options:
"""Options for DeltaTableStreamWriter"""

allow_population_by_field_name = True # To do convert to Field and pass as .options(**config)
maxBytesPerTrigger = None # How much data to be processed per trigger. The default is 1GB
maxFilesPerTrigger = 1000 # How many new files to be considered in every micro-batch. The default is 1000

def execute(self):
if self.batch_function:
self.streaming_query = self.writer.start()
else:
self.streaming_query = self.writer.toTable(tableName=self.table.table_name)


def _log_clauses(clauses: JavaObject, source_alias: str, target_alias: str) -> Optional[str]:
"""
Prepare log message for clauses of DeltaMergePlan statement.

Parameters
----------
clauses : JavaObject
The clauses of the DeltaMergePlan statement.
source_alias : str
The source alias.
target_alias : str
The target alias.

Returns
-------
Optional[str]
The log message if there are clauses, otherwise None.

Notes
-----
This function prepares a log message for the clauses of a DeltaMergePlan statement. It iterates over the clauses,
processes the conditions, and constructs the log message based on the clause type and columns.

If the condition is a value, it replaces the source and target aliases in the condition string. If the condition is
None, it sets the condition_clause to "No conditions required".

The log message includes the clauses type, the clause type, the columns, and the condition.
"""
log_message = None

if not clauses.isEmpty():
clauses_type = clauses.last().nodeName().replace("DeltaMergeInto", "")
_processed_clauses = {}

for i in range(0, clauses.length()):
clause = clauses.apply(i)
condition = clause.condition()

if "value" in dir(condition):
condition_clause = (
condition.value()
.toString()
.replace(f"'{source_alias}", source_alias)
.replace(f"'{target_alias}", target_alias)
)
elif condition.toString() == "None":
condition_clause = "No conditions required"

clause_type: str = clause.clauseType().capitalize()
columns = "ALL" if clause_type == "Delete" else clause.actions().toList().apply(0).toString()

if clause_type.lower() not in _processed_clauses:
_processed_clauses[clause_type.lower()] = []

log_message = (
f"{clauses_type} will perform action:{clause_type} columns ({columns}) if `{condition_clause}`"
)

return log_message
Loading