diff --git a/java/pom.xml b/java/pom.xml index 4b63f5389b..6f89736c01 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -565,6 +565,8 @@ org.apache.maven.plugins maven-compiler-plugin + 11 + 11 11 diff --git a/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java b/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java index e9a8019a7f..2d5621e4b4 100644 --- a/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java +++ b/java/serving/src/main/java/feast/serving/service/OnlineServingServiceV2.java @@ -166,13 +166,13 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re storageRetrievalSpan.setTag("entities", entityRows.size()); storageRetrievalSpan.setTag("features", featureReferences.size()); } - List> entityRowsFeatures = + List> features = retriever.getOnlineFeatures(projectName, entityRows, featureReferences, entityNames); if (storageRetrievalSpan != null) { storageRetrievalSpan.finish(); } - if (entityRowsFeatures.size() != entityRows.size()) { + if (features.size() != entityRows.size()) { throw Status.INTERNAL .withDescription( "The no. of FeatureRow obtained from OnlineRetriever" @@ -184,36 +184,30 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re for (int i = 0; i < entityRows.size(); i++) { GetOnlineFeaturesRequestV2.EntityRow entityRow = entityRows.get(i); - List curEntityRowFeatures = entityRowsFeatures.get(i); - - Map featureReferenceFeatureMap = - getFeatureRefFeatureMap(curEntityRowFeatures); + Map featureRow = features.get(i); Map rowValues = values.get(i); Map rowStatuses = statuses.get(i); for (FeatureReferenceV2 featureReference : featureReferences) { - if (featureReferenceFeatureMap.containsKey(featureReference)) { - Feature feature = featureReferenceFeatureMap.get(featureReference); + if (featureRow.containsKey(featureReference)) { + Feature feature = featureRow.get(featureReference); - ValueProto.Value value = - feature.getFeatureValue(featureValueTypes.get(feature.getFeatureReference())); + ValueProto.Value value = feature.getFeatureValue(featureValueTypes.get(featureReference)); Boolean isOutsideMaxAge = - checkOutsideMaxAge( - feature, entityRow, featureMaxAges.get(feature.getFeatureReference())); + checkOutsideMaxAge(feature, entityRow, featureMaxAges.get(featureReference)); if (value != null) { - rowValues.put(FeatureV2.getFeatureStringRef(feature.getFeatureReference()), value); + rowValues.put(FeatureV2.getFeatureStringRef(featureReference), value); } else { rowValues.put( - FeatureV2.getFeatureStringRef(feature.getFeatureReference()), + FeatureV2.getFeatureStringRef(featureReference), ValueProto.Value.newBuilder().build()); } rowStatuses.put( - FeatureV2.getFeatureStringRef(feature.getFeatureReference()), - getMetadata(value, isOutsideMaxAge)); + FeatureV2.getFeatureStringRef(featureReference), getMetadata(value, isOutsideMaxAge)); } else { rowValues.put( FeatureV2.getFeatureStringRef(featureReference), @@ -314,11 +308,6 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequestV2 re return GetOnlineFeaturesResponse.newBuilder().addAllFieldValues(fieldValuesList).build(); } - private static Map getFeatureRefFeatureMap(List features) { - return features.stream() - .collect(Collectors.toMap(Feature::getFeatureReference, Function.identity())); - } - /** * Generate Field level Status metadata for the given valueMap. * diff --git a/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java b/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java index b9d270cc10..c43e3218c7 100644 --- a/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java +++ b/java/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Duration; import com.google.protobuf.Timestamp; import feast.proto.core.FeatureProto; @@ -42,6 +43,7 @@ import io.opentracing.Tracer.SpanBuilder; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentMatchers; @@ -151,14 +153,14 @@ public void shouldReturnResponseWithValuesAndMetadataIfKeysPresent() { List.of(featureReference1, featureReference2); GetOnlineFeaturesRequestV2 request = getOnlineFeaturesRequestV2(projectName, featureReferences); - List entityKeyList1 = new ArrayList<>(); - List entityKeyList2 = new ArrayList<>(); - entityKeyList1.add(mockedFeatureRows.get(0)); - entityKeyList1.add(mockedFeatureRows.get(1)); - entityKeyList2.add(mockedFeatureRows.get(2)); - entityKeyList2.add(mockedFeatureRows.get(3)); - - List> featureRows = List.of(entityKeyList1, entityKeyList2); + List> featureRows = + List.of( + ImmutableMap.of( + mockedFeatureRows.get(0).getFeatureReference(), mockedFeatureRows.get(0), + mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)), + ImmutableMap.of( + mockedFeatureRows.get(2).getFeatureReference(), mockedFeatureRows.get(2), + mockedFeatureRows.get(3).getFeatureReference(), mockedFeatureRows.get(3))); when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows); when(registry.getFeatureViewSpec(any(), any())).thenReturn(getFeatureViewSpec()); @@ -225,7 +227,13 @@ public void shouldReturnResponseWithUnsetValuesAndMetadataIfKeysNotPresent() { entityKeyList1.add(mockedFeatureRows.get(1)); entityKeyList2.add(mockedFeatureRows.get(4)); - List> featureRows = List.of(entityKeyList1, entityKeyList2); + List> featureRows = + List.of( + ImmutableMap.of( + mockedFeatureRows.get(0).getFeatureReference(), mockedFeatureRows.get(0), + mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)), + ImmutableMap.of( + mockedFeatureRows.get(4).getFeatureReference(), mockedFeatureRows.get(4))); when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows); when(registry.getFeatureViewSpec(any(), any())).thenReturn(getFeatureViewSpec()); @@ -282,14 +290,14 @@ public void shouldReturnResponseWithValuesAndMetadataIfMaxAgeIsExceeded() { List.of(featureReference1, featureReference2); GetOnlineFeaturesRequestV2 request = getOnlineFeaturesRequestV2(projectName, featureReferences); - List entityKeyList1 = new ArrayList<>(); - List entityKeyList2 = new ArrayList<>(); - entityKeyList1.add(mockedFeatureRows.get(5)); - entityKeyList1.add(mockedFeatureRows.get(1)); - entityKeyList2.add(mockedFeatureRows.get(5)); - entityKeyList2.add(mockedFeatureRows.get(1)); - - List> featureRows = List.of(entityKeyList1, entityKeyList2); + List> featureRows = + List.of( + ImmutableMap.of( + mockedFeatureRows.get(5).getFeatureReference(), mockedFeatureRows.get(5), + mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1)), + ImmutableMap.of( + mockedFeatureRows.get(5).getFeatureReference(), mockedFeatureRows.get(5), + mockedFeatureRows.get(1).getFeatureReference(), mockedFeatureRows.get(1))); when(retrieverV2.getOnlineFeatures(any(), any(), any(), any())).thenReturn(featureRows); when(registry.getFeatureViewSpec(any(), any())) diff --git a/java/storage/api/src/main/java/feast/storage/api/retriever/OnlineRetrieverV2.java b/java/storage/api/src/main/java/feast/storage/api/retriever/OnlineRetrieverV2.java index a49ab3fccd..db5db8b63c 100644 --- a/java/storage/api/src/main/java/feast/storage/api/retriever/OnlineRetrieverV2.java +++ b/java/storage/api/src/main/java/feast/storage/api/retriever/OnlineRetrieverV2.java @@ -18,6 +18,7 @@ import feast.proto.serving.ServingAPIProto; import java.util.List; +import java.util.Map; public interface OnlineRetrieverV2 { /** @@ -37,7 +38,7 @@ public interface OnlineRetrieverV2 { * @return list of {@link Feature}s corresponding to data retrieved for each entity row from * FeatureTable specified in FeatureTable request. */ - List> getOnlineFeatures( + List> getOnlineFeatures( String project, List entityRows, List featureReferences, diff --git a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/common/RedisHashDecoder.java b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/common/RedisHashDecoder.java index e24b3bd5ab..fd0f0a56dc 100644 --- a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/common/RedisHashDecoder.java +++ b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/common/RedisHashDecoder.java @@ -16,6 +16,7 @@ */ package feast.storage.connectors.redis.common; +import com.google.common.collect.Maps; import com.google.common.hash.Hashing; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Timestamp; @@ -23,9 +24,10 @@ import feast.proto.types.ValueProto; import feast.storage.api.retriever.Feature; import feast.storage.api.retriever.ProtoFeature; -import io.lettuce.core.KeyValue; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.stream.Collectors; public class RedisHashDecoder { @@ -35,54 +37,57 @@ public class RedisHashDecoder { * @param redisHashValues retrieved Redis Hash values based on EntityRows * @param byteToFeatureReferenceMap map to decode bytes back to FeatureReference * @param timestampPrefix timestamp prefix - * @return List of {@link Feature} - * @throws InvalidProtocolBufferException if a protocol buffer exception occurs + * @return Map of {@link ServingAPIProto.FeatureReferenceV2} to {@link Feature} */ - public static List retrieveFeature( - List> redisHashValues, - Map byteToFeatureReferenceMap, - String timestampPrefix) - throws InvalidProtocolBufferException { - List allFeatures = new ArrayList<>(); - HashMap featureMap = new HashMap<>(); - Map featureTableTimestampMap = new HashMap<>(); + public static Map retrieveFeature( + Map redisHashValues, + Map byteToFeatureReferenceMap, + String timestampPrefix) { + Map featureTableTimestampMap = + redisHashValues.entrySet().stream() + .filter(e -> new String(e.getKey()).startsWith(timestampPrefix)) + .collect( + Collectors.toMap( + e -> new String(e.getKey()).substring(timestampPrefix.length() + 1), + e -> { + try { + return Timestamp.parseFrom(e.getValue()); + } catch (InvalidProtocolBufferException ex) { + throw new RuntimeException( + "Couldn't parse timestamp proto while pulling data from Redis"); + } + })); + Map results = + Maps.newHashMapWithExpectedSize(byteToFeatureReferenceMap.size()); - for (KeyValue entity : redisHashValues) { - if (entity.hasValue()) { - byte[] redisValueK = entity.getKey(); - byte[] redisValueV = entity.getValue(); + for (Map.Entry entry : redisHashValues.entrySet()) { + ServingAPIProto.FeatureReferenceV2 featureReference = + byteToFeatureReferenceMap.get(ByteBuffer.wrap(entry.getKey())); - // Decode data from Redis into Feature object fields - if (new String(redisValueK).startsWith(timestampPrefix)) { - Timestamp eventTimestamp = Timestamp.parseFrom(redisValueV); - featureTableTimestampMap.put(new String(redisValueK), eventTimestamp); - } else { - ServingAPIProto.FeatureReferenceV2 featureReference = - byteToFeatureReferenceMap.get(redisValueK); - ValueProto.Value featureValue = ValueProto.Value.parseFrom(redisValueV); - - featureMap.put(featureReference, featureValue); - } + if (featureReference == null) { + continue; } - } - // Add timestamp to features - for (Map.Entry entry : - featureMap.entrySet()) { - String timestampRedisHashKeyStr = timestampPrefix + ":" + entry.getKey().getFeatureTable(); - Timestamp curFeatureTimestamp = featureTableTimestampMap.get(timestampRedisHashKeyStr); - - ProtoFeature curFeature = - new ProtoFeature(entry.getKey(), curFeatureTimestamp, entry.getValue()); - allFeatures.add(curFeature); + ValueProto.Value v; + try { + v = ValueProto.Value.parseFrom(entry.getValue()); + } catch (InvalidProtocolBufferException ex) { + throw new RuntimeException( + "Couldn't parse feature value proto while pulling data from Redis"); + } + results.put( + featureReference, + new ProtoFeature( + featureReference, + featureTableTimestampMap.get(featureReference.getFeatureTable()), + v)); } - return allFeatures; + return results; } - public static byte[] getTimestampRedisHashKeyBytes( - ServingAPIProto.FeatureReferenceV2 featureReference, String timestampPrefix) { - String timestampRedisHashKeyStr = timestampPrefix + ":" + featureReference.getFeatureTable(); + public static byte[] getTimestampRedisHashKeyBytes(String featureTable, String timestampPrefix) { + String timestampRedisHashKeyStr = timestampPrefix + ":" + featureTable; return timestampRedisHashKeyStr.getBytes(); } diff --git a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/OnlineRetriever.java b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/OnlineRetriever.java index 4898bcfdab..ab03049b9f 100644 --- a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/OnlineRetriever.java +++ b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/OnlineRetriever.java @@ -17,18 +17,17 @@ package feast.storage.connectors.redis.retriever; import com.google.common.collect.Lists; -import com.google.protobuf.InvalidProtocolBufferException; import feast.proto.serving.ServingAPIProto; import feast.proto.storage.RedisProto; import feast.storage.api.retriever.Feature; import feast.storage.api.retriever.OnlineRetrieverV2; import feast.storage.connectors.redis.common.RedisHashDecoder; import feast.storage.connectors.redis.common.RedisKeyGenerator; -import io.grpc.Status; import io.lettuce.core.KeyValue; -import io.lettuce.core.RedisFuture; +import java.nio.ByteBuffer; import java.util.*; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -40,36 +39,37 @@ public class OnlineRetriever implements OnlineRetrieverV2 { private final RedisClientAdapter redisClientAdapter; private final EntityKeySerializer keySerializer; + // Number of fields in request to Redis which requires using HGETALL instead of HMGET + public static final int HGETALL_NUMBER_OF_FIELDS_THRESHOLD = 50; + public OnlineRetriever(RedisClientAdapter redisClientAdapter, EntityKeySerializer keySerializer) { this.redisClientAdapter = redisClientAdapter; this.keySerializer = keySerializer; } @Override - public List> getOnlineFeatures( + public List> getOnlineFeatures( String project, List entityRows, List featureReferences, List entityNames) { List redisKeys = RedisKeyGenerator.buildRedisKeys(project, entityRows); - List> features = getFeaturesFromRedis(redisKeys, featureReferences); - - return features; + return getFeaturesFromRedis(redisKeys, featureReferences); } - private List> getFeaturesFromRedis( + private List> getFeaturesFromRedis( List redisKeys, List featureReferences) { List> features = new ArrayList<>(); // To decode bytes back to Feature Reference - Map byteToFeatureReferenceMap = new HashMap<>(); + Map byteToFeatureReferenceMap = new HashMap<>(); // Serialize using proto List binaryRedisKeys = redisKeys.stream().map(this.keySerializer::serialize).collect(Collectors.toList()); - List featureReferenceWithTsByteList = new ArrayList<>(); + List retrieveFields = new ArrayList<>(); featureReferences.stream() .forEach( featureReference -> { @@ -77,43 +77,61 @@ private List> getFeaturesFromRedis( // eg. murmur() byte[] featureReferenceBytes = RedisHashDecoder.getFeatureReferenceRedisHashKeyBytes(featureReference); - featureReferenceWithTsByteList.add(featureReferenceBytes); - byteToFeatureReferenceMap.put(featureReferenceBytes, featureReference); + retrieveFields.add(featureReferenceBytes); + byteToFeatureReferenceMap.put( + ByteBuffer.wrap(featureReferenceBytes), featureReference); + }); + featureReferences.stream() + .map(ServingAPIProto.FeatureReferenceV2::getFeatureTable) + .distinct() + .forEach( + table -> { // eg. <_ts:featuretable_name> byte[] featureTableTsBytes = - RedisHashDecoder.getTimestampRedisHashKeyBytes(featureReference, timestampPrefix); - featureReferenceWithTsByteList.add(featureTableTsBytes); + RedisHashDecoder.getTimestampRedisHashKeyBytes(table, timestampPrefix); + + retrieveFields.add(featureTableTsBytes); }); - // Perform a series of independent calls - List>>> futures = Lists.newArrayList(); - for (byte[] binaryRedisKey : binaryRedisKeys) { - byte[][] featureReferenceWithTsByteArrays = - featureReferenceWithTsByteList.toArray(new byte[0][]); - // Access redis keys and extract features - futures.add(redisClientAdapter.hmget(binaryRedisKey, featureReferenceWithTsByteArrays)); + List>> futures = + Lists.newArrayListWithExpectedSize(binaryRedisKeys.size()); + + // Number of fields that controls whether to use hmget or hgetall was discovered empirically + // Could be potentially tuned further + if (retrieveFields.size() < HGETALL_NUMBER_OF_FIELDS_THRESHOLD) { + byte[][] retrieveFieldsByteArray = retrieveFields.toArray(new byte[0][]); + + for (byte[] binaryRedisKey : binaryRedisKeys) { + // Access redis keys and extract features + futures.add( + redisClientAdapter + .hmget(binaryRedisKey, retrieveFieldsByteArray) + .thenApply( + list -> + list.stream() + .filter(KeyValue::hasValue) + .collect(Collectors.toMap(KeyValue::getKey, KeyValue::getValue))) + .toCompletableFuture()); + } + + } else { + for (byte[] binaryRedisKey : binaryRedisKeys) { + futures.add(redisClientAdapter.hgetall(binaryRedisKey)); + } + } + + List> results = + Lists.newArrayListWithExpectedSize(futures.size()); + for (Future> f : futures) { + try { + results.add( + RedisHashDecoder.retrieveFeature(f.get(), byteToFeatureReferenceMap, timestampPrefix)); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Unexpected error when pulling data from Redis"); + } } - // Write all commands to the transport layer - redisClientAdapter.flushCommands(); - - futures.forEach( - future -> { - try { - List> redisValuesList = future.get(); - - List curRedisKeyFeatures = - RedisHashDecoder.retrieveFeature( - redisValuesList, byteToFeatureReferenceMap, timestampPrefix); - features.add(curRedisKeyFeatures); - } catch (InterruptedException | ExecutionException | InvalidProtocolBufferException e) { - throw Status.UNKNOWN - .withDescription("Unexpected error when pulling data from from Redis.") - .withCause(e) - .asRuntimeException(); - } - }); - return features; + return results; } } diff --git a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClient.java b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClient.java index e1699bbde0..ea95ca9ace 100644 --- a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClient.java +++ b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClient.java @@ -23,6 +23,7 @@ import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.codec.ByteArrayCodec; import java.util.List; +import java.util.Map; public class RedisClient implements RedisClientAdapter { @@ -33,6 +34,11 @@ public RedisFuture>> hmget(byte[] key, byte[]... f return asyncCommands.hmget(key, fields); } + @Override + public RedisFuture> hgetall(byte[] key) { + return asyncCommands.hgetall(key); + } + @Override public void flushCommands() { asyncCommands.flushCommands(); @@ -40,9 +46,6 @@ public void flushCommands() { private RedisClient(StatefulRedisConnection connection) { this.asyncCommands = connection.async(); - - // Disable auto-flushing - this.asyncCommands.setAutoFlushCommands(false); } public static RedisClientAdapter create(RedisStoreConfig config) { diff --git a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClientAdapter.java b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClientAdapter.java index 65e730ae93..a2b870af6c 100644 --- a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClientAdapter.java +++ b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClientAdapter.java @@ -18,9 +18,12 @@ import io.lettuce.core.*; import java.util.List; +import java.util.Map; public interface RedisClientAdapter { RedisFuture>> hmget(byte[] key, byte[]... fields); + RedisFuture> hgetall(byte[] key); + void flushCommands(); } diff --git a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterClient.java b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterClient.java index 5395b72e5f..6574962bbb 100644 --- a/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterClient.java +++ b/java/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisClusterClient.java @@ -24,6 +24,7 @@ import io.lettuce.core.codec.ByteArrayCodec; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; public class RedisClusterClient implements RedisClientAdapter { @@ -35,6 +36,11 @@ public RedisFuture>> hmget(byte[] key, byte[]... f return asyncCommands.hmget(key, fields); } + @Override + public RedisFuture> hgetall(byte[] key) { + return asyncCommands.hgetall(key); + } + @Override public void flushCommands() { asyncCommands.flushCommands(); @@ -57,9 +63,6 @@ private RedisClusterClient(Builder builder) { // allows reading from replicas this.asyncCommands.readOnly(); - - // Disable auto-flushing - this.asyncCommands.setAutoFlushCommands(false); } public static RedisClientAdapter create(RedisClusterStoreConfig config) {