Skip to content

Commit

Permalink
Add Support for Hybrid Query Type (opensearch-project#850)
Browse files Browse the repository at this point in the history
* Add Support for Hybrid Query Type

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Add samples, guide and integ tests

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Removing wildcard imports

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Adding import

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Adding import

Signed-off-by: Varun Jain <varunudr@amazon.com>

---------

Signed-off-by: Varun Jain <varunudr@amazon.com>
  • Loading branch information
vibrantvarun authored Feb 20, 2024
1 parent c5cbc00 commit 821dae6
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ This section is for maintaining a changelog for all breaking changes for the cli

### Added
- Add search role type for nodes in cluster stats ([#848](https://github.com/opensearch-project/opensearch-java/pull/848))
- Add support for Hybrid query type ([#850](https://github.com/opensearch-project/opensearch-java/pull/850))

### Dependencies

Expand Down
19 changes: 19 additions & 0 deletions guides/search.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ for (int i = 0; i < searchResponse.hits().hits().size(); i++) {
}
```

### Search documents using a hybrid query
```java
Query searchQuery = Query.of(
h -> h.hybrid(
q -> q.queries(Arrays.asList(
new MatchQuery.Builder().field("text").query(FieldValue.of("Text for document 2")).build().toQuery(),
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding").queryText("Hi world").modelId("bQ1J8ooBpBj3wT4HVUsb").k(100).build().toQuery()
)
)
)
);
SearchRequest searchRequest = new SearchRequest.Builder().query(searchQuery).build();
SearchResponse<IndexData> searchResponse = client.search(searchRequest, IndexData.class);
for (var hit : searchResponse.hits().hits()) {
LOGGER.info("Found {} with score {}", hit.source(), hit.score());
}
```

### Search documents using suggesters

[AppData](../samples/src/main/java/org/opensearch/client/samples/util/AppData.java) refers to the sample data class used in the below samples.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import java.util.List;
import java.util.function.Function;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

public class HybridQuery extends QueryBase implements QueryVariant {
private final List<Query> queries;

private HybridQuery(HybridQuery.Builder builder) {
super(builder);
this.queries = ApiTypeHelper.unmodifiable(builder.queries);
}

public static HybridQuery of(Function<HybridQuery.Builder, ObjectBuilder<HybridQuery>> fn) {
return fn.apply(new HybridQuery.Builder()).build();
}

/**
* Required - list of search queries.
*
* @return list of queries provided under hybrid clause.
*/
public final List<Query> queries() {
return this.queries;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
super.serializeInternal(generator, mapper);
generator.writeKey("queries");
generator.writeStartArray();
for (Query item0 : this.queries) {
item0.serialize(generator, mapper);
}
generator.writeEnd();
}

@Override
public Query.Kind _queryKind() {
return Query.Kind.Hybrid;
}

public HybridQuery.Builder toBuilder() {
return new HybridQuery.Builder().queries(queries);
}

public static class Builder extends QueryBase.AbstractBuilder<HybridQuery.Builder> implements ObjectBuilder<HybridQuery> {
private List<Query> queries;

/**
* API name: {@code hybrid}
* <p>
* Adds all elements of <code>list</code> to <code>hybrid</code>.
*/
public final HybridQuery.Builder queries(List<Query> list) {
this.queries = _listAddAll(this.queries, list);
return this;
}

/**
* API name: {@code hybrid}
* <p>
* Adds one or more values to <code>hybrid</code>.
*/
public final HybridQuery.Builder queries(Query value, Query... values) {
this.queries = _listAdd(this.queries, value, values);
return this;
}

/**
* API name: {@code hybrid}
* <p>
* Adds a value to <code>hybrid</code> using a builder lambda.
*/
public final HybridQuery.Builder queries(Function<Query.Builder, ObjectBuilder<Query>> fn) {
return queries(fn.apply(new Query.Builder()).build());
}

@Override
protected Builder self() {
return this;
}

@Override
public HybridQuery build() {
_checkSingleUse();
return new HybridQuery(this);
}
}

public static final JsonpDeserializer<HybridQuery> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
HybridQuery.Builder::new,
HybridQuery::setupHybridQueryDeserializer
);

protected static void setupHybridQueryDeserializer(ObjectDeserializer<HybridQuery.Builder> op) {
setupQueryBaseDeserializer(op);
op.add(HybridQuery.Builder::queries, JsonpDeserializer.arrayDeserializer(Query._DESERIALIZER), "queries");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ public enum Kind implements JsonEnum {

Neural("neural"),

Hybrid("hybrid"),

ParentId("parent_id"),

Percolate("percolate"),
Expand Down Expand Up @@ -725,6 +727,23 @@ public NeuralQuery neural() {
return TaggedUnionUtils.get(this, Kind.Neural);
}

/**
* Is this variant instance of kind {@code hybrid}?
*/
public boolean isHybrid() {
return _kind == Kind.Hybrid;
}

/**
* Get the {@code hybrid} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code hybrid} kind.
*/
public HybridQuery hybrid() {
return TaggedUnionUtils.get(this, Kind.Hybrid);
}

/**
* Is this variant instance of kind {@code parent_id}?
*/
Expand Down Expand Up @@ -1510,6 +1529,16 @@ public ObjectBuilder<Query> neural(Function<NeuralQuery.Builder, ObjectBuilder<N
return this.neural(fn.apply(new NeuralQuery.Builder()).build());
}

public ObjectBuilder<Query> hybrid(HybridQuery v) {
this._kind = Kind.Hybrid;
this._value = v;
return this;
}

public ObjectBuilder<Query> hybrid(Function<HybridQuery.Builder, ObjectBuilder<HybridQuery>> fn) {
return this.hybrid(fn.apply(new HybridQuery.Builder()).build());
}

public ObjectBuilder<Query> parentId(ParentIdQuery v) {
this._kind = Kind.ParentId;
this._value = v;
Expand Down Expand Up @@ -1818,6 +1847,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match");
op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested");
op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural");
op.add(Builder::hybrid, HybridQuery._DESERIALIZER, "hybrid");
op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id");
op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate");
op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ public static NeuralQuery.Builder neural() {
return new NeuralQuery.Builder();
}

/**
* Creates a builder for the {@link HybridQuery nested} {@code Query} variant.
*/
public static HybridQuery.Builder hybrid() {
return new HybridQuery.Builder();
}

/**
* Creates a builder for the {@link ParentIdQuery parent_id} {@code Query}
* variant.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.opensearch.client.opensearch._types.query_dsl;

import java.util.Arrays;
import org.junit.Test;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch.model.ModelTestCase;

public class HybridQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
HybridQuery origin = new HybridQuery.Builder().queries(
Arrays.asList(
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding")
.queryText("Hi world")
.modelId("bQ1J8ooBpBj3wT4HVUsb")
.k(100)
.build()
.toQuery(),
new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery()
)
).build();
HybridQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@

package org.opensearch.client.opensearch.model;

import java.util.Arrays;
import org.junit.Test;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.mapping.Property;
import org.opensearch.client.opensearch._types.mapping.TypeMapping;
import org.opensearch.client.opensearch._types.query_dsl.KnnQuery;
import org.opensearch.client.opensearch._types.query_dsl.NeuralQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch._types.query_dsl.QueryBuilders;
import org.opensearch.client.opensearch._types.query_dsl.TermQuery;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.indices.GetMappingResponse;

Expand Down Expand Up @@ -243,4 +248,57 @@ public void testNeuralQueryFromJson() {
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}

@Test
public void testHybridQuery() {

Query query = Query.of(
h -> h.hybrid(
q -> q.queries(
Arrays.asList(
new TermQuery.Builder().field("passage_text").value(FieldValue.of("Foo bar")).build().toQuery(),
new NeuralQuery.Builder().field("passage_embedding")
.queryText("Hi world")
.modelId("bQ1J8ooBpBj3wT4HVUsb")
.k(100)
.build()
.toQuery(),
new KnnQuery.Builder().field("passage_embedding").vector(new float[] { 0.01f, 0.02f }).k(2).build().toQuery()
)
)
)
);
SearchRequest searchRequest = SearchRequest.of(s -> s.query(query));
assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field());
assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
}

@Test
public void testHybridQueryFromJson() {

String json = "{\"query\""
+ ":{\"hybrid\":{\"queries\":[{\"term\":{\"passage_text\":\"Foo bar\"}},"
+ "{\"neural\":{\"passage_embedding\":{\"query_text\":\"Hi world\",\"model_id\":\"bQ1J8ooBpBj3wT4HVUsb\",\"k\":100}}},"
+ "{\"knn\":{\"passage_embedding\":{\"vector\":[0.01,0.02],\"k\":2}}}]}},\"size\":10"
+ "}";

SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper);

assertEquals("passage_text", searchRequest.query().hybrid().queries().get(0).term().field());
assertEquals("Foo bar", searchRequest.query().hybrid().queries().get(0).term().value().stringValue());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(1).neural().field());
assertEquals("Hi world", searchRequest.query().hybrid().queries().get(1).neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().hybrid().queries().get(1).neural().modelId());
assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k());
assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field());
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length);
assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k());
}
}
Loading

0 comments on commit 821dae6

Please sign in to comment.