Skip to content

Commit

Permalink
Merge pull request feast-dev#4 from dmartinol/remote_offline
Browse files Browse the repository at this point in the history
Integrating comments
  • Loading branch information
redhatHameed committed May 13, 2024
2 parents 01fa2f6 + 31d1fe8 commit ec763de
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 71 deletions.
58 changes: 25 additions & 33 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import uuid
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -39,45 +40,24 @@ def __init__(
entity_df: Union[pd.DataFrame, str],
# TODO add missing parameters from the OfflineStore API
):
# Generate unique command identifier
self.command = str(uuid.uuid4())
# Initialize the client connection
self.client = fl.connect(
f"grpc://{config.offline_store.host}:{config.offline_store.port}"
)
# Put API parameters
self._put_parameters(feature_refs, entity_df)
self.feature_refs = feature_refs
self.entity_df = entity_df

def _put_parameters(self, feature_refs, entity_df):
historical_flight_descriptor = fl.FlightDescriptor.for_command(self.command)
# TODO add one specialized implementation for each OfflineStore API
# This can result in a dictionary of functions indexed by api (e.g., "get_historical_features")
def _put_parameters(self, command_descriptor):
entity_df_table = pa.Table.from_pandas(self.entity_df)

entity_df_table = pa.Table.from_pandas(entity_df)
writer, _ = self.client.do_put(
historical_flight_descriptor,
entity_df_table.schema.with_metadata(
{
"command": self.command,
"api": "get_historical_features",
"param": "entity_df",
}
),
command_descriptor,
entity_df_table.schema,
)
writer.write_table(entity_df_table)
writer.close()

features_array = pa.array(feature_refs)
features_batch = pa.RecordBatch.from_arrays([features_array], ["features"])
writer, _ = self.client.do_put(
historical_flight_descriptor,
features_batch.schema.with_metadata(
{
"command": self.command,
"api": "get_historical_features",
"param": "features",
}
),
)
writer.write_batch(features_batch)
writer.write_table(entity_df_table)
writer.close()

# Invoked to realize the Pandas DataFrame
Expand All @@ -88,8 +68,21 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
# Invoked to synchronously execute the underlying query and return the result as an arrow table
# This is where do_get service is invoked
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
upload_descriptor = fl.FlightDescriptor.for_command(self.command)
flight = self.client.get_flight_info(upload_descriptor)
# Generate unique command identifier
command_id = str(uuid.uuid4())
command = {
"command_id": command_id,
"api": "get_historical_features",
"features": self.feature_refs,
}
command_descriptor = fl.FlightDescriptor.for_command(
json.dumps(
command,
)
)

self._put_parameters(command_descriptor)
flight = self.client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket

reader = self.client.do_get(ticket)
Expand All @@ -112,7 +105,6 @@ def get_historical_features(
project: str,
full_feature_names: bool = False,
) -> RemoteRetrievalJob:
print(f"config.offline_store is {type(config.offline_store)}")
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

# TODO: extend RemoteRetrievalJob API with all method parameters
Expand Down
58 changes: 20 additions & 38 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import json
import traceback
from typing import Dict
from typing import Any, Dict

import pyarrow as pa
import pyarrow.flight as fl
Expand All @@ -12,7 +13,8 @@ class OfflineServer(fl.FlightServerBase):
def __init__(self, store: FeatureStore, location: str, **kwargs):
super(OfflineServer, self).__init__(location, **kwargs)
self._location = location
self.flights: Dict[str, Dict[str, str]] = {}
# A dictionary of configured flights, e.g. API calls received and not yet served
self.flights: Dict[str, Any] = {}
self.store = store

@classmethod
Expand All @@ -23,20 +25,12 @@ def descriptor_to_key(self, descriptor):
tuple(descriptor.path or tuple()),
)

# TODO: since we cannot anticipate here the call to get_historical_features call, what data should we return?
# ATM it returns the metadata of the "entity_df" table
def _make_flight_info(self, key, descriptor, params):
table = params["entity_df"]
endpoints = [fl.FlightEndpoint(repr(key), [self._location])]
mock_sink = pa.MockOutputStream()
stream_writer = pa.RecordBatchStreamWriter(mock_sink, table.schema)
stream_writer.write_table(table)
stream_writer.close()
data_size = mock_sink.size()

return fl.FlightInfo(
table.schema, descriptor, endpoints, table.num_rows, data_size
)
# TODO calculate actual schema from the given features
schema = pa.schema([])

return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)

def get_flight_info(self, context, descriptor):
key = OfflineServer.descriptor_to_key(descriptor)
Expand All @@ -59,23 +53,12 @@ def list_flights(self, context, criteria):
def do_put(self, context, descriptor, reader, writer):
key = OfflineServer.descriptor_to_key(descriptor)

if key in self.flights:
params = self.flights[key]
command = json.loads(key[1])
if "api" in command:
data = reader.read_all()
self.flights[key] = data
else:
params = {}
decoded_metadata = {
key.decode(): value.decode()
for key, value in reader.schema.metadata.items()
}
if "command" in decoded_metadata:
command = decoded_metadata["command"]
api = decoded_metadata["api"]
param = decoded_metadata["param"]
value = reader.read_all()
# Merge the existing dictionary for the same key, as we have multiple calls to do_put for the same key
params.update({"command": command, "api": api, param: value})

self.flights[key] = params
print(f"No 'api' field in command: {command}")

# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
Expand All @@ -85,18 +68,17 @@ def do_get(self, context, ticket):
print(f"Unknown key {key}")
return None

api = self.flights[key]["api"]
# print(f"get key is {key}")
command = json.loads(key[1])
api = command["api"]
# print(f"get command is {command}")
# print(f"requested api is {api}")
if api == "get_historical_features":
# Extract parameters from the internal flight descriptor
entity_df_value = self.flights[key]["entity_df"]
# Extract parameters from the internal flights dictionary
entity_df_value = self.flights[key]
entity_df = pa.Table.to_pandas(entity_df_value)
# print(f"entity_df is {entity_df}")

features_value = self.flights[key]["features"]
features = pa.RecordBatch.to_pylist(features_value)
features = [item["features"] for item in features]
features = command["features"]
# print(f"features is {features}")

print(
Expand All @@ -113,7 +95,7 @@ def do_get(self, context, ticket):
traceback.print_exc()
table = pa.Table.from_pandas(training_df)

# Get service is consumed, so we clear the corresponding flight
# Get service is consumed, so we clear the corresponding flight and data
del self.flights[key]

return fl.RecordBatchStream(table)
Expand Down

0 comments on commit ec763de

Please sign in to comment.