-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #127 from Portkey-AI/feat/langchainCallbackHandler
langchain llamaindex callback handler
- Loading branch information
Showing
10 changed files
with
530 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import json | ||
import os | ||
from typing import Optional | ||
import requests | ||
|
||
from portkey_ai.api_resources.global_constants import PORTKEY_BASE_URL | ||
|
||
|
||
class Logger: | ||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
) -> None: | ||
api_key = api_key or os.getenv("PORTKEY_API_KEY") | ||
if api_key is None: | ||
raise ValueError("API key is required to use the Logger API") | ||
|
||
self.headers = { | ||
"Content-Type": "application/json", | ||
"x-portkey-api-key": api_key, | ||
} | ||
|
||
self.url = PORTKEY_BASE_URL + "/logs" | ||
|
||
def log( | ||
self, | ||
log_object: dict, | ||
): | ||
response = requests.post( | ||
url=self.url, data=json.dumps(log_object), headers=self.headers | ||
) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .chat import ChatPortkey | ||
from .completion import PortkeyLLM | ||
from .portkey_langchain_callback import PortkeyLangchain | ||
|
||
__all__ = ["ChatPortkey", "PortkeyLLM"] | ||
__all__ = ["ChatPortkey", "PortkeyLLM", "PortkeyLangchain"] |
170 changes: 170 additions & 0 deletions
170
portkey_ai/llms/langchain/portkey_langchain_callback.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from datetime import datetime | ||
import time | ||
from typing import Any, Dict, List, Optional | ||
from portkey_ai.api_resources.apis.logger import Logger | ||
|
||
try: | ||
from langchain_core.callbacks import BaseCallbackHandler | ||
except ImportError: | ||
raise ImportError("Please pip install langchain-core to use PortkeyLangchain") | ||
|
||
|
||
class PortkeyLangchain(BaseCallbackHandler): | ||
def __init__( | ||
self, | ||
api_key: str, | ||
) -> None: | ||
super().__init__() | ||
self.startTimestamp: float = 0 | ||
self.endTimestamp: float = 0 | ||
|
||
self.api_key = api_key | ||
|
||
self.portkey_logger = Logger(api_key=api_key) | ||
|
||
self.log_object: Dict[str, Any] = {} | ||
self.prompt_records: Any = [] | ||
|
||
self.request: Any = {} | ||
self.response: Any = {} | ||
|
||
# self.responseHeaders: Dict[str, Any] = {} | ||
self.responseBody: Any = None | ||
self.responseStatus: int = 0 | ||
|
||
self.streamingMode: bool = False | ||
|
||
if not api_key: | ||
raise ValueError("Please provide an API key to use PortkeyCallbackHandler") | ||
|
||
def on_llm_start( | ||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||
) -> None: | ||
for prompt in prompts: | ||
messages = prompt.split("\n") | ||
for message in messages: | ||
role, content = message.split(":", 1) | ||
self.prompt_records.append( | ||
{"role": role.lower(), "content": content.strip()} | ||
) | ||
|
||
self.startTimestamp = float(datetime.now().timestamp()) | ||
|
||
self.streamingMode = kwargs.get("invocation_params", False).get("stream", False) | ||
|
||
self.request["method"] = "POST" | ||
self.request["url"] = serialized.get("kwargs", "").get( | ||
"base_url", "chat/completions" | ||
) | ||
self.request["provider"] = serialized["id"][2] | ||
self.request["headers"] = serialized.get("kwargs", {}).get( | ||
"default_headers", {} | ||
) | ||
self.request["headers"].update({"provider": serialized["id"][2]}) | ||
self.request["body"] = {"messages": self.prompt_records} | ||
self.request["body"].update({**kwargs.get("invocation_params", {})}) | ||
|
||
def on_chain_start( | ||
self, | ||
serialized: Dict[str, Any], | ||
inputs: Dict[str, Any], | ||
**kwargs: Any, | ||
) -> None: | ||
"""Run when chain starts running.""" | ||
|
||
def on_llm_end(self, response: Any, **kwargs: Any) -> None: | ||
self.endTimestamp = float(datetime.now().timestamp()) | ||
responseTime = self.endTimestamp - self.startTimestamp | ||
|
||
usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr] | ||
|
||
self.response["status"] = ( | ||
200 if self.responseStatus == 0 else self.responseStatus | ||
) | ||
self.response["body"] = { | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": response.generations[0][0].text, | ||
}, | ||
"logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501 | ||
"finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501 | ||
} | ||
] | ||
} | ||
self.response["body"].update({"usage": usage}) | ||
self.response["body"].update({"id": str(kwargs.get("run_id", ""))}) | ||
self.response["body"].update({"created": int(time.time())}) | ||
self.response["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501 | ||
self.response["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501 | ||
self.response["time"] = int(responseTime * 1000) | ||
self.response["headers"] = {} | ||
self.response["streamingMode"] = self.streamingMode | ||
|
||
self.log_object.update( | ||
{ | ||
"request": self.request, | ||
"response": self.response, | ||
} | ||
) | ||
|
||
self.portkey_logger.log(log_object=self.log_object) | ||
|
||
def on_chain_end( | ||
self, | ||
outputs: Dict[str, Any], | ||
**kwargs: Any, | ||
) -> None: | ||
"""Run when chain ends running.""" | ||
pass | ||
|
||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: | ||
self.responseBody = error | ||
self.responseStatus = error.status_code # type: ignore[attr-defined] | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: | ||
self.responseBody = error | ||
self.responseStatus = error.status_code # type: ignore[attr-defined] | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: | ||
self.responseBody = error | ||
self.responseStatus = error.status_code # type: ignore[attr-defined] | ||
pass | ||
|
||
def on_text(self, text: str, **kwargs: Any) -> None: | ||
pass | ||
|
||
def on_agent_finish(self, finish: Any, **kwargs: Any) -> None: | ||
pass | ||
|
||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | ||
self.streamingMode = True | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_tool_start( | ||
self, | ||
serialized: Dict[str, Any], | ||
input_str: str, | ||
**kwargs: Any, | ||
) -> None: | ||
pass | ||
|
||
def on_agent_action(self, action: Any, **kwargs: Any) -> Any: | ||
"""Do nothing.""" | ||
pass | ||
|
||
def on_tool_end( | ||
self, | ||
output: Any, | ||
observation_prefix: Optional[str] = None, | ||
llm_prefix: Optional[str] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .completions import PortkeyLLM | ||
from .portkey_llama_callback import PortkeyLlamaindex | ||
|
||
__all__ = ["PortkeyLLM"] | ||
__all__ = ["PortkeyLlamaindex"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import time | ||
from typing import Any, Dict, List, Optional | ||
from portkey_ai.api_resources.apis.logger import Logger | ||
from datetime import datetime | ||
|
||
try: | ||
from llama_index.core.callbacks.base_handler import ( | ||
BaseCallbackHandler as LlamaIndexBaseCallbackHandler, | ||
) | ||
from llama_index.core.utilities.token_counting import TokenCounter | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError( | ||
"Please install llama-index to use Portkey Callback Handler" | ||
) | ||
except ImportError: | ||
raise ImportError("Please pip install llama-index to use Portkey Callback Handler") | ||
|
||
|
||
class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler): | ||
startTimestamp: int = 0 | ||
endTimestamp: float = 0 | ||
|
||
def __init__( | ||
self, | ||
api_key: str, | ||
) -> None: | ||
super().__init__( | ||
event_starts_to_ignore=[], | ||
event_ends_to_ignore=[], | ||
) | ||
|
||
self.api_key = api_key | ||
|
||
self.portkey_logger = Logger(api_key=api_key) | ||
|
||
self._token_counter = TokenCounter() | ||
self.completion_tokens = 0 | ||
self.prompt_tokens = 0 | ||
self.token_llm = 0 | ||
|
||
self.log_object: Dict[str, Any] = {} | ||
self.prompt_records: Any = [] | ||
|
||
self.request: Any = {} | ||
self.response: Any = {} | ||
|
||
self.responseTime: int = 0 | ||
self.streamingMode: bool = False | ||
|
||
if not api_key: | ||
raise ValueError("Please provide an API key to use PortkeyCallbackHandler") | ||
|
||
def on_event_start( # type: ignore[return] | ||
self, | ||
event_type: Any, | ||
payload: Optional[Dict[str, Any]] = None, | ||
event_id: str = "", | ||
parent_id: str = "", | ||
**kwargs: Any, | ||
) -> str: | ||
"""Run when an event starts and return id of event.""" | ||
|
||
if event_type == "llm": | ||
self.llm_event_start(payload) | ||
|
||
def on_event_end( | ||
self, | ||
event_type: Any, | ||
payload: Optional[Dict[str, Any]] = None, | ||
event_id: str = "", | ||
**kwargs: Any, | ||
) -> None: | ||
"""Run when an event ends.""" | ||
|
||
if event_type == "llm": | ||
self.llm_event_stop(payload, event_id) | ||
|
||
def start_trace(self, trace_id: Optional[str] = None) -> None: | ||
"""Run when an overall trace is launched.""" | ||
self.startTimestamp = int(datetime.now().timestamp()) | ||
|
||
def end_trace( | ||
self, | ||
trace_id: Optional[str] = None, | ||
trace_map: Optional[Dict[str, List[str]]] = None, | ||
) -> None: | ||
"""Run when an overall trace is exited.""" | ||
|
||
def llm_event_start(self, payload: Any) -> None: | ||
if "messages" in payload: | ||
chunks = payload.get("messages", {}) | ||
self.prompt_tokens = self._token_counter.estimate_tokens_in_messages(chunks) | ||
messages = payload.get("messages", {}) | ||
self.prompt_records = [ | ||
{"role": m.role.value, "content": m.content} for m in messages | ||
] | ||
self.request["method"] = "POST" | ||
self.request["url"] = payload.get("serialized", {}).get( | ||
"api_base", "chat/completions" | ||
) | ||
self.request["provider"] = payload.get("serialized", {}).get("class_name", "") | ||
self.request["headers"] = {} | ||
self.request["body"] = {"messages": self.prompt_records} | ||
self.request["body"].update( | ||
{"model": payload.get("serialized", {}).get("model", "")} | ||
) | ||
self.request["body"].update( | ||
{"temperature": payload.get("serialized", {}).get("temperature", "")} | ||
) | ||
|
||
return None | ||
|
||
def llm_event_stop(self, payload: Any, event_id) -> None: | ||
self.endTimestamp = float(datetime.now().timestamp()) | ||
responseTime = self.endTimestamp - self.startTimestamp | ||
|
||
data = payload.get("response", {}) | ||
|
||
chunks = payload.get("messages", {}) | ||
self.completion_tokens = self._token_counter.estimate_tokens_in_messages(chunks) | ||
self.token_llm = self.prompt_tokens + self.completion_tokens | ||
self.response["status"] = 200 | ||
self.response["body"] = { | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": data.message.role.value, | ||
"content": data.message.content, | ||
}, | ||
"logprobs": data.logprobs, | ||
"finish_reason": "done", | ||
} | ||
] | ||
} | ||
self.response["body"].update( | ||
{ | ||
"usage": { | ||
"prompt_tokens": self.prompt_tokens, | ||
"completion_tokens": self.completion_tokens, | ||
"total_tokens": self.token_llm, | ||
} | ||
} | ||
) | ||
self.response["body"].update({"id": event_id}) | ||
self.response["body"].update({"created": int(time.time())}) | ||
self.response["body"].update({"model": data.raw.get("model", "")}) | ||
self.response["time"] = int(responseTime * 1000) | ||
self.response["headers"] = {} | ||
self.response["streamingMode"] = self.streamingMode | ||
|
||
self.log_object.update( | ||
{ | ||
"request": self.request, | ||
"response": self.response, | ||
} | ||
) | ||
self.portkey_logger.log(log_object=self.log_object) | ||
|
||
return None |
Oops, something went wrong.