diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 8aef92abe6..3e1317626a 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -223,23 +223,8 @@ def to_bigquery( job_config = bigquery.QueryJobConfig(destination=path) if not job_config.dry_run and self.on_demand_feature_views is not None: - # It is complicated to get BQ to understand that we want an ARRAY - # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions - # https://github.com/googleapis/python-bigquery/issues/19 - writer = pyarrow.BufferOutputStream() - pyarrow.parquet.write_table( - self.to_arrow(), writer, use_compliant_nested_type=True - ) - reader = pyarrow.BufferReader(writer.getvalue()) - - parquet_options = bigquery.format_options.ParquetOptions() - parquet_options.enable_list_inference = True - job_config = bigquery.LoadJobConfig() - job_config.source_format = bigquery.SourceFormat.PARQUET - job_config.parquet_options = parquet_options - - job = self.client.load_table_from_file( - reader, job_config.destination, job_config=job_config, + job = _write_pyarrow_table_to_bq( + self.client, self.to_arrow(), job_config.destination ) job.result() print(f"Done writing to '{job_config.destination}'.") @@ -344,23 +329,7 @@ def _upload_entity_df_and_get_entity_schema( elif isinstance(entity_df, pd.DataFrame): # Drop the index so that we dont have unnecessary columns entity_df.reset_index(drop=True, inplace=True) - - # Upload the dataframe into BigQuery, creating a temporary table - # It is complicated to get BQ to understand that we want an ARRAY - # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions - # https://github.com/googleapis/python-bigquery/issues/19 - writer = pyarrow.BufferOutputStream() - pyarrow.parquet.write_table( - pyarrow.Table.from_pandas(entity_df), writer, use_compliant_nested_type=True - ) - reader = pyarrow.BufferReader(writer.getvalue()) - - parquet_options = bigquery.format_options.ParquetOptions() - parquet_options.enable_list_inference = True - job_config = bigquery.LoadJobConfig() - job_config.source_format = bigquery.SourceFormat.PARQUET - job_config.parquet_options = parquet_options - job = client.load_table_from_file(reader, table_name, job_config=job_config) + job = _write_df_to_bq(client, entity_df, table_name) block_until_done(client, job) entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) @@ -395,6 +364,44 @@ def _get_bigquery_client(project: Optional[str] = None): return client +def _write_df_to_bq( + client: bigquery.Client, df: pd.DataFrame, table_name: str +) -> bigquery.LoadJob: + # It is complicated to get BQ to understand that we want an ARRAY + # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions + # https://github.com/googleapis/python-bigquery/issues/19 + writer = pyarrow.BufferOutputStream() + pyarrow.parquet.write_table( + pyarrow.Table.from_pandas(df), writer, use_compliant_nested_type=True + ) + return _write_pyarrow_buffer_to_bq(client, writer.getvalue(), table_name,) + + +def _write_pyarrow_table_to_bq( + client: bigquery.Client, table: pyarrow.Table, table_name: str +) -> bigquery.LoadJob: + # It is complicated to get BQ to understand that we want an ARRAY + # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#parquetoptions + # https://github.com/googleapis/python-bigquery/issues/19 + writer = pyarrow.BufferOutputStream() + pyarrow.parquet.write_table(table, writer, use_compliant_nested_type=True) + return _write_pyarrow_buffer_to_bq(client, writer.getvalue(), table_name,) + + +def _write_pyarrow_buffer_to_bq( + client: bigquery.Client, buf: pyarrow.Buffer, table_name: str +) -> bigquery.LoadJob: + reader = pyarrow.BufferReader(buf) + + parquet_options = bigquery.format_options.ParquetOptions() + parquet_options.enable_list_inference = True + job_config = bigquery.LoadJobConfig() + job_config.source_format = bigquery.SourceFormat.PARQUET + job_config.parquet_options = parquet_options + + return client.load_table_from_file(reader, table_name, job_config=job_config,) + + # TODO: Optimizations # * Use GENERATE_UUID() instead of ROW_NUMBER(), or join on entity columns directly # * Precompute ROW_NUMBER() so that it doesn't have to be recomputed for every query on entity_dataframe diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py index 228e9959d5..46e0535d73 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py @@ -6,7 +6,10 @@ from feast import BigQuerySource from feast.data_source import DataSource -from feast.infra.offline_stores.bigquery import BigQueryOfflineStoreConfig +from feast.infra.offline_stores.bigquery import ( + BigQueryOfflineStoreConfig, + _write_df_to_bq, +) from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, ) @@ -61,15 +64,12 @@ def create_data_source( self.create_dataset() - job_config = bigquery.LoadJobConfig() if self.gcp_project not in destination_name: destination_name = ( f"{self.gcp_project}.{self.project_name}.{destination_name}" ) - job = self.client.load_table_from_dataframe( - df, destination_name, job_config=job_config - ) + job = _write_df_to_bq(self.client, df, destination_name) job.result() self.tables.append(destination_name) diff --git a/sdk/python/tests/integration/offline_store/test_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_historical_retrieval.py index 44f9e595e3..2f1377dd7d 100644 --- a/sdk/python/tests/integration/offline_store/test_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_historical_retrieval.py @@ -19,7 +19,10 @@ from feast.feature import Feature from feast.feature_store import FeatureStore, _validate_feature_refs from feast.feature_view import FeatureView -from feast.infra.offline_stores.bigquery import BigQueryOfflineStoreConfig +from feast.infra.offline_stores.bigquery import ( + BigQueryOfflineStoreConfig, + _write_df_to_bq, +) from feast.infra.offline_stores.offline_utils import ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, ) @@ -62,9 +65,8 @@ def stage_driver_hourly_stats_parquet_source(directory, df): def stage_driver_hourly_stats_bigquery_source(df, table_id): client = bigquery.Client() - job_config = bigquery.LoadJobConfig() df.reset_index(drop=True, inplace=True) - job = client.load_table_from_dataframe(df, table_id, job_config=job_config) + job = _write_df_to_bq(client, df, table_id) job.result() @@ -99,9 +101,8 @@ def feature_service(name: str, views) -> FeatureService: def stage_customer_daily_profile_bigquery_source(df, table_id): client = bigquery.Client() - job_config = bigquery.LoadJobConfig() df.reset_index(drop=True, inplace=True) - job = client.load_table_from_dataframe(df, table_id, job_config=job_config) + job = _write_df_to_bq(client, df, table_id) job.result() @@ -231,9 +232,8 @@ def get_expected_training_df( def stage_orders_bigquery(df, table_id): client = bigquery.Client() - job_config = bigquery.LoadJobConfig() df.reset_index(drop=True, inplace=True) - job = client.load_table_from_dataframe(df, table_id, job_config=job_config) + job = _write_df_to_bq(client, df, table_id) job.result() diff --git a/sdk/python/tests/utils/data_source_utils.py b/sdk/python/tests/utils/data_source_utils.py index 17ab06365e..af9663203a 100644 --- a/sdk/python/tests/utils/data_source_utils.py +++ b/sdk/python/tests/utils/data_source_utils.py @@ -7,6 +7,7 @@ from feast import BigQuerySource, FileSource from feast.data_format import ParquetFormat +from feast.infra.offline_stores.bigquery import _write_df_to_bq @contextlib.contextmanager @@ -38,9 +39,7 @@ def simple_bq_source_using_table_ref_arg( client.update_dataset(dataset, ["default_table_expiration_ms"]) table_ref = f"{gcp_project}.{bigquery_dataset}.table_{random.randrange(100, 999)}" - job = client.load_table_from_dataframe( - df, table_ref, job_config=bigquery.LoadJobConfig() - ) + job = _write_df_to_bq(client, df, table_ref) job.result() return BigQuerySource(