-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Backport-2.x] Backport fix bedrock preprocess func (#2537)
* Fix bedrock connector embedding generation issue Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * add IT Signed-off-by: zane-neo <zaniu@amazon.com> * add ITs Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * change input to fix number format exception in local Signed-off-by: zane-neo <zaniu@amazon.com> * Add log to identify the failure IT root cause Signed-off-by: zane-neo <zaniu@amazon.com> * Update plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java Co-authored-by: Yaliang Wu <ylwu@amazon.com> Signed-off-by: zane-neo <zaniu@amazon.com> * address comments Signed-off-by: zane-neo <zaniu@amazon.com> * fix backport incompatibility Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com> Co-authored-by: Yaliang Wu <ylwu@amazon.com> (cherry picked from commit 210903d)
- Loading branch information
1 parent
f54abcf
commit 243ecb6
Showing
5 changed files
with
248 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.rest; | ||
|
||
import java.io.IOException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Path; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.utils.StringUtils; | ||
|
||
import lombok.SneakyThrows; | ||
|
||
public class RestBedRockInferenceIT extends MLCommonsRestTestCase { | ||
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); | ||
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); | ||
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); | ||
private static final String GITHUB_CI_AWS_REGION = "us-west-2"; | ||
|
||
@SneakyThrows | ||
@Before | ||
public void setup() throws IOException, InterruptedException { | ||
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); | ||
Thread.sleep(20000); | ||
} | ||
|
||
public void test_bedrock_embedding_model() throws Exception { | ||
// Skip test if key is null | ||
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { | ||
return; | ||
} | ||
String templates = Files | ||
.readString( | ||
Path | ||
.of( | ||
RestMLPredictionAction.class | ||
.getClassLoader() | ||
.getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") | ||
.toURI() | ||
) | ||
); | ||
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class); | ||
for (Map.Entry<String, Object> templateEntry : templateMap.entrySet()) { | ||
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); | ||
String testCaseName = templateEntry.getKey(); | ||
String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); | ||
String modelId = registerRemoteModel( | ||
String | ||
.format( | ||
StringUtils.gson.toJson(templateEntry.getValue()), | ||
GITHUB_CI_AWS_REGION, | ||
AWS_ACCESS_KEY_ID, | ||
AWS_SECRET_ACCESS_KEY, | ||
AWS_SESSION_TOKEN | ||
), | ||
bedrockEmbeddingModelName, | ||
true | ||
); | ||
|
||
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); | ||
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); | ||
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); | ||
assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); | ||
List output = (List) inferenceResult.get("inference_results"); | ||
assertEquals(errorMsg, 2, output.size()); | ||
assertTrue(errorMsg, output.get(0) instanceof Map); | ||
assertTrue(errorMsg, output.get(1) instanceof Map); | ||
validateOutput(errorMsg, (Map) output.get(0)); | ||
validateOutput(errorMsg, (Map) output.get(1)); | ||
} | ||
} | ||
|
||
private void validateOutput(String errorMsg, Map<String, Object> output) { | ||
assertTrue(errorMsg, output.containsKey("output")); | ||
assertTrue(errorMsg, output.get("output") instanceof List); | ||
List outputList = (List) output.get("output"); | ||
assertEquals(errorMsg, 1, outputList.size()); | ||
assertTrue(errorMsg, outputList.get(0) instanceof Map); | ||
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List); | ||
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size()); | ||
} | ||
} |
63 changes: 63 additions & 0 deletions
63
plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
{ | ||
"without_step_size": { | ||
"name": "Amazon Bedrock Connector: embedding", | ||
"description": "The connector to bedrock Titan embedding model", | ||
"version": 1, | ||
"protocol": "aws_sigv4", | ||
"parameters": { | ||
"region": "%s", | ||
"service_name": "bedrock", | ||
"model_name": "amazon.titan-embed-text-v1" | ||
}, | ||
"credential": { | ||
"access_key": "%s", | ||
"secret_key": "%s", | ||
"session_token": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", | ||
"headers": { | ||
"content-type": "application/json", | ||
"x-amz-content-sha256": "required" | ||
}, | ||
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }", | ||
"pre_process_function": "connector.pre_process.bedrock.embedding", | ||
"post_process_function": "connector.post_process.bedrock.embedding" | ||
} | ||
] | ||
}, | ||
"with_step_size": { | ||
"name": "Amazon Bedrock Connector: embedding", | ||
"description": "The connector to bedrock Titan embedding model", | ||
"version": 1, | ||
"protocol": "aws_sigv4", | ||
"parameters": { | ||
"region": "%s", | ||
"service_name": "bedrock", | ||
"model_name": "amazon.titan-embed-text-v1", | ||
"input_docs_processed_step_size": "1" | ||
}, | ||
"credential": { | ||
"access_key": "%s", | ||
"secret_key": "%s", | ||
"session_token": "%s" | ||
}, | ||
"actions": [ | ||
{ | ||
"action_type": "predict", | ||
"method": "POST", | ||
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", | ||
"headers": { | ||
"content-type": "application/json", | ||
"x-amz-content-sha256": "required" | ||
}, | ||
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }", | ||
"pre_process_function": "connector.pre_process.bedrock.embedding", | ||
"post_process_function": "connector.post_process.bedrock.embedding" | ||
} | ||
] | ||
} | ||
} |