Skip to content

Commit

Permalink
Update embedding.js
Browse files Browse the repository at this point in the history
  • Loading branch information
serefyarar committed Dec 1, 2024
1 parent 8c29085 commit 43d30c8
Showing 1 changed file with 102 additions and 23 deletions.
125 changes: 102 additions & 23 deletions api/src/services/embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import { getCurrentDateTime } from "../utils/helpers.js";
import { IndexService } from "./index.js";
import pkg from "knex";
import { ChromaClient } from 'chromadb';
import OpenAI from "openai";
import { getModelInfo } from "../utils/mode.js";


const { knex } = pkg;

Expand All @@ -19,6 +22,44 @@ const collection = await chromaClient.getOrCreateCollection({
name: process.env.CHROMA_COLLECTION_NAME || "index_mainnet",
});

const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

const getDocText = (doc, metadata, runtimeDefinition) => {


if (metadata.modelId === runtimeDefinition.models.Cast.id) {
const authorName = doc.author.name || doc.author.username;
const castUrl = `https://warpcast.com/${doc.author.username}/${doc.hash.substring(0, 12)}`;
const authorUrl = `https://warpcast.com/${doc.author.username}`;

return [
'Cast details:',
`- text: ${doc.text}`,
`- link: ${castUrl}`,
`- author: [${authorName}](${authorUrl})`,
`- created_at: ${doc.timestamp}`,
'----'
].join('\n');
}

if (metadata.modelId === runtimeDefinition.models.Event.id) {
return [
'Event details:',
`- title: ${doc.title}`,
`- location: ${doc.location}`,
`- start time: ${doc.start_time}`,
`- end time: ${doc.end_time}`,
`- link: ${doc.link}`,
`- description: ${doc.description}`,
'----'
].join('\n');
}

return JSON.stringify(doc);
};


export class EmbeddingService {
constructor(definition) {
Expand Down Expand Up @@ -333,6 +374,8 @@ export class EmbeddingService {
const BATCH_SIZE = 1000;
let processedCount = 0;

const { runtimeDefinition } = await getModelInfo();

while (true) {
const embeddings = await cli('index_embeddings')
.select('*')
Expand All @@ -343,31 +386,67 @@ export class EmbeddingService {

if (embeddings.length === 0) break;

const newIds = embeddings.map(e => e.stream_id);
const newVectors = embeddings.map(e => JSON.parse(e.vector));
const newMetadatas = embeddings.map(embedding => ({
modelName: embedding.model_name,
modelId: embedding.model_id,
indexId: embedding.index_id,
itemId: embedding.item_id,
createdAt: new Date(embedding.created_at).toISOString(),
updatedAt: new Date(embedding.updated_at).toISOString(),
}));
const newDatas = await Promise.all(embeddings.map(async (a) => {
const itemStream = await cli(a.model_id)
.select('stream_content')
.where('stream_id', a.item_id)
.first();
return JSON.stringify(itemStream.stream_content);
}));

await collection.upsert({
ids: newIds,
embeddings: newVectors, // Use existing vectors
metadatas: newMetadatas,
documents: newDatas,
// Get IDs to check in Chroma
const ids = embeddings.map(e => e.stream_id);

// Check which IDs already exist in Chroma
const existingEmbeddings = await collection.get({
ids,
include: ['metadatas']
});

// Filter out embeddings that already exist and have same metadata
const newEmbeddings = embeddings.filter(embedding => {
const existing = existingEmbeddings.metadatas?.find(m =>
m?.itemId === embedding.item_id &&
m?.modelId === embedding.model_id &&
m?.indexId === embedding.index_id
);
return !existing;
});

if (newEmbeddings.length > 0) {
// Process only new embeddings
const datas = await Promise.all(newEmbeddings.map(async (a) => {
const itemStream = await cli(a.model_id)
.select('stream_content')
.where('stream_id', a.item_id)
.first();
return itemStream.stream_content;
}));

const formattedTexts = datas.map((doc, index) =>
getDocText(doc, {
modelId: newEmbeddings[index].model_id
}, runtimeDefinition)
);

const newEmbeddingVectors = await Promise.all(formattedTexts.map(async (text) => {
const response = await openai.embeddings.create({
model: MODEL_EMBEDDING,
input: text,
});
return response.data[0].embedding;
}));

const newIds = newEmbeddings.map(e => e.stream_id);
const newMetadatas = newEmbeddings.map(embedding => ({
modelName: embedding.model_name,
modelId: embedding.model_id,
indexId: embedding.index_id,
itemId: embedding.item_id,
createdAt: new Date(embedding.created_at).toISOString(),
updatedAt: new Date(embedding.updated_at).toISOString(),
}));

await collection.upsert({
ids: newIds,
embeddings: newEmbeddingVectors,
metadatas: newMetadatas,
documents: datas.map(JSON.stringify)
});

}
console.log(processedCount)
processedCount += embeddings.length;
if (embeddings.length < BATCH_SIZE) break;
Expand Down

0 comments on commit 43d30c8

Please sign in to comment.