Skip to content

Commit

Permalink
Add "filter" to neural query (opensearch-project#932)
Browse files Browse the repository at this point in the history
* 🩹 add filter to neural query

Signed-off-by: Lorenzo Caenazzo <lorenzo.caenazzo@optionfactory.net>

* CHANGELOG.md

Signed-off-by: Lorenzo Caenazzo <lorenzo.caenazzo@optionfactory.net>

* 💚 spotless fix

Signed-off-by: Lorenzo Caenazzo <lorenzo.caenazzo@optionfactory.net>

* 💚 spotless fix, jdk8 check

Signed-off-by: Lorenzo Caenazzo <lorenzo.caenazzo@optionfactory.net>

* 💚 spotless fix

Signed-off-by: Lorenzo Caenazzo <lorenzo.caenazzo@optionfactory.net>

* Update CHANGELOG.md

Co-authored-by: Andriy Redko <drreta@gmail.com>
Signed-off-by: Grogdunn <frzlollo@gmail.com>

---------

Signed-off-by: Lorenzo Caenazzo <lorenzo.caenazzo@optionfactory.net>
Signed-off-by: Grogdunn <frzlollo@gmail.com>
Co-authored-by: Andriy Redko <drreta@gmail.com>
  • Loading branch information
Grogdunn and reta authored Apr 10, 2024
1 parent 5ad54c6 commit bc78613
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ This section is for maintaining a changelog for all breaking changes for the cli
- Support weight function in function score query ([#880](https://github.com/opensearch-project/opensearch-java/pull/880))
- Fix pattern replace by making flag and replacement optional as on api ([#895](https://github.com/opensearch-project/opensearch-java/pull/895))
- Client with Java 8 runtime and Apache HttpClient 5 Transport fails with java.lang.NoSuchMethodError: java.nio.ByteBuffer.flip()Ljava/nio/ByteBuffer ([#920](https://github.com/opensearch-project/opensearch-java/pull/920))
- Add missed field "filter" to NeuralQuery model class

### Security

Expand Down Expand Up @@ -364,4 +365,4 @@ This section is for maintaining a changelog for all breaking changes for the cli
[2.5.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.4.0...v2.5.0
[2.4.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.3.0...v2.4.0
[2.3.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.2.0...v2.3.0
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0
[2.2.0]: https://github.com/opensearch-project/opensearch-java/compare/v2.1.0...v2.2.0
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public class NeuralQuery extends QueryBase implements QueryVariant {
private final int k;
@Nullable
private final String modelId;
@Nullable
private final Query filter;

private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);
Expand All @@ -35,6 +37,7 @@ private NeuralQuery(NeuralQuery.Builder builder) {
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
this.filter = builder.filter;
}

public static NeuralQuery of(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
Expand Down Expand Up @@ -93,6 +96,16 @@ public final String modelId() {
return this.modelId;
}

/**
* Optional - A query to filter the results of the query.
*
* @return The filter query.
*/
@Nullable
public final Query filter() {
return this.filter;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);
Expand All @@ -107,11 +120,16 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

generator.write("k", this.k);

if (this.filter != null) {
generator.writeKey("filter");
this.filter.serialize(generator, mapper);
}

generator.writeEnd();
}

public Builder toBuilder() {
return new Builder().field(field).queryText(queryText).k(k).modelId(modelId);
return new Builder().field(field).queryText(queryText).k(k).modelId(modelId).filter(filter);
}

/**
Expand All @@ -123,6 +141,8 @@ public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builde
private Integer k;
@Nullable
private String modelId;
@Nullable
private Query filter;

/**
* Required - The target field.
Expand Down Expand Up @@ -169,6 +189,17 @@ public NeuralQuery.Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
* @param filter The filter query.
* @return This builder.
*/
public NeuralQuery.Builder filter(@Nullable Query filter) {
this.filter = filter;
return this;
}

@Override
protected NeuralQuery.Builder self() {
return this;
Expand Down Expand Up @@ -198,6 +229,7 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
public class NeuralQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
NeuralQuery origin = new NeuralQuery.Builder().field("field").queryText("queryText").k(1).build();
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
Expand Down

0 comments on commit bc78613

Please sign in to comment.