Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DatabricksRM compatible with Mosaic agent framework #1800

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions dspy/retrieve/databricks_rm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Any, Dict, List, Optional, Union

Expand All @@ -11,6 +12,19 @@
_databricks_sdk_installed = find_spec("databricks.sdk") is not None


@dataclass
class Document:
page_content: str
metadata: Dict[str, Any]
type: str

def to_dict(self) -> Dict[str, Any]:
return {
"page_content": self.page_content,
"metadata": self.metadata,
"type": self.type,
}

class DatabricksRM(dspy.Retrieve):
"""
A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k
Expand Down Expand Up @@ -76,6 +90,7 @@ def __init__(
k: int = 3,
docs_id_column_name: str = "id",
text_column_name: str = "text",
use_with_databricks_agent_framework: bool = False,
):
"""
Args:
Expand All @@ -100,6 +115,8 @@ def __init__(
containing document IDs.
text_column_name (str): The name of the column in the Databricks Vector Search Index
containing document text to retrieve.
use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is
compatible with the Databricks Mosaic Agent Framework.
"""
super().__init__(k=k)
self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN")
Expand All @@ -119,6 +136,20 @@ def __init__(
self.k = k
self.docs_id_column_name = docs_id_column_name
self.text_column_name = text_column_name
self.use_with_databricks_agent_framework = use_with_databricks_agent_framework
if self.use_with_databricks_agent_framework:
try:
import mlflow
mlflow.models.set_retriever_schema(
primary_key="doc_id",
text_column="page_content",
doc_uri="doc_uri",
)
except ImportError:
raise ValueError(
"To use the `DatabricksRM` retriever module with the Databricks Mosaic Agent Framework, "
"you must install the mlflow Python library. Please install mlflow via `pip install mlflow`."
)

def _extract_doc_ids(self, item: Dict[str, Any]) -> str:
"""Extracts the document id from a search result
Expand Down Expand Up @@ -154,7 +185,7 @@ def forward(
query: Union[str, List[float]],
query_type: str = "ANN",
filters_json: Optional[str] = None,
) -> dspy.Prediction:
) -> Union[dspy.Prediction, List[Dict[str, Any]]]:
"""
Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the
specified query.
Expand All @@ -172,7 +203,9 @@ def forward(
parameter overrides the `filters_json` parameter passed to the constructor.

Returns:
dspy.Prediction: An object containing the retrieved results.
A list of dictionaries when ``use_with_databricks_agent_framework`` is ``True``,
or a ``dspy.Prediction`` object when ``use_with_databricks_agent_framework`` is
``False``.
"""
if query_type in ["vector", "text"]:
# Older versions of DSPy used a `query_type` argument to disambiguate between text
Expand Down Expand Up @@ -239,12 +272,24 @@ def forward(
# Sorting results by score in descending order
sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k]

# Returning the prediction
return Prediction(
docs=[doc[self.text_column_name] for doc in sorted_docs],
doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
)
if self.use_with_databricks_agent_framework:
return [Document(
page_content=doc[self.text_column_name],
metadata={
"doc_id": self._extract_doc_ids(doc),
"doc_uri": f"index/{self.databricks_index_name}/id/{self._extract_doc_ids(doc)}",
}
| self._get_extra_columns(doc),
type="Document",
).to_dict() for doc in sorted_docs]
else:
# Returning the prediction
return Prediction(
docs=[doc[self.text_column_name] for doc in sorted_docs],
doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
)


@staticmethod
def _query_via_databricks_sdk(
Expand Down
Loading