Skip to content

Commit

Permalink
Fix inference of BigQuery ARRAY types. (#2245)
Browse files Browse the repository at this point in the history
* Support more BigQuery ARRAY types

Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>

* Correctly infer BigQuery ARRAY types

Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>
  • Loading branch information
judahrand committed Feb 1, 2022
1 parent 2080fa3 commit 7c53177
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
20 changes: 11 additions & 9 deletions sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Iterable, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple

from feast import type_map
from feast.data_source import DataSource
Expand Down Expand Up @@ -123,18 +123,20 @@ def get_table_column_names_and_types(

client = bigquery.Client()
if self.table_ref is not None:
table_schema = client.get_table(self.table_ref).schema
if not isinstance(table_schema[0], bigquery.schema.SchemaField):
schema = client.get_table(self.table_ref).schema
if not isinstance(schema[0], bigquery.schema.SchemaField):
raise TypeError("Could not parse BigQuery table schema.")

name_type_pairs = [(field.name, field.field_type) for field in table_schema]
else:
bq_columns_query = f"SELECT * FROM ({self.query}) LIMIT 1"
queryRes = client.query(bq_columns_query).result()
name_type_pairs = [
(schema_field.name, schema_field.field_type)
for schema_field in queryRes.schema
]
schema = queryRes.schema

name_type_pairs: List[Tuple[str, str]] = []
for field in schema:
bq_type_as_str = field.field_type
if field.mode == "REPEATED":
bq_type_as_str = "ARRAY<" + bq_type_as_str + ">"
name_type_pairs.append((field.name, bq_type_as_str))

return name_type_pairs

Expand Down
16 changes: 10 additions & 6 deletions sdk/python/feast/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType:


def bq_to_feast_value_type(bq_type_as_str: str) -> ValueType:
is_list = False
if bq_type_as_str.startswith("ARRAY<"):
is_list = True
bq_type_as_str = bq_type_as_str[6:-1]

type_map: Dict[str, ValueType] = {
"DATETIME": ValueType.UNIX_TIMESTAMP,
"TIMESTAMP": ValueType.UNIX_TIMESTAMP,
Expand All @@ -453,15 +458,14 @@ def bq_to_feast_value_type(bq_type_as_str: str) -> ValueType:
"BYTES": ValueType.BYTES,
"BOOL": ValueType.BOOL,
"BOOLEAN": ValueType.BOOL, # legacy sql data type
"ARRAY<INT64>": ValueType.INT64_LIST,
"ARRAY<FLOAT64>": ValueType.DOUBLE_LIST,
"ARRAY<STRING>": ValueType.STRING_LIST,
"ARRAY<BYTES>": ValueType.BYTES_LIST,
"ARRAY<BOOL>": ValueType.BOOL_LIST,
"NULL": ValueType.NULL,
}

return type_map[bq_type_as_str]
value_type = type_map[bq_type_as_str]
if is_list:
value_type = ValueType[value_type.name + "_LIST"]

return value_type


def redshift_to_feast_value_type(redshift_type_as_str: str) -> ValueType:
Expand Down

0 comments on commit 7c53177

Please sign in to comment.