Skip to content

Commit

Permalink
Implement hybrid search
Browse files Browse the repository at this point in the history
  • Loading branch information
kdid committed Mar 11, 2024
1 parent 8ab431b commit 355bb28
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 111 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ The `env.json` file contains environment variable values for the lambda function
Some of the values can be found as follows:

- `API_TOKEN_SECRET` - already defined; value has to exist but doesn't matter in dev mode
- `ELASTICSEARCH_ENDPOINT` - run the following command:
- `OPENSEARCH_ENDPOINT` - run the following command:
```
aws secretsmanager get-secret-value \
--secret-id dev-environment/config/meadow --query SecretString \
Expand Down
2 changes: 1 addition & 1 deletion chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
K_VALUE = 5
MAX_K = 100
TEMPERATURE = 0.2
TEXT_KEY = "title"
TEXT_KEY = "id"
VERSION = "2023-07-01-preview"

@dataclass
Expand Down
134 changes: 79 additions & 55 deletions chat/src/handlers/opensearch_neural_search.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,87 @@
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore
from opensearchpy import OpenSearch
from typing import Any, List

class OpensearchNeuralSearch(VectorStore):
"""Read-only OpenSearch vectorstore with neural search."""

def __init__(
self,
endpoint: str,
index: str,
model_id: str,
vector_field: str = "embedding",
search_pipeline: str = None,
**kwargs: Any
):
self.client = OpenSearch(hosts=[{"host": endpoint, "port": "443", "use_ssl": True}], **kwargs)
self.index = index
self.model_id = model_id
self.vector_field = vector_field
self.search_pipeline = search_pipeline

# Allow for hybrid searching
# Allow for different types of searches
# Allow for _source override

def similarity_search(
self,
query: str,
k: int = 10,
subquery: Any = None,
**kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
dsl = {
'size': k,
'query': {
'hybrid': {
'queries': [
{
'neural': {
self.vector_field: {
'query_text': query,
'model_id': self.model_id,
'k': k
from typing import Any, List, Tuple


class OpenSearchNeuralSearch(VectorStore):
"""Read-only OpenSearch vectorstore with neural search."""

def __init__(
self,
endpoint: str,
index: str,
model_id: str,
vector_field: str = "embedding",
search_pipeline: str = None,
text_field: str = "id",
**kwargs: Any,
):
self.client = OpenSearch(
hosts=[{"host": endpoint, "port": "443", "use_ssl": True}], **kwargs
)
self.index = index
self.model_id = model_id
self.vector_field = vector_field
self.search_pipeline = search_pipeline
self.text_field = text_field

def similarity_search(
self, query: str, k: int = 10, subquery: Any = None, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to the embedding vector."""
docs_with_scores = self.similarity_search_with_score(
query, k, subquery, **kwargs
)
return [doc[0] for doc in docs_with_scores]

def similarity_search_with_score(
self, query: str, k: int = 10, subquery: Any = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query."""
dsl = {
"size": k,
"query": {
"hybrid": {
"queries": [
{
"neural": {
self.vector_field: {
"query_text": query,
"model_id": self.model_id,
"k": k,
}
}
}
]
}
}
}
]
},
}
}
}

if (subquery):
dsl['query']['hybrid']['queries'].append(subquery)

for key, value in kwargs.items():
dsl[key] = value
if subquery:
dsl["query"]["hybrid"]["queries"].append(subquery)

for key, value in kwargs.items():
dsl[key] = value

response = self.client.search(index=self.index, body=dsl)

response = self.client.search(index=self.index, body=dsl)
documents_with_scores = [
(
Document(
page_content=hit["_source"][self.text_field],
metadata=(hit["_source"]),
),
hit["_score"],
)
for hit in response["hits"]["hits"]
]

return response # replace this
return documents_with_scores

def add_texts(self, texts: List[str], metadatas: List[dict], **kwargs: Any) -> None:
pass

@classmethod
def from_texts(cls, texts: List[str], metadatas: List[dict], **kwargs: Any) -> None:
pass
2 changes: 1 addition & 1 deletion chat/src/helpers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def extract_prompt_value(v):
def prepare_response(config):
try:
docs = config.opensearch.similarity_search(
config.question, k=config.k, vector_field="embedding", text_field="id"
query=config.question, k=config.k
)
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})
Expand Down
22 changes: 8 additions & 14 deletions chat/src/setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from content_handler import ContentHandler
from langchain_community.chat_models import AzureChatOpenAI
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.vectorstores import OpenSearchVectorSearch
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
import os
Expand All @@ -22,7 +22,7 @@ def opensearch_client(region_name=os.getenv("AWS_REGION")):
print(region_name)
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials())
endpoint = os.getenv("ELASTICSEARCH_ENDPOINT")
endpoint = os.getenv("OPENSEARCH_ENDPOINT")

return OpenSearch(
hosts=[{'host': endpoint, 'port': 443}],
Expand All @@ -35,20 +35,14 @@ def opensearch_vector_store(region_name=os.getenv("AWS_REGION")):
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials())

sagemaker_client = session.client(service_name="sagemaker-runtime", region_name=session.region_name)
embeddings = SagemakerEndpointEmbeddings(
client=sagemaker_client,
region_name=session.region_name,
endpoint_name=os.getenv("EMBEDDING_ENDPOINT"),
content_handler=ContentHandler()
)

docsearch = OpenSearchVectorSearch(
index_name=prefix("dc-v2-work"),
embedding_function=embeddings,
opensearch_url="https://" + os.getenv("ELASTICSEARCH_ENDPOINT"),
docsearch = OpenSearchNeuralSearch(
index=prefix("dc-v2-work"),
model_id=os.getenv("OPENSEARCH_MODEL_ID"),
endpoint=os.getenv("OPENSEARCH_ENDPOINT"),
connection_class=RequestsHttpConnection,
http_auth=awsauth,
search_pipeline=prefix("dc-v2-work-pipeline"),
text_field= "id"
)
return docsearch

Expand Down
22 changes: 6 additions & 16 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,18 @@ Parameters:
AzureOpenaiApiKey:
Type: String
Description: Azure OpenAI API Key
AzureOpenaiEmbeddingDeploymentId:
Type: String
Description: Azure OpenAI Embedding Deployment ID
AzureOpenaiLlmDeploymentId:
Type: String
Description: Azure OpenAI LLM Deployment ID
AzureOpenaiResourceName:
Type: String
Description: Azure OpenAI Resource Name
ElasticsearchEndpoint:
OpenSearchEndpoint:
Type: String
Description: Elasticsearch URL
EmbeddingEndpoint:
Description: OpenSearch Endpoint
OpenSearchModelId:
Type: String
Description: Sagemaker Inference Endpoint
Description: OpenSearch Model ID
Resources:
ApiGwAccountConfig:
Type: "AWS::ApiGateway::Account"
Expand Down Expand Up @@ -199,11 +196,10 @@ Resources:
Variables:
API_TOKEN_SECRET: !Ref ApiTokenSecret
AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey
AZURE_OPENAI_EMBEDDING_DEPLOYMENT_ID: !Ref AzureOpenaiEmbeddingDeploymentId
AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId
AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName
ELASTICSEARCH_ENDPOINT: !Ref ElasticsearchEndpoint
EMBEDDING_ENDPOINT: !Ref EmbeddingEndpoint
OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint
OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId
Policies:
- Statement:
- Effect: Allow
Expand All @@ -217,12 +213,6 @@ Resources:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
- Statement:
- Effect: Allow
Action:
- 'sagemaker:InvokeEndpoint'
- 'sagemaker:InvokeEndpointAsync'
Resource: !Sub 'arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:endpoint/${EmbeddingEndpoint}'
Metadata:
BuildMethod: nodejs18.x
Deployment:
Expand Down
2 changes: 1 addition & 1 deletion dev/env.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"Parameters": {
"API_TOKEN_SECRET": "DEVELOPMENT_SECRET",
"ELASTICSEARCH_ENDPOINT": "",
"OPENSEARCH_ENDPOINT": "",
"ENV_PREFIX": "",
"DC_URL": ""
}
Expand Down
10 changes: 5 additions & 5 deletions node/src/api/opensearch.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const { HttpRequest } = require("@aws-sdk/protocol-http");
const { awsFetch } = require("../aws/fetch");
const { elasticsearchEndpoint, prefix } = require("../environment");
const { openSearchEndpoint, prefix } = require("../environment");
const Honeybadger = require("../honeybadger-setup");

async function getCollection(id, opts) {
Expand Down Expand Up @@ -65,7 +65,7 @@ function isVisible(doc, { allowPrivate, allowUnpublished }) {
}

function initRequest(path) {
const endpoint = elasticsearchEndpoint();
const endpoint = openSearchEndpoint();

return new HttpRequest({
method: "GET",
Expand All @@ -80,7 +80,7 @@ function initRequest(path) {

async function search(targets, body, optionsQuery = {}) {
Honeybadger.addBreadcrumb("Searching", { metadata: { targets, body } });
const endpoint = elasticsearchEndpoint();
const endpoint = openSearchEndpoint();

const request = new HttpRequest({
method: "POST",
Expand All @@ -98,7 +98,7 @@ async function search(targets, body, optionsQuery = {}) {
}

async function scroll(scrollId) {
const endpoint = elasticsearchEndpoint();
const endpoint = openSearchEndpoint();

const request = new HttpRequest({
method: "POST",
Expand All @@ -114,7 +114,7 @@ async function scroll(scrollId) {
}

async function deleteScroll(scrollId) {
const endpoint = elasticsearchEndpoint();
const endpoint = openSearchEndpoint();

const request = new HttpRequest({
method: "DELETE",
Expand Down
6 changes: 3 additions & 3 deletions node/src/environment.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ function dcUrl() {
return process.env.DC_URL;
}

function elasticsearchEndpoint() {
return process.env.ELASTICSEARCH_ENDPOINT;
function openSearchEndpoint() {
return process.env.OPENSEARCH_ENDPOINT;
}

function prefix(value) {
Expand All @@ -61,7 +61,7 @@ module.exports = {
appInfo,
dcApiEndpoint,
dcUrl,
elasticsearchEndpoint,
openSearchEndpoint,
prefix,
region,
};
2 changes: 1 addition & 1 deletion node/test/test-helpers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function mockIndex() {
const mock = nock("https://index.test.library.northwestern.edu");

beforeEach(function () {
process.env.ELASTICSEARCH_ENDPOINT = "index.test.library.northwestern.edu";
process.env.OPENSEARCH_ENDPOINT = "index.test.library.northwestern.edu";
});

afterEach(function () {
Expand Down
4 changes: 2 additions & 2 deletions node/test/unit/aws/environment.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ describe("environment", function () {
helpers.saveEnvironment();

it("returns the index endpoint", function () {
process.env.ELASTICSEARCH_ENDPOINT = "index.test.library.northwestern.edu";
expect(environment.elasticsearchEndpoint()).to.eq(
process.env.OPENSEARCH_ENDPOINT = "index.test.library.northwestern.edu";
expect(environment.openSearchEndpoint()).to.eq(
"index.test.library.northwestern.edu"
);
});
Expand Down
Loading

0 comments on commit 355bb28

Please sign in to comment.