Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: some small changes and fixes for AI assistant #1333

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions containers/bundled_querybook_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,21 @@ ELASTICSEARCH_HOST: http://elasticsearch:9200
# AI_ASSISTANT_CONFIG:
# default:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384
# reserved_tokens: 2048
# reserved_tokens: 1024
# table_summary:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384
# sql_summary:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384
# table_select:
# model_args:
# model_name: gpt-3.5-turbo-16k
# model_name: gpt-3.5-turbo
# temperature: 0
# context_length: 16384

# Uncomment below to enable vector store to support embedding based table search.
# Please check langchain doc for the configs of each provider.
Expand Down
10 changes: 6 additions & 4 deletions docs_website/docs/integrations/add_ai_assistant.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ The AI Assistant plugin will allow users to do title generation, text to sql and

Please follow below steps to enable AI assistant plugin:

1. [Optional] Create your own AI assistant provider if needed. Please refer to `querybook/server/lib/ai_assistant/openai_assistant.py` as an example.
1. Add `langchain` package dependency by adding `-r ai/langchain.txt` to `requirements/local.txt`.

2. Add your provider in `plugins/ai_assistant_plugin/__init__.py`
2. [Optional] Create your own AI assistant provider if needed. Please refer to `querybook/server/lib/ai_assistant/openai_assistant.py` as an example.

3. Add configs in the `querybook_config.yaml`. Please refer to `containers/bundled_querybook_config.yaml` as an example. Please also check the model's official doc for all avaialbe model args.
3. Add your provider in `plugins/ai_assistant_plugin/__init__.py`

4. Add configs in the `querybook_config.yaml`. Please refer to `containers/bundled_querybook_config.yaml` as an example. Please also check the model's official doc for all avaialbe model args.

- Dont forget to set proper environment variables for your provider. e.g. for openai, you'll need `OPENAI_API_KEY`.

4. Enable it in `querybook/config/querybook_public_config.yaml`
5. Enable it in `querybook/config/querybook_public_config.yaml`

## Vector Store Plugin

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "querybook",
"version": "3.28.0",
"version": "3.28.1",
"description": "A Big Data Webapp",
"private": true,
"scripts": {
Expand Down
4 changes: 0 additions & 4 deletions querybook/config/querybook_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ AI_ASSISTANT_CONFIG:
model_args:
model_name: ~
temperature: ~
context_length: ~
reserved_tokens: ~
table_select:
fetch_k: ~
top_n: ~

EMBEDDINGS_PROVIDER: ~
EMBEDDINGS_CONFIG: ~
Expand Down
4 changes: 3 additions & 1 deletion querybook/server/datasources/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lib.elasticsearch.suggest_table import construct_suggest_table_query
from lib.elasticsearch.suggest_user import construct_suggest_user_query
from lib.elasticsearch.search_utils import ES_CONFIG
from logic import vector_store as vs_logic

LOG = get_logger(__file__)

Expand Down Expand Up @@ -123,6 +122,9 @@ def vector_search_tables(
keywords,
filters=None,
):
# delayed import only if vector search is enabled
from logic import vector_store as vs_logic

verify_metastore_permission(metastore_id)
return vs_logic.search_tables(metastore_id, keywords, filters)

Expand Down
19 changes: 17 additions & 2 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
LOG = get_logger(__file__)


OPENAI_MODEL_CONTEXT_WINDOW_SIZE = {
"gpt-3.5-turbo": 4097,
"gpt-3.5-turbo-16k": 16385,
"gpt-4": 8192,
"gpt-4-32k": 32768,
}
DEFAULT_MODEL_NAME = "gpt-3.5-turbo"


class OpenAIAssistant(BaseAIAssistant):
"""To use it, please set the following environment variable:
OPENAI_API_KEY: OpenAI API key
Expand All @@ -18,10 +27,16 @@ class OpenAIAssistant(BaseAIAssistant):
def name(self) -> str:
return "openai"

def _get_context_length_by_model(self, model_name: str) -> int:
return (
OPENAI_MODEL_CONTEXT_WINDOW_SIZE.get(model_name)
or OPENAI_MODEL_CONTEXT_WINDOW_SIZE[DEFAULT_MODEL_NAME]
)

def _get_default_llm_config(self):
default_config = super()._get_default_llm_config()
if not default_config.get("model_name"):
default_config["model_name"] = "gpt-3.5-turbo"
default_config["model_name"] = DEFAULT_MODEL_NAME

return default_config

Expand All @@ -36,7 +51,7 @@ def _get_error_msg(self, error) -> str:

return super()._get_error_msg(error)

def _get_llm(self, ai_command: str, callback_handler=None):
def _get_llm(self, ai_command: str, prompt_length: int, callback_handler=None):
config = self._get_llm_config(ai_command)
if not callback_handler:
# non-streaming
Expand Down
82 changes: 64 additions & 18 deletions querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

from app.db import with_session
from const.ai_assistant import (
AICommandType,
DEFAUTL_TABLE_SELECT_LIMIT,
MAX_SAMPLE_QUERY_COUNT_FOR_TABLE_SUMMARY,
AICommandType,
)
from langchain.chains import LLMChain
from lib.logger import get_logger
from lib.query_analysis.lineage import process_query
from lib.vector_store import get_vector_store
Expand All @@ -28,7 +27,11 @@
from .prompts.table_summary_prompt import TABLE_SUMMARY_PROMPT
from .prompts.text_to_sql_prompt import TEXT_TO_SQL_PROMPT
from .streaming_web_socket_callback_handler import StreamingWebsocketCallbackHandler
from .tools.table_schema import get_table_schema_by_name, get_table_schemas_by_names
from .tools.table_schema import (
get_slimmed_table_schemas,
get_table_schema_by_name,
get_table_schemas_by_names,
)

LOG = get_logger(__file__)

Expand Down Expand Up @@ -71,13 +74,17 @@ def _get_llm_config(self, ai_command: str):
**self._config.get(ai_command, {}).get("model_args", {}),
}

@abstractmethod
def _get_context_length_by_model(self, model_name: str) -> int:
"""Get the context window size of the model."""
raise NotImplementedError()

def _get_usable_token_count(self, ai_command: str) -> int:
ai_command_config = self._config.get(ai_command, {})
default_config = self._config.get("default", {})

max_context_length = ai_command_config.get(
"context_length"
) or default_config.get("context_length", 0)
model_name = self._get_llm_config(ai_command)["model_name"]
max_context_length = self._get_context_length_by_model(model_name)
reserved_tokens = ai_command_config.get(
"reserved_tokens"
) or default_config.get("reserved_tokens", 0)
Expand All @@ -86,21 +93,44 @@ def _get_usable_token_count(self, ai_command: str) -> int:

@abstractmethod
def _get_llm(
self, ai_command, callback_handler: StreamingWebsocketCallbackHandler = None
self,
ai_command: str,
prompt_length: int,
callback_handler: StreamingWebsocketCallbackHandler = None,
):
"""return the large language model to use"""
"""return the large language model to use.

Args:
ai_command (str): AI command type
prompt_length (str): The number of tokens in the prompt. Can be used to decide which model to use.
callback_handler (StreamingWebsocketCallbackHandler, optional): Callback handler to handle the straming result.
"""
raise NotImplementedError()

def _get_sql_title_prompt(self, query):
return SQL_TITLE_PROMPT.format(query=query)

def _get_text_to_sql_prompt(self, dialect, question, table_schemas, original_query):
return TEXT_TO_SQL_PROMPT.format(
context_limit = self._get_usable_token_count(AICommandType.TEXT_TO_SQL.value)
prompt = TEXT_TO_SQL_PROMPT.format(
dialect=dialect,
question=question,
table_schemas=table_schemas,
original_query=original_query,
)
token_count = self._get_token_count(AICommandType.TEXT_TO_SQL.value, prompt)

if token_count > context_limit:
# if the prompt is too long, use slimmed table schemas
prompt = TEXT_TO_SQL_PROMPT.format(
dialect=dialect,
question=question,
table_schemas=get_slimmed_table_schemas(table_schemas),
original_query=original_query,
)

# TODO: need a better way to handle it if the prompt is still too long
return prompt

def _get_sql_fix_prompt(self, dialect, query, error, table_schemas):
return SQL_FIX_PROMPT.format(
Expand Down Expand Up @@ -133,11 +163,6 @@ def _get_table_select_prompt(self, top_n, question, table_schemas):
table_schemas=table_schemas,
)

def _get_llm_chain(self, prompt, socket):
callback_handler = StreamingWebsocketCallbackHandler(socket)
llm = self._get_llm(callback_handler=callback_handler)
return LLMChain(llm=llm, prompt=prompt)

def _get_error_msg(self, error) -> str:
"""Override this method to return specific error messages for your own assistant."""
if isinstance(error, ValidationError):
Expand Down Expand Up @@ -207,6 +232,9 @@ def generate_sql_query(
)
llm = self._get_llm(
ai_command=AICommandType.TEXT_TO_SQL.value,
prompt_length=self._get_token_count(
AICommandType.TEXT_TO_SQL.value, prompt
),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
Expand All @@ -224,6 +252,7 @@ def generate_title_from_query(self, query, socket=None):
prompt = self._get_sql_title_prompt(query=query)
llm = self._get_llm(
ai_command=AICommandType.SQL_TITLE.value,
prompt_length=self._get_token_count(AICommandType.SQL_TITLE.value, prompt),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
Expand Down Expand Up @@ -269,6 +298,7 @@ def query_auto_fix(
)
llm = self._get_llm(
ai_command=AICommandType.SQL_FIX.value,
prompt_length=self._get_token_count(AICommandType.SQL_FIX.value, prompt),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
Expand Down Expand Up @@ -301,7 +331,11 @@ def summarize_table(
table_schema=table_schema, sample_queries=sample_queries
)
llm = self._get_llm(
ai_command=AICommandType.TABLE_SUMMARY.value, callback_handler=None
ai_command=AICommandType.TABLE_SUMMARY.value,
prompt_length=self._get_token_count(
AICommandType.TABLE_SUMMARY.value, prompt
),
callback_handler=None,
)
return llm.predict(text=prompt)

Expand All @@ -325,7 +359,11 @@ def summarize_query(

prompt = self._get_sql_summary_prompt(table_schemas=table_schemas, query=query)
llm = self._get_llm(
ai_command=AICommandType.SQL_SUMMARY.value, callback_handler=None
ai_command=AICommandType.SQL_SUMMARY.value,
prompt_length=self._get_token_count(
AICommandType.SQL_SUMMARY.value, prompt
),
callback_handler=None,
)
return llm.predict(text=prompt)

Expand All @@ -349,7 +387,11 @@ def find_tables(self, metastore_id, question, session=None):
AICommandType.TABLE_SELECT.value
)
for full_table_name in table_names:
table_schema, table_name = full_table_name.split(".")
full_table_name_parts = full_table_name.split(".")
if len(full_table_name_parts) != 2:
continue

table_schema, table_name = full_table_name_parts
table = get_table_by_name(
schema_name=table_schema,
name=table_name,
Expand All @@ -374,7 +416,11 @@ def find_tables(self, metastore_id, question, session=None):
question=question,
)
llm = self._get_llm(
ai_command=AICommandType.TABLE_SELECT.value, callback_handler=None
ai_command=AICommandType.TABLE_SELECT.value,
prompt_length=self._get_token_count(
AICommandType.TABLE_SELECT.value, prompt
),
callback_handler=None,
)
return json.loads(llm.predict(text=prompt))
except Exception as e:
Expand Down
Loading