From 3a9b964d5c0580ca79c9f1b079f93f33c8e74cd7 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 9 Sep 2024 08:57:01 -0700 Subject: [PATCH 1/3] Add Litellm Rerank proxy (#2346) * add ability ot set reranking litellm proxy * add fully functional rerank litellm cards * minor formatting enforcement * remove logs --- ...0f66a_add_support_for_litellm_proxy_in_.py | 26 +++ backend/danswer/db/models.py | 2 + backend/danswer/main.py | 7 +- .../search_nlp_models.py | 4 + backend/danswer/search/models.py | 4 +- .../search/postprocessing/postprocessing.py | 1 + backend/model_server/encoders.py | 36 ++++ backend/shared_configs/enums.py | 1 + backend/shared_configs/model_server_models.py | 1 + .../admin/embeddings/RerankingFormPage.tsx | 155 +++++++++++++++--- web/src/app/admin/embeddings/interfaces.ts | 11 +- .../pages/AdvancedEmbeddingFormPage.tsx | 2 +- .../embeddings/pages/EmbeddingFormPage.tsx | 7 +- 13 files changed, 231 insertions(+), 26 deletions(-) create mode 100644 backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py diff --git a/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py new file mode 100644 index 00000000000..d439030be4a --- /dev/null +++ b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py @@ -0,0 +1,26 @@ +"""add support for litellm proxy in reranking + +Revision ID: ba98eba0f66a +Revises: bceb1e139447 +Create Date: 2024-09-06 10:36:04.507332 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "ba98eba0f66a" +down_revision = "bceb1e139447" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("search_settings", "rerank_api_url") diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index a3264a43765..b41e0581117 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -576,6 +576,8 @@ class SearchSettings(Base): Enum(RerankerProvider, native_enum=False), nullable=True ) rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True) + rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True) + num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS) cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship( diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 788b84df403..c518e463e6d 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.notice( f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." ) - - if search_settings.rerank_model_name and not search_settings.provider_type: + if ( + search_settings.rerank_model_name + and not search_settings.provider_type + and not search_settings.rerank_provider_type + ): warm_up_cross_encoder(search_settings.rerank_model_name) logger.notice("Verifying query preprocessing (NLTK) data is downloaded") diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 5a18cacab1c..117205761fe 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -242,6 +242,7 @@ def __init__( model_name: str, provider_type: RerankerProvider | None, api_key: str | None, + api_url: str | None, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: @@ -250,6 +251,7 @@ def __init__( self.model_name = model_name self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url def predict(self, query: str, passages: list[str]) -> list[float]: rerank_request = RerankRequest( @@ -258,6 +260,7 @@ def predict(self, query: str, passages: list[str]) -> list[float]: model_name=self.model_name, provider_type=self.provider_type, api_key=self.api_key, + api_url=self.api_url, ) response = requests.post( @@ -400,6 +403,7 @@ def warm_up_cross_encoder( reranking_model = RerankingModel( model_name=rerank_model_name, provider_type=None, + api_url=None, api_key=None, ) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index e9201c97056..678877812a2 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -26,6 +26,7 @@ class RerankingDetails(BaseModel): # If model is None (or num_rerank is 0), then reranking is turned off rerank_model_name: str | None + rerank_api_url: str | None rerank_provider_type: RerankerProvider | None rerank_api_key: str | None = None @@ -42,6 +43,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "RerankingDetails": rerank_provider_type=search_settings.rerank_provider_type, rerank_api_key=search_settings.rerank_api_key, num_rerank=search_settings.num_rerank, + rerank_api_url=search_settings.rerank_api_url, ) @@ -81,7 +83,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings" num_rerank=search_settings.num_rerank, # Multilingual Expansion multilingual_expansion=search_settings.multilingual_expansion, - api_url=search_settings.api_url, + rerank_api_url=search_settings.rerank_api_url, ) diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index 033bdf36e00..b4a1e48bd39 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -100,6 +100,7 @@ def semantic_reranking( model_name=rerank_settings.rerank_model_name, provider_type=rerank_settings.rerank_provider_type, api_key=rerank_settings.rerank_api_key, + api_url=rerank_settings.rerank_api_url, ) passages = [ diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 41167ab1936..860151b3dc4 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -362,6 +362,28 @@ def cohere_rerank( return [result.relevance_score for result in sorted_results] +def litellm_rerank( + query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None +) -> list[float]: + headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model_name, + "query": query, + "documents": docs, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [ + item["relevance_score"] + for item in sorted(result["results"], key=lambda x: x["index"]) + ] + + @router.post("/bi-encoder-embed") async def process_embed_request( embed_request: EmbedRequest, @@ -418,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons model_name=rerank_request.model_name, ) return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.LITELLM: + if rerank_request.api_url is None: + raise ValueError("API URL is required for LiteLLM reranking.") + + sim_scores = litellm_rerank( + query=rerank_request.query, + docs=rerank_request.documents, + api_url=rerank_request.api_url, + model_name=rerank_request.model_name, + api_key=rerank_request.api_key, + ) + + return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") diff --git a/backend/shared_configs/enums.py b/backend/shared_configs/enums.py index 4dccd43e0a5..b58ac0a8928 100644 --- a/backend/shared_configs/enums.py +++ b/backend/shared_configs/enums.py @@ -11,6 +11,7 @@ class EmbeddingProvider(str, Enum): class RerankerProvider(str, Enum): COHERE = "cohere" + LITELLM = "litellm" class EmbedTextType(str, Enum): diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 5415bebd75a..dd846ed6bad 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -43,6 +43,7 @@ class RerankRequest(BaseModel): model_name: str provider_type: RerankerProvider | None = None api_key: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/web/src/app/admin/embeddings/RerankingFormPage.tsx b/web/src/app/admin/embeddings/RerankingFormPage.tsx index 13cf72e9af1..27b7ce212f7 100644 --- a/web/src/app/admin/embeddings/RerankingFormPage.tsx +++ b/web/src/app/admin/embeddings/RerankingFormPage.tsx @@ -7,7 +7,11 @@ import { rerankingModels, } from "./interfaces"; import { FiExternalLink } from "react-icons/fi"; -import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons"; +import { + CohereIcon, + LiteLLMIcon, + MixedBreadIcon, +} from "@/components/icons/icons"; import { Modal } from "@/components/Modal"; import { Button } from "@tremor/react"; import { TextFormField } from "@/components/admin/connectors/Field"; @@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef< ref ) => { const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false); + const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] = + useState(false); return ( { setSubmitting(false); }} enableReinitialize={true} > - {({ values, setFieldValue }) => { + {({ values, setFieldValue, resetForm }) => { const resetRerankingValues = () => { setRerankingDetails({ ...values, @@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef< ) : rerankingModels.filter( (modelCard) => - modelCard.modelName == - originalRerankingDetails.rerank_model_name + (modelCard.modelName == + originalRerankingDetails.rerank_model_name && + modelCard.rerank_provider_type == + originalRerankingDetails.rerank_provider_type) || + (modelCard.rerank_provider_type == + RerankerProvider.LITELLM && + originalRerankingDetails.rerank_provider_type == + RerankerProvider.LITELLM) ) ).map((card) => { const isSelected = values.rerank_provider_type === card.rerank_provider_type && - values.rerank_model_name === card.modelName; + (card.modelName == null || + values.rerank_model_name === card.modelName); + return (
{ - if (card.rerank_provider_type) { + if ( + card.rerank_provider_type == RerankerProvider.COHERE + ) { setIsApiKeyModalOpen(true); + } else if ( + card.rerank_provider_type == + RerankerProvider.LITELLM + ) { + setShowLiteLLMConfigurationModal(true); + } + + if (!isSelected) { + setRerankingDetails({ + ...values, + rerank_provider_type: card.rerank_provider_type!, + rerank_model_name: card.modelName || null, + rerank_api_key: null, + rerank_api_url: null, + }); + setFieldValue( + "rerank_provider_type", + card.rerank_provider_type + ); + setFieldValue("rerank_model_name", card.modelName); } - setRerankingDetails({ - ...values, - rerank_provider_type: card.rerank_provider_type!, - rerank_model_name: card.modelName, - rerank_api_key: null, - }); - setFieldValue( - "rerank_provider_type", - card.rerank_provider_type - ); - setFieldValue("rerank_model_name", card.modelName); }} >
{card.rerank_provider_type === - RerankerProvider.COHERE ? ( + RerankerProvider.LITELLM ? ( + + ) : RerankerProvider.COHERE ? ( ) : ( @@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef< })}
+ {showLiteLLMConfigurationModal && ( + { + resetForm(); + setShowLiteLLMConfigurationModal(false); + }} + width="w-[800px]" + title="API Key Configuration" + > +
+ ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_url: value, + }); + setFieldValue("rerank_api_url", value); + }} + type="text" + label="LiteLLM Proxy URL" + name="rerank_api_url" + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_key: value, + }); + setFieldValue("rerank_api_key", value); + }} + type="password" + label="LiteLLM Proxy Key" + name="rerank_api_key" + optional + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_model_name: value, + }); + setFieldValue("rerank_model_name", value); + }} + label="LiteLLM Model Name" + name="rerank_model_name" + optional + /> + +
+ +
+
+
+ )} + {isApiKeyModalOpen && ( { @@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef< >
) => { const value = e.target.value; setRerankingDetails({ diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index 27136438814..70afb9830f8 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -5,11 +5,13 @@ export interface RerankingDetails { rerank_model_name: string | null; rerank_provider_type: RerankerProvider | null; rerank_api_key: string | null; + rerank_api_url: string | null; num_rerank: number; } export enum RerankerProvider { COHERE = "cohere", + LITELLM = "litellm", } export interface AdvancedSearchConfiguration { model_name: string; @@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails { export interface RerankingModel { rerank_provider_type: RerankerProvider | null; - modelName: string; + modelName?: string; displayName: string; description: string; link: string; @@ -48,6 +50,13 @@ export interface RerankingModel { } export const rerankingModels: RerankingModel[] = [ + { + rerank_provider_type: RerankerProvider.LITELLM, + cloud: true, + displayName: "LiteLLM", + description: "Host your own reranker or router with LiteLLM proxy", + link: "https://docs.litellm.ai/docs/proxy", + }, { rerank_provider_type: null, cloud: false, diff --git a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx index 3d519e63491..4f4df0a465c 100644 --- a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx @@ -4,7 +4,7 @@ import * as Yup from "yup"; import CredentialSubText from "@/components/credentials/CredentialFields"; import { TrashIcon } from "@/components/icons/icons"; import { FaPlus } from "react-icons/fa"; -import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces"; +import { AdvancedSearchConfiguration } from "../interfaces"; import { BooleanFormField } from "@/components/admin/connectors/Field"; import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput"; diff --git a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx index 09034d8d9b8..6415daf88f5 100644 --- a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx @@ -10,7 +10,7 @@ import { CloudEmbeddingModel, EmbeddingProvider, HostedEmbeddingModel, -} from "../../../../components/embedding/interfaces"; +} from "@/components/embedding/interfaces"; import { errorHandlingFetcher } from "@/lib/fetcher"; import { ErrorCallout } from "@/components/ErrorCallout"; import useSWR, { mutate } from "swr"; @@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading"; import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage"; import { AdvancedSearchConfiguration, - RerankerProvider, RerankingDetails, SavedSearchSettings, } from "../interfaces"; @@ -49,6 +48,7 @@ export default function EmbeddingForm() { num_rerank: 0, rerank_provider_type: null, rerank_model_name: "", + rerank_api_url: null, }); const updateAdvancedEmbeddingDetails = ( @@ -124,6 +124,7 @@ export default function EmbeddingForm() { num_rerank: searchSettings.num_rerank, rerank_provider_type: searchSettings.rerank_provider_type, rerank_model_name: searchSettings.rerank_model_name, + rerank_api_url: searchSettings.rerank_api_url, }); } }, [searchSettings]); @@ -134,12 +135,14 @@ export default function EmbeddingForm() { num_rerank: searchSettings.num_rerank, rerank_provider_type: searchSettings.rerank_provider_type, rerank_model_name: searchSettings.rerank_model_name, + rerank_api_url: searchSettings.rerank_api_url, } : { rerank_api_key: "", num_rerank: 0, rerank_provider_type: null, rerank_model_name: "", + rerank_api_url: null, }; useEffect(() => { From c967f53c02e35fcd9d55181a9fb4bc9f6e0d8b75 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Mon, 9 Sep 2024 11:26:12 -0700 Subject: [PATCH 2/3] docker versions have been deprecated for a while, so fixing the annoying warning (#2372) --- deployment/docker_compose/docker-compose.dev.yml | 1 - deployment/docker_compose/docker-compose.gpu-dev.yml | 1 - deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml | 1 - deployment/docker_compose/docker-compose.prod.yml | 1 - deployment/docker_compose/docker-compose.search-testing.yml | 1 - 5 files changed, 5 deletions(-) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 7c570219e81..eb5ba5efc88 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index e12ef698e9b..74da119737e 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index 57763c8fca7..c06e9ae3480 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index 85d585a5a88..53bfa646b55 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} diff --git a/deployment/docker_compose/docker-compose.search-testing.yml b/deployment/docker_compose/docker-compose.search-testing.yml index 7f2ee1d7392..ecd796f6716 100644 --- a/deployment/docker_compose/docker-compose.search-testing.yml +++ b/deployment/docker_compose/docker-compose.search-testing.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} From e4e4765c60dfbae9226b6427aa8ee447b1737b40 Mon Sep 17 00:00:00 2001 From: hj-danswer Date: Mon, 9 Sep 2024 13:21:31 -0700 Subject: [PATCH 3/3] Add user when they interact outside of UI (e.g. Slack bot) (#2369) * Add user when they interact outside of UI (e.g. Slack bot) * fix mypy errors * don't use user manager to avoid async messiness * fix email is none scenario * fix mypy * make code slightly clearer * PR comments * get slack email in generate button as well * fix alembic migration * update name to be more descriptive --------- Co-authored-by: Hyeong Joon Suh --- ...f7e58d357687_add_has_web_column_to_user.py | 26 +++++++++ backend/danswer/auth/schemas.py | 2 + backend/danswer/auth/users.py | 54 ++++++++++++++++++- .../slack/handlers/handle_buttons.py | 4 ++ .../slack/handlers/handle_message.py | 4 ++ .../slack/handlers/handle_regular_answer.py | 8 +-- backend/danswer/danswerbot/slack/listener.py | 14 ++++- backend/danswer/danswerbot/slack/models.py | 1 + backend/danswer/db/models.py | 2 + backend/danswer/db/users.py | 21 ++++++++ backend/ee/danswer/server/saml.py | 1 + web/src/app/auth/login/EmailPasswordForm.tsx | 2 + 12 files changed, 130 insertions(+), 9 deletions(-) create mode 100644 backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py diff --git a/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py new file mode 100644 index 00000000000..68a104c7b0f --- /dev/null +++ b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py @@ -0,0 +1,26 @@ +"""add has_web_login column to user + +Revision ID: f7e58d357687 +Revises: bceb1e139447 +Create Date: 2024-09-07 20:20:54.522620 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "f7e58d357687" +down_revision = "ba98eba0f66a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "user", + sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"), + ) + + +def downgrade() -> None: + op.drop_column("user", "has_web_login") diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py index 9e0553991cc..db8a97ceb04 100644 --- a/backend/danswer/auth/schemas.py +++ b/backend/danswer/auth/schemas.py @@ -33,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]): class UserCreate(schemas.BaseUserCreate): role: UserRole = UserRole.BASIC + has_web_login: bool | None = True class UserUpdate(schemas.BaseUserUpdate): role: UserRole + has_web_login: bool | None = True diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 453e34cbe3d..44a801e847f 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -16,7 +16,9 @@ from fastapi import Request from fastapi import Response from fastapi import status +from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager +from fastapi_users import exceptions from fastapi_users import FastAPIUsers from fastapi_users import models from fastapi_users import schemas @@ -33,6 +35,7 @@ from danswer.auth.invited_users import get_invited_users from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole +from danswer.auth.schemas import UserUpdate from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import EMAIL_FROM @@ -184,7 +187,7 @@ async def create( user_create: schemas.UC | UserCreate, safe: bool = False, request: Optional[Request] = None, - ) -> models.UP: + ) -> User: verify_email_is_invited(user_create.email) verify_email_domain(user_create.email) if hasattr(user_create, "role"): @@ -193,7 +196,27 @@ async def create( user_create.role = UserRole.ADMIN else: user_create.role = UserRole.BASIC - return await super().create(user_create, safe=safe, request=request) # type: ignore + user = None + try: + user = await super().create(user_create, safe=safe, request=request) # type: ignore + except exceptions.UserAlreadyExists: + user = await self.get_by_email(user_create.email) + # Handle case where user has used product outside of web and is now creating an account through web + if ( + not user.has_web_login + and hasattr(user_create, "has_web_login") + and user_create.has_web_login + ): + user_update = UserUpdate( + password=user_create.password, + has_web_login=True, + role=user_create.role, + is_verified=user_create.is_verified, + ) + user = await self.update(user_update, user) + else: + raise exceptions.UserAlreadyExists() + return user async def oauth_callback( self: "BaseUserManager[models.UOAP, models.ID]", @@ -234,6 +257,17 @@ async def oauth_callback( if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY: await self.user_db.update(user, update_dict={"oidc_expiry": None}) + # Handle case where user has used product outside of web and is now creating an account through web + if not user.has_web_login: + await self.user_db.update( + user, + update_dict={ + "is_verified": is_verified_by_default, + "has_web_login": True, + }, + ) + user.is_verified = is_verified_by_default + user.has_web_login = True return user async def on_after_register( @@ -262,6 +296,22 @@ async def on_after_request_verify( send_user_verification_email(user.email, token) + async def authenticate( + self, credentials: OAuth2PasswordRequestForm + ) -> Optional[User]: + user = await super().authenticate(credentials) + if user is None: + try: + user = await self.get_by_email(credentials.username) + if not user.has_web_login: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD", + ) + except exceptions.UserNotExists: + pass + return user + async def get_user_manager( user_db: SQLAlchemyUserDatabase = Depends(get_user_db), diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 732be8df9db..9e1c171ee4f 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -11,6 +11,7 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI +from danswer.connectors.slack.utils import expert_info_from_slack_id from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks from danswer.danswerbot.slack.blocks import get_document_feedback_blocks @@ -87,6 +88,8 @@ def handle_generate_answer_button( message_ts = req.payload["message"]["ts"] thread_ts = req.payload["container"]["thread_ts"] user_id = req.payload["user"]["id"] + expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={}) + email = expert_info.email if expert_info else None if not thread_ts: raise ValueError("Missing thread_ts in the payload") @@ -125,6 +128,7 @@ def handle_generate_answer_button( msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), sender=user_id or None, + email=email or None, bypass_filters=True, is_bot_msg=False, is_bot_dm=False, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 2edbd973553..cce45331ee7 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -21,6 +21,7 @@ from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig +from danswer.db.users import add_non_web_user_if_not_exists from danswer.utils.logger import setup_logger from shared_configs.configs import SLACK_CHANNEL_ID @@ -209,6 +210,9 @@ def handle_message( logger.error(f"Was not able to react to user message due to: {e}") with Session(get_sqlalchemy_engine()) as db_session: + if message_info.email: + add_non_web_user_if_not_exists(message_info.email, db_session) + # first check if we need to respond with a standard answer used_standard_answer = handle_standard_answers( message_info=message_info, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index 12ed9d55673..09ea4e05332 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -22,7 +22,6 @@ from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION -from danswer.connectors.slack.utils import expert_info_from_slack_id from danswer.danswerbot.slack.blocks import build_documents_blocks from danswer.danswerbot.slack.blocks import build_follow_up_block from danswer.danswerbot.slack.blocks import build_qa_response_blocks @@ -103,13 +102,10 @@ def handle_regular_answer( is_bot_msg = message_info.is_bot_msg user = None if message_info.is_bot_dm: - slack_user_info = expert_info_from_slack_id( - message_info.sender, client, user_cache={} - ) - if slack_user_info and slack_user_info.email: + if message_info.email: engine = get_sqlalchemy_engine() with Session(engine) as db_session: - user = get_user_by_email(slack_user_info.email, db_session) + user = get_user_by_email(message_info.email, db_session) document_set_names: list[str] | None = None persona = slack_bot_config.persona if slack_bot_config else None diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index c59f4caf1aa..63f8bcfcd9c 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -13,6 +13,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER +from danswer.connectors.slack.utils import expert_info_from_slack_id from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID @@ -256,6 +257,11 @@ def build_request_details( tagged = event.get("type") == "app_mention" message_ts = event.get("ts") thread_ts = event.get("thread_ts") + sender = event.get("user") or None + expert_info = expert_info_from_slack_id( + sender, client.web_client, user_cache={} + ) + email = expert_info.email if expert_info else None msg = remove_danswer_bot_tag(msg, client=client.web_client) @@ -286,7 +292,8 @@ def build_request_details( channel_to_respond=channel, msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), - sender=event.get("user") or None, + sender=sender, + email=email, bypass_filters=tagged, is_bot_msg=False, is_bot_dm=event.get("channel_type") == "im", @@ -296,6 +303,10 @@ def build_request_details( channel = req.payload["channel_id"] msg = req.payload["text"] sender = req.payload["user_id"] + expert_info = expert_info_from_slack_id( + sender, client.web_client, user_cache={} + ) + email = expert_info.email if expert_info else None single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER) @@ -305,6 +316,7 @@ def build_request_details( msg_to_respond=None, thread_to_respond=None, sender=sender, + email=email, bypass_filters=True, is_bot_msg=True, is_bot_dm=False, diff --git a/backend/danswer/danswerbot/slack/models.py b/backend/danswer/danswerbot/slack/models.py index e4521a759a7..6394eab562d 100644 --- a/backend/danswer/danswerbot/slack/models.py +++ b/backend/danswer/danswerbot/slack/models.py @@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel): msg_to_respond: str | None thread_to_respond: str | None sender: str | None + email: str | None bypass_filters: bool # User has tagged @DanswerBot is_bot_msg: bool # User is using /DanswerBot is_bot_dm: bool # User is direct messaging to DanswerBot diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index b41e0581117..ffc12323a52 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -157,6 +157,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base): notifications: Mapped[list["Notification"]] = relationship( "Notification", back_populates="user" ) + # Whether the user has logged in via web. False if user has only used Danswer through Slack bot + has_web_login: Mapped[bool] = mapped_column(Boolean, default=True) class InputPrompt(Base): diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index d824ccfd921..61ba6e475fe 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -1,9 +1,11 @@ from collections.abc import Sequence from uuid import UUID +from fastapi_users.password import PasswordHelper from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.auth.schemas import UserRole from danswer.db.models import User @@ -30,3 +32,22 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: user = db_session.query(User).filter(User.id == user_id).first() # type: ignore return user + + +def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User: + user = get_user_by_email(email, db_session) + if user is not None: + return user + + fastapi_users_pw_helper = PasswordHelper() + password = fastapi_users_pw_helper.generate() + hashed_pass = fastapi_users_pw_helper.hash(password) + user = User( + email=email, + hashed_password=hashed_pass, + has_web_login=False, + role=UserRole.BASIC, + ) + db_session.add(user) + db_session.commit() + return user diff --git a/backend/ee/danswer/server/saml.py b/backend/ee/danswer/server/saml.py index 5bc62e98d61..38966c15756 100644 --- a/backend/ee/danswer/server/saml.py +++ b/backend/ee/danswer/server/saml.py @@ -65,6 +65,7 @@ async def upsert_saml_user(email: str) -> User: password=hashed_pass, is_verified=True, role=role, + has_web_login=True, ) ) diff --git a/web/src/app/auth/login/EmailPasswordForm.tsx b/web/src/app/auth/login/EmailPasswordForm.tsx index 74dcc1a0a69..6862baa600c 100644 --- a/web/src/app/auth/login/EmailPasswordForm.tsx +++ b/web/src/app/auth/login/EmailPasswordForm.tsx @@ -72,6 +72,8 @@ export function EmailPasswordForm({ let errorMsg = "Unknown error"; if (errorDetail === "LOGIN_BAD_CREDENTIALS") { errorMsg = "Invalid email or password"; + } else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") { + errorMsg = "Create an account to set a password"; } setPopup({ type: "error",