Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java Feature Server] Use hgetall in redis connector when number of retrieved fields is big enough #2159

Merged
merged 8 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>11</source>
<target>11</target>
<release>11</release>
<annotationProcessorPaths>
<path>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re
storageRetrievalSpan.setTag("entities", entityRows.size());
storageRetrievalSpan.setTag("features", featureReferences.size());
}
List<List<Feature>> entityRowsFeatures =
List<Map<FeatureReferenceV2, Feature>> features =
retriever.getOnlineFeatures(projectName, entityRows, featureReferences, entityNames);
if (storageRetrievalSpan != null) {
storageRetrievalSpan.finish();
}

if (entityRowsFeatures.size() != entityRows.size()) {
if (features.size() != entityRows.size()) {
throw Status.INTERNAL
.withDescription(
"The no. of FeatureRow obtained from OnlineRetriever"
Expand All @@ -184,36 +184,30 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re

for (int i = 0; i < entityRows.size(); i++) {
GetOnlineFeaturesRequestV2.EntityRow entityRow = entityRows.get(i);
List<Feature> curEntityRowFeatures = entityRowsFeatures.get(i);

Map<FeatureReferenceV2, Feature> featureReferenceFeatureMap =
getFeatureRefFeatureMap(curEntityRowFeatures);
Map<FeatureReferenceV2, Feature> featureRow = features.get(i);

Map<String, ValueProto.Value> rowValues = values.get(i);
Map<String, GetOnlineFeaturesResponse.FieldStatus> rowStatuses = statuses.get(i);

for (FeatureReferenceV2 featureReference : featureReferences) {
if (featureReferenceFeatureMap.containsKey(featureReference)) {
Feature feature = featureReferenceFeatureMap.get(featureReference);
if (featureRow.containsKey(featureReference)) {
Feature feature = featureRow.get(featureReference);

ValueProto.Value value =
feature.getFeatureValue(featureValueTypes.get(feature.getFeatureReference()));
ValueProto.Value value = feature.getFeatureValue(featureValueTypes.get(featureReference));

Boolean isOutsideMaxAge =
checkOutsideMaxAge(
feature, entityRow, featureMaxAges.get(feature.getFeatureReference()));
checkOutsideMaxAge(feature, entityRow, featureMaxAges.get(featureReference));

if (value != null) {
rowValues.put(FeatureV2.getFeatureStringRef(feature.getFeatureReference()), value);
rowValues.put(FeatureV2.getFeatureStringRef(featureReference), value);
} else {
rowValues.put(
FeatureV2.getFeatureStringRef(feature.getFeatureReference()),
FeatureV2.getFeatureStringRef(featureReference),
ValueProto.Value.newBuilder().build());
}

rowStatuses.put(
FeatureV2.getFeatureStringRef(feature.getFeatureReference()),
getMetadata(value, isOutsideMaxAge));
FeatureV2.getFeatureStringRef(featureReference), getMetadata(value, isOutsideMaxAge));
} else {
rowValues.put(
FeatureV2.getFeatureStringRef(featureReference),
Expand Down Expand Up @@ -314,11 +308,6 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re
return GetOnlineFeaturesResponse.newBuilder().addAllFieldValues(fieldValuesList).build();
}

private static Map<FeatureReferenceV2, Feature> getFeatureRefFeatureMap(List<Feature> features) {
return features.stream()
.collect(Collectors.toMap(Feature::getFeatureReference, Function.identity()));
}

/**
* Generate Field level Status metadata for the given valueMap.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;

import com.google.common.collect.ImmutableMap;
import com.google.protobuf.Duration;
import com.google.protobuf.Timestamp;
import feast.proto.core.FeatureProto;
Expand All @@ -42,6 +43,7 @@
import io.opentracing.Tracer.SpanBuilder;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
Expand Down Expand Up @@ -151,14 +153,14 @@ public void shouldReturnResponseWithValuesAndMetadataIfKeysPresent() {
List.of(featureReference1, featureReference2);
GetOnlineFeaturesRequestV2 request = getOnlineFeaturesRequestV2(projectName, featureReferences);

List<Feature> entityKeyList1 = new ArrayList<>();
List<Feature> entityKeyList2 = new ArrayList<>();
entityKeyList1.add(mockedFeatureRows.get(0));
entityKeyList1.add(mockedFeatureRows.get(1));
entityKeyList2.add(mockedFeatureRows.get(2));
entityKeyList2.add(mockedFeatureRows.get(3));

List<List<Feature>> featureRows = List.of(entityKeyList1, entityKeyList2);
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> featureRows =
List.of(
ImmutableMap.of(
mockedFeatureRows.get(0).getFeatureReference(), mockedFeatureRows.get(0),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)),
ImmutableMap.of(
mockedFeatureRows.get(2).getFeatureReference(), mockedFeatureRows.get(2),
mockedFeatureRows.get(3).getFeatureReference(), mockedFeatureRows.get(3)));

when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows);
when(registry.getFeatureViewSpec(any(), any())).thenReturn(getFeatureViewSpec());
Expand Down Expand Up @@ -225,7 +227,13 @@ public void shouldReturnResponseWithUnsetValuesAndMetadataIfKeysNotPresent() {
entityKeyList1.add(mockedFeatureRows.get(1));
entityKeyList2.add(mockedFeatureRows.get(4));

List<List<Feature>> featureRows = List.of(entityKeyList1, entityKeyList2);
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> featureRows =
List.of(
ImmutableMap.of(
mockedFeatureRows.get(0).getFeatureReference(), mockedFeatureRows.get(0),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)),
ImmutableMap.of(
mockedFeatureRows.get(4).getFeatureReference(), mockedFeatureRows.get(4)));

when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows);
when(registry.getFeatureViewSpec(any(), any())).thenReturn(getFeatureViewSpec());
Expand Down Expand Up @@ -282,14 +290,14 @@ public void shouldReturnResponseWithValuesAndMetadataIfMaxAgeIsExceeded() {
List.of(featureReference1, featureReference2);
GetOnlineFeaturesRequestV2 request = getOnlineFeaturesRequestV2(projectName, featureReferences);

List<Feature> entityKeyList1 = new ArrayList<>();
List<Feature> entityKeyList2 = new ArrayList<>();
entityKeyList1.add(mockedFeatureRows.get(5));
entityKeyList1.add(mockedFeatureRows.get(1));
entityKeyList2.add(mockedFeatureRows.get(5));
entityKeyList2.add(mockedFeatureRows.get(1));

List<List<Feature>> featureRows = List.of(entityKeyList1, entityKeyList2);
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> featureRows =
List.of(
ImmutableMap.of(
mockedFeatureRows.get(5).getFeatureReference(), mockedFeatureRows.get(5),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)),
ImmutableMap.of(
mockedFeatureRows.get(5).getFeatureReference(), mockedFeatureRows.get(5),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)));

when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows);
when(registry.getFeatureViewSpec(any(), any()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import feast.proto.serving.ServingAPIProto;
import java.util.List;
import java.util.Map;

public interface OnlineRetrieverV2 {
/**
Expand All @@ -37,7 +38,7 @@ public interface OnlineRetrieverV2 {
* @return list of {@link Feature}s corresponding to data retrieved for each entity row from
* FeatureTable specified in FeatureTable request.
*/
List<List<Feature>> getOnlineFeatures(
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> getOnlineFeatures(
String project,
List<ServingAPIProto.GetOnlineFeaturesRequestV2.EntityRow> entityRows,
List<ServingAPIProto.FeatureReferenceV2> featureReferences,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
*/
package feast.storage.connectors.redis.common;

import com.google.common.collect.Maps;
import com.google.common.hash.Hashing;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Timestamp;
import feast.proto.serving.ServingAPIProto;
import feast.proto.types.ValueProto;
import feast.storage.api.retriever.Feature;
import feast.storage.api.retriever.ProtoFeature;
import io.lettuce.core.KeyValue;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;

public class RedisHashDecoder {

Expand All @@ -35,54 +37,57 @@ public class RedisHashDecoder {
* @param redisHashValues retrieved Redis Hash values based on EntityRows
* @param byteToFeatureReferenceMap map to decode bytes back to FeatureReference
* @param timestampPrefix timestamp prefix
* @return List of {@link Feature}
* @throws InvalidProtocolBufferException if a protocol buffer exception occurs
* @return Map of {@link ServingAPIProto.FeatureReferenceV2} to {@link Feature}
*/
public static List<Feature> retrieveFeature(
List<KeyValue<byte[], byte[]>> redisHashValues,
Map<byte[], ServingAPIProto.FeatureReferenceV2> byteToFeatureReferenceMap,
String timestampPrefix)
throws InvalidProtocolBufferException {
List<Feature> allFeatures = new ArrayList<>();
HashMap<ServingAPIProto.FeatureReferenceV2, ValueProto.Value> featureMap = new HashMap<>();
Map<String, Timestamp> featureTableTimestampMap = new HashMap<>();
public static Map<ServingAPIProto.FeatureReferenceV2, Feature> retrieveFeature(
Map<byte[], byte[]> redisHashValues,
Map<ByteBuffer, ServingAPIProto.FeatureReferenceV2> byteToFeatureReferenceMap,
String timestampPrefix) {
Map<String, Timestamp> featureTableTimestampMap =
redisHashValues.entrySet().stream()
.filter(e -> new String(e.getKey()).startsWith(timestampPrefix))
.collect(
Collectors.toMap(
e -> new String(e.getKey()).substring(timestampPrefix.length() + 1),
e -> {
try {
return Timestamp.parseFrom(e.getValue());
} catch (InvalidProtocolBufferException ex) {
throw new RuntimeException(
"Couldn't parse timestamp proto while pulling data from Redis");
}
}));
Map<ServingAPIProto.FeatureReferenceV2, Feature> results =
Maps.newHashMapWithExpectedSize(byteToFeatureReferenceMap.size());

for (KeyValue<byte[], byte[]> entity : redisHashValues) {
if (entity.hasValue()) {
byte[] redisValueK = entity.getKey();
byte[] redisValueV = entity.getValue();
for (Map.Entry<byte[], byte[]> entry : redisHashValues.entrySet()) {
ServingAPIProto.FeatureReferenceV2 featureReference =
byteToFeatureReferenceMap.get(ByteBuffer.wrap(entry.getKey()));

// Decode data from Redis into Feature object fields
if (new String(redisValueK).startsWith(timestampPrefix)) {
Timestamp eventTimestamp = Timestamp.parseFrom(redisValueV);
featureTableTimestampMap.put(new String(redisValueK), eventTimestamp);
} else {
ServingAPIProto.FeatureReferenceV2 featureReference =
byteToFeatureReferenceMap.get(redisValueK);
ValueProto.Value featureValue = ValueProto.Value.parseFrom(redisValueV);

featureMap.put(featureReference, featureValue);
}
if (featureReference == null) {
continue;
}
}

// Add timestamp to features
for (Map.Entry<ServingAPIProto.FeatureReferenceV2, ValueProto.Value> entry :
featureMap.entrySet()) {
String timestampRedisHashKeyStr = timestampPrefix + ":" + entry.getKey().getFeatureTable();
Timestamp curFeatureTimestamp = featureTableTimestampMap.get(timestampRedisHashKeyStr);

ProtoFeature curFeature =
new ProtoFeature(entry.getKey(), curFeatureTimestamp, entry.getValue());
allFeatures.add(curFeature);
ValueProto.Value v;
try {
v = ValueProto.Value.parseFrom(entry.getValue());
} catch (InvalidProtocolBufferException ex) {
throw new RuntimeException(
"Couldn't parse feature value proto while pulling data from Redis");
}
results.put(
featureReference,
new ProtoFeature(
featureReference,
featureTableTimestampMap.get(featureReference.getFeatureTable()),
v));
}

return allFeatures;
return results;
}

public static byte[] getTimestampRedisHashKeyBytes(
ServingAPIProto.FeatureReferenceV2 featureReference, String timestampPrefix) {
String timestampRedisHashKeyStr = timestampPrefix + ":" + featureReference.getFeatureTable();
public static byte[] getTimestampRedisHashKeyBytes(String featureTable, String timestampPrefix) {
String timestampRedisHashKeyStr = timestampPrefix + ":" + featureTable;
return timestampRedisHashKeyStr.getBytes();
}

Expand Down
Loading