From b6a5d5af533ed230ee06b96b58f803e5cc6bf541 Mon Sep 17 00:00:00 2001 From: Charles Marion Date: Fri, 18 Oct 2024 15:15:15 -0500 Subject: [PATCH 1/3] bug: Exclude internal calls from WAF rate limiter (#591) --- lib/chatbot-api/index.ts | 46 +++++++++++++++++-- tests/__snapshots__/cdk-app.test.ts.snap | 45 ++++++++++++++++++ .../chatbot-api-construct.test.ts.snap | 45 ++++++++++++++++++ 3 files changed, 133 insertions(+), 3 deletions(-) diff --git a/lib/chatbot-api/index.ts b/lib/chatbot-api/index.ts index 88adda71..3c3dc5bd 100644 --- a/lib/chatbot-api/index.ts +++ b/lib/chatbot-api/index.ts @@ -1,5 +1,6 @@ import * as cognito from "aws-cdk-lib/aws-cognito"; import * as dynamodb from "aws-cdk-lib/aws-dynamodb"; +import * as ec2 from "aws-cdk-lib/aws-ec2"; import * as s3 from "aws-cdk-lib/aws-s3"; import * as sqs from "aws-cdk-lib/aws-sqs"; import * as sns from "aws-cdk-lib/aws-sns"; @@ -110,7 +111,10 @@ export class ChatBotApi extends Construct { name: "WafAppsync", rules: [ ...props.shared.webACLRules, - ...this.createWafRules(props.config.llms.rateLimitPerIP ?? 100), + ...this.createWafRules( + props.config.llms.rateLimitPerIP ?? 100, + props.shared.vpc + ), ], }).attrArn, resourceArn: api.arn, @@ -175,7 +179,10 @@ export class ChatBotApi extends Construct { ]); } - private createWafRules(llmRatePerIP: number): wafv2.CfnWebACL.RuleProperty[] { + private createWafRules( + llmRatePerIP: number, + vpc: ec2.Vpc + ): wafv2.CfnWebACL.RuleProperty[] { /** * The rate limit is the maximum number of requests from a * single IP address that are allowed in a ten-minute period. @@ -242,6 +249,39 @@ export class ChatBotApi extends Construct { metricName: "LimitRequestsPerIP", }, }; - return [ruleLimitRequests]; + + // The following rule is disabling throttling for calls coming from the VPC. + const eips: string[] = []; + vpc.node.findAll().forEach((resource) => { + if (resource instanceof ec2.CfnEIP) { + // NAT Gateways IP + eips.push(resource.attrPublicIp + "/32"); + } + }); + + const vpcnIpSet = new wafv2.CfnIPSet(this, "VPCPublicIPs", { + addresses: eips, + ipAddressVersion: "IPV4", + scope: "REGIONAL", + }); + + const allowInternalCalls: wafv2.CfnWebACL.RuleProperty = { + name: "AllowInternalCalls", + priority: 2, + action: { + allow: {}, + }, + statement: { + ipSetReferenceStatement: { + arn: vpcnIpSet.attrArn, + }, + }, + visibilityConfig: { + sampledRequestsEnabled: false, + cloudWatchMetricsEnabled: false, + metricName: "AllowInternalCalls", + }, + }; + return [ruleLimitRequests, allowInternalCalls]; } } diff --git a/tests/__snapshots__/cdk-app.test.ts.snap b/tests/__snapshots__/cdk-app.test.ts.snap index ba4a1dc4..33aa500c 100644 --- a/tests/__snapshots__/cdk-app.test.ts.snap +++ b/tests/__snapshots__/cdk-app.test.ts.snap @@ -4057,6 +4057,29 @@ schema { }, "Type": "AWS::IAM::Policy", }, + "ChatBotApiVPCPublicIPsAE6206D3": { + "Properties": { + "Addresses": [ + { + "Fn::Join": [ + "", + [ + { + "Fn::GetAtt": [ + "SharedVPCpublicSubnet1EIPA9E2FC1C", + "PublicIp", + ], + }, + "/32", + ], + ], + }, + ], + "IPAddressVersion": "IPV4", + "Scope": "REGIONAL", + }, + "Type": "AWS::WAFv2::IPSet", + }, "ChatBotApiWafAppsync9FEB4E22": { "Properties": { "DefaultAction": { @@ -4145,6 +4168,28 @@ schema { "SampledRequestsEnabled": true, }, }, + { + "Action": { + "Allow": {}, + }, + "Name": "AllowInternalCalls", + "Priority": 2, + "Statement": { + "IPSetReferenceStatement": { + "Arn": { + "Fn::GetAtt": [ + "ChatBotApiVPCPublicIPsAE6206D3", + "Arn", + ], + }, + }, + }, + "VisibilityConfig": { + "CloudWatchMetricsEnabled": false, + "MetricName": "AllowInternalCalls", + "SampledRequestsEnabled": false, + }, + }, ], "Scope": "REGIONAL", "VisibilityConfig": { diff --git a/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap b/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap index 7bb87d8f..7695a49a 100644 --- a/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap +++ b/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap @@ -4119,6 +4119,29 @@ schema { }, "Type": "AWS::IAM::Policy", }, + "ChatBotApiConstructVPCPublicIPs5C63CEBB": { + "Properties": { + "Addresses": [ + { + "Fn::Join": [ + "", + [ + { + "Fn::GetAtt": [ + "SharedVPCpublicSubnet1EIPA9E2FC1C", + "PublicIp", + ], + }, + "/32", + ], + ], + }, + ], + "IPAddressVersion": "IPV4", + "Scope": "REGIONAL", + }, + "Type": "AWS::WAFv2::IPSet", + }, "ChatBotApiConstructWafAppsyncE6AACFFE": { "Properties": { "DefaultAction": { @@ -4207,6 +4230,28 @@ schema { "SampledRequestsEnabled": true, }, }, + { + "Action": { + "Allow": {}, + }, + "Name": "AllowInternalCalls", + "Priority": 2, + "Statement": { + "IPSetReferenceStatement": { + "Arn": { + "Fn::GetAtt": [ + "ChatBotApiConstructVPCPublicIPs5C63CEBB", + "Arn", + ], + }, + }, + }, + "VisibilityConfig": { + "CloudWatchMetricsEnabled": false, + "MetricName": "AllowInternalCalls", + "SampledRequestsEnabled": false, + }, + }, ], "Scope": "REGIONAL", "VisibilityConfig": { From a1d2aa9a269ab5fc5276dd23968dbb65ab8d2d1e Mon Sep 17 00:00:00 2001 From: Charles Marion Date: Mon, 21 Oct 2024 09:06:57 -0500 Subject: [PATCH 2/3] feat: Disable Sagemaker endpoint (or cross-encoder per workspace) (#588) --------- Co-authored-by: Ajay Lamba Co-authored-by: Bigad Soleiman --- bin/config.ts | 8 +- cli/magic-config.ts | 73 ++++++++--- .../chatbot-api/aurora_workspace_test.py | 75 +++++++++--- lib/aws-genai-llm-chatbot-stack.ts | 115 ++++++++---------- .../api-handler/routes/workspaces.py | 14 +-- lib/chatbot-api/rest-api.ts | 4 +- lib/chatbot-api/schema/schema.graphql | 8 +- lib/model-interfaces/langchain/index.ts | 4 +- lib/rag-engines/data-import/index.ts | 6 +- lib/rag-engines/index.ts | 13 +- lib/rag-engines/sagemaker-rag-models/index.ts | 47 +++---- .../python/genai_core/aurora/query.py | 54 ++++---- .../python/genai_core/opensearch/query.py | 54 ++++---- lib/shared/types.ts | 22 ++-- lib/user-interface/index.ts | 21 ++-- lib/user-interface/private-website.ts | 2 - lib/user-interface/public-website.ts | 2 - .../common/api-client/workspaces-client.ts | 8 +- .../common/helpers/embeddings-model-helper.ts | 12 +- .../src/common/helpers/options-helper.ts | 15 ++- .../components/chatbot/chat-input-panel.tsx | 19 ++- .../rag/create-workspace/aurora-form.tsx | 24 +++- .../create-workspace-aurora.tsx | 26 ++-- .../create-workspace-opensearch.tsx | 26 ++-- .../cross-encoder-selector-field.tsx | 39 +++--- .../embeddings-selector-field.tsx | 6 +- .../create-workspace/hybrid-search-field.tsx | 5 +- .../pages/rag/create-workspace/kb-form.tsx | 1 + .../rag/create-workspace/opensearch-form.tsx | 24 +++- .../src/pages/rag/embeddings/embeddings.tsx | 1 + .../workspace/aurora-workspace-settings.tsx | 4 +- .../open-search-workspace-settings.tsx | 4 +- tests/__snapshots__/cdk-app.test.ts.snap | 10 +- .../chatbot-api-construct.test.ts.snap | 10 +- tests/utils/config-util.ts | 1 + 35 files changed, 468 insertions(+), 289 deletions(-) diff --git a/bin/config.ts b/bin/config.ts index 72496dd2..5d119147 100644 --- a/bin/config.ts +++ b/bin/config.ts @@ -3,7 +3,9 @@ import { existsSync, readFileSync } from "fs"; export function getConfig(): SystemConfig { if (existsSync("./bin/config.json")) { - return JSON.parse(readFileSync("./bin/config.json").toString("utf8")); + return JSON.parse( + readFileSync("./bin/config.json").toString("utf8") + ) as SystemConfig; } // Default config return { @@ -48,11 +50,13 @@ export function getConfig(): SystemConfig { provider: "sagemaker", name: "intfloat/multilingual-e5-large", dimensions: 1024, + default: false, }, { provider: "sagemaker", name: "sentence-transformers/all-MiniLM-L6-v2", dimensions: 384, + default: false, }, { provider: "bedrock", @@ -80,8 +84,10 @@ export function getConfig(): SystemConfig { provider: "openai", name: "text-embedding-ada-002", dimensions: 1536, + default: false, }, ], + crossEncodingEnabled: false, crossEncoderModels: [ { provider: "sagemaker", diff --git a/cli/magic-config.ts b/cli/magic-config.ts index 974e7d8d..f414269f 100644 --- a/cli/magic-config.ts +++ b/cli/magic-config.ts @@ -10,6 +10,7 @@ import { SupportedSageMakerModels, SystemConfig, SupportedBedrockRegion, + ModelConfig, } from "../lib/shared/types"; import { LIB_VERSION } from "./version.js"; import * as fs from "fs"; @@ -34,7 +35,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] { function getCountryCodesAndNames(): { message: string; name: string }[] { // Use country-list to get an array of countries with their codes and names const countries = getData(); - // Map the country data to match the desired output structure const countryInfo = countries.map(({ code, name }) => { return { message: `${name} (${code})`, name: code }; @@ -88,21 +88,24 @@ const secretManagerArnRegExp = RegExp( /arn:aws:secretsmanager:[\w-_]+:\d+:secret:[\w-_]+/ ); -const embeddingModels = [ +const embeddingModels: ModelConfig[] = [ { provider: "sagemaker", name: "intfloat/multilingual-e5-large", dimensions: 1024, + default: false, }, { provider: "sagemaker", name: "sentence-transformers/all-MiniLM-L6-v2", dimensions: 384, + default: false, }, { provider: "bedrock", name: "amazon.titan-embed-text-v1", dimensions: 1536, + default: false, }, //Support for inputImage is not yet implemented for amazon.titan-embed-image-v1 { @@ -124,6 +127,7 @@ const embeddingModels = [ provider: "openai", name: "text-embedding-ada-002", dimensions: 1536, + default: false, }, ]; @@ -179,6 +183,8 @@ const embeddingModels = [ options.startScheduleEndDate = config.llms?.sagemakerSchedule?.startScheduleEndDate; options.enableRag = config.rag.enabled; + options.deployDefaultSagemakerModels = + config.rag.deployDefaultSagemakerModels; options.ragsToEnable = Object.keys(config.rag.engines ?? {}).filter( (v: string) => ( @@ -608,6 +614,16 @@ async function processCreateOptions(options: any): Promise { message: "Do you want to enable RAG", initial: options.enableRag || false, }, + { + type: "confirm", + name: "deployDefaultSagemakerModels", + message: + "Do you want to deploy the default embedding and cross-encoder models via SageMaker?", + initial: options.deployDefaultSagemakerModels || false, + skip(): boolean { + return !(this as any).state.answers.enableRag; + }, + }, { type: "multiselect", name: "ragsToEnable", @@ -810,10 +826,17 @@ async function processCreateOptions(options: any): Promise { choices: embeddingModels.map((m) => ({ name: m.name, value: m })), initial: options.defaultEmbedding, validate(value: string) { + const embeding = embeddingModels.find((i) => i.name === value); + if ( + embeding && + (this as any).state.answers.deployDefaultSagemakerModels === false && + embeding?.provider === "sagemaker" + ) { + return "SageMaker default models are not enabled. Please select another model."; + } if ((this as any).state.answers.enableRag) { return value ? true : "Select a default embedding model"; } - return true; }, skip() { @@ -1219,6 +1242,7 @@ async function processCreateOptions(options: any): Promise { } : undefined, llms: { + enableSagemakerModels: answers.enableSagemakerModels, rateLimitPerAIP: advancedSettings?.llmRateLimitPerIP ? Number(advancedSettings?.llmRateLimitPerIP) : undefined, @@ -1241,6 +1265,7 @@ async function processCreateOptions(options: any): Promise { }, rag: { enabled: answers.enableRag, + deployDefaultSagemakerModels: answers.deployDefaultSagemakerModels, engines: { aurora: { enabled: answers.ragsToEnable.includes("aurora"), @@ -1259,28 +1284,40 @@ async function processCreateOptions(options: any): Promise { external: [{}], }, }, - embeddingsModels: [{}], - crossEncoderModels: [{}], + embeddingsModels: [] as ModelConfig[], + crossEncoderModels: [] as ModelConfig[], }, }; + if (config.rag.enabled && config.rag.deployDefaultSagemakerModels) { + config.rag.crossEncoderModels[0] = { + provider: "sagemaker", + name: "cross-encoder/ms-marco-MiniLM-L-12-v2", + default: true, + }; + config.rag.embeddingsModels = embeddingModels; + } else if (config.rag.enabled) { + config.rag.embeddingsModels = embeddingModels.filter( + (model) => model.provider !== "sagemaker" + ); + for (const model of config.rag.embeddingsModels) { + model.default = model.name === models.defaultEmbedding; + } + } else { + config.rag.embeddingsModels = []; + } + // If we have not enabled rag the default embedding is set to the first model if (!answers.enableRag) { - models.defaultEmbedding = embeddingModels[0].name; + (config.rag.embeddingsModels[0] as any).default = true; + } else { + config.rag.embeddingsModels.forEach((m: any) => { + if (m.name === models.defaultEmbedding) { + m.default = true; + } + }); } - config.rag.crossEncoderModels[0] = { - provider: "sagemaker", - name: "cross-encoder/ms-marco-MiniLM-L-12-v2", - default: true, - }; - config.rag.embeddingsModels = embeddingModels; - config.rag.embeddingsModels.forEach((m: any) => { - if (m.name === models.defaultEmbedding) { - m.default = true; - } - }); - config.rag.engines.kendra.createIndex = answers.ragsToEnable.includes("kendra"); config.rag.engines.kendra.enabled = diff --git a/integtests/chatbot-api/aurora_workspace_test.py b/integtests/chatbot-api/aurora_workspace_test.py index df0b1d7b..d5a7ba26 100644 --- a/integtests/chatbot-api/aurora_workspace_test.py +++ b/integtests/chatbot-api/aurora_workspace_test.py @@ -10,8 +10,8 @@ def run_before_and_after_tests(client: AppSyncClient): for workspace in client.list_workspaces(): if ( workspace.get("name") == "INTEG_TEST_AURORA" - and workspace.get("status") == "ready" - ): + or workspace.get("name") == "INTEG_TEST_AURORA_WITHOUT_RERANK" + ) and workspace.get("status") == "ready": client.delete_workspace(workspace.get("id")) @@ -22,23 +22,25 @@ def test_create(client: AppSyncClient, default_embed_model): if engine.get("enabled") == False: pytest.skip_flag = True pytest.skip("Aurora is not enabled.") - pytest.workspace = client.create_aurora_workspace( - input={ - "kind": "auro2", - "name": "INTEG_TEST_AURORA", - "embeddingsModelProvider": "bedrock", - "embeddingsModelName": default_embed_model, - "crossEncoderModelName": "cross-encoder/ms-marco-MiniLM-L-12-v2", - "crossEncoderModelProvider": "sagemaker", - "languages": ["english"], - "index": True, - "hybridSearch": True, - "metric": "inner", - "chunkingStrategy": "recursive", - "chunkSize": 1000, - "chunkOverlap": 200, - } - ) + input = { + "kind": "auro2", + "name": "INTEG_TEST_AURORA_WITHOUT_RERANK", + "embeddingsModelProvider": "bedrock", + "embeddingsModelName": default_embed_model, + "languages": ["english"], + "index": True, + "hybridSearch": True, + "metric": "inner", + "chunkingStrategy": "recursive", + "chunkSize": 1000, + "chunkOverlap": 200, + } + input_with_rerank = input.copy() + input_with_rerank["name"] = "INTEG_TEST_AURORA" + input_with_rerank["crossEncoderModelName"] = "cross-encoder/ms-marco-MiniLM-L-12-v2" + input_with_rerank["crossEncoderModelProvider"] = "sagemaker" + pytest.workspace = client.create_aurora_workspace(input=input_with_rerank) + pytest.workspace_no_re_rank = client.create_aurora_workspace(input=input) ready = False retries = 0 @@ -56,6 +58,7 @@ def test_create(client: AppSyncClient, default_embed_model): def test_add_rss(client: AppSyncClient): if pytest.skip_flag == True: pytest.skip("Aurora is not enabled.") + pytest.document = client.add_rss_feed( input={ "workspaceId": pytest.workspace.get("id"), @@ -67,6 +70,17 @@ def test_add_rss(client: AppSyncClient): "limit": 2, } ) + client.add_rss_feed( + input={ + "workspaceId": pytest.workspace_no_re_rank.get("id"), + "title": "INTEG_TEST_AURORA_TITLE", + "address": "https://github.com/aws-samples/aws-genai-llm-chatbot/" + + "releases.atom", + "contentTypes": ["text/html"], + "followLinks": True, + "limit": 2, + } + ) ready = False retries = 0 @@ -137,6 +151,29 @@ def test_search_document(client: AppSyncClient): assert ready == True +def test_search_document_no_reank(client: AppSyncClient): + if pytest.skip_flag == True: + pytest.skip("Aurora is not enabled.") + ready = False + retries = 0 + # Wait for the page to be crawled. This starts on a cron every 5 min. + while not ready and retries < 50: + time.sleep(15) + retries += 1 + result = client.semantic_search( + input={ + "workspaceId": pytest.workspace_no_re_rank.get("id"), + "query": "Release github", + } + ) + if len(result.get("items")) > 1: + ready = True + assert result.get("engine") == "aurora" + # Re-ranking score is no set but the results are ordered by Aurora. + assert result.get("items")[0].get("score") is None + assert ready == True + + def test_query_llm(client, default_model, default_provider): if pytest.skip_flag == True: pytest.skip("Aurora is not enabled.") diff --git a/lib/aws-genai-llm-chatbot-stack.ts b/lib/aws-genai-llm-chatbot-stack.ts index c7383789..9478449d 100644 --- a/lib/aws-genai-llm-chatbot-stack.ts +++ b/lib/aws-genai-llm-chatbot-stack.ts @@ -159,10 +159,6 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { api: chatBotApi, chatbotFilesBucket: chatBotApi.filesBucket, uploadBucket: ragEngines?.uploadBucket, - crossEncodersEnabled: - typeof ragEngines?.sageMakerRagModels?.model !== "undefined", - sagemakerEmbeddingsEnabled: - typeof ragEngines?.sageMakerRagModels?.model !== "undefined", }); if (props.config.cognitoFederation?.enabled) { @@ -417,10 +413,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { ] ); - if ( - props.config.rag.engines.aurora.enabled || - props.config.rag.engines.opensearch.enabled - ) { + if (ragEngines?.sageMakerRagModels?.model) { NagSuppressions.addResourceSuppressionsByPath( this, [ @@ -456,59 +449,59 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { }, ] ); - if (props.config.rag.engines.aurora.enabled) { - NagSuppressions.addResourceSuppressionsByPath( - this, - `/${this.stackName}/RagEngines/AuroraPgVector/AuroraDatabase/Secret/Resource`, - [ - { - id: "AwsSolutions-SMG4", - reason: "Secret created implicitly by CDK.", - }, - ] - ); - NagSuppressions.addResourceSuppressionsByPath( - this, - [ - `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupFunction/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/DefaultPolicy/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, - `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspace/Role/DefaultPolicy/Resource`, - ], - [ - { - id: "AwsSolutions-IAM4", - reason: "IAM role implicitly created by CDK.", - }, - { - id: "AwsSolutions-IAM5", - reason: "IAM role implicitly created by CDK.", - }, - ] - ); - } - if (props.config.rag.engines.opensearch.enabled) { - NagSuppressions.addResourceSuppressionsByPath( - this, - [ - `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/Resource`, - `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, - `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspace/Role/DefaultPolicy/Resource`, - ], - [ - { - id: "AwsSolutions-IAM4", - reason: "IAM role implicitly created by CDK.", - }, - { - id: "AwsSolutions-IAM5", - reason: "IAM role implicitly created by CDK.", - }, - ] - ); - } + } + if (props.config.rag.engines.aurora.enabled) { + NagSuppressions.addResourceSuppressionsByPath( + this, + `/${this.stackName}/RagEngines/AuroraPgVector/AuroraDatabase/Secret/Resource`, + [ + { + id: "AwsSolutions-SMG4", + reason: "Secret created implicitly by CDK.", + }, + ] + ); + NagSuppressions.addResourceSuppressionsByPath( + this, + [ + `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupFunction/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/DatabaseSetupProvider/framework-onEvent/ServiceRole/DefaultPolicy/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, + `/${this.stackName}/RagEngines/AuroraPgVector/CreateAuroraWorkspace/CreateAuroraWorkspace/Role/DefaultPolicy/Resource`, + ], + [ + { + id: "AwsSolutions-IAM4", + reason: "IAM role implicitly created by CDK.", + }, + { + id: "AwsSolutions-IAM5", + reason: "IAM role implicitly created by CDK.", + }, + ] + ); + } + if (props.config.rag.engines.opensearch.enabled) { + NagSuppressions.addResourceSuppressionsByPath( + this, + [ + `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/Resource`, + `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspaceFunction/ServiceRole/DefaultPolicy/Resource`, + `/${this.stackName}/RagEngines/OpenSearchVector/CreateOpenSearchWorkspace/CreateOpenSearchWorkspace/Role/DefaultPolicy/Resource`, + ], + [ + { + id: "AwsSolutions-IAM4", + reason: "IAM role implicitly created by CDK.", + }, + { + id: "AwsSolutions-IAM5", + reason: "IAM role implicitly created by CDK.", + }, + ] + ); } if (props.config.rag.engines.kendra.enabled) { NagSuppressions.addResourceSuppressionsByPath( diff --git a/lib/chatbot-api/functions/api-handler/routes/workspaces.py b/lib/chatbot-api/functions/api-handler/routes/workspaces.py index 5be4a850..bdc3188b 100644 --- a/lib/chatbot-api/functions/api-handler/routes/workspaces.py +++ b/lib/chatbot-api/functions/api-handler/routes/workspaces.py @@ -1,4 +1,4 @@ -from typing import Annotated, List +from typing import Annotated, List, Optional from common.constant import ( SAFE_SHORT_STR_VALIDATION, ) @@ -28,8 +28,8 @@ class CreateWorkspaceAuroraRequest(BaseModel): name: str = Field(min_length=1, max_length=100, pattern=name_regex) embeddingsModelProvider: str = SAFE_SHORT_STR_VALIDATION embeddingsModelName: str = SAFE_SHORT_STR_VALIDATION - crossEncoderModelProvider: str = SAFE_SHORT_STR_VALIDATION - crossEncoderModelName: str = SAFE_SHORT_STR_VALIDATION + crossEncoderModelProvider: Optional[str] = SAFE_SHORT_STR_VALIDATION + crossEncoderModelName: Optional[str] = SAFE_SHORT_STR_VALIDATION languages: List[Annotated[str, SAFE_SHORT_STR_VALIDATION]] metric: str = SAFE_SHORT_STR_VALIDATION index: bool @@ -44,8 +44,8 @@ class CreateWorkspaceOpenSearchRequest(BaseModel): name: str = Field(min_length=1, max_length=100, pattern=name_regex) embeddingsModelProvider: str = SAFE_SHORT_STR_VALIDATION embeddingsModelName: str = SAFE_SHORT_STR_VALIDATION - crossEncoderModelProvider: str = SAFE_SHORT_STR_VALIDATION - crossEncoderModelName: str = SAFE_SHORT_STR_VALIDATION + crossEncoderModelProvider: Optional[str] = SAFE_SHORT_STR_VALIDATION + crossEncoderModelName: Optional[str] = SAFE_SHORT_STR_VALIDATION languages: List[Annotated[str, SAFE_SHORT_STR_VALIDATION]] hybridSearch: bool chunkingStrategy: str = SAFE_SHORT_STR_VALIDATION @@ -165,7 +165,7 @@ def _create_workspace_aurora(request: CreateWorkspaceAuroraRequest, config: dict if embeddings_model is None: raise genai_core.types.CommonError("Embeddings model not found") - if cross_encoder_model is None: + if request.crossEncoderModelName is not None and cross_encoder_model is None: raise genai_core.types.CommonError("Cross encoder model not found") embeddings_model_dimensions = embeddings_model["dimensions"] @@ -232,7 +232,7 @@ def _create_workspace_open_search( if embeddings_model is None: raise genai_core.types.CommonError("Embeddings model not found") - if cross_encoder_model is None: + if request.crossEncoderModelName is not None and cross_encoder_model is None: raise genai_core.types.CommonError("Cross encoder model not found") embeddings_model_dimensions = embeddings_model["dimensions"] diff --git a/lib/chatbot-api/rest-api.ts b/lib/chatbot-api/rest-api.ts index 21349316..946af150 100644 --- a/lib/chatbot-api/rest-api.ts +++ b/lib/chatbot-api/rest-api.ts @@ -94,7 +94,7 @@ export class ApiResolvers extends Construct { DOCUMENTS_BY_STATUS_INDEX: props.ragEngines?.documentsByStatusIndexName ?? "", SAGEMAKER_RAG_MODELS_ENDPOINT: - props.ragEngines?.sageMakerRagModels?.model.endpoint + props.ragEngines?.sageMakerRagModels?.model?.endpoint ?.attrEndpointName ?? "", DELETE_WORKSPACE_WORKFLOW_ARN: props.ragEngines?.deleteWorkspaceWorkflow?.stateMachineArn ?? "", @@ -275,7 +275,7 @@ export class ApiResolvers extends Construct { props.ragEngines.deleteDocumentWorkflow.grantStartExecution(apiHandler); } - if (props.ragEngines?.sageMakerRagModels) { + if (props.ragEngines?.sageMakerRagModels?.model) { apiHandler.addToRolePolicy( new iam.PolicyStatement({ actions: ["sagemaker:InvokeEndpoint"], diff --git a/lib/chatbot-api/schema/schema.graphql b/lib/chatbot-api/schema/schema.graphql index 69f1f4e0..bdba6963 100644 --- a/lib/chatbot-api/schema/schema.graphql +++ b/lib/chatbot-api/schema/schema.graphql @@ -5,8 +5,8 @@ input CreateWorkspaceAuroraInput { kind: String! embeddingsModelProvider: String! embeddingsModelName: String! - crossEncoderModelProvider: String! - crossEncoderModelName: String! + crossEncoderModelProvider: String + crossEncoderModelName: String languages: [String!]! metric: String! index: Boolean! @@ -35,8 +35,8 @@ input CreateWorkspaceOpenSearchInput { kind: String! embeddingsModelProvider: String! embeddingsModelName: String! - crossEncoderModelProvider: String! - crossEncoderModelName: String! + crossEncoderModelProvider: String + crossEncoderModelName: String languages: [String!]! hybridSearch: Boolean! chunkingStrategy: String! diff --git a/lib/model-interfaces/langchain/index.ts b/lib/model-interfaces/langchain/index.ts index 31c1f730..4926d6e3 100644 --- a/lib/model-interfaces/langchain/index.ts +++ b/lib/model-interfaces/langchain/index.ts @@ -66,7 +66,7 @@ export class LangChainInterface extends Construct { props.ragEngines?.auroraPgVector?.database?.clusterEndpoint?.port + "", SAGEMAKER_RAG_MODELS_ENDPOINT: - props.ragEngines?.sageMakerRagModels?.model.endpoint + props.ragEngines?.sageMakerRagModels?.model?.endpoint ?.attrEndpointName ?? "", OPEN_SEARCH_COLLECTION_ENDPOINT: props.ragEngines?.openSearchVector?.openSearchCollectionEndpoint ?? @@ -147,7 +147,7 @@ export class LangChainInterface extends Construct { props.ragEngines.documentsTable.grantReadWriteData(requestHandler); } - if (props.ragEngines?.sageMakerRagModels) { + if (props.ragEngines?.sageMakerRagModels?.model) { requestHandler.addToRolePolicy( new iam.PolicyStatement({ actions: ["sagemaker:InvokeEndpoint"], diff --git a/lib/rag-engines/data-import/index.ts b/lib/rag-engines/data-import/index.ts index 7654ab85..083d5720 100644 --- a/lib/rag-engines/data-import/index.ts +++ b/lib/rag-engines/data-import/index.ts @@ -167,7 +167,7 @@ export class DataImport extends Construct { processingBucket, auroraDatabase: props.auroraDatabase, ragDynamoDBTables: props.ragDynamoDBTables, - sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint, + sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model?.endpoint, openSearchVector: props.openSearchVector, } ); @@ -193,7 +193,7 @@ export class DataImport extends Construct { processingBucket, auroraDatabase: props.auroraDatabase, ragDynamoDBTables: props.ragDynamoDBTables, - sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint, + sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model?.endpoint, openSearchVector: props.openSearchVector, } ); @@ -248,7 +248,7 @@ export class DataImport extends Construct { DOCUMENTS_BY_COMPOUND_KEY_INDEX_NAME: props.documentsByCompoundKeyIndexName ?? "", SAGEMAKER_RAG_MODELS_ENDPOINT: - props.sageMakerRagModels?.model.endpoint.attrEndpointName ?? "", + props.sageMakerRagModels?.model?.endpoint.attrEndpointName ?? "", FILE_IMPORT_WORKFLOW_ARN: fileImportWorkflow?.stateMachine.stateMachineArn ?? "", DEFAULT_KENDRA_S3_DATA_SOURCE_BUCKET_NAME: diff --git a/lib/rag-engines/index.ts b/lib/rag-engines/index.ts index 6fdef3b6..4080a152 100644 --- a/lib/rag-engines/index.ts +++ b/lib/rag-engines/index.ts @@ -44,15 +44,10 @@ export class RagEngines extends Construct { }); let sageMakerRagModels: SageMakerRagModels | null = null; - if ( - props.config.rag.engines.aurora.enabled || - props.config.rag.engines.opensearch.enabled - ) { - sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", { - shared: props.shared, - config: props.config, - }); - } + sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", { + shared: props.shared, + config: props.config, + }); let auroraPgVector: AuroraPgVector | null = null; if (props.config.rag.engines.aurora.enabled) { diff --git a/lib/rag-engines/sagemaker-rag-models/index.ts b/lib/rag-engines/sagemaker-rag-models/index.ts index f935be8e..f12ec0ff 100644 --- a/lib/rag-engines/sagemaker-rag-models/index.ts +++ b/lib/rag-engines/sagemaker-rag-models/index.ts @@ -24,27 +24,32 @@ export class SageMakerRagModels extends Construct { .filter((c) => c.provider === "sagemaker") .map((c) => c.name); - const model = new SageMakerModel(this, "Model", { - vpc: props.shared.vpc, - region: cdk.Aws.REGION, - logRetention: props.config.logRetention, - kmsKey: props.shared.kmsKey, - // NVMe based instances (like ml.g4dn.xlarge) do not support KMS encryption - // They instead use an hardware module for encryption - // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/data-protection.html#encryption-rest - enableEndpointKMSEncryption: false, - retainOnDelete: props.config.retainOnDelete, - model: { - type: DeploymentType.CustomInferenceScript, - modelId: [ - ...sageMakerEmbeddingsModelIds, - ...sageMakerCrossEncoderModelIds, - ], - codeFolder: path.join(__dirname, "./model"), - instanceType: "ml.g4dn.xlarge", - }, - }); + if ( + sageMakerEmbeddingsModelIds?.length > 0 || + sageMakerCrossEncoderModelIds?.length > 0 + ) { + const model = new SageMakerModel(this, "Model", { + vpc: props.shared.vpc, + region: cdk.Aws.REGION, + logRetention: props.config.logRetention, + kmsKey: props.shared.kmsKey, + // NVMe based instances (like ml.g4dn.xlarge) do not support KMS encryption + // They instead use an hardware module for encryption + // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/data-protection.html#encryption-rest + enableEndpointKMSEncryption: false, + retainOnDelete: props.config.retainOnDelete, + model: { + type: DeploymentType.CustomInferenceScript, + modelId: [ + ...sageMakerEmbeddingsModelIds, + ...sageMakerCrossEncoderModelIds, + ], + codeFolder: path.join(__dirname, "./model"), + instanceType: "ml.g4dn.xlarge", + }, + }); - this.model = model; + this.model = model; + } } } diff --git a/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py b/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py index cf3553be..ba22a85d 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py +++ b/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py @@ -38,13 +38,6 @@ def query_workspace_aurora( if selected_model is None: raise CommonError("Embeddings model not found") - cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( - cross_encoder_model_provider, cross_encoder_model_name - ) - - if cross_encoder_model is None: - raise CommonError("Cross encoder model not found") - query_embeddings = genai_core.embeddings.generate_embeddings( selected_model, [query], Task.RETRIEVE )[0] @@ -186,24 +179,33 @@ def query_workspace_aurora( item["keyword_search_score"] = current["keyword_search_score"] unique_items = list(unique_items.values()) - score_dict = dict({}) - if len(unique_items) > 0: - passages = [record["content"] for record in unique_items] - passage_scores = genai_core.cross_encoder.rank_passages( - cross_encoder_model, query, passages + + if cross_encoder_model_name is not None: + cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( + cross_encoder_model_provider, cross_encoder_model_name ) - for i in range(len(unique_items)): - score = passage_scores[i] - unique_items[i]["score"] = score - score_dict[unique_items[i]["chunk_id"]] = score + if cross_encoder_model is None: + raise genai_core.types.CommonError("Cross encoder model not found") - unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) + score_dict = dict({}) + if len(unique_items) > 0: + passages = [record["content"] for record in unique_items] + passage_scores = genai_core.cross_encoder.rank_passages( + cross_encoder_model, query, passages + ) - for record in vector_search_records: - record["score"] = score_dict[record["chunk_id"]] - for record in keyword_search_records: - record["score"] = score_dict[record["chunk_id"]] + for i in range(len(unique_items)): + score = passage_scores[i] + unique_items[i]["score"] = score + score_dict[unique_items[i]["chunk_id"]] = score + + unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) + + for record in vector_search_records: + record["score"] = score_dict[record["chunk_id"]] + for record in keyword_search_records: + record["score"] = score_dict[record["chunk_id"]] if full_response: unique_items = unique_items[:limit] @@ -218,9 +220,13 @@ def query_workspace_aurora( "keyword_search_items": convert_types(keyword_search_records), } else: - ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[ - :limit - ] + if cross_encoder_model_name is not None: + ret_items = list( + filter(lambda val: val["score"] > threshold, unique_items) + )[:limit] + else: + ret_items = unique_items[:limit] + if len(ret_items) < limit: # inner product metric is negative hence we sort ascending if metric == "inner": diff --git a/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py b/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py index 96aafbff..c584417d 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py +++ b/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py @@ -37,13 +37,6 @@ def query_workspace_open_search( if selected_model is None: raise CommonError("Embeddings model not found") - cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( - cross_encoder_model_provider, cross_encoder_model_name - ) - - if cross_encoder_model is None: - raise CommonError("Cross encoder model not found") - query_embeddings = genai_core.embeddings.generate_embeddings( selected_model, [query], Task.RETRIEVE )[0] @@ -96,23 +89,32 @@ def query_workspace_open_search( item["keyword_search_score"] = current["keyword_search_score"] unique_items = list(unique_items.values()) - score_dict = dict({}) - if len(unique_items) > 0: - passages = [record["content"] for record in unique_items] - passage_scores = genai_core.cross_encoder.rank_passages( - cross_encoder_model, query, passages + + if cross_encoder_model_name is not None: + cross_encoder_model = genai_core.cross_encoder.get_cross_encoder_model( + cross_encoder_model_provider, cross_encoder_model_name ) - for i in range(len(unique_items)): - score = passage_scores[i] - unique_items[i]["score"] = score - score_dict[unique_items[i]["chunk_id"]] = score - unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) + if cross_encoder_model is None: + raise genai_core.types.CommonError("Cross encoder model not found") - for record in vector_search_records: - record["score"] = score_dict[record["chunk_id"]] - for record in keyword_search_records: - record["score"] = score_dict[record["chunk_id"]] + score_dict = dict({}) + if len(unique_items) > 0: + passages = [record["content"] for record in unique_items] + passage_scores = genai_core.cross_encoder.rank_passages( + cross_encoder_model, query, passages + ) + + for i in range(len(unique_items)): + score = passage_scores[i] + unique_items[i]["score"] = score + score_dict[unique_items[i]["chunk_id"]] = score + unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) + + for record in vector_search_records: + record["score"] = score_dict[record["chunk_id"]] + for record in keyword_search_records: + record["score"] = score_dict[record["chunk_id"]] if full_response: unique_items = unique_items[:limit] @@ -125,9 +127,13 @@ def query_workspace_open_search( "keyword_search_items": keyword_search_records, } else: - ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[ - :limit - ] + if cross_encoder_model_name is not None: + ret_items = list( + filter(lambda val: val["score"] > threshold, unique_items) + )[:limit] + else: + ret_items = unique_items[:limit] + if len(ret_items) < limit and len(unique_items) > len(ret_items): unique_items = list( filter( diff --git a/lib/shared/types.ts b/lib/shared/types.ts index 74ca0ace..a0c78dfc 100644 --- a/lib/shared/types.ts +++ b/lib/shared/types.ts @@ -70,6 +70,13 @@ export enum Direction { Out = "OUT", } +export interface ModelConfig { + provider: ModelProvider; + name: string; + dimensions?: number; + default?: boolean; +} + export interface SystemConfig { prefix: string; createCMKs?: boolean; @@ -132,6 +139,7 @@ export interface SystemConfig { }; rag: { enabled: boolean; + deployDefaultSagemakerModels?: boolean; engines: { aurora: { enabled: boolean; @@ -160,17 +168,9 @@ export interface SystemConfig { }[]; }; }; - embeddingsModels: { - provider: ModelProvider; - name: string; - dimensions: number; - default?: boolean; - }[]; - crossEncoderModels: { - provider: ModelProvider; - name: string; - default?: boolean; - }[]; + embeddingsModels: ModelConfig[]; + crossEncodingEnabled: boolean; + crossEncoderModels: ModelConfig[]; }; } diff --git a/lib/user-interface/index.ts b/lib/user-interface/index.ts index efbe40db..02f48560 100644 --- a/lib/user-interface/index.ts +++ b/lib/user-interface/index.ts @@ -26,8 +26,6 @@ export interface UserInterfaceProps { readonly api: ChatBotApi; readonly chatbotFilesBucket: s3.Bucket; readonly uploadBucket?: s3.Bucket; - readonly crossEncodersEnabled: boolean; - readonly sagemakerEmbeddingsEnabled: boolean; } export class UserInterface extends Construct { @@ -95,6 +93,9 @@ export class UserInterface extends Construct { redirectSignIn = `https://${this.publishedDomain}`; } + const sagemakerEmbedingModels = props.config.rag.embeddingsModels.filter( + (i) => i.provider === "sagemaker" + ); const exportsAsset = s3deploy.Source.jsonData("aws-exports.json", { aws_project_region: cdk.Aws.REGION, aws_cognito_region: cdk.Aws.REGION, @@ -126,12 +127,16 @@ export class UserInterface extends Construct { } : undefined, rag_enabled: props.config.rag.enabled, - cross_encoders_enabled: props.crossEncodersEnabled, - sagemaker_embeddings_enabled: props.sagemakerEmbeddingsEnabled, - default_embeddings_model: Utils.getDefaultEmbeddingsModel(props.config), - default_cross_encoder_model: Utils.getDefaultCrossEncoderModel( - props.config - ), + cross_encoders_enabled: props.config.rag.crossEncoderModels.length > 0, + sagemaker_embeddings_enabled: sagemakerEmbedingModels.length > 0, + default_embeddings_model: + props.config.rag.embeddingsModels.length > 0 + ? Utils.getDefaultEmbeddingsModel(props.config) + : undefined, + default_cross_encoder_model: + props.config.rag.crossEncoderModels.length > 0 + ? Utils.getDefaultCrossEncoderModel(props.config) + : undefined, privateWebsite: props.config.privateWebsite ? true : false, }, }); diff --git a/lib/user-interface/private-website.ts b/lib/user-interface/private-website.ts index 05c4be04..436e2de3 100644 --- a/lib/user-interface/private-website.ts +++ b/lib/user-interface/private-website.ts @@ -18,8 +18,6 @@ export interface PrivateWebsiteProps { readonly userPoolClientId: string; readonly api: ChatBotApi; readonly chatbotFilesBucket: s3.Bucket; - readonly crossEncodersEnabled: boolean; - readonly sagemakerEmbeddingsEnabled: boolean; readonly websiteBucket: s3.Bucket; } diff --git a/lib/user-interface/public-website.ts b/lib/user-interface/public-website.ts index 39693cf8..7067f8d9 100644 --- a/lib/user-interface/public-website.ts +++ b/lib/user-interface/public-website.ts @@ -16,8 +16,6 @@ export interface PublicWebsiteProps { readonly userPoolId: string; readonly userPoolClientId: string; readonly api: ChatBotApi; - readonly crossEncodersEnabled: boolean; - readonly sagemakerEmbeddingsEnabled: boolean; readonly websiteBucket: s3.Bucket; readonly chatbotFilesBucket: s3.Bucket; readonly uploadBucket?: s3.Bucket; diff --git a/lib/user-interface/react-app/src/common/api-client/workspaces-client.ts b/lib/user-interface/react-app/src/common/api-client/workspaces-client.ts index 7f72e82d..cb5adcac 100644 --- a/lib/user-interface/react-app/src/common/api-client/workspaces-client.ts +++ b/lib/user-interface/react-app/src/common/api-client/workspaces-client.ts @@ -56,8 +56,8 @@ export class WorkspacesClient { name: string; embeddingsModelProvider: string; embeddingsModelName: string; - crossEncoderModelProvider: string; - crossEncoderModelName: string; + crossEncoderModelProvider?: string; + crossEncoderModelName?: string; languages: string[]; metric: string; index: boolean; @@ -79,8 +79,8 @@ export class WorkspacesClient { name: string; embeddingsModelProvider: string; embeddingsModelName: string; - crossEncoderModelProvider: string; - crossEncoderModelName: string; + crossEncoderModelProvider?: string; + crossEncoderModelName?: string; languages: string[]; hybridSearch: boolean; chunkingStrategy: string; diff --git a/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts b/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts index 93902dcc..56e1bd30 100644 --- a/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts +++ b/lib/user-interface/react-app/src/common/helpers/embeddings-model-helper.ts @@ -1,5 +1,6 @@ import { SelectProps } from "@cloudscape-design/components"; import { EmbeddingModel } from "../../API"; +import { AppConfig } from "../types"; export abstract class EmbeddingsModelHelper { static getSelectOption(model?: string): SelectProps.Option | null { @@ -32,9 +33,18 @@ export abstract class EmbeddingsModelHelper { }; } - static getSelectOptions(embeddingsModels: EmbeddingModel[]) { + static getSelectOptions( + appContext: AppConfig | null, + embeddingsModels: EmbeddingModel[] + ) { const modelsMap = new Map(); embeddingsModels.forEach((model) => { + if ( + model.provider === "sagemaker" && + !appContext?.config.sagemaker_embeddings_enabled + ) { + return; + } let items = modelsMap.get(model.provider); if (!items) { items = []; diff --git a/lib/user-interface/react-app/src/common/helpers/options-helper.ts b/lib/user-interface/react-app/src/common/helpers/options-helper.ts index e06d3151..922bb18f 100644 --- a/lib/user-interface/react-app/src/common/helpers/options-helper.ts +++ b/lib/user-interface/react-app/src/common/helpers/options-helper.ts @@ -33,8 +33,9 @@ export abstract class OptionsHelper { } static getSelectOptionGroups( - data: T[] - ) { + data: T[], + addNone: boolean = false + ): (SelectProps.OptionGroup | SelectProps.Option)[] { const modelsMap = new Map(); data.forEach((item) => { let items = modelsMap.get(item.provider); @@ -63,6 +64,16 @@ export abstract class OptionsHelper { }; }); + if (addNone) { + return [ + { + label: "None", + value: "__none__", + }, + ...options, + ]; + } + return options; } diff --git a/lib/user-interface/react-app/src/components/chatbot/chat-input-panel.tsx b/lib/user-interface/react-app/src/components/chatbot/chat-input-panel.tsx index 8888b004..d2ec9611 100644 --- a/lib/user-interface/react-app/src/components/chatbot/chat-input-panel.tsx +++ b/lib/user-interface/react-app/src/components/chatbot/chat-input-panel.tsx @@ -699,9 +699,13 @@ function getSelectedModelOption(models: Model[]): SelectProps.Option | null { ); if (targetModel) { - selectedModelOption = OptionsHelper.getSelectOptionGroups([ - targetModel, - ])[0].options[0]; + const groups = OptionsHelper.getSelectOptionGroups([targetModel]).filter( + (i) => (i as SelectProps.OptionGroup).options + ) as SelectProps.OptionGroup[]; + selectedModelOption = + groups.length > 0 && groups[0].options.length > 0 + ? groups[0].options[0] + : null; } } @@ -745,8 +749,13 @@ function getSelectedModelOption(models: Model[]): SelectProps.Option | null { } if (candidate) { - selectedModelOption = OptionsHelper.getSelectOptionGroups([candidate])[0] - .options[0]; + const groups = OptionsHelper.getSelectOptionGroups([candidate]).filter( + (i) => (i as SelectProps.OptionGroup).options + ) as SelectProps.OptionGroup[]; + selectedModelOption = + groups.length > 0 && groups[0].options.length > 0 + ? groups[0].options[0] + : null; } } diff --git a/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx b/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx index 09533e86..2cc57103 100644 --- a/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx +++ b/lib/user-interface/react-app/src/pages/rag/create-workspace/aurora-form.tsx @@ -17,9 +17,11 @@ import { LanguageSelectorField } from "./language-selector-field"; import { CrossEncoderSelectorField } from "./cross-encoder-selector-field"; import { ChunkSelectorField } from "./chunks-selector"; import { HybridSearchField } from "./hybrid-search-field"; +import { useState } from "react"; export interface AuroraFormProps { data: AuroraWorkspaceCreateInput; + crossEncodingEnabled: boolean; onChange: (data: Partial) => void; errors: Record; submitting: boolean; @@ -33,6 +35,7 @@ export default function AuroraForm(props: AuroraFormProps) { footer={ ) => void; errors: Record; submitting: boolean; metrics: RadioGroupProps.RadioButtonDefinition[]; }) { + const [noEncodingSelected, setNoEncodingSelected] = useState( + !props.crossEncodingEnabled + ); return ( @@ -112,16 +119,21 @@ function AuroraFooter(props: { Create an index - { + setNoEncodingSelected(data.crossEncoderModel?.value === "__none__"); + props.onChange(data); + }} + /> + x.value ?? ""), metric: data.metric, index: data.index, - hybridSearch: data.hybridSearch, + hybridSearch: data.hybridSearch && crossEncoderSelected, chunkingStrategy: "recursive", chunkSize: data.chunkSize, chunkOverlap: data.chunkOverlap, @@ -188,6 +197,9 @@ export default function CreateWorkspaceAurora() { > x.value ?? ""), - hybridSearch: data.hybridSearch, + hybridSearch: data.hybridSearch && crossEncoderSelected, chunkingStrategy: "recursive", chunkSize: data.chunkSize, chunkOverlap: data.chunkOverlap, @@ -153,6 +162,9 @@ export default function CreateWorkspaceOpenSearch() { > ) => void; selectedModel: SelectProps.Option | null; errors: Record; @@ -41,21 +42,29 @@ export function CrossEncoderSelectorField(props: CrossEncoderSelectorProps) { })(); }, [appContext]); - const crossEncoderModelOptions = - OptionsHelper.getSelectOptionGroups(crossEncoderModels); + const crossEncoderModelOptions = OptionsHelper.getSelectOptionGroups( + crossEncoderModels, + true + ); return ( - + props.onChange({ crossEncoderModel: selectedOption }) + } + /> + ); } diff --git a/lib/user-interface/react-app/src/pages/rag/create-workspace/embeddings-selector-field.tsx b/lib/user-interface/react-app/src/pages/rag/create-workspace/embeddings-selector-field.tsx index 8053d345..f1d0b47c 100644 --- a/lib/user-interface/react-app/src/pages/rag/create-workspace/embeddings-selector-field.tsx +++ b/lib/user-interface/react-app/src/pages/rag/create-workspace/embeddings-selector-field.tsx @@ -39,8 +39,10 @@ export default function EmbeddingSelector(props: EmbeddingsSelectionProps) { })(); }, [appContext]); - const embeddingsModelOptions = - EmbeddingsModelHelper.getSelectOptions(embeddingsModels); + const embeddingsModelOptions = EmbeddingsModelHelper.getSelectOptions( + appContext, + embeddingsModels + ); return ( ) => void; checked: boolean; errors: Record; @@ -11,11 +12,11 @@ export function HybridSearchField(props: HybridSearchProps) { return ( props.onChange({ hybridSearch: checked }) diff --git a/lib/user-interface/react-app/src/pages/rag/create-workspace/kb-form.tsx b/lib/user-interface/react-app/src/pages/rag/create-workspace/kb-form.tsx index 5f77774a..e9f05900 100644 --- a/lib/user-interface/react-app/src/pages/rag/create-workspace/kb-form.tsx +++ b/lib/user-interface/react-app/src/pages/rag/create-workspace/kb-form.tsx @@ -98,6 +98,7 @@ export default function KBForm(props: KBFormProps) { /> ) => void; errors: Record; submitting: boolean; @@ -27,6 +29,7 @@ export function OpenSearchForm(props: OpenSearchFormProps) { footer={ ) => void; errors: Record; submitting: boolean; }) { + const [noEncodingSelected, setNoEncodingSelected] = useState( + !props.crossEncodingEnabled + ); return ( - { + setNoEncodingSelected(data.crossEncoderModel?.value === "__none__"); + props.onChange(data); + }} + /> +
Cross-encoder provider -
{props.workspace.crossEncoderModelProvider}
+
{props.workspace.crossEncoderModelProvider ?? "None"}
Cross-encoder model -
{props.workspace.crossEncoderModelName}
+
{props.workspace.crossEncoderModelName ?? "None"}
diff --git a/lib/user-interface/react-app/src/pages/rag/workspace/open-search-workspace-settings.tsx b/lib/user-interface/react-app/src/pages/rag/workspace/open-search-workspace-settings.tsx index 258fa48e..a9131d03 100644 --- a/lib/user-interface/react-app/src/pages/rag/workspace/open-search-workspace-settings.tsx +++ b/lib/user-interface/react-app/src/pages/rag/workspace/open-search-workspace-settings.tsx @@ -72,11 +72,11 @@ export default function OpenSearchWorkspaceSettings(
Cross-encoder provider -
{props.workspace.crossEncoderModelProvider}
+
{props.workspace.crossEncoderModelProvider ?? "None"}
Cross-encoder model -
{props.workspace.crossEncoderModelName}
+
{props.workspace.crossEncoderModelName ?? "None"}
diff --git a/tests/__snapshots__/cdk-app.test.ts.snap b/tests/__snapshots__/cdk-app.test.ts.snap index 33aa500c..b4a98941 100644 --- a/tests/__snapshots__/cdk-app.test.ts.snap +++ b/tests/__snapshots__/cdk-app.test.ts.snap @@ -1096,8 +1096,8 @@ input CreateWorkspaceAuroraInput { kind: String! embeddingsModelProvider: String! embeddingsModelName: String! - crossEncoderModelProvider: String! - crossEncoderModelName: String! + crossEncoderModelProvider: String + crossEncoderModelName: String languages: [String!]! metric: String! index: Boolean! @@ -1126,8 +1126,8 @@ input CreateWorkspaceOpenSearchInput { kind: String! embeddingsModelProvider: String! embeddingsModelName: String! - crossEncoderModelProvider: String! - crossEncoderModelName: String! + crossEncoderModelProvider: String + crossEncoderModelName: String languages: [String!]! hybridSearch: Boolean! chunkingStrategy: String! @@ -17391,7 +17391,7 @@ schema { "SharedConfig358B4A20": { "Properties": { "Type": "String", - "Value": "{"prefix":"prefix","privateWebsite":true,"certificate":"","cfGeoRestrictEnable":true,"cfGeoRestrictList":[],"bedrock":{"enabled":true,"region":"us-east-1"},"llms":{"sagemaker":["FalconLite [ml.g5.12xlarge]","Idefics_80b (Multimodal) [ml.g5.48xlarge]"]},"rag":{"enabled":true,"engines":{"aurora":{"enabled":true},"opensearch":{"enabled":true},"kendra":{"enabled":true,"createIndex":true,"enterprise":true},"knowledgeBase":{"enabled":false}},"embeddingsModels":[{"provider":"sagemaker","name":"intfloat/multilingual-e5-large","dimensions":1024},{"provider":"sagemaker","name":"sentence-transformers/all-MiniLM-L6-v2","dimensions":384},{"provider":"bedrock","name":"amazon.titan-embed-text-v1","dimensions":1536},{"provider":"bedrock","name":"amazon.titan-embed-image-v1","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-english-v3","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-multilingual-v3","dimensions":1024,"default":true},{"provider":"openai","name":"text-embedding-ada-002","dimensions":1536}],"crossEncoderModels":[{"provider":"sagemaker","name":"cross-encoder/ms-marco-MiniLM-L-12-v2","default":true}]},"createCMKs":true,"retainOnDelete":true}", + "Value": "{"prefix":"prefix","privateWebsite":true,"certificate":"","cfGeoRestrictEnable":true,"cfGeoRestrictList":[],"bedrock":{"enabled":true,"region":"us-east-1"},"llms":{"sagemaker":["FalconLite [ml.g5.12xlarge]","Idefics_80b (Multimodal) [ml.g5.48xlarge]"]},"rag":{"crossEncodingEnabled":true,"enabled":true,"engines":{"aurora":{"enabled":true},"opensearch":{"enabled":true},"kendra":{"enabled":true,"createIndex":true,"enterprise":true},"knowledgeBase":{"enabled":false}},"embeddingsModels":[{"provider":"sagemaker","name":"intfloat/multilingual-e5-large","dimensions":1024},{"provider":"sagemaker","name":"sentence-transformers/all-MiniLM-L6-v2","dimensions":384},{"provider":"bedrock","name":"amazon.titan-embed-text-v1","dimensions":1536},{"provider":"bedrock","name":"amazon.titan-embed-image-v1","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-english-v3","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-multilingual-v3","dimensions":1024,"default":true},{"provider":"openai","name":"text-embedding-ada-002","dimensions":1536}],"crossEncoderModels":[{"provider":"sagemaker","name":"cross-encoder/ms-marco-MiniLM-L-12-v2","default":true}]},"createCMKs":true,"retainOnDelete":true}", }, "Type": "AWS::SSM::Parameter", }, diff --git a/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap b/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap index 7695a49a..d7b672bf 100644 --- a/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap +++ b/tests/chatbot-api/__snapshots__/chatbot-api-construct.test.ts.snap @@ -1311,8 +1311,8 @@ input CreateWorkspaceAuroraInput { kind: String! embeddingsModelProvider: String! embeddingsModelName: String! - crossEncoderModelProvider: String! - crossEncoderModelName: String! + crossEncoderModelProvider: String + crossEncoderModelName: String languages: [String!]! metric: String! index: Boolean! @@ -1341,8 +1341,8 @@ input CreateWorkspaceOpenSearchInput { kind: String! embeddingsModelProvider: String! embeddingsModelName: String! - crossEncoderModelProvider: String! - crossEncoderModelName: String! + crossEncoderModelProvider: String + crossEncoderModelName: String languages: [String!]! hybridSearch: Boolean! chunkingStrategy: String! @@ -14854,7 +14854,7 @@ schema { "SharedConfig358B4A20": { "Properties": { "Type": "String", - "Value": "{"prefix":"prefix","privateWebsite":true,"certificate":"","cfGeoRestrictEnable":true,"cfGeoRestrictList":[],"bedrock":{"enabled":true,"region":"us-east-1"},"llms":{"sagemaker":["FalconLite [ml.g5.12xlarge]","Idefics_80b (Multimodal) [ml.g5.48xlarge]"]},"rag":{"enabled":true,"engines":{"aurora":{"enabled":true},"opensearch":{"enabled":true},"kendra":{"enabled":true,"createIndex":true,"enterprise":true},"knowledgeBase":{"enabled":false}},"embeddingsModels":[{"provider":"sagemaker","name":"intfloat/multilingual-e5-large","dimensions":1024},{"provider":"sagemaker","name":"sentence-transformers/all-MiniLM-L6-v2","dimensions":384},{"provider":"bedrock","name":"amazon.titan-embed-text-v1","dimensions":1536},{"provider":"bedrock","name":"amazon.titan-embed-image-v1","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-english-v3","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-multilingual-v3","dimensions":1024,"default":true},{"provider":"openai","name":"text-embedding-ada-002","dimensions":1536}],"crossEncoderModels":[{"provider":"sagemaker","name":"cross-encoder/ms-marco-MiniLM-L-12-v2","default":true}]}}", + "Value": "{"prefix":"prefix","privateWebsite":true,"certificate":"","cfGeoRestrictEnable":true,"cfGeoRestrictList":[],"bedrock":{"enabled":true,"region":"us-east-1"},"llms":{"sagemaker":["FalconLite [ml.g5.12xlarge]","Idefics_80b (Multimodal) [ml.g5.48xlarge]"]},"rag":{"crossEncodingEnabled":true,"enabled":true,"engines":{"aurora":{"enabled":true},"opensearch":{"enabled":true},"kendra":{"enabled":true,"createIndex":true,"enterprise":true},"knowledgeBase":{"enabled":false}},"embeddingsModels":[{"provider":"sagemaker","name":"intfloat/multilingual-e5-large","dimensions":1024},{"provider":"sagemaker","name":"sentence-transformers/all-MiniLM-L6-v2","dimensions":384},{"provider":"bedrock","name":"amazon.titan-embed-text-v1","dimensions":1536},{"provider":"bedrock","name":"amazon.titan-embed-image-v1","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-english-v3","dimensions":1024},{"provider":"bedrock","name":"cohere.embed-multilingual-v3","dimensions":1024,"default":true},{"provider":"openai","name":"text-embedding-ada-002","dimensions":1536}],"crossEncoderModels":[{"provider":"sagemaker","name":"cross-encoder/ms-marco-MiniLM-L-12-v2","default":true}]}}", }, "Type": "AWS::SSM::Parameter", }, diff --git a/tests/utils/config-util.ts b/tests/utils/config-util.ts index dabe109b..5a1cba7e 100644 --- a/tests/utils/config-util.ts +++ b/tests/utils/config-util.ts @@ -23,6 +23,7 @@ export function getTestConfig(): SystemConfig { ], }, rag: { + crossEncodingEnabled: true, enabled: true, engines: { aurora: { From 23d4da3a87cfe64fdc98d1b76fdbed221fb02538 Mon Sep 17 00:00:00 2001 From: Charles Marion Date: Tue, 22 Oct 2024 11:37:38 -0500 Subject: [PATCH 3/3] bug: Fix cli script when RAG is not enabled (#594) --- cli/magic-config.ts | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/cli/magic-config.ts b/cli/magic-config.ts index f414269f..5e2e42ca 100644 --- a/cli/magic-config.ts +++ b/cli/magic-config.ts @@ -112,16 +112,19 @@ const embeddingModels: ModelConfig[] = [ provider: "bedrock", name: "amazon.titan-embed-image-v1", dimensions: 1024, + default: false, }, { provider: "bedrock", name: "cohere.embed-english-v3", dimensions: 1024, + default: false, }, { provider: "bedrock", name: "cohere.embed-multilingual-v3", dimensions: 1024, + default: false, }, { provider: "openai", @@ -200,9 +203,14 @@ const embeddingModels: ModelConfig[] = [ options.ragsToEnable.pop("kendra"); } options.embeddings = config.rag.embeddingsModels.map((m) => m.name); - options.defaultEmbedding = (config.rag.embeddingsModels ?? []).filter( + const defaultEmbeddings = (config.rag.embeddingsModels ?? []).filter( (m) => m.default - )[0].name; + ); + + if (defaultEmbeddings.length > 0) { + options.defaultEmbedding = defaultEmbeddings[0].name; + } + options.kendraExternal = config.rag.engines.kendra.external; options.kbExternal = config.rag.engines.knowledgeBase?.external ?? []; options.kendraEnterprise = config.rag.engines.kendra.enterprise; @@ -373,7 +381,7 @@ async function processCreateOptions(options: any): Promise { { type: "confirm", name: "enableSagemakerModels", - message: "Do you want to use any Sagemaker Models", + message: "Do you want to use any text generation Sagemaker Models", initial: options.enableSagemakerModels || false, }, { @@ -665,10 +673,14 @@ async function processCreateOptions(options: any): Promise { options.kendraExternal.length > 0) || false, skip(): boolean { - return !(this as any).state.answers.enableRag; + return ( + !(this as any).state.answers.enableRag || + !(this as any).state.answers.ragsToEnable.includes("kendra") + ); }, }, ]; + const answers: any = await enquirer.prompt(questions); const kendraExternal: any[] = []; let newKendra = answers.enableRag && answers.kendra; @@ -828,13 +840,14 @@ async function processCreateOptions(options: any): Promise { validate(value: string) { const embeding = embeddingModels.find((i) => i.name === value); if ( + answers.enableRag && embeding && - (this as any).state.answers.deployDefaultSagemakerModels === false && + answers?.deployDefaultSagemakerModels === false && embeding?.provider === "sagemaker" ) { return "SageMaker default models are not enabled. Please select another model."; } - if ((this as any).state.answers.enableRag) { + if (answers.enableRag) { return value ? true : "Select a default embedding model"; } return true; @@ -1156,6 +1169,7 @@ async function processCreateOptions(options: any): Promise { initial: false, }, ]); + let advancedSettings: any = {}; if (doAdvancedConfirm.doAdvancedSettings) { advancedSettings = await enquirer.prompt(advancedSettingsPrompts); @@ -1300,22 +1314,14 @@ async function processCreateOptions(options: any): Promise { config.rag.embeddingsModels = embeddingModels.filter( (model) => model.provider !== "sagemaker" ); - for (const model of config.rag.embeddingsModels) { - model.default = model.name === models.defaultEmbedding; - } } else { config.rag.embeddingsModels = []; } - // If we have not enabled rag the default embedding is set to the first model - if (!answers.enableRag) { - (config.rag.embeddingsModels[0] as any).default = true; - } else { - config.rag.embeddingsModels.forEach((m: any) => { - if (m.name === models.defaultEmbedding) { - m.default = true; - } - }); + if (config.rag.embeddingsModels.length > 0 && models.defaultEmbedding) { + for (const model of config.rag.embeddingsModels) { + model.default = model.name === models.defaultEmbedding; + } } config.rag.engines.kendra.createIndex = @@ -1324,9 +1330,10 @@ async function processCreateOptions(options: any): Promise { config.rag.engines.kendra.createIndex || kendraExternal.length > 0; config.rag.engines.kendra.external = [...kendraExternal]; config.rag.engines.kendra.enterprise = answers.kendraEnterprise; + + config.rag.engines.knowledgeBase.external = [...kbExternal]; config.rag.engines.knowledgeBase.enabled = config.rag.engines.knowledgeBase.external.length > 0; - config.rag.engines.knowledgeBase.external = [...kbExternal]; console.log("\n✨ This is the chosen configuration:\n"); console.log(JSON.stringify(config, undefined, 2));