forked from feast-dev/feast
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request feast-dev#1 from redhatHameed/remote-offline
[WIP] feat: Added offline store Arrow Flight server/client
- Loading branch information
Showing
6 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import uuid | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import Any, Callable, List, Literal, Optional, Union | ||
|
||
import pandas as pd | ||
import pyarrow as pa | ||
import pyarrow.parquet | ||
from pydantic import StrictStr | ||
|
||
from feast import OnDemandFeatureView | ||
from feast.data_source import DataSource | ||
from feast.feature_logging import LoggingConfig, LoggingSource | ||
from feast.feature_view import FeatureView | ||
from feast.infra.offline_stores.offline_store import ( | ||
OfflineStore, | ||
RetrievalJob, | ||
) | ||
from feast.infra.registry.base_registry import BaseRegistry | ||
from feast.infra.registry.registry import Registry | ||
from feast.repo_config import FeastConfigBaseModel, RepoConfig | ||
from feast.usage import log_exceptions_and_usage | ||
|
||
|
||
class RemoteOfflineStoreConfig(FeastConfigBaseModel): | ||
|
||
offline_type: StrictStr = "remote" | ||
""" str: Provider name or a class name that implements Offline store.""" | ||
|
||
path: StrictStr = "" | ||
""" str: Path to metadata store. | ||
If offline_type is 'remote', then this is a URL for offline server """ | ||
|
||
host: StrictStr = "" | ||
""" str: host to offline store. | ||
If offline_type is 'remote', then this is a host URL for offline store of arrow flight server """ | ||
|
||
port: StrictStr = "" | ||
""" str: host to offline store.""" | ||
|
||
|
||
class RemoteRetrievalJob(RetrievalJob): | ||
def __init__( | ||
self, | ||
config: RepoConfig, | ||
feature_refs: List[str], | ||
entity_df: Union[pd.DataFrame, str], | ||
# TODO add missing parameters from the OfflineStore API | ||
): | ||
# Generate unique command identifier | ||
self.command = str(uuid.uuid4()) | ||
# Initialize the client connection | ||
self.client = pa.flight.connect(f"grpc://{config.offline_store.host}:{config.offline_store.port}") | ||
# Put API parameters | ||
self._put_parameters(feature_refs, entity_df) | ||
|
||
def _put_parameters(self, feature_refs, entity_df): | ||
entity_df_table = pa.Table.from_pandas(entity_df) | ||
historical_flight_descriptor = pa.flight.FlightDescriptor.for_command(self.command) | ||
writer, _ = self.client.do_put(historical_flight_descriptor, | ||
entity_df_table.schema.with_metadata({ | ||
'command': self.command, | ||
'api': 'get_historical_features', | ||
'param': 'entity_df'})) | ||
writer.write_table(entity_df_table) | ||
writer.close() | ||
|
||
features_array = pa.array(feature_refs) | ||
features_batch = pa.RecordBatch.from_arrays([features_array], ['features']) | ||
writer, _ = self.client.do_put(historical_flight_descriptor, | ||
features_batch.schema.with_metadata({ | ||
'command': self.command, | ||
'api': 'get_historical_features', | ||
'param': 'features'})) | ||
writer.write_batch(features_batch) | ||
writer.close() | ||
|
||
# Invoked to realize the Pandas DataFrame | ||
def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: | ||
# We use arrow format because it gives better control of the table schema | ||
return self._to_arrow_internal().to_pandas() | ||
|
||
# Invoked to synchronously execute the underlying query and return the result as an arrow table | ||
# This is where do_get service is invoked | ||
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: | ||
upload_descriptor = pa.flight.FlightDescriptor.for_command(self.command) | ||
flight = self.client.get_flight_info(upload_descriptor) | ||
ticket = flight.endpoints[0].ticket | ||
|
||
reader = self.client.do_get(ticket) | ||
return reader.read_all() | ||
|
||
@property | ||
def on_demand_feature_views(self) -> List[OnDemandFeatureView]: | ||
return [] | ||
|
||
|
||
class RemoteOfflineStore(OfflineStore): | ||
def __init__( | ||
self, | ||
|
||
arrow_host, | ||
arrow_port | ||
): | ||
self.arrow_host = arrow_host | ||
self.arrow_port = arrow_port | ||
|
||
@log_exceptions_and_usage(offline_store="remote") | ||
def get_historical_features( | ||
self, | ||
config: RepoConfig, | ||
feature_views: List[FeatureView], | ||
feature_refs: List[str], | ||
entity_df: Union[pd.DataFrame, str], | ||
registry: Registry = None, | ||
project: str = '', | ||
full_feature_names: bool = False, | ||
) -> RemoteRetrievalJob: | ||
offline_store_config = config.offline_store | ||
assert isinstance(config.offline_store_config, RemoteOfflineStoreConfig) | ||
store_type = offline_store_config.type | ||
port = offline_store_config.port | ||
host = offline_store_config.host | ||
|
||
return RemoteRetrievalJob(RepoConfig, feature_refs, entity_df) | ||
|
||
@log_exceptions_and_usage(offline_store="remote") | ||
def pull_latest_from_table_or_query(self, | ||
config: RepoConfig, | ||
data_source: DataSource, | ||
join_key_columns: List[str], | ||
feature_name_columns: List[str], | ||
timestamp_field: str, | ||
created_timestamp_column: Optional[str], | ||
start_date: datetime, | ||
end_date: datetime) -> RetrievalJob: | ||
""" Pulls data from the offline store for use in materialization.""" | ||
print("Pulling latest features from my offline store") | ||
# Implementation here. | ||
pass | ||
|
||
def write_logged_features( | ||
config: RepoConfig, | ||
data: Union[pyarrow.Table, Path], | ||
source: LoggingSource, | ||
logging_config: LoggingConfig, | ||
registry: BaseRegistry, | ||
): | ||
""" Optional method to have Feast support logging your online features.""" | ||
# Implementation here. | ||
pass | ||
|
||
def offline_write_batch( | ||
config: RepoConfig, | ||
feature_view: FeatureView, | ||
table: pyarrow.Table, | ||
progress: Optional[Callable[[int], Any]], | ||
): | ||
# Implementation here. | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import ast | ||
|
||
import pyarrow as pa | ||
import pyarrow.flight | ||
|
||
from feast import FeatureStore | ||
|
||
|
||
class OfflineServer(pa.flight.FlightServerBase): | ||
def __init__(self, location=None): | ||
super(OfflineServer, self).__init__(location) | ||
self._location = location | ||
self.flights = {} | ||
self.store = FeatureStore | ||
|
||
@classmethod | ||
def descriptor_to_key(self, descriptor): | ||
return ( | ||
descriptor.descriptor_type.value, | ||
descriptor.command, | ||
tuple(descriptor.path or tuple()), | ||
) | ||
|
||
def _make_flight_info(self, key, descriptor, table): | ||
endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [self._location])] | ||
mock_sink = pyarrow.MockOutputStream() | ||
stream_writer = pyarrow.RecordBatchStreamWriter(mock_sink, table.schema) | ||
stream_writer.write_table(table) | ||
stream_writer.close() | ||
data_size = mock_sink.size() | ||
|
||
return pyarrow.flight.FlightInfo( | ||
table.schema, descriptor, endpoints, table.num_rows, data_size | ||
) | ||
|
||
def get_flight_info(self, context, descriptor): | ||
key = OfflineServer.descriptor_to_key(descriptor) | ||
if key in self.flights: | ||
table = self.flights[key] | ||
return self._make_flight_info(key, descriptor, table) | ||
raise KeyError("Flight not found.") | ||
|
||
def list_flights(self, context, criteria): | ||
for key, table in self.flights.items(): | ||
if key[1] is not None: | ||
descriptor = pyarrow.flight.FlightDescriptor.for_command(key[1]) | ||
else: | ||
descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2]) | ||
|
||
yield self._make_flight_info(key, descriptor, table) | ||
|
||
def do_put(self, context, descriptor, reader, writer): | ||
key = OfflineServer.descriptor_to_key(descriptor) | ||
self.flights[key] = reader.read_all() | ||
|
||
def do_get(self, context, ticket): | ||
key = ast.literal_eval(ticket.ticket.decode()) | ||
if key not in self.flights: | ||
return None | ||
|
||
entity_df_key = self.flights[key] | ||
entity_df = pa.Table.to_pandas(entity_df_key) | ||
# Get feature data | ||
features_key = (2, b"features_descriptor", ()) | ||
if features_key in self.flights: | ||
features_data = self.flights[features_key] | ||
features = pa.RecordBatch.to_pylist(features_data) | ||
features = [item["features"] for item in features] | ||
else: | ||
features = None | ||
|
||
training_df = self.store.get_historical_features(entity_df, features).to_df() | ||
table = pa.Table.from_pandas(training_df) | ||
|
||
return pa.flight.RecordBatchStream(table) | ||
|
||
|
||
def start_server( | ||
store: FeatureStore, | ||
host: str, | ||
port: int, | ||
): | ||
location = "grpc+tcp://{}:{}".format(host, port) | ||
server = OfflineServer(location) | ||
print("Serving on", location) | ||
server.serve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters