diff --git a/posthog/temporal/data_imports/pipelines/chargebee/__init__.py b/posthog/temporal/data_imports/pipelines/chargebee/__init__.py index 245afb6e5d880..7a093e65f7364 100644 --- a/posthog/temporal/data_imports/pipelines/chargebee/__init__.py +++ b/posthog/temporal/data_imports/pipelines/chargebee/__init__.py @@ -218,7 +218,13 @@ def update_request(self, request: Request) -> None: @dlt.source(max_table_nesting=0) def chargebee_source( - api_key: str, site_name: str, endpoint: str, team_id: int, job_id: str, is_incremental: bool = False + api_key: str, + site_name: str, + endpoint: str, + team_id: int, + job_id: str, + db_incremental_field_last_value: Optional[Any], + is_incremental: bool = False, ): config: RESTAPIConfig = { "client": { @@ -242,7 +248,7 @@ def chargebee_source( "resources": [get_resource(endpoint, is_incremental)], } - yield from rest_api_resources(config, team_id, job_id) + yield from rest_api_resources(config, team_id, job_id, db_incremental_field_last_value) def validate_credentials(api_key: str, site_name: str) -> bool: diff --git a/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py b/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py new file mode 100644 index 0000000000000..30e3cf0e466d5 --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/pipeline/delta_table_helper.py @@ -0,0 +1,111 @@ +from collections.abc import Sequence +from conditional_cache import lru_cache +from typing import Any +import pyarrow as pa +from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_schema +from dlt.common.normalizers.naming.snake_case import NamingConvention +import deltalake as deltalake +from django.conf import settings +from posthog.settings.base_variables import TEST +from posthog.warehouse.models import ExternalDataJob + + +class DeltaTableHelper: + _resource_name: str + _job: ExternalDataJob + + def __init__(self, resource_name: str, job: ExternalDataJob) -> None: + self._resource_name = resource_name + self._job = job + + def _get_credentials(self): + if TEST: + return { + "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, + "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, + "endpoint_url": settings.OBJECT_STORAGE_ENDPOINT, + "region_name": settings.AIRBYTE_BUCKET_REGION, + "AWS_ALLOW_HTTP": "true", + "AWS_S3_ALLOW_UNSAFE_RENAME": "true", + } + + return { + "aws_access_key_id": settings.AIRBYTE_BUCKET_KEY, + "aws_secret_access_key": settings.AIRBYTE_BUCKET_SECRET, + "region_name": settings.AIRBYTE_BUCKET_REGION, + "AWS_DEFAULT_REGION": settings.AIRBYTE_BUCKET_REGION, + "AWS_S3_ALLOW_UNSAFE_RENAME": "true", + } + + def _get_delta_table_uri(self) -> str: + normalized_resource_name = NamingConvention().normalize_identifier(self._resource_name) + return f"{settings.BUCKET_URL}/{self._job.folder_path()}/{normalized_resource_name}" + + def _evolve_delta_schema(self, schema: pa.Schema) -> deltalake.DeltaTable: + delta_table = self.get_delta_table() + if delta_table is None: + raise Exception("Deltalake table not found") + + delta_table_schema = delta_table.schema().to_pyarrow() + + new_fields = [ + deltalake.Field.from_pyarrow(field) + for field in ensure_delta_compatible_arrow_schema(schema) + if field.name not in delta_table_schema.names + ] + if new_fields: + delta_table.alter.add_columns(new_fields) + + return delta_table + + @lru_cache(maxsize=1, condition=lambda result: result is not None) + def get_delta_table(self) -> deltalake.DeltaTable | None: + delta_uri = self._get_delta_table_uri() + storage_options = self._get_credentials() + + if deltalake.DeltaTable.is_deltatable(table_uri=delta_uri, storage_options=storage_options): + return deltalake.DeltaTable(table_uri=delta_uri, storage_options=storage_options) + + return None + + def write_to_deltalake( + self, data: pa.Table, is_incremental: bool, chunk_index: int, primary_keys: Sequence[Any] | None + ) -> deltalake.DeltaTable: + delta_table = self.get_delta_table() + + if delta_table: + delta_table = self._evolve_delta_schema(data.schema) + + if is_incremental and delta_table is not None: + if not primary_keys or len(primary_keys) == 0: + raise Exception("Primary key required for incremental syncs") + + delta_table.merge( + source=data, + source_alias="source", + target_alias="target", + predicate=" AND ".join([f"source.{c} = target.{c}" for c in primary_keys]), + ).when_matched_update_all().when_not_matched_insert_all().execute() + else: + mode = "append" + schema_mode = "merge" + if chunk_index == 0 or delta_table is None: + mode = "overwrite" + schema_mode = "overwrite" + + if delta_table is None: + delta_table = deltalake.DeltaTable.create(table_uri=self._get_delta_table_uri(), schema=data.schema) + + deltalake.write_deltalake( + table_or_uri=delta_table, + data=data, + partition_by=None, + mode=mode, + schema_mode=schema_mode, + engine="rust", + ) # type: ignore + + delta_table = self.get_delta_table() + assert delta_table is not None + + return delta_table diff --git a/posthog/temporal/data_imports/pipelines/pipeline/hogql_schema.py b/posthog/temporal/data_imports/pipelines/pipeline/hogql_schema.py new file mode 100644 index 0000000000000..383a3296f0435 --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/pipeline/hogql_schema.py @@ -0,0 +1,63 @@ +import pyarrow as pa +import deltalake as deltalake +from posthog.hogql.database.models import ( + BooleanDatabaseField, + DatabaseField, + DateDatabaseField, + DateTimeDatabaseField, + FloatDatabaseField, + IntegerDatabaseField, + StringDatabaseField, + StringJSONDatabaseField, +) + + +class HogQLSchema: + schema: dict[str, str] + + def __init__(self): + self.schema = {} + + def add_pyarrow_table(self, table: pa.Table) -> None: + for field in table.schema: + self.add_field(field, table.column(field.name)) + + def add_field(self, field: pa.Field, column: pa.ChunkedArray) -> None: + existing_type = self.schema.get(field.name) + if existing_type is not None and existing_type != StringDatabaseField.__name__: + return + + hogql_type: type[DatabaseField] = DatabaseField + + if pa.types.is_time(field.type): + hogql_type = DateTimeDatabaseField + elif pa.types.is_timestamp(field.type): + hogql_type = DateTimeDatabaseField + elif pa.types.is_date(field.type): + hogql_type = DateDatabaseField + elif pa.types.is_decimal(field.type): + hogql_type = FloatDatabaseField + elif pa.types.is_floating(field.type): + hogql_type = FloatDatabaseField + elif pa.types.is_boolean(field.type): + hogql_type = BooleanDatabaseField + elif pa.types.is_integer(field.type): + hogql_type = IntegerDatabaseField + elif pa.types.is_binary(field.type): + raise Exception("Type 'binary' is not a supported column type") + elif pa.types.is_string(field.type): + hogql_type = StringDatabaseField + + # Checking for JSON string columns with the first non-null value in the column + for value in column: + value_str = value.as_py() + if value_str is not None: + assert isinstance(value_str, str) + if value_str.startswith("{") or value_str.startswith("["): + hogql_type = StringJSONDatabaseField + break + + self.schema[field.name] = hogql_type.__name__ + + def to_hogql_types(self) -> dict[str, str]: + return self.schema diff --git a/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py b/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py new file mode 100644 index 0000000000000..96f938a32e55f --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/pipeline/pipeline.py @@ -0,0 +1,137 @@ +import time +from typing import Any +import pyarrow as pa +from dlt.sources import DltSource, DltResource +import deltalake as deltalake +from posthog.temporal.common.logger import FilteringBoundLogger +from posthog.temporal.data_imports.pipelines.pipeline.utils import ( + _update_incremental_state, + _get_primary_keys, + _evolve_pyarrow_schema, + _append_debug_column_to_pyarrows_table, + _update_job_row_count, +) +from posthog.temporal.data_imports.pipelines.pipeline.delta_table_helper import DeltaTableHelper +from posthog.temporal.data_imports.pipelines.pipeline.hogql_schema import HogQLSchema +from posthog.temporal.data_imports.pipelines.pipeline_sync import validate_schema_and_update_table_sync +from posthog.temporal.data_imports.util import prepare_s3_files_for_querying +from posthog.warehouse.models import DataWarehouseTable, ExternalDataJob, ExternalDataSchema + + +class PipelineNonDLT: + _resource: DltResource + _resource_name: str + _job: ExternalDataJob + _schema: ExternalDataSchema + _logger: FilteringBoundLogger + _is_incremental: bool + _delta_table_helper: DeltaTableHelper + _internal_schema = HogQLSchema() + _load_id: int + + def __init__(self, source: DltSource, logger: FilteringBoundLogger, job_id: str, is_incremental: bool) -> None: + resources = list(source.resources.items()) + assert len(resources) == 1 + resource_name, resource = resources[0] + + self._resource = resource + self._resource_name = resource_name + self._job = ExternalDataJob.objects.prefetch_related("schema").get(id=job_id) + self._is_incremental = is_incremental + self._logger = logger + self._load_id = time.time_ns() + + schema: ExternalDataSchema | None = self._job.schema + assert schema is not None + self._schema = schema + + self._delta_table_helper = DeltaTableHelper(resource_name, self._job) + self._internal_schema = HogQLSchema() + + def run(self): + buffer: list[Any] = [] + chunk_size = 5000 + row_count = 0 + chunk_index = 0 + + for item in self._resource: + py_table = None + + if isinstance(item, list): + if len(buffer) > 0: + buffer.extend(item) + if len(buffer) >= chunk_size: + py_table = pa.Table.from_pylist(buffer) + buffer = [] + else: + if len(item) >= chunk_size: + py_table = pa.Table.from_pylist(item) + else: + buffer.extend(item) + continue + elif isinstance(item, dict): + buffer.append(item) + if len(buffer) < chunk_size: + continue + + py_table = pa.Table.from_pylist(buffer) + buffer = [] + elif isinstance(item, pa.Table): + py_table = item + else: + raise Exception(f"Unhandled item type: {item.__class__.__name__}") + + assert py_table is not None + + self._process_pa_table(pa_table=py_table, index=chunk_index) + + row_count += py_table.num_rows + chunk_index += 1 + + if len(buffer) > 0: + py_table = pa.Table.from_pylist(buffer) + self._process_pa_table(pa_table=py_table, index=chunk_index) + row_count += py_table.num_rows + + self._post_run_operations(row_count=row_count) + + def _process_pa_table(self, pa_table: pa.Table, index: int): + delta_table = self._delta_table_helper.get_delta_table() + + pa_table = _append_debug_column_to_pyarrows_table(pa_table, self._load_id) + pa_table = _evolve_pyarrow_schema(pa_table, delta_table.schema() if delta_table is not None else None) + + table_primary_keys = _get_primary_keys(self._resource) + delta_table = self._delta_table_helper.write_to_deltalake( + pa_table, self._is_incremental, index, table_primary_keys + ) + + self._internal_schema.add_pyarrow_table(pa_table) + + _update_incremental_state(self._schema, pa_table, self._logger) + _update_job_row_count(self._job.id, pa_table.num_rows, self._logger) + + def _post_run_operations(self, row_count: int): + delta_table = self._delta_table_helper.get_delta_table() + + assert delta_table is not None + + self._logger.info("Compacting delta table") + delta_table.optimize.compact() + delta_table.vacuum(retention_hours=24, enforce_retention_duration=False, dry_run=False) + + file_uris = delta_table.file_uris() + self._logger.info(f"Preparing S3 files - total parquet files: {len(file_uris)}") + prepare_s3_files_for_querying(self._job.folder_path(), self._resource_name, file_uris) + + self._logger.debug("Validating schema and updating table") + + validate_schema_and_update_table_sync( + run_id=str(self._job.id), + team_id=self._job.team_id, + schema_id=self._schema.id, + table_schema={}, + table_schema_dict=self._internal_schema.to_hogql_types(), + row_count=row_count, + table_format=DataWarehouseTable.TableFormat.DeltaS3Wrapper, + ) diff --git a/posthog/temporal/data_imports/pipelines/pipeline/utils.py b/posthog/temporal/data_imports/pipelines/pipeline/utils.py new file mode 100644 index 0000000000000..d07a697b9b4ea --- /dev/null +++ b/posthog/temporal/data_imports/pipelines/pipeline/utils.py @@ -0,0 +1,105 @@ +import json +from collections.abc import Sequence +from typing import Any +import pyarrow as pa +from dlt.common.libs.deltalake import ensure_delta_compatible_arrow_schema +from dlt.sources import DltResource +import deltalake as deltalake +from django.db.models import F +from posthog.temporal.common.logger import FilteringBoundLogger +from posthog.warehouse.models import ExternalDataJob, ExternalDataSchema + + +def _get_primary_keys(resource: DltResource) -> list[Any] | None: + primary_keys = resource._hints.get("primary_key") + + if primary_keys is None: + return None + + if isinstance(primary_keys, list): + return primary_keys + + if isinstance(primary_keys, Sequence): + return list(primary_keys) + + raise Exception(f"primary_keys of type {primary_keys.__class__.__name__} are not supported") + + +def _evolve_pyarrow_schema(table: pa.Table, delta_schema: deltalake.Schema | None) -> pa.Table: + py_table_field_names = table.schema.names + + # Change pa.structs to JSON string + for column_name in table.column_names: + column = table.column(column_name) + if pa.types.is_struct(column.type) or pa.types.is_list(column.type): + json_column = pa.array([json.dumps(row.as_py()) if row.as_py() is not None else None for row in column]) + table = table.set_column(table.schema.get_field_index(column_name), column_name, json_column) + + if delta_schema: + for field in delta_schema.to_pyarrow(): + if field.name not in py_table_field_names: + if field.nullable: + new_column_data = pa.array([None] * table.num_rows, type=field.type) + else: + new_column_data = pa.array( + [_get_default_value_from_pyarrow_type(field.type)] * table.num_rows, type=field.type + ) + table = table.append_column(field, new_column_data) + + # Change types based on what deltalake tables support + return table.cast(ensure_delta_compatible_arrow_schema(table.schema)) + + +def _append_debug_column_to_pyarrows_table(table: pa.Table, load_id: int) -> pa.Table: + debug_info = f'{{"load_id": {load_id}}}' + + column = pa.array([debug_info] * table.num_rows, type=pa.string()) + return table.append_column("_ph_debug", column) + + +def _get_default_value_from_pyarrow_type(pyarrow_type: pa.DataType): + """ + Returns a default value for the given PyArrow type. + """ + if pa.types.is_integer(pyarrow_type): + return 0 + elif pa.types.is_floating(pyarrow_type): + return 0.0 + elif pa.types.is_string(pyarrow_type): + return "" + elif pa.types.is_boolean(pyarrow_type): + return False + elif pa.types.is_binary(pyarrow_type): + return b"" + elif pa.types.is_timestamp(pyarrow_type): + return pa.scalar(0, type=pyarrow_type).as_py() + elif pa.types.is_date(pyarrow_type): + return pa.scalar(0, type=pyarrow_type).as_py() + elif pa.types.is_time(pyarrow_type): + return pa.scalar(0, type=pyarrow_type).as_py() + else: + raise ValueError(f"No default value defined for type: {pyarrow_type}") + + +def _update_incremental_state(schema: ExternalDataSchema | None, table: pa.Table, logger: FilteringBoundLogger) -> None: + if schema is None or schema.sync_type != ExternalDataSchema.SyncType.INCREMENTAL: + return + + incremental_field_name: str | None = schema.sync_type_config.get("incremental_field") + if incremental_field_name is None: + return + + column = table[incremental_field_name] + numpy_arr = column.combine_chunks().to_pandas().to_numpy() + + # TODO(@Gilbert09): support different operations here (e.g. min) + last_value = numpy_arr.max() + + logger.debug(f"Updating incremental_field_last_value with {last_value}") + + schema.update_incremental_field_last_value(last_value) + + +def _update_job_row_count(job_id: str, count: int, logger: FilteringBoundLogger) -> None: + logger.debug(f"Updating rows_synced with +{count}") + ExternalDataJob.objects.filter(id=job_id).update(rows_synced=F("rows_synced") + count) diff --git a/posthog/temporal/data_imports/pipelines/pipeline_sync.py b/posthog/temporal/data_imports/pipelines/pipeline_sync.py index bd48d9a53ec0e..581e84f2e476e 100644 --- a/posthog/temporal/data_imports/pipelines/pipeline_sync.py +++ b/posthog/temporal/data_imports/pipelines/pipeline_sync.py @@ -6,6 +6,7 @@ import dlt from django.conf import settings from django.db.models import Prefetch +import dlt.common from dlt.pipeline.exceptions import PipelineStepFailed from deltalake import DeltaTable @@ -345,6 +346,10 @@ def _run(self) -> dict[str, int]: job_id=self.inputs.run_id, schema_id=str(self.inputs.schema_id), team_id=self.inputs.team_id ) + if self._incremental: + self.logger.debug("Saving last incremental value...") + save_last_incremental_value(str(self.inputs.schema_id), str(self.inputs.team_id), self.source, self.logger) + # Cleanup: delete local state from the file system pipeline.drop() @@ -371,6 +376,28 @@ def update_last_synced_at_sync(job_id: str, schema_id: str, team_id: int) -> Non schema.save() +def save_last_incremental_value(schema_id: str, team_id: str, source: DltSource, logger: FilteringBoundLogger) -> None: + schema = ExternalDataSchema.objects.exclude(deleted=True).get(id=schema_id, team_id=team_id) + + incremental_field = schema.sync_type_config.get("incremental_field") + resource = next(iter(source.resources.values())) + + incremental: dict | None = resource.state.get("incremental") + + if incremental is None: + return + + incremental_object: dict | None = incremental.get(incremental_field) + if incremental_object is None: + return + + last_value = incremental_object.get("last_value") + + logger.debug(f"Updating incremental_field_last_value with {last_value}") + + schema.update_incremental_field_last_value(last_value) + + def validate_schema_and_update_table_sync( run_id: str, team_id: int, @@ -378,6 +405,7 @@ def validate_schema_and_update_table_sync( table_schema: TSchemaTables, row_count: int, table_format: DataWarehouseTable.TableFormat, + table_schema_dict: Optional[dict[str, str]] = None, ) -> None: """ @@ -465,27 +493,46 @@ def validate_schema_and_update_table_sync( else: raise - for schema in table_schema.values(): - if schema.get("resource") == _schema_name: - schema_columns = schema.get("columns") or {} - raw_db_columns: dict[str, dict[str, str]] = table_created.get_columns() - db_columns = {key: column.get("clickhouse", "") for key, column in raw_db_columns.items()} - - columns = {} - for column_name, db_column_type in db_columns.items(): - dlt_column = schema_columns.get(column_name) - if dlt_column is not None: - dlt_data_type = dlt_column.get("data_type") - hogql_type = dlt_to_hogql_type(dlt_data_type) - else: - hogql_type = dlt_to_hogql_type(None) - - columns[column_name] = { - "clickhouse": db_column_type, - "hogql": hogql_type, - } - table_created.columns = columns - break + # If using new non-DLT pipeline + if table_schema_dict is not None: + raw_db_columns: dict[str, dict[str, str]] = table_created.get_columns() + db_columns = {key: column.get("clickhouse", "") for key, column in raw_db_columns.items()} + + columns = {} + for column_name, db_column_type in db_columns.items(): + hogql_type = table_schema_dict.get(column_name) + + if hogql_type is None: + raise Exception(f"HogQL type not found for column: {column_name}") + + columns[column_name] = { + "clickhouse": db_column_type, + "hogql": hogql_type, + } + table_created.columns = columns + else: + # If using DLT pipeline + for schema in table_schema.values(): + if schema.get("resource") == _schema_name: + schema_columns = schema.get("columns") or {} + raw_db_columns: dict[str, dict[str, str]] = table_created.get_columns() + db_columns = {key: column.get("clickhouse", "") for key, column in raw_db_columns.items()} + + columns = {} + for column_name, db_column_type in db_columns.items(): + dlt_column = schema_columns.get(column_name) + if dlt_column is not None: + dlt_data_type = dlt_column.get("data_type") + hogql_type = dlt_to_hogql_type(dlt_data_type) + else: + hogql_type = dlt_to_hogql_type(None) + + columns[column_name] = { + "clickhouse": db_column_type, + "hogql": hogql_type, + } + table_created.columns = columns + break table_created.save() diff --git a/posthog/temporal/data_imports/pipelines/rest_source/__init__.py b/posthog/temporal/data_imports/pipelines/rest_source/__init__.py index 4fd019ce76753..9a8599882c652 100644 --- a/posthog/temporal/data_imports/pipelines/rest_source/__init__.py +++ b/posthog/temporal/data_imports/pipelines/rest_source/__init__.py @@ -46,6 +46,7 @@ def rest_api_source( config: RESTAPIConfig, team_id: int, job_id: str, + db_incremental_field_last_value: Optional[Any] = None, name: Optional[str] = None, section: Optional[str] = None, max_table_nesting: Optional[int] = None, @@ -108,10 +109,12 @@ def rest_api_source( spec, ) - return decorated(config, team_id, job_id) + return decorated(config, team_id, job_id, db_incremental_field_last_value) -def rest_api_resources(config: RESTAPIConfig, team_id: int, job_id: str) -> list[DltResource]: +def rest_api_resources( + config: RESTAPIConfig, team_id: int, job_id: str, db_incremental_field_last_value: Optional[Any] +) -> list[DltResource]: """Creates a list of resources from a REST API configuration. Args: @@ -193,6 +196,7 @@ def rest_api_resources(config: RESTAPIConfig, team_id: int, job_id: str) -> list resolved_param_map, team_id=team_id, job_id=job_id, + db_incremental_field_last_value=db_incremental_field_last_value, ) return list(resources.values()) @@ -205,6 +209,7 @@ def create_resources( resolved_param_map: dict[str, Optional[ResolvedParam]], team_id: int, job_id: str, + db_incremental_field_last_value: Optional[Any] = None, ) -> dict[str, DltResource]: resources = {} @@ -264,6 +269,7 @@ async def paginate_resource( incremental_object, incremental_param, incremental_cursor_transform, + db_incremental_field_last_value, ) yield client.paginate( @@ -317,6 +323,7 @@ async def paginate_dependent_resource( incremental_object, incremental_param, incremental_cursor_transform, + db_incremental_field_last_value, ) for item in items: @@ -358,6 +365,7 @@ def _set_incremental_params( incremental_object: Incremental[Any], incremental_param: Optional[IncrementalParam], transform: Optional[Callable[..., Any]], + db_incremental_field_last_value: Optional[Any] = None, ) -> dict[str, Any]: def identity_func(x: Any) -> Any: return x @@ -368,7 +376,13 @@ def identity_func(x: Any) -> Any: if incremental_param is None: return params - params[incremental_param.start] = transform(incremental_object.last_value) + last_value = ( + db_incremental_field_last_value + if db_incremental_field_last_value is not None + else incremental_object.last_value + ) + + params[incremental_param.start] = transform(last_value) if incremental_param.end: params[incremental_param.end] = transform(incremental_object.end_value) return params diff --git a/posthog/temporal/data_imports/pipelines/salesforce/__init__.py b/posthog/temporal/data_imports/pipelines/salesforce/__init__.py index cd206b6adcd4f..f01e17197e65f 100644 --- a/posthog/temporal/data_imports/pipelines/salesforce/__init__.py +++ b/posthog/temporal/data_imports/pipelines/salesforce/__init__.py @@ -6,7 +6,6 @@ from posthog.temporal.data_imports.pipelines.rest_source import RESTAPIConfig, rest_api_resources from posthog.temporal.data_imports.pipelines.rest_source.typing import EndpointResource from posthog.temporal.data_imports.pipelines.salesforce.auth import SalseforceAuth -import pendulum import re @@ -326,6 +325,7 @@ def salesforce_source( endpoint: str, team_id: int, job_id: str, + db_incremental_field_last_value: Optional[Any], is_incremental: bool = False, ): config: RESTAPIConfig = { @@ -340,4 +340,4 @@ def salesforce_source( "resources": [get_resource(endpoint, is_incremental)], } - yield from rest_api_resources(config, team_id, job_id) + yield from rest_api_resources(config, team_id, job_id, db_incremental_field_last_value) diff --git a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py index 2d826b8ed71f6..7593332f2d20a 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/__init__.py @@ -48,6 +48,7 @@ def sql_source_for_type( sslmode: str, schema: str, table_names: list[str], + db_incremental_field_last_value: Optional[Any], team_id: Optional[int] = None, incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, @@ -90,12 +91,13 @@ def sql_source_for_type( raise Exception("Unsupported source_type") db_source = sql_database( - credentials, + credentials=credentials, schema=schema, table_names=table_names, incremental=incremental, team_id=team_id, connect_args=connect_args, + db_incremental_field_last_value=db_incremental_field_last_value, ) return db_source @@ -109,6 +111,7 @@ def snowflake_source( warehouse: str, schema: str, table_names: list[str], + db_incremental_field_last_value: Optional[Any], role: Optional[str] = None, incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, @@ -130,7 +133,13 @@ def snowflake_source( credentials = ConnectionStringCredentials( f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}" ) - db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) + db_source = sql_database( + credentials=credentials, + schema=schema, + table_names=table_names, + incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, + ) return db_source @@ -144,6 +153,7 @@ def bigquery_source( token_uri: str, table_name: str, bq_destination_table_id: str, + db_incremental_field_last_value: Optional[Any], incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, ) -> DltSource: @@ -168,7 +178,13 @@ def bigquery_source( credentials_info=credentials_info, ) - return sql_database(engine, schema=None, table_names=[table_name], incremental=incremental) + return sql_database( + credentials=engine, + schema=None, + table_names=[table_name], + incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, + ) # Temp while DLT doesn't support `interval` columns @@ -189,6 +205,7 @@ def internal_remove(doc: dict) -> dict: @dlt.source(max_table_nesting=0) def sql_database( + db_incremental_field_last_value: Optional[Any], credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, @@ -248,6 +265,7 @@ def sql_database( table=table, incremental=incremental, connect_args=connect_args, + db_incremental_field_last_value=db_incremental_field_last_value, ) ) diff --git a/posthog/temporal/data_imports/pipelines/sql_database/helpers.py b/posthog/temporal/data_imports/pipelines/sql_database/helpers.py index 50577b6b04d17..0400a60b32fd5 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database/helpers.py +++ b/posthog/temporal/data_imports/pipelines/sql_database/helpers.py @@ -27,6 +27,7 @@ def __init__( chunk_size: int = 1000, incremental: Optional[dlt.sources.incremental[Any]] = None, connect_args: Optional[list[str]] = None, + db_incremental_field_last_value: Optional[Any] = None, ) -> None: self.engine = engine self.table = table @@ -43,7 +44,11 @@ def __init__( raise KeyError( f"Cursor column '{incremental.cursor_path}' does not exist in table '{table.name}'" ) from e - self.last_value = incremental.last_value + self.last_value = ( + db_incremental_field_last_value + if db_incremental_field_last_value is not None + else incremental.last_value + ) else: self.cursor_column = None self.last_value = None @@ -90,6 +95,7 @@ def table_rows( chunk_size: int = DEFAULT_CHUNK_SIZE, incremental: Optional[dlt.sources.incremental[Any]] = None, connect_args: Optional[list[str]] = None, + db_incremental_field_last_value: Optional[Any] = None, ) -> Iterator[TDataItem]: """ A DLT source which loads data from an SQL database using SQLAlchemy. @@ -106,7 +112,14 @@ def table_rows( """ yield dlt.mark.materialize_table_schema() # type: ignore - loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size, connect_args=connect_args) + loader = TableLoader( + engine, + table, + incremental=incremental, + chunk_size=chunk_size, + connect_args=connect_args, + db_incremental_field_last_value=db_incremental_field_last_value, + ) yield from loader.load_rows() engine.dispose() diff --git a/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py b/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py index 33c150e79998f..0ec4abdc202c9 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py +++ b/posthog/temporal/data_imports/pipelines/sql_database_v2/__init__.py @@ -64,6 +64,7 @@ def sql_source_for_type( sslmode: str, schema: str, table_names: list[str], + db_incremental_field_last_value: Optional[Any], team_id: Optional[int] = None, incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, @@ -106,10 +107,11 @@ def sql_source_for_type( raise Exception("Unsupported source_type") db_source = sql_database( - credentials, + credentials=credentials, schema=schema, table_names=table_names, incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, team_id=team_id, connect_args=connect_args, ) @@ -125,6 +127,7 @@ def snowflake_source( warehouse: str, schema: str, table_names: list[str], + db_incremental_field_last_value: Optional[Any], role: Optional[str] = None, incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, @@ -146,7 +149,13 @@ def snowflake_source( credentials = ConnectionStringCredentials( f"snowflake://{user}:{password}@{account_id}/{database}/{schema}?warehouse={warehouse}{f'&role={role}' if role else ''}" ) - db_source = sql_database(credentials, schema=schema, table_names=table_names, incremental=incremental) + db_source = sql_database( + credentials=credentials, + schema=schema, + table_names=table_names, + incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, + ) return db_source @@ -160,6 +169,7 @@ def bigquery_source( token_uri: str, table_name: str, bq_destination_table_id: str, + db_incremental_field_last_value: Optional[Any], incremental_field: Optional[str] = None, incremental_field_type: Optional[IncrementalFieldType] = None, ) -> DltSource: @@ -184,11 +194,18 @@ def bigquery_source( credentials_info=credentials_info, ) - return sql_database(engine, schema=None, table_names=[table_name], incremental=incremental) + return sql_database( + credentials=engine, + schema=None, + table_names=[table_name], + incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, + ) @dlt.source(max_table_nesting=0) def sql_database( + db_incremental_field_last_value: Optional[Any], credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, @@ -275,6 +292,7 @@ def sql_database( backend_kwargs=backend_kwargs, type_adapter_callback=type_adapter_callback, incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, team_id=team_id, connect_args=connect_args, ) @@ -299,6 +317,7 @@ def internal_remove(table: pa.Table) -> pa.Table: @dlt.resource(name=lambda args: args["table"], standalone=True, spec=SqlTableResourceConfiguration) def sql_table( + db_incremental_field_last_value: Optional[Any], credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, table: str = dlt.config.value, schema: Optional[str] = dlt.config.value, @@ -396,6 +415,7 @@ def query_adapter_callback(query: SelectAny, table: Table): chunk_size=chunk_size, backend=backend, incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, reflection_level=reflection_level, defer_table_reflect=defer_table_reflect, table_adapter_callback=table_adapter_callback, diff --git a/posthog/temporal/data_imports/pipelines/sql_database_v2/helpers.py b/posthog/temporal/data_imports/pipelines/sql_database_v2/helpers.py index 46f59929beb47..acd64c97aae99 100644 --- a/posthog/temporal/data_imports/pipelines/sql_database_v2/helpers.py +++ b/posthog/temporal/data_imports/pipelines/sql_database_v2/helpers.py @@ -46,6 +46,7 @@ def __init__( columns: TTableSchemaColumns, chunk_size: int = 1000, incremental: Optional[dlt.sources.incremental[Any]] = None, + db_incremental_field_last_value: Optional[Any] = None, query_adapter_callback: Optional[TQueryAdapter] = None, connect_args: Optional[list[str]] = None, ) -> None: @@ -64,7 +65,11 @@ def __init__( raise KeyError( f"Cursor column '{incremental.cursor_path}' does not exist in table '{table.name}'" ) from e - self.last_value = incremental.last_value + self.last_value = ( + db_incremental_field_last_value + if db_incremental_field_last_value is not None + else incremental.last_value + ) self.end_value = incremental.end_value self.row_order: TSortOrder = self.incremental.row_order else: @@ -183,6 +188,7 @@ def table_rows( chunk_size: int, backend: TableBackend, incremental: Optional[dlt.sources.incremental[Any]] = None, + db_incremental_field_last_value: Optional[Any] = None, defer_table_reflect: bool = False, table_adapter_callback: Optional[Callable[[Table], None]] = None, reflection_level: ReflectionLevel = "minimal", @@ -226,6 +232,7 @@ def table_rows( table, columns, incremental=incremental, + db_incremental_field_last_value=db_incremental_field_last_value, chunk_size=chunk_size, query_adapter_callback=query_adapter_callback, connect_args=connect_args, diff --git a/posthog/temporal/data_imports/pipelines/stripe/__init__.py b/posthog/temporal/data_imports/pipelines/stripe/__init__.py index 5b386aa10adba..da9af92c191dc 100644 --- a/posthog/temporal/data_imports/pipelines/stripe/__init__.py +++ b/posthog/temporal/data_imports/pipelines/stripe/__init__.py @@ -325,7 +325,13 @@ def update_request(self, request: Request) -> None: @dlt.source(max_table_nesting=0) def stripe_source( - api_key: str, account_id: Optional[str], endpoint: str, team_id: int, job_id: str, is_incremental: bool = False + api_key: str, + account_id: Optional[str], + endpoint: str, + team_id: int, + job_id: str, + db_incremental_field_last_value: Optional[Any], + is_incremental: bool = False, ): config: RESTAPIConfig = { "client": { @@ -355,7 +361,7 @@ def stripe_source( "resources": [get_resource(endpoint, is_incremental)], } - yield from rest_api_resources(config, team_id, job_id) + yield from rest_api_resources(config, team_id, job_id, db_incremental_field_last_value) def validate_credentials(api_key: str) -> bool: diff --git a/posthog/temporal/data_imports/pipelines/vitally/__init__.py b/posthog/temporal/data_imports/pipelines/vitally/__init__.py index 223513d439d7c..86ca0bfdf7ff4 100644 --- a/posthog/temporal/data_imports/pipelines/vitally/__init__.py +++ b/posthog/temporal/data_imports/pipelines/vitally/__init__.py @@ -323,6 +323,7 @@ def vitally_source( endpoint: str, team_id: int, job_id: str, + db_incremental_field_last_value: Optional[Any], is_incremental: bool = False, ): config: RESTAPIConfig = { @@ -347,7 +348,7 @@ def vitally_source( "resources": [get_resource(endpoint, is_incremental)], } - yield from rest_api_resources(config, team_id, job_id) + yield from rest_api_resources(config, team_id, job_id, db_incremental_field_last_value) def validate_credentials(secret_token: str, region: str, subdomain: Optional[str]) -> bool: diff --git a/posthog/temporal/data_imports/pipelines/zendesk/__init__.py b/posthog/temporal/data_imports/pipelines/zendesk/__init__.py index 36d842e4d3889..55b6be994f006 100644 --- a/posthog/temporal/data_imports/pipelines/zendesk/__init__.py +++ b/posthog/temporal/data_imports/pipelines/zendesk/__init__.py @@ -289,6 +289,7 @@ def zendesk_source( endpoint: str, team_id: int, job_id: str, + db_incremental_field_last_value: Optional[Any], is_incremental: bool = False, ): config: RESTAPIConfig = { @@ -312,7 +313,7 @@ def zendesk_source( "resources": [get_resource(endpoint, is_incremental)], } - yield from rest_api_resources(config, team_id, job_id) + yield from rest_api_resources(config, team_id, job_id, db_incremental_field_last_value) def validate_credentials(subdomain: str, api_key: str, email_address: str) -> bool: diff --git a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py index 74244dcded195..85e21351a50b4 100644 --- a/posthog/temporal/data_imports/workflow_activities/import_data_sync.py +++ b/posthog/temporal/data_imports/workflow_activities/import_data_sync.py @@ -1,8 +1,10 @@ import dataclasses import uuid from datetime import datetime +from dateutil import parser from typing import Any +from django.conf import settings from django.db import close_old_connections from django.db.models import Prefetch, F @@ -12,6 +14,7 @@ from posthog.temporal.common.heartbeat_sync import HeartbeaterSync from posthog.temporal.data_imports.pipelines.bigquery import delete_table +from posthog.temporal.data_imports.pipelines.pipeline.pipeline import PipelineNonDLT from posthog.temporal.data_imports.pipelines.pipeline_sync import DataImportPipelineSync, PipelineInputs from posthog.temporal.data_imports.util import is_posthog_team from posthog.warehouse.models import ( @@ -22,6 +25,7 @@ from structlog.typing import FilteringBoundLogger from posthog.warehouse.models.external_data_schema import ExternalDataSchema from posthog.warehouse.models.ssh_tunnel import SSHTunnel +from posthog.warehouse.types import IncrementalFieldType @dataclasses.dataclass @@ -32,6 +36,20 @@ class ImportDataActivityInputs: run_id: str +def process_incremental_last_value(value: Any | None, field_type: IncrementalFieldType | None) -> Any | None: + if value is None or field_type is None: + return None + + if field_type == IncrementalFieldType.Integer or field_type == IncrementalFieldType.Numeric: + return value + + if field_type == IncrementalFieldType.DateTime or field_type == IncrementalFieldType.Timestamp: + return parser.parse(value) + + if field_type == IncrementalFieldType.Date: + return parser.parse(value).date() + + @activity.defn def import_data_activity_sync(inputs: ImportDataActivityInputs): logger = bind_temporal_worker_logger_sync(team_id=inputs.team_id) @@ -64,6 +82,11 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): endpoints = [schema.name] + processed_incremental_last_value = process_incremental_last_value( + schema.sync_type_config.get("incremental_field_last_value"), + schema.sync_type_config.get("incremental_field_type"), + ) + source = None if model.pipeline.source_type == ExternalDataSource.Type.STRIPE: from posthog.temporal.data_imports.pipelines.stripe import stripe_source @@ -80,6 +103,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): team_id=inputs.team_id, job_id=inputs.run_id, is_incremental=schema.is_incremental, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, ) return _run( @@ -176,6 +200,9 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): incremental_field_type=schema.sync_type_config.get("incremental_field_type") if schema.is_incremental else None, + db_incremental_field_last_value=processed_incremental_last_value + if schema.is_incremental + else None, team_id=inputs.team_id, ) @@ -202,6 +229,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): incremental_field_type=schema.sync_type_config.get("incremental_field_type") if schema.is_incremental else None, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, team_id=inputs.team_id, ) @@ -244,6 +272,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): incremental_field_type=schema.sync_type_config.get("incremental_field_type") if schema.is_incremental else None, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, ) return _run( @@ -288,6 +317,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): team_id=inputs.team_id, job_id=inputs.run_id, is_incremental=schema.is_incremental, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, ) return _run( @@ -310,6 +340,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): team_id=inputs.team_id, job_id=inputs.run_id, is_incremental=schema.is_incremental, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, ) return _run( @@ -331,6 +362,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): team_id=inputs.team_id, job_id=inputs.run_id, is_incremental=schema.is_incremental, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, ) return _run( @@ -368,6 +400,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): incremental_field_type=schema.sync_type_config.get("incremental_field_type") if schema.is_incremental else None, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, ) _run( @@ -403,6 +436,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs): team_id=inputs.team_id, job_id=inputs.run_id, is_incremental=schema.is_incremental, + db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None, ) return _run( @@ -425,12 +459,18 @@ def _run( schema: ExternalDataSchema, reset_pipeline: bool, ): - table_row_counts = DataImportPipelineSync(job_inputs, source, logger, reset_pipeline, schema.is_incremental).run() - total_rows_synced = sum(table_row_counts.values()) + if settings.DEBUG: + PipelineNonDLT(source, logger, job_inputs.run_id, schema.is_incremental).run() + else: + table_row_counts = DataImportPipelineSync( + job_inputs, source, logger, reset_pipeline, schema.is_incremental + ).run() + total_rows_synced = sum(table_row_counts.values()) + + ExternalDataJob.objects.filter(id=inputs.run_id, team_id=inputs.team_id).update( + rows_synced=F("rows_synced") + total_rows_synced + ) - ExternalDataJob.objects.filter(id=inputs.run_id, team_id=inputs.team_id).update( - rows_synced=F("rows_synced") + total_rows_synced - ) source = ExternalDataSource.objects.get(id=inputs.source_id) source.job_inputs.pop("reset_pipeline", None) source.save() diff --git a/posthog/warehouse/models/external_data_schema.py b/posthog/warehouse/models/external_data_schema.py index 3bcbc6c658f7f..beaad6ba8c408 100644 --- a/posthog/warehouse/models/external_data_schema.py +++ b/posthog/warehouse/models/external_data_schema.py @@ -1,6 +1,6 @@ from collections import defaultdict from datetime import datetime, timedelta -from typing import Optional +from typing import Any, Optional from django.db import models from django_deprecate_fields import deprecate_field import snowflake.connector @@ -48,6 +48,8 @@ class SyncFrequency(models.TextChoices): status = models.CharField(max_length=400, null=True, blank=True) last_synced_at = models.DateTimeField(null=True, blank=True) sync_type = models.CharField(max_length=128, choices=SyncType.choices, null=True, blank=True) + + # { "incremental_field": string, "incremental_field_type": string, "incremental_field_last_value": any } sync_type_config = models.JSONField( default=dict, blank=True, @@ -67,6 +69,20 @@ def folder_path(self) -> str: def is_incremental(self): return self.sync_type == self.SyncType.INCREMENTAL + def update_incremental_field_last_value(self, last_value: Any) -> None: + incremental_field_type = self.sync_type_config.get("incremental_field_type") + + if ( + incremental_field_type == IncrementalFieldType.Integer + or incremental_field_type == IncrementalFieldType.Numeric + ): + last_value_json = last_value + else: + last_value_json = str(last_value) + + self.sync_type_config["incremental_field_last_value"] = last_value_json + self.save() + def soft_delete(self): self.deleted = True self.deleted_at = datetime.now() diff --git a/requirements.in b/requirements.in index 3696df35d43d1..e1afbf34b108f 100644 --- a/requirements.in +++ b/requirements.in @@ -14,6 +14,7 @@ celery==5.3.4 celery-redbeat==2.1.1 clickhouse-driver==0.2.7 clickhouse-pool==0.5.3 +conditional-cache==1.2 cryptography==39.0.2 dj-database-url==0.5.0 Django~=4.2.15 diff --git a/requirements.txt b/requirements.txt index c276d7a792904..d5cac17a5ce4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -93,6 +93,8 @@ charset-normalizer==2.1.0 # via # requests # snowflake-connector-python +circular-dict==1.9 + # via conditional-cache click==8.1.7 # via # celery @@ -113,6 +115,8 @@ clickhouse-driver==0.2.7 # sentry-sdk clickhouse-pool==0.5.3 # via -r requirements.in +conditional-cache==1.2 + # via -r requirements.in cryptography==39.0.2 # via # -r requirements.in