Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add darwin assistant #14

Open
wants to merge 4 commits into
base: feature/darwin
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Create table user_slack_persona

Revision ID: 792d1af3dc44
Revises: 3a7802814195
Create Date: 2025-01-24 04:26:02.844951

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "792d1af3dc44"
down_revision = "3a7802814195"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"user_slack_persona",
sa.Column("sender_id", sa.String(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("sender_id"),
)


def downgrade() -> None:
op.drop_table("user_slack_persona")
6 changes: 6 additions & 0 deletions backend/danswer/connectors/slack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,9 @@ def replace_special_catchall(message: str) -> str:
def add_zero_width_whitespace_after_tag(message: str) -> str:
"""Add a 0 width whitespace after every @"""
return message.replace("@", "@\u200B")

@staticmethod
def handle_bold_syntax_for_slack(text: str) -> str:
""" Replace instances of '**' with a single '*'"""
corrected_text = text.replace('**', '*')
return corrected_text
68 changes: 68 additions & 0 deletions backend/danswer/danswerbot/slack/handlers/handle_buttons.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.feedback import create_chat_message_feedback
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.persona import fetch_persona_by_id
from danswer.db.users import fetch_user_slack_persona
from danswer.db.users import add_user_slack_persona
from danswer.db.users import add_slack_persona_for_user
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger

from sqlalchemy.orm.exc import NoResultFound

logger_base = setup_logger()


Expand Down Expand Up @@ -293,3 +299,65 @@ def handle_followup_resolved_button(
thread_ts=thread_ts,
unfurl=False,
)


def handle_persona_selection(req: SocketModeRequest, client: SocketModeClient) -> None:
action = cast(dict[str, Any], req.payload.get("actions", [])[0])
user_id = req.payload["user"]["id"]
channel_id = req.payload["container"]["channel_id"]
persona_id = action.get(
"value"
)
message_ts_to_respond_to = req.payload.get("container", {}).get("thread_ts")

with Session(get_sqlalchemy_engine()) as db_session:
try:
persona = fetch_persona_by_id(db_session=db_session, persona_id=persona_id)

if persona is None:
respond_in_thread(
client=client.web_client,
channel=channel_id,
text=f"Persona not found.",
thread_ts=message_ts_to_respond_to,
)
return

user_slack_persona = fetch_user_slack_persona(
db_session=db_session, sender_id=user_id
)
if user_slack_persona:
add_slack_persona_for_user(
db_session=db_session,
persona=persona,
user_slack_persona=user_slack_persona,
)
response_text = f"Persona '{persona.name}' has been set!\n"
respond_in_thread(
client=client.web_client,
channel=channel_id,
text=response_text,
thread_ts=message_ts_to_respond_to,
)
return

else:
add_user_slack_persona(
db_session=db_session, sender_id=user_id, persona=persona
)
respond_in_thread(
client=client.web_client,
channel=channel_id,
text=f"'{persona.name}' has been successfully set as the current persona.",
thread_ts=message_ts_to_respond_to,
)
return

except NoResultFound:
respond_in_thread(
client=client.web_client,
channel=channel_id,
text="Error in fetching persona",
thread_ts=message_ts_to_respond_to,
)
return
85 changes: 84 additions & 1 deletion backend/danswer/danswerbot/slack/handlers/handle_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.db.persona import get_persona_with_docset_and_prompts
from danswer.db.persona import get_personas
from danswer.db.users import fetch_user_slack_persona
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
Expand Down Expand Up @@ -177,6 +180,7 @@ def handle_message(
channel_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
channel_name: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
Expand Down Expand Up @@ -206,9 +210,88 @@ def handle_message(
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_bot_dm = message_info.is_bot_dm
persona_name = None

if channel_name is None:
with Session(get_sqlalchemy_engine()) as db_session:
user_slack_persona = fetch_user_slack_persona(
db_session=db_session, sender_id=sender_id
)
if user_slack_persona:
slack_persona_id = user_slack_persona.persona_id or None
persona = get_persona_with_docset_and_prompts(
persona_id=slack_persona_id, db_session=db_session
)
persona_name = persona.name
else:
persona = None
else:
persona = channel_config.persona if channel_config else None

if is_bot_msg:
command = message_info.command
if command == "/personas":
with Session(get_sqlalchemy_engine()) as db_session:
personas = get_personas(
user_id=None, db_session=db_session, include_default=False
)

if not personas:
respond_in_thread(
client=client,
channel=channel,
text="No personas are available.",
thread_ts=message_ts_to_respond_to,
)
return

buttons = [
{
"type": "button",
"text": {"type": "plain_text", "text": persona.name},
"value": str(persona.id),
"action_id": f"set_persona_{persona.id}",
}
for persona in personas
]

blocks = [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "Here are the available personas. Click on one to set it:",
},
},
{"type": "actions", "elements": buttons},
]

respond_in_thread(
client=client,
channel=channel,
blocks=blocks,
thread_ts=message_ts_to_respond_to,
)
return
elif command == "/current_persona":
if persona_name is None:
respond_in_thread(
client=client,
channel=channel,
text="No persona is set. Please use the /personas command to set up a persona.",
thread_ts=message_ts_to_respond_to,
)
return
else:
respond_in_thread(
client=client,
channel=channel,
text=f"Current persona : {persona_name}",
thread_ts=message_ts_to_respond_to,
)
return

document_set_names: list[str] | None = None
persona = channel_config.persona if channel_config else None
prompt = None
if persona:
document_set_names = [
Expand Down
12 changes: 12 additions & 0 deletions backend/danswer/danswerbot/slack/listener.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import time
from threading import Event
from typing import Any
Expand Down Expand Up @@ -28,6 +29,7 @@
handle_followup_resolved_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
from danswer.danswerbot.slack.handlers.handle_buttons import handle_persona_selection
from danswer.danswerbot.slack.handlers.handle_message import handle_message
from danswer.danswerbot.slack.handlers.handle_message import (
remove_scheduled_feedback_reminder,
Expand Down Expand Up @@ -93,6 +95,10 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
if not msg:
channel_specific_logger.error("Cannot respond to empty message - skipping")
return False

if re.search(r"!darwin", msg, re.IGNORECASE):
channel_specific_logger.info("Ignoring message containing '!darwin'")
return False

if (
req.payload.setdefault("event", {}).get("user", "")
Expand Down Expand Up @@ -277,6 +283,7 @@ def build_request_details(
channel = req.payload["channel_id"]
msg = req.payload["text"]
sender = req.payload["user_id"]
command = req.payload["command"]

single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)

Expand All @@ -288,6 +295,7 @@ def build_request_details(
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
command=command
)

raise RuntimeError("Programming fault, this should never happen.")
Expand Down Expand Up @@ -356,6 +364,7 @@ def process_message(
channel_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
channel_name = channel_name
)

if failed:
Expand Down Expand Up @@ -391,6 +400,9 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
return handle_followup_resolved_button(req, client, immediate=True)
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
return handle_followup_resolved_button(req, client, immediate=False)
elif action["action_id"].startswith("set_persona_"):
# Persona selection
return handle_persona_selection(req, client)


def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/danswerbot/slack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class SlackMessageInfo(BaseModel):
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot
is_bot_dm: bool # User is direct messaging to DanswerBot
command: str | None # Slash command used by user
1 change: 1 addition & 0 deletions backend/danswer/danswerbot/slack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def remove_slack_text_interactions(slack_str: str) -> str:
slack_str = SlackTextCleaner.replace_links(slack_str)
slack_str = SlackTextCleaner.replace_special_catchall(slack_str)
slack_str = SlackTextCleaner.add_zero_width_whitespace_after_tag(slack_str)
slack_str = SlackTextCleaner.handle_bold_syntax_for_slack(slack_str)
return slack_str


Expand Down
10 changes: 10 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,16 @@ class SlackBotConfig(Base):
persona: Mapped[Persona | None] = relationship("Persona")


class UserSlackPersona(Base):
__tablename__ = "user_slack_persona"

sender_id: Mapped[str] = mapped_column(primary_key=True)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
persona: Mapped[Persona | None] = relationship("Persona", foreign_keys=[persona_id])


class TaskQueueState(Base):
# Currently refers to Celery Tasks
__tablename__ = "task_queue_jobs"
Expand Down
12 changes: 12 additions & 0 deletions backend/danswer/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.orm import joinedload

from danswer.auth.schemas import UserRole
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
Expand Down Expand Up @@ -621,3 +622,14 @@ def delete_persona_by_name(
db_session.execute(stmt)

db_session.commit()


def get_persona_with_docset_and_prompts(
persona_id: int, db_session: Session
) -> Persona | None:
persona = db_session.scalar(
select(Persona)
.options(joinedload(Persona.document_sets), joinedload(Persona.prompts))
.filter_by(id=persona_id, deleted=False)
)
return persona
30 changes: 30 additions & 0 deletions backend/danswer/db/users.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections.abc import Sequence

from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.schema import Column

from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.models import UserSlackPersona


def list_users(db_session: Session, q: str = "") -> Sequence[User]:
Expand All @@ -19,3 +22,30 @@ def get_user_by_email(email: str, db_session: Session) -> User | None:
user = db_session.query(User).filter(User.email == email).first() # type: ignore

return user


def fetch_user_slack_persona(
db_session: Session, sender_id: str
) -> UserSlackPersona | None:
return db_session.scalar(
select(UserSlackPersona).where(UserSlackPersona.sender_id == sender_id)
)


def add_user_slack_persona(
db_session: Session, sender_id: str, persona: Persona
) -> None:
user_persona = UserSlackPersona(
sender_id=sender_id, persona_id=persona.id, persona=persona
)
db_session.add(user_persona)
db_session.commit()


def add_slack_persona_for_user(
db_session: Session, persona: Persona, user_slack_persona: UserSlackPersona
) -> None:
user_slack_persona.persona_id = persona.id
user_slack_persona.persona = persona

db_session.commit()
Loading
Loading