Skip to content

Commit

Permalink
Merge branch 'main' into issue-571
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-marion authored Oct 24, 2024
2 parents 48a4d41 + 23d4da3 commit 2f9ef99
Show file tree
Hide file tree
Showing 36 changed files with 614 additions and 298 deletions.
8 changes: 7 additions & 1 deletion bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ import { existsSync, readFileSync } from "fs";

export function getConfig(): SystemConfig {
if (existsSync("./bin/config.json")) {
return JSON.parse(readFileSync("./bin/config.json").toString("utf8"));
return JSON.parse(
readFileSync("./bin/config.json").toString("utf8")
) as SystemConfig;
}
// Default config
return {
Expand Down Expand Up @@ -48,11 +50,13 @@ export function getConfig(): SystemConfig {
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
Expand Down Expand Up @@ -80,8 +84,10 @@ export function getConfig(): SystemConfig {
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
],
crossEncodingEnabled: false,
crossEncoderModels: [
{
provider: "sagemaker",
Expand Down
92 changes: 68 additions & 24 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
SupportedSageMakerModels,
SystemConfig,
SupportedBedrockRegion,
ModelConfig,
} from "../lib/shared/types";
import { LIB_VERSION } from "./version.js";
import * as fs from "fs";
Expand All @@ -34,7 +35,6 @@ function getTimeZonesWithCurrentTime(): { message: string; name: string }[] {
function getCountryCodesAndNames(): { message: string; name: string }[] {
// Use country-list to get an array of countries with their codes and names
const countries = getData();

// Map the country data to match the desired output structure
const countryInfo = countries.map(({ code, name }) => {
return { message: `${name} (${code})`, name: code };
Expand Down Expand Up @@ -88,42 +88,49 @@ const secretManagerArnRegExp = RegExp(
/arn:aws:secretsmanager:[\w-_]+:\d+:secret:[\w-_]+/
);

const embeddingModels = [
const embeddingModels: ModelConfig[] = [
{
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
dimensions: 1024,
default: false,
},
{
provider: "sagemaker",
name: "sentence-transformers/all-MiniLM-L6-v2",
dimensions: 384,
default: false,
},
{
provider: "bedrock",
name: "amazon.titan-embed-text-v1",
dimensions: 1536,
default: false,
},
//Support for inputImage is not yet implemented for amazon.titan-embed-image-v1
{
provider: "bedrock",
name: "amazon.titan-embed-image-v1",
dimensions: 1024,
default: false,
},
{
provider: "bedrock",
name: "cohere.embed-english-v3",
dimensions: 1024,
default: false,
},
{
provider: "bedrock",
name: "cohere.embed-multilingual-v3",
dimensions: 1024,
default: false,
},
{
provider: "openai",
name: "text-embedding-ada-002",
dimensions: 1536,
default: false,
},
];

Expand Down Expand Up @@ -179,6 +186,8 @@ const embeddingModels = [
options.startScheduleEndDate =
config.llms?.sagemakerSchedule?.startScheduleEndDate;
options.enableRag = config.rag.enabled;
options.deployDefaultSagemakerModels =
config.rag.deployDefaultSagemakerModels;
options.ragsToEnable = Object.keys(config.rag.engines ?? {}).filter(
(v: string) =>
(
Expand All @@ -194,9 +203,14 @@ const embeddingModels = [
options.ragsToEnable.pop("kendra");
}
options.embeddings = config.rag.embeddingsModels.map((m) => m.name);
options.defaultEmbedding = (config.rag.embeddingsModels ?? []).filter(
const defaultEmbeddings = (config.rag.embeddingsModels ?? []).filter(
(m) => m.default
)[0].name;
);

if (defaultEmbeddings.length > 0) {
options.defaultEmbedding = defaultEmbeddings[0].name;
}

options.kendraExternal = config.rag.engines.kendra.external;
options.kbExternal = config.rag.engines.knowledgeBase?.external ?? [];
options.kendraEnterprise = config.rag.engines.kendra.enterprise;
Expand Down Expand Up @@ -367,7 +381,7 @@ async function processCreateOptions(options: any): Promise<void> {
{
type: "confirm",
name: "enableSagemakerModels",
message: "Do you want to use any Sagemaker Models",
message: "Do you want to use any text generation Sagemaker Models",
initial: options.enableSagemakerModels || false,
},
{
Expand Down Expand Up @@ -608,6 +622,16 @@ async function processCreateOptions(options: any): Promise<void> {
message: "Do you want to enable RAG",
initial: options.enableRag || false,
},
{
type: "confirm",
name: "deployDefaultSagemakerModels",
message:
"Do you want to deploy the default embedding and cross-encoder models via SageMaker?",
initial: options.deployDefaultSagemakerModels || false,
skip(): boolean {
return !(this as any).state.answers.enableRag;
},
},
{
type: "multiselect",
name: "ragsToEnable",
Expand Down Expand Up @@ -649,10 +673,14 @@ async function processCreateOptions(options: any): Promise<void> {
options.kendraExternal.length > 0) ||
false,
skip(): boolean {
return !(this as any).state.answers.enableRag;
return (
!(this as any).state.answers.enableRag ||
!(this as any).state.answers.ragsToEnable.includes("kendra")
);
},
},
];

const answers: any = await enquirer.prompt(questions);
const kendraExternal: any[] = [];
let newKendra = answers.enableRag && answers.kendra;
Expand Down Expand Up @@ -810,10 +838,18 @@ async function processCreateOptions(options: any): Promise<void> {
choices: embeddingModels.map((m) => ({ name: m.name, value: m })),
initial: options.defaultEmbedding,
validate(value: string) {
if ((this as any).state.answers.enableRag) {
const embeding = embeddingModels.find((i) => i.name === value);
if (
answers.enableRag &&
embeding &&
answers?.deployDefaultSagemakerModels === false &&
embeding?.provider === "sagemaker"
) {
return "SageMaker default models are not enabled. Please select another model.";
}
if (answers.enableRag) {
return value ? true : "Select a default embedding model";
}

return true;
},
skip() {
Expand Down Expand Up @@ -1133,6 +1169,7 @@ async function processCreateOptions(options: any): Promise<void> {
initial: false,
},
]);

let advancedSettings: any = {};
if (doAdvancedConfirm.doAdvancedSettings) {
advancedSettings = await enquirer.prompt(advancedSettingsPrompts);
Expand Down Expand Up @@ -1219,6 +1256,7 @@ async function processCreateOptions(options: any): Promise<void> {
}
: undefined,
llms: {
enableSagemakerModels: answers.enableSagemakerModels,
rateLimitPerAIP: advancedSettings?.llmRateLimitPerIP
? Number(advancedSettings?.llmRateLimitPerIP)
: undefined,
Expand All @@ -1241,6 +1279,7 @@ async function processCreateOptions(options: any): Promise<void> {
},
rag: {
enabled: answers.enableRag,
deployDefaultSagemakerModels: answers.deployDefaultSagemakerModels,
engines: {
aurora: {
enabled: answers.ragsToEnable.includes("aurora"),
Expand All @@ -1259,37 +1298,42 @@ async function processCreateOptions(options: any): Promise<void> {
external: [{}],
},
},
embeddingsModels: [{}],
crossEncoderModels: [{}],
embeddingsModels: [] as ModelConfig[],
crossEncoderModels: [] as ModelConfig[],
},
};

// If we have not enabled rag the default embedding is set to the first model
if (!answers.enableRag) {
models.defaultEmbedding = embeddingModels[0].name;
if (config.rag.enabled && config.rag.deployDefaultSagemakerModels) {
config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
} else if (config.rag.enabled) {
config.rag.embeddingsModels = embeddingModels.filter(
(model) => model.provider !== "sagemaker"
);
} else {
config.rag.embeddingsModels = [];
}

config.rag.crossEncoderModels[0] = {
provider: "sagemaker",
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.embeddingsModels = embeddingModels;
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
m.default = true;
if (config.rag.embeddingsModels.length > 0 && models.defaultEmbedding) {
for (const model of config.rag.embeddingsModels) {
model.default = model.name === models.defaultEmbedding;
}
});
}

config.rag.engines.kendra.createIndex =
answers.ragsToEnable.includes("kendra");
config.rag.engines.kendra.enabled =
config.rag.engines.kendra.createIndex || kendraExternal.length > 0;
config.rag.engines.kendra.external = [...kendraExternal];
config.rag.engines.kendra.enterprise = answers.kendraEnterprise;

config.rag.engines.knowledgeBase.external = [...kbExternal];
config.rag.engines.knowledgeBase.enabled =
config.rag.engines.knowledgeBase.external.length > 0;
config.rag.engines.knowledgeBase.external = [...kbExternal];

console.log("\n✨ This is the chosen configuration:\n");
console.log(JSON.stringify(config, undefined, 2));
Expand Down
75 changes: 56 additions & 19 deletions integtests/chatbot-api/aurora_workspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def run_before_and_after_tests(client: AppSyncClient):
for workspace in client.list_workspaces():
if (
workspace.get("name") == "INTEG_TEST_AURORA"
and workspace.get("status") == "ready"
):
or workspace.get("name") == "INTEG_TEST_AURORA_WITHOUT_RERANK"
) and workspace.get("status") == "ready":
client.delete_workspace(workspace.get("id"))


Expand All @@ -22,23 +22,25 @@ def test_create(client: AppSyncClient, default_embed_model):
if engine.get("enabled") == False:
pytest.skip_flag = True
pytest.skip("Aurora is not enabled.")
pytest.workspace = client.create_aurora_workspace(
input={
"kind": "auro2",
"name": "INTEG_TEST_AURORA",
"embeddingsModelProvider": "bedrock",
"embeddingsModelName": default_embed_model,
"crossEncoderModelName": "cross-encoder/ms-marco-MiniLM-L-12-v2",
"crossEncoderModelProvider": "sagemaker",
"languages": ["english"],
"index": True,
"hybridSearch": True,
"metric": "inner",
"chunkingStrategy": "recursive",
"chunkSize": 1000,
"chunkOverlap": 200,
}
)
input = {
"kind": "auro2",
"name": "INTEG_TEST_AURORA_WITHOUT_RERANK",
"embeddingsModelProvider": "bedrock",
"embeddingsModelName": default_embed_model,
"languages": ["english"],
"index": True,
"hybridSearch": True,
"metric": "inner",
"chunkingStrategy": "recursive",
"chunkSize": 1000,
"chunkOverlap": 200,
}
input_with_rerank = input.copy()
input_with_rerank["name"] = "INTEG_TEST_AURORA"
input_with_rerank["crossEncoderModelName"] = "cross-encoder/ms-marco-MiniLM-L-12-v2"
input_with_rerank["crossEncoderModelProvider"] = "sagemaker"
pytest.workspace = client.create_aurora_workspace(input=input_with_rerank)
pytest.workspace_no_re_rank = client.create_aurora_workspace(input=input)

ready = False
retries = 0
Expand All @@ -56,6 +58,7 @@ def test_create(client: AppSyncClient, default_embed_model):
def test_add_rss(client: AppSyncClient):
if pytest.skip_flag == True:
pytest.skip("Aurora is not enabled.")

pytest.document = client.add_rss_feed(
input={
"workspaceId": pytest.workspace.get("id"),
Expand All @@ -67,6 +70,17 @@ def test_add_rss(client: AppSyncClient):
"limit": 2,
}
)
client.add_rss_feed(
input={
"workspaceId": pytest.workspace_no_re_rank.get("id"),
"title": "INTEG_TEST_AURORA_TITLE",
"address": "https://github.com/aws-samples/aws-genai-llm-chatbot/"
+ "releases.atom",
"contentTypes": ["text/html"],
"followLinks": True,
"limit": 2,
}
)

ready = False
retries = 0
Expand Down Expand Up @@ -137,6 +151,29 @@ def test_search_document(client: AppSyncClient):
assert ready == True


def test_search_document_no_reank(client: AppSyncClient):
if pytest.skip_flag == True:
pytest.skip("Aurora is not enabled.")
ready = False
retries = 0
# Wait for the page to be crawled. This starts on a cron every 5 min.
while not ready and retries < 50:
time.sleep(15)
retries += 1
result = client.semantic_search(
input={
"workspaceId": pytest.workspace_no_re_rank.get("id"),
"query": "Release github",
}
)
if len(result.get("items")) > 1:
ready = True
assert result.get("engine") == "aurora"
# Re-ranking score is no set but the results are ordered by Aurora.
assert result.get("items")[0].get("score") is None
assert ready == True


def test_query_llm(client, default_model, default_provider):
if pytest.skip_flag == True:
pytest.skip("Aurora is not enabled.")
Expand Down
Loading

0 comments on commit 2f9ef99

Please sign in to comment.