Skip to content

Commit

Permalink
[Feat] Rag system service logic #189
Browse files Browse the repository at this point in the history
* setting: Spring AI 1.0.0-M1 의존성 설정 완료

* test: Spring AI 질문 인수테스트 작성

* feat(Member): Member Entity question_count 필드 추가

* feat(ErrorCode): AI관련 error code 추가

* feat(RAGConfiguration): SimpleVectorDB 구현

* feat: RAG API 구현

* feat(RAGQueryAiModelAdapter): Flux를 사용한 응답 stream 구현

* test(AiAcceptanceTest): AI에게 질문하는 인수테스트 작성

* feat(InMemoryQueryAiModelAdapter): 테스트 용도의 InMemory AI model 응답 Fake 구현

* feat(AiAcceptanceTest): 가능한 요청 횟수를 넘긴 경우 TOO_MANY_REQUESTS를 반환받는다

* fix(RAGQueryApiV2): 질문 API의 produces를 text/event-stream 으로 변경

* fix(InMemoryQueryAiModelAdapter): InMemory model의 응답 수정

* feat(AiAcceptanceTest): 질문 한도수 초과시 예외가 발행하는 인수테스트 작성

* feat(RAGEventAdapter): RAG 시스템의 이벤트 포트 별도로 분리

기존에는 service 내부에서 그냥 Event 를 raise 하였는데, 이 방식 보다는 포트에 전달하고 Adapter 에서 raise 하도록 변경

* feat: dev profile ai 설정 추가

* feat(InMemoryQueryAiModelAdapter): 기본 응답 부분을 단건 문자로 분리

* feat(RAGQueryApiV2): API를 POST에서 GET으로 변경

* chore: 테스트 코드에 필요없는 예외 제거

* test(UserRepositoryTest): 사용자 질문 토큰감소 리포지토리 테스트 추가

* feat(User): Deprecated 된 Where 에너테이션을 SQLRestriction 으로 변경

* refactor(NoticeQueryRepositoryImpl): 중복 문자열 상수로 추출
  • Loading branch information
zbqmgldjfh committed Jul 23, 2024
1 parent 22fca60 commit 6e89cde
Show file tree
Hide file tree
Showing 36 changed files with 919 additions and 28 deletions.
17 changes: 16 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
plugins {
id 'org.springframework.boot' version '3.2.3'
id 'io.spring.dependency-management' version '1.1.4'
id 'io.spring.dependency-management' version '1.1.5'
id 'java'
id 'org.asciidoctor.jvm.convert' version "3.3.2"
id 'org.sonarqube' version '3.5.0.2730' // sonarqube gradle plugin 의존성
Expand All @@ -20,6 +20,8 @@ configurations {

repositories {
mavenCentral()
maven { url 'https://repo.spring.io/milestone' }
maven { url 'https://repo.spring.io/snapshot' }
}

sonarqube {
Expand All @@ -30,6 +32,10 @@ sonarqube {
}
}

ext {
set('springAiVersion', "1.0.0-M1")
}

dependencies {
// Web
implementation 'org.springframework.boot:spring-boot-starter-web'
Expand All @@ -38,6 +44,9 @@ dependencies {
implementation 'org.springframework:spring-aspects'
annotationProcessor 'org.springframework.boot:spring-boot-configuration-processor'

// AI
implementation "org.springframework.ai:spring-ai-openai-spring-boot-starter:${springAiVersion}"

// DB
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
runtimeOnly 'com.h2database:h2'
Expand Down Expand Up @@ -98,6 +107,12 @@ dependencies {
testImplementation 'org.testcontainers:mariadb:1.19.3'
}

dependencyManagement {
imports {
mavenBom "org.springframework.ai:spring-ai-bom:${springAiVersion}"
}
}

// Swagger force conflict resolution
configurations.all {
resolutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.kustacks.kuring.ai.adapter.in.web;

import com.kustacks.kuring.ai.application.port.in.RAGQueryUseCase;
import com.kustacks.kuring.common.annotation.RestWebAdapter;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.security.SecurityRequirement;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestParam;
import reactor.core.publisher.Flux;

@RequiredArgsConstructor
@RestWebAdapter(path = "/api/v2/ai/messages")
public class RAGQueryApiV2 {

private static final String USER_TOKEN_HEADER_KEY = "User-Token";

private final RAGQueryUseCase ragQueryUseCase;

@Operation(summary = "사용자 AI에 질문요청", description = "사용자가 궁금한 학교 정보를 AI에게 질문합니다.")
@SecurityRequirement(name = "User-Token")
@GetMapping(produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> askAIQuery(
@RequestParam("question") String question,
@RequestHeader(USER_TOKEN_HEADER_KEY) String id
) {
return ragQueryUseCase.askAiModel(question, id);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.kustacks.kuring.ai.adapter.in.web.dto;

import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Size;

public record UserQuestionRequest(
@NotBlank @Size(min = 5, max = 256) String question
) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.kustacks.kuring.ai.adapter.out.event;

import com.kustacks.kuring.ai.application.port.out.RAGEventPort;
import com.kustacks.kuring.common.domain.Events;
import com.kustacks.kuring.user.adapter.in.event.dto.UserDecreaseQuestionCountEvent;
import org.springframework.stereotype.Component;

@Component
public class RAGEventAdapter implements RAGEventPort {

@Override
public void userDecreaseQuestionCountEvent(String userId) {
Events.raise(new UserDecreaseQuestionCountEvent(userId));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.kustacks.kuring.ai.adapter.out.model;

import com.kustacks.kuring.ai.application.port.out.QueryAiModelPort;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

@Slf4j
@Component
@Profile("dev | local | test")
@RequiredArgsConstructor
public class InMemoryQueryAiModelAdapter implements QueryAiModelPort {

@Value("classpath:/ai/docs/ku-uni-register.txt")
private Resource kuUniRegisterInfo;

@Override
public Flux<String> call(Prompt prompt) {
if (prompt.getContents().contains("교내,외 장학금 및 학자금 대출 관련 전화번호들을 안내를 해줘")) {
return Flux.just("학", "생", "복", "지", "처", " ", "장", "학", "복", "지", "팀", "의",
" ", "전", "화", "번", "호", "는", " ", "0", "2", "-", "4", "5", "0", "-", "3", "2", "1",
"1", "~", "2", "이", "며", ",", " ", "건", "국", "사", "랑", "/", "장", "학", "사", "정",
"관", "장", "학", "/", "기", "금", "장", "학", "과", " ", "관", "련", "된", " ", "문", "의",
"는", " ", "0", "2", "-", "4", "5", "0", "-", "3", "9", "6", "7", "로", " ", "하", "시",
"면", " ", "됩", "니", "다", "."
);
}

return Flux.just("미", "리", " ", "준", "비", "된", " ",
"테", "스", "트", "질", "문", "이", " ", "아", "닙", "니", "다");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.kustacks.kuring.ai.adapter.out.model;

import com.kustacks.kuring.ai.application.port.out.QueryAiModelPort;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

@Slf4j
@Component
@Profile("prod")
@RequiredArgsConstructor
public class QueryAiModelAdapter implements QueryAiModelPort {

private final OpenAiChatModel openAiChatModel;

@Override
public Flux<String> call(Prompt prompt) {
return openAiChatModel.stream(prompt)
.filter(chatResponse -> chatResponse.getResult().getOutput().getContent() != null)
.flatMap(chatResponse -> Flux.just(chatResponse.getResult().getOutput().getContent()))
.doOnError(throwable -> log.error("[RAGQueryAiModelAdapter] {}", throwable.getMessage()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.kustacks.kuring.ai.adapter.out.persistence;

import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.List;
import java.util.stream.Stream;

@Slf4j
@Profile("local | dev | test")
@Component
@RequiredArgsConstructor
public class InMemoryQueryVectorStoreAdapter implements QueryVectorStorePort {

@Override
public List<String> findSimilarityContents(String question) {
HashMap<String, Object> metadata = createMetaData();

Document document = createDocument(metadata);

return Stream.of(document)
.map(Document::getContent)
.toList();
}

private Document createDocument(HashMap<String, Object> metadata) {
return new Document(
"a5a7414f-f676-409b-9f2e-1042f9846c97",
"● 등록금 전액 완납 또는 분할납부 1차분을 정해진 기간에 미납할 경우 분할납부 신청은 자동 취소되며, 미납 등록금은 이후\n" +
"추가 등록기간에 전액 납부해야 함.\n",
metadata);
}

private HashMap<String, Object> createMetaData() {
HashMap<String, Object> metadata = new HashMap<>();
metadata.put("charset", "UTF-8");
metadata.put("filename", "ku-uni-register.txt");
metadata.put("source", "ku-uni-register.txt");
return metadata;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.kustacks.kuring.ai.adapter.out.persistence;

import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

import java.util.Collections;
import java.util.List;

@Component
@Profile("prod")
@RequiredArgsConstructor
public class QueryVectorStoreAdapter implements QueryVectorStorePort {

@Override
public List<String> findSimilarityContents(String question) {
return Collections.emptyList();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.kustacks.kuring.ai.application.port.in;

import reactor.core.publisher.Flux;

public interface RAGQueryUseCase {
Flux<String> askAiModel(String question, String id);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.kustacks.kuring.ai.application.port.out;

import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;

public interface QueryAiModelPort {
Flux<String> call(Prompt prompt);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.kustacks.kuring.ai.application.port.out;

import java.util.List;

public interface QueryVectorStorePort {
List<String> findSimilarityContents(String question);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.kustacks.kuring.ai.application.port.out;

public interface RAGEventPort {
void userDecreaseQuestionCountEvent(String userId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.kustacks.kuring.ai.application.service;

import com.kustacks.kuring.ai.application.port.in.RAGQueryUseCase;
import com.kustacks.kuring.ai.application.port.out.QueryAiModelPort;
import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import com.kustacks.kuring.ai.application.port.out.RAGEventPort;
import com.kustacks.kuring.common.annotation.UseCase;
import com.kustacks.kuring.common.exception.InvalidStateException;
import com.kustacks.kuring.common.exception.code.ErrorCode;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import reactor.core.publisher.Flux;

import javax.annotation.PostConstruct;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@UseCase
@RequiredArgsConstructor
public class RAGQueryService implements RAGQueryUseCase {

private final QueryVectorStorePort vectorStorePort;
private final QueryAiModelPort ragChatModel;
private final RAGEventPort ragEventPort;

@Value("classpath:/ai/prompts/rag-prompt-template.st")
private Resource ragPromptTemplate;
private PromptTemplate promptTemplate;

@Override
public Flux<String> askAiModel(String question, String id) {
Prompt completePrompt = buildCompletePrompt(question);
ragEventPort.userDecreaseQuestionCountEvent(id);
return ragChatModel.call(completePrompt);
}

@PostConstruct
private void init() {
this.promptTemplate = new PromptTemplate(ragPromptTemplate);
}

private Prompt buildCompletePrompt(String question) {
List<String> similarDocuments = vectorStorePort.findSimilarityContents(question);
if(similarDocuments.isEmpty()) {
throw new InvalidStateException(ErrorCode.AI_SIMILAR_DOCUMENTS_NOT_FOUND);
}

Map<String, Object> promptParameters = createQuestions(question, similarDocuments);
return promptTemplate.create(promptParameters);
}

private Map<String, Object> createQuestions(String question, List<String> contentList) {
Map<String, Object> promptParameters = new HashMap<>();
promptParameters.put("input", question);
promptParameters.put("documents", String.join("In", contentList));
return promptParameters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.kustacks.kuring.common.exception;

import com.kustacks.kuring.common.exception.code.ErrorCode;

public class InvalidStateException extends BusinessException {

public InvalidStateException(ErrorCode errorCode) {
super(errorCode);
}
public InvalidStateException(ErrorCode errorCode, Exception e) {
super(errorCode, e);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ public enum ErrorCode {

CAT_NOT_EXIST_CATEGORY(HttpStatus.BAD_REQUEST, "서버에서 지원하지 않는 카테고리입니다."),

// STAFF_SCRAPER_TAG_NOT_EXIST("Jsoup - 찾고자 하는 태그가 존재하지 않습니다."),
STAFF_SCRAPER_EXCEED_RETRY_LIMIT("교직원 업데이트 재시도 횟수를 초과했습니다."),
STAFF_SCRAPER_CANNOT_SCRAP("건국대학교 홈페이지가 불안정합니다. 교직원 정보를 가져올 수 없습니다."),
STAFF_SCRAPER_CANNOT_PARSE("교직원 페이지 HTML 파싱에 실패했습니다."),
Expand All @@ -77,11 +76,15 @@ public enum ErrorCode {

USER_NOT_FOUND(HttpStatus.NOT_FOUND, "해당 사용자를 찾을 수 없습니다."),

// AI 관련
AI_SIMILAR_DOCUMENTS_NOT_FOUND(HttpStatus.NOT_FOUND, "죄송합니다, 관련된 내용에 대하여 알지 못합니다."),

/**
* ErrorCodes about DomainLogicException
*/
DOMAIN_CANNOT_CREATE("해당 도메인을 생성할 수 없습니다."),
DEPARTMENT_NOT_FOUND("해당 학과를 찾을 수 없습니다.");
DEPARTMENT_NOT_FOUND("해당 학과를 찾을 수 없습니다."),
QUESTION_COUNT_NOT_ENOUGH(HttpStatus.TOO_MANY_REQUESTS, "남은 질문 횟수가 부족합니다.");

private final HttpStatus httpStatus;
private final String message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.kustacks.kuring.common.dto.ErrorResponse;
import com.kustacks.kuring.common.exception.AdminException;
import com.kustacks.kuring.common.exception.InternalLogicException;
import com.kustacks.kuring.common.exception.InvalidStateException;
import com.kustacks.kuring.common.exception.NotFoundException;
import com.kustacks.kuring.common.exception.code.ErrorCode;
import com.kustacks.kuring.message.application.service.exception.FirebaseSubscribeException;
Expand Down Expand Up @@ -69,6 +70,13 @@ public ResponseEntity<ErrorResponse> FirebaseSubscribeExceptionHandler(FirebaseS
.body(new ErrorResponse(ErrorCode.API_FB_SERVER_ERROR));
}

@ExceptionHandler
public ResponseEntity<ErrorResponse> InvalidStateExceptionHandler(InvalidStateException exception) {
log.info("[InvalidStateException] {}", exception.getMessage());
return ResponseEntity.status(exception.getErrorCode().getHttpStatus())
.body(new ErrorResponse(exception.getErrorCode()));
}

@ExceptionHandler
public void InternalLogicExceptionHandler(InternalLogicException e) {
log.warn("[InternalLogicException] {}", e.getErrorCode().getMessage(), e);
Expand Down
Loading

0 comments on commit 6e89cde

Please sign in to comment.