Skip to content

Commit

Permalink
TRY-410: add possibility to initialize caches by api-keys on app star…
Browse files Browse the repository at this point in the history
…t up (#6)
  • Loading branch information
ivan-kripakov-m10 authored May 16, 2024
1 parent e71d829 commit e15c749
Show file tree
Hide file tree
Showing 15 changed files with 248 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
import com.exadel.frs.core.trainservice.dto.CacheActionDto.RenameSubjects;
import com.exadel.frs.core.trainservice.service.EmbeddingService;
import com.exadel.frs.core.trainservice.service.NotificationSenderService;
import com.exadel.frs.core.trainservice.util.MaskUtils;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -23,6 +27,7 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.linear.RealVector;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;
Expand All @@ -41,22 +46,48 @@ public class EmbeddingCacheProvider {

private final Lock lock = new ReentrantLock();

@Value("${embeddings.cache.initialization.page-size}")
private int pageSize;

private static final Cache<String, EmbeddingCollection> cache =
CacheBuilder.newBuilder()
.expireAfterAccess(CACHE_EXPIRATION, TimeUnit.SECONDS)
.maximumSize(CACHE_MAXIMUM_SIZE)
.build();

void fillInCache(Collection<String> apiKeys) {
if (apiKeys.size() > CACHE_MAXIMUM_SIZE) {
log.warn("Number of api keys to initialize cache is greater than cache maximum size");
}
apiKeys.stream()
.limit(CACHE_MAXIMUM_SIZE)
.forEach(k -> {
var masked = MaskUtils.maskApiKey(k);
log.debug("Initializing cache for api key: {}", masked);
var start = Instant.now();
getOrLoad(k);
var initDuration = Duration.between(start, Instant.now());
log.info("Cache for api key {} initialized in {}", masked, initDuration);
});
}

public EmbeddingCollection getOrLoad(final String apiKey) {
var result = cache.getIfPresent(apiKey);
if (result == null) {
try {
lock.lock();
result = cache.getIfPresent(apiKey);
if (result == null) {
result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from);
cache.put(apiKey, result);
var embeddingCollection = new EmbeddingCollection();
embeddingService.doWithEnhancedEmbeddingProjections(
apiKey,
embeddingCollection::addEmbedding,
pageSize
);
cache.put(apiKey, embeddingCollection);
return embeddingCollection;
}
return result;
} finally {
lock.unlock();
}
Expand Down Expand Up @@ -110,13 +141,13 @@ public void invalidate(final String apiKey) {
}

/**
* @deprecated
* See {@link com.exadel.frs.core.trainservice.service.NotificationHandler#handleUpdate(CacheActionDto)}
* @deprecated See {@link com.exadel.frs.core.trainservice.service.NotificationHandler#handleUpdate(CacheActionDto)}
*/
@Deprecated(forRemoval = true)
public void receivePutOnCache(String apiKey) {
var result = embeddingService.doWithEnhancedEmbeddingProjectionStream(apiKey, EmbeddingCollection::from);
cache.put(apiKey, result);
var newEmbeddingCollection = new EmbeddingCollection();
embeddingService.doWithEnhancedEmbeddingProjections(apiKey, newEmbeddingCollection::addEmbedding, pageSize);
cache.put(apiKey, newEmbeddingCollection);
}

public void receiveInvalidateCache(final String apiKey) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.exadel.frs.core.trainservice.cache;

import java.util.Arrays;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

@Slf4j
@Component
@RequiredArgsConstructor
public class EmbeddingCacheProviderInitializer implements ApplicationRunner {
private final EmbeddingCacheProvider embeddingCacheProvider;
@Value("${embeddings.cache.initialization.api-keys}")
private String initializationApiKeys;

@Override
public void run(ApplicationArguments args) {
var apiKeys = Optional.ofNullable(initializationApiKeys)
.map(keys -> keys.split(","))
.stream()
.flatMap(Arrays::stream)
.filter(s -> !s.isBlank())
.filter(EmbeddingCacheProviderInitializer::isUuid)
.collect(Collectors.toSet());
embeddingCacheProvider.fillInCache(apiKeys);
}

private static boolean isUuid(String s) {
try {
UUID.fromString(s);
return true;
} catch (IllegalArgumentException e) {
log.warn("Invalid UUID: {}", s, e);
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.exadel.frs.commonservice.entity.EmbeddingProjection;
import com.exadel.frs.commonservice.entity.EnhancedEmbeddingProjection;
import com.exadel.frs.commonservice.exception.IncorrectImageIdException;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Map.Entry;
Expand All @@ -14,22 +15,25 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealVector;
import org.springframework.data.util.Pair;
import org.springframework.lang.NonNull;

@AllArgsConstructor(access = AccessLevel.PRIVATE)
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public class EmbeddingCollection {

private final ConcurrentMap<String, Map<UUID, RealVector>> mapping;

public static EmbeddingCollection from(final Stream<EnhancedEmbeddingProjection> stream) {
public EmbeddingCollection() {
this.mapping = new ConcurrentHashMap<>();
}

public static EmbeddingCollection from(final Collection<EnhancedEmbeddingProjection> projections) {
// we copy vector here just in case
var newMap = stream.map(e -> Map.entry(e.getSubjectName(), Pair.of(e.getEmbeddingId(), MatrixUtils.createRealVector(e.getEmbeddingData()))))
var newMap = projections.stream().map(e -> Map.entry(e.getSubjectName(), Pair.of(e.getEmbeddingId(), MatrixUtils.createRealVector(e.getEmbeddingData()))))
.collect(
Collectors.toConcurrentMap(
Entry::getKey,
Expand Down Expand Up @@ -68,6 +72,11 @@ public EmbeddingProjection addEmbedding(final Embedding embedding) {
return new EmbeddingProjection(id, embedding.getSubject().getSubjectName());
}

public void addEmbedding(final EnhancedEmbeddingProjection projection) {
mapping.computeIfAbsent(projection.getSubjectName(), k -> new ConcurrentHashMap<>())
.put(projection.getEmbeddingId(), MatrixUtils.createRealVector(projection.getEmbeddingData()));
}

public void removeEmbeddingsBySubjectName(String subjectName) {
mapping.remove(subjectName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import com.exadel.frs.commonservice.repository.ImgRepository;
import com.exadel.frs.core.trainservice.system.global.Constants;
import java.util.Map;
import java.util.stream.Stream;
import java.util.function.Consumer;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;
Expand All @@ -18,8 +19,8 @@
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;

@Slf4j
@Service
@RequiredArgsConstructor
public class EmbeddingService {
Expand Down Expand Up @@ -47,10 +48,17 @@ public void updateEmbedding(UUID embeddingId, Map<String, String> imageAttribute
});
}

@org.springframework.transaction.annotation.Transactional(readOnly = true)
public <T> T doWithEnhancedEmbeddingProjectionStream(String apiKey, Function<Stream<EnhancedEmbeddingProjection>, T> func) {
try (var stream = embeddingRepository.findBySubjectApiKey(apiKey)) {
return func.apply(stream);

public void doWithEnhancedEmbeddingProjections(
String apiKey,
Consumer<EnhancedEmbeddingProjection> func,
int pageSize
) {
var page = embeddingRepository.findEnhancedBySubjectApiKey(apiKey, Pageable.ofSize(pageSize));
page.forEach(func);
while (page.hasNext()) {
page = embeddingRepository.findEnhancedBySubjectApiKey(apiKey, page.nextPageable());
page.forEach(func);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.exadel.frs.core.trainservice.util;

import org.springframework.lang.NonNull;

public final class MaskUtils {
private MaskUtils() {
// NOOP
}

public static String maskApiKey(@NonNull String apiKey) {
var maskedApiKey = new StringBuilder();
for (int i = 0; i < apiKey.length(); i++) {
if (i < 4 || apiKey.charAt(i) == '-') {
maskedApiKey.append(apiKey.charAt(i));
} else {
maskedApiKey.append('*');
}
}
return maskedApiKey.toString();
}
}
6 changes: 6 additions & 0 deletions java/api/src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ app:
retryer:
max-attempts: ${MAX_ATTEMPTS:1}

embeddings:
cache:
initialization:
api-keys: ${EMBEDDINGS_CACHE_INITIALIZATION_API_KEYS:}
page-size: ${EMBEDDINGS_CACHE_INITIALIZATION_PAGE_SIZE:100000}

---

spring:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@
import io.zonky.test.db.AutoConfigureEmbeddedDatabase;
import io.zonky.test.db.AutoConfigureEmbeddedDatabase.DatabaseProvider;
import io.zonky.test.db.AutoConfigureEmbeddedDatabase.DatabaseType;
import javax.annotation.PostConstruct;
import javax.sql.DataSource;
import liquibase.Contexts;
import liquibase.LabelExpression;
import liquibase.Liquibase;
import liquibase.database.DatabaseFactory;
import liquibase.database.jvm.JdbcConnection;
import liquibase.integration.spring.SpringResourceAccessor;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ResourceLoader;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit.jupiter.SpringExtension;

Expand All @@ -33,41 +22,4 @@ public class EmbeddedPostgreSQLTest {

@MockBean
NotificationReceiverService notificationReceiverService;

@Autowired
DataSource dataSource;

@Autowired
ResourceLoader resourceLoader;

@Autowired
private Environment env;

@PostConstruct
public void initDatabase() {
try {
Liquibase liquibase = new Liquibase(
"db/changelog/db.changelog-master.yaml",
new SpringResourceAccessor(resourceLoader),
DatabaseFactory.getInstance().findCorrectDatabaseImplementation(new JdbcConnection(dataSource.getConnection()))
);
setLiquibaseChangeLogParams(liquibase);
liquibase.update(new Contexts(), new LabelExpression());
} catch (Exception e) {
//manage exception
e.printStackTrace();
}
}

private void setLiquibaseChangeLogParams(final Liquibase liquibase) {
String clientId = env.getProperty("spring.liquibase.parameters.common-client.client-id", "CommonClientId");
String accessTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.access-token-validity", "2400");
String refreshTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.refresh-token-validity", "1209600");
String authorizedGrantTypes = env.getProperty("spring.liquibase.parameters.common-client.authorized-grant-types", "password,refresh_token");

liquibase.setChangeLogParameter("common-client.client-id", clientId);
liquibase.setChangeLogParameter("common-client.access-token-validity", accessTokenValidity);
liquibase.setChangeLogParameter("common-client.refresh-token-validity", refreshTokenValidity);
liquibase.setChangeLogParameter("common-client.authorized-grant-types", authorizedGrantTypes);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.exadel.frs.core.trainservice;

import javax.annotation.PostConstruct;
import javax.sql.DataSource;
import liquibase.Contexts;
import liquibase.LabelExpression;
import liquibase.Liquibase;
import liquibase.database.DatabaseFactory;
import liquibase.database.jvm.JdbcConnection;
import liquibase.integration.spring.SpringResourceAccessor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ResourceLoader;
import org.springframework.stereotype.Service;

@Service
public class TestLiquibase {
@Autowired
ResourceLoader resourceLoader;

@Autowired
private Environment env;

@Autowired
DataSource dataSource;

@PostConstruct
public void initDatabase() {
try {
Liquibase liquibase = new Liquibase(
"db/changelog/db.changelog-master.yaml",
new SpringResourceAccessor(resourceLoader),
DatabaseFactory.getInstance().findCorrectDatabaseImplementation(new JdbcConnection(dataSource.getConnection()))
);
setLiquibaseChangeLogParams(liquibase);
liquibase.update(new Contexts(), new LabelExpression());
} catch (Exception e) {
//manage exception
e.printStackTrace();
}
}

private void setLiquibaseChangeLogParams(final Liquibase liquibase) {
String clientId = env.getProperty("spring.liquibase.parameters.common-client.client-id", "CommonClientId");
String accessTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.access-token-validity", "2400");
String refreshTokenValidity = env.getProperty("spring.liquibase.parameters.common-client.refresh-token-validity", "1209600");
String authorizedGrantTypes = env.getProperty("spring.liquibase.parameters.common-client.authorized-grant-types", "password,refresh_token");

liquibase.setChangeLogParameter("common-client.client-id", clientId);
liquibase.setChangeLogParameter("common-client.access-token-validity", accessTokenValidity);
liquibase.setChangeLogParameter("common-client.refresh-token-validity", refreshTokenValidity);
liquibase.setChangeLogParameter("common-client.authorized-grant-types", authorizedGrantTypes);
}
}
Loading

0 comments on commit e15c749

Please sign in to comment.