Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(data-warehouse): New pipeline WIP #26341

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
392 changes: 392 additions & 0 deletions posthog/temporal/data_imports/pipelines/pipeline_non_dlt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,392 @@
import json
import time
from collections.abc import Sequence
from conditional_cache import lru_cache
from typing import Any
import pyarrow as pa
from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_schema
from dlt.common.normalizers.naming.snake_case import NamingConvention
from dlt.sources import DltSource, DltResource
import deltalake as deltalake
from django.conf import settings
from django.db.models import F
from posthog.settings.base_variables import TEST
from posthog.temporal.common.logger import FilteringBoundLogger
from posthog.temporal.data_imports.pipelines.pipeline_sync import validate_schema_and_update_table_sync
from posthog.temporal.data_imports.util import prepare_s3_files_for_querying
from posthog.warehouse.models import DataWarehouseTable, ExternalDataJob, ExternalDataSchema
from posthog.hogql.database.models import (
BooleanDatabaseField,
DatabaseField,
DateDatabaseField,
DateTimeDatabaseField,
FloatDatabaseField,
IntegerDatabaseField,
StringDatabaseField,
StringJSONDatabaseField,
)


class HogQLSchema:
schema: dict[str, str]

def __init__(self):
self.schema = {}

def add_pyarrow_table(self, table: pa.Table) -> None:
for field in table.schema:
self.add_field(field, table.column(field.name))

def add_field(self, field: pa.Field, column: pa.ChunkedArray) -> None:
existing_type = self.schema.get(field.name)
if existing_type is not None and existing_type != StringDatabaseField.__name__:
return

hogql_type: type[DatabaseField] = DatabaseField

if pa.types.is_time(field.type):
hogql_type = DateTimeDatabaseField
elif pa.types.is_timestamp(field.type):
hogql_type = DateTimeDatabaseField
elif pa.types.is_date(field.type):
hogql_type = DateDatabaseField
elif pa.types.is_decimal(field.type):
hogql_type = FloatDatabaseField
elif pa.types.is_floating(field.type):
hogql_type = FloatDatabaseField
elif pa.types.is_boolean(field.type):
hogql_type = BooleanDatabaseField
elif pa.types.is_integer(field.type):
hogql_type = IntegerDatabaseField
elif pa.types.is_binary(field.type):
raise Exception("Type 'binary' is not a supported column type")
elif pa.types.is_string(field.type):
hogql_type = StringDatabaseField

# Checking for JSON string columns with the first non-null value in the column
for value in column:
value_str = value.as_py()
if value_str is not None:
assert isinstance(value_str, str)
if value_str.startswith("{") or value_str.startswith("["):
hogql_type = StringJSONDatabaseField
break

self.schema[field.name] = hogql_type.__name__

def to_hogql_types(self) -> dict[str, str]:
return self.schema


class DeltaTableHelper:
_resource_name: str
_job: ExternalDataJob

def __init__(self, resource_name: str, job: ExternalDataJob) -> None:
self._resource_name = resource_name
self._job = job

def _get_credentials(self):
if TEST:
return {
"aws_access_key_id": settings.AIRBYTE_BUCKET_KEY,
"aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET,
"endpoint_url": settings.OBJECT_STORAGE_ENDPOINT,
"region_name": settings.AIRBYTE_BUCKET_REGION,
"AWS_ALLOW_HTTP": "true",
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
}

return {
"aws_access_key_id": settings.AIRBYTE_BUCKET_KEY,
"aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET,
"region_name": settings.AIRBYTE_BUCKET_REGION,
"AWS_DEFAULT_REGION": settings.AIRBYTE_BUCKET_REGION,
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
}

def _get_delta_table_uri(self) -> str:
normalized_resource_name = NamingConvention().normalize_identifier(self._resource_name)
return f"{settings.BUCKET_URL}/{self._job.folder_path()}/{normalized_resource_name}"

def _evolve_delta_schema(self, schema: pa.Schema) -> deltalake.DeltaTable:
delta_table = self.get_delta_table()
if delta_table is None:
raise Exception("Deltalake table not found")

delta_table_schema = delta_table.schema().to_pyarrow()

new_fields = [
deltalake.Field.from_pyarrow(field)
for field in ensure_delta_compatible_arrow_schema(schema)
if field.name not in delta_table_schema.names
]
if new_fields:
delta_table.alter.add_columns(new_fields)

return delta_table

@lru_cache(maxsize=1, condition=lambda result: result is not None)
def get_delta_table(self) -> deltalake.DeltaTable | None:
delta_uri = self._get_delta_table_uri()
storage_options = self._get_credentials()

if deltalake.DeltaTable.is_deltatable(table_uri=delta_uri, storage_options=storage_options):
return deltalake.DeltaTable(table_uri=delta_uri, storage_options=storage_options)

return None

def write_to_deltalake(
self, data: pa.Table, is_incremental: bool, chunk_index: int, primary_keys: Sequence[Any] | None
) -> deltalake.DeltaTable:
delta_table = self.get_delta_table()

if delta_table:
delta_table = self._evolve_delta_schema(data.schema)

if is_incremental and delta_table is not None:
if not primary_keys or len(primary_keys) == 0:
raise Exception("Primary key required for incremental syncs")

delta_table.merge(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will function work on an empty delta table? Asking because it'd clean up this logic a bunch if we could just handle if delta_table is None before this entire if block

source=data,
source_alias="source",
target_alias="target",
predicate=" AND ".join([f"source.{c} = target.{c}" for c in primary_keys]),
).when_matched_update_all().when_not_matched_insert_all().execute()
else:
mode = "append"
schema_mode = "merge"
if chunk_index == 0 or delta_table is None:
mode = "overwrite"
schema_mode = "overwrite"

if delta_table is None:
delta_table = deltalake.DeltaTable.create(table_uri=self._get_delta_table_uri(), schema=data.schema)

deltalake.write_deltalake(
table_or_uri=delta_table,
data=data,
partition_by=None,
mode=mode,
schema_mode=schema_mode,
engine="rust",
) # type: ignore

delta_table = self.get_delta_table()
assert delta_table is not None

return delta_table


class PipelineNonDLT:
_resource: DltResource
_resource_name: str
_job: ExternalDataJob
_schema: ExternalDataSchema
_logger: FilteringBoundLogger
_is_incremental: bool
_delta_table_helper: DeltaTableHelper
_internal_schema = HogQLSchema()
_load_id: int

def __init__(self, source: DltSource, logger: FilteringBoundLogger, job_id: str, is_incremental: bool) -> None:
resources = list(source.resources.items())
assert len(resources) == 1
resource_name, resource = resources[0]

self._resource = resource
self._resource_name = resource_name
self._job = ExternalDataJob.objects.prefetch_related("schema").get(id=job_id)
self._is_incremental = is_incremental
self._logger = logger
self._load_id = time.time_ns()

schema: ExternalDataSchema | None = self._job.schema
assert schema is not None
self._schema = schema

self._delta_table_helper = DeltaTableHelper(resource_name, self._job)
self._internal_schema = HogQLSchema()

def run(self):
buffer: list[Any] = []
chunk_size = 5000
row_count = 0
chunk_index = 0

for item in self._resource:
py_table = None

if isinstance(item, list):
if len(buffer) > 0:
buffer.extend(item)
if len(buffer) >= chunk_size:
py_table = pa.Table.from_pylist(buffer)
buffer = []
else:
if len(item) >= chunk_size:
py_table = pa.Table.from_pylist(item)
else:
buffer.extend(item)
continue
elif isinstance(item, dict):
buffer.append(item)
if len(buffer) < chunk_size:
continue

py_table = pa.Table.from_pylist(buffer)
buffer = []
elif isinstance(item, pa.Table):
py_table = item
else:
raise Exception(f"Unhandled item type: {item.__class__.__name__}")

assert py_table is not None

self._process_pa_table(pa_table=py_table, index=chunk_index)

row_count += py_table.num_rows
chunk_index += 1

if len(buffer) > 0:
py_table = pa.Table.from_pylist(buffer)
self._process_pa_table(pa_table=py_table, index=chunk_index)
row_count += py_table.num_rows

self._post_run_operations(row_count=row_count)

def _process_pa_table(self, pa_table: pa.Table, index: int):
delta_table = self._delta_table_helper.get_delta_table()

pa_table = _append_debug_column_to_pyarrows_table(pa_table, self._load_id)
pa_table = _evolve_pyarrow_schema(pa_table, delta_table.schema() if delta_table is not None else None)

table_primary_keys = self._get_primary_keys()
delta_table = self._delta_table_helper.write_to_deltalake(
pa_table, self._is_incremental, index, table_primary_keys
)

self._internal_schema.add_pyarrow_table(pa_table)

_update_incrementality(self._schema, pa_table, self._logger)
_update_job_row_count(self._job.id, pa_table.num_rows, self._logger)

def _post_run_operations(self, row_count: int):
delta_table = self._delta_table_helper.get_delta_table()

assert delta_table is not None

self._logger.info("Compacting delta table")
delta_table.optimize.compact()
delta_table.vacuum(retention_hours=24, enforce_retention_duration=False, dry_run=False)

file_uris = delta_table.file_uris()
self._logger.info(f"Preparing S3 files - total parquet files: {len(file_uris)}")
prepare_s3_files_for_querying(self._job.folder_path(), self._resource_name, file_uris)

self._logger.debug("Validating schema and updating table")

validate_schema_and_update_table_sync(
run_id=str(self._job.id),
team_id=self._job.team_id,
schema_id=self._schema.id,
table_schema={},
table_schema_dict=self._internal_schema.to_hogql_types(),
row_count=row_count,
table_format=DataWarehouseTable.TableFormat.DeltaS3Wrapper,
)

def _get_primary_keys(self) -> list[Any] | None:
primary_keys = self._resource._hints.get("primary_key")

if primary_keys is None:
return None

if isinstance(primary_keys, list):
return primary_keys

if isinstance(primary_keys, Sequence):
return list(primary_keys)

raise Exception(f"primary_keys of type {primary_keys.__class__.__name__} are not supported")


def _evolve_pyarrow_schema(table: pa.Table, delta_schema: deltalake.Schema | None) -> pa.Table:
py_table_field_names = table.schema.names

# Change pa.structs to JSON string
for column_name in table.column_names:
column = table.column(column_name)
if pa.types.is_struct(column.type) or pa.types.is_list(column.type):
json_column = pa.array([json.dumps(row.as_py()) if row.as_py() is not None else None for row in column])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope here but an issue I just discovered that might be addressable here. clickhouse s3 can't deserialize a list like ["test"]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you more context to this? A support ticket maybe?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

table = table.set_column(table.schema.get_field_index(column_name), column_name, json_column)

if delta_schema:
for field in delta_schema.to_pyarrow():
if field.name not in py_table_field_names:
if field.nullable:
new_column_data = pa.array([None] * table.num_rows, type=field.type)
else:
new_column_data = pa.array(
[_get_default_value_from_pyarrow_type(field.type)] * table.num_rows, type=field.type
)
table = table.append_column(field, new_column_data)

# Change types based on what deltalake tables support
return table.cast(ensure_delta_compatible_arrow_schema(table.schema))


def _append_debug_column_to_pyarrows_table(table: pa.Table, load_id: int) -> pa.Table:
debug_info = f'{{"load_id": {load_id}}}'

column = pa.array([debug_info] * table.num_rows, type=pa.string())
return table.append_column("_ph_debug", column)


def _get_default_value_from_pyarrow_type(pyarrow_type: pa.DataType):
"""
Returns a default value for the given PyArrow type.
"""
if pa.types.is_integer(pyarrow_type):
return 0
elif pa.types.is_floating(pyarrow_type):
return 0.0
elif pa.types.is_string(pyarrow_type):
return ""
elif pa.types.is_boolean(pyarrow_type):
return False
elif pa.types.is_binary(pyarrow_type):
return b""
elif pa.types.is_timestamp(pyarrow_type):
return pa.scalar(0, type=pyarrow_type).as_py()
elif pa.types.is_date(pyarrow_type):
return pa.scalar(0, type=pyarrow_type).as_py()
elif pa.types.is_time(pyarrow_type):
return pa.scalar(0, type=pyarrow_type).as_py()
else:
raise ValueError(f"No default value defined for type: {pyarrow_type}")


def _update_incrementality(schema: ExternalDataSchema | None, table: pa.Table, logger: FilteringBoundLogger) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: _update_increment_state?

if schema is None or schema.sync_type != ExternalDataSchema.SyncType.INCREMENTAL:
return

incremental_field_name: str | None = schema.sync_type_config.get("incremental_field")
if incremental_field_name is None:
return

column = table[incremental_field_name]
numpy_arr = column.combine_chunks().to_pandas().to_numpy()

# TODO(@Gilbert09): support different operations here (e.g. min)
last_value = numpy_arr.max()

logger.debug(f"Updating incremental_field_last_value with {last_value}")

schema.update_incremental_field_last_value(last_value)


def _update_job_row_count(job_id: str, count: int, logger: FilteringBoundLogger) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this outside of the Pipeline class?

logger.debug(f"Updating rows_synced with +{count}")
ExternalDataJob.objects.filter(id=job_id).update(rows_synced=F("rows_synced") + count)
Loading
Loading