diff --git a/langchain/chains/base.py b/langchain/chains/base.py index b10a87dc49305..2db63a8fef9e7 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -18,7 +18,7 @@ CallbackManagerForChainRun, Callbacks, ) -from langchain.schema import BaseMemory +from langchain.schema import RUN_KEY, BaseMemory, RunInfo def _get_verbosity() -> bool: @@ -108,6 +108,8 @@ def __call__( inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, + *, + include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -118,7 +120,10 @@ def __call__( response. If True, only new keys generated by this chain will be returned. If False, both input keys and new keys generated by this chain will be returned. Defaults to False. - + callbacks: Callbacks to use for this chain run. If not provided, will + use the callbacks provided to the chain. + include_run_info: Whether to include run info in the response. Defaults + to False. """ inputs = self.prep_inputs(inputs) callback_manager = CallbackManager.configure( @@ -139,13 +144,20 @@ def __call__( run_manager.on_chain_error(e) raise e run_manager.on_chain_end(outputs) - return self.prep_outputs(inputs, outputs, return_only_outputs) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs + ) + if include_run_info: + final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) + return final_outputs async def acall( self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, + *, + include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -156,7 +168,10 @@ async def acall( response. If True, only new keys generated by this chain will be returned. If False, both input keys and new keys generated by this chain will be returned. Defaults to False. - + callbacks: Callbacks to use for this chain run. If not provided, will + use the callbacks provided to the chain. + include_run_info: Whether to include run info in the response. Defaults + to False. """ inputs = self.prep_inputs(inputs) callback_manager = AsyncCallbackManager.configure( @@ -177,7 +192,12 @@ async def acall( await run_manager.on_chain_error(e) raise e await run_manager.on_chain_end(outputs) - return self.prep_outputs(inputs, outputs, return_only_outputs) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs + ) + if include_run_info: + final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) + return final_outputs def prep_outputs( self, diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index de2cdd06c0406..dcb4ebebcfa55 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -25,6 +25,7 @@ HumanMessage, LLMResult, PromptValue, + RunInfo, ) @@ -93,6 +94,8 @@ def generate( generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) run_manager.on_llm_end(output) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output async def agenerate( @@ -131,6 +134,8 @@ async def agenerate( generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) await run_manager.on_llm_end(output) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output def generate_prompt( diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 21267a96afbcf..84ba2c5c86d86 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -25,6 +25,7 @@ Generation, LLMResult, PromptValue, + RunInfo, get_buffer_string, ) @@ -190,6 +191,8 @@ def generate( run_manager.on_llm_error(e) raise e run_manager.on_llm_end(output) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output if len(missing_prompts) > 0: run_manager = callback_manager.on_llm_start( @@ -210,10 +213,14 @@ def generate( llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) + run_info = None + if run_manager: + run_info = RunInfo(run_id=run_manager.run_id) else: llm_output = {} + run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] - return LLMResult(generations=generations, llm_output=llm_output) + return LLMResult(generations=generations, llm_output=llm_output, run=run_info) async def agenerate( self, @@ -256,6 +263,8 @@ async def agenerate( await run_manager.on_llm_error(e, verbose=self.verbose) raise e await run_manager.on_llm_end(output, verbose=self.verbose) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output if len(missing_prompts) > 0: run_manager = await callback_manager.on_llm_start( @@ -278,10 +287,14 @@ async def agenerate( llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) + run_info = None + if run_manager: + run_info = RunInfo(run_id=run_manager.run_id) else: llm_output = {} + run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] - return LLMResult(generations=generations, llm_output=llm_output) + return LLMResult(generations=generations, llm_output=llm_output, run=run_info) def __call__( self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None diff --git a/langchain/schema.py b/langchain/schema.py index 4a04bd04c5f68..b74b40a7c5e42 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -13,9 +13,12 @@ TypeVar, Union, ) +from uuid import UUID from pydantic import BaseModel, Extra, Field, root_validator +RUN_KEY = "__run" + def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" @@ -156,6 +159,12 @@ def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values +class RunInfo(BaseModel): + """Class that contains all relevant metadata for a Run.""" + + run_id: UUID + + class ChatResult(BaseModel): """Class that contains all relevant information for a Chat Result.""" @@ -173,6 +182,16 @@ class LLMResult(BaseModel): each input could have multiple generations.""" llm_output: Optional[dict] = None """For arbitrary LLM provider specific output.""" + run: Optional[RunInfo] = None + """Run metadata.""" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LLMResult): + return NotImplemented + return ( + self.generations == other.generations + and self.llm_output == other.llm_output + ) class PromptValue(BaseModel, ABC): diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 1e5022b89c40e..d60e06a8debf1 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -5,7 +5,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.schema import BaseMemory +from langchain.schema import RUN_KEY, BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -72,6 +72,15 @@ def test_bad_outputs() -> None: chain({"foo": "baz"}) +def test_run_info() -> None: + """Test that run_info is returned properly when specified""" + chain = FakeChain() + output = chain({"foo": "bar"}, include_run_info=True) + assert "foo" in output + assert "bar" in output + assert RUN_KEY in output + + def test_correct_call() -> None: """Test correct call of fake chain.""" chain = FakeChain()