From 43d30c8407587c117309ac86962663eeff0e42b5 Mon Sep 17 00:00:00 2001 From: serefyarar Date: Sun, 1 Dec 2024 01:20:05 -0500 Subject: [PATCH] Update embedding.js --- api/src/services/embedding.js | 125 +++++++++++++++++++++++++++------- 1 file changed, 102 insertions(+), 23 deletions(-) diff --git a/api/src/services/embedding.js b/api/src/services/embedding.js index 0d462446..737a061f 100644 --- a/api/src/services/embedding.js +++ b/api/src/services/embedding.js @@ -3,6 +3,9 @@ import { getCurrentDateTime } from "../utils/helpers.js"; import { IndexService } from "./index.js"; import pkg from "knex"; import { ChromaClient } from 'chromadb'; +import OpenAI from "openai"; +import { getModelInfo } from "../utils/mode.js"; + const { knex } = pkg; @@ -19,6 +22,44 @@ const collection = await chromaClient.getOrCreateCollection({ name: process.env.CHROMA_COLLECTION_NAME || "index_mainnet", }); +const openai = new OpenAI({ + apiKey: process.env.OPENAI_API_KEY, +}); + +const getDocText = (doc, metadata, runtimeDefinition) => { + + + if (metadata.modelId === runtimeDefinition.models.Cast.id) { + const authorName = doc.author.name || doc.author.username; + const castUrl = `https://warpcast.com/${doc.author.username}/${doc.hash.substring(0, 12)}`; + const authorUrl = `https://warpcast.com/${doc.author.username}`; + + return [ + 'Cast details:', + `- text: ${doc.text}`, + `- link: ${castUrl}`, + `- author: [${authorName}](${authorUrl})`, + `- created_at: ${doc.timestamp}`, + '----' + ].join('\n'); + } + + if (metadata.modelId === runtimeDefinition.models.Event.id) { + return [ + 'Event details:', + `- title: ${doc.title}`, + `- location: ${doc.location}`, + `- start time: ${doc.start_time}`, + `- end time: ${doc.end_time}`, + `- link: ${doc.link}`, + `- description: ${doc.description}`, + '----' + ].join('\n'); + } + + return JSON.stringify(doc); +}; + export class EmbeddingService { constructor(definition) { @@ -333,6 +374,8 @@ export class EmbeddingService { const BATCH_SIZE = 1000; let processedCount = 0; + const { runtimeDefinition } = await getModelInfo(); + while (true) { const embeddings = await cli('index_embeddings') .select('*') @@ -343,31 +386,67 @@ export class EmbeddingService { if (embeddings.length === 0) break; - const newIds = embeddings.map(e => e.stream_id); - const newVectors = embeddings.map(e => JSON.parse(e.vector)); - const newMetadatas = embeddings.map(embedding => ({ - modelName: embedding.model_name, - modelId: embedding.model_id, - indexId: embedding.index_id, - itemId: embedding.item_id, - createdAt: new Date(embedding.created_at).toISOString(), - updatedAt: new Date(embedding.updated_at).toISOString(), - })); - const newDatas = await Promise.all(embeddings.map(async (a) => { - const itemStream = await cli(a.model_id) - .select('stream_content') - .where('stream_id', a.item_id) - .first(); - return JSON.stringify(itemStream.stream_content); - })); - - await collection.upsert({ - ids: newIds, - embeddings: newVectors, // Use existing vectors - metadatas: newMetadatas, - documents: newDatas, + // Get IDs to check in Chroma + const ids = embeddings.map(e => e.stream_id); + + // Check which IDs already exist in Chroma + const existingEmbeddings = await collection.get({ + ids, + include: ['metadatas'] }); + // Filter out embeddings that already exist and have same metadata + const newEmbeddings = embeddings.filter(embedding => { + const existing = existingEmbeddings.metadatas?.find(m => + m?.itemId === embedding.item_id && + m?.modelId === embedding.model_id && + m?.indexId === embedding.index_id + ); + return !existing; + }); + + if (newEmbeddings.length > 0) { + // Process only new embeddings + const datas = await Promise.all(newEmbeddings.map(async (a) => { + const itemStream = await cli(a.model_id) + .select('stream_content') + .where('stream_id', a.item_id) + .first(); + return itemStream.stream_content; + })); + + const formattedTexts = datas.map((doc, index) => + getDocText(doc, { + modelId: newEmbeddings[index].model_id + }, runtimeDefinition) + ); + + const newEmbeddingVectors = await Promise.all(formattedTexts.map(async (text) => { + const response = await openai.embeddings.create({ + model: MODEL_EMBEDDING, + input: text, + }); + return response.data[0].embedding; + })); + + const newIds = newEmbeddings.map(e => e.stream_id); + const newMetadatas = newEmbeddings.map(embedding => ({ + modelName: embedding.model_name, + modelId: embedding.model_id, + indexId: embedding.index_id, + itemId: embedding.item_id, + createdAt: new Date(embedding.created_at).toISOString(), + updatedAt: new Date(embedding.updated_at).toISOString(), + })); + + await collection.upsert({ + ids: newIds, + embeddings: newEmbeddingVectors, + metadatas: newMetadatas, + documents: datas.map(JSON.stringify) + }); + + } console.log(processedCount) processedCount += embeddings.length; if (embeddings.length < BATCH_SIZE) break;