-
Notifications
You must be signed in to change notification settings - Fork 998
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add snowflake online store (#2902)
* feat: Add snowflake online store Signed-off-by: Miles Adkins <miles.adkins@snowflake.com> * lint/format Signed-off-by: Miles Adkins <miles.adkins@snowflake.com> * removing missing testing env variables Signed-off-by: Miles Adkins <miles.adkins@snowflake.com> * test offline store first Signed-off-by: Miles Adkins <miles.adkins@snowflake.com> * snowflake online test fixes Signed-off-by: Miles Adkins <miles.adkins@snowflake.com> * format Signed-off-by: Miles Adkins <miles.adkins@snowflake.com> * fix snowflake testing (#2903) Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>
- Loading branch information
1 parent
0ceb39c
commit f758f9e
Showing
3 changed files
with
423 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
import itertools | ||
import os | ||
from binascii import hexlify | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple | ||
|
||
import pandas as pd | ||
import pytz | ||
from pydantic import Field | ||
from pydantic.schema import Literal | ||
|
||
from feast import Entity, FeatureView | ||
from feast.infra.key_encoding_utils import serialize_entity_key | ||
from feast.infra.online_stores.online_store import OnlineStore | ||
from feast.infra.utils.snowflake_utils import get_snowflake_conn, write_pandas_binary | ||
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto | ||
from feast.protos.feast.types.Value_pb2 import Value as ValueProto | ||
from feast.repo_config import FeastConfigBaseModel, RepoConfig | ||
from feast.usage import log_exceptions_and_usage | ||
|
||
|
||
class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): | ||
""" Online store config for Snowflake """ | ||
|
||
type: Literal["snowflake.online"] = "snowflake.online" | ||
""" Online store type selector""" | ||
|
||
config_path: Optional[str] = ( | ||
Path(os.environ["HOME"]) / ".snowsql/config" | ||
).__str__() | ||
""" Snowflake config path -- absolute path required (Can't use ~)""" | ||
|
||
account: Optional[str] = None | ||
""" Snowflake deployment identifier -- drop .snowflakecomputing.com""" | ||
|
||
user: Optional[str] = None | ||
""" Snowflake user name """ | ||
|
||
password: Optional[str] = None | ||
""" Snowflake password """ | ||
|
||
role: Optional[str] = None | ||
""" Snowflake role name""" | ||
|
||
warehouse: Optional[str] = None | ||
""" Snowflake warehouse name """ | ||
|
||
database: Optional[str] = None | ||
""" Snowflake database name """ | ||
|
||
schema_: Optional[str] = Field("PUBLIC", alias="schema") | ||
""" Snowflake schema name """ | ||
|
||
class Config: | ||
allow_population_by_field_name = True | ||
|
||
|
||
class SnowflakeOnlineStore(OnlineStore): | ||
@log_exceptions_and_usage(online_store="snowflake") | ||
def online_write_batch( | ||
self, | ||
config: RepoConfig, | ||
table: FeatureView, | ||
data: List[ | ||
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] | ||
], | ||
progress: Optional[Callable[[int], Any]], | ||
) -> None: | ||
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) | ||
|
||
dfs = [None] * len(data) | ||
for i, (entity_key, values, timestamp, created_ts) in enumerate(data): | ||
|
||
df = pd.DataFrame( | ||
columns=[ | ||
"entity_feature_key", | ||
"entity_key", | ||
"feature_name", | ||
"value", | ||
"event_ts", | ||
"created_ts", | ||
], | ||
index=range(0, len(values)), | ||
) | ||
|
||
timestamp = _to_naive_utc(timestamp) | ||
if created_ts is not None: | ||
created_ts = _to_naive_utc(created_ts) | ||
|
||
for j, (feature_name, val) in enumerate(values.items()): | ||
df.loc[j, "entity_feature_key"] = serialize_entity_key( | ||
entity_key | ||
) + bytes(feature_name, encoding="utf-8") | ||
df.loc[j, "entity_key"] = serialize_entity_key(entity_key) | ||
df.loc[j, "feature_name"] = feature_name | ||
df.loc[j, "value"] = val.SerializeToString() | ||
df.loc[j, "event_ts"] = timestamp | ||
df.loc[j, "created_ts"] = created_ts | ||
|
||
dfs[i] = df | ||
if progress: | ||
progress(1) | ||
|
||
if dfs: | ||
agg_df = pd.concat(dfs) | ||
|
||
with get_snowflake_conn(config.online_store, autocommit=False) as conn: | ||
|
||
write_pandas_binary(conn, agg_df, f"{config.project}_{table.name}") | ||
|
||
query = f""" | ||
INSERT OVERWRITE INTO "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" | ||
SELECT | ||
"entity_feature_key", | ||
"entity_key", | ||
"feature_name", | ||
"value", | ||
"event_ts", | ||
"created_ts" | ||
FROM | ||
(SELECT | ||
*, | ||
ROW_NUMBER() OVER(PARTITION BY "entity_key","feature_name" ORDER BY "event_ts" DESC, "created_ts" DESC) AS "_feast_row" | ||
FROM | ||
"{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}") | ||
WHERE | ||
"_feast_row" = 1; | ||
""" | ||
|
||
conn.cursor().execute(query) | ||
|
||
return None | ||
|
||
@log_exceptions_and_usage(online_store="snowflake") | ||
def online_read( | ||
self, | ||
config: RepoConfig, | ||
table: FeatureView, | ||
entity_keys: List[EntityKeyProto], | ||
requested_features: List[str], | ||
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: | ||
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) | ||
|
||
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] | ||
|
||
with get_snowflake_conn(config.online_store) as conn: | ||
|
||
df = ( | ||
conn.cursor() | ||
.execute( | ||
f""" | ||
SELECT | ||
"entity_key", "feature_name", "value", "event_ts" | ||
FROM | ||
"{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" | ||
WHERE | ||
"entity_feature_key" IN ({','.join([('TO_BINARY('+hexlify(serialize_entity_key(combo[0])+bytes(combo[1], encoding='utf-8')).__str__()[1:]+")") for combo in itertools.product(entity_keys,requested_features)])}) | ||
""", | ||
) | ||
.fetch_pandas_all() | ||
) | ||
|
||
for entity_key in entity_keys: | ||
entity_key_bin = serialize_entity_key(entity_key) | ||
res = {} | ||
res_ts = None | ||
for index, row in df[df["entity_key"] == entity_key_bin].iterrows(): | ||
val = ValueProto() | ||
val.ParseFromString(row["value"]) | ||
res[row["feature_name"]] = val | ||
res_ts = row["event_ts"].to_pydatetime() | ||
|
||
if not res: | ||
result.append((None, None)) | ||
else: | ||
result.append((res_ts, res)) | ||
return result | ||
|
||
@log_exceptions_and_usage(online_store="snowflake") | ||
def update( | ||
self, | ||
config: RepoConfig, | ||
tables_to_delete: Sequence[FeatureView], | ||
tables_to_keep: Sequence[FeatureView], | ||
entities_to_delete: Sequence[Entity], | ||
entities_to_keep: Sequence[Entity], | ||
partial: bool, | ||
): | ||
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) | ||
|
||
with get_snowflake_conn(config.online_store) as conn: | ||
|
||
for table in tables_to_keep: | ||
|
||
conn.cursor().execute( | ||
f"""CREATE TABLE IF NOT EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" ( | ||
"entity_feature_key" BINARY, | ||
"entity_key" BINARY, | ||
"feature_name" VARCHAR, | ||
"value" BINARY, | ||
"event_ts" TIMESTAMP, | ||
"created_ts" TIMESTAMP | ||
)""" | ||
) | ||
|
||
for table in tables_to_delete: | ||
|
||
conn.cursor().execute( | ||
f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"' | ||
) | ||
|
||
def teardown( | ||
self, | ||
config: RepoConfig, | ||
tables: Sequence[FeatureView], | ||
entities: Sequence[Entity], | ||
): | ||
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) | ||
|
||
with get_snowflake_conn(config.online_store) as conn: | ||
|
||
for table in tables: | ||
query = f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"' | ||
conn.cursor().execute(query) | ||
|
||
|
||
def _to_naive_utc(ts: datetime): | ||
if ts.tzinfo is None: | ||
return ts | ||
else: | ||
return ts.astimezone(pytz.utc).replace(tzinfo=None) |
Oops, something went wrong.