diff --git a/universal_transfer_operator/example_dags/example_universal_transfer_operator.py b/universal_transfer_operator/example_dags/example_universal_transfer_operator.py index 8a3f40192..daa2a3b26 100644 --- a/universal_transfer_operator/example_dags/example_universal_transfer_operator.py +++ b/universal_transfer_operator/example_dags/example_universal_transfer_operator.py @@ -5,7 +5,7 @@ from universal_transfer_operator.constants import FileType from universal_transfer_operator.datasets.file.base import File -from universal_transfer_operator.datasets.table import Table +from universal_transfer_operator.datasets.table import Metadata, Table from universal_transfer_operator.universal_transfer_operator import UniversalTransferOperator s3_bucket = os.getenv("S3_BUCKET", "s3://astro-sdk-test") @@ -35,7 +35,7 @@ transfer_non_native_s3_to_sqlite = UniversalTransferOperator( task_id="transfer_non_native_s3_to_sqlite", source_dataset=File(path=f"{s3_bucket}/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV), - destination_dataset=Table(name="uto_s3_table_1", conn_id="sqlite_default"), + destination_dataset=Table(name="uto_s3_to_sqlite_table", conn_id="sqlite_default"), ) transfer_non_native_gs_to_sqlite = UniversalTransferOperator( @@ -43,7 +43,7 @@ source_dataset=File( path=f"{gcs_bucket}/uto/csv_files/", conn_id="google_cloud_default", filetype=FileType.CSV ), - destination_dataset=Table(name="uto_gs_table_1", conn_id="sqlite_default"), + destination_dataset=Table(name="uto_gs_to_sqlite_table", conn_id="sqlite_default"), ) transfer_non_native_s3_to_snowflake = UniversalTransferOperator( @@ -51,7 +51,7 @@ source_dataset=File( path="s3://astro-sdk-test/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV ), - destination_dataset=Table(name="uto_s3_table_2", conn_id="snowflake_conn"), + destination_dataset=Table(name="uto_s3_table_to_snowflake", conn_id="snowflake_conn"), ) transfer_non_native_gs_to_snowflake = UniversalTransferOperator( @@ -59,5 +59,48 @@ source_dataset=File( path="gs://uto-test/uto/csv_files/", conn_id="google_cloud_default", filetype=FileType.CSV ), - destination_dataset=Table(name="uto_gs_table_2", conn_id="snowflake_conn"), + destination_dataset=Table(name="uto_gs_to_snowflake_table", conn_id="snowflake_conn"), + ) + + transfer_non_native_gs_to_bigquery = UniversalTransferOperator( + task_id="transfer_non_native_gs_to_bigquery", + source_dataset=File(path="gs://uto-test/uto/homes_main.csv", conn_id="google_cloud_default"), + destination_dataset=Table( + name="uto_gs_to_bigquery_table", + conn_id="google_cloud_default", + metadata=Metadata(schema="astro"), + ), + ) + + transfer_non_native_s3_to_bigquery = UniversalTransferOperator( + task_id="transfer_non_native_s3_to_bigquery", + source_dataset=File( + path="s3://astro-sdk-test/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV + ), + destination_dataset=Table( + name="uto_s3_to_bigquery_destination_table", + conn_id="google_cloud_default", + metadata=Metadata(schema="astro"), + ), + ) + + transfer_non_native_bigquery_to_snowflake = UniversalTransferOperator( + task_id="transfer_non_native_bigquery_to_snowflake", + source_dataset=Table( + name="uto_s3_to_bigquery_table", + conn_id="google_cloud_default", + metadata=Metadata(schema="astro"), + ), + destination_dataset=Table( + name="uto_bigquery_to_snowflake_table", + conn_id="snowflake_conn", + ), + ) + + transfer_non_native_bigquery_to_sqlite = UniversalTransferOperator( + task_id="transfer_non_native_bigquery_to_sqlite", + source_dataset=Table( + name="uto_s3_to_bigquery_table", conn_id="google_cloud_default", metadata=Metadata(schema="astro") + ), + destination_dataset=Table(name="uto_bigquery_to_sqlite_table", conn_id="sqlite_default"), ) diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py index 97c43979f..5980a6c6b 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/__init__.py @@ -14,8 +14,9 @@ ("aws", File): "universal_transfer_operator.data_providers.filesystem.aws.s3", ("gs", File): "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs", ("google_cloud_platform", File): "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs", - ("sqlite", Table): "universal_transfer_operator.data_providers.database.sqlite", ("sftp", File): "universal_transfer_operator.data_providers.filesystem.sftp", + ("google_cloud_platform", Table): "universal_transfer_operator.data_providers.database.google.bigquery", + ("gs", Table): "universal_transfer_operator.data_providers.database.google.bigquery", ("sqlite", Table): "universal_transfer_operator.data_providers.database.sqlite", ("snowflake", Table): "universal_transfer_operator.data_providers.database.snowflake", (None, File): "universal_transfer_operator.data_providers.filesystem.local", diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py index c7e95d1be..d26841042 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/base.py @@ -104,4 +104,10 @@ def openlineage_dataset_uri(self) -> str: Returns the open lineage dataset uri as per https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ + return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}" + + def populate_metadata(self): + """ + Given a dataset, check if the dataset has metadata. + """ raise NotImplementedError diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/base.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/base.py index 6b7f5a325..42857d81d 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/base.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Callable import pandas as pd @@ -27,6 +28,7 @@ from universal_transfer_operator.data_providers.filesystem import resolve_file_path_pattern from universal_transfer_operator.data_providers.filesystem.base import FileStream from universal_transfer_operator.datasets.base import Dataset +from universal_transfer_operator.datasets.dataframe.pandas import PandasDataframe from universal_transfer_operator.datasets.file.base import File from universal_transfer_operator.datasets.table import Metadata, Table from universal_transfer_operator.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SCHEMA @@ -186,13 +188,20 @@ def check_if_transfer_supported(self, source_dataset: Table) -> bool: return Location(source_connection_type) in self.transfer_mapping def read(self): - """ ""Read the dataset and write to local reference location""" - raise NotImplementedError + """Read the dataset and write to local reference location""" + with NamedTemporaryFile(mode="wb+", suffix=".parquet", delete=False) as tmp_file: + df = self.export_table_to_pandas_dataframe() + df.to_parquet(tmp_file.name) + local_temp_file = FileStream( + remote_obj_buffer=tmp_file.file, + actual_filename=tmp_file.name, + actual_file=File(path=tmp_file.name), + ) + yield local_temp_file def write(self, source_ref: FileStream): """ Write the data from local reference location to the dataset. - :param source_ref: Stream of data to be loaded into output table. """ return self.load_file_to_table(input_file=source_ref.actual_file, output_table=self.dataset) @@ -269,7 +278,7 @@ def default_metadata(self) -> Metadata: """ raise NotImplementedError - def populate_table_metadata(self, table: Table) -> Table: + def populate_metadata(self): """ Given a table, check if the table has metadata. If the metadata is missing, and the database has metadata, assign it to the table. @@ -279,11 +288,11 @@ def populate_table_metadata(self, table: Table) -> Table: :param table: Table to potentially have their metadata changed :return table: Return the modified table """ - if table.metadata and table.metadata.is_empty() and self.default_metadata: - table.metadata = self.default_metadata - if not table.metadata.schema: - table.metadata.schema = self.DEFAULT_SCHEMA - return table + + if self.dataset.metadata and self.dataset.metadata.is_empty() and self.default_metadata: + self.dataset.metadata = self.default_metadata + if not self.dataset.metadata.schema: + self.dataset.metadata.schema = self.DEFAULT_SCHEMA # --------------------------------------------------------- # Table creation & deletion methods @@ -659,3 +668,19 @@ def schema_exists(self, schema: str) -> bool: :param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc) """ raise NotImplementedError + + # --------------------------------------------------------- + # Extract methods + # --------------------------------------------------------- + + def export_table_to_pandas_dataframe(self) -> pd.DataFrame: + """ + Copy the content of a table to an in-memory Pandas dataframe. + """ + + if not self.table_exists(self.dataset): + raise ValueError(f"The table {self.dataset.name} does not exist") + + sqla_table = self.get_sqla_table(self.dataset) + df = pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine) + return PandasDataframe.from_pandas_df(df) diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/__init__.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/bigquery.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/bigquery.py new file mode 100644 index 000000000..c0bb1f24a --- /dev/null +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/google/bigquery.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import attr +import pandas as pd +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from google.api_core.exceptions import ( + NotFound as GoogleNotFound, +) +from sqlalchemy import create_engine +from sqlalchemy.engine.base import Engine + +from universal_transfer_operator.constants import DEFAULT_CHUNK_SIZE, LoadExistStrategy +from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider +from universal_transfer_operator.datasets.table import Metadata, Table +from universal_transfer_operator.settings import BIGQUERY_SCHEMA, BIGQUERY_SCHEMA_LOCATION +from universal_transfer_operator.universal_transfer_operator import TransferParameters + + +class BigqueryDataProvider(DatabaseDataProvider): + """BigqueryDataProvider represent all the DataProviders interactions with Bigquery Databases.""" + + DEFAULT_SCHEMA = BIGQUERY_SCHEMA + + illegal_column_name_chars: list[str] = ["."] + illegal_column_name_chars_replacement: list[str] = ["_"] + + _create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {} OPTIONS (location='{}')" + + def __init__( + self, + dataset: Table, + transfer_mode, + transfer_params: TransferParameters = attr.field( + factory=TransferParameters, + converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val, + ), + ): + self.dataset = dataset + self.transfer_params = transfer_params + self.transfer_mode = transfer_mode + self.transfer_mapping = {} + self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {} + super().__init__( + dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params + ) + + def __repr__(self): + return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})' + + @property + def sql_type(self) -> str: + return "bigquery" + + @property + def hook(self) -> BigQueryHook: + """Retrieve Airflow hook to interface with Bigquery.""" + return BigQueryHook( + gcp_conn_id=self.dataset.conn_id, use_legacy_sql=False, location=BIGQUERY_SCHEMA_LOCATION + ) + + @property + def sqlalchemy_engine(self) -> Engine: + """Return SQAlchemy engine.""" + uri = self.hook.get_uri() + with self.hook.provide_gcp_credential_file_as_context(): + return create_engine(uri) + + @property + def default_metadata(self) -> Metadata: + """ + Fill in default metadata values for table objects addressing snowflake databases + """ + return Metadata( + schema=self.DEFAULT_SCHEMA, + database=self.hook.project_id, + ) # type: ignore + + # --------------------------------------------------------- + # Table metadata + # --------------------------------------------------------- + + def schema_exists(self, schema: str) -> bool: + """ + Checks if a dataset exists in the bigquery + + :param schema: Bigquery namespace + """ + try: + self.hook.get_dataset(dataset_id=schema) + except GoogleNotFound: + # google.api_core.exceptions throws when a resource is not found + return False + return True + + def _get_schema_location(self, schema: str | None = None) -> str: + """ + Get region where the schema is created + + :param schema: Bigquery namespace + """ + if schema is None: + return "" + try: + dataset = self.hook.get_dataset(dataset_id=schema) + return str(dataset.location) + except GoogleNotFound: + # google.api_core.exceptions throws when a resource is not found + return "" + + def load_pandas_dataframe_to_table( + self, + source_dataframe: pd.DataFrame, + target_table: Table, + if_exists: LoadExistStrategy = "replace", + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> None: + """ + Create a table with the dataframe's contents. + If the table already exists, append or replace the content, depending on the value of `if_exists`. + + :param source_dataframe: Local or remote filepath + :param target_table: Table in which the file will be loaded + :param if_exists: Strategy to be used in case the target table already exists. + :param chunk_size: Specify the number of rows in each batch to be written at a time. + """ + self._assert_not_empty_df(source_dataframe) + + try: + creds = self.hook._get_credentials() # skipcq PYL-W021 + except AttributeError: + # Details: https://github.com/astronomer/astro-sdk/issues/703 + creds = self.hook.get_credentials() + source_dataframe.to_gbq( + self.get_table_qualified_name(target_table), + if_exists=if_exists, + chunksize=chunk_size, + project_id=self.hook.project_id, + credentials=creds, + ) + + def create_schema_if_needed(self, schema: str | None) -> None: + """ + This function checks if the expected schema exists in the database. If the schema does not exist, + it will attempt to create it. + + :param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc) + """ + # We check if the schema exists first because BigQuery will fail on a create schema query even if it + # doesn't actually create a schema. + if schema and not self.schema_exists(schema): + table_schema = self.dataset.metadata.schema if self.dataset and self.dataset.metadata else None + table_location = self._get_schema_location(table_schema) + + location = table_location or BIGQUERY_SCHEMA_LOCATION + statement = self._create_schema_statement.format(schema, location) + self.run_sql(statement) + + def truncate_table(self, table): + """Truncate table""" + self.run_sql(f"TRUNCATE {self.get_table_qualified_name(table)}") + + @property + def openlineage_dataset_name(self) -> str: + """ + Returns the open lineage dataset name as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: db_name.schema_name.table_name + """ + dataset = self.dataset.metadata.database or self.dataset.metadata.schema + return f"{self.hook.project_id}.{dataset}.{self.dataset.name}" + + @property + def openlineage_dataset_namespace(self) -> str: + """ + Returns the open lineage dataset namespace as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + Example: snowflake://ACCOUNT + """ + return self.sql_type + + @property + def openlineage_dataset_uri(self) -> str: + """ + Returns the open lineage dataset uri as per + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md + """ + return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/snowflake.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/snowflake.py index 2307b89c9..5ed028187 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/snowflake.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/snowflake.py @@ -68,10 +68,6 @@ def default_metadata(self) -> Metadata: database=connection.database, ) - def read(self): - """ ""Read the dataset and write to local reference location""" - raise NotImplementedError - # --------------------------------------------------------- # Table metadata # --------------------------------------------------------- @@ -256,4 +252,4 @@ def openlineage_dataset_uri(self) -> str: Returns the open lineage dataset uri as per https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ - return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name()}" + return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/sqlite.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/sqlite.py index 4470b496a..bbdaba556 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/sqlite.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/database/sqlite.py @@ -8,7 +8,7 @@ from sqlalchemy.engine.base import Engine from sqlalchemy.sql.schema import Table as SqlaTable -from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider, FileStream +from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider from universal_transfer_operator.datasets.table import Metadata, Table from universal_transfer_operator.universal_transfer_operator import TransferParameters @@ -59,17 +59,6 @@ def default_metadata(self) -> Metadata: """Since Sqlite does not use Metadata, we return an empty Metadata instances.""" return Metadata() - def read(self): - """ ""Read the dataset and write to local reference location""" - raise NotImplementedError - - def write(self, source_ref: FileStream): - """Write the data from local reference location to the dataset""" - return self.load_file_to_table( - input_file=source_ref.actual_file, - output_table=self.dataset, - ) - # --------------------------------------------------------- # Table metadata # --------------------------------------------------------- @@ -82,13 +71,12 @@ def get_table_qualified_name(table: Table) -> str: """ return str(table.name) - def populate_table_metadata(self, table: Table) -> Table: + def populate_metadata(self): # skipcq: PTC-W0049 """ Since SQLite does not have a concept of databases or schemas, we just return the table as is, without any modifications. """ - table.conn_id = table.conn_id or self.dataset.conn_id - return table + pass def create_schema_if_needed(self, schema: str | None) -> None: """ @@ -136,4 +124,4 @@ def openlineage_dataset_uri(self) -> str: Returns the open lineage dataset uri as per https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ - return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name()}" + return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py index 3d4caf7af..7b846c53c 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/aws/s3.py @@ -49,6 +49,13 @@ def hook(self) -> S3Hook: extra_args=self.s3_extra_args, ) + def delete(self): + """ + Delete a file/object if they exists + """ + url = urlparse(self.dataset.path) + self.hook.delete_objects(bucket=url.netloc, keys=url.path.lstrip("/")) + @property def transport_params(self) -> dict: """Structure s3fs credentials from Airflow connection. diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py index 8309a48c0..532c00eb1 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/base.py @@ -66,6 +66,12 @@ def paths(self) -> list[str]: """Resolve patterns in path""" raise NotImplementedError + def delete(self): + """ + Delete a file/object if they exists + """ + raise NotImplementedError + @property def transport_params(self) -> dict | None: # skipcq: PYL-R0201 """Get credentials required by smart open to access files""" @@ -120,7 +126,7 @@ def write(self, source_ref: FileStream): def write_using_smart_open(self, source_ref: FileStream): """Write the source data from remote object i/o buffer to the dataset using smart open""" mode = "wb" if self.read_as_binary(source_ref.actual_filename) else "w" - destination_file = os.path.join(self.dataset.path, os.path.basename(source_ref.actual_filename)) + destination_file = self.dataset.path with smart_open.open(destination_file, mode=mode, transport_params=self.transport_params) as stream: stream.write(source_ref.remote_obj_buffer.read()) return destination_file @@ -195,3 +201,9 @@ def openlineage_dataset_uri(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ raise NotImplementedError + + def populate_metadata(self): # skipcq: PTC-W0049 + """ + Given a dataset, check if the dataset has metadata. + """ + pass diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py index 709baa504..5dba448b2 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/google/cloud/gcs.py @@ -49,6 +49,13 @@ def hook(self) -> GCSHook: impersonation_chain=self.google_impersonation_chain, ) + def delete(self): + """ + Delete a file/object if they exists + """ + url = urlparse(self.dataset.path) + self.hook.delete(bucket_name=url.netloc, object_name=url.path.lstrip("/")) + @property def transport_params(self) -> dict: """get GCS credentials for storage""" diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/local.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/local.py index e9c19fdf1..d15ce41ea 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/local.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/local.py @@ -3,9 +3,13 @@ import glob import os import pathlib +from os.path import exists from urllib.parse import urlparse -from universal_transfer_operator.data_providers.filesystem.base import BaseFilesystemProviders +import smart_open +from airflow.hooks.base import BaseHook + +from universal_transfer_operator.data_providers.filesystem.base import BaseFilesystemProviders, FileStream class LocalDataProvider(BaseFilesystemProviders): @@ -48,3 +52,26 @@ def openlineage_dataset_name(self) -> str: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md """ return urlparse(self.path).path + + def delete(self): + """ + Delete a file/object if they exists + """ + os.remove(self.dataset.path) + + def check_if_exists(self) -> bool: + """Return true if the dataset exists""" + return exists(self.dataset.path) + + def write_using_smart_open(self, source_ref: FileStream): + """Write the source data from remote object i/o buffer to the dataset using smart open""" + mode = "wb" if self.read_as_binary(source_ref.actual_filename) else "w" + # destination_file = os.path.join(, os.path.basename(source_ref.actual_filename)) + with smart_open.open(self.dataset.path, mode=mode, transport_params=self.transport_params) as stream: + stream.write(source_ref.remote_obj_buffer.read()) + return self.dataset.path + + @property + def hook(self) -> BaseHook: + """Return an instance of the Airflow hook.""" + raise NotImplementedError diff --git a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/sftp.py b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/sftp.py index dfac7ae0e..ff0fbfb07 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/sftp.py +++ b/universal_transfer_operator/src/universal_transfer_operator/data_providers/filesystem/sftp.py @@ -45,6 +45,16 @@ def hook(self) -> SFTPHook: """Return an instance of the SFTPHook Airflow hook.""" return SFTPHook(ssh_conn_id=self.dataset.conn_id) + def delete(self): + """ + Delete a file/object if they exists + """ + self.hook.delete_file(path=self.dataset.path.replace("sftp://", "/")) + + def check_if_exists(self): + """Return true if the dataset exists""" + return self.hook.path_exists(self.dataset.path.replace("sftp://", "/")) + @property def paths(self) -> list[str]: """Resolve SFTP file paths with netloc of self.dataset.path as prefix. Paths are added if they start with prefix diff --git a/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py b/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py index a3e850ef7..bc7fae866 100644 --- a/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py +++ b/universal_transfer_operator/src/universal_transfer_operator/datasets/file/base.py @@ -131,3 +131,24 @@ def __eq__(self, other) -> bool: def __hash__(self) -> int: return hash((self.path, self.conn_id, self.filetype)) + + @uri.default + def _path_to_dataset_uri(self) -> str: + """Build a URI to be passed to Dataset obj introduced in Airflow 2.4""" + from urllib.parse import urlencode, urlparse + + parsed_url = urlparse(url=self.path) + netloc = parsed_url.netloc + # Local filepaths do not have scheme + parsed_scheme = parsed_url.scheme or "file" + scheme = f"astro+{parsed_scheme}" + extra = {} + if self.filetype: + extra["filetype"] = str(self.filetype) + + new_parsed_url = parsed_url._replace( + netloc=f"{self.conn_id}@{netloc}" if self.conn_id else netloc, + scheme=scheme, + query=urlencode(extra), + ) + return new_parsed_url.geturl() diff --git a/universal_transfer_operator/tests/conftest.py b/universal_transfer_operator/tests/conftest.py index 898bdbc5e..ce2188222 100644 --- a/universal_transfer_operator/tests/conftest.py +++ b/universal_transfer_operator/tests/conftest.py @@ -1,23 +1,45 @@ import logging import os from copy import deepcopy +from urllib.parse import urlparse, urlunparse import pytest +import smart_open import yaml from airflow.models import DAG, Connection, DagRun, TaskInstance as TI from airflow.utils import timezone from airflow.utils.db import create_default_connections from airflow.utils.session import create_session +from google.api_core.exceptions import NotFound from utils.test_utils import create_unique_str from universal_transfer_operator.constants import TransferMode -from universal_transfer_operator.data_providers import create_dataprovider +from universal_transfer_operator.data_providers import DataProviders, create_dataprovider +from universal_transfer_operator.data_providers.filesystem.base import BaseFilesystemProviders +from universal_transfer_operator.datasets.file.base import File from universal_transfer_operator.datasets.table import Table DEFAULT_DATE = timezone.datetime(2016, 1, 1) UNIQUE_HASH_SIZE = 16 -DATASET_NAME_TO_CONN_ID = {"SqliteDataProvider": "sqlite_default", "SnowflakeDataProvider": "snowflake_conn"} +DATASET_NAME_TO_CONN_ID = { + "SqliteDataProvider": "sqlite_default", + "SnowflakeDataProvider": "snowflake_conn", + "BigqueryDataProvider": "google_cloud_default", + "S3DataProvider": "aws_default", + "GCSDataProvider": "google_cloud_default", + "LocalDataProvider": None, + "SFTPDataProvider": "sftp_conn", +} +DATASET_NAME_TO_PROVIDER_TYPE = { + "SqliteDataProvider": "database", + "SnowflakeDataProvider": "database", + "BigqueryDataProvider": "database", + "S3DataProvider": "file", + "GCSDataProvider": "file", + "LocalDataProvider": "file", + "SFTPDataProvider": "file", +} @pytest.fixture @@ -91,10 +113,165 @@ def dataset_table_fixture(request): table.name = create_unique_str(UNIQUE_HASH_SIZE) file = params.get("file") - dp.populate_table_metadata(table) + dp.populate_metadata() # dp.create_schema_if_needed(table.metadata.schema) if file: dp.load_file_to_table(file, table) yield dp, table dp.drop_table(table) + + +def set_table_missing_values(table: Table, dataset_name: str) -> Table: + """ + Set missing values of table dataset + """ + conn_id = DATASET_NAME_TO_CONN_ID[dataset_name] + table = table or Table(conn_id=conn_id) + + if not table.conn_id: + table.conn_id = conn_id + + if not table.name: + # We create a unique table name to make the name unique across runs + table.name = create_unique_str(UNIQUE_HASH_SIZE) + + return table + + +def set_file_missing_values(file: File, dataset_name: str): + """ + Set missing values of file dataset + """ + conn_id = DATASET_NAME_TO_CONN_ID[dataset_name] + if not file.conn_id: + file.conn_id = conn_id + return file + + +def populate_file(src_file_path: str, dataset_provider: BaseFilesystemProviders, dp_name: str): + """ + Populate file with local file data + :param src_file_path: source path of the content that will be populated + :param dataset_provider: dataset provider object, that will be populated with content in src_file_path + :param dp_name: name of data provider + :return: + """ + src_file_object = dataset_provider._convert_remote_file_to_byte_stream(src_file_path) + mode = "wb" if dataset_provider.read_as_binary(src_file_path) else "w" + + # Currently, we are passing the credentials to sftp server via URL - sftp://username:password@localhost, we are + # populating the credentials in the URL if the server destination is SFTP. + path = dataset_provider.dataset.path + if dp_name == "SFTPDataProvider": + original_url = urlparse(path) + cred_url = urlparse(dataset_provider.get_uri()) + url_netloc = f"{cred_url.netloc}/{original_url.netloc}" + url_path = original_url.path + cred_url = cred_url._replace(netloc=url_netloc, path=url_path) + path = urlunparse(cred_url) + + with smart_open.open(path, mode=mode, transport_params=dataset_provider.transport_params) as stream: + stream.write(src_file_object.read()) + stream.flush() + + +def set_missing_values(dataset_object: [File, Table], dp_name: str) -> [File, Table]: + """Set missing values for datasets""" + dataset_type = DATASET_NAME_TO_PROVIDER_TYPE[dp_name] + if dataset_type == "database": + dataset_object = set_table_missing_values(table=dataset_object, dataset_name=dp_name) + elif dataset_type == "file": + dataset_object = set_file_missing_values(file=dataset_object, dataset_name=dp_name) + return dataset_object + + +def load_data_in_datasets( + dataset_object: [File, Table], dp: DataProviders, dp_name: str, local_file_path: str +): + """ + Load data in datasets + :param dataset_object: user passed Dataset object + :param dp: DataProviders object created using dataset_object + :param dp_name: name of data_provider class + :param local_file_path: data that needs to be loaded in dataset + """ + dataset_type = DATASET_NAME_TO_PROVIDER_TYPE[dp_name] + if dataset_type == "database": + dp.create_schema_if_needed(dataset_object.metadata.schema) + if local_file_path: + dp.load_file_to_table(File(local_file_path), dataset_object) + elif dataset_type == "file": + if local_file_path: + populate_file(src_file_path=local_file_path, dataset_provider=dp, dp_name=dp_name) + + +def delete_dataset(dataset_object: [File, Table], dp: DataProviders, dp_name: str, local_file_path: str): + """ + Delete dataset + + :param dataset_object: user passed Dataset object + :param dp: DataProviders object created using dataset_object + :param dp_name: name of data_provider class + :param local_file_path: data that needs to be loaded in dataset + """ + dataset_type = DATASET_NAME_TO_PROVIDER_TYPE[dp_name] + if dataset_type == "database": + dp.drop_table(dataset_object) + elif dataset_type == "file" and local_file_path: + try: + dp.delete() + except (FileNotFoundError, NotFound): + pass + + +def dataset_fixture(request): + """ + Fixture to populate data in datasets and fill in missing values. For file dataset we need to pass an "object" + parameter with absolute path of the file. + Example: + @pytest.mark.parametrize( + "src_dataset_fixture", + [ + {"name": "SqliteDataProvider", "local_file_path": f"{str(CWD)}/../../data/sample.csv"}, + { + "name": "S3DataProvider", + "object": File(path=f"s3://tmp9/{create_unique_str(10)}.csv"), + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + } + ], + indirect=True, + ids=lambda dp: dp["name"], + ) + + """ + # We deepcopy the request param dictionary as we modify the table item directly. + params = deepcopy(request.param) + + dp_name = params["name"] + dataset_object = params.get("object", None) + transfer_mode = params.get("transfer_mode", TransferMode.NONNATIVE) + local_file_path = params.get("local_file_path") + + dataset_object = set_missing_values(dataset_object=dataset_object, dp_name=dp_name) + + dp = create_dataprovider(dataset=dataset_object, transfer_mode=transfer_mode) + dp.populate_metadata() + + load_data_in_datasets( + dataset_object=dataset_object, dp=dp, dp_name=dp_name, local_file_path=local_file_path + ) + + yield dp, dataset_object + + delete_dataset(dataset_object=dataset_object, dp=dp, dp_name=dp_name, local_file_path=local_file_path) + + +@pytest.fixture +def src_dataset_fixture(request): + yield from dataset_fixture(request) + + +@pytest.fixture +def dst_dataset_fixture(request): + yield from dataset_fixture(request) diff --git a/universal_transfer_operator/tests/test_data_provider/test_data_provider.py b/universal_transfer_operator/tests/test_data_provider/test_data_provider.py index 4c990afa2..3b107f9bd 100644 --- a/universal_transfer_operator/tests/test_data_provider/test_data_provider.py +++ b/universal_transfer_operator/tests/test_data_provider/test_data_provider.py @@ -1,6 +1,8 @@ import pytest +from universal_transfer_operator.constants import TransferMode from universal_transfer_operator.data_providers import create_dataprovider +from universal_transfer_operator.data_providers.base import DataProviders from universal_transfer_operator.data_providers.database.snowflake import SnowflakeDataProvider from universal_transfer_operator.data_providers.database.sqlite import SqliteDataProvider from universal_transfer_operator.data_providers.filesystem.aws.s3 import S3DataProvider @@ -25,3 +27,47 @@ def test_create_dataprovider(datasets): """Test that the correct data-provider is created for a dataset""" data_provider = create_dataprovider(dataset=datasets["dataset"]) assert isinstance(data_provider, datasets["expected"]) + + +def test_raising_of_NotImplementedError(): + """ + Test that the class inheriting from DataProviders should implement methods + """ + + class Test(DataProviders): + pass + + methods = [ + "check_if_exists", + "read", + "openlineage_dataset_namespace", + "openlineage_dataset_name", + "populate_metadata", + ] + + test = Test(dataset=File("/tmp/test.csv"), transfer_mode=TransferMode.NONNATIVE) + for method in methods: + with pytest.raises(NotImplementedError): + m = test.__getattribute__(method) + m() + + with pytest.raises(NotImplementedError): + test.hook + + +def test_openlineage_dataset_uri(): + """ + Test openlineage_dataset_uri is creating the correct uri + """ + + class Test(DataProviders): + @property + def openlineage_dataset_name(self): + return "/rest/v1/get_user" + + @property + def openlineage_dataset_namespace(self): + return "http://localhost:9900" + + test = Test(dataset=File("/tmp/test.csv"), transfer_mode=TransferMode.NONNATIVE) + assert test.openlineage_dataset_uri == "http://localhost:9900/rest/v1/get_user" diff --git a/universal_transfer_operator/tests_integration/test_data_provider/test_databases/test_base.py b/universal_transfer_operator/tests_integration/test_data_provider/test_databases/test_base.py new file mode 100644 index 000000000..a27e15ac3 --- /dev/null +++ b/universal_transfer_operator/tests_integration/test_data_provider/test_databases/test_base.py @@ -0,0 +1,123 @@ +import pathlib +from urllib.parse import urlparse, urlunparse + +import pandas as pd +import pytest +import smart_open +from pyarrow.lib import ArrowInvalid +from utils.test_utils import create_unique_str + +from universal_transfer_operator.datasets.file.base import File +from universal_transfer_operator.datasets.table import Table + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.parametrize( + "src_dataset_fixture", + [ + {"name": "SqliteDataProvider", "local_file_path": f"{str(CWD)}/../../data/sample.csv"}, + { + "name": "SnowflakeDataProvider", + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + }, + { + "name": "BigqueryDataProvider", + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + }, + { + "name": "S3DataProvider", + "object": File(path=f"s3://tmp9/{create_unique_str(10)}.csv"), + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + }, + { + "name": "GCSDataProvider", + "object": File(path=f"gs://uto-test/{create_unique_str(10)}.csv"), + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + }, + { + "name": "LocalDataProvider", + "object": File(path=f"/tmp/{create_unique_str(10)}.csv"), + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + }, + { + "name": "SFTPDataProvider", + "object": File(path=f"sftp://upload/{create_unique_str(10)}.csv"), + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + }, + ], + indirect=True, + ids=lambda dp: dp["name"], +) +@pytest.mark.parametrize( + "dst_dataset_fixture", + [ + { + "name": "SqliteDataProvider", + }, + { + "name": "BigqueryDataProvider", + }, + { + "name": "SnowflakeDataProvider", + }, + {"name": "S3DataProvider", "object": File(path=f"s3://tmp9/{create_unique_str(10)}")}, + { + "name": "GCSDataProvider", + "object": File(path=f"gs://uto-test/{create_unique_str(10)}"), + }, + {"name": "LocalDataProvider", "object": File(path=f"/tmp/{create_unique_str(10)}")}, + {"name": "SFTPDataProvider", "object": File(path=f"sftp://upload/{create_unique_str(10)}")}, + ], + indirect=True, + ids=lambda dp: dp["name"], +) +def test_read_write_methods_of_datasets(src_dataset_fixture, dst_dataset_fixture): + """ + Test datasets read and write methods of all datasets + """ + src_dp, _ = src_dataset_fixture + dst_dp, _ = dst_dataset_fixture + for source_data in src_dp.read(): + dst_dp.write(source_data) + output_df = export_to_dataframe(dst_dp) + input_df = pd.read_csv(f"{str(CWD)}/../../data/sample.csv") + + assert output_df.equals(input_df) + + +def export_to_dataframe(data_provider) -> pd.DataFrame: + """Read file from all supported location and convert them into dataframes.""" + if isinstance(data_provider.dataset, File): + path = data_provider.dataset.path + # Currently, we are passing the credentials to sftp server via URL - sftp://username:password@localhost, we are + # populating the credentials in the URL if the server destination is SFTP. + if data_provider.dataset.path.startswith("sftp://"): + path = get_complete_url(data_provider) + try: + # Currently, there is a limitation, when we are saving data of a table in a file, we choose rhe parquet + # format, when moving this saved file to another filetype location(like s3/gcs/local) we are not able to + # change the data format, because of this case when validating if the source is a database and + # destination is a filetype, we need to check for parquet format, for other cases like - + # database -> database, filesystem -> database and filesystem -> filesystem it works as expected. + with smart_open.open(path, mode="rb", transport_params=data_provider.transport_params) as stream: + return pd.read_parquet(stream) + except ArrowInvalid: + with smart_open.open(path, mode="r", transport_params=data_provider.transport_params) as stream: + return pd.read_csv(stream) + elif isinstance(data_provider.dataset, Table): + return data_provider.export_table_to_pandas_dataframe() + + +def get_complete_url(dataset_provider): + """ + Add sftp credential to url + """ + path = dataset_provider.dataset.path + original_url = urlparse(path) + cred_url = urlparse(dataset_provider.get_uri()) + url_netloc = f"{cred_url.netloc}/{original_url.netloc}" + url_path = original_url.path + cred_url = cred_url._replace(netloc=url_netloc, path=url_path) + path = urlunparse(cred_url) + return path diff --git a/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_aws/test_s3.py b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_aws/test_s3.py new file mode 100644 index 000000000..041016e15 --- /dev/null +++ b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_aws/test_s3.py @@ -0,0 +1,27 @@ +import pathlib + +import pytest +from utils.test_utils import create_unique_str + +from universal_transfer_operator.datasets.file.base import File + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.parametrize( + "src_dataset_fixture", + [ + { + "name": "S3DataProvider", + "local_file_path": f"{str(CWD)}/../../../data/sample.csv", + "object": File(path=f"s3://tmp9/{create_unique_str(10)}.csv"), + } + ], + indirect=True, + ids=lambda dp: dp["name"], +) +def test_delete_s3_object(src_dataset_fixture): + dp, _ = src_dataset_fixture + assert dp.check_if_exists() + dp.delete() + assert not dp.check_if_exists() diff --git a/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_google/test_cloud/test_gcs.py b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_google/test_cloud/test_gcs.py new file mode 100644 index 000000000..12fae6085 --- /dev/null +++ b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_google/test_cloud/test_gcs.py @@ -0,0 +1,27 @@ +import pathlib + +import pytest +from utils.test_utils import create_unique_str + +from universal_transfer_operator.datasets.file.base import File + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.parametrize( + "src_dataset_fixture", + [ + { + "name": "GCSDataProvider", + "local_file_path": f"{str(CWD)}/../../../../data/sample.csv", + "object": File(path=f"gs://uto-test/{create_unique_str(10)}.csv"), + } + ], + indirect=True, + ids=lambda dp: dp["name"], +) +def test_delete_gcs_object(src_dataset_fixture): + dp, _ = src_dataset_fixture + assert dp.check_if_exists() + dp.delete() + assert not dp.check_if_exists() diff --git a/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_local.py b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_local.py new file mode 100644 index 000000000..60400990e --- /dev/null +++ b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_local.py @@ -0,0 +1,27 @@ +import pathlib + +import pytest +from utils.test_utils import create_unique_str + +from universal_transfer_operator.datasets.file.base import File + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.parametrize( + "src_dataset_fixture", + [ + { + "name": "LocalDataProvider", + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + "object": File(path=f"/tmp/{create_unique_str(10)}.csv"), + } + ], + indirect=True, + ids=lambda dp: dp["name"], +) +def test_delete_local_object(src_dataset_fixture): + dp, _ = src_dataset_fixture + assert dp.check_if_exists() + dp.delete() + assert not dp.check_if_exists() diff --git a/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_sftp.py b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_sftp.py new file mode 100644 index 000000000..2a16d0dc2 --- /dev/null +++ b/universal_transfer_operator/tests_integration/test_data_provider/test_filesystem/test_sftp.py @@ -0,0 +1,27 @@ +import pathlib + +import pytest +from utils.test_utils import create_unique_str + +from universal_transfer_operator.datasets.file.base import File + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.parametrize( + "src_dataset_fixture", + [ + { + "name": "SFTPDataProvider", + "local_file_path": f"{str(CWD)}/../../data/sample.csv", + "object": File(path=f"sftp://upload/{create_unique_str(10)}.csv"), + } + ], + indirect=True, + ids=lambda dp: dp["name"], +) +def test_delete_sftp_object(src_dataset_fixture): + dp, _ = src_dataset_fixture + assert dp.check_if_exists() + dp.delete() + assert not dp.check_if_exists()