Skip to content

Commit

Permalink
Add darwin assistant
Browse files Browse the repository at this point in the history
Add changes to reply only to darwinAssistant

Add changes for custom_llm.py
  • Loading branch information
swati354 committed Jan 27, 2025
1 parent 711b14f commit 8c0531e
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Create table user_slack_persona
Revision ID: 792d1af3dc44
Revises: 3a7802814195
Create Date: 2025-01-24 04:26:02.844951
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "792d1af3dc44"
down_revision = "3a7802814195"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"user_slack_persona",
sa.Column("sender_id", sa.String(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("sender_id"),
)


def downgrade() -> None:
op.drop_table("user_slack_persona")
68 changes: 68 additions & 0 deletions backend/danswer/danswerbot/slack/handlers/handle_buttons.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.feedback import create_chat_message_feedback
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.persona import fetch_persona_by_id
from danswer.db.users import fetch_user_slack_persona
from danswer.db.users import add_user_slack_persona
from danswer.db.users import add_slack_persona_for_user
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger

from sqlalchemy.orm.exc import NoResultFound

logger_base = setup_logger()


Expand Down Expand Up @@ -293,3 +299,65 @@ def handle_followup_resolved_button(
thread_ts=thread_ts,
unfurl=False,
)


def handle_persona_selection(req: SocketModeRequest, client: SocketModeClient) -> None:
action = cast(dict[str, Any], req.payload.get("actions", [])[0])
user_id = req.payload["user"]["id"]
channel_id = req.payload["container"]["channel_id"]
persona_id = action.get(
"value"
)
message_ts_to_respond_to = req.payload.get("container", {}).get("thread_ts")

with Session(get_sqlalchemy_engine()) as db_session:
try:
persona = fetch_persona_by_id(db_session=db_session, persona_id=persona_id)

if persona is None:
respond_in_thread(
client=client.web_client,
channel=channel_id,
text=f"Persona not found.",
thread_ts=message_ts_to_respond_to,
)
return

user_slack_persona = fetch_user_slack_persona(
db_session=db_session, sender_id=user_id
)
if user_slack_persona:
add_slack_persona_for_user(
db_session=db_session,
persona=persona,
user_slack_persona=user_slack_persona,
)
response_text = f"Persona '{persona.name}' has been set!\n"
respond_in_thread(
client=client.web_client,
channel=channel_id,
text=response_text,
thread_ts=message_ts_to_respond_to,
)
return

else:
add_user_slack_persona(
db_session=db_session, sender_id=user_id, persona=persona
)
respond_in_thread(
client=client.web_client,
channel=channel_id,
text=f"'{persona.name}' has been successfully set as the current persona.",
thread_ts=message_ts_to_respond_to,
)
return

except NoResultFound:
respond_in_thread(
client=client.web_client,
channel=channel_id,
text="Error in fetching persona",
thread_ts=message_ts_to_respond_to,
)
return
85 changes: 84 additions & 1 deletion backend/danswer/danswerbot/slack/handlers/handle_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.db.persona import get_persona_with_docset_and_prompts
from danswer.db.persona import get_personas
from danswer.db.users import fetch_user_slack_persona
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
Expand Down Expand Up @@ -177,6 +180,7 @@ def handle_message(
channel_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
channel_name: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
Expand Down Expand Up @@ -206,9 +210,88 @@ def handle_message(
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_bot_dm = message_info.is_bot_dm
persona_name = None

if channel_name is None:
with Session(get_sqlalchemy_engine()) as db_session:
user_slack_persona = fetch_user_slack_persona(
db_session=db_session, sender_id=sender_id
)
if user_slack_persona:
slack_persona_id = user_slack_persona.persona_id or None
persona = get_persona_with_docset_and_prompts(
persona_id=slack_persona_id, db_session=db_session
)
persona_name = persona.name
else:
persona = None
else:
persona = channel_config.persona if channel_config else None

if is_bot_msg:
command = message_info.command
if command == "/personas":
with Session(get_sqlalchemy_engine()) as db_session:
personas = get_personas(
user_id=None, db_session=db_session, include_default=False
)

if not personas:
respond_in_thread(
client=client,
channel=channel,
text="No personas are available.",
thread_ts=message_ts_to_respond_to,
)
return

buttons = [
{
"type": "button",
"text": {"type": "plain_text", "text": persona.name},
"value": str(persona.id),
"action_id": f"set_persona_{persona.id}",
}
for persona in personas
]

blocks = [
{
"type": "section",
"text": {
"type": "mrkdwn",
"text": "Here are the available personas. Click on one to set it:",
},
},
{"type": "actions", "elements": buttons},
]

respond_in_thread(
client=client,
channel=channel,
blocks=blocks,
thread_ts=message_ts_to_respond_to,
)
return
elif command == "/current_persona":
if persona_name is None:
respond_in_thread(
client=client,
channel=channel,
text="No persona is set. Please use the /personas command to set up a persona.",
thread_ts=message_ts_to_respond_to,
)
return
else:
respond_in_thread(
client=client,
channel=channel,
text=f"Current persona : {persona_name}",
thread_ts=message_ts_to_respond_to,
)
return

document_set_names: list[str] | None = None
persona = channel_config.persona if channel_config else None
prompt = None
if persona:
document_set_names = [
Expand Down
7 changes: 7 additions & 0 deletions backend/danswer/danswerbot/slack/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
handle_followup_resolved_button,
)
from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback
from danswer.danswerbot.slack.handlers.handle_buttons import handle_persona_selection
from danswer.danswerbot.slack.handlers.handle_message import handle_message
from danswer.danswerbot.slack.handlers.handle_message import (
remove_scheduled_feedback_reminder,
Expand Down Expand Up @@ -277,6 +278,7 @@ def build_request_details(
channel = req.payload["channel_id"]
msg = req.payload["text"]
sender = req.payload["user_id"]
command = req.payload["command"]

single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)

Expand All @@ -288,6 +290,7 @@ def build_request_details(
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
command=command
)

raise RuntimeError("Programming fault, this should never happen.")
Expand Down Expand Up @@ -356,6 +359,7 @@ def process_message(
channel_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
channel_name = channel_name
)

if failed:
Expand Down Expand Up @@ -391,6 +395,9 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
return handle_followup_resolved_button(req, client, immediate=True)
elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID:
return handle_followup_resolved_button(req, client, immediate=False)
elif action["action_id"].startswith("set_persona_"):
# Persona selection
return handle_persona_selection(req, client)


def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/danswerbot/slack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class SlackMessageInfo(BaseModel):
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot
is_bot_dm: bool # User is direct messaging to DanswerBot
command: str | None # Slash command used by user
10 changes: 10 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,16 @@ class SlackBotConfig(Base):
persona: Mapped[Persona | None] = relationship("Persona")


class UserSlackPersona(Base):
__tablename__ = "user_slack_persona"

sender_id: Mapped[str] = mapped_column(primary_key=True)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
persona: Mapped[Persona | None] = relationship("Persona", foreign_keys=[persona_id])


class TaskQueueState(Base):
# Currently refers to Celery Tasks
__tablename__ = "task_queue_jobs"
Expand Down
12 changes: 12 additions & 0 deletions backend/danswer/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.orm import joinedload

from danswer.auth.schemas import UserRole
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
Expand Down Expand Up @@ -621,3 +622,14 @@ def delete_persona_by_name(
db_session.execute(stmt)

db_session.commit()


def get_persona_with_docset_and_prompts(
persona_id: int, db_session: Session
) -> Persona | None:
persona = db_session.scalar(
select(Persona)
.options(joinedload(Persona.document_sets), joinedload(Persona.prompts))
.filter_by(id=persona_id, deleted=False)
)
return persona
30 changes: 30 additions & 0 deletions backend/danswer/db/users.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections.abc import Sequence

from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.schema import Column

from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.models import UserSlackPersona


def list_users(db_session: Session, q: str = "") -> Sequence[User]:
Expand All @@ -19,3 +22,30 @@ def get_user_by_email(email: str, db_session: Session) -> User | None:
user = db_session.query(User).filter(User.email == email).first() # type: ignore

return user


def fetch_user_slack_persona(
db_session: Session, sender_id: str
) -> UserSlackPersona | None:
return db_session.scalar(
select(UserSlackPersona).where(UserSlackPersona.sender_id == sender_id)
)


def add_user_slack_persona(
db_session: Session, sender_id: str, persona: Persona
) -> None:
user_persona = UserSlackPersona(
sender_id=sender_id, persona_id=persona.id, persona=persona
)
db_session.add(user_persona)
db_session.commit()


def add_slack_persona_for_user(
db_session: Session, persona: Persona, user_slack_persona: UserSlackPersona
) -> None:
user_slack_persona.persona_id = persona.id
user_slack_persona.persona = persona

db_session.commit()
2 changes: 1 addition & 1 deletion backend/danswer/llm/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
identity_url: str | None = GEN_AI_IDENTITY_ENDPOINT,
client_id: str | None = GEN_AI_CLIENT_ID,
client_secret: str | None = GEN_AI_CLIENT_SECRET,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
max_output_tokens: int = int(GEN_AI_MAX_OUTPUT_TOKENS),
api_version: str | None = GEN_AI_API_VERSION,
):

Expand Down

0 comments on commit 8c0531e

Please sign in to comment.