Skip to content

Commit

Permalink
[Java Feature Server] Use hgetall in redis connector when number of r…
Browse files Browse the repository at this point in the history
…etrieved fields is big enough (#2159)

* hgetall

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>

* clean up Redis Hash Decoder

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>

* expected size for collections

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>

* more cleanup

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>

* format

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>

* do not use streams in critical parts

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>

* enable autoflush in redis cluster client

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>

* hgetall threshold as constant

Signed-off-by: pyalex <moskalenko.alexey@gmail.com>
  • Loading branch information
pyalex authored Dec 20, 2021
1 parent 7d4369f commit 033659e
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 126 deletions.
2 changes: 2 additions & 0 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>11</source>
<target>11</target>
<release>11</release>
<annotationProcessorPaths>
<path>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re
storageRetrievalSpan.setTag("entities", entityRows.size());
storageRetrievalSpan.setTag("features", featureReferences.size());
}
List<List<Feature>> entityRowsFeatures =
List<Map<FeatureReferenceV2, Feature>> features =
retriever.getOnlineFeatures(projectName, entityRows, featureReferences, entityNames);
if (storageRetrievalSpan != null) {
storageRetrievalSpan.finish();
}

if (entityRowsFeatures.size() != entityRows.size()) {
if (features.size() != entityRows.size()) {
throw Status.INTERNAL
.withDescription(
"The no. of FeatureRow obtained from OnlineRetriever"
Expand All @@ -184,36 +184,30 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re

for (int i = 0; i < entityRows.size(); i++) {
GetOnlineFeaturesRequestV2.EntityRow entityRow = entityRows.get(i);
List<Feature> curEntityRowFeatures = entityRowsFeatures.get(i);

Map<FeatureReferenceV2, Feature> featureReferenceFeatureMap =
getFeatureRefFeatureMap(curEntityRowFeatures);
Map<FeatureReferenceV2, Feature> featureRow = features.get(i);

Map<String, ValueProto.Value> rowValues = values.get(i);
Map<String, GetOnlineFeaturesResponse.FieldStatus> rowStatuses = statuses.get(i);

for (FeatureReferenceV2 featureReference : featureReferences) {
if (featureReferenceFeatureMap.containsKey(featureReference)) {
Feature feature = featureReferenceFeatureMap.get(featureReference);
if (featureRow.containsKey(featureReference)) {
Feature feature = featureRow.get(featureReference);

ValueProto.Value value =
feature.getFeatureValue(featureValueTypes.get(feature.getFeatureReference()));
ValueProto.Value value = feature.getFeatureValue(featureValueTypes.get(featureReference));

Boolean isOutsideMaxAge =
checkOutsideMaxAge(
feature, entityRow, featureMaxAges.get(feature.getFeatureReference()));
checkOutsideMaxAge(feature, entityRow, featureMaxAges.get(featureReference));

if (value != null) {
rowValues.put(FeatureV2.getFeatureStringRef(feature.getFeatureReference()), value);
rowValues.put(FeatureV2.getFeatureStringRef(featureReference), value);
} else {
rowValues.put(
FeatureV2.getFeatureStringRef(feature.getFeatureReference()),
FeatureV2.getFeatureStringRef(featureReference),
ValueProto.Value.newBuilder().build());
}

rowStatuses.put(
FeatureV2.getFeatureStringRef(feature.getFeatureReference()),
getMetadata(value, isOutsideMaxAge));
FeatureV2.getFeatureStringRef(featureReference), getMetadata(value, isOutsideMaxAge));
} else {
rowValues.put(
FeatureV2.getFeatureStringRef(featureReference),
Expand Down Expand Up @@ -314,11 +308,6 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re
return GetOnlineFeaturesResponse.newBuilder().addAllFieldValues(fieldValuesList).build();
}

private static Map<FeatureReferenceV2, Feature> getFeatureRefFeatureMap(List<Feature> features) {
return features.stream()
.collect(Collectors.toMap(Feature::getFeatureReference, Function.identity()));
}

/**
* Generate Field level Status metadata for the given valueMap.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.initMocks;

import com.google.common.collect.ImmutableMap;
import com.google.protobuf.Duration;
import com.google.protobuf.Timestamp;
import feast.proto.core.FeatureProto;
Expand All @@ -42,6 +43,7 @@
import io.opentracing.Tracer.SpanBuilder;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
Expand Down Expand Up @@ -151,14 +153,14 @@ public void shouldReturnResponseWithValuesAndMetadataIfKeysPresent() {
List.of(featureReference1, featureReference2);
GetOnlineFeaturesRequestV2 request = getOnlineFeaturesRequestV2(projectName, featureReferences);

List<Feature> entityKeyList1 = new ArrayList<>();
List<Feature> entityKeyList2 = new ArrayList<>();
entityKeyList1.add(mockedFeatureRows.get(0));
entityKeyList1.add(mockedFeatureRows.get(1));
entityKeyList2.add(mockedFeatureRows.get(2));
entityKeyList2.add(mockedFeatureRows.get(3));

List<List<Feature>> featureRows = List.of(entityKeyList1, entityKeyList2);
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> featureRows =
List.of(
ImmutableMap.of(
mockedFeatureRows.get(0).getFeatureReference(), mockedFeatureRows.get(0),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)),
ImmutableMap.of(
mockedFeatureRows.get(2).getFeatureReference(), mockedFeatureRows.get(2),
mockedFeatureRows.get(3).getFeatureReference(), mockedFeatureRows.get(3)));

when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows);
when(registry.getFeatureViewSpec(any(), any())).thenReturn(getFeatureViewSpec());
Expand Down Expand Up @@ -225,7 +227,13 @@ public void shouldReturnResponseWithUnsetValuesAndMetadataIfKeysNotPresent() {
entityKeyList1.add(mockedFeatureRows.get(1));
entityKeyList2.add(mockedFeatureRows.get(4));

List<List<Feature>> featureRows = List.of(entityKeyList1, entityKeyList2);
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> featureRows =
List.of(
ImmutableMap.of(
mockedFeatureRows.get(0).getFeatureReference(), mockedFeatureRows.get(0),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)),
ImmutableMap.of(
mockedFeatureRows.get(4).getFeatureReference(), mockedFeatureRows.get(4)));

when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows);
when(registry.getFeatureViewSpec(any(), any())).thenReturn(getFeatureViewSpec());
Expand Down Expand Up @@ -282,14 +290,14 @@ public void shouldReturnResponseWithValuesAndMetadataIfMaxAgeIsExceeded() {
List.of(featureReference1, featureReference2);
GetOnlineFeaturesRequestV2 request = getOnlineFeaturesRequestV2(projectName, featureReferences);

List<Feature> entityKeyList1 = new ArrayList<>();
List<Feature> entityKeyList2 = new ArrayList<>();
entityKeyList1.add(mockedFeatureRows.get(5));
entityKeyList1.add(mockedFeatureRows.get(1));
entityKeyList2.add(mockedFeatureRows.get(5));
entityKeyList2.add(mockedFeatureRows.get(1));

List<List<Feature>> featureRows = List.of(entityKeyList1, entityKeyList2);
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> featureRows =
List.of(
ImmutableMap.of(
mockedFeatureRows.get(5).getFeatureReference(), mockedFeatureRows.get(5),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)),
ImmutableMap.of(
mockedFeatureRows.get(5).getFeatureReference(), mockedFeatureRows.get(5),
mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)));

when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows);
when(registry.getFeatureViewSpec(any(), any()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import feast.proto.serving.ServingAPIProto;
import java.util.List;
import java.util.Map;

public interface OnlineRetrieverV2 {
/**
Expand All @@ -37,7 +38,7 @@ public interface OnlineRetrieverV2 {
* @return list of {@link Feature}s corresponding to data retrieved for each entity row from
* FeatureTable specified in FeatureTable request.
*/
List<List<Feature>> getOnlineFeatures(
List<Map<ServingAPIProto.FeatureReferenceV2, Feature>> getOnlineFeatures(
String project,
List<ServingAPIProto.GetOnlineFeaturesRequestV2.EntityRow> entityRows,
List<ServingAPIProto.FeatureReferenceV2> featureReferences,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
*/
package feast.storage.connectors.redis.common;

import com.google.common.collect.Maps;
import com.google.common.hash.Hashing;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Timestamp;
import feast.proto.serving.ServingAPIProto;
import feast.proto.types.ValueProto;
import feast.storage.api.retriever.Feature;
import feast.storage.api.retriever.ProtoFeature;
import io.lettuce.core.KeyValue;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;

public class RedisHashDecoder {

Expand All @@ -35,54 +37,57 @@ public class RedisHashDecoder {
* @param redisHashValues retrieved Redis Hash values based on EntityRows
* @param byteToFeatureReferenceMap map to decode bytes back to FeatureReference
* @param timestampPrefix timestamp prefix
* @return List of {@link Feature}
* @throws InvalidProtocolBufferException if a protocol buffer exception occurs
* @return Map of {@link ServingAPIProto.FeatureReferenceV2} to {@link Feature}
*/
public static List<Feature> retrieveFeature(
List<KeyValue<byte[], byte[]>> redisHashValues,
Map<byte[], ServingAPIProto.FeatureReferenceV2> byteToFeatureReferenceMap,
String timestampPrefix)
throws InvalidProtocolBufferException {
List<Feature> allFeatures = new ArrayList<>();
HashMap<ServingAPIProto.FeatureReferenceV2, ValueProto.Value> featureMap = new HashMap<>();
Map<String, Timestamp> featureTableTimestampMap = new HashMap<>();
public static Map<ServingAPIProto.FeatureReferenceV2, Feature> retrieveFeature(
Map<byte[], byte[]> redisHashValues,
Map<ByteBuffer, ServingAPIProto.FeatureReferenceV2> byteToFeatureReferenceMap,
String timestampPrefix) {
Map<String, Timestamp> featureTableTimestampMap =
redisHashValues.entrySet().stream()
.filter(e -> new String(e.getKey()).startsWith(timestampPrefix))
.collect(
Collectors.toMap(
e -> new String(e.getKey()).substring(timestampPrefix.length() + 1),
e -> {
try {
return Timestamp.parseFrom(e.getValue());
} catch (InvalidProtocolBufferException ex) {
throw new RuntimeException(
"Couldn't parse timestamp proto while pulling data from Redis");
}
}));
Map<ServingAPIProto.FeatureReferenceV2, Feature> results =
Maps.newHashMapWithExpectedSize(byteToFeatureReferenceMap.size());

for (KeyValue<byte[], byte[]> entity : redisHashValues) {
if (entity.hasValue()) {
byte[] redisValueK = entity.getKey();
byte[] redisValueV = entity.getValue();
for (Map.Entry<byte[], byte[]> entry : redisHashValues.entrySet()) {
ServingAPIProto.FeatureReferenceV2 featureReference =
byteToFeatureReferenceMap.get(ByteBuffer.wrap(entry.getKey()));

// Decode data from Redis into Feature object fields
if (new String(redisValueK).startsWith(timestampPrefix)) {
Timestamp eventTimestamp = Timestamp.parseFrom(redisValueV);
featureTableTimestampMap.put(new String(redisValueK), eventTimestamp);
} else {
ServingAPIProto.FeatureReferenceV2 featureReference =
byteToFeatureReferenceMap.get(redisValueK);
ValueProto.Value featureValue = ValueProto.Value.parseFrom(redisValueV);

featureMap.put(featureReference, featureValue);
}
if (featureReference == null) {
continue;
}
}

// Add timestamp to features
for (Map.Entry<ServingAPIProto.FeatureReferenceV2, ValueProto.Value> entry :
featureMap.entrySet()) {
String timestampRedisHashKeyStr = timestampPrefix + ":" + entry.getKey().getFeatureTable();
Timestamp curFeatureTimestamp = featureTableTimestampMap.get(timestampRedisHashKeyStr);

ProtoFeature curFeature =
new ProtoFeature(entry.getKey(), curFeatureTimestamp, entry.getValue());
allFeatures.add(curFeature);
ValueProto.Value v;
try {
v = ValueProto.Value.parseFrom(entry.getValue());
} catch (InvalidProtocolBufferException ex) {
throw new RuntimeException(
"Couldn't parse feature value proto while pulling data from Redis");
}
results.put(
featureReference,
new ProtoFeature(
featureReference,
featureTableTimestampMap.get(featureReference.getFeatureTable()),
v));
}

return allFeatures;
return results;
}

public static byte[] getTimestampRedisHashKeyBytes(
ServingAPIProto.FeatureReferenceV2 featureReference, String timestampPrefix) {
String timestampRedisHashKeyStr = timestampPrefix + ":" + featureReference.getFeatureTable();
public static byte[] getTimestampRedisHashKeyBytes(String featureTable, String timestampPrefix) {
String timestampRedisHashKeyStr = timestampPrefix + ":" + featureTable;
return timestampRedisHashKeyStr.getBytes();
}

Expand Down
Loading

0 comments on commit 033659e

Please sign in to comment.