Skip to content

Commit

Permalink
TRY-198: use apache-commons-math3 lib & new API for subject-based fac…
Browse files Browse the repository at this point in the history
…e verification (#3)
  • Loading branch information
ivan-kripakov-m10 authored Jan 23, 2024
1 parent 98e6177 commit 6c1b106
Show file tree
Hide file tree
Showing 21 changed files with 778 additions and 273 deletions.
8 changes: 4 additions & 4 deletions dev/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM maven:3.6.3-jdk-11-slim as build
FROM maven:3.8.2-eclipse-temurin-17 AS build
ARG ND4J_CLASSIFIER
WORKDIR /workspace/compreface
LABEL intermidiate_frs=true
Expand All @@ -12,14 +12,14 @@ COPY admin admin
COPY common common
RUN mvn package -Dmaven.test.skip=true -Dmaven.site.skip=true -Dmaven.javadoc.skip=true -Dnd4j.classifier=$ND4J_CLASSIFIER

FROM openjdk:11.0.8-jre-slim as frs_core
FROM eclipse-temurin:17-jre-focal AS frs_core
ARG DIR=/workspace/compreface
COPY --from=build ${DIR}/api/target/*.jar /home/app.jar
ENTRYPOINT ["sh","-c","java $API_JAVA_OPTS -agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:5005 -jar /home/app.jar"]

FROM openjdk:11.0.8-jre-slim as frs_crud
FROM eclipse-temurin:17-jre-focal AS frs_crud
ARG DIR=/workspace/compreface
COPY --from=build ${DIR}/admin/target/*.jar /home/app.jar
ARG APPERY_API_KEY
ENV APPERY_API_KEY ${APPERY_API_KEY}
ENTRYPOINT ["sh","-c","java $ADMIN_JAVA_OPTS -agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:5005 -jar /home/app.jar"]
ENTRYPOINT ["sh","-c","java $ADMIN_JAVA_OPTS -agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:5005 -jar /home/app.jar"]
14 changes: 9 additions & 5 deletions java/api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,12 @@
<artifactId>hibernate-types-52</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<classifier>${nd4j.classifier}</classifier>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</dependency>
<dependency>
<groupId>com.impossibl.pgjdbc-ng</groupId>
Expand All @@ -163,6 +162,11 @@
<artifactId>embedded-database-spring-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.zonky.test</groupId>
<artifactId>embedded-postgres</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
import com.google.common.cache.CacheBuilder;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.linear.RealVector;
import org.springframework.stereotype.Component;

import static com.exadel.frs.core.trainservice.system.global.Constants.SERVER_UUID;
Expand Down Expand Up @@ -86,6 +89,10 @@ public void addEmbedding(String apiKey, Embedding embedding) {
notifyCacheEvent(CacheAction.ADD_EMBEDDINGS, apiKey, new AddEmbeddings(List.of(embedding.getId())));
}

public Optional<Map<UUID, RealVector>> getEmbeddings(String apiKey, String subjectName) {
return getOrLoad(apiKey).getEmbeddingsBySubjectName(subjectName);
}

/**
* Method can be used to make changes in cache without sending notification.
* Use it carefully, because changes you do will not be visible for other compreface-api instances
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,162 +4,119 @@
import com.exadel.frs.commonservice.entity.EmbeddingProjection;
import com.exadel.frs.commonservice.entity.EnhancedEmbeddingProjection;
import com.exadel.frs.commonservice.exception.IncorrectImageIdException;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.val;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.Collections;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealVector;
import org.springframework.data.util.Pair;
import org.springframework.lang.NonNull;

@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class EmbeddingCollection {

private final BiMap<EmbeddingProjection, Integer> projection2Index;
private INDArray embeddings;
private final ConcurrentMap<String, Map<UUID, RealVector>> mapping;

public static EmbeddingCollection from(final Stream<EnhancedEmbeddingProjection> stream) {
val rawEmbeddings = new LinkedList<double[]>();
val projections2Index = new HashMap<EmbeddingProjection, Integer>();
val index = new AtomicInteger(); // just to bypass 'final' variables restriction inside lambdas

stream.forEach(projection -> {
projections2Index.put(EmbeddingProjection.from(projection), index.getAndIncrement());
rawEmbeddings.add(projection.getEmbeddingData());
});

return new EmbeddingCollection(
HashBiMap.create(projections2Index),
rawEmbeddings.isEmpty()
? Nd4j.empty()
: Nd4j.create(rawEmbeddings.toArray(double[][]::new))
);
// we copy vector here just in case
var newMap = stream.map(e -> Map.entry(e.getSubjectName(), Pair.of(e.getEmbeddingId(), MatrixUtils.createRealVector(e.getEmbeddingData()))))
.collect(
Collectors.toConcurrentMap(
Entry::getKey,
entry -> {
Map<UUID, RealVector> map = new ConcurrentHashMap<>();
map.put(entry.getValue().getFirst(), entry.getValue().getSecond());
return map;
},
(map1, map2) -> {
map1.putAll(map2);
return map1;
}
)
);
return new EmbeddingCollection(newMap);
}

public Map<Integer, EmbeddingProjection> getIndexMap() {
// returns index to projection map
return Collections.unmodifiableMap(projection2Index.inverse());
public <T> T visit(Function<Map<String, Map<UUID, RealVector>>, T> readAndDo) {
return readAndDo.apply(exposeMap());
}

public Set<EmbeddingProjection> getProjections() {
return Collections.unmodifiableSet(projection2Index.keySet());
// package private for test purposes
Map<String, Map<UUID, RealVector>> exposeMap() {
return Collections.unmodifiableMap(mapping);
}

private int getSize() {
// should be invoked only if underlying array is not empty!
return (int) embeddings.size(0);
}

/**
* NOTE: current method returns COPY! Each time you invoke it, memory consumed, be careful!
*
* @return copy of underlying embeddings array.
*/
public INDArray getEmbeddings() {
return embeddings.dup();
}

public synchronized void updateSubjectName(String oldSubjectName, String newSubjectName) {
final List<EmbeddingProjection> projections = projection2Index.keySet()
.stream()
.filter(projection -> projection.getSubjectName().equals(oldSubjectName))
.collect(Collectors.toList());

projections.forEach(projection -> projection2Index.put(
projection.withNewSubjectName(newSubjectName),
projection2Index.remove(projection)
));
}

public synchronized EmbeddingProjection addEmbedding(final Embedding embedding) {
final var projection = EmbeddingProjection.from(embedding);

final INDArray array = Nd4j.create(new double[][]{embedding.getEmbedding()});

embeddings = embeddings.isEmpty()
? array
: Nd4j.concat(0, embeddings, array);

projection2Index.put(
projection,
getSize() - 1
);

return projection;
public void updateSubjectName(String oldSubjectName, String newSubjectName) {
mapping.put(newSubjectName, mapping.remove(oldSubjectName));
}

public synchronized Collection<EmbeddingProjection> removeEmbeddingsBySubjectName(String subjectName) {
// not efficient at ALL! review current approach!

final List<EmbeddingProjection> toRemove = projection2Index.keySet().stream()
.filter(projection -> projection.getSubjectName().equals(subjectName))
.collect(Collectors.toList());

toRemove.forEach(this::removeEmbedding); // <- rethink

return toRemove;
public EmbeddingProjection addEmbedding(final Embedding embedding) {
var id = embedding.getId();
var realVector = MatrixUtils.createRealVector(embedding.getEmbedding());
mapping.computeIfAbsent(embedding.getSubject().getSubjectName(), k -> new ConcurrentHashMap<>())
.put(id, realVector);
return new EmbeddingProjection(id, embedding.getSubject().getSubjectName());
}

public synchronized EmbeddingProjection removeEmbedding(Embedding embedding) {
return removeEmbedding(EmbeddingProjection.from(embedding));
public void removeEmbeddingsBySubjectName(String subjectName) {
mapping.remove(subjectName);
}

public synchronized EmbeddingProjection removeEmbedding(EmbeddingProjection projection) {
if (projection2Index.isEmpty()) {
return null;
}

var index = projection2Index.remove(projection);

// remove embedding by concatenating sub lists [0, index) + [index + 1, size),
// thus size of resulting array is decreased by one
embeddings = Nd4j.concat(
0,
embeddings.get(NDArrayIndex.interval(0, index), NDArrayIndex.all()),
embeddings.get(NDArrayIndex.interval(index + 1, getSize()), NDArrayIndex.all())
);

// shifting (-1) all indexes, greater than current one
projection2Index.entrySet()
.stream()
.filter(entry -> entry.getValue() > index)
.sorted(Map.Entry.comparingByValue())
.forEach(e -> projection2Index.replace(e.getKey(), e.getValue(), e.getValue() - 1));

return projection;
public EmbeddingProjection removeEmbedding(EmbeddingProjection projection) {
var wasRemoved = new AtomicBoolean(false);
mapping.compute(
projection.getSubjectName(),
(k, v) -> {
if (v == null) {
return null;
}
if (v.remove(projection.getEmbeddingId()) != null) {
wasRemoved.set(true);
if (v.isEmpty()) {
return null;
}
}
return v;
});
return wasRemoved.get() ? projection : null;
}

public synchronized Optional<INDArray> getRawEmbeddingById(UUID embeddingId) {
public Optional<double[]> getRawEmbeddingById(UUID embeddingId) {
return findByEmbeddingId(
embeddingId,
// return duplicated row
entry -> embeddings.getRow(entry.getValue(), true).dup()
entry -> entry.getValue().getValue().toArray()
);
}

public synchronized Optional<String> getSubjectNameByEmbeddingId(UUID embeddingId) {
public Optional<String> getSubjectNameByEmbeddingId(UUID embeddingId) {
return findByEmbeddingId(
embeddingId,
entry -> entry.getKey().getSubjectName()
Entry::getKey
);
}

private <T> Optional<T> findByEmbeddingId(UUID embeddingId, Function<Map.Entry<EmbeddingProjection, Integer>, T> func) {
validImageId(embeddingId);
public Optional<Map<UUID, RealVector>> getEmbeddingsBySubjectName(@NonNull String subjectName) {
return Optional.ofNullable(mapping.get(subjectName));
}

return Optional.ofNullable(projection2Index.entrySet()
.stream()
.filter(entry -> embeddingId.equals(entry.getKey().getEmbeddingId()))
private <T> Optional<T> findByEmbeddingId(UUID embeddingId, Function<Map.Entry<String, Map.Entry<UUID, RealVector>>, T> func) {
validImageId(embeddingId);
return mapping.entrySet().stream()
.filter(entry -> entry.getValue().containsKey(embeddingId))
.findFirst()
.map(func)
.orElseThrow(IncorrectImageIdException::new));
.map(entry -> Map.entry(entry.getKey(), Map.entry(embeddingId, entry.getValue().get(embeddingId))))
.map(func);
}

private void validImageId(UUID embeddingId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.exadel.frs.core.trainservice.component.classifiers.Classifier;
import java.util.List;
import java.util.UUID;

import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
Expand All @@ -38,6 +37,15 @@ public Double verify(final String modelKey, final double[] input, final UUID emb
return classifier.verify(input, modelKey, embeddingId);
}

public List<Pair<UUID, Double>> verifySubject(
final String modelKey,
final double[] input,
String subjectName,
final int resultCount
) {
return classifier.verifySubject(modelKey, input, subjectName, resultCount);
}

public double[] verify(final double[] sourceImageEmbedding, final double[][] targetImageEmbedding) {
return classifier.verify(sourceImageEmbedding, targetImageEmbedding);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,22 @@
import java.io.Serializable;
import java.util.List;
import java.util.UUID;

import org.apache.commons.lang3.tuple.Pair;
import org.springframework.lang.NonNull;

public interface Classifier extends Serializable {

List<Pair<Double, String>> predict(double[] input, String apiKey, int resultCount);

Double verify(double[] input, String apiKey, UUID embeddingId);

@NonNull
List<Pair<UUID, Double>> verifySubject(
@NonNull final String apiKey,
@NonNull final double[] input,
@NonNull final String subjectName,
final int resultCount
);

double[] verify(double[] sourceImageEmbedding, double[][] targetImageEmbedding);
}
Loading

0 comments on commit 6c1b106

Please sign in to comment.