Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Search Pipeline processors, Remote Inference and HttpConnector to enable Retrieval Augmented Generation (RAG) #1195

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ opensearchplugin {
dependencies {
implementation project(':opensearch-ml-common')
implementation project(':opensearch-ml-algorithms')
implementation project(':opensearch-ml-search-processors')
implementation project(':opensearch-ml-memory')

implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,19 @@
import org.opensearch.monitor.os.OsService;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -197,7 +206,7 @@

import lombok.SneakyThrows;

public class MachineLearningPlugin extends Plugin implements ActionPlugin {
public class MachineLearningPlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin {
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons.";
public static final String GENERAL_THREAD_POOL = "opensearch_ml_general";
public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute";
Expand Down Expand Up @@ -648,4 +657,26 @@ public List<Setting<?>> getSettings() {
);
return settings;
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List
.of(
new SearchPlugin.SearchExtSpec<>(
GenerativeQAParamExtBuilder.PARAMETER_NAME,
input -> new GenerativeQAParamExtBuilder(input),
parser -> GenerativeQAParamExtBuilder.parse(parser)
)
);
}

@Override
public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory());
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client));
}
Comment on lines +662 to +682
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msfroh We want to release these behind a feature flag for v2.10.0. I am thinking of returning empty Lists/Maps if the flag is false. What do you think? Is there a cleaner way to handle this?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a very clean way of handling it.

In particular, feature flags are loaded on startup and the set of available processors are loaded on startup. Saying that the processors are available if and only if the feature flag is set sounds correct to me.

}
95 changes: 95 additions & 0 deletions search-processors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# conversational-search-processors
OpenSearch search processors providing conversational search capabilities
=======
# Plugin for Conversations Using Search Processors in OpenSearch
This repo is a WIP plugin for handling conversations in OpenSearch ([Per this RFC](https://github.com/opensearch-project/ml-commons/issues/1150)).

Conversational Retrieval Augmented Generation (RAG) is implemented via Search processors that combine user questions and OpenSearch query results as input to an LLM, e.g. OpenAI, and return answers.

## Creating a search pipeline with the GenerativeQAResponseProcessor

```
PUT /_search/pipeline/<search pipeline name>
{
"response_processors": [
{
"retrieval_augmented_generation": {
"tag": <tag>,
"description": <description>,
"model_id": "<model_id>",
"context_field_list": [<field>] (e.g. ["text"])
}
}
]
}
```

The 'model_id' parameter here needs to refer to a model of type REMOTE that has an HttpConnector instance associated with it.

## Making a search request against an index using the above processor
```
GET /<index>/_search\?search_pipeline\=<search pipeline name>
{
"_source": ["title", "text"],
"query" : {
"neural": {
"text_vector": {
"query_text": <query string>,
"k": <integer> (e.g. 10),
"model_id": <model_id>
}
}
},
"ext": {
"generative_qa_parameters": {
"llm_model": <LLM model> (e.g. "gpt-3.5-turbo"),
"llm_question": <question string>
}
austintlee marked this conversation as resolved.
Show resolved Hide resolved
}
}
```

## Retrieval Augmented Generation response
```
{
"took": 3,
"timed_out": false,
"_shards": {
"total": 3,
"successful": 3,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 110,
"relation": "eq"
},
"max_score": 0.55129033,
"hits": [
{
"_index": "...",
"_id": "...",
"_score": 0.55129033,
"_source": {
"text": "...",
"title": "..."
}
},
{
...
}
...
{
...
}
]
}, // end of hits
"ext": {
"retrieval_augmented_generation": {
"answer": "..."
}
}
}
```
The RAG answer is returned as an "ext" to SearchResponse following the "hits" array.
austintlee marked this conversation as resolved.
Show resolved Hide resolved
74 changes: 74 additions & 0 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
}

repositories {
mavenCentral()
mavenLocal()
}

dependencies {

compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation 'org.apache.commons:commons-lang3:3.12.0'
implementation project(':opensearch-ml-client')
implementation project(':opensearch-ml-common')
implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}"
// https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5
implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1'
implementation("com.google.guava:guava:32.0.1-jre")
implementation group: 'org.json', name: 'json', version: '20230227'
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
testImplementation "org.opensearch.test:framework:${opensearch_version}"
}

test {
include '**/*Tests.class'
systemProperty 'tests.security.manager', 'false'
}

jacocoTestReport {
dependsOn /*integTest,*/ test
reports {
xml.required = true
html.required = true
}
}

jacocoTestCoverageVerification {
violationRules {
rule {
limit {
counter = 'LINE'
minimum = 0.65 //TODO: increase coverage to 0.90
}
limit {
counter = 'BRANCH'
minimum = 0.55 //TODO: increase coverage to 0.85
}
}
}
dependsOn jacocoTestReport
}

check.dependsOn jacocoTestCoverageVerification
//jacocoTestCoverageVerification.dependsOn jacocoTestReport
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.searchpipelines.questionanswering.generative;

public class GenerativeQAProcessorConstants {

// Identifier for the generative QA request processor
public static final String REQUEST_PROCESSOR_TYPE = "question_rewrite";

// Identifier for the generative QA response processor
public static final String RESPONSE_PROCESSOR_TYPE = "retrieval_augmented_generation";

// The model_id of the model registered and deployed in OpenSearch.
public static final String CONFIG_NAME_MODEL_ID = "model_id";

// The name of the model supported by an LLM, e.g. "gpt-3.5" in OpenAI.
public static final String CONFIG_NAME_LLM_MODEL = "llm_model";

// The field in search results that contain the context to be sent to the LLM.
public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

import java.util.Map;

/**
* Defines the request processor for generative QA search pipelines.
*/
public class GenerativeQARequestProcessor extends AbstractProcessor implements SearchRequestProcessor {

private String modelId;
austintlee marked this conversation as resolved.
Show resolved Hide resolved

protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
}

@Override
public SearchRequest processRequest(SearchRequest request) throws Exception {

// TODO Use chat history to rephrase the question with full conversation context.

return request;
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE;
}

public static final class Factory implements Processor.Factory<SearchRequestProcessor> {

@Override
public SearchRequestProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
return new GenerativeQARequestProcessor(tag, description, ignoreFailure,
ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, tag, config, GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID)
);
}
}
}
Loading