Skip to content

Commit

Permalink
Resolve feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Jul 30, 2024
1 parent e4d5a8d commit 1616704
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion it/resources/benchmark-os-it.ini
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ datastore.password =


[workloads]
default.url = https://github.com/junqiu-lei/opensearch-benchmark-workloads
default.url = https://github.com/opensearch-project/opensearch-benchmark-workloads

[provision_configs]
default.dir = default-provision-config
Expand Down
2 changes: 1 addition & 1 deletion osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=F
if predictions[j] in truth_set:
correct += 1.0

return float(correct) / min_num_of_results
return correct / min_num_of_results

result = {
"weight": 1,
Expand Down
18 changes: 8 additions & 10 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,8 +1102,7 @@ def __init__(self, workloads, params, query_params, **kwargs):
self.corpora.extend(corpora for corpora in neighbors_corpora if corpora not in self.corpora)

def _validate_query_type_parameters(self):
count = sum([self.k is not None, self.distance is not None, self.score is not None])
if count > 1:
if bool(self.k) + bool(self.distance) + bool(self.score) > 1:
raise ValueError("Only one of k, max_distance, or min_score can be specified in vector search.")

@staticmethod
Expand All @@ -1129,14 +1128,16 @@ def _get_query_neighbors(self):
return Context.MAX_DISTANCE_NEIGHBORS
raise Exception("Unknown query type [%s]" % self.query_type)

def _get_query_size(self):
if self.query_type == self.KNN_QUERY_TYPE:
return self.k
return self.DEFAULT_RADIAL_SEARCH_QUERY_RESULT_SIZE

def _update_body_params(self, vector):
# accept body params if passed from workload, else, create empty dictionary
body_params = self.query_params.get(self.PARAMS_NAME_BODY) or dict()
if self.PARAMS_NAME_SIZE not in body_params:
if self.query_type == self.KNN_QUERY_TYPE:
body_params[self.PARAMS_NAME_SIZE] = self.k
else:
body_params[self.PARAMS_NAME_SIZE] = self.DEFAULT_RADIAL_SEARCH_QUERY_RESULT_SIZE
body_params[self.PARAMS_NAME_SIZE] = self._get_query_size()
if self.PARAMS_NAME_QUERY in body_params:
self.logger.warning(
"[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY)
Expand Down Expand Up @@ -1175,10 +1176,7 @@ def params(self):
raise StopIteration
vector = self.data_set.read(1)[0]
neighbor = self.neighbors_data_set.read(1)[0]
if self.k:
true_neighbors = list(map(str, neighbor[:self.k]))
else:
true_neighbors = list(map(str, neighbor))
true_neighbors = list(map(str, neighbor[:self.k] if self.k else neighbor))
self.query_params.update({
"neighbors": true_neighbors,
})
Expand Down

0 comments on commit 1616704

Please sign in to comment.