Skip to content

Commit

Permalink
Upload all data to BQ in ARRAY safe manner
Browse files Browse the repository at this point in the history
Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>
  • Loading branch information
judahrand committed Sep 20, 2021
1 parent 81c8abc commit db6f607
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 49 deletions.
75 changes: 41 additions & 34 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,23 +223,8 @@ def to_bigquery(
job_config = bigquery.QueryJobConfig(destination=path)

if not job_config.dry_run and self.on_demand_feature_views is not None:
# It is complicated to get BQ to understand that we want an ARRAY<value_type>
# https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions
# https://github.com/googleapis/python-bigquery/issues/19
writer = pyarrow.BufferOutputStream()
pyarrow.parquet.write_table(
self.to_arrow(), writer, use_compliant_nested_type=True
)
reader = pyarrow.BufferReader(writer.getvalue())

parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True
job_config = bigquery.LoadJobConfig()
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.parquet_options = parquet_options

job = self.client.load_table_from_file(
reader, job_config.destination, job_config=job_config,
job = _write_pyarrow_table_to_bq(
self.client, self.to_arrow(), job_config.destination
)
job.result()
print(f"Done writing to '{job_config.destination}'.")
Expand Down Expand Up @@ -344,23 +329,7 @@ def _upload_entity_df_and_get_entity_schema(
elif isinstance(entity_df, pd.DataFrame):
# Drop the index so that we dont have unnecessary columns
entity_df.reset_index(drop=True, inplace=True)

# Upload the dataframe into BigQuery, creating a temporary table
# It is complicated to get BQ to understand that we want an ARRAY<value_type>
# https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions
# https://github.com/googleapis/python-bigquery/issues/19
writer = pyarrow.BufferOutputStream()
pyarrow.parquet.write_table(
pyarrow.Table.from_pandas(entity_df), writer, use_compliant_nested_type=True
)
reader = pyarrow.BufferReader(writer.getvalue())

parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True
job_config = bigquery.LoadJobConfig()
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.parquet_options = parquet_options
job = client.load_table_from_file(reader, table_name, job_config=job_config)
job = _write_df_to_bq(client, entity_df, table_name)
block_until_done(client, job)

entity_schema = dict(zip(entity_df.columns, entity_df.dtypes))
Expand Down Expand Up @@ -395,6 +364,44 @@ def _get_bigquery_client(project: Optional[str] = None):
return client


def _write_df_to_bq(
client: bigquery.Client, df: pd.DataFrame, table_name: str
) -> bigquery.LoadJob:
# It is complicated to get BQ to understand that we want an ARRAY<value_type>
# https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions
# https://github.com/googleapis/python-bigquery/issues/19
writer = pyarrow.BufferOutputStream()
pyarrow.parquet.write_table(
pyarrow.Table.from_pandas(df), writer, use_compliant_nested_type=True
)
return _write_pyarrow_buffer_to_bq(client, writer.getvalue(), table_name,)


def _write_pyarrow_table_to_bq(
client: bigquery.Client, table: pyarrow.Table, table_name: str
) -> bigquery.LoadJob:
# It is complicated to get BQ to understand that we want an ARRAY<value_type>
# https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions
# https://github.com/googleapis/python-bigquery/issues/19
writer = pyarrow.BufferOutputStream()
pyarrow.parquet.write_table(table, writer, use_compliant_nested_type=True)
return _write_pyarrow_buffer_to_bq(client, writer.getvalue(), table_name,)


def _write_pyarrow_buffer_to_bq(
client: bigquery.Client, buf: pyarrow.Buffer, table_name: str
) -> bigquery.LoadJob:
reader = pyarrow.BufferReader(buf)

parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True
job_config = bigquery.LoadJobConfig()
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.parquet_options = parquet_options

return client.load_table_from_file(reader, table_name, job_config=job_config,)


# TODO: Optimizations
# * Use GENERATE_UUID() instead of ROW_NUMBER(), or join on entity columns directly
# * Precompute ROW_NUMBER() so that it doesn't have to be recomputed for every query on entity_dataframe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from feast import BigQuerySource
from feast.data_source import DataSource
from feast.infra.offline_stores.bigquery import BigQueryOfflineStoreConfig
from feast.infra.offline_stores.bigquery import (
BigQueryOfflineStoreConfig,
_write_df_to_bq,
)
from tests.integration.feature_repos.universal.data_source_creator import (
DataSourceCreator,
)
Expand Down Expand Up @@ -61,15 +64,12 @@ def create_data_source(

self.create_dataset()

job_config = bigquery.LoadJobConfig()
if self.gcp_project not in destination_name:
destination_name = (
f"{self.gcp_project}.{self.project_name}.{destination_name}"
)

job = self.client.load_table_from_dataframe(
df, destination_name, job_config=job_config
)
job = _write_df_to_bq(self.client, df, destination_name)
job.result()

self.tables.append(destination_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from feast.feature import Feature
from feast.feature_store import FeatureStore, _validate_feature_refs
from feast.feature_view import FeatureView
from feast.infra.offline_stores.bigquery import BigQueryOfflineStoreConfig
from feast.infra.offline_stores.bigquery import (
BigQueryOfflineStoreConfig,
_write_df_to_bq,
)
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
)
Expand Down Expand Up @@ -62,9 +65,8 @@ def stage_driver_hourly_stats_parquet_source(directory, df):

def stage_driver_hourly_stats_bigquery_source(df, table_id):
client = bigquery.Client()
job_config = bigquery.LoadJobConfig()
df.reset_index(drop=True, inplace=True)
job = client.load_table_from_dataframe(df, table_id, job_config=job_config)
job = _write_df_to_bq(client, df, table_id)
job.result()


Expand Down Expand Up @@ -99,9 +101,8 @@ def feature_service(name: str, views) -> FeatureService:

def stage_customer_daily_profile_bigquery_source(df, table_id):
client = bigquery.Client()
job_config = bigquery.LoadJobConfig()
df.reset_index(drop=True, inplace=True)
job = client.load_table_from_dataframe(df, table_id, job_config=job_config)
job = _write_df_to_bq(client, df, table_id)
job.result()


Expand Down Expand Up @@ -231,9 +232,8 @@ def get_expected_training_df(

def stage_orders_bigquery(df, table_id):
client = bigquery.Client()
job_config = bigquery.LoadJobConfig()
df.reset_index(drop=True, inplace=True)
job = client.load_table_from_dataframe(df, table_id, job_config=job_config)
job = _write_df_to_bq(client, df, table_id)
job.result()


Expand Down
5 changes: 2 additions & 3 deletions sdk/python/tests/utils/data_source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from feast import BigQuerySource, FileSource
from feast.data_format import ParquetFormat
from feast.infra.offline_stores.bigquery import _write_df_to_bq


@contextlib.contextmanager
Expand Down Expand Up @@ -38,9 +39,7 @@ def simple_bq_source_using_table_ref_arg(
client.update_dataset(dataset, ["default_table_expiration_ms"])
table_ref = f"{gcp_project}.{bigquery_dataset}.table_{random.randrange(100, 999)}"

job = client.load_table_from_dataframe(
df, table_ref, job_config=bigquery.LoadJobConfig()
)
job = _write_df_to_bq(client, df, table_ref)
job.result()

return BigQuerySource(
Expand Down

0 comments on commit db6f607

Please sign in to comment.