Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search by asset type endpoint #2691

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ public record Index(
String decapodesContext) {
}

public String getIndex(String root) {
return String.join("_", index.prefix, root, index.suffix);
}

public String getCodeIndex() {
return String.join("_", index.prefix, index.codeRoot, index.suffix);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package software.uncharted.terarium.hmiserver.controller.knn;
package software.uncharted.terarium.hmiserver.controller.search;

import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
Expand All @@ -11,8 +11,8 @@
import org.springframework.security.access.annotation.Secured;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;

Expand All @@ -31,28 +31,32 @@
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import software.uncharted.terarium.hmiserver.configuration.ElasticsearchConfiguration;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.models.task.TaskStatus;
import software.uncharted.terarium.hmiserver.security.Roles;
import software.uncharted.terarium.hmiserver.service.TaskService;
import software.uncharted.terarium.hmiserver.service.elasticsearch.ElasticsearchService;

@RequestMapping("/knn")
@RequestMapping("/search-by-asset-type")
@RestController
@Slf4j
@RequiredArgsConstructor
public class KNNSearchController {
public class SearchByAssetTypeController {

final private ObjectMapper objectMapper;
final private TaskService taskService;
final private RedissonClient redissonClient;
final private ElasticsearchService elasticsearchService;
final private ElasticsearchService esService;
final private ElasticsearchConfiguration esConfig;
private RMapCache<byte[], List<Float>> queryVectorCache;

final private long CACHE_TTL_SECONDS = 60 * 60 * 24;
final private long REQUEST_TIMEOUT_SECONDS = 10;
final private String EMBEDDING_MODEL = "text-embedding-ada-002";
static final private long CACHE_TTL_SECONDS = 60 * 60 * 2; // 2 hours
static final private long REQUEST_TIMEOUT_SECONDS = 10;
static final private String EMBEDDING_MODEL = "text-embedding-ada-002";
static final private String SEARCHABLE_INDEX_PREFIX = "searchable_";
static final private String REDIS_EMBEDDING_CACHE_KEY = "knn-vector-cache";

@Data
static public class GoLLMSearchRequest {
Expand All @@ -67,39 +71,45 @@ private static class EmbeddingsResponse {
List<Float> response;
}

@Data
static public class KNNSearchRequest {
private String text;
private int numCandidates = 100;
private int k = 10;
}

@PostConstruct
public void init() {
queryVectorCache = redissonClient.getMapCache("knn-vector-cache");
queryVectorCache = redissonClient.getMapCache(REDIS_EMBEDDING_CACHE_KEY);
}

@GetMapping("/{index}")
@GetMapping("/{asset-type}")
@Secured(Roles.USER)
@Operation(summary = "Executes a knn search against provided index")
@Operation(summary = "Executes a knn search against the provided asset type")
@ApiResponses(value = {
@ApiResponse(responseCode = "200", description = "Query results", content = @Content(mediaType = "application/json", schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = JsonNode.class))),
@ApiResponse(responseCode = "204", description = "There was no concept found", content = @Content),
@ApiResponse(responseCode = "500", description = "There was an issue retrieving the concept from the data store", content = @Content)
})
public ResponseEntity<List<JsonNode>> knnSearch(
@PathVariable("index") final String index,
@RequestBody KNNSearchRequest body) {
public ResponseEntity<List<JsonNode>> searchByAssetType(
@PathVariable("asset-type") final String assetType,
@RequestParam(value = "text", required = true) final String text,
@RequestParam(value = "k", defaultValue = "10") final int k,
@RequestParam(value = "num-results", defaultValue = "100") final int numResults,
@RequestParam(value = "num-candidates", defaultValue = "100") final int numCandidates,
@RequestParam(value = "embedding-model", defaultValue = EMBEDDING_MODEL) final String embeddingModel,
@RequestParam(value = "index", defaultValue = "") String index) {

try {

if (body.getK() > body.getNumCandidates()) {
if (index.equals("")) {
index = esConfig.getIndex(SEARCHABLE_INDEX_PREFIX + assetType);
if (!esService.containsIndex(index)) {
log.error("Unsupported asset type: {}, index does not exist", assetType);
return ResponseEntity.badRequest().build();
}
}

if (k > numCandidates) {
return ResponseEntity.badRequest().build();
}

// sha256 the text to use as a cache key
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] hash = md.digest(body.getText().getBytes(StandardCharsets.UTF_8));
byte[] hash = md.digest(text.getBytes(StandardCharsets.UTF_8));

// check if we already have the vectors cached
List<Float> vector = queryVectorCache.get(hash);
Expand All @@ -108,7 +118,7 @@ public ResponseEntity<List<JsonNode>> knnSearch(
// set the embedding model

GoLLMSearchRequest embeddingRequest = new GoLLMSearchRequest();
embeddingRequest.setText(body.getText());
embeddingRequest.setText(text);
embeddingRequest.setEmbeddingModel(EMBEDDING_MODEL);

TaskRequest req = new TaskRequest();
Expand Down Expand Up @@ -139,11 +149,11 @@ public ResponseEntity<List<JsonNode>> knnSearch(
KnnQuery query = new KnnQuery.Builder()
.field("embeddings.vector")
.queryVector(vector)
.k(body.getK())
.numCandidates(body.getNumCandidates())
.k(k)
.numCandidates(numCandidates)
.build();

List<JsonNode> docs = elasticsearchService.knnSearch(index, query, JsonNode.class);
List<JsonNode> docs = esService.knnSearch(index, query, numResults, JsonNode.class);

return ResponseEntity.ok(docs);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public void init() {
}

/**
* Create all indices that are not already present in the cluster
* Check for the existence of an index.
*
* @return True if the index exists, false otherwise
*/
Expand Down Expand Up @@ -334,7 +334,7 @@ public <T> T get(final String index, final String id, final Class<T> tClass) thr
return null;
}

public <T> List<T> knnSearch(String index, KnnQuery query, final Class<T> tClass)
public <T> List<T> knnSearch(String index, KnnQuery query, int numResults, final Class<T> tClass)
throws IOException {
log.info("KNN search on: {}", index);

Expand All @@ -344,7 +344,7 @@ public <T> List<T> knnSearch(String index, KnnQuery query, final Class<T> tClass

SearchRequest req = new SearchRequest.Builder()
.index(index)
.size((int) query.k())
.size(numResults)
.source(src -> src.filter(v -> v.includes("title")))
.knn(query)
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package software.uncharted.terarium.hmiserver.controller.knn;
package software.uncharted.terarium.hmiserver.controller.search;

import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

import java.util.List;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.security.test.context.support.WithUserDetails;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
Expand All @@ -18,37 +18,26 @@
import lombok.extern.slf4j.Slf4j;
import software.uncharted.terarium.hmiserver.TerariumApplicationTests;
import software.uncharted.terarium.hmiserver.configuration.MockUser;
import software.uncharted.terarium.hmiserver.controller.knn.KNNSearchController.KNNSearchRequest;

@Slf4j
public class KNNSearchControllerTests extends TerariumApplicationTests {
public class SearchByAssetTypeControllerTests extends TerariumApplicationTests {

@Autowired
private ObjectMapper objectMapper;

private static final String TEST_INDEX = "tds_covid_tera_1.0";
private static final String TEST_INDEX = "tds_searchable_document_tera_1.0";
private static final String TEST_ASSET = "document";

// @Test
@WithUserDetails(MockUser.ADAM)
public void testKnnSearch() throws Exception {

KNNSearchRequest req = new KNNSearchRequest();
req.setText("Papers that discuss the use of masks to prevent the spread of COVID-19");

// Test that we get a 404 if we provide a project id that doesn't exist
MvcResult res = mockMvc.perform(MockMvcRequestBuilders.get("/knn/" + TEST_INDEX)
.contentType(MediaType.APPLICATION_JSON)
.accept(MediaType.APPLICATION_JSON)
.with(request -> {
try {
request.setMethod("GET");
request.setContent(objectMapper.writeValueAsBytes(req));
} catch (Exception e) {
e.printStackTrace();
}
return request;
})
MvcResult res = mockMvc.perform(MockMvcRequestBuilders.get("/search-by-asset-type/" + TEST_ASSET)
.param("text", "Papers that discuss the use of masks to prevent the spread of COVID-19")
.param("index", TEST_INDEX) // index override
.with(csrf()))
.andDo(print())
.andExpect(status().isOk())
.andReturn();

Expand Down
Loading