Skip to content

Commit

Permalink
Cleaning up the code
Browse files Browse the repository at this point in the history
  • Loading branch information
kszapsza committed Dec 23, 2024
1 parent 8c9a1b3 commit f31ef8b
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.context.properties.ConfigurationPropertiesScan;

@SpringBootApplication
@ConfigurationPropertiesScan
public class SpringAiRagApplication {
public static void main(String[] args) {
SpringApplication.run(SpringAiRagApplication.class, args);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package io.github.kszapsza.springairag.adapter.application;

import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import io.github.kszapsza.springairag.adapter.llm.ChatProperties;
import io.github.kszapsza.springairag.domain.chat.ChatProvider;
import io.github.kszapsza.springairag.domain.chat.ChatService;

@Configuration
public class ChatConfiguration {
@EnableConfigurationProperties(ChatProperties.class)
public class ApplicationConfiguration {

@Bean
ChatService chatService(ChatProvider chatProvider) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.github.kszapsza.springairag.adapter.llm;

import java.util.Map;

import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties(prefix = "app.chat")
public record ChatProperties(
Embedding embedding,
SystemPrompt systemPrompt) {

public record Embedding(String document) {
}

public record SystemPrompt(
String resource,
Map<String, Object> placeholders) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,23 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.JsonReader;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;

import io.github.kszapsza.springairag.adapter.llm.EmbeddingDocumentsProvider;

@Component
public class ClasspathEmbeddingDocumentsProvider implements EmbeddingDocumentsProvider {

private static final Logger logger = LoggerFactory.getLogger(ClasspathEmbeddingDocumentsProvider.class);

private final List<Document> data;

public ClasspathEmbeddingDocumentsProvider(
@Value("classpath:/embedding/faq-data.json") Resource exampleDocumentResource) {
if (exampleDocumentResource == null || !exampleDocumentResource.exists()
|| !exampleDocumentResource.isReadable()) {
throw new IllegalStateException("RAG input data is missing or not readable");
public ClasspathEmbeddingDocumentsProvider(Resource documentsResource) {
if (documentsResource == null || !documentsResource.exists() || !documentsResource.isReadable()) {
throw new IllegalArgumentException("RAG input data is missing or not readable");
}
var reader = new JsonReader(exampleDocumentResource, "question", "answer", "category");
var reader = new JsonReader(documentsResource, "question", "answer", "category");
this.data = reader.read();
logger.info("Successfully loaded RAG input data from resource: {}", exampleDocumentResource);
logger.info("Loaded {} documents from resource: {}", data.size(), documentsResource);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.github.kszapsza.springairag.adapter.llm.classpath;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;

import io.github.kszapsza.springairag.adapter.llm.EmbeddingDocumentsProvider;
import io.github.kszapsza.springairag.adapter.llm.SystemPromptTemplateProvider;

@Configuration
class ClasspathResourcesConfiguration {

@Bean
EmbeddingDocumentsProvider embeddingDocumentsProvider(
@Value("${app.chat.embedding.document}") Resource documentResource) {
return new ClasspathEmbeddingDocumentsProvider(documentResource);
}

@Bean
SystemPromptTemplateProvider systemPromptTemplateProvider(
@Value("${app.chat.system-prompt.resource}") Resource systemPromptResource) {
return new ClasspathSystemPromptTemplateProvider(systemPromptResource);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,19 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;

import io.github.kszapsza.springairag.adapter.llm.SystemPromptTemplateProvider;

@Component
public class ClasspathSystemPromptTemplateProvider implements SystemPromptTemplateProvider {

private static final Logger logger = LoggerFactory.getLogger(ClasspathSystemPromptTemplateProvider.class);

private final SystemPromptTemplate systemPromptTemplate;

public ClasspathSystemPromptTemplateProvider(
@Value("classpath:/chat/system-message.txt") Resource systemPromptResource) {
public ClasspathSystemPromptTemplateProvider(Resource systemPromptResource) {
if (systemPromptResource == null || !systemPromptResource.exists() || !systemPromptResource.isReadable()) {
throw new IllegalStateException("System prompt resource is missing or not readable");
throw new IllegalArgumentException("System prompt resource is missing or not readable");
}
this.systemPromptTemplate = new SystemPromptTemplate(systemPromptResource);
logger.info("Successfully loaded SystemPromptTemplate from resource: {}", systemPromptResource);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io.github.kszapsza.springairag.adapter.llm.function.realestate.RealEstateSearchFunction;

@Configuration
public class FunctionConfiguration {
class FunctionConfiguration {

@Bean
@Description("Searches real estate listings by location, price range, bedrooms, and active status")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.github.kszapsza.springairag.adapter.llm.openai;

import java.util.List;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import io.github.kszapsza.springairag.adapter.llm.ChatProperties;
import io.github.kszapsza.springairag.adapter.llm.SystemPromptTemplateProvider;

@Configuration
class OpenAiChatConfiguration {

@Bean
ChatClient chatClient(OpenAiChatModel chatModel) {
return ChatClient.create(chatModel);
}

@Bean
ChatOptions chatOptions() {
return OpenAiChatOptions.builder()
.withFunction("realEstateSearchFunction")
.build();
}

@Bean
List<Advisor> chatAdvisors(VectorStore vectorStore) {
return List.of(
new SimpleLoggerAdvisor(),
new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()));
}

@Bean
Message systemMessage(
SystemPromptTemplateProvider systemPromptTemplateProvider,
ChatProperties chatProperties) {
return systemPromptTemplateProvider.getSystemPromptTemplate()
.createMessage(chatProperties.systemPrompt().placeholders());
}
}
Original file line number Diff line number Diff line change
@@ -1,78 +1,67 @@
package io.github.kszapsza.springairag.adapter.llm.openai;

import java.util.List;
import java.util.Optional;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Component;
import org.springframework.context.annotation.Configuration;

import io.github.kszapsza.springairag.adapter.application.SystemMessageProperties;
import io.github.kszapsza.springairag.adapter.llm.SystemPromptTemplateProvider;
import io.github.kszapsza.springairag.domain.chat.ChatProvider;
import io.github.kszapsza.springairag.domain.chat.ChatRequest;
import io.github.kszapsza.springairag.domain.chat.ChatResponse;

@Component
@Configuration
public class OpenAiChatProvider implements ChatProvider {

private static final Logger logger = LoggerFactory.getLogger(OpenAiChatProvider.class);

private final OpenAiChatModel chatModel;
private final VectorStore vectorStore;
private final SystemMessageProperties systemMessageProperties;
private final SystemPromptTemplate systemPromptTemplate;
private final ChatClient chatClient;
private final ChatOptions chatOptions;
private final List<Advisor> advisors;
private final Message systemMessage;

public OpenAiChatProvider(
OpenAiChatModel chatModel,
VectorStore vectorStore,
SystemMessageProperties systemMessageProperties,
SystemPromptTemplateProvider systemPromptTemplateProvider) {
this.chatModel = chatModel;
this.vectorStore = vectorStore;
this.systemMessageProperties = systemMessageProperties;
this.systemPromptTemplate = systemPromptTemplateProvider.getSystemPromptTemplate();
ChatClient chatClient,
ChatOptions chatOptions,
List<Advisor> advisors,
Message systemMessage) {
this.chatClient = chatClient;
this.chatOptions = chatOptions;
this.advisors = advisors;
this.systemMessage = systemMessage;
}

@Override
public ChatResponse chat(ChatRequest request) {
if (request.message() == null || request.message().trim().isEmpty()) {
logger.warn("Received a null or empty request message");
return new ChatResponse.Failure("Request message cannot be null or empty");
}
try {
var content = callChatModel(request.message()).getContent();
return new ChatResponse.Success(content);
return callChatModel(request.message())
.map((generation) -> generation.getOutput().getContent())
.map((content) -> (ChatResponse) new ChatResponse.Success(content))
.orElseGet(() -> {
logger.warn("Received null generation result");
return new ChatResponse.Error.ServerError("Received null generation result");
});
} catch (Exception ex) {
return new ChatResponse.Failure(ex.getMessage());
logger.error("Error during chat model call", ex);
return new ChatResponse.Error.ServerError("An internal error occurred during processing.");
}
}

private Message callChatModel(String userMessageContent) {
return ChatClient.create(chatModel)
.prompt()
.options(buildOptions())
.advisors(
new SimpleLoggerAdvisor(),
new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()))
.system(systemPromptTemplate.createMessage(systemMessageProperties.placeholders()).getContent())
.user(userMessageContent)
.call()
.chatResponse()
.getResult()
.getOutput();
}

private ChatOptions buildOptions() {
return OpenAiChatOptions.builder()
.withFunction("realEstateSearchFunction")
.build();
private Optional<Generation> callChatModel(String userMessageContent) {
return Optional.ofNullable(
chatClient.prompt()
.options(chatOptions)
.advisors(advisors)
.system(systemMessage.getContent())
.user(userMessageContent)
.call()
.chatResponse()
.getResult());
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.kszapsza.springairag.adapter.rest;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
Expand All @@ -10,31 +12,43 @@
import io.github.kszapsza.springairag.domain.chat.ChatRequest;
import io.github.kszapsza.springairag.domain.chat.ChatResponse;
import io.github.kszapsza.springairag.domain.chat.ChatService;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotBlank;

@RestController
@RequestMapping("/api/chat")
public class ChatController {
private static final Logger logger = LoggerFactory.getLogger(ChatController.class);

private final ChatService chatService;

public ChatController(ChatService chatService) {
this.chatService = chatService;
}

@PostMapping
public ResponseEntity<ChatResponseDto> chat(@RequestBody ChatRequestDto request) {
public ResponseEntity<ChatResponseDto> chat(@Valid @RequestBody ChatRequestDto request) {
var chatResponse = chatService.chat(request.toDomain());

return switch (chatResponse) {
case ChatResponse.Success success ->
ResponseEntity.ok(ChatResponseDto.fromDomain(success));
case ChatResponse.Failure failure ->
ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE)
case ChatResponse.Success success -> {
yield ResponseEntity.ok(ChatResponseDto.fromDomain(success));
}
case ChatResponse.Error.ClientError failure -> {
logger.warn("Client error occurred: {}", failure.errorMessage());
yield ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body(new ChatResponseDto(failure.errorMessage()));
}
case ChatResponse.Error.ServerError failure -> {
logger.error("Server error occurred: {}", failure.errorMessage());
yield ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE)
.body(new ChatResponseDto("The chat is temporarily unavailable. Please try again later."));
}
};
}
}

record ChatRequestDto(String message) {
record ChatRequestDto(@NotBlank String message) {
public ChatRequest toDomain() {
return new ChatRequest(message());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ public sealed interface ChatResponse {
record Success(String content) implements ChatResponse {
}

record Failure(String errorMessage) implements ChatResponse {
sealed interface Error extends ChatResponse {
record ClientError(String errorMessage) implements Error {
}

record ServerError(String errorMessage) implements Error {
}
}
}
Loading

0 comments on commit f31ef8b

Please sign in to comment.