Skip to content

Commit

Permalink
Add Litellm Rerank proxy (onyx-dot-app#2346)
Browse files Browse the repository at this point in the history
* add ability ot set reranking litellm proxy

* add fully functional rerank litellm cards

* minor formatting enforcement

* remove logs
  • Loading branch information
pablonyx authored and rajiv chodisetti committed Oct 2, 2024
1 parent 3992ded commit fee6b98
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)


def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")
2 changes: 2 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,8 @@ class SearchSettings(Base):
Enum(RerankerProvider, native_enum=False), nullable=True
)
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)

num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)

cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
Expand Down
7 changes: 5 additions & 2 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)

if search_settings.rerank_model_name and not search_settings.provider_type:
if (
search_settings.rerank_model_name
and not search_settings.provider_type
and not search_settings.rerank_provider_type
):
warm_up_cross_encoder(search_settings.rerank_model_name)

logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
model_name: str,
provider_type: RerankerProvider | None,
api_key: str | None,
api_url: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
Expand All @@ -250,6 +251,7 @@ def __init__(
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url

def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
Expand All @@ -258,6 +260,7 @@ def predict(self, query: str, passages: list[str]) -> list[float]:
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
api_url=self.api_url,
)

response = requests.post(
Expand Down Expand Up @@ -400,6 +403,7 @@ def warm_up_cross_encoder(
reranking_model = RerankingModel(
model_name=rerank_model_name,
provider_type=None,
api_url=None,
api_key=None,
)

Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/search/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
class RerankingDetails(BaseModel):
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
rerank_api_url: str | None
rerank_provider_type: RerankerProvider | None
rerank_api_key: str | None = None

Expand All @@ -42,6 +43,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "RerankingDetails":
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
rerank_api_url=search_settings.rerank_api_url,
)


Expand Down Expand Up @@ -81,7 +83,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings"
num_rerank=search_settings.num_rerank,
# Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion,
api_url=search_settings.api_url,
rerank_api_url=search_settings.rerank_api_url,
)


Expand Down
1 change: 1 addition & 0 deletions backend/danswer/search/postprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def semantic_reranking(
model_name=rerank_settings.rerank_model_name,
provider_type=rerank_settings.rerank_provider_type,
api_key=rerank_settings.rerank_api_key,
api_url=rerank_settings.rerank_api_url,
)

passages = [
Expand Down
36 changes: 36 additions & 0 deletions backend/model_server/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,28 @@ def cohere_rerank(
return [result.relevance_score for result in sorted_results]


def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client:
response = client.post(
api_url,
json={
"model": model_name,
"query": query,
"documents": docs,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [
item["relevance_score"]
for item in sorted(result["results"], key=lambda x: x["index"])
]


@router.post("/bi-encoder-embed")
async def process_embed_request(
embed_request: EmbedRequest,
Expand Down Expand Up @@ -418,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.LITELLM:
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")

sim_scores = litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)

return RerankResponse(scores=sim_scores)

elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
Expand Down
1 change: 1 addition & 0 deletions backend/shared_configs/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class EmbeddingProvider(str, Enum):

class RerankerProvider(str, Enum):
COHERE = "cohere"
LITELLM = "litellm"


class EmbedTextType(str, Enum):
Expand Down
1 change: 1 addition & 0 deletions backend/shared_configs/model_server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class RerankRequest(BaseModel):
model_name: str
provider_type: RerankerProvider | None = None
api_key: str | None = None
api_url: str | None = None

# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
Expand Down
Loading

0 comments on commit fee6b98

Please sign in to comment.