Skip to content

Commit

Permalink
Implement transfer for BigQuery - read/write (#1829)
Browse files Browse the repository at this point in the history
**Please describe the feature you'd like to see**
- Add `DataProvider` for Bigquery - read/write methods
- Add non-native transfer implementation for GCS to BigQuery
-  Add non-native transfer implementation for S3 to BigQuery
- Add non-native transfer example DAG for BigQuery to Sqlite 
- Add non-native transfer example DAG for BigQuery to Snowflake
- Add example DAG
- Add tests with 90% coverage

**Acceptance Criteria**

- [ ] All checks and tests in the CI should pass
- [ ] Unit tests (90% code coverage or more, [once
available](#191))
- [ ] Integration tests (if the feature relates to a new database or
external service)
- [ ] Example DAG
- [ ] Docstrings in
[reStructuredText](https://peps.python.org/pep-0287/) for each of
methods, classes, functions and module-level attributes (including
Example DAG on how it should be used)
- [ ] Exception handling in case of errors
- [ ] Logging (are we exposing useful information to the user? e.g.
source and destination)
- [ ] Improve the documentation (README, Sphinx, and any other relevant)
- [ ] How to use Guide for the feature
([example](https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/operators/postgres_operator_howto_guide.html))


closes: #1732 
closes: #1785
closes: #1730

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Utkarsh Sharma <utkarsharma2@gmail.com>
Co-authored-by: Phani Kumar <94376113+phanikumv@users.noreply.github.com>
  • Loading branch information
4 people authored Mar 16, 2023
1 parent bf189de commit d06d3e7
Show file tree
Hide file tree
Showing 21 changed files with 825 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from universal_transfer_operator.constants import FileType
from universal_transfer_operator.datasets.file.base import File
from universal_transfer_operator.datasets.table import Table
from universal_transfer_operator.datasets.table import Metadata, Table
from universal_transfer_operator.universal_transfer_operator import UniversalTransferOperator

s3_bucket = os.getenv("S3_BUCKET", "s3://astro-sdk-test")
Expand Down Expand Up @@ -35,29 +35,72 @@
transfer_non_native_s3_to_sqlite = UniversalTransferOperator(
task_id="transfer_non_native_s3_to_sqlite",
source_dataset=File(path=f"{s3_bucket}/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV),
destination_dataset=Table(name="uto_s3_table_1", conn_id="sqlite_default"),
destination_dataset=Table(name="uto_s3_to_sqlite_table", conn_id="sqlite_default"),
)

transfer_non_native_gs_to_sqlite = UniversalTransferOperator(
task_id="transfer_non_native_gs_to_sqlite",
source_dataset=File(
path=f"{gcs_bucket}/uto/csv_files/", conn_id="google_cloud_default", filetype=FileType.CSV
),
destination_dataset=Table(name="uto_gs_table_1", conn_id="sqlite_default"),
destination_dataset=Table(name="uto_gs_to_sqlite_table", conn_id="sqlite_default"),
)

transfer_non_native_s3_to_snowflake = UniversalTransferOperator(
task_id="transfer_non_native_s3_to_snowflake",
source_dataset=File(
path="s3://astro-sdk-test/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV
),
destination_dataset=Table(name="uto_s3_table_2", conn_id="snowflake_conn"),
destination_dataset=Table(name="uto_s3_table_to_snowflake", conn_id="snowflake_conn"),
)

transfer_non_native_gs_to_snowflake = UniversalTransferOperator(
task_id="transfer_non_native_gs_to_snowflake",
source_dataset=File(
path="gs://uto-test/uto/csv_files/", conn_id="google_cloud_default", filetype=FileType.CSV
),
destination_dataset=Table(name="uto_gs_table_2", conn_id="snowflake_conn"),
destination_dataset=Table(name="uto_gs_to_snowflake_table", conn_id="snowflake_conn"),
)

transfer_non_native_gs_to_bigquery = UniversalTransferOperator(
task_id="transfer_non_native_gs_to_bigquery",
source_dataset=File(path="gs://uto-test/uto/homes_main.csv", conn_id="google_cloud_default"),
destination_dataset=Table(
name="uto_gs_to_bigquery_table",
conn_id="google_cloud_default",
metadata=Metadata(schema="astro"),
),
)

transfer_non_native_s3_to_bigquery = UniversalTransferOperator(
task_id="transfer_non_native_s3_to_bigquery",
source_dataset=File(
path="s3://astro-sdk-test/uto/csv_files/", conn_id="aws_default", filetype=FileType.CSV
),
destination_dataset=Table(
name="uto_s3_to_bigquery_destination_table",
conn_id="google_cloud_default",
metadata=Metadata(schema="astro"),
),
)

transfer_non_native_bigquery_to_snowflake = UniversalTransferOperator(
task_id="transfer_non_native_bigquery_to_snowflake",
source_dataset=Table(
name="uto_s3_to_bigquery_table",
conn_id="google_cloud_default",
metadata=Metadata(schema="astro"),
),
destination_dataset=Table(
name="uto_bigquery_to_snowflake_table",
conn_id="snowflake_conn",
),
)

transfer_non_native_bigquery_to_sqlite = UniversalTransferOperator(
task_id="transfer_non_native_bigquery_to_sqlite",
source_dataset=Table(
name="uto_s3_to_bigquery_table", conn_id="google_cloud_default", metadata=Metadata(schema="astro")
),
destination_dataset=Table(name="uto_bigquery_to_sqlite_table", conn_id="sqlite_default"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
("aws", File): "universal_transfer_operator.data_providers.filesystem.aws.s3",
("gs", File): "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs",
("google_cloud_platform", File): "universal_transfer_operator.data_providers.filesystem.google.cloud.gcs",
("sqlite", Table): "universal_transfer_operator.data_providers.database.sqlite",
("sftp", File): "universal_transfer_operator.data_providers.filesystem.sftp",
("google_cloud_platform", Table): "universal_transfer_operator.data_providers.database.google.bigquery",
("gs", Table): "universal_transfer_operator.data_providers.database.google.bigquery",
("sqlite", Table): "universal_transfer_operator.data_providers.database.sqlite",
("snowflake", Table): "universal_transfer_operator.data_providers.database.snowflake",
(None, File): "universal_transfer_operator.data_providers.filesystem.local",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,10 @@ def openlineage_dataset_uri(self) -> str:
Returns the open lineage dataset uri as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}"

def populate_metadata(self):
"""
Given a dataset, check if the dataset has metadata.
"""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Callable

import pandas as pd
Expand Down Expand Up @@ -27,6 +28,7 @@
from universal_transfer_operator.data_providers.filesystem import resolve_file_path_pattern
from universal_transfer_operator.data_providers.filesystem.base import FileStream
from universal_transfer_operator.datasets.base import Dataset
from universal_transfer_operator.datasets.dataframe.pandas import PandasDataframe
from universal_transfer_operator.datasets.file.base import File
from universal_transfer_operator.datasets.table import Metadata, Table
from universal_transfer_operator.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SCHEMA
Expand Down Expand Up @@ -186,13 +188,20 @@ def check_if_transfer_supported(self, source_dataset: Table) -> bool:
return Location(source_connection_type) in self.transfer_mapping

def read(self):
""" ""Read the dataset and write to local reference location"""
raise NotImplementedError
"""Read the dataset and write to local reference location"""
with NamedTemporaryFile(mode="wb+", suffix=".parquet", delete=False) as tmp_file:
df = self.export_table_to_pandas_dataframe()
df.to_parquet(tmp_file.name)
local_temp_file = FileStream(
remote_obj_buffer=tmp_file.file,
actual_filename=tmp_file.name,
actual_file=File(path=tmp_file.name),
)
yield local_temp_file

def write(self, source_ref: FileStream):
"""
Write the data from local reference location to the dataset.
:param source_ref: Stream of data to be loaded into output table.
"""
return self.load_file_to_table(input_file=source_ref.actual_file, output_table=self.dataset)
Expand Down Expand Up @@ -269,7 +278,7 @@ def default_metadata(self) -> Metadata:
"""
raise NotImplementedError

def populate_table_metadata(self, table: Table) -> Table:
def populate_metadata(self):
"""
Given a table, check if the table has metadata.
If the metadata is missing, and the database has metadata, assign it to the table.
Expand All @@ -279,11 +288,11 @@ def populate_table_metadata(self, table: Table) -> Table:
:param table: Table to potentially have their metadata changed
:return table: Return the modified table
"""
if table.metadata and table.metadata.is_empty() and self.default_metadata:
table.metadata = self.default_metadata
if not table.metadata.schema:
table.metadata.schema = self.DEFAULT_SCHEMA
return table

if self.dataset.metadata and self.dataset.metadata.is_empty() and self.default_metadata:
self.dataset.metadata = self.default_metadata
if not self.dataset.metadata.schema:
self.dataset.metadata.schema = self.DEFAULT_SCHEMA

# ---------------------------------------------------------
# Table creation & deletion methods
Expand Down Expand Up @@ -659,3 +668,19 @@ def schema_exists(self, schema: str) -> bool:
:param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc)
"""
raise NotImplementedError

# ---------------------------------------------------------
# Extract methods
# ---------------------------------------------------------

def export_table_to_pandas_dataframe(self) -> pd.DataFrame:
"""
Copy the content of a table to an in-memory Pandas dataframe.
"""

if not self.table_exists(self.dataset):
raise ValueError(f"The table {self.dataset.name} does not exist")

sqla_table = self.get_sqla_table(self.dataset)
df = pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine)
return PandasDataframe.from_pandas_df(df)
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from __future__ import annotations

import attr
import pandas as pd
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from google.api_core.exceptions import (
NotFound as GoogleNotFound,
)
from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine

from universal_transfer_operator.constants import DEFAULT_CHUNK_SIZE, LoadExistStrategy
from universal_transfer_operator.data_providers.database.base import DatabaseDataProvider
from universal_transfer_operator.datasets.table import Metadata, Table
from universal_transfer_operator.settings import BIGQUERY_SCHEMA, BIGQUERY_SCHEMA_LOCATION
from universal_transfer_operator.universal_transfer_operator import TransferParameters


class BigqueryDataProvider(DatabaseDataProvider):
"""BigqueryDataProvider represent all the DataProviders interactions with Bigquery Databases."""

DEFAULT_SCHEMA = BIGQUERY_SCHEMA

illegal_column_name_chars: list[str] = ["."]
illegal_column_name_chars_replacement: list[str] = ["_"]

_create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {} OPTIONS (location='{}')"

def __init__(
self,
dataset: Table,
transfer_mode,
transfer_params: TransferParameters = attr.field(
factory=TransferParameters,
converter=lambda val: TransferParameters(**val) if isinstance(val, dict) else val,
),
):
self.dataset = dataset
self.transfer_params = transfer_params
self.transfer_mode = transfer_mode
self.transfer_mapping = {}
self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {}
super().__init__(
dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params
)

def __repr__(self):
return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})'

@property
def sql_type(self) -> str:
return "bigquery"

@property
def hook(self) -> BigQueryHook:
"""Retrieve Airflow hook to interface with Bigquery."""
return BigQueryHook(
gcp_conn_id=self.dataset.conn_id, use_legacy_sql=False, location=BIGQUERY_SCHEMA_LOCATION
)

@property
def sqlalchemy_engine(self) -> Engine:
"""Return SQAlchemy engine."""
uri = self.hook.get_uri()
with self.hook.provide_gcp_credential_file_as_context():
return create_engine(uri)

@property
def default_metadata(self) -> Metadata:
"""
Fill in default metadata values for table objects addressing snowflake databases
"""
return Metadata(
schema=self.DEFAULT_SCHEMA,
database=self.hook.project_id,
) # type: ignore

# ---------------------------------------------------------
# Table metadata
# ---------------------------------------------------------

def schema_exists(self, schema: str) -> bool:
"""
Checks if a dataset exists in the bigquery
:param schema: Bigquery namespace
"""
try:
self.hook.get_dataset(dataset_id=schema)
except GoogleNotFound:
# google.api_core.exceptions throws when a resource is not found
return False
return True

def _get_schema_location(self, schema: str | None = None) -> str:
"""
Get region where the schema is created
:param schema: Bigquery namespace
"""
if schema is None:
return ""
try:
dataset = self.hook.get_dataset(dataset_id=schema)
return str(dataset.location)
except GoogleNotFound:
# google.api_core.exceptions throws when a resource is not found
return ""

def load_pandas_dataframe_to_table(
self,
source_dataframe: pd.DataFrame,
target_table: Table,
if_exists: LoadExistStrategy = "replace",
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> None:
"""
Create a table with the dataframe's contents.
If the table already exists, append or replace the content, depending on the value of `if_exists`.
:param source_dataframe: Local or remote filepath
:param target_table: Table in which the file will be loaded
:param if_exists: Strategy to be used in case the target table already exists.
:param chunk_size: Specify the number of rows in each batch to be written at a time.
"""
self._assert_not_empty_df(source_dataframe)

try:
creds = self.hook._get_credentials() # skipcq PYL-W021
except AttributeError:
# Details: https://github.com/astronomer/astro-sdk/issues/703
creds = self.hook.get_credentials()
source_dataframe.to_gbq(
self.get_table_qualified_name(target_table),
if_exists=if_exists,
chunksize=chunk_size,
project_id=self.hook.project_id,
credentials=creds,
)

def create_schema_if_needed(self, schema: str | None) -> None:
"""
This function checks if the expected schema exists in the database. If the schema does not exist,
it will attempt to create it.
:param schema: DB Schema - a namespace that contains named objects like (tables, functions, etc)
"""
# We check if the schema exists first because BigQuery will fail on a create schema query even if it
# doesn't actually create a schema.
if schema and not self.schema_exists(schema):
table_schema = self.dataset.metadata.schema if self.dataset and self.dataset.metadata else None
table_location = self._get_schema_location(table_schema)

location = table_location or BIGQUERY_SCHEMA_LOCATION
statement = self._create_schema_statement.format(schema, location)
self.run_sql(statement)

def truncate_table(self, table):
"""Truncate table"""
self.run_sql(f"TRUNCATE {self.get_table_qualified_name(table)}")

@property
def openlineage_dataset_name(self) -> str:
"""
Returns the open lineage dataset name as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
Example: db_name.schema_name.table_name
"""
dataset = self.dataset.metadata.database or self.dataset.metadata.schema
return f"{self.hook.project_id}.{dataset}.{self.dataset.name}"

@property
def openlineage_dataset_namespace(self) -> str:
"""
Returns the open lineage dataset namespace as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
Example: snowflake://ACCOUNT
"""
return self.sql_type

@property
def openlineage_dataset_uri(self) -> str:
"""
Returns the open lineage dataset uri as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}"
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def default_metadata(self) -> Metadata:
database=connection.database,
)

def read(self):
""" ""Read the dataset and write to local reference location"""
raise NotImplementedError

# ---------------------------------------------------------
# Table metadata
# ---------------------------------------------------------
Expand Down Expand Up @@ -256,4 +252,4 @@ def openlineage_dataset_uri(self) -> str:
Returns the open lineage dataset uri as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
return f"{self.openlineage_dataset_namespace()}{self.openlineage_dataset_name()}"
return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}"
Loading

0 comments on commit d06d3e7

Please sign in to comment.