Skip to content

Commit

Permalink
Implemented PR change proposal
Browse files Browse the repository at this point in the history
Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
Signed-off-by: Abdul Hameed <ahameed@redhat.com>
  • Loading branch information
tmihalac authored and redhatHameed committed Jun 10, 2024
1 parent a585d16 commit 7f64708
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 50 deletions.
28 changes: 8 additions & 20 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,15 @@ def persist(
for key, value in self.api_parameters.items():
api_parameters[key] = value

command_descriptor = _call_put(
api=self.api,
api_parameters["retrieve_func"] = self.api

_call_put(
api=RemoteRetrievalJob.persist.__name__,
api_parameters=api_parameters,
client=self.client,
table=self.table,
entity_df=self.entity_df,
)
bytes = command_descriptor.serialize()

self.client.do_action(
pa.flight.Action(RemoteRetrievalJob.persist.__name__, bytes)
)


class RemoteOfflineStore(OfflineStore):
Expand Down Expand Up @@ -236,18 +233,13 @@ def write_logged_features(
"feature_service_name": source._feature_service.name,
}

api_name = OfflineStore.write_logged_features.__name__

command_descriptor = _call_put(
api=api_name,
_call_put(
api=OfflineStore.write_logged_features.__name__,
api_parameters=api_parameters,
client=client,
table=data,
entity_df=None,
)
bytes = command_descriptor.serialize()

client.do_action(pa.flight.Action(api_name, bytes))

@staticmethod
def offline_write_batch(
Expand All @@ -270,17 +262,13 @@ def offline_write_batch(
"name_aliases": name_aliases,
}

api_name = OfflineStore.offline_write_batch.__name__
command_descriptor = _call_put(
api=api_name,
_call_put(
api=OfflineStore.offline_write_batch.__name__,
api_parameters=api_parameters,
client=client,
table=table,
entity_df=None,
)
bytes = command_descriptor.serialize()

client.do_action(pa.flight.Action(api_name, bytes))

@staticmethod
def init_client(config):
Expand Down
65 changes: 35 additions & 30 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,39 @@ def list_flights(self, context, criteria):
# Indexed by the unique command
def do_put(self, context, descriptor, reader, writer):
key = OfflineServer.descriptor_to_key(descriptor)

command = json.loads(key[1])
if "api" in command:
data = reader.read_all()
logger.debug(f"do_put: command is{command}, data is {data}")
self.flights[key] = data

self._call_api(command, key)
else:
logger.warning(f"No 'api' field in command: {command}")

def _call_api(self, command, key):
remove_data = False
try:
api = command["api"]
if api == OfflineServer.offline_write_batch.__name__:
self.offline_write_batch(command, key)
remove_data = True
elif api == OfflineServer.write_logged_features.__name__:
self.write_logged_features(command, key)
remove_data = True
elif api == OfflineServer.persist.__name__:
self.persist(command["retrieve_func"], command, key)
remove_data = True
except Exception as e:
remove_data = True
logger.exception(e)
traceback.print_exc()
raise e
finally:
if remove_data:
# Get service is consumed, so we clear the corresponding flight and data
del self.flights[key]

def get_feature_view_by_name(
self, fv_name: str, name_alias: str, project: str
) -> FeatureView:
Expand Down Expand Up @@ -133,20 +157,18 @@ def do_get(self, context, ticket):
logger.debug(f"requested api is {api}")
try:
if api == OfflineServer.get_historical_features.__name__:
df = self.get_historical_features(command, key).to_df()
table = self.get_historical_features(command, key).to_arrow()
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
df = self.pull_all_from_table_or_query(command).to_df()
table = self.pull_all_from_table_or_query(command).to_arrow()
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
df = self.pull_latest_from_table_or_query(command).to_df()
table = self.pull_latest_from_table_or_query(command).to_arrow()
else:
raise NotImplementedError
except Exception as e:
logger.exception(e)
traceback.print_exc()
raise e

table = pa.Table.from_pandas(df)

# Get service is consumed, so we clear the corresponding flight and data
del self.flights[key]
return fl.RecordBatchStream(table)
Expand Down Expand Up @@ -252,14 +274,15 @@ def get_historical_features(self, command, key):
)
return retJob

def persist(self, command, key):
def persist(self, retrieve_func, command, key):
try:
api = command["api"]
if api == OfflineServer.get_historical_features.__name__:
if retrieve_func == OfflineServer.get_historical_features.__name__:
ret_job = self.get_historical_features(command, key)
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
elif (
retrieve_func == OfflineServer.pull_latest_from_table_or_query.__name__
):
ret_job = self.pull_latest_from_table_or_query(command)
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
elif retrieve_func == OfflineServer.pull_all_from_table_or_query.__name__:
ret_job = self.pull_all_from_table_or_query(command)
else:
raise NotImplementedError
Expand All @@ -273,25 +296,7 @@ def persist(self, command, key):
raise e

def do_action(self, context, action):
command_descriptor = fl.FlightDescriptor.deserialize(action.body.to_pybytes())

key = OfflineServer.descriptor_to_key(command_descriptor)
command = json.loads(key[1])
logger.info(f"do_action command is {command}")

try:
if action.type == OfflineServer.offline_write_batch.__name__:
self.offline_write_batch(command, key)
elif action.type == OfflineServer.write_logged_features.__name__:
self.write_logged_features(command, key)
elif action.type == OfflineServer.persist.__name__:
self.persist(command, key)
else:
raise NotImplementedError
except Exception as e:
logger.exception(e)
traceback.print_exc()
raise e
pass

def do_drop_dataset(self, dataset):
pass
Expand Down

0 comments on commit 7f64708

Please sign in to comment.