Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/danswer-ai/danswer into fea…
Browse files Browse the repository at this point in the history
…ture/background_processing

rebase alembic migration
  • Loading branch information
rkuo-danswer committed Sep 9, 2024
2 parents ab77692 + e4e4765 commit fe9bb06
Show file tree
Hide file tree
Showing 30 changed files with 362 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# revision identifiers, used by Alembic.
revision = "52a219fb5233"
down_revision = "bceb1e139447"
down_revision = "f7e58d357687"
branch_labels = None
depends_on = None

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)


def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add has_web_login column to user
Revision ID: f7e58d357687
Revises: bceb1e139447
Create Date: 2024-09-07 20:20:54.522620
"""
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "f7e58d357687"
down_revision = "ba98eba0f66a"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)


def downgrade() -> None:
op.drop_column("user", "has_web_login")
2 changes: 2 additions & 0 deletions backend/danswer/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]):

class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True


class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True
54 changes: 52 additions & 2 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
from fastapi_users import exceptions
from fastapi_users import FastAPIUsers
from fastapi_users import models
from fastapi_users import schemas
Expand All @@ -33,6 +35,7 @@
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
Expand Down Expand Up @@ -184,7 +187,7 @@ async def create(
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
) -> User:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
Expand All @@ -193,7 +196,27 @@ async def create(
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user

async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
Expand Down Expand Up @@ -234,6 +257,17 @@ async def oauth_callback(
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, update_dict={"oidc_expiry": None})

# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login:
await self.user_db.update(
user,
update_dict={
"is_verified": is_verified_by_default,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True
return user

async def on_after_register(
Expand Down Expand Up @@ -262,6 +296,22 @@ async def on_after_request_verify(

send_user_verification_email(user.email, token)

async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
user = await super().authenticate(credentials)
if user is None:
try:
user = await self.get_by_email(credentials.username)
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
except exceptions.UserNotExists:
pass
return user


async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/danswerbot/slack/handlers/handle_buttons.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
Expand Down Expand Up @@ -87,6 +88,8 @@ def handle_generate_answer_button(
message_ts = req.payload["message"]["ts"]
thread_ts = req.payload["container"]["thread_ts"]
user_id = req.payload["user"]["id"]
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
email = expert_info.email if expert_info else None

if not thread_ts:
raise ValueError("Missing thread_ts in the payload")
Expand Down Expand Up @@ -125,6 +128,7 @@ def handle_generate_answer_button(
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=user_id or None,
email=email or None,
bypass_filters=True,
is_bot_msg=False,
is_bot_dm=False,
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/danswerbot/slack/handlers/handle_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.db.users import add_non_web_user_if_not_exists
from danswer.utils.logger import setup_logger
from shared_configs.configs import SLACK_CHANNEL_ID

Expand Down Expand Up @@ -209,6 +210,9 @@ def handle_message(
logger.error(f"Was not able to react to user message due to: {e}")

with Session(get_sqlalchemy_engine()) as db_session:
if message_info.email:
add_non_web_user_if_not_exists(message_info.email, db_session)

# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(
message_info=message_info,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
Expand Down Expand Up @@ -103,13 +102,10 @@ def handle_regular_answer(
is_bot_msg = message_info.is_bot_msg
user = None
if message_info.is_bot_dm:
slack_user_info = expert_info_from_slack_id(
message_info.sender, client, user_cache={}
)
if slack_user_info and slack_user_info.email:
if message_info.email:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
user = get_user_by_email(slack_user_info.email, db_session)
user = get_user_by_email(message_info.email, db_session)

document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
Expand Down
14 changes: 13 additions & 1 deletion backend/danswer/danswerbot/slack/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
Expand Down Expand Up @@ -256,6 +257,11 @@ def build_request_details(
tagged = event.get("type") == "app_mention"
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
sender = event.get("user") or None
expert_info = expert_info_from_slack_id(
sender, client.web_client, user_cache={}
)
email = expert_info.email if expert_info else None

msg = remove_danswer_bot_tag(msg, client=client.web_client)

Expand Down Expand Up @@ -286,7 +292,8 @@ def build_request_details(
channel_to_respond=channel,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=event.get("user") or None,
sender=sender,
email=email,
bypass_filters=tagged,
is_bot_msg=False,
is_bot_dm=event.get("channel_type") == "im",
Expand All @@ -296,6 +303,10 @@ def build_request_details(
channel = req.payload["channel_id"]
msg = req.payload["text"]
sender = req.payload["user_id"]
expert_info = expert_info_from_slack_id(
sender, client.web_client, user_cache={}
)
email = expert_info.email if expert_info else None

single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)

Expand All @@ -305,6 +316,7 @@ def build_request_details(
msg_to_respond=None,
thread_to_respond=None,
sender=sender,
email=email,
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/danswerbot/slack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel):
msg_to_respond: str | None
thread_to_respond: str | None
sender: str | None
email: str | None
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot
is_bot_dm: bool # User is direct messaging to DanswerBot
4 changes: 4 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
notifications: Mapped[list["Notification"]] = relationship(
"Notification", back_populates="user"
)
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)


class InputPrompt(Base):
Expand Down Expand Up @@ -591,6 +593,8 @@ class SearchSettings(Base):
Enum(RerankerProvider, native_enum=False), nullable=True
)
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)

num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)

cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
Expand Down
21 changes: 21 additions & 0 deletions backend/danswer/db/users.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import Sequence
from uuid import UUID

from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.orm import Session

from danswer.auth.schemas import UserRole
from danswer.db.models import User


Expand All @@ -30,3 +32,22 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore

return user


def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user

fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user = User(
email=email,
hashed_password=hashed_pass,
has_web_login=False,
role=UserRole.BASIC,
)
db_session.add(user)
db_session.commit()
return user
7 changes: 5 additions & 2 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)

if search_settings.rerank_model_name and not search_settings.provider_type:
if (
search_settings.rerank_model_name
and not search_settings.provider_type
and not search_settings.rerank_provider_type
):
warm_up_cross_encoder(search_settings.rerank_model_name)

logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
model_name: str,
provider_type: RerankerProvider | None,
api_key: str | None,
api_url: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
Expand All @@ -250,6 +251,7 @@ def __init__(
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url

def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
Expand All @@ -258,6 +260,7 @@ def predict(self, query: str, passages: list[str]) -> list[float]:
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
api_url=self.api_url,
)

response = requests.post(
Expand Down Expand Up @@ -400,6 +403,7 @@ def warm_up_cross_encoder(
reranking_model = RerankingModel(
model_name=rerank_model_name,
provider_type=None,
api_url=None,
api_key=None,
)

Expand Down
Loading

0 comments on commit fe9bb06

Please sign in to comment.