diff --git a/osbenchmark/worker_coordinator/runner.py b/osbenchmark/worker_coordinator/runner.py index ece6c7fb0..3ed101bfe 100644 --- a/osbenchmark/worker_coordinator/runner.py +++ b/osbenchmark/worker_coordinator/runner.py @@ -1318,7 +1318,7 @@ def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=F if predictions[j] in truth_set: correct += 1.0 - return float(correct) / min_num_of_results + return correct / min_num_of_results result = { "weight": 1, diff --git a/osbenchmark/workload/params.py b/osbenchmark/workload/params.py index 703a8af7c..d59ff557e 100644 --- a/osbenchmark/workload/params.py +++ b/osbenchmark/workload/params.py @@ -1102,8 +1102,7 @@ def __init__(self, workloads, params, query_params, **kwargs): self.corpora.extend(corpora for corpora in neighbors_corpora if corpora not in self.corpora) def _validate_query_type_parameters(self): - count = sum([self.k is not None, self.distance is not None, self.score is not None]) - if count > 1: + if bool(self.k) + bool(self.distance) + bool(self.score) > 1: raise ValueError("Only one of k, max_distance, or min_score can be specified in vector search.") @staticmethod @@ -1129,14 +1128,16 @@ def _get_query_neighbors(self): return Context.MAX_DISTANCE_NEIGHBORS raise Exception("Unknown query type [%s]" % self.query_type) + def _get_query_size(self): + if self.query_type == self.KNN_QUERY_TYPE: + return self.k + return self.DEFAULT_RADIAL_SEARCH_QUERY_RESULT_SIZE + def _update_body_params(self, vector): # accept body params if passed from workload, else, create empty dictionary body_params = self.query_params.get(self.PARAMS_NAME_BODY) or dict() if self.PARAMS_NAME_SIZE not in body_params: - if self.query_type == self.KNN_QUERY_TYPE: - body_params[self.PARAMS_NAME_SIZE] = self.k - else: - body_params[self.PARAMS_NAME_SIZE] = self.DEFAULT_RADIAL_SEARCH_QUERY_RESULT_SIZE + body_params[self.PARAMS_NAME_SIZE] = self._get_query_size() if self.PARAMS_NAME_QUERY in body_params: self.logger.warning( "[%s] param from body will be replaced with vector search query.", self.PARAMS_NAME_QUERY) @@ -1175,10 +1176,7 @@ def params(self): raise StopIteration vector = self.data_set.read(1)[0] neighbor = self.neighbors_data_set.read(1)[0] - if self.k: - true_neighbors = list(map(str, neighbor[:self.k])) - else: - true_neighbors = list(map(str, neighbor)) + true_neighbors = list(map(str, neighbor[:self.k] if self.k else neighbor)) self.query_params.update({ "neighbors": true_neighbors, })