Skip to content

Commit

Permalink
add cosmos db vector store
Browse files Browse the repository at this point in the history
  • Loading branch information
TheovanKraay committed Feb 19, 2024
1 parent 0ef4a1b commit 656bb23
Show file tree
Hide file tree
Showing 11 changed files with 961 additions and 833 deletions.
10 changes: 10 additions & 0 deletions apps/acme-assist/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@
<artifactId>azure-ai-openai</artifactId>
<version>1.0.0-beta.3</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-mongodb</artifactId>
<version>3.1.2</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.experimental.ai</groupId>
<artifactId>spring-ai-azure-openai-spring-boot-starter</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package com.microsoft.azure.spring.chatgpt.sample.common;
package com.example.acme.assist;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatMessage;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;
import com.azure.ai.openai.models.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

Expand All @@ -14,11 +10,6 @@
@RequiredArgsConstructor
@Slf4j
public class AzureOpenAIClient {

private static final String EMBEDDING_MODEL = "text-embedding-ada-002";

private static final String CHAT_COMPLETION_MODEL = "gpt-35-turbo";

private static final double TEMPERATURE = 0.7;

private final OpenAIClient client;
Expand All @@ -29,14 +20,14 @@ public class AzureOpenAIClient {

public Embeddings getEmbeddings(List<String> texts) {
var response = client.getEmbeddings(embeddingDeploymentId,
new EmbeddingsOptions(texts).setModel(EMBEDDING_MODEL));
new EmbeddingsOptions(texts).setModel(embeddingDeploymentId));
log.info("Finished an embedding call with {} tokens.", response.getUsage().getTotalTokens());
return response;
}

public ChatCompletions getChatCompletions(List<ChatMessage> messages) {
var chatCompletionsOptions = new ChatCompletionsOptions(messages)
.setModel(CHAT_COMPLETION_MODEL)
.setModel(chatDeploymentId)
.setTemperature(TEMPERATURE);
var response = client.getChatCompletions(chatDeploymentId, chatCompletionsOptions);
log.info("Finished a chat completion call with {} tokens", response.getUsage().getTotalTokens());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package com.example.acme.assist;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatRole;
import com.example.acme.assist.model.AcmeChatRequest;
import com.example.acme.assist.model.Product;
import com.example.acme.assist.vectorstore.CosmosDBVectorStore;
import com.example.acme.assist.vectorstore.DocEntry;
import io.micrometer.common.util.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.ai.client.AiClient;
Expand Down Expand Up @@ -30,6 +33,12 @@ public class ChatService {
@Autowired
private SimplePersistentVectorStore store;

@Autowired
private AzureOpenAIClient openAIClient;

@Autowired
private CosmosDBVectorStore cosmosDBVectorStore;

@Autowired
private ProductRepository productRepository;

Expand All @@ -41,6 +50,9 @@ public class ChatService {

@Value("classpath:/prompts/chatWithProductId.st")
private Resource chatWithProductIdResource;

@Value("${spring.data.mongodb.enabled}")
private String cosmosEnabled;
/**
* Chat with the OpenAI API. Use the product details as the context.
*
Expand All @@ -65,8 +77,23 @@ private List<String> chatWithProductId(Product product, List<AcmeChatRequest.Mes
// We have a specific Product
String question = chatRequestMessages.get(chatRequestMessages.size() - 1).getContent();

var response = openAIClient.getEmbeddings(List.of(question));
var embedding = response.getData().get(0).getEmbedding();


List<Document> candidateDocuments = new ArrayList<>();;
// step 1. Query for documents that are related to the question from the vector store
List<Document> candidateDocuments = this.store.similaritySearch(question, 5, 0.4);
if (cosmosEnabled.equals("true")) {
List<DocEntry> cosmosVectorStoreDocs = this.cosmosDBVectorStore.searchTopKNearest(embedding, 5, 0.4);
for (DocEntry docEntry : cosmosVectorStoreDocs) {
Document document = new Document(docEntry.getText());
candidateDocuments.add(document);
}
}
else
{
candidateDocuments = this.store.similaritySearch(question, 5, 0.4);
}

// step 2. Create a SystemMessage that contains the product information in addition to related documents.
List<Message> messages = new ArrayList<>();
Expand All @@ -88,8 +115,22 @@ protected List<String> chatWithoutProductId(List<AcmeChatRequest.Message> acmeCh

String question = acmeChatRequestMessages.get(acmeChatRequestMessages.size() - 1).getContent();

var response = openAIClient.getEmbeddings(List.of(question));
var embedding = response.getData().get(0).getEmbedding();

// step 1. Query for documents that are related to the question from the vector store
List<Document> relatedDocuments = store.similaritySearch(question, 5, 0.4);
List<Document> relatedDocuments = new ArrayList<>();;
if (cosmosEnabled.equals("true")) {
List<DocEntry> cosmosVectorStoreDocs = this.cosmosDBVectorStore.searchTopKNearest(embedding, 5, 0.4);
for (DocEntry docEntry : cosmosVectorStoreDocs) {
Document document = new Document(docEntry.getText());
relatedDocuments.add(document);
}
}
else {
relatedDocuments = this.store.similaritySearch(question, 5, 0.4);
}


// step 2. Create the system message with the related documents;
List<Message> messages = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,82 @@
package com.example.acme.assist.config;

import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.example.acme.assist.AzureOpenAIClient;
import com.example.acme.assist.vectorstore.CosmosDBVectorStore;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.impl.SimplePersistentVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
import org.springframework.data.mongodb.core.MongoTemplate;

import java.io.IOException;

@Configuration
public class FitAssistConfiguration {

@Value("${spring.ai.azure.openai.embedding-model}")
private String embeddingDeploymentId;

@Value("${spring.ai.azure.openai.deployment-name}")
private String chatDeploymentId;

@Value("${spring.ai.azure.openai.endpoint}")
private String endpoint;

@Value("${spring.ai.azure.openai.api-key}")
private String apiKey;

@Value("${vector-store.file}")
private String cosmosVectorJsonFile;

@Value("${spring.data.mongodb.enabled}")
private String cosmosEnabled;

//@Autowired
private MongoTemplate mongoTemplate;
public FitAssistConfiguration(MongoTemplate mongoTemplate) {
this.mongoTemplate = mongoTemplate;
}



@Value("classpath:/vector_store.json")
private Resource vectorDbResource;
@Bean
public SimplePersistentVectorStore simpleVectorStore(EmbeddingClient embeddingClient) {
SimplePersistentVectorStore simpleVectorStore = new SimplePersistentVectorStore(embeddingClient);
simpleVectorStore.load(vectorDbResource);
if (cosmosEnabled.equals("false")) {
simpleVectorStore.load(vectorDbResource);
}
return simpleVectorStore;
}

@Bean
public CosmosDBVectorStore vectorStore() throws IOException {
CosmosDBVectorStore store = null;
if (cosmosEnabled.equals("true")) {
store = new CosmosDBVectorStore(mongoTemplate);
String currentPath = new java.io.File(".").getCanonicalPath();
String path = currentPath + cosmosVectorJsonFile.replace("\\", "//");
store.loadFromJsonFile(path);
}
else {
store = new CosmosDBVectorStore(null);
}
return store;
}

@Bean
public AzureOpenAIClient AzureOpenAIClient() {
var innerClient = new OpenAIClientBuilder()
.endpoint(endpoint)
.credential(new AzureKeyCredential(apiKey))
.buildClient();
return new AzureOpenAIClient(innerClient, embeddingDeploymentId, chatDeploymentId);
}


}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.microsoft.azure.spring.chatgpt.sample.common.vectorstore;
package com.example.acme.assist.vectorstore;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.mongodb.client.AggregateIterable;
Expand All @@ -8,20 +8,16 @@
import org.bson.Document;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.repository.config.EnableMongoRepositories;
import org.springframework.stereotype.Component;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;


@EnableMongoRepositories(basePackages = "com.microsoft.azure.spring.chatgpt.sample.common.vectorstore")
@Component
@EnableMongoRepositories(basePackages = "com.example.acme.assist.vectorstore")
//@Component
@Slf4j
public class CosmosDBVectorStore implements VectorStore {

Expand Down Expand Up @@ -67,10 +63,11 @@ public List<DocEntry> searchTopKNearest(List<Double> embedding, int k, double cu
List<DocEntry> result = new ArrayList<>();
for (Document doc : docs) {
String id = doc.getString("id");
String hash = doc.getString("hash");
String text = doc.getString("text");
MetaData metadata = new MetaData();
metadata.setName(doc.getString("metadata.name"));
List<Double> embedding1 = (List<Double>) doc.get("embedding");
DocEntry docEntry = new DocEntry(id, hash, text, embedding1);
DocEntry docEntry = new DocEntry(embedding1, id, metadata, text);
result.add(docEntry);
}
return result;
Expand All @@ -96,7 +93,7 @@ public List<MongoEntity> loadFromJsonFile(String filePath) {
List<DocEntry> list = new ArrayList<DocEntry>(data.store.values());
List<MongoEntity> mongoEntities = new ArrayList<>();
for (DocEntry docEntry : list) {
MongoEntity doc = new MongoEntity(docEntry.getId(), docEntry.getHash(), docEntry.getText(), docEntry.getEmbedding());
MongoEntity doc = new MongoEntity(docEntry.getEmbedding(), docEntry.getId(), docEntry.getMetadata(), docEntry.getText());
if (dimensions == 0) {
dimensions = docEntry.getEmbedding().size();
} else if (dimensions != docEntry.getEmbedding().size()) {
Expand All @@ -121,7 +118,7 @@ public List<MongoEntity> loadFromJsonFile(String filePath) {
}
}
}
createVectorIndex(100, dimensions, "COS");
createVectorIndex(5, dimensions, "COS");
}
return mongoEntities;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.microsoft.azure.spring.chatgpt.sample.common.vectorstore;
package com.example.acme.assist.vectorstore;

import lombok.Builder;
import lombok.Data;
Expand All @@ -14,19 +14,29 @@
@Document(collection = "vectorstore")
public class DocEntry {

private List<Double> embedding;
@Id
private String id;
private String hash;

public MetaData getMetadata() {
return metadata;
}

public void setMetadata(MetaData metadata) {
this.metadata = metadata;
}

private MetaData metadata;
private String text;
private List<Double> embedding;


public DocEntry() {}

public DocEntry(String id, String hash, String text, List<Double> embedding) {
public DocEntry(List<Double> embedding, String id, MetaData metadata, String text) {
this.id = id;
this.hash = hash;
this.text = text;
this.embedding = embedding;
this.metadata = metadata;
}

public String getId() {
Expand All @@ -37,14 +47,6 @@ public void setId(String id) {
this.id = id;
}

public String getHash() {
return hash;
}

public void setHash(String hash) {
this.hash = hash;
}

public String getText() {
return text;
}
Expand All @@ -65,7 +67,6 @@ public void setEmbedding(List<Double> embedding) {
public String toString() {
return "DocEntry{" +
"id='" + id + '\'' +
", hash='" + hash + '\'' +
", text='" + text + '\'' +
", embedding='" + embedding + '\'' +
'}';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
package com.example.acme.assist.vectorstore;public class MetaData {
}
package com.example.acme.assist.vectorstore;

public class MetaData {
public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

private String name;
}
Loading

0 comments on commit 656bb23

Please sign in to comment.