diff --git a/dspy/__init__.py b/dspy/__init__.py index fea48caca..524f4d619 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -3,8 +3,7 @@ from dspy.retrieve import * from dspy.signatures import * from dspy.teleprompt import * - -import dspy.retrievers +from dspy.retrievers import * from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip @@ -27,6 +26,7 @@ import dspy.teleprompt +ColBERTv2 = ColBERTv2 LabeledFewShot = dspy.teleprompt.LabeledFewShot BootstrapFewShot = dspy.teleprompt.BootstrapFewShot BootstrapFewShotWithRandomSearch = dspy.teleprompt.BootstrapFewShotWithRandomSearch diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 0636497d8..3753ee004 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -15,10 +15,9 @@ from dspy.adapters.base import Adapter from dspy.adapters.utils import find_enum_member, format_field_value, serialize_for_json - -from ..adapters.image_utils import Image -from ..signatures.signature import SignatureMeta -from ..signatures.utils import get_dspy_field_type +from dspy.adapters.image_utils import Image +from dspy.signatures.signature import SignatureMeta +from dspy.signatures.utils import get_dspy_field_type _logger = logging.getLogger(__name__) @@ -38,7 +37,8 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs): try: provider = lm.model.split("/", 1)[0] or "openai" - if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): + params = litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider) + if params and "response_format" in params: try: response_format = _get_structured_outputs_response_format(signature) outputs = lm(**inputs, **lm_kwargs, response_format=response_format) diff --git a/dspy/clients/anyscale.py b/dspy/clients/anyscale.py index 3d3022867..459f3798a 100644 --- a/dspy/clients/anyscale.py +++ b/dspy/clients/anyscale.py @@ -5,9 +5,9 @@ import yaml import logging -from dspy.clients.finetune import ( +from dspy.clients.lm import ( FinetuneJob, - # TrainingMethod, + TrainingMethod, save_data, ) from dspy.clients.openai import openai_data_validation @@ -182,11 +182,6 @@ def start_remote_training(job_config) -> str: return job_id -def wait_for_training(job_id): - print("Waiting for training to complete") - anyscale.job.wait(id=job_id, timeout_s=18000) - - def get_model_info(job_id): print("[Finetune] Retrieving model information from Anyscale Models SDK...") info = anyscale.llm.model.get(job_id=job_id).to_dict() diff --git a/dspy/dsp/colbertv2.py b/dspy/dsp/colbertv2.py index 6b1fd8cb3..93abb22e8 100644 --- a/dspy/dsp/colbertv2.py +++ b/dspy/dsp/colbertv2.py @@ -1,4 +1,5 @@ import functools +import importlib.util from typing import Any, List, Optional, Union import requests @@ -8,7 +9,6 @@ # TODO: Ideally, this takes the name of the index and looks up its port. - class ColBERTv2: """Wrapper for the ColBERTv2 Retrieval.""" @@ -76,7 +76,7 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): colbertv2_post_request = colbertv2_post_request_v2_wrapped class ColBERTv2RetrieverLocal: - def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False): + def __init__(self,passages: List[str], colbert_config=None, load_only=False): """Colbertv2 retriever module Args: @@ -84,6 +84,9 @@ def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False): colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to None. load_only (bool, optional): whether to load the index or build and then load. Defaults to False. """ + if importlib.util.find_spec("colbert") is None: + raise ImportError("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + assert colbert_config is not None, "Please pass a valid colbert_config, which you can import from colbert.infra.config import ColBERTConfig and modify it" self.colbert_config = colbert_config @@ -101,12 +104,6 @@ def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False): self.searcher = self.get_index() def build_index(self): - - try: - import colbert - except ImportError: - print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") - from colbert import Indexer from colbert.infra import Run, RunConfig with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)): @@ -114,11 +111,6 @@ def build_index(self): indexer.index(name=self.colbert_config.index_name, collection=self.passages, overwrite=True) def get_index(self): - try: - import colbert - except ImportError: - print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") - from colbert import Searcher from colbert.infra import Run, RunConfig @@ -153,16 +145,15 @@ def forward(self,query:str,k:int=7,**kwargs): class ColBERTv2RerankerLocal: def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'): - try: - import colbert - except ImportError: - print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") """_summary_ Args: colbert_config (ColBERTConfig, optional): Colbert config. Defaults to None. checkpoint_name (str, optional): checkpoint for embeddings. Defaults to 'bert-base-uncased'. """ + + if importlib.util.find_spec("colbert") is None: + raise ImportError("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") self.colbert_config = colbert_config self.checkpoint = checkpoint self.colbert_config.checkpoint = checkpoint diff --git a/dspy/dsp/utils/dpr.py b/dspy/dsp/utils/dpr.py index d4d18f84f..83c58b70e 100644 --- a/dspy/dsp/utils/dpr.py +++ b/dspy/dsp/utils/dpr.py @@ -5,10 +5,15 @@ """ import unicodedata +import logging +import copy import regex +logger = logging.getLogger(__name__) + + class Tokens: """A class to represent a list of tokenized text.""" TEXT = 0 diff --git a/dspy/dsp/utils/utils.py b/dspy/dsp/utils/utils.py index 4025dfa3f..b651282da 100644 --- a/dspy/dsp/utils/utils.py +++ b/dspy/dsp/utils/utils.py @@ -232,7 +232,7 @@ def load_batch_backgrounds(args, qids): for qid in qids: back = args.qid2backgrounds[qid] - if len(back) and type(back[0]) == int: + if len(back) and isinstance(back[0], int): x = [args.collection[pid] for pid in back] else: x = [args.collectionX.get(pid, "") for pid in back] diff --git a/dspy/predict/parallel.py b/dspy/predict/parallel.py index 78ef3e8ea..f7cb6bbb8 100644 --- a/dspy/predict/parallel.py +++ b/dspy/predict/parallel.py @@ -2,8 +2,8 @@ from typing import Tuple, List, Any -from ..primitives.example import Example -from ..utils.parallelizer import ParallelExecutor +from dspy.primitives.example import Example +from dspy.utils.parallelizer import ParallelExecutor class Parallel: diff --git a/dspy/predict/program_of_thought.py b/dspy/predict/program_of_thought.py index 4497af12e..84cfc4a0c 100644 --- a/dspy/predict/program_of_thought.py +++ b/dspy/predict/program_of_thought.py @@ -3,8 +3,8 @@ import dspy from dspy.signatures.signature import ensure_signature -from ..primitives.program import Module -from ..primitives.python_interpreter import CodePrompt, PythonInterpreter +from dspy.primitives.program import Module +from dspy.primitives.python_interpreter import CodePrompt, PythonInterpreter class ProgramOfThought(Module): diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 1e760180a..8a2a2a72b 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -26,8 +26,11 @@ class InterpreterError(ValueError): expression, due to syntax error or unsupported operations. """ +class BreakException(Exception): pass +class ContinueException(Exception): + pass class PythonInterpreter: r"""A customized python interpreter to control the execution of diff --git a/dspy/propose/dataset_summary_generator.py b/dspy/propose/dataset_summary_generator.py index 7a4c03a96..e282359aa 100644 --- a/dspy/propose/dataset_summary_generator.py +++ b/dspy/propose/dataset_summary_generator.py @@ -45,7 +45,8 @@ def reorder_keys(match): return ordered_repr def create_dataset_summary(trainset, view_data_batch_size, prompt_model, log_file=None, verbose=False): - if verbose: print("\nBootstrapping dataset summary (this will be used to generate instructions)...") + if verbose: + print("\nBootstrapping dataset summary (this will be used to generate instructions)...") upper_lim = min(len(trainset), view_data_batch_size) prompt_model = prompt_model if prompt_model else dspy.settings.lm with dspy.settings.context(lm=prompt_model): @@ -63,7 +64,8 @@ def create_dataset_summary(trainset, view_data_batch_size, prompt_model, log_fil calls+=1 if calls >= max_calls: break - if verbose: print(f"b: {b}") + if verbose: + print(f"b: {b}") upper_lim = min(len(trainset), b+view_data_batch_size) with dspy.settings.context(lm=prompt_model): output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)(prior_observations=observations, examples=order_input_keys_in_string(trainset[b:upper_lim].__repr__())) @@ -77,17 +79,20 @@ def create_dataset_summary(trainset, view_data_batch_size, prompt_model, log_fil if log_file: log_file.write(f"observations {observations}\n") except Exception as e: - if verbose: print(f"e {e}. using observations from past round for a summary.") + if verbose: + print(f"e {e}. using observations from past round for a summary.") if prompt_model: with dspy.settings.context(lm=prompt_model): summary = dspy.Predict(ObservationSummarizer, n=1, temperature=1.0)(observations=observations) else: summary = dspy.Predict(ObservationSummarizer, n=1, temperature=1.0)(observations=observations) - if verbose: print(f"summary: {summary}") + if log_file: log_file.write(f"summary: {summary}\n") - - if verbose: print(f"\nGenerated summary: {strip_prefix(summary.summary)}\n") + + if verbose: + print(f"summary: {summary}") + print(f"\nGenerated summary: {strip_prefix(summary.summary)}\n") return strip_prefix(summary.summary) diff --git a/dspy/propose/grounded_proposer.py b/dspy/propose/grounded_proposer.py index 15ef13130..31f211dd4 100644 --- a/dspy/propose/grounded_proposer.py +++ b/dspy/propose/grounded_proposer.py @@ -195,7 +195,8 @@ def forward( program_code=self.program_code_string, program_example=task_demos, ).program_description, ) - if self.verbose: print(f"PROGRAM DESCRIPTION: {program_description}") + if self.verbose: + print(f"PROGRAM DESCRIPTION: {program_description}") inputs = [] outputs = [] @@ -218,12 +219,14 @@ def forward( module=module_code, max_depth=10, ).module_description - except: - if self.verbose: print("Error getting program description. Running without program aware proposer.") + except Exception: + if self.verbose: + print("Error getting program description. Running without program aware proposer.") self.program_aware = False # Generate an instruction for our chosen module - if self.verbose: print(f"task_demos {task_demos}") + if self.verbose: + print(f"task_demos {task_demos}") instruct = self.generate_module_instruction( dataset_description=data_summary, program_code=self.program_code_string, @@ -237,7 +240,8 @@ def forward( ) if hasattr(instruct, "module_description"): module_description = strip_prefix(instruct.module_description) - if self.verbose: print(f"MODULE DESCRIPTION: {module_description}") + if self.verbose: + print(f"MODULE DESCRIPTION: {module_description}") proposed_instruction = strip_prefix(instruct.proposed_instruction) return dspy.Prediction(proposed_instruction=proposed_instruction) @@ -278,7 +282,8 @@ def __init__( if self.program_aware: try: self.program_code_string = get_dspy_source_code(program) - if self.verbose: print("SOURCE CODE:",self.program_code_string) + if self.verbose: + print("SOURCE CODE:",self.program_code_string) except Exception as e: print(f"Error getting source code: {e}.\n\nRunning without program aware proposer.") self.program_aware = False @@ -289,7 +294,8 @@ def __init__( self.data_summary = create_dataset_summary( trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model, ) - if self.verbose: print(f"DATA SUMMARY: {self.data_summary}") + if self.verbose: + print(f"DATA SUMMARY: {self.data_summary}") except Exception as e: print(f"Error getting data summary: {e}.\n\nRunning without data aware proposer.") self.use_dataset_summary = False @@ -313,12 +319,14 @@ def propose_instructions_for_program( # Randomly select whether or not we're using instruction history use_history = self.rng.random() < 0.5 self.use_instruct_history = use_history - if self.verbose: print(f"Use history T/F: {self.use_instruct_history}") + if self.verbose: + print(f"Use history T/F: {self.use_instruct_history}") num_demos = max(len(demo_candidates[0]) if demo_candidates else N, 1) if not demo_candidates: - if self.verbose: print("No demo candidates provided. Running without task demos.") + if self.verbose: + print("No demo candidates provided. Running without task demos.") self.use_task_demos = False # Create an instruction for each predictor @@ -327,14 +335,14 @@ def propose_instructions_for_program( if pred_i not in proposed_instructions: proposed_instructions[pred_i] = [] if self.set_tip_randomly: - if self.verbose: print("Using a randomly generated configuration for our grounded proposer.") + if self.verbose: + print("Using a randomly generated configuration for our grounded proposer.") # Randomly select the tip selected_tip_key = self.rng.choice(list(TIPS.keys())) selected_tip = TIPS[selected_tip_key] - self.use_tip = bool( - selected_tip, - ) - if self.verbose: print(f"Selected tip: {selected_tip_key}") + self.use_tip = bool(selected_tip) + if self.verbose: + print(f"Selected tip: {selected_tip_key}") proposed_instructions[pred_i].append( self.propose_instruction_for_predictor( @@ -399,7 +407,8 @@ def propose_instruction_for_predictor( self.prompt_model.kwargs["temperature"] = original_temp # Log the trace used to generate the new instruction, along with the new instruction itself - if self.verbose: self.prompt_model.inspect_history(n=1) - if self.verbose: print(f"PROPOSED INSTRUCTION: {proposed_instruction}") + if self.verbose: + self.prompt_model.inspect_history(n=1) + print(f"PROPOSED INSTRUCTION: {proposed_instruction}") return strip_prefix(proposed_instruction) diff --git a/dspy/propose/instruction_proposal.py b/dspy/propose/instruction_proposal.py index 2967427ca..52e363d08 100644 --- a/dspy/propose/instruction_proposal.py +++ b/dspy/propose/instruction_proposal.py @@ -1,4 +1,6 @@ import dspy +import dsp + from dspy.signatures import Signature @@ -80,13 +82,6 @@ class BasicGenerateInstructionWithDataObservations(Signature): proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") -class BasicGenerateInstruction(Signature): - """You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" - - basic_instruction = dspy.InputField(desc="The initial instructions before optimization") - proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model") - proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task") - class BasicGenerateInstructionAllFields(Signature): """You are an instruction optimizer for large language models. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.""" ("""You are an instruction optimizer for large language models. I will provide you with""") diff --git a/dspy/propose/utils.py b/dspy/propose/utils.py index 99e146f71..ad18a693f 100644 --- a/dspy/propose/utils.py +++ b/dspy/propose/utils.py @@ -2,6 +2,7 @@ import re import dspy import inspect + try: from IPython.core.magics.code import extract_symbols except ImportError: @@ -10,12 +11,14 @@ from dspy.predict.parameter import Parameter -from dspy.teleprompt.utils import get_signature +from dspy.teleprompt.utils import get_signature, new_getfile + def strip_prefix(text): - pattern = r'^[\*\s]*(([\w\'\-]+\s+){0,4}[\w\'\-]+):\s*' - modified_text = re.sub(pattern, '', text) - return modified_text.strip("\"") + pattern = r"^[\*\s]*(([\w\'\-]+\s+){0,4}[\w\'\-]+):\s*" + modified_text = re.sub(pattern, "", text) + return modified_text.strip('"') + def create_instruction_set_history_string(base_program, trial_logs, top_n): program_history = [] @@ -24,10 +27,12 @@ def create_instruction_set_history_string(base_program, trial_logs, top_n): if "program_path" in trial: trial_program = base_program.deepcopy() trial_program.load(trial["program_path"]) - program_history.append({ - "program": trial_program, - "score": trial["score"], - }) + program_history.append( + { + "program": trial_program, + "score": trial["score"], + } + ) # Deduplicate program history based on the program's instruction set seen_programs = set() @@ -38,9 +43,9 @@ def create_instruction_set_history_string(base_program, trial_logs, top_n): if instruction_set not in seen_programs: seen_programs.add(instruction_set) unique_program_history.append(entry) - + # Get the top n programs from program history - top_n_program_history = sorted(unique_program_history, key=lambda x: x['score'], reverse=True)[:top_n] + top_n_program_history = sorted(unique_program_history, key=lambda x: x["score"], reverse=True)[:top_n] top_n_program_history.reverse() # Create formatted string @@ -50,9 +55,10 @@ def create_instruction_set_history_string(base_program, trial_logs, top_n): score = entry["score"] instruction_set = get_program_instruction_set_string(program) instruction_set_history_string += instruction_set + f" | Score: {score}\n\n" - + return instruction_set_history_string + def parse_list_of_instructions(instruction_string): # Try to convert the string representation of a list to an actual list using JSON try: @@ -60,52 +66,56 @@ def parse_list_of_instructions(instruction_string): return instructions except json.JSONDecodeError: pass - + # If JSON decoding fails, extract strings within quotes instructions = re.findall(r'"([^"]*)"', instruction_string) return instructions + def get_program_instruction_set_string(program): instruction_list = [] for _, pred in enumerate(program.predictors()): pred_instructions = get_signature(pred).instructions - instruction_list.append(f"\"{pred_instructions}\"") + instruction_list.append(f'"{pred_instructions}"') # Joining the list into a single string that looks like a list return f"[{', '.join(instruction_list)}]" + def create_predictor_level_history_string(base_program, predictor_i, trial_logs, top_n): instruction_aggregate = {} instruction_history = [] - + # Load trial programs for trial_num in trial_logs: trial = trial_logs[trial_num] if "program_path" in trial: trial_program = base_program.deepcopy() trial_program.load(trial["program_path"]) - instruction_history.append({ - "program": trial_program, - "score": trial["score"], - }) + instruction_history.append( + { + "program": trial_program, + "score": trial["score"], + } + ) # Aggregate scores for each instruction for history_item in instruction_history: predictor = history_item["program"].predictors()[predictor_i] instruction = get_signature(predictor).instructions score = history_item["score"] - + if instruction in instruction_aggregate: - instruction_aggregate[instruction]['total_score'] += score - instruction_aggregate[instruction]['count'] += 1 + instruction_aggregate[instruction]["total_score"] += score + instruction_aggregate[instruction]["count"] += 1 else: - instruction_aggregate[instruction] = {'total_score': score, 'count': 1} - + instruction_aggregate[instruction] = {"total_score": score, "count": 1} + # Calculate average score for each instruction and prepare for sorting predictor_history = [] for instruction, data in instruction_aggregate.items(): - average_score = data['total_score'] / data['count'] + average_score = data["total_score"] / data["count"] predictor_history.append((instruction, average_score)) - + # Deduplicate and sort by average score, then select top N seen_instructions = set() unique_predictor_history = [] @@ -116,16 +126,16 @@ def create_predictor_level_history_string(base_program, predictor_i, trial_logs, top_instructions = sorted(unique_predictor_history, key=lambda x: x[1], reverse=True)[:top_n] top_instructions.reverse() - + # Create formatted history string predictor_history_string = "" for instruction, score in top_instructions: predictor_history_string += instruction + f" | Score: {score}\n\n" - + return predictor_history_string -def create_example_string(fields, example): +def create_example_string(fields, example): # Building the output string output = [] for field_name, field_values in fields.items(): @@ -139,7 +149,8 @@ def create_example_string(fields, example): output.append(field_str) # Joining all the field strings - return '\n'.join(output) + return "\n".join(output) + def get_dspy_source_code(module): header = [] @@ -166,18 +177,22 @@ def get_dspy_source_code(module): if item in completed_set: continue if isinstance(item, Parameter): - if hasattr(item, 'signature') and item.signature is not None and item.signature.__pydantic_parent_namespace__['signature_name'] + "_sig" not in completed_set: + if ( + hasattr(item, "signature") + and item.signature is not None + and item.signature.__pydantic_parent_namespace__["signature_name"] + "_sig" not in completed_set + ): try: header.append(inspect.getsource(item.signature)) print(inspect.getsource(item.signature)) except (TypeError, OSError): header.append(str(item.signature)) - completed_set.add(item.signature.__pydantic_parent_namespace__['signature_name'] + "_sig") + completed_set.add(item.signature.__pydantic_parent_namespace__["signature_name"] + "_sig") if isinstance(item, dspy.Module): code = get_dspy_source_code(item).strip() if code not in completed_set: header.append(code) completed_set.add(code) completed_set.add(item) - - return '\n\n'.join(header) + '\n\n' + base_code \ No newline at end of file + + return "\n\n".join(header) + "\n\n" + base_code diff --git a/dspy/retrieve/chromadb_rm.py b/dspy/retrieve/chromadb_rm.py index ee6219ba9..c102f8565 100644 --- a/dspy/retrieve/chromadb_rm.py +++ b/dspy/retrieve/chromadb_rm.py @@ -25,7 +25,6 @@ EmbeddingFunction, ) from chromadb.config import Settings - from chromadb.utils import embedding_functions except ImportError: raise ImportError( "The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`", diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index a21cc8c3c..e88b88a8b 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -1,6 +1,9 @@ -import random +from functools import lru_cache from typing import List, Optional, Union +import logging +import random +from dspy import dsp from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction from dspy.utils.callback import with_callbacks @@ -16,12 +19,24 @@ def single_query_passage(passages): return Prediction(**passages_dict) +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=None) +def warn_once(msg: str): + logger.warning(msg) + + class Retrieve(Parameter): name = "Search" input_variable = "query" desc = "takes a search query and returns one or more potentially relevant passages from a corpus" def __init__(self, k=3, callbacks=None): + warn_once( + "Existing retriever integrations under dspy/retrieve inheriting `dspy.Retrieve` are deprecated and will be removed in the DSPy 2.7 release. \n" + "For future retriever integrations, please use the `dspy.Retriever` interface under dspy/retriever/retriever.py and reference any of the custom integrations supported in dspy/retriever/" + ) self.stage = random.randbytes(8).hex() self.k = k self.callbacks = callbacks or [] @@ -57,6 +72,7 @@ def forward( passages = dspy.settings.rm(query, k=k, **kwargs) from collections.abc import Iterable + if not isinstance(passages, Iterable): # it's not an iterable yet; make it one. # TODO: we should unify the type signatures of dspy.Retriever @@ -65,4 +81,64 @@ def forward( return Prediction(passages=passages) + # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. +# TODO potentially add for deprecation/removal in 2.7 +class RetrieveThenRerank(Parameter): + name = "Search" + input_variable = "query" + desc = ( + "takes a search query and returns one or more potentially relevant passages followed by reranking from a corpus" + ) + + def __init__(self, k=3): + self.stage = random.randbytes(8).hex() + self.k = k + + def reset(self): + pass + + def dump_state(self, save_verbose=False): + """save_verbose is set as a default argument to support the inherited Parameter interface for dump_state""" + state_keys = ["k"] + return {k: getattr(self, k) for k in state_keys} + + def load_state(self, state): + for name, value in state.items(): + setattr(self, name, value) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward( + self, + query_or_queries: Union[str, List[str]], + k: Optional[int] = None, + with_metadata: bool = False, + **kwargs, + ) -> Union[List[str], Prediction, List[Prediction]]: + queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries + queries = [query.strip().split("\n")[0].strip() for query in queries] + + # print(queries) + # TODO: Consider removing any quote-like markers that surround the query too. + k = k if k is not None else self.k + if not with_metadata: + passages = dsp.retrieveRerankEnsemble(queries, k=k, **kwargs) + return passages + else: + passages = dsp.retrieveRerankEnsemblewithMetadata(queries, k=k, **kwargs) + if isinstance(passages[0], List): + pred_returns = [] + for query_passages in passages: + passages_dict = {key: [] for key in list(query_passages[0].keys())} + for docs in query_passages: + for key, value in docs.items(): + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], dict): + return single_query_passage(passages=passages) diff --git a/dspy/retriever/__init__.py b/dspy/retriever/__init__.py new file mode 100644 index 000000000..ef7ef4a14 --- /dev/null +++ b/dspy/retriever/__init__.py @@ -0,0 +1,2 @@ +from .retriever import Retriever +from .colbertv2 import ColBERTv2 \ No newline at end of file diff --git a/dspy/retriever/colbertv2.py b/dspy/retriever/colbertv2.py new file mode 100644 index 000000000..bb641eadc --- /dev/null +++ b/dspy/retriever/colbertv2.py @@ -0,0 +1,89 @@ +from typing import Any, Union, Optional, List +import functools +import requests + +from dsp.cache_utils import CacheMemory, NotebookCacheMemory +from dsp.utils import dotdict +from dspy.retriever import Retriever + + +class ColBERTv2(Retriever): + """ + ColBERTv2 Retriever for retrieval of top-k most relevant text passages for given query. + + Args: + url (str): Base URL endpoint for the ColBERTv2 server. + port (Union[str, int], optional): Port number for server. Appended to URL if provided. + post_requests (bool, optional): Determines if POST requests should be used instead of GET requests for querying the server. Defaults to False. + k (int, optional): Number of top passages to retrieve. Defaults to 10. + callbacks (Optional[List[Any]]): List of callback functions to be called during retrieval. + cache (bool, optional): Enable retrieval caching. Disabled by default. + + + Returns: + An object containing the retrieved passages. + + Example: + import dspy + results = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=10).passages + print(results) + """ + + def __init__( + self, + url: str = "http://0.0.0.0", + port: Optional[Union[str, int]] = None, + post_requests: bool = False, + k: int = 10, + callbacks: Optional[List[Any]] = None, + cache: bool = False, + ): + super().__init__(embedder=None, k=k, callbacks=callbacks, cache=cache) + self.post_requests = post_requests + self.url = f"{url}:{port}" if port else url + + def forward(self, query: str, k: int = 10) -> Any: + if self.post_requests: + topk = colbertv2_post_request(self.url, query, k) + else: + topk = colbertv2_get_request(self.url, query, k) + return dotdict({"passages": [dotdict(psg) for psg in topk]}) + + +@CacheMemory.cache +def colbertv2_get_request_v2(url: str, query: str, k: int): + assert k <= 100, "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment." + + payload = {"query": query, "k": k} + res = requests.get(url, params=payload, timeout=10) + + topk = res.json()["topk"][:k] + topk = [{**d, "long_text": d["text"]} for d in topk] + return topk[:k] + + +@functools.cache +@NotebookCacheMemory.cache +def colbertv2_get_request_v2_wrapped(*args, **kwargs): + return colbertv2_get_request_v2(*args, **kwargs) + + +colbertv2_get_request = colbertv2_get_request_v2_wrapped + + +@CacheMemory.cache +def colbertv2_post_request_v2(url: str, query: str, k: int): + headers = {"Content-Type": "application/json; charset=utf-8"} + payload = {"query": query, "k": k} + res = requests.post(url, json=payload, headers=headers, timeout=10) + + return res.json()["topk"][:k] + + +@functools.cache +@NotebookCacheMemory.cache +def colbertv2_post_request_v2_wrapped(*args, **kwargs): + return colbertv2_post_request_v2(*args, **kwargs) + + +colbertv2_post_request = colbertv2_post_request_v2_wrapped diff --git a/dspy/retriever/databricks.py b/dspy/retriever/databricks.py new file mode 100644 index 000000000..1048546a1 --- /dev/null +++ b/dspy/retriever/databricks.py @@ -0,0 +1,348 @@ +import json +import os +from importlib.util import find_spec +from typing import Any, Dict, List, Optional + +import requests + +from dspy.retriever import Retriever +from dspy.primitives.prediction import Prediction +from dspy.clients.embedding import Embedding + +_databricks_sdk_installed = find_spec("databricks.sdk") is not None + + +class Databricks(Retriever): + """ + A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k + embeddings for a given query. + + Examples: + Below is a code snippet that shows how to set up a Databricks Vector Search Index + and configure a Databricks retriever module to query the index. + + (example adapted from "Databricks: How to create and query a Vector Search Index: + https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index) + + ```python + from databricks.vector_search.client import VectorSearchClient + + # Create a Databricks Vector Search Endpoint + client = VectorSearchClient() + client.create_endpoint( + name="your_vector_search_endpoint_name", + endpoint_type="STANDARD" + ) + + # Create a Databricks Direct Access Vector Search Index + index = client.create_direct_access_index( + endpoint_name="your_vector_search_endpoint_name", + index_name="your_index_name", + primary_key="id", + embedding_dimension=1024, + embedding_vector_column="text_vector", + schema={ + "id": "int", + "field2": "str", + "field3": "float", + "text_vector": "array" + } + ) + + # Create a Databricks retriever module to query the Databricks Direct Access Vector + # Search Index + from dspy.retriever.databricks import Databricks + + retriever = Databricks( + databricks_index_name = "your_index_name", + docs_id_column_name="id", + text_column_name="field2", + k=3 + ) + ``` + + Below is a code snippet that shows how to query the Databricks Direct Access Vector + Search Index using the Databricks retriever module: + + ```python + retrieved_results = retriever(query="Example query text") + ``` + """ + + def __init__( + self, + databricks_index_name: str, + databricks_endpoint: Optional[str] = None, + databricks_token: Optional[str] = None, + columns: Optional[List[str]] = None, + filters_json: Optional[str] = None, + query_type: str = "ANN", + k: int = 3, + docs_id_column_name: str = "id", + text_column_name: str = "text", + embedder: Optional[Embedding] = None, + callbacks: Optional[List[Any]] = None, + cache: bool = False, + ): + """ + Args: + databricks_index_name (str): The name of the Databricks Vector Search Index to query. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_token (Optional[str]): The Databricks Workspace authentication token to use + when querying the Vector Search Index. Defaults to the value of the + ``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is + used to identify the token based on the current environment. + columns (Optional[List[str]]): Extra column names to include in response, + in addition to the document id and text columns specified by + ``docs_id_column_name`` and ``text_column_name``. + filters_json (Optional[str]): A JSON string specifying additional query filters. + Example filters: ``{"id <": 5}`` selects records that have an ``id`` column value + less than 5, and ``{"id >=": 5, "id <": 10}`` selects records that have an ``id`` + column value greater than or equal to 5 and less than 10. + query_type (str): The type of search query to perform. Must be 'ANN', 'HYBRID', or 'VECTOR'. + k (int): The number of documents to retrieve. Defaults to 3 + docs_id_column_name (str): The name of the column in the Databricks Vector Search Index + containing document IDs. + text_column_name (str): The name of the column in the Databricks Vector Search Index + containing document text to retrieve. + embedder (Optional[Embedding]): An embedder to convert query text to vectors when using 'VECTOR' query_type. + callbacks (Optional[List[Any]]): List of callback functions. + cache (bool, optional): Enable retrieval caching. Disabled by default. + """ + super().__init__(embedder=embedder, k=k, callbacks=callbacks, cache=cache) + self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN") + self.databricks_endpoint = ( + databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST") + ) + if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: + raise ValueError( + "To retrieve documents with Databricks Vector Search, you must install the" + " databricks-sdk Python library, supply the databricks_token and" + " databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST" + " environment variables." + ) + self.databricks_index_name = databricks_index_name + self.columns = list({docs_id_column_name, text_column_name, *(columns or [])}) + self.filters_json = filters_json + self.query_type = query_type + self.docs_id_column_name = docs_id_column_name + self.text_column_name = text_column_name + + def _extract_doc_ids(self, item: Dict[str, Any]) -> str: + """Extracts the document id from a search result. + + Args: + item (Dict[str, Any]): A record from the search results. + + Returns: + str: Document id. + """ + if self.docs_id_column_name == "metadata": + docs_dict = json.loads(item["metadata"]) + return docs_dict["document_id"] + return item[self.docs_id_column_name] + + def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: + """Extracts search result column values, excluding the "text" and "id" columns. + + Args: + item (Dict[str, Any]): A record from the search results. + + Returns: + Dict[str, Any]: Search result column values, excluding the "text" and "id" columns. + """ + extra_columns = {k: v for k, v in item.items() if k not in [self.docs_id_column_name, self.text_column_name]} + if self.docs_id_column_name == "metadata": + extra_columns = { + **extra_columns, + **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, + } + return extra_columns + + def forward(self, query: str, k: Optional[int] = None) -> Prediction: + """ + Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the + specified query. + + Args: + query (str): The query text for which to retrieve relevant documents. + k (Optional[int]): The number of documents to retrieve. If None, defaults to self.k. + + Returns: + dspy.Prediction: An object containing the retrieved results. + """ + k = k or self.k + query_text = query + query_vector = None + + if self.query_type.upper() in ["ANN", "HYBRID"]: + query_text = query + elif self.query_type.upper() == "VECTOR": + if self.embedder: + query_vector = self.embedder.embed(query) + query_text = None + else: + raise ValueError("An embedder must be provided when using 'VECTOR' query_type without providing a query vector.") + else: + raise ValueError(f"Unsupported query_type: {self.query_type}") + + if _databricks_sdk_installed: + results = self._query_via_databricks_sdk( + index_name=self.databricks_index_name, + k=k, + columns=self.columns, + query_type=self.query_type.upper(), + query_text=query_text, + query_vector=query_vector, + databricks_token=self.databricks_token, + databricks_endpoint=self.databricks_endpoint, + filters_json=self.filters_json, + ) + else: + results = self._query_via_requests( + index_name=self.databricks_index_name, + k=k, + columns=self.columns, + databricks_token=self.databricks_token, + databricks_endpoint=self.databricks_endpoint, + query_type=self.query_type.upper(), + query_text=query_text, + query_vector=query_vector, + filters_json=self.filters_json, + ) + + # Checking if defined columns are present in the index columns + col_names = [column["name"] for column in results["manifest"]["columns"]] + + if self.docs_id_column_name not in col_names: + raise Exception( + f"docs_id_column_name: '{self.docs_id_column_name}' is not in the index columns: \n {col_names}" + ) + + if self.text_column_name not in col_names: + raise Exception(f"text_column_name: '{self.text_column_name}' is not in the index columns: \n {col_names}") + + # Extracting the results + items = [] + for data_row in results["result"]["data_array"]: + item = {col_name: val for col_name, val in zip(col_names, data_row)} + items.append(item) + + # Sorting results by score in descending order + sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[:k] + + # Returning the prediction + return Prediction( + passages=[doc[self.text_column_name] for doc in sorted_docs], + doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs], + extra_columns=[self._get_extra_columns(doc) for doc in sorted_docs], + ) + + @staticmethod + def _query_via_databricks_sdk( + index_name: str, + k: int, + columns: List[str], + query_type: str, + query_text: Optional[str], + query_vector: Optional[List[float]], + databricks_token: Optional[str], + databricks_endpoint: Optional[str], + filters_json: Optional[str], + ) -> Dict[str, Any]: + """ + Query a Databricks Vector Search Index via the Databricks SDK. + Assumes that the databricks-sdk Python library is installed. + + Args: + index_name (str): Name of the Databricks vector search index to query + k (int): Number of relevant documents to retrieve. + columns (List[str]): Column names to include in response. + query_text (Optional[str]): Text query for which to find relevant documents. Exactly + one of query_text or query_vector must be specified. + query_vector (Optional[List[float]]): Numeric query vector for which to find relevant + documents. Exactly one of query_text or query_vector must be specified. + filters_json (Optional[str]): JSON string representing additional query filters. + databricks_token (str): Databricks authentication token. If not specified, + the token is resolved from the current environment. + databricks_endpoint (str): Databricks index endpoint url. If not specified, + the endpoint is resolved from the current environment. + Returns: + Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. + """ + from databricks.sdk import WorkspaceClient + + if (query_text, query_vector).count(None) != 1: + raise ValueError("Exactly one of query_text or query_vector must be specified.") + + databricks_client = WorkspaceClient(host=databricks_endpoint, token=databricks_token) + return databricks_client.vector_search_indexes.query_index( + index_name=index_name, + query_type=query_type, + query_text=query_text, + query_vector=query_vector, + columns=columns, + filters_json=filters_json, + num_results=k, + ).as_dict() + + @staticmethod + def _query_via_requests( + index_name: str, + k: int, + columns: List[str], + databricks_token: str, + databricks_endpoint: str, + query_type: str, + query_text: Optional[str], + query_vector: Optional[List[float]], + filters_json: Optional[str], + ) -> Dict[str, Any]: + """ + Query a Databricks Vector Search Index via the Python requests library. + + Args: + index_name (str): Name of the Databricks vector search index to query + k (int): Number of relevant documents to retrieve. + columns (List[str]): Column names to include in response. + databricks_token (str): Databricks authentication token. + databricks_endpoint (str): Databricks index endpoint url. + query_text (Optional[str]): Text query for which to find relevant documents. Exactly + one of query_text or query_vector must be specified. + query_vector (Optional[List[float]]): Numeric query vector for which to find relevant + documents. Exactly one of query_text or query_vector must be specified. + filters_json (Optional[str]): JSON string representing additional query filters. + + Returns: + Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. + """ + if (query_text, query_vector).count(None) != 1: + raise ValueError("Exactly one of query_text or query_vector must be specified.") + + headers = { + "Authorization": f"Bearer {databricks_token}", + "Content-Type": "application/json", + } + payload = { + "columns": columns, + "num_results": k, + "query_type": query_type, + } + if filters_json is not None: + payload["filters_json"] = filters_json + if query_text is not None: + payload["query_text"] = query_text + elif query_vector is not None: + payload["query_vector"] = query_vector + response = requests.post( + f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query", + json=payload, + headers=headers, + ) + results = response.json() + if "error_code" in results: + raise Exception(f"ERROR: {results['error_code']} -- {results['message']}") + return results diff --git a/dspy/retriever/faiss.py b/dspy/retriever/faiss.py new file mode 100644 index 000000000..b64935694 --- /dev/null +++ b/dspy/retriever/faiss.py @@ -0,0 +1,141 @@ +"""Retriever model for faiss: https://github.com/facebookresearch/faiss. +Author: Jagane Sundar: https://github.com/jagane. +(modified to support `dspy.Retriever` interface) +""" + +import logging +from typing import List, Any, Optional + +import numpy as np + +from dspy.retriever import Retriever +from dspy.clients.embedding import Embedding +from dspy.primitives.prediction import Prediction + +try: + import faiss +except ImportError: + faiss = None + +if faiss is None: + raise ImportError( + """ + The faiss package is required. Install it using `pip install dspy-ai[faiss-cpu]` + """, + ) + +logger = logging.getLogger(__name__) + +class Faiss(Retriever): + """A retrieval module that uses an in-memory Faiss index to return the top passages for a given query. + + Args: + document_chunks: The input text chunks. + embedder: An instance of `dspy.Embedding` to compute embeddings. + k (int, optional): The number of top passages to retrieve. Defaults to 3. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + + Examples: + Below is a code snippet that shows how to use this as the default retriever: + + ```python + import dspy + from dspy.retriever.faiss import Faiss + + # Custom embedding function using SentenceTransformers and dspy.Embedding + def sentence_transformers_embedder(texts): + #(pip install sentence-transformers) + from sentence_transformers import SentenceTransformer + model = SentenceTransformer('all-MiniLM-L6-v2') + embeddings = model.encode(texts, batch_size=256, normalize_embeddings=True) + return embeddings.tolist() + embedder = dspy.Embedding(embedding_model=sentence_transformers_embedder) + + document_chunks = [ + "The superbowl this year was played between the San Francisco 49ers and the Kansas City Chiefs", + "Pop corn is often served in a bowl", + "The Rice Bowl is a Chinese Restaurant located in the city of Tucson, Arizona", + "Mars is the fourth planet in the Solar System", + "An aquarium is a place where children can learn about marine life", + "The capital of the United States is Washington, D.C", + "Rock and Roll musicians are honored by being inducted in the Rock and Roll Hall of Fame", + "Music albums were published on Long Play Records in the 70s and 80s", + "Sichuan cuisine is a spicy cuisine from central China", + "The interest rates for mortgages are considered to be very high in 2024", + ] + + retriever = Faiss(document_chunks, embedder=embedder) + results = retriever("I am in the mood for Chinese food").passages + print(results) + ``` + """ + + def __init__( + self, + document_chunks: List[str], + embedder: Embedding, + k: int = 5, + callbacks: Optional[List[Any]] = None, + cache: bool = False, + ): + """Inits the faiss retriever. + + Args: + document_chunks (List[str]): A list of input strings. + embedder (dspy.Embedding): An instance of `dspy.Embedding` to compute embeddings. + k (int, optional): Number of passages to retrieve. Defaults to 5. + callbacks (Optional[List[Any]]): List of callback functions. + cache (bool, optional): Enable retrieval caching. Disabled by default. + """ + if embedder is not None and not isinstance(embedder, Embedding): + raise ValueError("If provided, the embedder must be of type `dspy.Embedding`.") + self.embedder = embedder + super().__init__(embedder=self.embedder, k=k, callbacks=callbacks, cache=cache) + embeddings = self.embedder(document_chunks) + xb = np.array(embeddings) + d = xb.shape[1] + logger.info(f"Faiss: embedding size={d}") + if len(xb) < 100: + self._faiss_index = faiss.IndexFlatL2(d) + self._faiss_index.add(xb) + else: + # If we have at least 100 vectors, we use Voronoi cells + nlist = 100 + quantizer = faiss.IndexFlatL2(d) + self._faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist) + self._faiss_index.train(xb) + self._faiss_index.add(xb) + + logger.info(f"{self._faiss_index.ntotal} vectors in faiss index") + self._document_chunks = document_chunks # Save the input document chunks + + def _dump_raw_results(self, queries, index_list, distance_list) -> None: + for i in range(len(queries)): + indices = index_list[i] + distances = distance_list[i] + logger.debug(f"Query: {queries[i]}") + for j in range(len(indices)): + logger.debug( + f" Hit {j} = {indices[j]}/{distances[j]}: {self._document_chunks[indices[j]]}" + ) + return + + def forward(self, query: str, k: Optional[int] = None, **kwargs) -> Prediction: + """Search the faiss index for k or self.k top passages for query. + + Args: + query (str): The query to search for. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + k = k or self.k + embeddings = self.embedder([query]) + emb_npa = np.array(embeddings) + distance_list, index_list = self._faiss_index.search(emb_npa, k) + # self._dump_raw_results([query], index_list, distance_list) + passages = [self._document_chunks[ind] for ind in index_list[0]] + doc_ids = [ind for ind in index_list[0]] + return Prediction(passages=passages, doc_ids=doc_ids) diff --git a/dspy/retriever/milvus.py b/dspy/retriever/milvus.py new file mode 100644 index 000000000..1523140ab --- /dev/null +++ b/dspy/retriever/milvus.py @@ -0,0 +1,119 @@ +""" +Retriever model for Milvus or Zilliz Cloud +""" + +from typing import List, Optional, Any + +from dspy.retriever import Retriever +from dspy.clients.embedding import Embedding +from dspy.primitives.prediction import Prediction + +try: + from pymilvus import MilvusClient +except ImportError: + raise ImportError( + "The pymilvus library is required to use Milvus. Install it with `pip install dspy-ai[milvus]`", + ) + +class Milvus(Retriever): + """ + A retrieval module that uses Milvus to return passages for a given query. + + Assumes that a Milvus collection has been created and populated with the following field: + - text: The text of the passage + + Args: + collection_name (str): The name of the Milvus collection to query against. + uri (str, optional): The Milvus connection URI. Defaults to "http://localhost:19530". + token (str, optional): The Milvus connection token. Defaults to None. + db_name (str, optional): The Milvus database name. Defaults to "default". + embedder (dspy.Embedding): An instance of `dspy.Embedding` to compute embeddings. + k (int, optional): Number of top passages to retrieve. Defaults to 5. + callbacks (Optional[List[Any]]): List of callback functions. + cache (bool, optional): Enable retrieval caching. Disabled by default. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + + Examples: + Below is a code snippet that shows how to use this as the default retriever: + ```python + import dspy + from dspy.retriever.milvus import Milvus + + # Create an Embedding instance + embedder = dspy.Embedding(embedding_model="openai/text-embedding-3-small") + + retriever = Milvus( + collection_name="", + uri="", + token="", + embedder=embedder, + k=3 + ) + results = retriever(query).passages + print(results) + ``` + """ + + def __init__( + self, + collection_name: str, + uri: str = "http://localhost:19530", + token: Optional[str] = None, + db_name: str = "default", + embedder: Embedding = None, + k: int = 5, + callbacks: Optional[List[Any]] = None, + cache: bool = False, + ): + if embedder is not None and not isinstance(embedder, Embedding): + raise ValueError("If provided, the embedder must be of type `dspy.Embedding`.") + super().__init__(embedder=embedder, k=k, callbacks=callbacks, cache=cache) + + self.milvus_client = MilvusClient(uri=uri, token=token, db_name=db_name) + + # Check if collection exists + if collection_name not in self.milvus_client.list_collections(): + raise AttributeError(f"Milvus collection not found: {collection_name}") + self.collection_name = collection_name + + def forward(self, query: str, k: Optional[int] = None) -> Prediction: + """ + Retrieve passages from Milvus that are relevant to the specified query. + + Args: + query (str): The query text for which to retrieve relevant passages. + k (Optional[int]): The number of passages to retrieve. If None, defaults to self.k. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + k = k or self.k + query_embedding = self.embedder([query])[0] + + # Milvus expects embeddings as lists + query_embedding = query_embedding.tolist() + + milvus_res = self.milvus_client.search( + collection_name=self.collection_name, + data=[query_embedding], + output_fields=["text"], + limit=k, + ) + + results = [] + for res in milvus_res: + for r in res: + text = r["entity"]["text"] + doc_id = r["id"] + distance = r["distance"] + results.append((text, doc_id, distance)) + + sorted_results = sorted(results, key=lambda x: x[2], reverse=True)[:k] + passages = [x[0] for x in sorted_results] + doc_ids = [x[1] for x in sorted_results] + distances = [x[2] for x in sorted_results] + + return Prediction(passages=passages, doc_ids=doc_ids, scores=distances) + \ No newline at end of file diff --git a/dspy/retriever/pinecone.py b/dspy/retriever/pinecone.py new file mode 100644 index 000000000..c77530d6e --- /dev/null +++ b/dspy/retriever/pinecone.py @@ -0,0 +1,181 @@ +""" +Retriever model for Pinecone +Author: Dhar Rawal (@drawal1) +(modified to support `dspy.Retriever` interface) +""" + +from typing import List, Optional, Any, Union + +from dspy.retriever import Retriever +from dspy.clients.embedding import Embedding +from dspy.primitives.prediction import Prediction +from dsp.utils import dotdict + +try: + import pinecone +except ImportError: + pinecone = None + +if pinecone is None: + raise ImportError( + "The pinecone library is required to use Pinecone. Install it with `pip install dspy-ai[pinecone]`", + ) + + +class Pinecone(Retriever): + """ + A retrieval module that uses Pinecone to return the top passages for a given query or list of queries. + + Assumes that the Pinecone index has been created and populated with the following metadata: + - text: The text of the passage + + Args: + pinecone_index_name (str): The name of the Pinecone index to query against. + pinecone_api_key (str, optional): The Pinecone API key. Defaults to None. + pinecone_env (str, optional): The Pinecone environment. Defaults to None. + embedder (dspy.Embedding): An instance of `dspy.Embedding` to compute embeddings. + k (int, optional): Number of top passages to retrieve. Defaults to 5. + callbacks (Optional[List[Any]]): List of callback functions. + cache (bool, optional): Enable retrieval caching. Disabled by default. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + + Examples: + Below is a code snippet that shows how to use this retriever: + ```python + import dspy + from dspy.retriever.pinecone import Pinecone + + # Create an Embedding instance + embedder = dspy.Embedding(embedding_model="openai/text-embedding-3-small") + + retriever = Pinecone( + pinecone_index_name="", + pinecone_api_key="", + pinecone_env="", + embedder=embedder, + k=3 + ) + + results = retriever(query).passages + print(results) + ``` + """ + + def __init__( + self, + pinecone_index_name: str, + pinecone_api_key: Optional[str] = None, + pinecone_env: Optional[str] = None, + dimension: Optional[int] = None, + distance_metric: Optional[str] = None, + embedder: Embedding = None, + k: int = 5, + callbacks: Optional[List[Any]] = None, + cache: bool = False, + ): + if embedder is None or not isinstance(embedder, Embedding): + raise ValueError("An embedder of type `dspy.Embedding` must be provided.") + self.embedder = embedder + super().__init__(embedder=self.embedder, k=k, callbacks=callbacks, cache=cache) + + self._pinecone_index = self._init_pinecone( + index_name=pinecone_index_name, + api_key=pinecone_api_key, + environment=pinecone_env, + dimension=dimension, + distance_metric=distance_metric, + ) + + def _init_pinecone( + self, + index_name: str, + api_key: Optional[str] = None, + environment: Optional[str] = None, + dimension: Optional[int] = None, + distance_metric: Optional[str] = None, + ) -> pinecone.Index: + """Initialize pinecone and return the loaded index. + + Args: + index_name (str): The name of the index to load. If the index is not does not exist, it will be created. + api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided. + environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT. + + Raises: + ValueError: If api_key or environment is not provided and not set as an environment variable. + + Returns: + pinecone.Index: The loaded index. + """ + + # Pinecone init overrides default if kwargs are present, so we need to exclude if None + kwargs = {} + if api_key: + kwargs["api_key"] = api_key + if environment: + kwargs["environment"] = environment + pinecone.init(**kwargs) + + active_indexes = pinecone.list_indexes() + if index_name not in active_indexes: + if dimension is None or distance_metric is None: + raise ValueError( + "dimension and distance_metric must be provided since the index does not exist and needs to be created." + ) + + pinecone.create_index( + name=index_name, + dimension=dimension, + metric=distance_metric, + ) + + return pinecone.Index(index_name) + + def forward(self, query: Union[str, List[str]], k: Optional[int] = None) -> Prediction: + """Search with Pinecone for top k passages for the query or queries. + + Args: + query (Union[str, List[str]]): The query or list of queries to search for. + k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + k = k or self.k + queries = [query] if isinstance(query, str) else query + queries = [q for q in queries if q] + embeddings = self.embedder(queries) + # For single query, just look up the top k passages + if len(queries) == 1: + results_dict = self._pinecone_index.query( + embeddings[0], top_k=self.k, include_metadata=True, + ) + + # Sort results by score + sorted_results = sorted( + results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True, + ) + passages = [result["metadata"]["text"] for result in sorted_results] + passages = [dotdict({"long_text": passage for passage in passages})] + return Prediction(passages=passages) + + # For multiple queries, query each and return the highest scoring passages + # If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x + passage_scores = {} + for embedding in embeddings: + results_dict = self._pinecone_index.query( + embedding, top_k=self.k * 3, include_metadata=True, + ) + for result in results_dict["matches"]: + passage_scores[result["metadata"]["text"]] = ( + passage_scores.get(result["metadata"]["text"], 0.0) + + result["score"] + ) + + sorted_passages = sorted( + passage_scores.items(), key=lambda x: x[1], reverse=True, + )[: self.k] + return Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]) + \ No newline at end of file diff --git a/dspy/retriever/retriever.py b/dspy/retriever/retriever.py new file mode 100644 index 000000000..088cb47d5 --- /dev/null +++ b/dspy/retriever/retriever.py @@ -0,0 +1,60 @@ +from typing import Any, List, Optional + +from abc import ABC, abstractmethod +from dspy.clients.embedding import Embedding +from dspy.utils.callback import with_callbacks + +import os +from pathlib import Path +from diskcache import Cache + +DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") + + +class Retriever(ABC): + def __init__(self, embedder: Optional[Embedding] = None, k: int = 5, callbacks: Optional[List[Any]] = None, cache: bool = False): + """ + Interface for composing retrievers in DSPy to return relevant passages or documents based on a query. + + Args: + embedder (Optional[Embedding]): An instance of `dspy.Embedding` used to compute embeddings + for queries and documents. If `None`, embedding functionality should be implemented + within the subclass. Defaults to `None`. + k (int): The default number of top passages to retrieve when not specified in the `forward` method. Defaults to `5`. + callbacks (Optional[List[Any]]): A list of callback functions to be called during retrieval. + cache (bool): Enable retrieval caching. Disabled by default. + """ + self.embedder = embedder + self.k = k + self.callbacks = callbacks or [] + self.cache_enabled = cache + self.cache = Cache(directory=DISK_CACHE_DIR) if self.cache_enabled else None + + @abstractmethod + def forward(self, query: str, k: Optional[int] = None) -> Any: + """ + Retrievers implement this method with their custom retrieval logic. + Must return an object that has a 'passages' attribute (ideally `dspy.Prediction`). + """ + pass + + @with_callbacks + def __call__(self, query: str, k: Optional[int] = None) -> Any: + """ + Calls the forward method and checks if the result has a 'passages' attribute. + """ + k = k if k is not None else self.k + if self.cache_enabled and self.cache is not None: + cache_key = (query, k) + try: + result = self.cache[cache_key] + except KeyError: + result = self.forward(query, k) + self.cache[cache_key] = result + else: + result = self.forward(query, k) + if not hasattr(result, 'passages'): + raise ValueError( + "The 'forward' method must return an object with a 'passages' attribute (ideally `dspy.Prediction`)." + ) + return result diff --git a/dspy/retriever/weaviate.py b/dspy/retriever/weaviate.py new file mode 100644 index 000000000..1a3b4d938 --- /dev/null +++ b/dspy/retriever/weaviate.py @@ -0,0 +1,158 @@ +from typing import Any, List, Optional, Union + +from dspy.retriever import Retriever +from dspy.primitives.prediction import Prediction +from dsp.utils import dotdict + +try: + import weaviate + from weaviate.util import get_valid_uuid + from uuid import uuid4 +except ImportError as err: + raise ImportError( + "The 'weaviate' extra is required to use Weaviate. Install it with `pip install dspy-ai[weaviate]`", + ) from err + + +class Weaviate(Retriever): + """A retrieval module that uses Weaviate to return the top passages for a given query. + + Assumes that a Weaviate collection has been created and populated with the following payload: + - content: passage text + + Args: + weaviate_collection_name (str): Name of the Weaviate collection. + weaviate_client (Union[weaviate.WeaviateClient, weaviate.Client]): An instance of the Weaviate client. + weaviate_collection_text_key (Optional[str]): The key in the Weaviate collection where the passage text is stored. Defaults to "content". + k (int): Number of top passages to retrieve. Defaults to 5. + callbacks (Optional[List[Any]]): List of callback functions. + cache (bool): Enable retrieval caching. Disabled by default. + + Examples: + Below is a code snippet that shows how to use this retriever: + ```python + import weaviate + import dspy + from dspy.retriever.weaviate import Weaviate + + weaviate_client = weaviate.Client("http://localhost:8080") + retriever = Weaviate( + weaviate_collection_name="MyCollection", + weaviate_client=weaviate_client, + k=5 + ) + results = retriever("What are the stages in planning public works?").passages + print(results) + ``` + """ + + def __init__( + self, + weaviate_collection_name: str, + weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], + weaviate_collection_text_key: Optional[str] = "content", + k: int = 5, + callbacks: Optional[List[Any]] = None, + cache: bool = False, + ): + super().__init__(embedder=None, k=k, callbacks=callbacks, cache=cache) + self._weaviate_collection_name = weaviate_collection_name + self._weaviate_client = weaviate_client + self._weaviate_collection_text_key = weaviate_collection_text_key + + # Determine client type (Weaviate v3 or v4) + if hasattr(weaviate_client, "collections"): + self._client_type = "WeaviateClient" # Weaviate v4 + elif hasattr(weaviate_client, "query"): + self._client_type = "Client" # Weaviate v3 + else: + raise ValueError("Unsupported Weaviate client type") + + def forward(self, query: Union[str, List[str]], k: Optional[int] = None) -> Prediction: + """Search with Weaviate for the top k passages for the query or queries. + + Args: + query (Union[str, List[str]]): The query or list of queries to search for. + k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + k = k if k is not None else self.k + queries = [query] if isinstance(query, str) else query + queries = [q for q in queries if q] + passages = [] + + for q in queries: + if self._client_type == "WeaviateClient": + # For Weaviate v4 + results = ( + self._weaviate_client.query.get( + self._weaviate_collection_name, + [self._weaviate_collection_text_key], + ) + .with_hybrid(query=q) + .with_limit(k) + .do() + ) + + parsed_results = [ + result[self._weaviate_collection_text_key] + for result in results["data"]["Get"][self._weaviate_collection_name] + ] + + elif self._client_type == "Client": + # For Weaviate v3 + results = ( + self._weaviate_client.query.get( + self._weaviate_collection_name, + [self._weaviate_collection_text_key], + ) + .with_hybrid(query=q) + .with_limit(k) + .do() + ) + + parsed_results = [ + result[self._weaviate_collection_text_key] + for result in results["data"]["Get"][self._weaviate_collection_name] + ] + else: + raise ValueError("Unsupported Weaviate client type") + + passages.extend(dotdict({"long_text": d}) for d in parsed_results) + + return Prediction(passages=passages) + + def get_objects(self, num_samples: int, fields: List[str]) -> List[dict]: + """Get objects from Weaviate using the cursor API.""" + if self._client_type == "WeaviateClient": + objects = [] + counter = 0 + for item in self._weaviate_client.data_object.get(): + if counter >= num_samples: + break + new_object = { + key: item['properties'][key] + for key in item['properties'] + if key in fields + } + objects.append(new_object) + counter += 1 + return objects + else: + raise ValueError( + "`get_objects` is not supported for the v3 Weaviate Python client, please upgrade to v4." + ) + + def insert(self, new_object_properties: dict): + if self._client_type == "WeaviateClient": + self._weaviate_client.data_object.create( + data_object=new_object_properties, + class_name=self._weaviate_collection_name, + uuid=get_valid_uuid(uuid4()), + ) + else: + raise AttributeError( + "`insert` is not supported for the v3 Weaviate Python client, please upgrade to v4." + ) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index fdc48b083..e56dcba35 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -150,7 +150,8 @@ def _bootstrap(self, *, max_bootstraps=None): self.name2traces = {name: [] for name in self.name2predictor} for example_idx, example in enumerate(tqdm.tqdm(self.trainset)): - if len(bootstrapped) >= max_bootstraps: break + if len(bootstrapped) >= max_bootstraps: + break for round_idx in range(self.max_rounds): bootstrap_attempts += 1 diff --git a/examples/retriever_migration.ipynb b/examples/retriever_migration.ipynb new file mode 100644 index 000000000..81781a47b --- /dev/null +++ b/examples/retriever_migration.ipynb @@ -0,0 +1,137 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### dspy.Retriever Interface\n", + "This notebook demonstrates how to use the `dspy.Retriever` interface for custom retriever integrations.\n", + "\n", + "**Supported Integrations**: ColBERTv2, Databricks, FAISS, Milvus, Pinecone, Weaviate\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'text': 'David Gregory (physician) | David Gregory (20 December 1625 – 1720) was a Scottish physician and inventor. His surname is sometimes spelt as Gregorie, the original Scottish spelling. He inherited Kinnairdy Castle in 1664. Three of his twenty-nine children became mathematics professors. He is credited with inventing a military cannon that Isaac Newton described as \"being destructive to the human species\". Copies and details of the model no longer exist. Gregory\\'s use of a barometer to predict farming-related weather conditions led him to be accused of witchcraft by Presbyterian ministers from Aberdeen, although he was never convicted.', 'pid': 3296134, 'rank': 1, 'score': 23.595149993896484, 'prob': 0.9494133657446476, 'long_text': 'David Gregory (physician) | David Gregory (20 December 1625 – 1720) was a Scottish physician and inventor. His surname is sometimes spelt as Gregorie, the original Scottish spelling. He inherited Kinnairdy Castle in 1664. Three of his twenty-nine children became mathematics professors. He is credited with inventing a military cannon that Isaac Newton described as \"being destructive to the human species\". Copies and details of the model no longer exist. Gregory\\'s use of a barometer to predict farming-related weather conditions led him to be accused of witchcraft by Presbyterian ministers from Aberdeen, although he was never convicted.'}, {'text': 'David Gregory (mathematician) | David Gregory (originally spelt Gregorie) FRS (? 1659 – 10 October 1708) was a Scottish mathematician and astronomer. He was professor of mathematics at the University of Edinburgh, Savilian Professor of Astronomy at the University of Oxford, and a commentator on Isaac Newton\\'s \"Principia\".', 'pid': 1499187, 'rank': 2, 'score': 19.923200607299805, 'prob': 0.024140595827172887, 'long_text': 'David Gregory (mathematician) | David Gregory (originally spelt Gregorie) FRS (? 1659 – 10 October 1708) was a Scottish mathematician and astronomer. He was professor of mathematics at the University of Edinburgh, Savilian Professor of Astronomy at the University of Oxford, and a commentator on Isaac Newton\\'s \"Principia\".'}, {'text': 'David Gregory (historian) | David Gregory (1696–1767) was an English churchman and academic, Dean of Christ Church, Oxford and the first Regius Professor of Modern History at Oxford.', 'pid': 1943433, 'rank': 3, 'score': 19.197763442993164, 'prob': 0.011686773511174484, 'long_text': 'David Gregory (historian) | David Gregory (1696–1767) was an English churchman and academic, Dean of Christ Church, Oxford and the first Regius Professor of Modern History at Oxford.'}, {'text': 'David Gregory (Royal Navy officer) | Vice Admiral Sir George David Archibald Gregory & Bar (8 October 1909 – 21 March 1975) was a Royal Navy officer who became Flag Officer, Scotland and Northern Ireland.', 'pid': 4237441, 'rank': 4, 'score': 18.8070011138916, 'prob': 0.007906580631912762, 'long_text': 'David Gregory (Royal Navy officer) | Vice Admiral Sir George David Archibald Gregory & Bar (8 October 1909 – 21 March 1975) was a Royal Navy officer who became Flag Officer, Scotland and Northern Ireland.'}, {'text': 'David Gregory (footballer, born 1951) | David Harry Gregory (born 6 October 1951) is an English former footballer who played in the Football League for Blackburn Rovers, Bury, Peterborough United, Portsmouth, Stoke City and Wrexham.', 'pid': 4276505, 'rank': 5, 'score': 18.6639461517334, 'prob': 0.006852684285092242, 'long_text': 'David Gregory (footballer, born 1951) | David Harry Gregory (born 6 October 1951) is an English former footballer who played in the Football League for Blackburn Rovers, Bury, Peterborough United, Portsmouth, Stoke City and Wrexham.'}]\n" + ] + } + ], + "source": [ + "import dspy\n", + "\n", + "\n", + "results = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(\"What's the name of the castle that David Gregory inherited?\", k=5).passages\n", + "print(results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can create and configure your custom retrievers by implementing this interface.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class MyRetriever(dspy.Retriever):\n", + " def __init__(self, embedder: dspy.Embedding, k: int = 5, cache: bool = False):\n", + " # embedder: dspy.Embedding class to specify (1) embedding model and/or (2) embedding_function to perform any embedding computation\n", + " # k: number of top query results to return.\n", + " # cache: enable for query caching. (default is disabled)\n", + " super().__init__(embedder=embedder, k=k, cache=cache)\n", + "\n", + " def forward(self, query, k):\n", + " embeddings = self.embedder([query])\n", + "\n", + " #include custom logic here if your retriever supports a custom client, index, vector store, etc. \n", + " results = [\n", + " {\"passage\": f\"Mock passage {i+1} for query '{query}'\", \"doc_id\": i, \"score\": 1.0 / (i + 1)}\n", + " for i in range(k)\n", + " ]\n", + " passages = [res[\"passage\"] for res in results]\n", + " doc_ids = [res[\"doc_id\"] for res in results]\n", + " scores = [res[\"score\"] for res in results]\n", + " return dspy.Prediction(passages=passages, doc_ids=doc_ids, scores=scores)\n", + "\n", + "embedder=dspy.Embedding(embedding_model=\"openai/text-embedding-3-small\")\n", + "retriever = MyRetriever(embedder=embedder, k=10, cache=True)\n", + "\n", + "result = retriever(\"What's the name of the castle that David Gregory inherited?\")\n", + "print(result.passages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With this integration, you can easily configure retrievers in DSPy programs and pipelines.\n", + "\n", + "This example demonstrates a multi-hop program that declares layers for different modules and a retrievers to compose the system." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class MultiHop(dspy.Module):\n", + " def __init__(self, retriever, passages_per_hop):\n", + " super().__init__()\n", + " self.retrieve = retriever(k=passages_per_hop) \n", + " self.generate_query = dspy.ChainOfThought(\"context ,question->search_query\")\n", + " self.generate_answer = dspy.ChainOfThought(\"context ,question->answer\")\n", + "\n", + " def forward(self, question):\n", + " context = []\n", + " for hop in range(2):\n", + " query = self.generate_query(context=context, question=question).search_query\n", + " context += self.retrieve(query).passages\n", + " return dspy.Prediction(\n", + " context=context,\n", + " answer=self.generate_answer(context=context, question=question).answer,\n", + " )\n", + " \n", + "embedder=dspy.Embedding(embedding_model=\"openai/text-embedding-3-small\")\n", + "\n", + "multihop = MultiHop(retriever=MyRetriever(embedder=embedder, cache=True), passages_per_hop=3)\n", + "result = multihop(question = \"What's the name of the castle that David Gregory inherited?\")\n", + "\n", + "print(result.context)\n", + "print(result.answer)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "NEW_DSPY", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/test_streaming.py b/tests/utils/test_streaming.py index dbaa8be66..819699383 100644 --- a/tests/utils/test_streaming.py +++ b/tests/utils/test_streaming.py @@ -5,7 +5,7 @@ from tests.test_utils.server import litellm_test_server -@pytest.mark.asyncio +@pytest.mark.anyio async def test_streamify_yields_expected_response_chunks(litellm_test_server): api_base, _ = litellm_test_server lm = dspy.LM( @@ -37,7 +37,7 @@ class TestSignature(dspy.Signature): assert last_chunk2.output_text == "Hello!" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_streaming_response_yields_expected_response_chunks(litellm_test_server): api_base, _ = litellm_test_server lm = dspy.LM(