From 95524222b28c6166e7f0aa627e71f580c0da5565 Mon Sep 17 00:00:00 2001 From: Ying Mao Date: Fri, 22 Nov 2024 12:18:27 -0500 Subject: [PATCH] Fixing bug setting index when parsing Google Vertex AI results (#117287) (#117358) * Using record ID as index value when parsing Google Vertex AI rerank results * Update docs/changelog/117287.yaml * PR feedback --- docs/changelog/117287.yaml | 5 +++ .../GoogleVertexAiRerankResponseEntity.java | 28 ++++++++++++-- ...ogleVertexAiRerankResponseEntityTests.java | 37 ++++++++++++++++++- 3 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 docs/changelog/117287.yaml diff --git a/docs/changelog/117287.yaml b/docs/changelog/117287.yaml new file mode 100644 index 0000000000000..08da9dd8087b2 --- /dev/null +++ b/docs/changelog/117287.yaml @@ -0,0 +1,5 @@ +pr: 117287 +summary: Fixing bug setting index when parsing Google Vertex AI results +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java index 24946ee5875a5..78673277797d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java @@ -30,6 +30,8 @@ public class GoogleVertexAiRerankResponseEntity { private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Google Vertex AI rerank response"; + private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Google Vertex AI rerank " + + "response but received [%s]"; /** * Parses the Google Vertex AI rerank response. @@ -109,14 +111,27 @@ private static List doParse(XContentParser parser) throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName())); } - return new RankedDocsResults.RankedDoc(index, parsedRankedDoc.score, parsedRankedDoc.content); + if (parsedRankedDoc.id == null) { + throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.ID.getPreferredName())); + } + + try { + return new RankedDocsResults.RankedDoc( + Integer.parseInt(parsedRankedDoc.id), + parsedRankedDoc.score, + parsedRankedDoc.content + ); + } catch (NumberFormatException e) { + throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id)); + } }); } - private record RankedDoc(@Nullable Float score, @Nullable String content) { + private record RankedDoc(@Nullable Float score, @Nullable String content, @Nullable String id) { private static final ParseField CONTENT = new ParseField("content"); private static final ParseField SCORE = new ParseField("score"); + private static final ParseField ID = new ParseField("id"); private static final ObjectParser PARSER = new ObjectParser<>( "google_vertex_ai_rerank_response", true, @@ -126,6 +141,7 @@ private record RankedDoc(@Nullable Float score, @Nullable String content) { static { PARSER.declareString(Builder::setContent, CONTENT); PARSER.declareFloat(Builder::setScore, SCORE); + PARSER.declareString(Builder::setId, ID); } public static RankedDoc parse(XContentParser parser) { @@ -137,6 +153,7 @@ private static final class Builder { private String content; private Float score; + private String id; private Builder() {} @@ -150,8 +167,13 @@ public Builder setContent(String content) { return this; } + public Builder setId(String id) { + this.id = id; + return this; + } + public RankedDoc build() { - return new RankedDoc(score, content); + return new RankedDoc(score, content, id); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java index 32450e3facfd0..7ff79e2618425 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java @@ -39,7 +39,7 @@ public void testFromResponse_CreatesResultsForASingleItem() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2")))); + assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2")))); } public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { @@ -68,7 +68,7 @@ public void testFromResponse_CreatesResultsForMultipleItems() throws IOException assertThat( parsedResults.getRankedDocs(), - is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1"))) + is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1"))) ); } @@ -161,4 +161,37 @@ public void testFromResponse_FailsWhenScoreFieldIsNotPresent() { assertThat(thrownException.getMessage(), is("Failed to find required field [score] in Google Vertex AI rerank response")); } + + public void testFromResponse_FailsWhenIDFieldIsNotInteger() { + String responseJson = """ + { + "records": [ + { + "id": "abcd", + "title": "title 2", + "content": "content 2", + "score": 0.97 + }, + { + "id": "1", + "title": "title 1", + "content": "content 1", + "score": 0.96 + } + ] + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> GoogleVertexAiRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Expected numeric value for record ID field in Google Vertex AI rerank response but received [abcd]") + ); + } }