Skip to content

Commit

Permalink
split up batch llm calls into separate runs (#5804)
Browse files Browse the repository at this point in the history
  • Loading branch information
agola11 authored Jun 25, 2023
1 parent 1da99ce commit e1b801b
Show file tree
Hide file tree
Showing 14 changed files with 401 additions and 293 deletions.
212 changes: 119 additions & 93 deletions langchain/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,66 +672,72 @@ def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
managers = []
for prompt in prompts:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)

_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)

return CallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
return managers

def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)

# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
managers = []
for message_list in messages:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)

managers.append(
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)

return managers

def on_chain_start(
self,
Expand Down Expand Up @@ -830,64 +836,84 @@ async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForLLMRun:
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()

await _ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
tasks = []
managers = []

return AsyncCallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
for prompt in prompts:
run_id_ = uuid4()

tasks.append(
_ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
)

managers.append(
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)

await asyncio.gather(*tasks)

return managers

async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()
tasks = []
managers = []

for message_list in messages:
run_id_ = uuid4()

tasks.append(
_ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
)

await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)

return AsyncCallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
await asyncio.gather(*tasks)
return managers

async def on_chain_start(
self,
Expand Down
62 changes: 2 additions & 60 deletions langchain/callbacks/openai_info.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import LLMResult

MODEL_COST_PER_1K_TOKENS = {
# GPT-4 input
Expand Down Expand Up @@ -152,64 +152,6 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens

def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass

def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass

def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
pass

def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
pass

def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass

def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
pass

def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self
Expand Down
Loading

0 comments on commit e1b801b

Please sign in to comment.