From e15c749b1f3952c4a6bc96e653d62683ae093504 Mon Sep 17 00:00:00 2001 From: Ivan Kripakov <108407979+ivan-kripakov-m10@users.noreply.github.com> Date: Thu, 16 May 2024 15:37:37 +0400 Subject: [PATCH] TRY-410: add possibility to initialize caches by api-keys on app start up (#6) --- .../cache/EmbeddingCacheProvider.java | 43 ++++++++++++--- .../EmbeddingCacheProviderInitializer.java | 43 +++++++++++++++ .../cache/EmbeddingCollection.java | 19 +++++-- .../service/EmbeddingService.java | 20 ++++--- .../frs/core/trainservice/util/MaskUtils.java | 21 ++++++++ java/api/src/main/resources/application.yml | 6 +++ .../trainservice/EmbeddedPostgreSQLTest.java | 48 ----------------- .../frs/core/trainservice/TestLiquibase.java | 54 +++++++++++++++++++ .../cache/EmbeddingCacheProviderTest.java | 28 +++++----- .../cache/EmbeddingCollectionTest.java | 12 ++--- .../EuclideanDistanceClassifierTest.java | 5 +- .../service/SubjectServiceTest.java | 2 +- .../core/trainservice/util/MaskUtilsTest.java | 20 +++++++ java/api/src/test/resources/application.yml | 6 +++ .../repository/EmbeddingRepository.java | 27 ++++------ 15 files changed, 248 insertions(+), 106 deletions(-) create mode 100644 java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderInitializer.java create mode 100644 java/api/src/main/java/com/exadel/frs/core/trainservice/util/MaskUtils.java create mode 100644 java/api/src/test/java/com/exadel/frs/core/trainservice/TestLiquibase.java create mode 100644 java/api/src/test/java/com/exadel/frs/core/trainservice/util/MaskUtilsTest.java diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java index cd339a54e2..41dddf68fe 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProvider.java @@ -10,8 +10,12 @@ import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects; import com.exadel.frs.core.trainservice.service.EmbeddingService; import com.exadel.frs.core.trainservice.service.NotificationSenderService; +import com.exadel.frs.core.trainservice.util.MaskUtils; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; +import java.time.Duration; +import java.time.Instant; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; @@ -23,6 +27,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.math3.linear.RealVector; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID; @@ -41,12 +46,31 @@ public class EmbeddingCacheProvider { private final Lock lock = new ReentrantLock(); + @Value("${embeddings.cache.initialization.page-size}") + private int pageSize; + private static final Cache cache = CacheBuilder.newBuilder() .expireAfterAccess(CACHE_EXPIRATION, TimeUnit.SECONDS) .maximumSize(CACHE_MAXIMUM_SIZE) .build(); + void fillInCache(Collection apiKeys) { + if (apiKeys.size() > CACHE_MAXIMUM_SIZE) { + log.warn("Number of api keys to initialize cache is greater than cache maximum size"); + } + apiKeys.stream() + .limit(CACHE_MAXIMUM_SIZE) + .forEach(k -> { + var masked = MaskUtils.maskApiKey(k); + log.debug("Initializing cache for api key: {}", masked); + var start = Instant.now(); + getOrLoad(k); + var initDuration = Duration.between(start, Instant.now()); + log.info("Cache for api key {} initialized in {}", masked, initDuration); + }); + } + public EmbeddingCollection getOrLoad(final String apiKey) { var result = cache.getIfPresent(apiKey); if (result == null) { @@ -54,9 +78,16 @@ public EmbeddingCollection getOrLoad(final String apiKey) { lock.lock(); result = cache.getIfPresent(apiKey); if (result == null) { - result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from); - cache.put(apiKey, result); + var embeddingCollection = new EmbeddingCollection(); + embeddingService.doWithEnhancedEmbeddingProjections( + apiKey, + embeddingCollection::addEmbedding, + pageSize + ); + cache.put(apiKey, embeddingCollection); + return embeddingCollection; } + return result; } finally { lock.unlock(); } @@ -110,13 +141,13 @@ public void invalidate(final String apiKey) { } /** - * @deprecated - * See {@link com.exadel.frs.core.trainservice.service.NotificationHandler#handleUpdate(CacheActionDto)} + * @deprecated See {@link com.exadel.frs.core.trainservice.service.NotificationHandler#handleUpdate(CacheActionDto)} */ @Deprecated(forRemoval = true) public void receivePutOnCache(String apiKey) { - var result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from); - cache.put(apiKey, result); + var newEmbeddingCollection = new EmbeddingCollection(); + embeddingService.doWithEnhancedEmbeddingProjections(apiKey, newEmbeddingCollection::addEmbedding, pageSize); + cache.put(apiKey, newEmbeddingCollection); } public void receiveInvalidateCache(final String apiKey) { diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderInitializer.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderInitializer.java new file mode 100644 index 0000000000..2e233388b5 --- /dev/null +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderInitializer.java @@ -0,0 +1,43 @@ +package com.exadel.frs.core.trainservice.cache; + +import java.util.Arrays; +import java.util.Optional; +import java.util.UUID; +import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.ApplicationArguments; +import org.springframework.boot.ApplicationRunner; +import org.springframework.stereotype.Component; + +@Slf4j +@Component +@RequiredArgsConstructor +public class EmbeddingCacheProviderInitializer implements ApplicationRunner { + private final EmbeddingCacheProvider embeddingCacheProvider; + @Value("${embeddings.cache.initialization.api-keys}") + private String initializationApiKeys; + + @Override + public void run(ApplicationArguments args) { + var apiKeys = Optional.ofNullable(initializationApiKeys) + .map(keys -> keys.split(",")) + .stream() + .flatMap(Arrays::stream) + .filter(s -> !s.isBlank()) + .filter(EmbeddingCacheProviderInitializer::isUuid) + .collect(Collectors.toSet()); + embeddingCacheProvider.fillInCache(apiKeys); + } + + private static boolean isUuid(String s) { + try { + UUID.fromString(s); + return true; + } catch (IllegalArgumentException e) { + log.warn("Invalid UUID: {}", s, e); + return false; + } + } +} diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollection.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollection.java index 7d678675ba..6f061a22f7 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollection.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollection.java @@ -4,6 +4,7 @@ import com.exadel.frs.commonservice.entity.EmbeddingProjection; import com.exadel.frs.commonservice.entity.EnhancedEmbeddingProjection; import com.exadel.frs.commonservice.exception.IncorrectImageIdException; +import java.util.Collection; import java.util.Collections; import java.util.Map; import java.util.Map.Entry; @@ -14,22 +15,25 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.stream.Collectors; -import java.util.stream.Stream; import lombok.AccessLevel; -import lombok.AllArgsConstructor; +import lombok.RequiredArgsConstructor; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealVector; import org.springframework.data.util.Pair; import org.springframework.lang.NonNull; -@AllArgsConstructor(access = AccessLevel.PRIVATE) +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) public class EmbeddingCollection { private final ConcurrentMap> mapping; - public static EmbeddingCollection from(final Stream stream) { + public EmbeddingCollection() { + this.mapping = new ConcurrentHashMap<>(); + } + + public static EmbeddingCollection from(final Collection projections) { // we copy vector here just in case - var newMap = stream.map(e -> Map.entry(e.getSubjectName(), Pair.of(e.getEmbeddingId(), MatrixUtils.createRealVector(e.getEmbeddingData())))) + var newMap = projections.stream().map(e -> Map.entry(e.getSubjectName(), Pair.of(e.getEmbeddingId(), MatrixUtils.createRealVector(e.getEmbeddingData())))) .collect( Collectors.toConcurrentMap( Entry::getKey, @@ -68,6 +72,11 @@ public EmbeddingProjection addEmbedding(final Embedding embedding) { return new EmbeddingProjection(id, embedding.getSubject().getSubjectName()); } + public void addEmbedding(final EnhancedEmbeddingProjection projection) { + mapping.computeIfAbsent(projection.getSubjectName(), k -> new ConcurrentHashMap<>()) + .put(projection.getEmbeddingId(), MatrixUtils.createRealVector(projection.getEmbeddingData())); + } + public void removeEmbeddingsBySubjectName(String subjectName) { mapping.remove(subjectName); } diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java index ff2e36b18b..a2aad04941 100644 --- a/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/service/EmbeddingService.java @@ -8,8 +8,9 @@ import com.exadel.frs.commonservice.repository.ImgRepository; import com.exadel.frs.core.trainservice.system.global.Constants; import java.util.Map; -import java.util.stream.Stream; +import java.util.function.Consumer; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.stereotype.Service; @@ -18,8 +19,8 @@ import java.util.List; import java.util.Optional; import java.util.UUID; -import java.util.function.Function; +@Slf4j @Service @RequiredArgsConstructor public class EmbeddingService { @@ -47,10 +48,17 @@ public void updateEmbedding(UUID embeddingId, Map imageAttribute }); } - @org.springframework.transaction.annotation.Transactional(readOnly = true) - public T doWithEnhancedEmbeddingProjectionStream(String apiKey, Function, T> func) { - try (var stream = embeddingRepository.findBySubjectApiKey(apiKey)) { - return func.apply(stream); + + public void doWithEnhancedEmbeddingProjections( + String apiKey, + Consumer func, + int pageSize + ) { + var page = embeddingRepository.findEnhancedBySubjectApiKey(apiKey, Pageable.ofSize(pageSize)); + page.forEach(func); + while (page.hasNext()) { + page = embeddingRepository.findEnhancedBySubjectApiKey(apiKey, page.nextPageable()); + page.forEach(func); } } diff --git a/java/api/src/main/java/com/exadel/frs/core/trainservice/util/MaskUtils.java b/java/api/src/main/java/com/exadel/frs/core/trainservice/util/MaskUtils.java new file mode 100644 index 0000000000..62d470f6e6 --- /dev/null +++ b/java/api/src/main/java/com/exadel/frs/core/trainservice/util/MaskUtils.java @@ -0,0 +1,21 @@ +package com.exadel.frs.core.trainservice.util; + +import org.springframework.lang.NonNull; + +public final class MaskUtils { + private MaskUtils() { + // NOOP + } + + public static String maskApiKey(@NonNull String apiKey) { + var maskedApiKey = new StringBuilder(); + for (int i = 0; i < apiKey.length(); i++) { + if (i < 4 || apiKey.charAt(i) == '-') { + maskedApiKey.append(apiKey.charAt(i)); + } else { + maskedApiKey.append('*'); + } + } + return maskedApiKey.toString(); + } +} diff --git a/java/api/src/main/resources/application.yml b/java/api/src/main/resources/application.yml index 7c282968b3..d5a83df766 100644 --- a/java/api/src/main/resources/application.yml +++ b/java/api/src/main/resources/application.yml @@ -88,6 +88,12 @@ app: retryer: max-attempts: ${MAX_ATTEMPTS:1} +embeddings: + cache: + initialization: + api-keys: ${EMBEDDINGS_CACHE_INITIALIZATION_API_KEYS:} + page-size: ${EMBEDDINGS_CACHE_INITIALIZATION_PAGE_SIZE:100000} + --- spring: diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/EmbeddedPostgreSQLTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/EmbeddedPostgreSQLTest.java index e1287666f9..21ca7b2fc3 100644 --- a/java/api/src/test/java/com/exadel/frs/core/trainservice/EmbeddedPostgreSQLTest.java +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/EmbeddedPostgreSQLTest.java @@ -6,19 +6,8 @@ import io.zonky.test.db.AutoConfigureEmbeddedDatabase; import io.zonky.test.db.AutoConfigureEmbeddedDatabase.DatabaseProvider; import io.zonky.test.db.AutoConfigureEmbeddedDatabase.DatabaseType; -import javax.annotation.PostConstruct; -import javax.sql.DataSource; -import liquibase.Contexts; -import liquibase.LabelExpression; -import liquibase.Liquibase; -import liquibase.database.DatabaseFactory; -import liquibase.database.jvm.JdbcConnection; -import liquibase.integration.spring.SpringResourceAccessor; import org.junit.jupiter.api.extension.ExtendWith; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.mock.mockito.MockBean; -import org.springframework.core.env.Environment; -import org.springframework.core.io.ResourceLoader; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -33,41 +22,4 @@ public class EmbeddedPostgreSQLTest { @MockBean NotificationReceiverService notificationReceiverService; - - @Autowired - DataSource dataSource; - - @Autowired - ResourceLoader resourceLoader; - - @Autowired - private Environment env; - - @PostConstruct - public void initDatabase() { - try { - Liquibase liquibase = new Liquibase( - "db/changelog/db.changelog-master.yaml", - new SpringResourceAccessor(resourceLoader), - DatabaseFactory.getInstance().findCorrectDatabaseImplementation(new JdbcConnection(dataSource.getConnection())) - ); - setLiquibaseChangeLogParams(liquibase); - liquibase.update(new Contexts(), new LabelExpression()); - } catch (Exception e) { - //manage exception - e.printStackTrace(); - } - } - - private void setLiquibaseChangeLogParams(final Liquibase liquibase) { - String clientId = env.getProperty("spring.liquibase.parameters.common-client.client-id", "CommonClientId"); - String accessTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.access-token-validity", "2400"); - String refreshTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.refresh-token-validity", "1209600"); - String authorizedGrantTypes = env.getProperty("spring.liquibase.parameters.common-client.authorized-grant-types", "password,refresh_token"); - - liquibase.setChangeLogParameter("common-client.client-id", clientId); - liquibase.setChangeLogParameter("common-client.access-token-validity", accessTokenValidity); - liquibase.setChangeLogParameter("common-client.refresh-token-validity", refreshTokenValidity); - liquibase.setChangeLogParameter("common-client.authorized-grant-types", authorizedGrantTypes); - } } diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/TestLiquibase.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/TestLiquibase.java new file mode 100644 index 0000000000..6bf6756210 --- /dev/null +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/TestLiquibase.java @@ -0,0 +1,54 @@ +package com.exadel.frs.core.trainservice; + +import javax.annotation.PostConstruct; +import javax.sql.DataSource; +import liquibase.Contexts; +import liquibase.LabelExpression; +import liquibase.Liquibase; +import liquibase.database.DatabaseFactory; +import liquibase.database.jvm.JdbcConnection; +import liquibase.integration.spring.SpringResourceAccessor; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.env.Environment; +import org.springframework.core.io.ResourceLoader; +import org.springframework.stereotype.Service; + +@Service +public class TestLiquibase { + @Autowired + ResourceLoader resourceLoader; + + @Autowired + private Environment env; + + @Autowired + DataSource dataSource; + + @PostConstruct + public void initDatabase() { + try { + Liquibase liquibase = new Liquibase( + "db/changelog/db.changelog-master.yaml", + new SpringResourceAccessor(resourceLoader), + DatabaseFactory.getInstance().findCorrectDatabaseImplementation(new JdbcConnection(dataSource.getConnection())) + ); + setLiquibaseChangeLogParams(liquibase); + liquibase.update(new Contexts(), new LabelExpression()); + } catch (Exception e) { + //manage exception + e.printStackTrace(); + } + } + + private void setLiquibaseChangeLogParams(final Liquibase liquibase) { + String clientId = env.getProperty("spring.liquibase.parameters.common-client.client-id", "CommonClientId"); + String accessTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.access-token-validity", "2400"); + String refreshTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.refresh-token-validity", "1209600"); + String authorizedGrantTypes = env.getProperty("spring.liquibase.parameters.common-client.authorized-grant-types", "password,refresh_token"); + + liquibase.setChangeLogParameter("common-client.client-id", clientId); + liquibase.setChangeLogParameter("common-client.access-token-validity", accessTokenValidity); + liquibase.setChangeLogParameter("common-client.refresh-token-validity", refreshTokenValidity); + liquibase.setChangeLogParameter("common-client.authorized-grant-types", authorizedGrantTypes); + } +} diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java index 4d390e4bc6..037e706df6 100644 --- a/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCacheProviderTest.java @@ -30,18 +30,20 @@ import com.exadel.frs.core.trainservice.service.NotificationReceiverService; import com.exadel.frs.core.trainservice.service.NotificationSenderService; import com.exadel.frs.core.trainservice.system.global.Constants; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.function.Function; -import java.util.stream.Stream; +import java.util.function.Consumer; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.testcontainers.shaded.org.apache.commons.lang3.reflect.FieldUtils; import static com.exadel.frs.core.trainservice.ItemsBuilder.makeEnhancedEmbeddingProjection; import static org.hamcrest.MatcherAssert.assertThat; @@ -52,7 +54,6 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class EmbeddingCacheProviderTest { @@ -78,13 +79,10 @@ class EmbeddingCacheProviderTest { @BeforeEach @SuppressWarnings("unchecked") - public void resetStaticCache() { + public void resetStaticCache() throws IllegalAccessException { + FieldUtils.writeField(embeddingCacheProvider, "pageSize", 10, true); embeddingCacheProvider.invalidate(API_KEY); - when(embeddingService.doWithEnhancedEmbeddingProjectionStream(eq(API_KEY), any())) - .thenAnswer(invocation -> { - var function = (Function, ?>) invocation.getArgument(1); - return function.apply(Stream.of()); - }); + Mockito.doNothing().when(embeddingService).doWithEnhancedEmbeddingProjections(eq(API_KEY), any(), eq(10)); } @Test @@ -97,11 +95,13 @@ void getOrLoad() { makeEnhancedEmbeddingProjection("C") }; - when(embeddingService.doWithEnhancedEmbeddingProjectionStream(eq(API_KEY), any())) - .thenAnswer(invocation -> { - var function = (Function, ?>) invocation.getArgument(1); - return function.apply(Stream.of(projections)); - }); + Mockito.doAnswer( + invocation -> { + var function = (Consumer) invocation.getArgument(1); + Arrays.stream(projections).forEach(function); + return null; + } + ).when(embeddingService).doWithEnhancedEmbeddingProjections(eq(API_KEY), any(), eq(10)); var actual = embeddingCacheProvider.getOrLoad(API_KEY); diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollectionTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollectionTest.java index 58d8ef6512..84f0f36409 100644 --- a/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollectionTest.java +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/cache/EmbeddingCollectionTest.java @@ -18,8 +18,8 @@ import com.exadel.frs.commonservice.entity.EmbeddingProjection; import com.exadel.frs.commonservice.entity.EnhancedEmbeddingProjection; +import java.util.Arrays; import java.util.UUID; -import java.util.stream.Stream; import org.apache.commons.math3.linear.MatrixUtils; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; @@ -34,7 +34,7 @@ class EmbeddingCollectionTest { @Test void testRemoveFromEmpty() { - var embeddingCollection = EmbeddingCollection.from(Stream.of()); + var embeddingCollection = new EmbeddingCollection(); final EmbeddingProjection removed = embeddingCollection.removeEmbedding(new EmbeddingProjection(UUID.randomUUID(), "subject_name")); assertThat(removed).isNull(); @@ -42,7 +42,7 @@ void testRemoveFromEmpty() { @Test void testAddToEmpty() { - var embeddingCollection = EmbeddingCollection.from(Stream.of()); + var embeddingCollection = new EmbeddingCollection(); Assertions.assertThat(embeddingCollection.exposeMap()).isEmpty(); var embedding = makeEmbedding("A", API_KEY); @@ -60,7 +60,7 @@ void testCreate() { var projection2 = makeEnhancedEmbeddingProjection("B"); var projection3 = makeEnhancedEmbeddingProjection("C"); var projections = new EnhancedEmbeddingProjection[]{projection1, projection2, projection3}; - var embeddingCollection = EmbeddingCollection.from(Stream.of(projections)); + var embeddingCollection = EmbeddingCollection.from(Arrays.asList(projections)); assertThat(embeddingCollection.exposeMap()).hasSize(projections.length); assertThat(embeddingCollection.exposeMap().get("A")) @@ -78,7 +78,7 @@ void testAdd() { makeEnhancedEmbeddingProjection("B"), makeEnhancedEmbeddingProjection("C") }; - var embeddingCollection = EmbeddingCollection.from(Stream.of(projections)); + var embeddingCollection = EmbeddingCollection.from(Arrays.asList(projections)); var newEmbedding = makeEmbedding("D", API_KEY); newEmbedding.setId(UUID.randomUUID()); @@ -96,7 +96,7 @@ void testRemove() { var projection2 = makeEnhancedEmbeddingProjection("B"); var projection3 = makeEnhancedEmbeddingProjection("C"); var projections = new EnhancedEmbeddingProjection[]{projection1, projection2, projection3}; - var embeddingCollection = EmbeddingCollection.from(Stream.of(projections)); + var embeddingCollection = EmbeddingCollection.from(Arrays.asList(projections)); embeddingCollection.removeEmbedding(EmbeddingProjection.from(projection1)); diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/component/classifiers/EuclideanDistanceClassifierTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/component/classifiers/EuclideanDistanceClassifierTest.java index 87f6c39bd6..a8be35325e 100644 --- a/java/api/src/test/java/com/exadel/frs/core/trainservice/component/classifiers/EuclideanDistanceClassifierTest.java +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/component/classifiers/EuclideanDistanceClassifierTest.java @@ -9,7 +9,6 @@ import java.util.Map; import java.util.Optional; import java.util.UUID; -import java.util.stream.Stream; import org.apache.commons.math3.linear.MatrixUtils; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; @@ -42,7 +41,7 @@ void predict() { // arrange Mockito.when(provider.getOrLoad(Mockito.anyString())).thenReturn( EmbeddingCollection.from( - Stream.of( + List.of( EMBEDDING_1, EMBEDDING_2, EMBEDDING_3 @@ -67,7 +66,7 @@ void verify() { // arrange Mockito.when(provider.getOrLoad(Mockito.anyString())).thenReturn( EmbeddingCollection.from( - Stream.of( + List.of( EMBEDDING_1, EMBEDDING_2, EMBEDDING_3 diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java index 069a1aa892..677166d7ab 100644 --- a/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java @@ -260,7 +260,7 @@ void testInvalidImageIdException(boolean status){ var detProbThreshold = 0.7; var randomUUId = UUID.randomUUID(); var file = new MockMultipartFile("anyname", new byte[]{0xA}); - var embeddingCollection = EmbeddingCollection.from(Stream.of( + var embeddingCollection = EmbeddingCollection.from(List.of( makeEnhancedEmbeddingProjection("A"), makeEnhancedEmbeddingProjection("B"))); diff --git a/java/api/src/test/java/com/exadel/frs/core/trainservice/util/MaskUtilsTest.java b/java/api/src/test/java/com/exadel/frs/core/trainservice/util/MaskUtilsTest.java new file mode 100644 index 0000000000..c7ec9159bb --- /dev/null +++ b/java/api/src/test/java/com/exadel/frs/core/trainservice/util/MaskUtilsTest.java @@ -0,0 +1,20 @@ +package com.exadel.frs.core.trainservice.util; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class MaskUtilsTest { + + @Test + void maskApiKey() { + // arrange + var apiKey = "2abca51f-af81-43a8-81de-0dfc75e94b23"; + + // act + var masked = MaskUtils.maskApiKey(apiKey); + + // assert + assertEquals("2abc****-****-****-****-************", masked); + } +} diff --git a/java/api/src/test/resources/application.yml b/java/api/src/test/resources/application.yml index 2920b20c4a..ca6a97f498 100644 --- a/java/api/src/test/resources/application.yml +++ b/java/api/src/test/resources/application.yml @@ -48,6 +48,12 @@ app: retryer: max-attempts: ${MAX_ATTEMPTS:1} +embeddings: + cache: + initialization: + api-keys: ${EMBEDDINGS_CACHE_INITIALIZATION_API_KEYS:123} + page-size: ${EMBEDDINGS_CACHE_INITIALIZATION_PAGE_SIZE:100000} + statistic: model: cron-expression: ${MODEL_STATISTIC_CRON_EXPRESSION:0 0 * ? * *} diff --git a/java/common/src/main/java/com/exadel/frs/commonservice/repository/EmbeddingRepository.java b/java/common/src/main/java/com/exadel/frs/commonservice/repository/EmbeddingRepository.java index ca0b7f81fe..90bc5a304f 100644 --- a/java/common/src/main/java/com/exadel/frs/commonservice/repository/EmbeddingRepository.java +++ b/java/common/src/main/java/com/exadel/frs/commonservice/repository/EmbeddingRepository.java @@ -6,15 +6,12 @@ import com.exadel.frs.commonservice.entity.Subject; import java.util.List; import java.util.UUID; -import java.util.stream.Stream; -import javax.persistence.QueryHint; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.jpa.repository.EntityGraph; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Modifying; import org.springframework.data.jpa.repository.Query; -import org.springframework.data.jpa.repository.QueryHints; import org.springframework.data.repository.query.Param; import org.springframework.lang.NonNull; @@ -23,20 +20,6 @@ public interface EmbeddingRepository extends JpaRepository { @Query("SELECT e FROM Embedding e JOIN FETCH e.subject WHERE e.id in :ids") List findByIdIn(@Param("ids") @NonNull Iterable ids); - // Note: consumer should consume in transaction - @Query("select " + - " new com.exadel.frs.commonservice.entity.EnhancedEmbeddingProjection(e.id, e.embedding, s.subjectName)" + - " from " + - " Embedding e " + - " left join " + - " e.subject s " + - " where " + - " s.apiKey = :apiKey") - @QueryHints( - @QueryHint(name = org.hibernate.jpa.QueryHints.HINT_FETCH_SIZE, value = "10000") - ) - Stream findBySubjectApiKey(@Param("apiKey") String apiKey); - @EntityGraph("embedding-with-subject") List findBySubjectId(UUID subjectId); @@ -69,6 +52,16 @@ int updateEmbedding(@Param("embeddingId") UUID embeddingId, " e.subject.apiKey = :apiKey") Page findBySubjectApiKey(String apiKey, Pageable pageable); + @Query("select " + + " new com.exadel.frs.commonservice.entity.EnhancedEmbeddingProjection(e.id, e.embedding, s.subjectName)" + + " from " + + " Embedding e " + + " left join " + + " e.subject s " + + " where " + + " s.apiKey = :apiKey") + Page findEnhancedBySubjectApiKey(String apiKey, Pageable pageable); + @Query("select " + " new com.exadel.frs.commonservice.entity.EmbeddingProjection(e.id, e.subject.subjectName)" + " from " +