Skip to content

Commit

Permalink
support returning run info for llms, chat models and chains (#5666)
Browse files Browse the repository at this point in the history
returning the run id is important for accessing the run later on
  • Loading branch information
agola11 authored Jun 6, 2023
1 parent 65111eb commit b177a29
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 8 deletions.
30 changes: 25 additions & 5 deletions langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
CallbackManagerForChainRun,
Callbacks,
)
from langchain.schema import BaseMemory
from langchain.schema import RUN_KEY, BaseMemory, RunInfo


def _get_verbosity() -> bool:
Expand Down Expand Up @@ -108,6 +108,8 @@ def __call__(
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Expand All @@ -118,7 +120,10 @@ def __call__(
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
Expand All @@ -139,13 +144,20 @@ def __call__(
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs

async def acall(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Expand All @@ -156,7 +168,10 @@ async def acall(
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
Expand All @@ -177,7 +192,12 @@ async def acall(
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs

def prep_outputs(
self,
Expand Down
5 changes: 5 additions & 0 deletions langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
HumanMessage,
LLMResult,
PromptValue,
RunInfo,
)


Expand Down Expand Up @@ -93,6 +94,8 @@ def generate(
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output

async def agenerate(
Expand Down Expand Up @@ -131,6 +134,8 @@ async def agenerate(
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output

def generate_prompt(
Expand Down
17 changes: 15 additions & 2 deletions langchain/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Generation,
LLMResult,
PromptValue,
RunInfo,
get_buffer_string,
)

Expand Down Expand Up @@ -190,6 +191,8 @@ def generate(
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
Expand All @@ -210,10 +213,14 @@ def generate(
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)

async def agenerate(
self,
Expand Down Expand Up @@ -256,6 +263,8 @@ async def agenerate(
await run_manager.on_llm_error(e, verbose=self.verbose)
raise e
await run_manager.on_llm_end(output, verbose=self.verbose)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
Expand All @@ -278,10 +287,14 @@ async def agenerate(
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)

def __call__(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
Expand Down
19 changes: 19 additions & 0 deletions langchain/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
TypeVar,
Union,
)
from uuid import UUID

from pydantic import BaseModel, Extra, Field, root_validator

RUN_KEY = "__run"


def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
Expand Down Expand Up @@ -156,6 +159,12 @@ def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values


class RunInfo(BaseModel):
"""Class that contains all relevant metadata for a Run."""

run_id: UUID


class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result."""

Expand All @@ -173,6 +182,16 @@ class LLMResult(BaseModel):
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
run: Optional[RunInfo] = None
"""Run metadata."""

def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)


class PromptValue(BaseModel, ABC):
Expand Down
11 changes: 10 additions & 1 deletion tests/unit_tests/chains/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseMemory
from langchain.schema import RUN_KEY, BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler


Expand Down Expand Up @@ -72,6 +72,15 @@ def test_bad_outputs() -> None:
chain({"foo": "baz"})


def test_run_info() -> None:
"""Test that run_info is returned properly when specified"""
chain = FakeChain()
output = chain({"foo": "bar"}, include_run_info=True)
assert "foo" in output
assert "bar" in output
assert RUN_KEY in output


def test_correct_call() -> None:
"""Test correct call of fake chain."""
chain = FakeChain()
Expand Down

0 comments on commit b177a29

Please sign in to comment.