Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Rag system #188 #189

Merged
merged 21 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ff5dee0
setting: Spring AI 1.0.0-M1 의존성 설정 완료
zbqmgldjfh Jul 8, 2024
366ae24
test: Spring AI 질문 인수테스트 작성
zbqmgldjfh Jul 8, 2024
e70165f
feat(Member): Member Entity question_count 필드 추가
zbqmgldjfh Jul 9, 2024
1efcacf
feat(ErrorCode): AI관련 error code 추가
zbqmgldjfh Jul 9, 2024
1fbc53c
feat(RAGConfiguration): SimpleVectorDB 구현
zbqmgldjfh Jul 9, 2024
513fbe4
feat: RAG API 구현
zbqmgldjfh Jul 9, 2024
ade75b0
feat(RAGQueryAiModelAdapter): Flux를 사용한 응답 stream 구현
zbqmgldjfh Jul 10, 2024
77fdf25
test(AiAcceptanceTest): AI에게 질문하는 인수테스트 작성
zbqmgldjfh Jul 10, 2024
018bf90
feat(InMemoryQueryAiModelAdapter): 테스트 용도의 InMemory AI model 응답 Fake 구현
zbqmgldjfh Jul 10, 2024
967ca51
feat(AiAcceptanceTest): 가능한 요청 횟수를 넘긴 경우 TOO_MANY_REQUESTS를 반환받는다
zbqmgldjfh Jul 10, 2024
276a78d
fix(RAGQueryApiV2): 질문 API의 produces를 text/event-stream 으로 변경
zbqmgldjfh Jul 11, 2024
8229e8e
fix(InMemoryQueryAiModelAdapter): InMemory model의 응답 수정
zbqmgldjfh Jul 11, 2024
acad157
feat(AiAcceptanceTest): 질문 한도수 초과시 예외가 발행하는 인수테스트 작성
zbqmgldjfh Jul 13, 2024
a25e5d1
feat(RAGEventAdapter): RAG 시스템의 이벤트 포트 별도로 분리
zbqmgldjfh Jul 13, 2024
f44ecea
feat: dev profile ai 설정 추가
zbqmgldjfh Jul 13, 2024
487abf5
feat(InMemoryQueryAiModelAdapter): 기본 응답 부분을 단건 문자로 분리
zbqmgldjfh Jul 13, 2024
6c0445f
feat(RAGQueryApiV2): API를 POST에서 GET으로 변경
zbqmgldjfh Jul 13, 2024
58ddd5e
chore: 테스트 코드에 필요없는 예외 제거
zbqmgldjfh Jul 13, 2024
92adca9
test(UserRepositoryTest): 사용자 질문 토큰감소 리포지토리 테스트 추가
zbqmgldjfh Jul 13, 2024
e390cb5
feat(User): Deprecated 된 Where 에너테이션을 SQLRestriction 으로 변경
zbqmgldjfh Jul 13, 2024
535ea2d
refactor(NoticeQueryRepositoryImpl): 중복 문자열 상수로 추출
zbqmgldjfh Jul 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
plugins {
id 'org.springframework.boot' version '3.2.3'
id 'io.spring.dependency-management' version '1.1.4'
id 'io.spring.dependency-management' version '1.1.5'
id 'java'
id 'org.asciidoctor.jvm.convert' version "3.3.2"
id 'org.sonarqube' version '3.5.0.2730' // sonarqube gradle plugin 의존성
Expand All @@ -20,6 +20,8 @@ configurations {

repositories {
mavenCentral()
maven { url 'https://repo.spring.io/milestone' }
maven { url 'https://repo.spring.io/snapshot' }
}

sonarqube {
Expand All @@ -30,6 +32,10 @@ sonarqube {
}
}

ext {
set('springAiVersion', "1.0.0-M1")
}

dependencies {
// Web
implementation 'org.springframework.boot:spring-boot-starter-web'
Expand All @@ -38,6 +44,9 @@ dependencies {
implementation 'org.springframework:spring-aspects'
annotationProcessor 'org.springframework.boot:spring-boot-configuration-processor'

// AI
implementation "org.springframework.ai:spring-ai-openai-spring-boot-starter:${springAiVersion}"

// DB
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
runtimeOnly 'com.h2database:h2'
Expand Down Expand Up @@ -98,6 +107,12 @@ dependencies {
testImplementation 'org.testcontainers:mariadb:1.19.3'
}

dependencyManagement {
imports {
mavenBom "org.springframework.ai:spring-ai-bom:${springAiVersion}"
}
}

// Swagger force conflict resolution
configurations.all {
resolutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.kustacks.kuring.ai.adapter.in.web;

import com.kustacks.kuring.ai.application.port.in.RAGQueryUseCase;
import com.kustacks.kuring.common.annotation.RestWebAdapter;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.security.SecurityRequirement;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestParam;
import reactor.core.publisher.Flux;

@RequiredArgsConstructor
@RestWebAdapter(path = "/api/v2/ai/messages")
public class RAGQueryApiV2 {

private static final String USER_TOKEN_HEADER_KEY = "User-Token";

private final RAGQueryUseCase ragQueryUseCase;

@Operation(summary = "사용자 AI에 질문요청", description = "사용자가 궁금한 학교 정보를 AI에게 질문합니다.")
@SecurityRequirement(name = "User-Token")
@GetMapping(produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> askAIQuery(
@RequestParam("question") String question,
@RequestHeader(USER_TOKEN_HEADER_KEY) String id
) {
return ragQueryUseCase.askAiModel(question, id);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.kustacks.kuring.ai.adapter.in.web.dto;

import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Size;

public record UserQuestionRequest(
@NotBlank @Size(min = 5, max = 256) String question
) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.kustacks.kuring.ai.adapter.out.event;

import com.kustacks.kuring.ai.application.port.out.RAGEventPort;
import com.kustacks.kuring.common.domain.Events;
import com.kustacks.kuring.user.adapter.in.event.dto.UserDecreaseQuestionCountEvent;
import org.springframework.stereotype.Component;

@Component
public class RAGEventAdapter implements RAGEventPort {

@Override
public void userDecreaseQuestionCountEvent(String userId) {
Events.raise(new UserDecreaseQuestionCountEvent(userId));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.kustacks.kuring.ai.adapter.out.model;

import com.kustacks.kuring.ai.application.port.out.QueryAiModelPort;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

@Slf4j
@Component
@Profile("dev | local | test")
@RequiredArgsConstructor
public class InMemoryQueryAiModelAdapter implements QueryAiModelPort {

@Value("classpath:/ai/docs/ku-uni-register.txt")
private Resource kuUniRegisterInfo;

@Override
public Flux<String> call(Prompt prompt) {
if (prompt.getContents().contains("교내,외 장학금 및 학자금 대출 관련 전화번호들을 안내를 해줘")) {
return Flux.just("학", "생", "복", "지", "처", " ", "장", "학", "복", "지", "팀", "의",
" ", "전", "화", "번", "호", "는", " ", "0", "2", "-", "4", "5", "0", "-", "3", "2", "1",
"1", "~", "2", "이", "며", ",", " ", "건", "국", "사", "랑", "/", "장", "학", "사", "정",
"관", "장", "학", "/", "기", "금", "장", "학", "과", " ", "관", "련", "된", " ", "문", "의",
"는", " ", "0", "2", "-", "4", "5", "0", "-", "3", "9", "6", "7", "로", " ", "하", "시",
"면", " ", "됩", "니", "다", "."
);
}

return Flux.just("미", "리", " ", "준", "비", "된", " ",
"테", "스", "트", "질", "문", "이", " ", "아", "닙", "니", "다");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.kustacks.kuring.ai.adapter.out.model;

import com.kustacks.kuring.ai.application.port.out.QueryAiModelPort;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;

@Slf4j
@Component
@Profile("prod")
@RequiredArgsConstructor
public class QueryAiModelAdapter implements QueryAiModelPort {

private final OpenAiChatModel openAiChatModel;

@Override
public Flux<String> call(Prompt prompt) {
return openAiChatModel.stream(prompt)
.filter(chatResponse -> chatResponse.getResult().getOutput().getContent() != null)
.flatMap(chatResponse -> Flux.just(chatResponse.getResult().getOutput().getContent()))
.doOnError(throwable -> log.error("[RAGQueryAiModelAdapter] {}", throwable.getMessage()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.kustacks.kuring.ai.adapter.out.persistence;

import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.List;
import java.util.stream.Stream;

@Slf4j
@Profile("local | dev | test")
@Component
@RequiredArgsConstructor
public class InMemoryQueryVectorStoreAdapter implements QueryVectorStorePort {

@Override
public List<String> findSimilarityContents(String question) {
HashMap<String, Object> metadata = createMetaData();

Document document = createDocument(metadata);

return Stream.of(document)
.map(Document::getContent)
.toList();
}

private Document createDocument(HashMap<String, Object> metadata) {
return new Document(
"a5a7414f-f676-409b-9f2e-1042f9846c97",
"● 등록금 전액 완납 또는 분할납부 1차분을 정해진 기간에 미납할 경우 분할납부 신청은 자동 취소되며, 미납 등록금은 이후\n" +
"추가 등록기간에 전액 납부해야 함.\n",
metadata);
}

private HashMap<String, Object> createMetaData() {
HashMap<String, Object> metadata = new HashMap<>();
metadata.put("charset", "UTF-8");
metadata.put("filename", "ku-uni-register.txt");
metadata.put("source", "ku-uni-register.txt");
return metadata;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.kustacks.kuring.ai.adapter.out.persistence;

import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

import java.util.Collections;
import java.util.List;

@Component
@Profile("prod")
@RequiredArgsConstructor
public class QueryVectorStoreAdapter implements QueryVectorStorePort {

@Override
public List<String> findSimilarityContents(String question) {
return Collections.emptyList();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.kustacks.kuring.ai.application.port.in;

import reactor.core.publisher.Flux;

public interface RAGQueryUseCase {
Flux<String> askAiModel(String question, String id);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.kustacks.kuring.ai.application.port.out;

import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;

public interface QueryAiModelPort {
Flux<String> call(Prompt prompt);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.kustacks.kuring.ai.application.port.out;

import java.util.List;

public interface QueryVectorStorePort {
List<String> findSimilarityContents(String question);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.kustacks.kuring.ai.application.port.out;

public interface RAGEventPort {
void userDecreaseQuestionCountEvent(String userId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.kustacks.kuring.ai.application.service;

import com.kustacks.kuring.ai.application.port.in.RAGQueryUseCase;
import com.kustacks.kuring.ai.application.port.out.QueryAiModelPort;
import com.kustacks.kuring.ai.application.port.out.QueryVectorStorePort;
import com.kustacks.kuring.ai.application.port.out.RAGEventPort;
import com.kustacks.kuring.common.annotation.UseCase;
import com.kustacks.kuring.common.exception.InvalidStateException;
import com.kustacks.kuring.common.exception.code.ErrorCode;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import reactor.core.publisher.Flux;

import javax.annotation.PostConstruct;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@UseCase
@RequiredArgsConstructor
public class RAGQueryService implements RAGQueryUseCase {

private final QueryVectorStorePort vectorStorePort;
private final QueryAiModelPort ragChatModel;
private final RAGEventPort ragEventPort;

@Value("classpath:/ai/prompts/rag-prompt-template.st")
private Resource ragPromptTemplate;
private PromptTemplate promptTemplate;

@Override
public Flux<String> askAiModel(String question, String id) {
Prompt completePrompt = buildCompletePrompt(question);
ragEventPort.userDecreaseQuestionCountEvent(id);
return ragChatModel.call(completePrompt);
}

@PostConstruct
private void init() {
this.promptTemplate = new PromptTemplate(ragPromptTemplate);
}

private Prompt buildCompletePrompt(String question) {
List<String> similarDocuments = vectorStorePort.findSimilarityContents(question);
if(similarDocuments.isEmpty()) {
throw new InvalidStateException(ErrorCode.AI_SIMILAR_DOCUMENTS_NOT_FOUND);
}

Map<String, Object> promptParameters = createQuestions(question, similarDocuments);
return promptTemplate.create(promptParameters);
}

private Map<String, Object> createQuestions(String question, List<String> contentList) {
Map<String, Object> promptParameters = new HashMap<>();
promptParameters.put("input", question);
promptParameters.put("documents", String.join("In", contentList));
return promptParameters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.kustacks.kuring.common.exception;

import com.kustacks.kuring.common.exception.code.ErrorCode;

public class InvalidStateException extends BusinessException {

public InvalidStateException(ErrorCode errorCode) {
super(errorCode);
}
public InvalidStateException(ErrorCode errorCode, Exception e) {
super(errorCode, e);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ public enum ErrorCode {

CAT_NOT_EXIST_CATEGORY(HttpStatus.BAD_REQUEST, "서버에서 지원하지 않는 카테고리입니다."),

// STAFF_SCRAPER_TAG_NOT_EXIST("Jsoup - 찾고자 하는 태그가 존재하지 않습니다."),
STAFF_SCRAPER_EXCEED_RETRY_LIMIT("교직원 업데이트 재시도 횟수를 초과했습니다."),
STAFF_SCRAPER_CANNOT_SCRAP("건국대학교 홈페이지가 불안정합니다. 교직원 정보를 가져올 수 없습니다."),
STAFF_SCRAPER_CANNOT_PARSE("교직원 페이지 HTML 파싱에 실패했습니다."),
Expand All @@ -77,11 +76,15 @@ public enum ErrorCode {

USER_NOT_FOUND(HttpStatus.NOT_FOUND, "해당 사용자를 찾을 수 없습니다."),

// AI 관련
AI_SIMILAR_DOCUMENTS_NOT_FOUND(HttpStatus.NOT_FOUND, "죄송합니다, 관련된 내용에 대하여 알지 못합니다."),

/**
* ErrorCodes about DomainLogicException
*/
DOMAIN_CANNOT_CREATE("해당 도메인을 생성할 수 없습니다."),
DEPARTMENT_NOT_FOUND("해당 학과를 찾을 수 없습니다.");
DEPARTMENT_NOT_FOUND("해당 학과를 찾을 수 없습니다."),
QUESTION_COUNT_NOT_ENOUGH(HttpStatus.TOO_MANY_REQUESTS, "남은 질문 횟수가 부족합니다.");

private final HttpStatus httpStatus;
private final String message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.kustacks.kuring.common.dto.ErrorResponse;
import com.kustacks.kuring.common.exception.AdminException;
import com.kustacks.kuring.common.exception.InternalLogicException;
import com.kustacks.kuring.common.exception.InvalidStateException;
import com.kustacks.kuring.common.exception.NotFoundException;
import com.kustacks.kuring.common.exception.code.ErrorCode;
import com.kustacks.kuring.message.application.service.exception.FirebaseSubscribeException;
Expand Down Expand Up @@ -69,6 +70,13 @@ public ResponseEntity<ErrorResponse> FirebaseSubscribeExceptionHandler(FirebaseS
.body(new ErrorResponse(ErrorCode.API_FB_SERVER_ERROR));
}

@ExceptionHandler
public ResponseEntity<ErrorResponse> InvalidStateExceptionHandler(InvalidStateException exception) {
log.info("[InvalidStateException] {}", exception.getMessage());
return ResponseEntity.status(exception.getErrorCode().getHttpStatus())
.body(new ErrorResponse(exception.getErrorCode()));
}

@ExceptionHandler
public void InternalLogicExceptionHandler(InternalLogicException e) {
log.warn("[InternalLogicException] {}", e.getErrorCode().getMessage(), e);
Expand Down
Loading
Loading