From 0278e2bc8aae2165b8a0438823aed32f3000af37 Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Fri, 12 Jan 2024 13:25:44 -0800 Subject: [PATCH] feat: configurable embedding/llms --- app/config.py | 68 ++++++++++++++++++++++++++-- app/ingest/embed.py | 7 --- app/ingest/store.py | 108 +++++++++++++++++++++++++++++++++++--------- docker-compose.yml | 4 +- env.template | 46 +++++++++++++++++++ 5 files changed, 199 insertions(+), 34 deletions(-) delete mode 100644 app/ingest/embed.py diff --git a/app/config.py b/app/config.py index dc6f424..876efde 100644 --- a/app/config.py +++ b/app/config.py @@ -1,6 +1,7 @@ from typing import Any, Optional -from pydantic import RedisDsn +from pydantic import RedisDsn, ValidationInfo, field_validator +from pydantic_core import Url from pydantic_settings import BaseSettings from app.constants import Environment @@ -16,7 +17,6 @@ class Config: env_file = ".env" env_file_encoding = "utf-8" - env_prefix = "KB_" ENVIRONMENT: Environment = Environment.PRODUCTION """The environment the application is running in.""" @@ -24,11 +24,73 @@ class Config: REDIS: Optional[RedisDsn] = None """The Redis service to use for queueing, indexing and document storage.""" + EMBEDDING_MODEL: str = "" + """The embedding model to use. + + This is a string of the form `[:]` with the following options: + + 1. `local[:]` -- Run a model locally. If this is a path + it will attempt to load a model from that location. Otherwise, it should + be a Hugging Face repository from which to retrieve the model. + 2. `openai[:]` -- The named OpenAI model. `OPENAI_API_KEY` must be set. + 3. `ollama:` -- The named Ollama model. `OLLAMA_BASE_URL` must be set. + + In each of these cases, you can omit the second part for the default model of the + given kind. + + If unset, this will default to `"openai"` if an OpenAI API KEY is available and + otherwise will use `"local"`. + + NOTE: Changing embedding models is not currently supported. + """ + + LLM_MODEL: str = "" + """The LLM model to use. + + This is a string of the form `:` with the following options: + + 1. `local[:]` -- Run a model locally. If this is a path + it will attempt to load a model from that location. Otherwise, it should + be a Hugging Face repository from which to retrieve the model. + 2. `openai[:]` -- The named OpenAI model. `OPENAI_API_KEY` must be set. + 3. `ollama:` -- The named Ollama model. `OLLAMA_BASE_URL` must be set. + + In each of these cases, you can omit the second part for the default model of the + given kind. + + If unset, this will default to `"openai"` if an OpenAI API KEY is available and + otherwise will use `"local"`. + """ + + OPENAI_API_KEY: Optional[str] = None + """ The OpenAI API Key to use for OpenAI models. + + This is required for using openai models. + """ + + OLLAMA_BASE_URL: Optional[Url] = None + """The Base URL for Ollama. + + This is required for using ollama models. + """ + + @field_validator("OLLAMA_BASE_URL") + def validate_ollama_base_url(cls, v, info: ValidationInfo): + MODELS = ["LLM_MODEL", "EMBEDDING_MODEL"] + if v is None: + for model in MODELS: + value = info.get(model, "") + if value.startswith("ollama"): + raise ValueError( + f"{info.field_name} must be set to use '{model}={value}'" + ) + return v + settings = Config() app_configs: dict[str, Any] = { - "title": "Knowledge Bases API", + "title": "Dewy Knowledge Base API", } if not settings.ENVIRONMENT.is_debug: diff --git a/app/ingest/embed.py b/app/ingest/embed.py deleted file mode 100644 index 2a4fb32..0000000 --- a/app/ingest/embed.py +++ /dev/null @@ -1,7 +0,0 @@ -from llama_index.embeddings import resolve_embed_model - -EMBED_MODEL = resolve_embed_model("local") # use the default local embedding - -# Determine the dimension of the embeddings. There isn't an easy API for this, -# so we just perform an embedding and see the dimensions we get back. -EMBEDDING_DIMENSIONS = len(EMBED_MODEL.get_text_embedding("test embedding")) diff --git a/app/ingest/store.py b/app/ingest/store.py index 04bbf3b..ca74b50 100644 --- a/app/ingest/store.py +++ b/app/ingest/store.py @@ -2,29 +2,93 @@ from fastapi import Depends, Request from llama_index import ServiceContext, StorageContext, VectorStoreIndex +from llama_index.embeddings import BaseEmbedding from llama_index.ingestion import DocstoreStrategy, IngestionPipeline from llama_index.ingestion.cache import IngestionCache, RedisCache from llama_index.storage.docstore.redis_docstore import RedisDocumentStore from llama_index.vector_stores import RedisVectorStore +from loguru import logger from app.config import settings -from app.ingest.embed import EMBED_MODEL + +DEFAULT_OPENAI_EMBEDDING_MODEL: str = "text-embedding-ada-002" +DEFAULT_HF_EMBEDDING_MODEL: str = "BAAI/bge-small-en" +DEFAULT_OPENAI_LLM_MODEL: str = "gpt-3.5-turbo" +DEFAULT_HF_LLM_MODEL: str = "StabilityAI/stablelm-tuned-alpha-3b" + + +def _embedding_model(model: str) -> BaseEmbedding: + if not model: + if settings.OPENAI_API_KEY: + model = "openai" + else: + model = "local" + + split = model.split(":", 2) + if split[0] == "openai": + from llama_index.embeddings import OpenAIEmbedding + + model = DEFAULT_OPENAI_EMBEDDING_MODEL + if len(split) == 2: + model = split[1] + return OpenAIEmbedding(model=model) + elif split[0] == "local": + from llama_index.embeddings import HuggingFaceEmbedding + + model = DEFAULT_HF_EMBEDDING_MODEL + if len(split) == 2: + model = split[1] + return HuggingFaceEmbedding(model) + elif split[0] == "ollama": + from llama_index.embeddings import OllamaEmbedding + + model = split[1] + return OllamaEmbedding( + model=model, base_url=settings.OLLAMA_BASE_URL.unicode_string() + ) + else: + raise ValueError(f"Unrecognized embedding model '{model}'") + + +def _llm_model(model: str) -> BaseEmbedding: + if not model: + if settings.OPENAI_API_KEY: + model = "openai" + else: + model = "local" + + split = model.split(":", 2) + if split[0] == "openai": + from llama_index.llms import OpenAI + + model = DEFAULT_OPENAI_LLM_MODEL + if len(split) == 2: + model = split[1] + return OpenAI(model=model) + elif split[0] == "local": + from llama_index.llms import HuggingFaceLLM + + model = DEFAULT_HF_LLM_MODEL + if len(split) == 2: + model = split[1] + return HuggingFaceLLM(model_name=model, tokenizer_name=model) + elif split[0] == "ollama": + from llama_index.llms import Ollama + + model = split[1] + return Ollama(model=model, base_url=settings.OLLAMA_BASE_URL.unicode_string()) + else: + raise ValueError(f"Unrecognized LLM model '{model}") class Store: """Class managing the vector and document store.""" def __init__(self) -> None: - from llama_index.llms import HuggingFaceLLM - - model = "mistralai/Mistral-7B-v0.1" - model_kwargs = {} - self.llm = HuggingFaceLLM( - model_name=model, - model_kwargs=model_kwargs, - tokenizer_name=model, - tokenizer_kwargs=model_kwargs, - ) + self.embedding = _embedding_model(settings.EMBEDDING_MODEL) + self.llm = _llm_model(settings.LLM_MODEL) + logger.info("Embedding: {}", self.embedding.to_dict()) + logger.info("LLM: {}", self.llm.to_dict()) vector_store = RedisVectorStore( index_name="vector_store", @@ -50,20 +114,20 @@ def __init__(self) -> None: HierarchicalNodeParser.from_defaults(chunk_sizes=[2048, 512, 128]), ] - # if self.llm: - # # Transformations that require an LLM. - # from llama_index.extractors import SummaryExtractor, TitleExtractor + if self.llm: + # Transformations that require an LLM. + from llama_index.extractors import SummaryExtractor, TitleExtractor - # transformations.extend( - # [ - # TitleExtractor(self.llm), - # SummaryExtractor(self.llm), - # ] - # ) + transformations.extend( + [ + TitleExtractor(self.llm), + SummaryExtractor(self.llm), + ] + ) self.service_context = ServiceContext.from_defaults( llm=self.llm, - embed_model=EMBED_MODEL, + embed_model=self.embedding, transformations=transformations, ) @@ -74,7 +138,7 @@ def __init__(self) -> None: ) self.ingestion_pipeline = IngestionPipeline( - transformations=transformations + [EMBED_MODEL], + transformations=transformations + [self.embedding], vector_store=vector_store, docstore=docstore, cache=cache, diff --git a/docker-compose.yml b/docker-compose.yml index a84b378..94bb356 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,8 +4,8 @@ services: dewy: image: app_image environment: - KB_ENVIRONMENT: LOCAL - KB_REDIS: "redis://default:testing123@redis:6379" + ENVIRONMENT: LOCAL + REDIS: "redis://default:testing123@redis:6379" LLAMA_INDEX_CACHE_DIR: "/tmp/cache/llama_index" HF_HOME: "/tmp/cache/hf" env_file: diff --git a/env.template b/env.template index e69de29..4be4a47 100644 --- a/env.template +++ b/env.template @@ -0,0 +1,46 @@ +# The embedding model to use. +# +# This is a string of the form `[:]` with the following options: +# +# 1. `local:` -- A local model from the given path. +# 2. `hf:` -- A hugging face model from the given repository. +# 3. `ollama:` -- The named Ollama model. `OLLAMA_BASE_URL` must be set. +# +# In each of these cases, you can omit the second part for the default model of the +# given kind. +# +# If unset, this will default to `"openai"` if an OpenAI API KEY is available and +# otherwise will use `"local"`. +# +# NOTE: Changing embedding models is not currently supported. +EMBEDDING_MODEL="local" + +# The LLM model to use. +# +# This is a string of the form `:` with the following options: +# +# 1. `local:` -- Run a model locally. If this is a path +# it will attempt to load a model from that location. Otherwise, it should +# be a Hugging Face repository from which to retrieve the model. +# 2. `openai:` -- The named OpenAI model. `OPENAI_API_KEY` must be set. +# 3. `ollama:` -- The named Ollama model. `OLLAMA_BASE_URL` must be set. +# +# In each of these cases, you can omit the second part for the default model of the +# given kind. +# +# If unset, this will default to `"openai"` if an OpenAI API KEY is available and +# otherwise will use `"local"`. +# +LLM_MODEL="local" + +# The OpenAI API Key to use for OpenAI models. +# +# This is required for using openai models. +# +# OPENAI_API_KEY="" + +# The Base URL for Ollama. +# +# This is required for using ollama models. +# +OLLAMA_BASE_URL=":11434" \ No newline at end of file