-
Notifications
You must be signed in to change notification settings - Fork 24
/
messageHistory.m
303 lines (273 loc) · 12.4 KB
/
messageHistory.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
classdef (Sealed) messageHistory
%messageHistory - Create an object to manage and store messages in a conversation.
% messages = messageHistory creates a messageHistory object.
%
% messageHistory functions:
% addSystemMessage - Add system message.
% addUserMessage - Add user message.
% addUserMessageWithImages - Add user message with images for
% GPT-4 Turbo with Vision.
% addToolMessage - Add a tool message.
% addResponseMessage - Add a response message.
% removeMessage - Remove message from history.
%
% messageHistory properties:
% Messages - Messages in the conversation history.
% Copyright 2023-2024 The MathWorks, Inc.
properties(SetAccess=private)
%MESSAGES - Messages in the conversation history.
Messages = {}
end
methods
function this = addSystemMessage(this, name, content)
%addSystemMessage Add system message.
%
% MESSAGES = addSystemMessage(MESSAGES, NAME, CONTENT) adds a system
% message with the specified name and content. NAME and CONTENT
% must be text scalars.
%
% Example:
% % Create messages object
% messages = messageHistory;
%
% % Add system messages to provide examples of the conversation
% messages = addSystemMessage(messages, "example_user", "Hello, how are you?");
% messages = addSystemMessage(messages, "example_assistant", "Olá, como vai?");
% messages = addSystemMessage(messages, "example_user", "The sky is beautiful today");
% messages = addSystemMessage(messages, "example_assistant", "O céu está lindo hoje.");
arguments
this (1,1) messageHistory
name {mustBeNonzeroLengthTextScalar}
content {mustBeNonzeroLengthTextScalar}
end
newMessage = struct("role", "system", "name", string(name), "content", string(content));
this.Messages{end+1} = newMessage;
end
function this = addUserMessage(this, content)
%addUserMessage Add user message.
%
% MESSAGES = addUserMessage(MESSAGES, CONTENT) adds a user message
% with the specified content to MESSAGES. CONTENT must be a text scalar.
%
% Example:
% % Create messages object
% messages = messageHistory;
%
% % Add user message
% messages = addUserMessage(messages, "Where is Natick located?");
arguments
this (1,1) messageHistory
content {mustBeNonzeroLengthTextScalar}
end
newMessage = struct("role", "user", "content", string(content));
this.Messages{end+1} = newMessage;
end
function this = addUserMessageWithImages(this, content, images, nvp)
%addUserMessageWithImages Add user message with images
%
% MESSAGES = addUserMessageWithImages(MESSAGES, CONTENT, IMAGES)
% adds a user message with the specified content and images
% to MESSAGES. CONTENT must be a text scalar. IMAGES must be
% a string array of image URLs or file paths.
%
% messages = addUserMessageWithImages(__,Detail="low");
% specify how the model should process the images using
% "Detail" parameter. The default is "auto".
% - When set to "low", the model scales the image to 512x512
% - When set to "high", the model scales the image to 512x512
% and also creates detailed 512x512 crops of the image
% - When set to "auto", the models chooses which mode to use
% depending on the input image.
%
% Example:
%
% % Create a chat with GPT-4 Turbo with Vision
% chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-vision-preview");
%
% % Create messages object
% messages = messageHistory;
%
% % Add user message with an image
% content = "What is in this picture?"
% images = "peppers.png"
% messages = addUserMessageWithImages(messages, content, images);
%
% % Generate a response
% [text, response] = generate(chat, messages, MaxNumTokens=300);
arguments
this (1,1) messageHistory
content {mustBeNonzeroLengthTextScalar}
images (1,:) {mustBeNonzeroLengthText}
nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto"
end
newMessage = struct("role", "user", "content", string(content), ...
"images", images, "image_detail", nvp.Detail);
this.Messages{end+1} = newMessage;
end
function this = addToolMessage(this, id, name, content)
%addToolMessage Add Tool message.
%
% MESSAGES = addToolMessage(MESSAGES, ID, NAME, CONTENT)
% adds a tool message with the specified id, name and content.
% ID, NAME and CONTENT must be text scalars.
%
% Example:
% % Create messages object
% messages = messageHistory;
%
% % Add function message, containing the result of
% % calling strcat("Hello", " World")
% messages = addToolMessage(messages, "call_123", "strcat", "Hello World");
arguments
this (1,1) messageHistory
id {mustBeNonzeroLengthTextScalar}
name {mustBeNonzeroLengthTextScalar}
content {mustBeNonzeroLengthTextScalar}
end
newMessage = struct("tool_call_id", id, "role", "tool", ...
"name", string(name), "content", string(content));
this.Messages{end+1} = newMessage;
end
function this = addResponseMessage(this, messageStruct)
%addResponseMessage Add response message.
%
% MESSAGES = addResponseMessage(MESSAGES, messageStruct) adds a response
% message with the specified messageStruct. The input
% messageStruct should be a struct with field 'role' and
% value 'assistant' and with field 'content'. This response
% can be obtained from calling the GENERATE function.
%
% Example:
%
% % Create a chat object
% chat = openAIChat("You are a helpful AI Assistant.");
%
% % Create messages object
% messages = messageHistory;
%
% % Add user message
% messages = addUserMessage(messages, "What is the capital of England?");
%
% % Generate a response
% [text, response] = generate(chat, messages);
%
% % Add response to history
% messages = addResponseMessage(messages, response);
arguments
this (1,1) messageHistory
messageStruct (1,1) struct
end
if ~isfield(messageStruct, "role")||~isequal(messageStruct.role, "assistant")||~isfield(messageStruct, "content")
error("llms:mustBeAssistantCall",llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantCall"));
end
% Assistant is asking for function call
if isfield(messageStruct, "tool_calls")
toolCalls = messageStruct.tool_calls;
validateAssistantWithToolCalls(toolCalls)
this = addAssistantMessage(this, messageStruct.content, toolCalls);
else
% Simple assistant response
validateRegularAssistant(messageStruct.content);
this = addAssistantMessage(this,messageStruct.content);
end
end
function this = removeMessage(this, idx)
%removeMessage Remove message.
%
% MESSAGES = removeMessage(MESSAGES, IDX) removes a message at the specified
% index from MESSAGES. IDX must be a positive integer.
%
% Example:
%
% % Create messages object
% messages = messageHistory;
%
% % Add user messages
% messages = addUserMessage(messages, "What is the capital of England?");
% messages = addUserMessage(messages, "What is the capital of Italy?");
%
% % Remove the first message
% messages = removeMessage(messages,1);
arguments
this (1,1) messageHistory
idx (1,1) {mustBeInteger, mustBePositive}
end
if isempty(this.Messages)
error("llms:removeFromEmptyHistory",llms.utils.errorMessageCatalog.getMessage("llms:removeFromEmptyHistory"));
end
if idx>numel(this.Messages)
error("llms:mustBeValidIndex",llms.utils.errorMessageCatalog.getMessage("llms:mustBeValidIndex", string(numel(this.Messages))));
end
this.Messages(idx) = [];
end
end
methods(Access=private)
function this = addAssistantMessage(this, content, toolCalls)
arguments
this (1,1) messageHistory
content string
toolCalls struct = []
end
if isempty(toolCalls)
% Default assistant response
newMessage = struct("role", "assistant", "content", content);
else
% tool_calls message
toolsStruct = repmat(struct("id",[],"type",[],"function",[]),size(toolCalls));
for i = 1:numel(toolCalls)
toolsStruct(i).id = toolCalls(i).id;
toolsStruct(i).type = toolCalls(i).type;
toolsStruct(i).function = struct( ...
"name", toolCalls(i).function.name, ...
"arguments", toolCalls(i).function.arguments);
end
if numel(toolsStruct) > 1
newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct);
else
newMessage = struct("role", "assistant", "content", content, "tool_calls", []);
newMessage.tool_calls = {toolsStruct};
end
end
if isempty(this.Messages)
this.Messages = {newMessage};
else
this.Messages{end+1} = newMessage;
end
end
end
end
function mustBeNonzeroLengthTextScalar(content)
mustBeNonzeroLengthText(content)
mustBeTextScalar(content)
end
function validateRegularAssistant(content)
try
mustBeNonzeroLengthText(content)
mustBeTextScalar(content)
catch ME
error("llms:mustBeAssistantWithContent",llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithContent"))
end
end
function validateAssistantWithToolCalls(toolCallStruct)
if ~(isstruct(toolCallStruct) && isfield(toolCallStruct, "id") && isfield(toolCallStruct, "function"))
error("llms:mustBeAssistantWithIdAndFunction", ...
llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithIdAndFunction"))
else
functionCallStruct = [toolCallStruct.function];
end
if ~isfield(functionCallStruct, "name")||~isfield(functionCallStruct, "arguments")
error("llms:mustBeAssistantWithNameAndArguments", ...
llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithNameAndArguments"))
end
try
for i = 1:numel(functionCallStruct)
mustBeNonzeroLengthText(functionCallStruct(i).name)
mustBeTextScalar(functionCallStruct(i).name)
mustBeNonzeroLengthText(functionCallStruct(i).arguments)
mustBeTextScalar(functionCallStruct(i).arguments)
end
catch ME
error("llms:assistantMustHaveTextNameAndArguments", ...
llms.utils.errorMessageCatalog.getMessage("llms:assistantMustHaveTextNameAndArguments"))
end
end