diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java index 1118dd4cd6..594d306037 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -9,14 +9,17 @@ import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verify; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; import static org.opensearch.sql.util.MatcherUtils.verifySome; import static org.opensearch.sql.util.TestUtils.getResponseBody; +import static org.opensearch.sql.util.TestUtils.roundOfResponse; import java.io.IOException; import java.util.List; import java.util.Locale; +import org.json.JSONArray; import org.json.JSONObject; import org.junit.jupiter.api.Test; import org.opensearch.client.Request; @@ -395,8 +398,9 @@ public void testMaxDoublePushedDown() throws IOException { @Test public void testAvgDoublePushedDown() throws IOException { var response = executeQuery(String.format("SELECT avg(num3)" + " from %s", TEST_INDEX_CALCS)); + JSONArray responseJSON = roundOfResponse(response.getJSONArray("datarows")); verifySchema(response, schema("avg(num3)", null, "double")); - verifyDataRows(response, rows(-6.12D)); + verify(responseJSON, rows(-6.12D)); } @Test @@ -455,8 +459,9 @@ public void testAvgDoubleInMemory() throws IOException { executeQuery( String.format( "SELECT avg(num3)" + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + JSONArray roundOfResponse = roundOfResponse(response.getJSONArray("datarows")); verifySchema(response, schema("avg(num3) OVER(PARTITION BY datetime1)", null, "double")); - verifySome(response.getJSONArray("datarows"), rows(-6.12D)); + verifySome(roundOfResponse, rows(-6.12D)); } @Test diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java index 783fa2db2c..fc391aa057 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/ScoreQueryIT.java @@ -8,7 +8,7 @@ import static org.hamcrest.Matchers.containsString; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; -import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataAddressRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; import java.io.IOException; @@ -71,8 +71,7 @@ public void scoreQueryTest() throws IOException { TestsConstants.TEST_INDEX_ACCOUNT), "jdbc")); verifySchema(result, schema("address", null, "text"), schema("_score", null, "float")); - verifyDataRows( - result, rows("154 Douglass Street", 650.1515), rows("565 Hall Street", 3.2507575)); + verifyDataAddressRows(result, rows("154 Douglass Street"), rows("565 Hall Street")); } @Test @@ -102,7 +101,8 @@ public void scoreQueryDefaultBoostQueryTest() throws IOException { + "where score(matchQuery(address, 'Powell')) order by _score desc limit 2", TestsConstants.TEST_INDEX_ACCOUNT), "jdbc")); + verifySchema(result, schema("address", null, "text"), schema("_score", null, "float")); - verifyDataRows(result, rows("305 Powell Street", 6.501515)); + verifyDataAddressRows(result, rows("305 Powell Street")); } } diff --git a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java index 26a60cb4e5..d4db502407 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java @@ -159,6 +159,11 @@ public static void verifyDataRows(JSONObject response, Matcher... mat verify(response.getJSONArray("datarows"), matchers); } + @SafeVarargs + public static void verifyDataAddressRows(JSONObject response, Matcher... matchers) { + verifyAddressRow(response.getJSONArray("datarows"), matchers); + } + @SafeVarargs public static void verifyColumn(JSONObject response, Matcher... matchers) { verify(response.getJSONArray("schema"), matchers); @@ -183,6 +188,32 @@ public static void verify(JSONArray array, Matcher... matchers) { assertThat(objects, containsInAnyOrder(matchers)); } + // TODO: this is temporary fix for fixing serverless tests to pass as it creates multiple shards + // leading to score differences. + public static void verifyAddressRow(JSONArray array, Matcher... matchers) { + // List to store the processed elements from the JSONArray + List objects = new ArrayList<>(); + + // Iterate through each element in the JSONArray + array + .iterator() + .forEachRemaining( + o -> { + // Check if o is a JSONArray with exactly 2 elements + if (o instanceof JSONArray && ((JSONArray) o).length() == 2) { + // Check if the second element is a BigDecimal/_score value + if (((JSONArray) o).get(1) instanceof BigDecimal) { + // Remove the _score element from response data rows to skip the assertion as it + // will be different when compared against multiple shards + ((JSONArray) o).remove(1); + } + } + objects.add((T) o); + }); + assertEquals(matchers.length, objects.size()); + assertThat(objects, containsInAnyOrder(matchers)); + } + @SafeVarargs @SuppressWarnings("unchecked") public static void verifyInOrder(JSONArray array, Matcher... matchers) { diff --git a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java index a2f4021c1d..027b55e831 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java @@ -17,6 +17,8 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; @@ -27,6 +29,7 @@ import java.util.List; import java.util.Locale; import java.util.stream.Collectors; +import org.json.JSONArray; import org.json.JSONObject; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; @@ -34,6 +37,7 @@ import org.opensearch.client.Client; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.client.RestClient; import org.opensearch.common.xcontent.XContentType; import org.opensearch.sql.legacy.cursor.CursorType; @@ -121,10 +125,45 @@ public static Response performRequest(RestClient client, Request request) { } return response; } catch (IOException e) { + if (isRefreshPolicyError(e)) { + try { + return retryWithoutRefreshPolicy(request, client); + } catch (IOException ex) { + throw new IllegalStateException("Failed to perform request without refresh policy.", ex); + } + } throw new IllegalStateException("Failed to perform request", e); } } + /** + * Checks if the IOException is due to an unsupported refresh policy. + * + * @param e The IOException to check. + * @return true if the exception is due to a refresh policy error, false otherwise. + */ + private static boolean isRefreshPolicyError(IOException e) { + return e instanceof ResponseException + && ((ResponseException) e).getResponse().getStatusLine().getStatusCode() == 400 + && e.getMessage().contains("true refresh policy is not supported."); + } + + /** + * Attempts to perform the request without the refresh policy. + * + * @param request The original request. + * @param client client connection + * @return The response after retrying the request. + * @throws IOException If the request fails. + */ + private static Response retryWithoutRefreshPolicy(Request request, RestClient client) + throws IOException { + Request req = + new Request(request.getMethod(), request.getEndpoint().replaceAll("refresh=true", "")); + req.setEntity(request.getEntity()); + return client.performRequest(req); + } + public static String getAccountIndexMapping() { return "{ \"mappings\": {" + " \"properties\": {\n" @@ -770,6 +809,29 @@ public static String getResponseBody(Response response, boolean retainNewLines) return sb.toString(); } + // TODO: this is temporary fix for fixing serverless tests to pass with 2 digit precision value + public static JSONArray roundOfResponse(JSONArray array) { + JSONArray responseJSON = new JSONArray(); + array + .iterator() + .forEachRemaining( + o -> { + JSONArray jsonArray = new JSONArray(); + ((JSONArray) o) + .iterator() + .forEachRemaining( + i -> { + if (i instanceof BigDecimal) { + jsonArray.put(((BigDecimal) i).setScale(2, RoundingMode.HALF_UP)); + } else { + jsonArray.put(i); + } + }); + responseJSON.put(jsonArray); + }); + return responseJSON; + } + public static String fileToString( final String filePathFromProjectRoot, final boolean removeNewLines) throws IOException {