Skip to content

Commit

Permalink
Minor
Browse files Browse the repository at this point in the history
  • Loading branch information
msveshnikov committed May 2, 2024
1 parent 2eb231e commit 5879798
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 46 deletions.
66 changes: 29 additions & 37 deletions server/claude.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,47 +55,39 @@ export const getTextClaude = async (prompt, temperature, imageBase64, fileType,
},
];

let response = await anthropic.beta.tools.messages.create({
model,
max_tokens: 4096,
temperature: temperature || 0.5,
tools: webTools ? tools : [],
messages,
});
let response = await getResponse();
let toolUses, toolResults;

if (!response) {
throw new Error("Claude Error");
} else {
let toolUses, toolResults;

while (response.stop_reason === "tool_use") {
toolUses = response.content.filter((block) => block.type === "tool_use");
if (!toolUses.length) {
return response?.content?.[0]?.text;
}
while (response?.stop_reason === "tool_use") {
toolUses = response.content.filter((block) => block.type === "tool_use");
if (!toolUses.length) {
return response?.content?.[0]?.text;
}

toolResults = await Promise.all(
toolUses.map(async (toolUse) => {
const toolResult = await handleToolCall(toolUse.name, toolUse.input, userId);
return { tool_use_id: toolUse.id, content: toolResult };
})
);
toolResults = await Promise.all(
toolUses.map(async (toolUse) => {
const toolResult = await handleToolCall(toolUse.name, toolUse.input, userId);
return { tool_use_id: toolUse.id, content: toolResult };
})
);

messages.push({ role: "assistant", content: response.content.filter((c) => c.type !== "text" || c.text) });
messages.push({
role: "user",
content: toolResults.map((toolResult) => ({ type: "tool_result", ...toolResult })),
});
messages.push({ role: "assistant", content: response.content.filter((c) => c.type !== "text" || c.text) });
messages.push({
role: "user",
content: toolResults.map((toolResult) => ({ type: "tool_result", ...toolResult })),
});
response = await getResponse();
}

response = await anthropic.beta.tools.messages.create({
model,
max_tokens: 4096,
temperature: temperature || 0.5,
tools: webTools ? tools : [],
messages,
});
}
return response?.content?.[0]?.text;

return response?.content?.[0]?.text;
async function getResponse() {
return anthropic.beta.tools.messages.create({
model,
max_tokens: 4096,
temperature: temperature || 0.5,
tools: webTools ? tools : [],
messages,
});
}
};
14 changes: 14 additions & 0 deletions server/gemini.js
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ export async function getTextGemini(prompt, temperature, imageBase64, fileType,
data: imageBase64,
},
});
} else if (fileType === "ogg") {
parts.push({
inlineData: {
mimeType: "audio/ogg",
data: imageBase64,
},
});
} else if (fileType === "wav") {
parts.push({
inlineData: {
mimeType: "audio/wav",
data: imageBase64,
},
});
}

const contents = {
Expand Down
14 changes: 7 additions & 7 deletions server/openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export const getTextGpt = async (prompt, temperature, userId, model, apiKey, web
const openAiTools = tools.map(renameProperty).map((f) => ({ type: "function", function: f }));
const messages = [{ role: "user", content: prompt }];

const getMessage = async () => {
const getResponse = async () => {
const completion = await openai.chat.completions.create({
model: model || "gpt-3.5-turbo",
max_tokens: 3000,
Expand All @@ -29,10 +29,10 @@ export const getTextGpt = async (prompt, temperature, userId, model, apiKey, web
return completion?.choices?.[0]?.message;
};

let responseMessage = await getMessage();
while (responseMessage?.tool_calls) {
const toolCalls = responseMessage?.tool_calls;
messages.push(responseMessage);
let response = await getResponse();
while (response?.tool_calls) {
const toolCalls = response?.tool_calls;
messages.push(response);
for (const toolCall of toolCalls) {
const toolResult = await handleToolCall(
toolCall.function.name,
Expand All @@ -46,7 +46,7 @@ export const getTextGpt = async (prompt, temperature, userId, model, apiKey, web
content: toolResult,
});
}
responseMessage = await getMessage();
response = await getResponse();
}
return responseMessage?.content;
return response?.content;
};
2 changes: 2 additions & 0 deletions src/components/ChatHistory.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ const getFileTypeIcon = (mimeType) => {
case "jpg":
return null;
case "mp3":
case "ogg":
case "wav":
case "mpeg":
case "x-m4a":
return "🔊";
Expand Down
4 changes: 2 additions & 2 deletions src/components/FileSelector.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const FileSelector = ({ onFileSelect, selectedFile, allowedFileTypes }) => {
if (type === "image") {
acc.push(".png", ".jpeg", ".jpg");
} else if (type === "audio") {
acc.push(".mp3", ".m4a");
acc.push(".mp3", ".m4a", ".ogg", ".wav");
} else if (type === "video") {
acc.push(".mp4");
} else if (type === "document") {
Expand All @@ -25,7 +25,7 @@ const FileSelector = ({ onFileSelect, selectedFile, allowedFileTypes }) => {
return acc;
}, [])
.join(",")
: ".pdf,.doc,.docx,.xls,.xlsx,.png,.jpeg,.jpg,.mp4,.mp3,.m4a";
: ".pdf,.doc,.docx,.xls,.xlsx,.png,.jpeg,.jpg,.mp4,.mp3,.m4a,.ogg";

return (
<Tooltip title="Upload file">
Expand Down

0 comments on commit 5879798

Please sign in to comment.