Skip to content

Commit

Permalink
✨ Added multi tenancy to vector dbs
Browse files Browse the repository at this point in the history
  • Loading branch information
naelob committed Sep 16, 2024
1 parent c9f0652 commit a666b60
Show file tree
Hide file tree
Showing 20 changed files with 386 additions and 169 deletions.
5 changes: 4 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,22 @@ NEXT_PUBLIC_DISTRIBUTION=selfhost # selfhost or managed
## pinecone
PINECONE_API_KEY=
PINECONE_INDEX_NAME=

## qdrant
QDRANT_BASE_URL=
QDRANT_API_KEY=
QDRANT_COLLECTION_NAME=
## chroma
CHROMADB_URL=
CHROMADB_COLLECTION_NAME=
## weaviate
WEAVIATE_URL=
WEAVIATE_API_KEY=
WEAVIATE_CLASS_NAME=
# turbopuffer
TURBOPUFFER_API_KEY=
# milvus
MILVUS_ADDRESS=
MILVUS_COLLECTION_NAME=

# EMBEDDINGS
JINA_API_KEY=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ const formSchema = z.object({
url: z.string().optional(),
indexName: z.string().optional(),
embeddingApiKey: z.string().optional(),
collectionName: z.string().optional(),
className: z.string().optional(),
});

interface ItemDisplayProps {
Expand Down Expand Up @@ -183,11 +185,11 @@ export function RAGItemDisplay({ item, type }: ItemDisplayProps) {
case 'pinecone':
return ['apiKey', 'indexName'];
case 'qdrant':
return ['apiKey', 'baseUrl'];
return ['apiKey', 'baseUrl', 'collectionName'];
case 'chromadb':
return ['url'];
return ['url', 'collectionName'];
case 'weaviate':
return ['apiKey', 'url'];
return ['apiKey', 'url', 'className'];
case 'openai_ada_small_1536':
case 'openai_ada_large_3072':
case 'openai_ada_002':
Expand All @@ -210,13 +212,16 @@ export function RAGItemDisplay({ item, type }: ItemDisplayProps) {
case 'qdrant':
form.setValue("apiKey", data[0]);
form.setValue("baseUrl", data[1]);
form.setValue("collectionName", data[1]);
break;
case 'chromadb':
form.setValue("url", data[0]);
form.setValue("collectionName", data[1]);
break;
case 'weaviate':
form.setValue("apiKey", data[0]);
form.setValue("url", data[1]);
form.setValue("className", data[1]);
break;
case 'openai_ada_small_1536':
case 'openai_ada_large_3072':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { cn } from "@/lib/utils"
import { Button } from "@/components/ui/button"
import { ScrollArea } from "@/components/ui/scroll-area"
import { vectorDatabases, embeddingModels } from "./utils"
import { useRagItem } from "./useRAGItem"
import { useRagItem } from "./useRagItem"

interface RAGItemListProps {
items: (typeof vectorDatabases[number] | typeof embeddingModels[number])[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import * as React from "react"
import { RAGItemDisplay } from "./RAGItemDisplay"
import { RAGItemList } from "./RAGItemList"
import { embeddingModels, vectorDatabases } from "./utils"
import { useRagItem } from "./useRAGItem"
import { useRagItem } from "./useRagItem"

interface Props {
items: (typeof vectorDatabases[number] | typeof embeddingModels[number])[];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as React from "react"
import { RAGLayout } from "./RagLayout";
import { RAGLayout } from "./RAGLayout";
import { embeddingModels, TabType, vectorDatabases } from "./utils";

export default function RAGSettingsPage() {
Expand Down
4 changes: 4 additions & 0 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,15 @@ services:
PINECONE_INDEX_NAME: ${PINECONE_INDEX_NAME}
QDRANT_BASE_URL: ${QDRANT_BASE_URL}
QDRANT_API_KEY: ${QDRANT_API_KEY}
QDRANT_COLLECTION_NAME: ${QDRANT_COLLECTION_NAME}
CHROMADB_URL: ${CHROMADB_URL}
CHROMADB_COLLECTION_NAME: ${CHROMADB_COLLECTION_NAME}
WEAVIATE_URL: ${WEAVIATE_URL}
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY}
WEAVIATE_CLASS_NAME: ${WEAVIATE_CLASS_NAME}
TURBOPUFFER_API_KEY: ${TURBOPUFFER_API_KEY}
MILVUS_ADDRESS: ${MILVUS_ADDRESS}
MILVUS_COLLECTION_NAME: ${MILVUS_COLLECTION_NAME}

restart: unless-stopped
ports:
Expand Down
4 changes: 4 additions & 0 deletions docker-compose.source.yml
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,15 @@ services:
PINECONE_INDEX_NAME: ${PINECONE_INDEX_NAME}
QDRANT_BASE_URL: ${QDRANT_BASE_URL}
QDRANT_API_KEY: ${QDRANT_API_KEY}
QDRANT_COLLECTION_NAME: ${QDRANT_COLLECTION_NAME}
CHROMADB_URL: ${CHROMADB_URL}
CHROMADB_COLLECTION_NAME: ${CHROMADB_COLLECTION_NAME}
WEAVIATE_URL: ${WEAVIATE_URL}
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY}
WEAVIATE_CLASS_NAME: ${WEAVIATE_CLASS_NAME}
TURBOPUFFER_API_KEY: ${TURBOPUFFER_API_KEY}
MILVUS_ADDRESS: ${MILVUS_ADDRESS}
MILVUS_COLLECTION_NAME: ${MILVUS_COLLECTION_NAME}

restart: unless-stopped
ports:
Expand Down
4 changes: 4 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,15 @@ services:
PINECONE_INDEX_NAME: ${PINECONE_INDEX_NAME}
QDRANT_BASE_URL: ${QDRANT_BASE_URL}
QDRANT_API_KEY: ${QDRANT_API_KEY}
QDRANT_COLLECTION_NAME: ${QDRANT_COLLECTION_NAME}
CHROMADB_URL: ${CHROMADB_URL}
CHROMADB_COLLECTION_NAME: ${CHROMADB_COLLECTION_NAME}
WEAVIATE_URL: ${WEAVIATE_URL}
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY}
WEAVIATE_CLASS_NAME: ${WEAVIATE_CLASS_NAME}
TURBOPUFFER_API_KEY: ${TURBOPUFFER_API_KEY}
MILVUS_ADDRESS: ${MILVUS_ADDRESS}
MILVUS_COLLECTION_NAME: ${MILVUS_COLLECTION_NAME}

restart: unless-stopped
ports:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ export class EnvironmentService {
};
}

getChromaCreds(): string {
return this.configService.get<string>('CHROMADB_URL');
getChromaCreds() {
return {
url: this.configService.get<string>('CHROMADB_URL'),
collectionName: this.configService.get<string>(
'CHROMADB_COLLECTION_NAME',
),
};
}

getMilvusCreds() {
return {
address: this.configService.get<string>('MILVUS_ADDRESS'),
collectionName: this.configService.get<string>('MILVUS_COLLECTION_NAME'),
};
}
getPineconeCreds() {
Expand All @@ -85,6 +91,7 @@ export class EnvironmentService {
return {
url: this.configService.get<string>('WEAVIATE_URL'),
apiKey: this.configService.get<string>('WEAVIATE_API_KEY'),
className: this.configService.get<string>('WEAVIATE_CLASS_NAME'),
};
}

Expand All @@ -96,6 +103,7 @@ export class EnvironmentService {
return {
baseUrl: this.configService.get<string>('QDRANT_BASE_URL'),
apiKey: this.configService.get<string>('QDRANT_API_KEY'),
collectionName: this.configService.get<string>('QDRANT_COLLECTION_NAME'),
};
}

Expand Down
11 changes: 8 additions & 3 deletions packages/api/src/@core/rag/document.processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ export class ProcessDocumentProcessor {

@Process('batchDocs')
async processDocuments(
job: Job<{ filesInfo: FileInfo[]; projectId: string }>,
job: Job<{
filesInfo: FileInfo[];
projectId: string;
linkedUserId: string;
}>,
) {
const { filesInfo, projectId } = job.data;
const { filesInfo, projectId, linkedUserId } = job.data;
const results = [];

for (const fileInfo of filesInfo) {
Expand All @@ -40,7 +44,7 @@ export class ProcessDocumentProcessor {
// console.log(`chunks for ${fileInfo.id} are ` + JSON.stringify(chunks));
const embeddings = await this.embeddingService.generateEmbeddings(
chunks,
projectId
projectId,
);
// Split embeddings into smaller batches
const batchSize = 100; // Adjust this value as needed
Expand All @@ -52,6 +56,7 @@ export class ProcessDocumentProcessor {
batchChunks,
batchEmbeddings,
projectId,
linkedUserId,
);
}
results.push(`Successfully processed document ${fileInfo.id}`);
Expand Down
35 changes: 22 additions & 13 deletions packages/api/src/@core/rag/rag.controller.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
import { Controller, Post, Body, UseGuards } from '@nestjs/common';
import { RagService } from './rag.service';
import { ApiKeyAuthGuard } from '@@core/auth/guards/api-key.guard';
import {
ApiBody,
ApiOperation,
ApiParam,
ApiQuery,
ApiTags,
ApiHeader,
//ApiKeyAuth,
} from '@nestjs/swagger';
import { ConnectionUtils } from '@@core/connections/@utils';
import { Body, Controller, Headers, Post, UseGuards } from '@nestjs/common';
import { ApiHeader } from '@nestjs/swagger';
import { RagService } from './rag.service';

@Controller('rag')
export class RagController {
Expand All @@ -21,9 +13,26 @@ export class RagController {
) {}

@Post('query')
@ApiHeader({
name: 'x-connection-token',
required: true,
description: 'The connection token',
example: 'b008e199-eda9-4629-bd41-a01b6195864a',
})
@UseGuards(ApiKeyAuthGuard)
async queryEmbeddings(@Body() body: { query: string; topK?: number }) {
return this.documentEmbeddingService.queryEmbeddings(body.query, body.topK);
async queryEmbeddings(
@Body() body: { query: string; topK?: number },
@Headers('x-connection-token') connection_token: string,
) {
const { linkedUserId, remoteSource, connectionId, projectId } =
await this.connectionUtils.getConnectionMetadataFromConnectionToken(
connection_token,
);
return this.documentEmbeddingService.queryEmbeddings(
body.query,
body.topK,
linkedUserId,
);
}

/*
Expand Down
3 changes: 2 additions & 1 deletion packages/api/src/@core/rag/rag.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ export class RagService {
private s3Service: S3Service,
) {}

async queryEmbeddings(query: string, topK = 5) {
async queryEmbeddings(query: string, topK = 5, linkedUserId: string) {
const queryEmbedding = await this.embeddingService.embedQuery(query);
const results = await this.vectorDatabaseService.queryEmbeddings(
queryEmbedding,
topK,
linkedUserId,
);
return results.map((match: any) => ({
chunk: match.metadata.text,
Expand Down
47 changes: 25 additions & 22 deletions packages/api/src/@core/rag/vecdb/chromadb/chromadb.service.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,55 @@
import { EnvironmentService } from '@@core/@core-services/environment/environment.service';
import { ProcessedChunk } from '@@core/rag/types';
import { Injectable } from '@nestjs/common';
import { ChromaClient } from 'chromadb';
import { ChromaClient, Collection } from 'chromadb';

@Injectable()
export class ChromaDBService {
private client: ChromaClient;
private collection: Collection;

constructor(private envService: EnvironmentService) {
//this.initialize();
constructor(private envService: EnvironmentService) {}

async onModuleInit() {
return;
}

async initialize() {
async initialize(credentials: string[]) {
this.client = new ChromaClient({
path: this.envService.getChromaCreds(),
path: credentials[0],
});
this.collection = await this.client.getOrCreateCollection({
name: credentials[1],
});
}

async storeEmbeddings(
fileId: string,
chunks: ProcessedChunk[],
embeddings: number[][],
linkedUserId: string,
) {
const collection = await this.client.createCollection({ name: fileId });
await collection.add({
await this.collection.add({
ids: chunks.map((_, i) => `${fileId}_${i}`),
embeddings: embeddings,
metadatas: chunks.map((chunk) => ({
text: chunk.text,
...chunk.metadata,
user_id: `ns_${linkedUserId}`,
})),
});
}

async queryEmbeddings(queryEmbedding: number[], topK: number) {
const collections = await this.client.listCollections();
const results = await Promise.all(
collections.map(async (collection) => {
const collectionInstance = await this.client.getCollection({
name: collection.name,
});
const result = await collectionInstance.query({
queryEmbeddings: [queryEmbedding],
nResults: topK,
});
return result.metadatas[0];
}),
);
return results.flat().slice(0, topK);
async queryEmbeddings(
queryEmbedding: number[],
topK: number,
linkedUserId: string,
) {
const result = await this.collection.query({
queryEmbeddings: [queryEmbedding],
nResults: topK,
where: { user_id: `ns_${linkedUserId}` },
});
return result.metadatas[0];
}
}
Loading

0 comments on commit a666b60

Please sign in to comment.