diff --git a/backend/alembic/versions/792d1af3dc44_create_table_user_slack_persona.py b/backend/alembic/versions/792d1af3dc44_create_table_user_slack_persona.py new file mode 100644 index 00000000000..458cca9d0df --- /dev/null +++ b/backend/alembic/versions/792d1af3dc44_create_table_user_slack_persona.py @@ -0,0 +1,33 @@ +"""Create table user_slack_persona + +Revision ID: 792d1af3dc44 +Revises: 3a7802814195 +Create Date: 2025-01-24 04:26:02.844951 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "792d1af3dc44" +down_revision = "3a7802814195" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "user_slack_persona", + sa.Column("sender_id", sa.String(), nullable=False), + sa.Column("persona_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.PrimaryKeyConstraint("sender_id"), + ) + + +def downgrade() -> None: + op.drop_table("user_slack_persona") \ No newline at end of file diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index f4b0b2e02c5..c02e49efd44 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -114,27 +114,29 @@ def combine_message_chain( return "\n\n".join(message_strs) -def reorganize_citations( - answer: str, citations: list[CitationInfo] -) -> tuple[str, list[CitationInfo]]: - """For a complete, citation-aware response, we want to reorganize the citations so that +def reorganize_citations(answer: str, citations: list) -> tuple[str, list]: + """ + For a complete, citation-aware response, we want to reorganize the citations so that they are in the order of the documents that were used in the response. This just looks nicer / avoids - confusion ("Why is there [7] when only 2 documents are cited?").""" + confusion ("Why is there [7] when only 2 documents are cited?"). - # Regular expression to find all instances of [[x]](LINK) - pattern = r"\[\[(.*?)\]\]\((.*?)\)" + Now also handles citations in the format [number] in addition to [[number]](LINK). + """ + + pattern = r"\[\[(\d+)\]\]\((.*?)\)|\[(\d+)\]" all_citation_matches = re.findall(pattern, answer) new_citation_info: dict[int, CitationInfo] = {} for citation_match in all_citation_matches: try: - citation_num = int(citation_match[0]) + citation_str = citation_match[0] if citation_match[0] else citation_match[2] + citation_num = int(citation_str) if citation_num in new_citation_info: continue matching_citation = next( - iter([c for c in citations if c.citation_num == int(citation_num)]), + (c for c in citations if c.citation_num == citation_num), None, ) if matching_citation is None: @@ -146,19 +148,30 @@ def reorganize_citations( ) except Exception: pass - # Function to replace citations with their new number def slack_link_format(match: re.Match) -> str: - link_text = match.group(1) - try: - citation_num = int(link_text) - if citation_num in new_citation_info: - link_text = new_citation_info[citation_num].citation_num - except Exception: - pass - - link_url = match.group(2) - return f"[[{link_text}]]({link_url})" + # Case 1: Linked citation ([[number]](LINK)) + if match.group(1): + link_text = match.group(1) + try: + citation_num = int(link_text) + if citation_num in new_citation_info: + link_text = new_citation_info[citation_num].citation_num + except Exception: + pass + link_url = match.group(2) + return f"[[{link_text}]]({link_url})" + # Case 2: Non-linked citation ([number]) + elif match.group(3): + try: + citation_num = int(match.group(3)) + if citation_num in new_citation_info: + citation_num = new_citation_info[citation_num].citation_num + except Exception: + pass + return f"[{citation_num}]" + else: + return match.group(0) # Substitute all matches in the input text new_answer = re.sub(pattern, slack_link_format, answer) diff --git a/backend/danswer/connectors/slack/utils.py b/backend/danswer/connectors/slack/utils.py index 21bae6571d8..7a37b65def4 100644 --- a/backend/danswer/connectors/slack/utils.py +++ b/backend/danswer/connectors/slack/utils.py @@ -278,3 +278,9 @@ def replace_special_catchall(message: str) -> str: def add_zero_width_whitespace_after_tag(message: str) -> str: """Add a 0 width whitespace after every @""" return message.replace("@", "@\u200B") + + @staticmethod + def handle_bold_syntax_for_slack(text: str) -> str: + """ Replace instances of '**' with a single '*'""" + corrected_text = text.replace('**', '*') + return corrected_text diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 3a0209b076f..41f015c4d1c 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -32,10 +32,16 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.feedback import create_chat_message_feedback from danswer.db.feedback import create_doc_retrieval_feedback +from danswer.db.persona import fetch_persona_by_id +from danswer.db.users import fetch_user_slack_persona +from danswer.db.users import add_user_slack_persona +from danswer.db.users import add_slack_persona_for_user from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.utils.logger import setup_logger +from sqlalchemy.orm.exc import NoResultFound + logger_base = setup_logger() @@ -293,3 +299,65 @@ def handle_followup_resolved_button( thread_ts=thread_ts, unfurl=False, ) + + +def handle_persona_selection(req: SocketModeRequest, client: SocketModeClient) -> None: + action = cast(dict[str, Any], req.payload.get("actions", [])[0]) + user_id = req.payload["user"]["id"] + channel_id = req.payload["container"]["channel_id"] + persona_id = action.get( + "value" + ) + message_ts_to_respond_to = req.payload.get("container", {}).get("thread_ts") + + with Session(get_sqlalchemy_engine()) as db_session: + try: + persona = fetch_persona_by_id(db_session=db_session, persona_id=persona_id) + + if persona is None: + respond_in_thread( + client=client.web_client, + channel=channel_id, + text=f"Persona not found.", + thread_ts=message_ts_to_respond_to, + ) + return + + user_slack_persona = fetch_user_slack_persona( + db_session=db_session, sender_id=user_id + ) + if user_slack_persona: + add_slack_persona_for_user( + db_session=db_session, + persona=persona, + user_slack_persona=user_slack_persona, + ) + response_text = f"Persona '{persona.name}' has been set!\n" + respond_in_thread( + client=client.web_client, + channel=channel_id, + text=response_text, + thread_ts=message_ts_to_respond_to, + ) + return + + else: + add_user_slack_persona( + db_session=db_session, sender_id=user_id, persona=persona + ) + respond_in_thread( + client=client.web_client, + channel=channel_id, + text=f"'{persona.name}' has been successfully set as the current persona.", + thread_ts=message_ts_to_respond_to, + ) + return + + except NoResultFound: + respond_in_thread( + client=client.web_client, + channel=channel_id, + text="Error in fetching persona", + thread_ts=message_ts_to_respond_to, + ) + return diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 85ddfe70a6c..b33f63dc699 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -2,6 +2,7 @@ import functools import logging from collections.abc import Callable +import re from typing import Any from typing import cast from typing import Optional @@ -47,6 +48,9 @@ from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType from danswer.db.persona import fetch_persona_by_id +from danswer.db.persona import get_persona_with_docset_and_prompts +from danswer.db.persona import get_personas +from danswer.db.users import fetch_user_slack_persona from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) @@ -172,11 +176,23 @@ def remove_scheduled_feedback_reminder( ) +def contains_questionmark_outside_links(message: str) -> bool: + """ + Checks if the message contains a question mark outside of URLs. + """ + url_pattern = r"]+>|https?://\S+" + + message_without_links = re.sub(url_pattern, "", message) + + return "?" in message_without_links + + def handle_message( message_info: SlackMessageInfo, channel_config: SlackBotConfig | None, client: WebClient, feedback_reminder_id: str | None, + channel_name: str | None, num_retries: int = DANSWER_BOT_NUM_RETRIES, answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT, should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS, @@ -206,9 +222,88 @@ def handle_message( bypass_filters = message_info.bypass_filters is_bot_msg = message_info.is_bot_msg is_bot_dm = message_info.is_bot_dm + persona_name = None + + if channel_name is None: + with Session(get_sqlalchemy_engine()) as db_session: + user_slack_persona = fetch_user_slack_persona( + db_session=db_session, sender_id=sender_id + ) + if user_slack_persona: + slack_persona_id = user_slack_persona.persona_id or None + persona = get_persona_with_docset_and_prompts( + persona_id=slack_persona_id, db_session=db_session + ) + persona_name = persona.name + else: + persona = None + else: + persona = channel_config.persona if channel_config else None + + if is_bot_msg: + command = message_info.command + if command == "/personas": + with Session(get_sqlalchemy_engine()) as db_session: + personas = get_personas( + user_id=None, db_session=db_session, include_default=False + ) + + if not personas: + respond_in_thread( + client=client, + channel=channel, + text="No personas are available.", + thread_ts=message_ts_to_respond_to, + ) + return + + buttons = [ + { + "type": "button", + "text": {"type": "plain_text", "text": persona.name}, + "value": str(persona.id), + "action_id": f"set_persona_{persona.id}", + } + for persona in personas + ] + + blocks = [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "Here are the available personas. Click on one to set it:", + }, + }, + {"type": "actions", "elements": buttons}, + ] + + respond_in_thread( + client=client, + channel=channel, + blocks=blocks, + thread_ts=message_ts_to_respond_to, + ) + return + elif command == "/current_persona": + if persona_name is None: + respond_in_thread( + client=client, + channel=channel, + text="No persona is set. Please use the /personas command to set up a persona.", + thread_ts=message_ts_to_respond_to, + ) + return + else: + respond_in_thread( + client=client, + channel=channel, + text=f"Current persona : {persona_name}", + thread_ts=message_ts_to_respond_to, + ) + return document_set_names: list[str] | None = None - persona = channel_config.persona if channel_config else None prompt = None if persona: document_set_names = [ @@ -249,7 +344,7 @@ def handle_message( if ( "questionmark_prefilter" in channel_conf["answer_filters"] - and "?" not in messages[-1].message + and not contains_questionmark_outside_links(messages[-1].message) ): logger.info( "Skipping message since it does not contain a question mark" @@ -487,21 +582,22 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non except SlackApiError as e: logger.error(f"Failed to remove Reaction due to: {e}") - if answer.answer_valid is False: - logger.info( - "Answer was evaluated to be invalid, throwing it away without responding." - ) - update_emote_react( - emoji=DANSWER_FOLLOWUP_EMOJI, - channel=message_info.channel_to_respond, - message_ts=message_info.msg_to_respond, - remove=False, - client=client, - ) - - if answer.answer: - logger.debug(answer.answer) - return True + #Removing this as we are handling this with citations logic + # if answer.answer_valid is False: + # logger.info( + # "Answer was evaluated to be invalid, throwing it away without responding." + # ) + # update_emote_react( + # emoji=DANSWER_FOLLOWUP_EMOJI, + # channel=message_info.channel_to_respond, + # message_ts=message_info.msg_to_respond, + # remove=False, + # client=client, + # ) + + # if answer.answer: + # logger.debug(answer.answer) + # return True retrieval_info = answer.docs if not retrieval_info: @@ -571,6 +667,15 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non if matching_doc: cited_docs.append((citation.citation_num, matching_doc)) + if not cited_docs: + respond_in_thread( + client=client, + channel=channel, + text="Could not generate an answer due to a lack of relevant documents. Please try refining your search query with more context.", + thread_ts=message_ts_to_respond_to, + ) + return False + cited_docs.sort() citations_block = build_sources_blocks(cited_documents=cited_docs) elif priority_ordered_docs: diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index f8dfb600211..35633f17f6d 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -1,3 +1,4 @@ +import re import time from threading import Event from typing import Any @@ -28,6 +29,7 @@ handle_followup_resolved_button, ) from danswer.danswerbot.slack.handlers.handle_buttons import handle_slack_feedback +from danswer.danswerbot.slack.handlers.handle_buttons import handle_persona_selection from danswer.danswerbot.slack.handlers.handle_message import handle_message from danswer.danswerbot.slack.handlers.handle_message import ( remove_scheduled_feedback_reminder, @@ -93,6 +95,10 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool if not msg: channel_specific_logger.error("Cannot respond to empty message - skipping") return False + + if re.search(r"!darwin", msg, re.IGNORECASE): + channel_specific_logger.info("Ignoring message containing '!darwin'") + return False if ( req.payload.setdefault("event", {}).get("user", "") @@ -197,6 +203,22 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool "Cannot respond to DanswerBot command without sender to respond to." ) return False + + #Do not respond to messages if the channel is tagged + payload = req.payload + event = payload.get("event", {}) + blocks = event.get("blocks", []) + for block in blocks: + if block.get("type") == "rich_text": + for element in block.get("elements", []): + if element.get("type") == "rich_text_section": + for sub_element in element.get("elements", []): + if ( + sub_element.get("type") == "broadcast" + and sub_element.get("range") in {"channel", "here"} + ): + logger.info("Broadcast message detected; skipping reply.") + return False logger.debug(f"Handling Slack request with Payload: '{req.payload}'") return True @@ -277,6 +299,7 @@ def build_request_details( channel = req.payload["channel_id"] msg = req.payload["text"] sender = req.payload["user_id"] + command = req.payload["command"] single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER) @@ -288,6 +311,7 @@ def build_request_details( bypass_filters=True, is_bot_msg=True, is_bot_dm=False, + command=command ) raise RuntimeError("Programming fault, this should never happen.") @@ -356,6 +380,7 @@ def process_message( channel_config=slack_bot_config, client=client.web_client, feedback_reminder_id=feedback_reminder_id, + channel_name = channel_name ) if failed: @@ -391,6 +416,9 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None: return handle_followup_resolved_button(req, client, immediate=True) elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID: return handle_followup_resolved_button(req, client, immediate=False) + elif action["action_id"].startswith("set_persona_"): + # Persona selection + return handle_persona_selection(req, client) def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None: diff --git a/backend/danswer/danswerbot/slack/models.py b/backend/danswer/danswerbot/slack/models.py index 57a92a29753..375b92d364c 100644 --- a/backend/danswer/danswerbot/slack/models.py +++ b/backend/danswer/danswerbot/slack/models.py @@ -11,3 +11,4 @@ class SlackMessageInfo(BaseModel): bypass_filters: bool # User has tagged @DanswerBot is_bot_msg: bool # User is using /DanswerBot is_bot_dm: bool # User is direct messaging to DanswerBot + command: str | None # Slash command used by user diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index 3132aa6f24a..481f49c6b59 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -275,6 +275,7 @@ def remove_slack_text_interactions(slack_str: str) -> str: slack_str = SlackTextCleaner.replace_links(slack_str) slack_str = SlackTextCleaner.replace_special_catchall(slack_str) slack_str = SlackTextCleaner.add_zero_width_whitespace_after_tag(slack_str) + slack_str = SlackTextCleaner.handle_bold_syntax_for_slack(slack_str) return slack_str diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 909236a978b..cf1fdaa7ed9 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1108,6 +1108,16 @@ class SlackBotConfig(Base): persona: Mapped[Persona | None] = relationship("Persona") +class UserSlackPersona(Base): + __tablename__ = "user_slack_persona" + + sender_id: Mapped[str] = mapped_column(primary_key=True) + persona_id: Mapped[int | None] = mapped_column( + ForeignKey("persona.id"), nullable=True + ) + persona: Mapped[Persona | None] = relationship("Persona", foreign_keys=[persona_id]) + + class TaskQueueState(Base): # Currently refers to Celery Tasks __tablename__ = "task_queue_jobs" diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 4726cf42637..08662dc5873 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -10,6 +10,7 @@ from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session +from sqlalchemy.orm import joinedload from danswer.auth.schemas import UserRole from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX @@ -621,3 +622,14 @@ def delete_persona_by_name( db_session.execute(stmt) db_session.commit() + + +def get_persona_with_docset_and_prompts( + persona_id: int, db_session: Session +) -> Persona | None: + persona = db_session.scalar( + select(Persona) + .options(joinedload(Persona.document_sets), joinedload(Persona.prompts)) + .filter_by(id=persona_id, deleted=False) + ) + return persona diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index f8a3938027f..34e3ada7d4a 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -1,9 +1,12 @@ from collections.abc import Sequence +from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.schema import Column +from danswer.db.models import Persona from danswer.db.models import User +from danswer.db.models import UserSlackPersona def list_users(db_session: Session, q: str = "") -> Sequence[User]: @@ -19,3 +22,30 @@ def get_user_by_email(email: str, db_session: Session) -> User | None: user = db_session.query(User).filter(User.email == email).first() # type: ignore return user + + +def fetch_user_slack_persona( + db_session: Session, sender_id: str +) -> UserSlackPersona | None: + return db_session.scalar( + select(UserSlackPersona).where(UserSlackPersona.sender_id == sender_id) + ) + + +def add_user_slack_persona( + db_session: Session, sender_id: str, persona: Persona +) -> None: + user_persona = UserSlackPersona( + sender_id=sender_id, persona_id=persona.id, persona=persona + ) + db_session.add(user_persona) + db_session.commit() + + +def add_slack_persona_for_user( + db_session: Session, persona: Persona, user_slack_persona: UserSlackPersona +) -> None: + user_slack_persona.persona_id = persona.id + user_slack_persona.persona = persona + + db_session.commit() diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 6a250d02d2a..694764a6c0a 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,4 +1,5 @@ from collections.abc import Iterator +import re from typing import cast from uuid import uuid4 @@ -382,6 +383,47 @@ def _raw_output_for_non_explicit_tool_calling_llms( prompt = prompt_builder.build() yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt)) + def _fix_document_references(self, answer: str) -> str: + """ + Searches the input string for DOCUMENT references in any of these forms: + - DOCUMENT (link) + - [DOCUMENT ] (link) + - DOCUMENT + - [DOCUMENT ] + + and converts them to the proper citation format: + - If a link is provided, returns a linked citation: [[number]](link) + - Otherwise, returns a non-linked citation: [number] + + However, if an adjacent citation (linked or non-linked) for the same number already follows + immediately (ignoring whitespace), the DOCUMENT reference is not converted (i.e. it is removed) + to avoid duplicate citations. + """ + pattern = r"\[?DOCUMENT\s+(\d+)\]?(?:\s*\((.*?)\))?" + + def replacer(match: re.Match) -> str: + try: + num = int(match.group(1)) + except Exception: + return match.group(0) + + if match.group(2) and match.group(2).strip(): + citation = f"[[{num}]]({match.group(2).strip()})" + else: + citation = f"[{num}]" + + post_text = answer[match.end():] + adj_pattern = ( + r"^\s*(\[\[\s*" + re.escape(str(num)) + r"\s*\]\]\([^)]+\)|\[\s*" + re.escape(str(num)) + r"\s*\])" + ) + if re.match(adj_pattern, post_text): + # If an adjacent citation for the same number exists, return an empty string (skip replacement). + return "" + else: + return citation + + return re.sub(pattern, replacer, answer) + @property def processed_streamed_output(self) -> AnswerStream: if self._processed_stream is not None: @@ -465,7 +507,7 @@ def llm_answer(self) -> str: if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: answer += packet.answer_piece - return answer + return self._fix_document_references(answer) @property def citations(self) -> list[CitationInfo]: diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 0b9924b2bed..c692173dc43 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -72,11 +72,11 @@ def __init__( # Not used here but you probably want a model server that isn't completely open api_key: str | None, timeout: int, - endpoint: str | None = 'https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-03-15-preview', + endpoint: str | None = 'https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-4o-mini-2024-07-18/chat/completions?api-version=2024-06-01', identity_url: str | None = GEN_AI_IDENTITY_ENDPOINT, client_id: str | None = GEN_AI_CLIENT_ID, client_secret: str | None = GEN_AI_CLIENT_SECRET, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + max_output_tokens: int = int(GEN_AI_MAX_OUTPUT_TOKENS), api_version: str | None = GEN_AI_API_VERSION, ): diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index b315c3662c7..62a0439758d 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -310,49 +310,66 @@ def get_search_answer( rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, ) -> OneShotQAResponse: """Collects the streamed one shot answer responses into a single object""" - qa_response = OneShotQAResponse() + max_attempts = 5 + attempt = 0 + qa_response = None - results = stream_answer_objects( - query_req=query_req, - user=user, - max_document_tokens=max_document_tokens, - max_history_tokens=max_history_tokens, - db_session=db_session, - bypass_acl=bypass_acl, - use_citations=use_citations, - danswerbot_flow=danswerbot_flow, - timeout=answer_generation_timeout, - retrieval_metrics_callback=retrieval_metrics_callback, - rerank_metrics_callback=rerank_metrics_callback, - ) + while attempt < max_attempts: + qa_response = OneShotQAResponse() - answer = "" - for packet in results: - if isinstance(packet, QueryRephrase): - qa_response.rephrase = packet.rephrased_query - if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: - answer += packet.answer_piece - elif isinstance(packet, QADocsResponse): - qa_response.docs = packet - elif isinstance(packet, LLMRelevanceFilterResponse): - qa_response.llm_chunks_indices = packet.relevant_chunk_indices - elif isinstance(packet, DanswerQuotes): - qa_response.quotes = packet - elif isinstance(packet, CitationInfo): - if qa_response.citations: - qa_response.citations.append(packet) - else: - qa_response.citations = [packet] - elif isinstance(packet, DanswerContexts): - qa_response.contexts = packet - elif isinstance(packet, StreamingError): - qa_response.error_msg = packet.error - elif isinstance(packet, ChatMessageDetail): - qa_response.chat_message_id = packet.message_id - - if answer: - qa_response.answer = answer + results = stream_answer_objects( + query_req=query_req, + user=user, + max_document_tokens=max_document_tokens, + max_history_tokens=max_history_tokens, + db_session=db_session, + bypass_acl=bypass_acl, + use_citations=use_citations, + danswerbot_flow=danswerbot_flow, + timeout=answer_generation_timeout, + retrieval_metrics_callback=retrieval_metrics_callback, + rerank_metrics_callback=rerank_metrics_callback, + ) + answer = "" + for packet in results: + if isinstance(packet, QueryRephrase): + qa_response.rephrase = packet.rephrased_query + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + elif isinstance(packet, QADocsResponse): + qa_response.docs = packet + elif isinstance(packet, LLMRelevanceFilterResponse): + qa_response.llm_chunks_indices = packet.relevant_chunk_indices + elif isinstance(packet, DanswerQuotes): + qa_response.quotes = packet + elif isinstance(packet, CitationInfo): + if qa_response.citations: + qa_response.citations.append(packet) + else: + qa_response.citations = [packet] + elif isinstance(packet, DanswerContexts): + qa_response.contexts = packet + elif isinstance(packet, StreamingError): + qa_response.error_msg = packet.error + elif isinstance(packet, ChatMessageDetail): + qa_response.chat_message_id = packet.message_id + + if answer: + qa_response.answer = answer + + if use_citations and qa_response.answer and qa_response.citations: + qa_response.answer, qa_response.citations = reorganize_citations( + qa_response.answer, qa_response.citations + ) + break # Citations found, break out of retry loop. + # If citations are not required, we can exit immediately. + elif not use_citations: + break + + logger.info(f"Citations not found, retrying... (attempt {attempt + 1}/{max_attempts})") + attempt += 1 + if enable_reflexion: # Because follow up messages are explicitly tagged, we don't need to verify the answer if len(query_req.messages) == 1: @@ -361,10 +378,4 @@ def get_search_answer( else: qa_response.answer_valid = True - if use_citations and qa_response.answer and qa_response.citations: - # Reorganize citation nums to be in the same order as the answer - qa_response.answer, qa_response.citations = reorganize_citations( - qa_response.answer, qa_response.citations - ) - return qa_response diff --git a/backend/danswer/prompts/chat_prompts.py b/backend/danswer/prompts/chat_prompts.py index bf2a9d8cdc7..2395674290d 100644 --- a/backend/danswer/prompts/chat_prompts.py +++ b/backend/danswer/prompts/chat_prompts.py @@ -2,7 +2,7 @@ from danswer.prompts.constants import QUESTION_PAT REQUIRE_CITATION_STATEMENT = """ -Cite relevant statements INLINE using the format [1], [2], [3], etc to reference the document number, \ +CRUCIAL: Cite relevant statements INLINE using the format [1], [2], [3], etc to reference the document number, \ DO NOT provide a reference section at the end and DO NOT provide any links following the citations. """.rstrip() @@ -11,7 +11,7 @@ """.rstrip() CITATION_REMINDER = """ -Remember to provide inline citations in the format [1], [2], [3], etc. +MANDATORY: For any information sourced from documents in your response, include inline citations formatted as [1], [2], [3], etc. """ ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}." diff --git a/darwin-kubernetes/env-configmap.yaml b/darwin-kubernetes/env-configmap.yaml index e796c046e94..d5dd301ffed 100644 --- a/darwin-kubernetes/env-configmap.yaml +++ b/darwin-kubernetes/env-configmap.yaml @@ -15,7 +15,7 @@ data: EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead # Gen AI Settings GEN_AI_MODEL_PROVIDER: "custom" - GEN_AI_API_ENDPOINT: "https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-03-15-preview" + GEN_AI_API_ENDPOINT: "https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-4o-mini-2024-07-18/chat/completions?api-version=2024-06-01" GEN_AI_IDENTITY_ENDPOINT: "https://alpha.uipath.com/identity_/connect/token" GEN_AI_CLIENT_ID: "XXX" GEN_AI_CLIENT_SECRET: "XXX" diff --git a/deployment/docker_compose/docker-compose.local.yml b/deployment/docker_compose/docker-compose.local.yml index 7583b09b836..99179c0747f 100644 --- a/deployment/docker_compose/docker-compose.local.yml +++ b/deployment/docker_compose/docker-compose.local.yml @@ -39,7 +39,7 @@ services: - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-03-15-preview + - GEN_AI_API_ENDPOINT=https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-4o-mini-2024-07-18/chat/completions?api-version=2024-06-01 - GEN_AI_IDENTITY_ENDPOINT=https://alpha.uipath.com/identity_/connect/token - GEN_AI_CLIENT_ID=${GEN_AI_CLIENT_ID:-} - GEN_AI_CLIENT_SECRET=${GEN_AI_CLIENT_SECRET:-} @@ -128,7 +128,7 @@ services: - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-custom} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-03-15-preview + - GEN_AI_API_ENDPOINT=https://alpha.uipath.com/llmgateway_/openai/deployments/gpt-4o-mini-2024-07-18/chat/completions?api-version=2024-06-01 - GEN_AI_IDENTITY_ENDPOINT=https://alpha.uipath.com/identity_/connect/token - GEN_AI_CLIENT_ID=${GEN_AI_CLIENT_ID:-} - GEN_AI_CLIENT_SECRET=${GEN_AI_CLIENT_SECRET:-}