Skip to content

Commit

Permalink
Merge pull request #289 from weaviate/add-support-for-multi-target-se…
Browse files Browse the repository at this point in the history
…arch

Add support for multi target vector search
  • Loading branch information
antas-marcin authored Jul 19, 2024
2 parents 05db887 + 6b800ba commit e8e7c06
Show file tree
Hide file tree
Showing 29 changed files with 722 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class HybridArgument implements Argument {
String[] properties;
String[] targetVectors;
Searches searches;

Targets targets;

@Override
public String build() {
Expand Down Expand Up @@ -59,6 +59,9 @@ public String build() {
}
arg.add(String.format("searches:{%s}", String.join(" ", searchesArgs)));
}
if (targets != null) {
arg.add(String.format("%s", targets.build()));
}

return String.format("hybrid:{%s}", String.join(" ", arg));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class NearAudioArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;

@Override
public String build() {
Expand All @@ -30,6 +31,7 @@ public String build() {
.targetVectors(targetVectors)
.data(audio)
.dataFile(audioFile)
.targets(targets)
.mediaField("audio")
.mediaName("nearAudio")
.build().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class NearDepthArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;

@Override
public String build() {
Expand All @@ -30,6 +31,7 @@ public String build() {
.targetVectors(targetVectors)
.data(depth)
.dataFile(depthFile)
.targets(targets)
.mediaField("depth")
.mediaName("nearDepth")
.build().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class NearImageArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;

@Override
public String build() {
Expand All @@ -30,6 +31,7 @@ public String build() {
.targetVectors(targetVectors)
.data(image)
.dataFile(imageFile)
.targets(targets)
.mediaField("image")
.mediaName("nearImage")
.build().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class NearImuArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;

@Override
public String build() {
Expand All @@ -30,6 +31,7 @@ public String build() {
.targetVectors(targetVectors)
.data(imu)
.dataFile(imuFile)
.targets(targets)
.mediaField("imu")
.mediaName("nearIMU")
.build().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class NearMediaArgumentHelper {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;


public String build() {
Expand All @@ -44,6 +45,9 @@ public String build() {
if (ArrayUtils.isNotEmpty(targetVectors)) {
fields.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors)));
}
if (targets != null) {
fields.add(String.format("%s", targets.build()));
}

return String.format("%s:{%s}", mediaName, String.join(" ", fields));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class NearObjectArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;

@Override
public String build() {
Expand All @@ -44,6 +45,9 @@ public String build() {
if (ArrayUtils.isNotEmpty(targetVectors)) {
arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors)));
}
if (targets != null) {
arg.add(String.format("%s", targets.build()));
}

return String.format("nearObject:{%s}", String.join(" ", arg));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class NearTextArgument implements Argument {
NearTextMoveParameters moveAwayFrom;
Boolean autocorrect;
String[] targetVectors;
Targets targets;

private String buildMoveParam(String name, NearTextMoveParameters moveParam) {
Set<String> arg = new LinkedHashSet<>();
Expand Down Expand Up @@ -81,6 +82,9 @@ public String build() {
if (ArrayUtils.isNotEmpty(targetVectors)) {
arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors)));
}
if (targets != null) {
arg.add(String.format("%s", targets.build()));
}

return String.format("nearText:{%s}", String.join(" ", arg));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class NearThermalArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;

@Override
public String build() {
Expand All @@ -30,6 +31,7 @@ public String build() {
.targetVectors(targetVectors)
.data(thermal)
.dataFile(thermalFile)
.targets(targets)
.mediaField("thermal")
.mediaName("nearThermal")
.build().build();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.weaviate.client.v1.graphql.query.argument;

import io.weaviate.client.v1.graphql.query.util.Serializer;
import java.util.Map;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
Expand All @@ -23,6 +24,8 @@ public class NearVectorArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Map<String, Float[]> vectorPerTarget;
Targets targets;

@Override
public String build() {
Expand All @@ -40,6 +43,16 @@ public String build() {
if (ArrayUtils.isNotEmpty(targetVectors)) {
arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors)));
}
if (vectorPerTarget != null && !vectorPerTarget.isEmpty()) {
Set<String> vectorPerTargetArg = new LinkedHashSet<>();
for (Map.Entry<String, Float[]> entry : vectorPerTarget.entrySet()) {
vectorPerTargetArg.add(String.format("%s:%s", entry.getKey(), Serializer.array(entry.getValue())));
}
arg.add(String.format("vectorPerTarget:{%s}", String.join(" ", vectorPerTargetArg)));
}
if (targets != null) {
arg.add(String.format("%s", targets.build()));
}

return String.format("nearVector:{%s}", String.join(" ", arg));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class NearVideoArgument implements Argument {
Float certainty;
Float distance;
String[] targetVectors;
Targets targets;

@Override
public String build() {
Expand All @@ -30,6 +31,7 @@ public String build() {
.targetVectors(targetVectors)
.data(video)
.dataFile(videoFile)
.targets(targets)
.mediaField("video")
.mediaName("nearVideo")
.build().build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.weaviate.client.v1.graphql.query.argument;

import io.weaviate.client.v1.graphql.query.util.Serializer;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.ArrayUtils;

@Getter
@Builder
@ToString
@EqualsAndHashCode
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class Targets {
CombinationMethod combinationMethod;
String[] targetVectors;
Map<String, Float> weights;

public enum CombinationMethod {
minimum("minimum"),
average("average"),
sum("sum"),
manualWeights("manualWeights"),
relativeScore("relativeScore");

private final String type;

CombinationMethod(String type) {
this.type = type;
}

@Override
public String toString() {
return type;
}
}

String build() {
Set<String> arg = new LinkedHashSet<>();

if (combinationMethod != null) {
arg.add(String.format("combinationMethod:%s", combinationMethod.name()));
}
if (ArrayUtils.isNotEmpty(targetVectors)) {
arg.add(String.format("targetVectors:%s", Serializer.arrayWithQuotes(targetVectors)));
}
if (weights != null && !weights.isEmpty()) {
Set<String> weightsArg = new LinkedHashSet<>();
for (Map.Entry<String, Float> entry : weights.entrySet()) {
weightsArg.add(String.format("%s:%s", entry.getKey(), entry.getValue()));
}
arg.add(String.format("weights:{%s}", String.join(" ", weightsArg)));
}

return String.format("targets:{%s}", String.join(" ", arg));
}
}

Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package io.weaviate.client.v1.graphql.query.argument;

import java.util.LinkedHashMap;
import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThat;

public class HybridArgumentTest {

Expand Down Expand Up @@ -130,4 +131,24 @@ public void shouldCreateArgumentWithNearTextSearches() {

assertThat(str).isEqualTo("hybrid:{query:\"I'm a simple string\" searches:{nearText:{concepts:[\"concept\"] certainty:0.9}}}");
}

@Test
public void shouldCreateArgumentWithTargets() {
// given
LinkedHashMap<String, Float> weights = new LinkedHashMap<>();
weights.put("t1", 0.8f);
weights.put("t2", 0.2f);
Targets targets = Targets.builder()
.targetVectors(new String[]{ "t1", "t2" })
.combinationMethod(Targets.CombinationMethod.minimum)
.weights(weights)
.build();
HybridArgument hybrid = HybridArgument.builder()
.query("I'm a simple string").targets(targets)
.build();
// when
String str = hybrid.build();
// then
assertThat(str).isEqualTo("hybrid:{query:\"I'm a simple string\" targets:{combinationMethod:minimum targetVectors:[\"t1\",\"t2\"] weights:{t1:0.8 t2:0.2}}}");
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.weaviate.client.v1.graphql.query.argument;

import java.util.LinkedHashMap;
import org.junit.Test;

import java.io.File;
Expand Down Expand Up @@ -123,4 +124,28 @@ public void shouldBuildEmptyDueToNotSet() {

assertThat(nearAudio).isEqualTo("nearAudio:{}");
}

@Test
public void shouldBuildFromBase64WithTargets() {
// given
LinkedHashMap<String, Float> weights = new LinkedHashMap<>();
weights.put("t1", 0.8f);
weights.put("t2", 0.2f);
Targets targets = Targets.builder()
.targetVectors(new String[]{ "t1", "t2" })
.combinationMethod(Targets.CombinationMethod.minimum)
.weights(weights)
.build();
NearTextArgument nearText = NearTextArgument.builder()
.concepts(new String[]{"concept"}).targets(targets).build();

String audioBase64 = "iVBORw0KGgoAAAANS";

String nearAudio = NearAudioArgument.builder()
.audio(audioBase64)
.targets(targets)
.build().build();

assertThat(nearAudio).isEqualTo(String.format("nearAudio:{audio:\"%s\" targets:{combinationMethod:minimum targetVectors:[\"t1\",\"t2\"] weights:{t1:0.8 t2:0.2}}}", audioBase64));
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.weaviate.client.v1.graphql.query.argument;

import java.util.LinkedHashMap;
import org.junit.Test;

import java.io.File;
Expand Down Expand Up @@ -123,4 +124,25 @@ public void shouldBuildEmptyDueToNotSet() {

assertThat(nearDepth).isEqualTo("nearDepth:{}");
}

@Test
public void shouldBuildFromBase64WithTargets() {
LinkedHashMap<String, Float> weights = new LinkedHashMap<>();
weights.put("t1", 0.8f);
weights.put("t2", 0.2f);
Targets targets = Targets.builder()
.targetVectors(new String[]{ "t1", "t2" })
.combinationMethod(Targets.CombinationMethod.minimum)
.weights(weights)
.build();

String depthBase64 = "iVBORw0KGgoAAAANS";

String nearDepth = NearDepthArgument.builder()
.depth(depthBase64)
.targets(targets)
.build().build();

assertThat(nearDepth).isEqualTo(String.format("nearDepth:{depth:\"%s\" targets:{combinationMethod:minimum targetVectors:[\"t1\",\"t2\"] weights:{t1:0.8 t2:0.2}}}", depthBase64));
}
}
Loading

0 comments on commit e8e7c06

Please sign in to comment.