Skip to content

Commit

Permalink
Python FeatureServer optimization (#2202)
Browse files Browse the repository at this point in the history
* Optimize Python FeatureServer

Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>

* Handle `RepeatedValue` proto in `_get_online_features`

Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>

* Only initialize `Timestamp` once

Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>

* Don't use `defaultdict`

Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>
  • Loading branch information
judahrand authored Jan 18, 2022
1 parent f32b4f4 commit 05f4e8f
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 110 deletions.
16 changes: 5 additions & 11 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import feast
from feast import proto_json
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest
from feast.type_map import feast_value_type_to_python_type


def get_app(store: "feast.FeatureStore"):
Expand Down Expand Up @@ -43,16 +42,11 @@ def get_online_features(body=Depends(get_body)):
if any(batch_size != num_entities for batch_size in batch_sizes):
raise HTTPException(status_code=500, detail="Uneven number of columns")

entity_rows = [
{
k: feast_value_type_to_python_type(v.val[idx])
for k, v in request_proto.entities.items()
}
for idx in range(num_entities)
]

response_proto = store.get_online_features(
features, entity_rows, full_feature_names=full_feature_names
response_proto = store._get_online_features(
features,
request_proto.entities,
full_feature_names=full_feature_names,
native_entity_values=False,
).proto

# Convert the Protobuf object to JSON and return it
Expand Down
257 changes: 158 additions & 99 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand Down Expand Up @@ -72,7 +74,7 @@
GetOnlineFeaturesResponse,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value
from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value
from feast.registry import Registry
from feast.repo_config import RepoConfig, load_repo_config
from feast.request_feature_view import RequestFeatureView
Expand Down Expand Up @@ -267,14 +269,18 @@ def _list_feature_views(
return feature_views

@log_exceptions_and_usage
def list_on_demand_feature_views(self) -> List[OnDemandFeatureView]:
def list_on_demand_feature_views(
self, allow_cache: bool = False
) -> List[OnDemandFeatureView]:
"""
Retrieves the list of on demand feature views from the registry.
Returns:
A list of on demand feature views.
"""
return self._registry.list_on_demand_feature_views(self.project)
return self._registry.list_on_demand_feature_views(
self.project, allow_cache=allow_cache
)

@log_exceptions_and_usage
def get_entity(self, name: str) -> Entity:
Expand Down Expand Up @@ -1067,6 +1073,30 @@ def get_online_features(
... )
>>> online_response_dict = online_response.to_dict()
"""
columnar: Dict[str, List[Any]] = {k: [] for k in entity_rows[0].keys()}
for entity_row in entity_rows:
for key, value in entity_row.items():
try:
columnar[key].append(value)
except KeyError as e:
raise ValueError("All entity_rows must have the same keys.") from e

return self._get_online_features(
features=features,
entity_values=columnar,
full_feature_names=full_feature_names,
native_entity_values=True,
)

def _get_online_features(
self,
features: Union[List[str], FeatureService],
entity_values: Mapping[
str, Union[Sequence[Any], Sequence[Value], RepeatedValue]
],
full_feature_names: bool = False,
native_entity_values: bool = True,
):
_feature_refs = self._get_features(features, allow_cache=True)
(
requested_feature_views,
Expand All @@ -1076,6 +1106,29 @@ def get_online_features(
features=features, allow_cache=True, hide_dummy_entity=False
)

entity_name_to_join_key_map, entity_type_map = self._get_entity_maps(
requested_feature_views
)

# Extract Sequence from RepeatedValue Protobuf.
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
k: list(v) if isinstance(v, Sequence) else list(v.val)
for k, v in entity_values.items()
}

entity_proto_values: Dict[str, List[Value]]
if native_entity_values:
# Convert values to Protobuf once.
entity_proto_values = {
k: python_values_to_proto_values(
v, entity_type_map.get(k, ValueType.UNKNOWN)
)
for k, v in entity_value_lists.items()
}
else:
entity_proto_values = entity_value_lists

num_rows = _validate_entity_values(entity_proto_values)
_validate_feature_refs(_feature_refs, full_feature_names)
(
grouped_refs,
Expand All @@ -1101,111 +1154,72 @@ def get_online_features(
}

feature_views = list(view for view, _ in grouped_refs)
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]

provider = self._get_provider()
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
entity_name_to_join_key_map: Dict[str, str] = {}
join_key_to_entity_type_map: Dict[str, ValueType] = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
join_key_to_entity_type_map[entity.join_key] = entity.value_type
for feature_view in requested_feature_views:
for entity_name in feature_view.entities:
entity = self._registry.get_entity(
entity_name, self.project, allow_cache=True
)
# User directly uses join_key as the entity reference in the entity_rows for the
# entity mapping case.
entity_name = feature_view.projection.join_key_map.get(
entity.join_key, entity.name
)
join_key = feature_view.projection.join_key_map.get(
entity.join_key, entity.join_key
)
entity_name_to_join_key_map[entity_name] = join_key
join_key_to_entity_type_map[join_key] = entity.value_type

needed_request_data, needed_request_fv_features = self.get_needed_request_data(
grouped_odfv_refs, grouped_request_fv_refs
)

join_key_rows = []
request_data_features: Dict[str, List[Any]] = defaultdict(list)
join_key_values: Dict[str, List[Value]] = {}
request_data_features: Dict[str, List[Value]] = {}
# Entity rows may be either entities or request data.
for row in entity_rows:
join_key_row = {}
for entity_name, entity_value in row.items():
# Found request data
if (
entity_name in needed_request_data
or entity_name in needed_request_fv_features
):
if entity_name in needed_request_fv_features:
# If the data was requested as a feature then
# make sure it appears in the result.
requested_result_row_names.add(entity_name)
request_data_features[entity_name].append(entity_value)
else:
try:
join_key = entity_name_to_join_key_map[entity_name]
except KeyError:
raise EntityNotFoundException(entity_name, self.project)
# All join keys should be returned in the result.
requested_result_row_names.add(join_key)
join_key_row[join_key] = entity_value
if entityless_case:
join_key_row[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL
if len(join_key_row) > 0:
# May be empty if this entity row was request data
join_key_rows.append(join_key_row)
for entity_name, values in entity_proto_values.items():
# Found request data
if (
entity_name in needed_request_data
or entity_name in needed_request_fv_features
):
if entity_name in needed_request_fv_features:
# If the data was requested as a feature then
# make sure it appears in the result.
requested_result_row_names.add(entity_name)
request_data_features[entity_name] = values
else:
try:
join_key = entity_name_to_join_key_map[entity_name]
except KeyError:
raise EntityNotFoundException(entity_name, self.project)
# All join keys should be returned in the result.
requested_result_row_names.add(join_key)
join_key_values[join_key] = values

self.ensure_request_data_values_exist(
needed_request_data, needed_request_fv_features, request_data_features
)

# Convert join_key_rows from rowise to columnar.
join_key_python_values: Dict[str, List[Value]] = defaultdict(list)
for join_key_row in join_key_rows:
for join_key, value in join_key_row.items():
join_key_python_values[join_key].append(value)

# Convert all join key values to Protobuf Values
join_key_proto_values = {
k: python_values_to_proto_values(v, join_key_to_entity_type_map[k])
for k, v in join_key_python_values.items()
}

# Populate online features response proto with join keys
# Populate online features response proto with join keys and request data features
online_features_response = GetOnlineFeaturesResponse(
results=[
GetOnlineFeaturesResponse.FeatureVector()
for _ in range(len(entity_rows))
]
results=[GetOnlineFeaturesResponse.FeatureVector() for _ in range(num_rows)]
)
for key, values in join_key_proto_values.items():
online_features_response.metadata.feature_names.val.append(key)
for row_idx, result_row in enumerate(online_features_response.results):
result_row.values.append(values[row_idx])
result_row.statuses.append(FieldStatus.PRESENT)
result_row.event_timestamps.append(Timestamp())
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data=dict(**join_key_values, **request_data_features),
)

# Add the Entityless case after populating result rows to avoid having to remove
# it later.
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]
if entityless_case:
join_key_values[DUMMY_ENTITY_ID] = python_values_to_proto_values(
[DUMMY_ENTITY_VAL] * num_rows, DUMMY_ENTITY.value_type
)

# Initialize the set of EntityKeyProtos once and reuse them for each FeatureView
# to avoid initialization overhead.
entity_keys = [EntityKeyProto() for _ in range(len(join_key_rows))]
entity_keys = [EntityKeyProto() for _ in range(num_rows)]
provider = self._get_provider()
for table, requested_features in grouped_refs:
# Get the correct set of entity values with the correct join keys.
entity_values = self._get_table_entity_values(
table, entity_name_to_join_key_map, join_key_proto_values,
table_entity_values = self._get_table_entity_values(
table, entity_name_to_join_key_map, join_key_values,
)

# Set the EntityKeyProtos inplace.
self._set_table_entity_keys(
entity_values, entity_keys,
table_entity_values, entity_keys,
)

# Populate the result_rows with the Features from the OnlineStore inplace.
Expand All @@ -1218,10 +1232,6 @@ def get_online_features(
table,
)

self._populate_request_data_features(
online_features_response, request_data_features
)

if grouped_odfv_refs:
self._augment_response_with_on_demand_transforms(
online_features_response,
Expand All @@ -1235,6 +1245,50 @@ def get_online_features(
)
return OnlineResponse(online_features_response)

@staticmethod
def _get_columnar_entity_values(
rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]]
) -> Dict[str, List[Any]]:
if (rowise is None and columnar is None) or (
rowise is not None and columnar is not None
):
raise ValueError(
"Exactly one of `columnar_entity_values` and `rowise_entity_values` must be set."
)

if rowise is not None:
# Convert entity_rows from rowise to columnar.
res = defaultdict(list)
for entity_row in rowise:
for key, value in entity_row.items():
res[key].append(value)
return res
return cast(Dict[str, List[Any]], columnar)

def _get_entity_maps(self, feature_views):
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
entity_name_to_join_key_map: Dict[str, str] = {}
entity_type_map: Dict[str, ValueType] = {}
for entity in entities:
entity_name_to_join_key_map[entity.name] = entity.join_key
entity_type_map[entity.name] = entity.value_type
for feature_view in feature_views:
for entity_name in feature_view.entities:
entity = self._registry.get_entity(
entity_name, self.project, allow_cache=True
)
# User directly uses join_key as the entity reference in the entity_rows for the
# entity mapping case.
entity_name = feature_view.projection.join_key_map.get(
entity.join_key, entity.name
)
join_key = feature_view.projection.join_key_map.get(
entity.join_key, entity.join_key
)
entity_name_to_join_key_map[entity_name] = join_key
entity_type_map[join_key] = entity.value_type
return entity_name_to_join_key_map, entity_type_map

@staticmethod
def _get_table_entity_values(
table: FeatureView,
Expand Down Expand Up @@ -1275,23 +1329,21 @@ def _set_table_entity_keys(
entity_key.entity_values.extend(next(rowise_values))

@staticmethod
def _populate_request_data_features(
def _populate_result_rows_from_columnar(
online_features_response: GetOnlineFeaturesResponse,
request_data_features: Dict[str, List[Any]],
data: Dict[str, List[Value]],
):
# Add more feature values to the existing result rows for the request data features
for feature_name, feature_values in request_data_features.items():
proto_values = python_values_to_proto_values(
feature_values, ValueType.UNKNOWN
)
timestamp = Timestamp() # Only initialize this timestamp once.
# Add more values to the existing result rows
for feature_name, feature_values in data.items():

online_features_response.metadata.feature_names.val.append(feature_name)

for row_idx, proto_value in enumerate(proto_values):
for row_idx, proto_value in enumerate(feature_values):
result_row = online_features_response.results[row_idx]
result_row.values.append(proto_value)
result_row.statuses.append(FieldStatus.PRESENT)
result_row.event_timestamps.append(Timestamp())
result_row.event_timestamps.append(timestamp)

@staticmethod
def get_needed_request_data(
Expand Down Expand Up @@ -1567,6 +1619,13 @@ def serve_transformations(self, port: int) -> None:
transformation_server.start_server(self, port)


def _validate_entity_values(join_key_values: Dict[str, List[Value]]):
set_of_row_lengths = {len(v) for v in join_key_values.values()}
if len(set_of_row_lengths) > 1:
raise ValueError("All entity rows must have the same columns.")
return set_of_row_lengths.pop()


def _validate_feature_refs(feature_refs: List[str], full_feature_names: bool = False):
collided_feature_refs = []

Expand Down

0 comments on commit 05f4e8f

Please sign in to comment.