Skip to content

Commit

Permalink
feat(RAGCommandService): 어드민이 직접 학습시키는 API 구현 완료 (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
zbqmgldjfh authored Aug 31, 2024
1 parent 985a9d7 commit 3b74b1c
Show file tree
Hide file tree
Showing 14 changed files with 206 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.kustacks.kuring.admin.application.port.in.dto.RealNotificationCommand;
import com.kustacks.kuring.admin.domain.AdminRole;
import com.kustacks.kuring.alert.application.port.in.dto.AlertCreateCommand;
import com.kustacks.kuring.alert.application.port.in.dto.DataEmbeddingCommand;
import com.kustacks.kuring.auth.authorization.AuthenticationPrincipal;
import com.kustacks.kuring.auth.context.Authentication;
import com.kustacks.kuring.auth.secured.Secured;
Expand All @@ -24,9 +25,9 @@
import org.springframework.http.ResponseEntity;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import static com.kustacks.kuring.common.dto.ResponseCodeAndMessages.ADMIN_REAL_NOTICE_CREATE_SUCCESS;
import static com.kustacks.kuring.common.dto.ResponseCodeAndMessages.ADMIN_TEST_NOTICE_CREATE_SUCCESS;
import static com.kustacks.kuring.common.dto.ResponseCodeAndMessages.*;

@Tag(name = "Admin-Command", description = "관리자가 주체가 되는 정보 수정")
@Validated
Expand Down Expand Up @@ -88,6 +89,16 @@ public void cancelAlert(
adminCommandUseCase.cancelAlertSchedule(id);
}

@Operation(summary = "파일 임베딩", description = "어드민이 원하는 파일을 임베딩 하여 쿠링봇에서 사용할 수 있다")
@SecurityRequirement(name = "JWT")
@Secured(AdminRole.ROLE_ROOT)
@PostMapping("/embedding")
public ResponseEntity<BaseResponse<String>> embeddingCustomData(@RequestParam(name = "file") MultipartFile file) {
adminCommandUseCase.embeddingCustomData(new DataEmbeddingCommand(file));

return ResponseEntity.ok().body(new BaseResponse<>(ADMIN_EMBEDDING_NOTICE_SUCCESS, null));
}

@Hidden
@Secured(AdminRole.ROLE_ROOT)
@GetMapping("/subscribe/all")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.kustacks.kuring.admin.adapter.out.event;

import com.kustacks.kuring.admin.application.port.out.AiEventPort;
import com.kustacks.kuring.ai.adapter.in.event.dto.DataEmbeddingEvent;
import com.kustacks.kuring.common.domain.Events;
import lombok.RequiredArgsConstructor;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
public class AdminAiEventAdapter implements AiEventPort {

@Override
public void sendDataEmbeddingEvent(String originName, String extension, String contentType, Resource resource) {
Events.raise(new DataEmbeddingEvent(originName, extension, contentType, resource));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import com.kustacks.kuring.admin.application.port.in.dto.RealNotificationCommand;
import com.kustacks.kuring.admin.application.port.in.dto.TestNotificationCommand;
import com.kustacks.kuring.alert.application.port.in.dto.AlertCreateCommand;
import com.kustacks.kuring.alert.application.port.in.dto.DataEmbeddingCommand;

public interface AdminCommandUseCase {

void createTestNotice(TestNotificationCommand command);

void createRealNoticeForAllUser(RealNotificationCommand command);

void subscribeAllUserSameTopic();

void addAlertSchedule(AlertCreateCommand command);

void cancelAlertSchedule(Long id);

void embeddingCustomData(DataEmbeddingCommand command);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.kustacks.kuring.admin.application.port.out;

import org.springframework.core.io.Resource;

public interface AiEventPort {

void sendDataEmbeddingEvent(String originName, String extension, String contentType, Resource resource);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@
import com.kustacks.kuring.admin.application.port.out.AdminAlertEventPort;
import com.kustacks.kuring.admin.application.port.out.AdminEventPort;
import com.kustacks.kuring.admin.application.port.out.AdminUserFeedbackPort;
import com.kustacks.kuring.admin.application.port.out.AiEventPort;
import com.kustacks.kuring.admin.domain.Admin;
import com.kustacks.kuring.alert.application.port.in.dto.AlertCreateCommand;
import com.kustacks.kuring.alert.application.port.in.dto.DataEmbeddingCommand;
import com.kustacks.kuring.auth.userdetails.UserDetailsServicePort;
import com.kustacks.kuring.common.annotation.UseCase;
import com.kustacks.kuring.common.properties.ServerProperties;
import com.kustacks.kuring.notice.domain.CategoryName;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.InputStreamResource;
import org.springframework.core.io.Resource;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
Expand All @@ -33,6 +39,7 @@ public class AdminCommandService implements AdminCommandUseCase {
private final AdminUserFeedbackPort adminUserFeedbackPort;
private final AdminAlertEventPort adminAlertEventPort;
private final AdminEventPort adminEventPort;
private final AiEventPort aiEventPort;
private final NoticeProperties noticeProperties;
private final ServerProperties serverProperties;
private final PasswordEncoder passwordEncoder;
Expand Down Expand Up @@ -79,6 +86,27 @@ public void cancelAlertSchedule(Long id) {
adminAlertEventPort.cancelAlertSchedule(id);
}

@Override
public void embeddingCustomData(DataEmbeddingCommand command) {
try {
MultipartFile file = command.file();

String originalFilename = file.getOriginalFilename();
String contentType = file.getContentType();
Resource resource = new InputStreamResource(file.getInputStream());
String extension = extractExtension(originalFilename);

aiEventPort.sendDataEmbeddingEvent(
originalFilename,
extension,
contentType,
resource
);
} catch (IOException e) {
log.error("file read error", e);
}
}

/**
* TODO : 1회성 API - client v2 배포 후, 단 한번 모든 사용자를 공통 topic에 구독시킨 후 제거 예정
*/
Expand All @@ -100,4 +128,9 @@ public void subscribeAllUserSameTopic() {
private boolean isNotMatchPassword(final String commandPassword, final String adminPassword) {
return !passwordEncoder.matches(commandPassword, adminPassword);
}

private String extractExtension(String originalFilename) {
int pos = originalFilename.lastIndexOf(".");
return originalFilename.substring(pos + 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.kustacks.kuring.ai.adapter.in.event;

import com.kustacks.kuring.ai.adapter.in.event.dto.DataEmbeddingEvent;
import com.kustacks.kuring.ai.application.port.in.RAGCommandUseCase;
import lombok.RequiredArgsConstructor;
import org.springframework.context.event.EventListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
public class DataEmbeddingEventListener {

private final RAGCommandUseCase ragCommandUseCase;

@Async
@EventListener
public void dataEmbeddingEvent(
DataEmbeddingEvent event
) {
ragCommandUseCase.dataEmbedding(
event.fileName(),
event.extension(),
event.contentType(),
event.resource()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.kustacks.kuring.ai.adapter.in.event.dto;

import org.springframework.core.io.Resource;

public record DataEmbeddingEvent(
String fileName,
String extension,
String contentType,
Resource resource
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.time.LocalDateTime;
import java.util.List;

@Component
Expand Down Expand Up @@ -48,6 +50,15 @@ public void embedding(List<PageTextDto> extractTextResults, CategoryName categor
}
}

@Override
public void embeddingSingleTextFile(String originName, Resource resource) throws IOException {
TokenTextSplitter textSplitter = new TokenTextSplitter();

List<Document> documents = createDocument(originName, resource);
List<Document> splitDocuments = textSplitter.apply(documents);
chromaVectorStore.accept(splitDocuments);
}

private List<Document> createDocuments(CategoryName categoryName, PageTextDto textResult) {
Resource resource = new ByteArrayResource(textResult.text().getBytes()) {
@Override
Expand All @@ -62,4 +73,19 @@ public String getFilename() {
textReader.getCustomMetadata().put("category", categoryName.getName());
return textReader.get();
}

private List<Document> createDocument(String originName, Resource resource) throws IOException {
Resource byteResource = new ByteArrayResource(resource.getContentAsByteArray()) {
@Override
public String getFilename() {
return originName;
}
};

TextReader textReader = new TextReader(byteResource);
textReader.getCustomMetadata().put("articleId", "");
textReader.getCustomMetadata().put("date", LocalDateTime.now().toString());
textReader.getCustomMetadata().put("category", originName);
return textReader.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.context.annotation.Profile;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;

import java.util.Collections;
Expand Down Expand Up @@ -39,6 +40,11 @@ public void embedding(List<PageTextDto> extractTextResults, CategoryName categor
log.info("[InMemoryQueryVectorStoreAdapter] embedding {}", categoryName);
}

@Override
public void embeddingSingleTextFile(String originName, Resource resource) {
log.info("[InMemoryQueryVectorStoreAdapter] embeddingSingleTextFile {}", originName);
}

private Document createDocument(HashMap<String, Object> metadata) {
return new Document(
"a5a7414f-f676-409b-9f2e-1042f9846c97",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.kustacks.kuring.ai.application.port.in;

import org.springframework.core.io.Resource;

public interface RAGCommandUseCase {

void dataEmbedding(String originName, String extension, String contentType, Resource resource);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import com.kustacks.kuring.notice.domain.CategoryName;
import com.kustacks.kuring.worker.parser.notice.PageTextDto;
import org.springframework.core.io.Resource;

import java.io.IOException;
import java.util.List;

public interface CommandVectorStorePort {

void embedding(List<PageTextDto> extractTextResults, CategoryName categoryName);

void embeddingSingleTextFile(String originName, Resource resource) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.kustacks.kuring.ai.application.service;

import com.kustacks.kuring.ai.application.port.in.RAGCommandUseCase;
import com.kustacks.kuring.ai.application.port.out.CommandVectorStorePort;
import com.kustacks.kuring.common.annotation.UseCase;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.Resource;

import java.io.IOException;

@Slf4j
@UseCase
@RequiredArgsConstructor
public class RAGCommandService implements RAGCommandUseCase {

private final CommandVectorStorePort commandVectorStorePort;
private static final String EXTENSION_PDF = "pdf";
private static final String EXTENSION_TXT = "txt";

@Override
public void dataEmbedding(String originName, String extension, String contentType, Resource resource) {
try {
if (extension.equals(EXTENSION_PDF)) {
// TODO: pdf embedding
} else if (extension.equals(EXTENSION_TXT)) {
commandVectorStorePort.embeddingSingleTextFile(originName, resource);
} else {
log.warn("not supported file type : {}", extension);
}
} catch (IOException e) {
log.warn("file embedding fail : {}", originName);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.kustacks.kuring.alert.application.port.in.dto;

import org.springframework.web.multipart.MultipartFile;

public record DataEmbeddingCommand(
MultipartFile file
) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public enum ResponseCodeAndMessages {
/* Admin */
ADMIN_TEST_NOTICE_CREATE_SUCCESS(HttpStatus.OK.value(), "테스트 공지 생성에 성공하였습니다"),
ADMIN_REAL_NOTICE_CREATE_SUCCESS(HttpStatus.OK.value(), "실제 공지 생성에 성공하였습니다"),
ADMIN_EMBEDDING_NOTICE_SUCCESS(HttpStatus.OK.value(), "데이터 임베딩에 생성에 성공하였습니다"),

/* User */
USER_REGISTER_SUCCESS(HttpStatus.OK.value(), "회원가입에 성공하였습니다"),
Expand Down

0 comments on commit 3b74b1c

Please sign in to comment.