Skip to content

Commit

Permalink
feat: store image metadata in db
Browse files Browse the repository at this point in the history
  • Loading branch information
radityaharya committed Apr 16, 2024
1 parent 31ffbc0 commit 0b7ceba
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 46 deletions.
Binary file modified bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"gpt3-tokenizer": "^1.1.5",
"hono": "^4.2.4",
"langchain": "^0.1.33",
"langsmith": "^0.1.14",
"lodash": "^4.17.21",
"openai": "^4.33.1",
"pino": "^8.20.0",
Expand Down
6 changes: 6 additions & 0 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@ model RssPooler {
lastChecked DateTime
lastCheckedString String? // Nullable
etag String? // Nullable
}

model AnalyzedAttachmentMetadata {
id String @id @default(uuid())
messageId String
metadata String
}
20 changes: 17 additions & 3 deletions src/events/message-create.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,24 @@ async function handleThreadMessage(

await channel.sendTyping();

if (message.attachments.size > 0) {
const attachment = message.attachments.first();
if (attachment) {
const file = await tempFile(attachment.url);
message.content = `data:${file.mimeType};base64,${file.base64}`;
}
}

const typingInterval = setInterval(() => {
channel.sendTyping();
}, 5000);

const completion = await createChatCompletion(
buildThreadContext(messages, message.content, client.user.id),
buildThreadContext(messages, message, client.user.id),
);

clearInterval(typingInterval);

if (completion.status !== CompletionStatus.Ok) {
await handleFailedRequest(
channel,
Expand Down Expand Up @@ -157,7 +171,7 @@ async function handleDirectMessage(
}, 5000);

const completion = await createChatCompletion(
buildDirectMessageContext(messages, message.content, client.user.id),
buildDirectMessageContext(messages, message, client.user.id),
);

clearInterval(typingInterval);
Expand Down Expand Up @@ -252,7 +266,7 @@ async function splitSend(completion: CompletionResponse, channel: DMChannel) {
if (message.trim() !== '') {
await channel.sendTyping();
await new Promise((resolve) =>
setTimeout(resolve, (message.length / 20) * 1000),
setTimeout(resolve, (message.length / 30) * 1000),
);
await channel.send({
content: message,
Expand Down
55 changes: 41 additions & 14 deletions src/lib/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ import {
type ThreadChannel,
} from 'discord.js';
import GPT3Tokenizer from 'gpt3-tokenizer';
import type OpenAI from 'openai';

import config from '@/config';

export type MessageContext = {
role: string;
content: string;
id: string;
};

// TODO: inject multimodal context metadata here
export function buildContext(
messages: Array<any>,
userMessage: string,
messages: Array<MessageContext>,
userMessage: Message,
instruction?: string,
): Array<any> {
): Array<MessageContext> {
let finalInstruction = instruction;

if (!finalInstruction || finalInstruction === 'Default') {
Expand All @@ -29,15 +34,19 @@ export function buildContext(
finalInstruction += '.';
}

finalInstruction.replace('{{user}}', userMessage.author.username);

const systemMessageContext = {
role: 'system',
content: `${finalInstruction} The current date is ${format(new Date(), 'PPP')}.`,
content: `${finalInstruction} The current date is ${format(new Date(), 'PPP')}. The latest message is form ${userMessage.author.username}.`,
name: 'system',
id: 'system',
};

const userMessageContext = {
role: 'user',
content: userMessage,
content: userMessage.content,
id: userMessage.id,
};

if (messages.length === 0) {
Expand All @@ -53,14 +62,15 @@ export function buildContext(
for (let i = 0; i < messages.length; i++) {
const message = messages[i];
const content = message.content as string;
const encoded = tokenizer.encode(content);

const encoded = tokenizer.encode(content);
tokenCount += encoded.text.length;

if (tokenCount > maxTokens) {
contexts.push({
role: message.role,
content: content.slice(0, tokenCount - maxTokens),
id: message.id,
});

break;
Expand All @@ -69,6 +79,7 @@ export function buildContext(
contexts.push({
role: message.role,
content,
id: message.id,
});
}

Expand All @@ -77,10 +88,10 @@ export function buildContext(

export function buildThreadContext(
messages: Collection<string, Message>,
userMessage: string,
userMessage: Message,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
botId: string,
): Array<OpenAI.Chat.ChatCompletionMessageParam> {
): Array<MessageContext> {
if (messages.size === 0) {
return buildContext([], userMessage);
}
Expand Down Expand Up @@ -108,20 +119,28 @@ export function buildThreadContext(
}

const context = [
{ role: 'user', content: prompt, name: 'user' },
{ role: 'user', content: prompt, name: 'user', id: initialMessage.id },
...messages
.filter(
(message) =>
message.type === MessageType.Default &&
message.content &&
(message.content || message.attachments.size > 0) &&
message.embeds.length === 0 &&
(message.mentions.members?.size ?? 0) === 0,
)
.map((message) => {
if (message.attachments.size > 0) {
const attachment = message.attachments.first();
if (attachment) {
message.content = 'data:image';
}
}

return {
role: 'function',
content: message.content,
name: 'someName',
id: message.id,
};
})
.reverse(),
Expand All @@ -132,9 +151,9 @@ export function buildThreadContext(

export function buildDirectMessageContext(
messages: Collection<string, Message>,
userMessage: string,
userMessage: Message,
botId: string,
): Array<OpenAI.Chat.ChatCompletionMessageParam> {
): Array<MessageContext> {
if (messages.size === 0) {
return buildContext([], userMessage);
}
Expand All @@ -143,14 +162,22 @@ export function buildDirectMessageContext(
.filter(
(message) =>
message.type === MessageType.Default &&
message.content &&
(message.content || message.attachments.size > 0) &&
message.embeds.length === 0 &&
(message.mentions.members?.size ?? 0) === 0,
)
.map((message) => {
if (message.attachments.size > 0) {
const attachment = message.attachments.first();
if (attachment) {
message.content = 'data:image';
}
}

return {
role: message.author.id === botId ? 'assistant' : 'user',
content: message.content,
id: message.id,
};
})
.reverse();
Expand Down
85 changes: 61 additions & 24 deletions src/lib/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import {
import { ChatOpenAI } from '@langchain/openai';
import { ChatGoogleGenerativeAI } from '@langchain/google-genai';
import { getAnimeDetails, getAnimeSauce, TraceMoeResultItem } from './tracemoe';
import {
getAnalyzedAttachmentMetadataByMessageId,
setAttachmentMetadata,
} from '@/utils/metadataLogger';
import { MessageContext } from './helpers';

const openai = new OpenAI({
apiKey: config.openai.api_key,
Expand Down Expand Up @@ -72,13 +77,38 @@ async function traceAnimeContext(base64Image: string) {
episode: match.episode,
episodes: anilistResult.data.Media.episodes,
genres: anilistResult.data.Media.genres,
score: anilistResult.data.Media.averageScore,
description: anilistResult.data.Media.description,
video: match.video,
image: match.image,
characters: anilistResult.data.Media.characters.edges
.slice(0, 5)
.map((edge) => {
const { name, gender, description } = edge.node;
const truncatedDescription = description
? `${description.substring(0, 47)}...`
: 'No description available';
return `${name.full} (Gender: ${gender}, Description: ${truncatedDescription})`;
})
.join(', ')
.replace(/, ([^,]*)$/, ' and $1'),
nextAiringDatetime: new Date(
anilistResult.data.Media.nextAiringEpisode?.airingAt * 1000,
).toLocaleString(),
relations: anilistResult.data.Media.relations.edges
.map(
(edge) =>
edge.node.title.english ||
edge.node.title.romaji ||
edge.node.title.native,
)
.join(', '),
startDate: new Date(
anilistResult.data.Media.startDate.year,
anilistResult.data.Media.startDate.month - 1,
anilistResult.data.Media.startDate.day,
).toLocaleDateString(),
score: anilistResult.data.Media.averageScore,
};

additionalContext = `The image is from the anime titled "${anime.title}". This anime falls under the genres: ${anime.genres.join(', ')}. It has an average score of ${anime.score}. The specific scene in the image is from episode ${anime.episode} out of the total ${anime.episodes} episodes. Here is a brief description of the anime: "${anime.description}". Do note that the context provided is based on the image and may not be 100% accurate.`;
additionalContext = `The image is from the anime titled "${anime.title}". This anime falls under the genres: ${anime.genres.join(', ')}. It has an average score of ${anime.score}. The specific scene in the image is from episode ${anime.episode} out of the total ${anime.episodes} episodes. Here is a brief description of the anime: "${anime.description}". The main characters in this anime are ${anime.characters}. The next episode is scheduled to air on ${anime.nextAiringDatetime}. The anime is set to release on ${anime.startDate}. The anime has relations with the following anime or mangas: ${anime.relations}.`;
}
} catch (error) {
console.error('Error tracing anime context:', error);
Expand All @@ -92,7 +122,6 @@ async function generateImageContext(file: string) {
return additionalContext;
}

// TODO: Save context metadata in db and asign it to the history?
async function identifyImage(file: string) {
const additionalContext = await generateImageContext(file);

Expand All @@ -117,30 +146,38 @@ async function identifyImage(file: string) {
return response;
}
export async function createChatCompletion(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
messages: Array<any>,
messages: Array<MessageContext>,
): Promise<CompletionResponse> {
try {
const chatMessages = messages.map(async (message) => {
switch (message.role) {
case 'system':
return new SystemMessage(message.content);
case 'user': {
if (message.content.startsWith('data:image')) {
return new SystemMessage(
await identifyImage(message.content as string),
);
const chatMessages = await Promise.all(
messages.map(async (message) => {
switch (message.role) {
case 'system':
return new SystemMessage(message.content);
case 'user': {
if (message.content.startsWith('data:image')) {
const analyzedAttachment =
await getAnalyzedAttachmentMetadataByMessageId(message.id);
let metadata = analyzedAttachment
? analyzedAttachment.metadata
: null;
if (!metadata) {
metadata = await identifyImage(message.content as string);
await setAttachmentMetadata(message.id, metadata);
}
return new SystemMessage(metadata);
}
return new HumanMessage(message.content);
}
return new HumanMessage(message.content);
case 'assistant':
return new AIMessage(message.content);
default:
throw new Error(`Invalid message role: ${message.role}`);
}
case 'assistant':
return new AIMessage(message.content);
default:
throw new Error(`Invalid message role: ${message.role}`);
}
});
}),
);

const completion = await chat.invoke(await Promise.all(chatMessages));
const completion = await chat.invoke(chatMessages);
const message = completion.content;
if (message) {
return {
Expand Down
Loading

0 comments on commit 0b7ceba

Please sign in to comment.