Skip to content

Commit

Permalink
Implemented PR change proposal
Browse files Browse the repository at this point in the history
Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
  • Loading branch information
tmihalac committed Jun 13, 2024
1 parent bdf0150 commit 17c7e72
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
12 changes: 9 additions & 3 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,20 @@ def _send_retrieve_remote(
return _call_get(client, command_descriptor)


def _call_get(client, command_descriptor):
def _call_get(client: fl.FlightClient, command_descriptor: fl.FlightDescriptor):
flight = client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket
reader = client.do_get(ticket)
return reader.read_all()


def _call_put(api, api_parameters, client, entity_df, table):
def _call_put(
api: str,
api_parameters: Dict[str, Any],
client: fl.FlightClient,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
):
# Generate unique command identifier
command_id = str(uuid.uuid4())
command = {
Expand All @@ -364,7 +370,7 @@ def _call_put(api, api_parameters, client, entity_df, table):


def _put_parameters(
command_descriptor,
command_descriptor: fl.FlightDescriptor,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
Expand Down
43 changes: 26 additions & 17 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,46 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
self.offline_store = get_offline_store_from_config(store.config.offline_store)

@classmethod
def descriptor_to_key(self, descriptor):
def descriptor_to_key(self, descriptor: fl.FlightDescriptor):
return (
descriptor.descriptor_type.value,
descriptor.command,
tuple(descriptor.path or tuple()),
)

def _make_flight_info(self, key, descriptor, params):
def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor):
endpoints = [fl.FlightEndpoint(repr(key), [self._location])]
# TODO calculate actual schema from the given features
schema = pa.schema([])

return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)

def get_flight_info(self, context, descriptor):
def get_flight_info(
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
):
key = OfflineServer.descriptor_to_key(descriptor)
if key in self.flights:
params = self.flights[key]
return self._make_flight_info(key, descriptor, params)
return self._make_flight_info(key, descriptor)
raise KeyError("Flight not found.")

def list_flights(self, context, criteria):
def list_flights(self, context: fl.ServerCallContext, criteria: bytes):
for key, table in self.flights.items():
if key[1] is not None:
descriptor = fl.FlightDescriptor.for_command(key[1])
else:
descriptor = fl.FlightDescriptor.for_path(*key[2])

yield self._make_flight_info(key, descriptor, table)
yield self._make_flight_info(key, descriptor)

# Expects to receive request parameters and stores them in the flights dictionary
# Indexed by the unique command
def do_put(self, context, descriptor, reader, writer):
def do_put(
self,
context: fl.ServerCallContext,
descriptor: fl.FlightDescriptor,
reader: fl.MetadataRecordBatchReader,
writer: fl.FlightMetadataWriter,
):
key = OfflineServer.descriptor_to_key(descriptor)
command = json.loads(key[1])
if "api" in command:
Expand All @@ -71,7 +78,7 @@ def do_put(self, context, descriptor, reader, writer):
else:
logger.warning(f"No 'api' field in command: {command}")

def _call_api(self, command, key):
def _call_api(self, command: dict, key: str):
remove_data = False
try:
api = command["api"]
Expand Down Expand Up @@ -145,7 +152,7 @@ def list_feature_views_by_name(

# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
def do_get(self, context, ticket):
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
key = ast.literal_eval(ticket.ticket.decode())
if key not in self.flights:
logger.error(f"Unknown key {key}")
Expand Down Expand Up @@ -173,7 +180,7 @@ def do_get(self, context, ticket):
del self.flights[key]
return fl.RecordBatchStream(table)

def offline_write_batch(self, command, key):
def offline_write_batch(self, command: dict, key: str):
feature_view_names = command["feature_view_names"]
assert (
len(feature_view_names) == 1
Expand All @@ -193,12 +200,14 @@ def offline_write_batch(self, command, key):
self.store.config, feature_views[0], table, command["progress"]
)

def write_logged_features(self, command, key):
def write_logged_features(self, command: dict, key: str):
table = self.flights[key]
feature_service = self.store.get_feature_service(
command["feature_service_name"]
)

assert feature_service.logging_config is not None

self.offline_store.write_logged_features(
config=self.store.config,
data=table,
Expand All @@ -209,7 +218,7 @@ def write_logged_features(self, command, key):
registry=self.store.registry,
)

def pull_all_from_table_or_query(self, command):
def pull_all_from_table_or_query(self, command: dict):
return self.offline_store.pull_all_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
Expand All @@ -220,7 +229,7 @@ def pull_all_from_table_or_query(self, command):
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
)

def pull_latest_from_table_or_query(self, command):
def pull_latest_from_table_or_query(self, command: dict):
return self.offline_store.pull_latest_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
Expand Down Expand Up @@ -249,7 +258,7 @@ def list_actions(self, context):
),
]

def get_historical_features(self, command, key):
def get_historical_features(self, command: dict, key: str):
# Extract parameters from the internal flights dictionary
entity_df_value = self.flights[key]
entity_df = pa.Table.to_pandas(entity_df_value)
Expand All @@ -274,7 +283,7 @@ def get_historical_features(self, command, key):
)
return retJob

def persist(self, retrieve_func, command, key):
def persist(self, retrieve_func: str, command: dict, key: str):
try:
if retrieve_func == OfflineServer.get_historical_features.__name__:
ret_job = self.get_historical_features(command, key)
Expand All @@ -295,7 +304,7 @@ def persist(self, retrieve_func, command, key):
traceback.print_exc()
raise e

def do_action(self, context, action):
def do_action(self, context: fl.ServerCallContext, action: fl.Action):
pass

def do_drop_dataset(self, dataset):
Expand Down

0 comments on commit 17c7e72

Please sign in to comment.