diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 153c4e9260..70ee8b79a7 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -1,5 +1,6 @@ import json import os +from dataclasses import dataclass from importlib.util import find_spec from typing import Any, Dict, List, Optional, Union @@ -11,6 +12,19 @@ _databricks_sdk_installed = find_spec("databricks.sdk") is not None +@dataclass +class Document: + page_content: str + metadata: Dict[str, Any] + type: str + + def to_dict(self) -> Dict[str, Any]: + return { + "page_content": self.page_content, + "metadata": self.metadata, + "type": self.type, + } + class DatabricksRM(dspy.Retrieve): """ A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k @@ -76,6 +90,7 @@ def __init__( k: int = 3, docs_id_column_name: str = "id", text_column_name: str = "text", + use_with_databricks_agent_framework: bool = False, ): """ Args: @@ -100,6 +115,8 @@ def __init__( containing document IDs. text_column_name (str): The name of the column in the Databricks Vector Search Index containing document text to retrieve. + use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is + compatible with the Databricks Mosaic Agent Framework. """ super().__init__(k=k) self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN") @@ -119,6 +136,20 @@ def __init__( self.k = k self.docs_id_column_name = docs_id_column_name self.text_column_name = text_column_name + self.use_with_databricks_agent_framework = use_with_databricks_agent_framework + if self.use_with_databricks_agent_framework: + try: + import mlflow + mlflow.models.set_retriever_schema( + primary_key="doc_id", + text_column="page_content", + doc_uri="doc_uri", + ) + except ImportError: + raise ValueError( + "To use the `DatabricksRM` retriever module with the Databricks Mosaic Agent Framework, " + "you must install the mlflow Python library. Please install mlflow via `pip install mlflow`." + ) def _extract_doc_ids(self, item: Dict[str, Any]) -> str: """Extracts the document id from a search result @@ -154,7 +185,7 @@ def forward( query: Union[str, List[float]], query_type: str = "ANN", filters_json: Optional[str] = None, - ) -> dspy.Prediction: + ) -> Union[dspy.Prediction, List[Dict[str, Any]]]: """ Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the specified query. @@ -239,12 +270,23 @@ def forward( # Sorting results by score in descending order sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k] - # Returning the prediction - return Prediction( - docs=[doc[self.text_column_name] for doc in sorted_docs], - doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs], - extra_columns=[self._get_extra_columns(item) for item in sorted_docs], - ) + if self.use_with_databricks_agent_framework: + return [Document( + page_content=doc[self.text_column_name], + metadata={ + "doc_id": self._extract_doc_ids(doc), + "doc_uri": f"index/{self.databricks_index_name}/id/{self._extract_doc_ids(doc)}", + } + | self._get_extra_columns(doc), + ).to_dict() for doc in sorted_docs] + else: + # Returning the prediction + return Prediction( + docs=[doc[self.text_column_name] for doc in sorted_docs], + doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs], + extra_columns=[self._get_extra_columns(item) for item in sorted_docs], + ) + @staticmethod def _query_via_databricks_sdk(