diff --git a/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py new file mode 100644 index 00000000000..d439030be4a --- /dev/null +++ b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py @@ -0,0 +1,26 @@ +"""add support for litellm proxy in reranking + +Revision ID: ba98eba0f66a +Revises: bceb1e139447 +Create Date: 2024-09-06 10:36:04.507332 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "ba98eba0f66a" +down_revision = "bceb1e139447" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("search_settings", "rerank_api_url") diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index a3264a43765..b41e0581117 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -576,6 +576,8 @@ class SearchSettings(Base): Enum(RerankerProvider, native_enum=False), nullable=True ) rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True) + rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True) + num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS) cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship( diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 788b84df403..c518e463e6d 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.notice( f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." ) - - if search_settings.rerank_model_name and not search_settings.provider_type: + if ( + search_settings.rerank_model_name + and not search_settings.provider_type + and not search_settings.rerank_provider_type + ): warm_up_cross_encoder(search_settings.rerank_model_name) logger.notice("Verifying query preprocessing (NLTK) data is downloaded") diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 5a18cacab1c..117205761fe 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -242,6 +242,7 @@ def __init__( model_name: str, provider_type: RerankerProvider | None, api_key: str | None, + api_url: str | None, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: @@ -250,6 +251,7 @@ def __init__( self.model_name = model_name self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url def predict(self, query: str, passages: list[str]) -> list[float]: rerank_request = RerankRequest( @@ -258,6 +260,7 @@ def predict(self, query: str, passages: list[str]) -> list[float]: model_name=self.model_name, provider_type=self.provider_type, api_key=self.api_key, + api_url=self.api_url, ) response = requests.post( @@ -400,6 +403,7 @@ def warm_up_cross_encoder( reranking_model = RerankingModel( model_name=rerank_model_name, provider_type=None, + api_url=None, api_key=None, ) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index e9201c97056..678877812a2 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -26,6 +26,7 @@ class RerankingDetails(BaseModel): # If model is None (or num_rerank is 0), then reranking is turned off rerank_model_name: str | None + rerank_api_url: str | None rerank_provider_type: RerankerProvider | None rerank_api_key: str | None = None @@ -42,6 +43,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "RerankingDetails": rerank_provider_type=search_settings.rerank_provider_type, rerank_api_key=search_settings.rerank_api_key, num_rerank=search_settings.num_rerank, + rerank_api_url=search_settings.rerank_api_url, ) @@ -81,7 +83,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings" num_rerank=search_settings.num_rerank, # Multilingual Expansion multilingual_expansion=search_settings.multilingual_expansion, - api_url=search_settings.api_url, + rerank_api_url=search_settings.rerank_api_url, ) diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index 033bdf36e00..b4a1e48bd39 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -100,6 +100,7 @@ def semantic_reranking( model_name=rerank_settings.rerank_model_name, provider_type=rerank_settings.rerank_provider_type, api_key=rerank_settings.rerank_api_key, + api_url=rerank_settings.rerank_api_url, ) passages = [ diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 41167ab1936..860151b3dc4 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -362,6 +362,28 @@ def cohere_rerank( return [result.relevance_score for result in sorted_results] +def litellm_rerank( + query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None +) -> list[float]: + headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model_name, + "query": query, + "documents": docs, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [ + item["relevance_score"] + for item in sorted(result["results"], key=lambda x: x["index"]) + ] + + @router.post("/bi-encoder-embed") async def process_embed_request( embed_request: EmbedRequest, @@ -418,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons model_name=rerank_request.model_name, ) return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.LITELLM: + if rerank_request.api_url is None: + raise ValueError("API URL is required for LiteLLM reranking.") + + sim_scores = litellm_rerank( + query=rerank_request.query, + docs=rerank_request.documents, + api_url=rerank_request.api_url, + model_name=rerank_request.model_name, + api_key=rerank_request.api_key, + ) + + return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") diff --git a/backend/shared_configs/enums.py b/backend/shared_configs/enums.py index 4dccd43e0a5..b58ac0a8928 100644 --- a/backend/shared_configs/enums.py +++ b/backend/shared_configs/enums.py @@ -11,6 +11,7 @@ class EmbeddingProvider(str, Enum): class RerankerProvider(str, Enum): COHERE = "cohere" + LITELLM = "litellm" class EmbedTextType(str, Enum): diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 5415bebd75a..dd846ed6bad 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -43,6 +43,7 @@ class RerankRequest(BaseModel): model_name: str provider_type: RerankerProvider | None = None api_key: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/web/src/app/admin/embeddings/RerankingFormPage.tsx b/web/src/app/admin/embeddings/RerankingFormPage.tsx index 13cf72e9af1..27b7ce212f7 100644 --- a/web/src/app/admin/embeddings/RerankingFormPage.tsx +++ b/web/src/app/admin/embeddings/RerankingFormPage.tsx @@ -7,7 +7,11 @@ import { rerankingModels, } from "./interfaces"; import { FiExternalLink } from "react-icons/fi"; -import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons"; +import { + CohereIcon, + LiteLLMIcon, + MixedBreadIcon, +} from "@/components/icons/icons"; import { Modal } from "@/components/Modal"; import { Button } from "@tremor/react"; import { TextFormField } from "@/components/admin/connectors/Field"; @@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef< ref ) => { const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false); + const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] = + useState(false); return ( { setSubmitting(false); }} enableReinitialize={true} > - {({ values, setFieldValue }) => { + {({ values, setFieldValue, resetForm }) => { const resetRerankingValues = () => { setRerankingDetails({ ...values, @@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef< ) : rerankingModels.filter( (modelCard) => - modelCard.modelName == - originalRerankingDetails.rerank_model_name + (modelCard.modelName == + originalRerankingDetails.rerank_model_name && + modelCard.rerank_provider_type == + originalRerankingDetails.rerank_provider_type) || + (modelCard.rerank_provider_type == + RerankerProvider.LITELLM && + originalRerankingDetails.rerank_provider_type == + RerankerProvider.LITELLM) ) ).map((card) => { const isSelected = values.rerank_provider_type === card.rerank_provider_type && - values.rerank_model_name === card.modelName; + (card.modelName == null || + values.rerank_model_name === card.modelName); + return (
{ - if (card.rerank_provider_type) { + if ( + card.rerank_provider_type == RerankerProvider.COHERE + ) { setIsApiKeyModalOpen(true); + } else if ( + card.rerank_provider_type == + RerankerProvider.LITELLM + ) { + setShowLiteLLMConfigurationModal(true); + } + + if (!isSelected) { + setRerankingDetails({ + ...values, + rerank_provider_type: card.rerank_provider_type!, + rerank_model_name: card.modelName || null, + rerank_api_key: null, + rerank_api_url: null, + }); + setFieldValue( + "rerank_provider_type", + card.rerank_provider_type + ); + setFieldValue("rerank_model_name", card.modelName); } - setRerankingDetails({ - ...values, - rerank_provider_type: card.rerank_provider_type!, - rerank_model_name: card.modelName, - rerank_api_key: null, - }); - setFieldValue( - "rerank_provider_type", - card.rerank_provider_type - ); - setFieldValue("rerank_model_name", card.modelName); }} >
{card.rerank_provider_type === - RerankerProvider.COHERE ? ( + RerankerProvider.LITELLM ? ( + + ) : RerankerProvider.COHERE ? ( ) : ( @@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef< })}
+ {showLiteLLMConfigurationModal && ( + { + resetForm(); + setShowLiteLLMConfigurationModal(false); + }} + width="w-[800px]" + title="API Key Configuration" + > +
+ ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_url: value, + }); + setFieldValue("rerank_api_url", value); + }} + type="text" + label="LiteLLM Proxy URL" + name="rerank_api_url" + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_key: value, + }); + setFieldValue("rerank_api_key", value); + }} + type="password" + label="LiteLLM Proxy Key" + name="rerank_api_key" + optional + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_model_name: value, + }); + setFieldValue("rerank_model_name", value); + }} + label="LiteLLM Model Name" + name="rerank_model_name" + optional + /> + +
+ +
+
+
+ )} + {isApiKeyModalOpen && ( { @@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef< >
) => { const value = e.target.value; setRerankingDetails({ diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index 27136438814..70afb9830f8 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -5,11 +5,13 @@ export interface RerankingDetails { rerank_model_name: string | null; rerank_provider_type: RerankerProvider | null; rerank_api_key: string | null; + rerank_api_url: string | null; num_rerank: number; } export enum RerankerProvider { COHERE = "cohere", + LITELLM = "litellm", } export interface AdvancedSearchConfiguration { model_name: string; @@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails { export interface RerankingModel { rerank_provider_type: RerankerProvider | null; - modelName: string; + modelName?: string; displayName: string; description: string; link: string; @@ -48,6 +50,13 @@ export interface RerankingModel { } export const rerankingModels: RerankingModel[] = [ + { + rerank_provider_type: RerankerProvider.LITELLM, + cloud: true, + displayName: "LiteLLM", + description: "Host your own reranker or router with LiteLLM proxy", + link: "https://docs.litellm.ai/docs/proxy", + }, { rerank_provider_type: null, cloud: false, diff --git a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx index 3d519e63491..4f4df0a465c 100644 --- a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx @@ -4,7 +4,7 @@ import * as Yup from "yup"; import CredentialSubText from "@/components/credentials/CredentialFields"; import { TrashIcon } from "@/components/icons/icons"; import { FaPlus } from "react-icons/fa"; -import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces"; +import { AdvancedSearchConfiguration } from "../interfaces"; import { BooleanFormField } from "@/components/admin/connectors/Field"; import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput"; diff --git a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx index 09034d8d9b8..6415daf88f5 100644 --- a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx @@ -10,7 +10,7 @@ import { CloudEmbeddingModel, EmbeddingProvider, HostedEmbeddingModel, -} from "../../../../components/embedding/interfaces"; +} from "@/components/embedding/interfaces"; import { errorHandlingFetcher } from "@/lib/fetcher"; import { ErrorCallout } from "@/components/ErrorCallout"; import useSWR, { mutate } from "swr"; @@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading"; import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage"; import { AdvancedSearchConfiguration, - RerankerProvider, RerankingDetails, SavedSearchSettings, } from "../interfaces"; @@ -49,6 +48,7 @@ export default function EmbeddingForm() { num_rerank: 0, rerank_provider_type: null, rerank_model_name: "", + rerank_api_url: null, }); const updateAdvancedEmbeddingDetails = ( @@ -124,6 +124,7 @@ export default function EmbeddingForm() { num_rerank: searchSettings.num_rerank, rerank_provider_type: searchSettings.rerank_provider_type, rerank_model_name: searchSettings.rerank_model_name, + rerank_api_url: searchSettings.rerank_api_url, }); } }, [searchSettings]); @@ -134,12 +135,14 @@ export default function EmbeddingForm() { num_rerank: searchSettings.num_rerank, rerank_provider_type: searchSettings.rerank_provider_type, rerank_model_name: searchSettings.rerank_model_name, + rerank_api_url: searchSettings.rerank_api_url, } : { rerank_api_key: "", num_rerank: 0, rerank_provider_type: null, rerank_model_name: "", + rerank_api_url: null, }; useEffect(() => {