-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathMessage.py
113 lines (91 loc) · 3.13 KB
/
Message.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Standard library
from typing import List, Union
# Third party
import json_repair
# Local
from saplings.dtos.ToolCall import ToolCall
class Message(object):
def __init__(
self,
role: str,
content: Union[None, str] = None,
tool_calls: Union[None, List[ToolCall]] = None,
tool_call_id: int = None,
raw_output: any = None,
):
self.role = role
self.content = content
self.tool_calls = tool_calls
self.tool_call_id = tool_call_id
# For tool messages (unformatted tool call output)
self.raw_output = raw_output
@classmethod
def system(cls, content):
return Message("system", content)
@classmethod
def user(cls, content):
return Message("user", content)
@classmethod
def assistant(cls, content):
return Message("assistant", content)
@classmethod
def tool_calls(cls, tool_calls):
return cls(role="assistant", tool_calls=tool_calls)
@classmethod
def tool(cls, content, id, raw_output=None):
return cls(
role="tool",
content=content,
tool_call_id=id,
raw_output=raw_output,
)
@classmethod
def from_openai_message(cls, message):
role = message.role
content = message.content
if hasattr(message, "tool_calls") and message.tool_calls:
tool_calls = []
for tool_call in message.tool_calls:
args = json_repair.loads(tool_call.function.arguments)
tool_call = ToolCall(tool_call.id, tool_call.function.name, args)
tool_calls.append(tool_call)
return cls.tool_calls(tool_calls)
return cls(role, content=content)
def to_openai_message(self):
message = {"role": self.role, "content": self.content}
if self.role == "tool":
message["tool_call_id"] = self.tool_call_id
if self.tool_calls:
message["tool_calls"] = [
tool_call.to_dict() for tool_call in self.tool_calls
]
return message
def __repr__(self):
self.__str__()
def __str__(self):
bold = "\033[1m"
grey = "\033[37m"
reset = "\033[0m"
if self.role == "tool":
return f'{bold}TOOL OUTPUT:{reset} {grey}"{self.content}"{reset}'
elif self.role == "user":
return f'{bold}USER INPUT:{reset} {grey}"{self.content}"{reset}'
elif self.role == "assistant":
if self.tool_calls:
tool_calls_str = ""
for tool_call in self.tool_calls:
tool_calls_str += (
f"{bold}TOOL CALL:{reset} {grey}{str(tool_call)}{reset}\n"
)
return tool_calls_str.strip()
return f'{bold}ASSISTANT OUTPUT:{reset} {grey}"{self.content}"{reset}'
return ""
def __hash__(self):
tool_call_hashes = [hash(tc) for tc in self.tool_calls or []]
return hash(
(
self.role,
self.content,
tuple(tool_call_hashes) or "",
)
)