Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distance and similarity metric as output in KNN search #1260

Merged
merged 7 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions python/hello_infinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_english():

res = (
table.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
)

pds_df = pds.DataFrame(res)
Expand All @@ -70,8 +70,8 @@ def test_english():
table_obj = db.get_table("my_table")
qb_result = (
table_obj.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
)
print("------tabular -------")
print("------vector-------")
Expand All @@ -85,10 +85,10 @@ def test_english():

qb_result2 = (
table_obj.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.match("body", "blooms", "topn=1")
.fusion("rrf")
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.match("body", "blooms", "topn=1")
.fusion("rrf")
.to_pl()
)
print("------vector+fulltext-------")
print(qb_result2)
Expand Down Expand Up @@ -171,8 +171,8 @@ def test_chinese():
print("------json-------")
res = (
table.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 2)
.to_pl()
)
pds_df = pds.DataFrame(res)
json_data = pds_df.to_json()
Expand All @@ -183,8 +183,8 @@ def test_chinese():
print("------vector-------")
qb_result = (
table_obj.output(["num", "body"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 3)
.to_pl()
)
print(qb_result)

Expand All @@ -199,8 +199,8 @@ def test_chinese():
for question in questions:
qb_result = (
table_obj.output(["num", "body", "_score"])
.match("body", question, "topn=10")
.to_pl()
.match("body", question, "topn=10")
.to_pl()
)
print(f"question: {question}")
print(qb_result)
Expand All @@ -209,10 +209,10 @@ def test_chinese():
for question in questions:
qb_result = (
table_obj.output(["num", "body", "_score"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 10)
.match("body", question, "topn=10")
.fusion("rrf")
.to_pl()
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 10)
.match("body", question, "topn=10")
.fusion("rrf")
.to_pl()
)
print(f"question: {question}")
print(qb_result)
Expand Down
5 changes: 5 additions & 0 deletions python/infinity/remote_thrift/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ def output(self, columns: Optional[list]) -> InfinityThriftQueryBuilder:
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
select_list.append(parsed_expr)
case "_similarity":
func_expr = FunctionExpr(function_name="similarity", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
select_list.append(parsed_expr)
case "_distance":
func_expr = FunctionExpr(function_name="distance", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
Expand Down
2 changes: 1 addition & 1 deletion python/infinity/remote_thrift/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def insert(self, data: Union[INSERT_DATA, list[INSERT_DATA]]):
literal_type=ttypes.LiteralType.DoubleTensorArray,
f64_tensor_array_value=value)
else:
raise InfinityException(3069, "Invalid constant expression")
raise InfinityException(3069, f"Invalid constant expression: {type(value)}")

expr_type = ttypes.ParsedExprType(
constant_expr=constant_expression)
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "infinity_sdk"
version = "0.2.0.dev2"
version = "0.2.0.dev3"
dependencies = [
"sqlglot~=11.7.1",
"pydantic~=2.7.1",
Expand Down
6 changes: 3 additions & 3 deletions python/test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_knn_on_vector_column(self, get_infinity_db, check_data, column_name):
copy_data("tmp_20240116.csv")
test_csv_dir = "/var/infinity/test_data/tmp_20240116.csv"
table_obj.import_data(test_csv_dir, None)
res = table_obj.output(["variant_id", "_row_id", "_distance"]).knn(
res = table_obj.output(["variant_id", "_row_id", "_similarity"]).knn(
column_name, [1.0] * 4, "float", "ip", 2).to_pl()
print(res)

Expand Down Expand Up @@ -302,12 +302,12 @@ def test_valid_embedding_data_type(self, get_infinity_db, check_data, embedding_
test_csv_dir = "/var/infinity/test_data/tmp_20240116.csv"
table_obj.import_data(test_csv_dir, None)
if embedding_data_type[1]:
res = table_obj.output(["variant_id"]).knn("gender_vector", embedding_data, embedding_data_type[0],
res = table_obj.output(["variant_id", "_distance"]).knn("gender_vector", embedding_data, embedding_data_type[0],
"l2",
2).to_pl()
print(res)
else:
res = table_obj.output(["variant_id"]).knn("gender_vector", embedding_data, embedding_data_type[0],
res = table_obj.output(["variant_id", "_similarity"]).knn("gender_vector", embedding_data, embedding_data_type[0],
"ip",
2).to_pl()

Expand Down
11 changes: 7 additions & 4 deletions src/function/builtin_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ void BuiltinFunctions::RegisterSpecialFunction() {
SharedPtr<SpecialFunction> row_function = MakeShared<SpecialFunction>("ROW_ID", DataType(LogicalType::kBigInt), 1, SpecialType::kRowID);
Catalog::AddSpecialFunction(catalog_ptr_.get(), row_function);

SharedPtr<SpecialFunction> create_ts_function = MakeShared<SpecialFunction>("DISTANCE", DataType(LogicalType::kFloat), 2, SpecialType::kDistance);
Catalog::AddSpecialFunction(catalog_ptr_.get(), create_ts_function);
SharedPtr<SpecialFunction> distance_function = MakeShared<SpecialFunction>("DISTANCE", DataType(LogicalType::kFloat), 2, SpecialType::kDistance);
Catalog::AddSpecialFunction(catalog_ptr_.get(), distance_function);

SharedPtr<SpecialFunction> delete_ts_function = MakeShared<SpecialFunction>("SCORE", DataType(LogicalType::kFloat), 3, SpecialType::kScore);
Catalog::AddSpecialFunction(catalog_ptr_.get(), delete_ts_function);
SharedPtr<SpecialFunction> similarity_function = MakeShared<SpecialFunction>("SIMILARITY", DataType(LogicalType::kFloat), 3, SpecialType::kSimilarity);
Catalog::AddSpecialFunction(catalog_ptr_.get(), similarity_function);

SharedPtr<SpecialFunction> score_function = MakeShared<SpecialFunction>("SCORE", DataType(LogicalType::kFloat), 4, SpecialType::kScore);
Catalog::AddSpecialFunction(catalog_ptr_.get(), score_function);
}

} // namespace infinity
1 change: 1 addition & 0 deletions src/function/special_function.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace infinity {
export enum class SpecialType {
kRowID,
kDistance,
kSimilarity,
kScore,
kCreateTs,
kDeleteTs,
Expand Down
1 change: 1 addition & 0 deletions src/parser/type/data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ std::shared_ptr<DataType> DataType::Deserialize(const nlohmann::json &data_type_
}
case LogicalType::kSparse: {
type_info = SparseInfo::Deserialize(type_info_json);
break;
}
default:
// There's no type_info for other types
Expand Down
39 changes: 39 additions & 0 deletions src/planner/bind_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import column_identifer;
import block_index;
import column_expr;
import logger;
import knn_expr;

namespace infinity {

Expand Down Expand Up @@ -368,6 +369,44 @@ const Binding *BindContext::GetBindingFromCurrentOrParentByName(const String &bi
return binding_iter->second.get();
}

void BindContext::BoundSearch(ParsedExpr *expr) {
if (expr == nullptr) {
return;
}
auto search_expr = (SearchExpr *)expr;

if(!search_expr->knn_exprs_.empty() && search_expr->fusion_exprs_.empty()) {
SizeT expr_count = search_expr->knn_exprs_.size();
KnnExpr* first_knn = search_expr->knn_exprs_[0];
KnnDistanceType first_distance_type = first_knn->distance_type_;
for(SizeT idx = 1; idx < expr_count; ++ idx) {
if(search_expr->knn_exprs_[idx]->distance_type_ != first_distance_type) {
// Mixed distance type
return ;
}
}
switch(first_distance_type) {
case KnnDistanceType::kL2:
case KnnDistanceType::kHamming: {
allow_distance = true;
break;
}
case KnnDistanceType::kInnerProduct:
case KnnDistanceType::kCosine: {
allow_similarity = true;
break;
}
default: {
String error_message = "Invalid KNN metric type";
LOG_ERROR(error_message);
UnrecoverableError(error_message);
}
}
}

allow_score = !search_expr->match_exprs_.empty() || !search_expr->match_tensor_exprs_.empty() || !(search_expr->fusion_exprs_.empty());
}

// void
// BindContext::AddChild(const SharedPtr<BindContext>& child) {
// child->binding_context_id_ = GenerateBindingContextIndex();
Expand Down
11 changes: 2 additions & 9 deletions src/planner/bind_context.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ public:
bool single_row = false;

bool allow_distance = false;
bool allow_similarity = false;
bool allow_score = false;

public:
Expand Down Expand Up @@ -166,15 +167,7 @@ public:

void BoundTable(const String &table_name) { bound_table_set_.insert(table_name); }

void BoundSearch(ParsedExpr *expr) {
if (expr == nullptr) {
return;
}
auto search_expr = (SearchExpr *)expr;

allow_distance = !search_expr->knn_exprs_.empty() && search_expr->fusion_exprs_.empty();
allow_score = !search_expr->match_exprs_.empty() || !search_expr->match_tensor_exprs_.empty() || !(search_expr->fusion_exprs_.empty());
}
void BoundSearch(ParsedExpr *expr);

void AddSubqueryBinding(const String &name,
u64 table_index,
Expand Down
10 changes: 9 additions & 1 deletion src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,15 @@ Optional<SharedPtr<BaseExpression>> ExpressionBinder::TryBuildSpecialFuncExpr(co
switch (special_function_ptr->special_type()) {
case SpecialType::kDistance: {
if (!bind_context_ptr->allow_distance) {
Status status = Status::SyntaxError("DISTANCE() needs to be allowed only when there is only MATCH VECTOR");
Status status = Status::SyntaxError("DISTANCE() needs to be allowed only when there is only MATCH VECTOR with distance metrics, like L2");
LOG_ERROR(status.message());
RecoverableError(status);
}
break;
}
case SpecialType::kSimilarity: {
if (!bind_context_ptr->allow_similarity) {
Status status = Status::SyntaxError("SIMILARITY() needs to be allowed only when there is only MATCH VECTOR with similarity metrics, like Inner product");
LOG_ERROR(status.message());
RecoverableError(status);
}
Expand Down
1 change: 1 addition & 0 deletions src/planner/optimizer/column_remapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ SharedPtr<BaseExpression> BindingRemapper::VisitReplace(const SharedPtr<ColumnEx
column_cnt_ - 1);
}
case SpecialType::kScore:
case SpecialType::kSimilarity:
case SpecialType::kDistance: {
return ReferenceExpression::Make(expression->Type(),
expression->table_name(),
Expand Down
5 changes: 4 additions & 1 deletion test/sql/dql/knn/test_knn_ip.slt
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ SELECT c2 FROM test_knn_ip SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float
0.2,0.1,0.3,0.4

query II
SELECT c1, ROW_ID(), DISTANCE() FROM test_knn_ip SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'ip', 3);
SELECT c1, ROW_ID(), SIMILARITY() FROM test_knn_ip SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'ip', 3);
----
8 3 0.270000
6 2 0.250000
4 1 0.230000

statement error
SELECT c1, ROW_ID(), DISTANCE() FROM test_knn_l2 SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'ip', 3);

# copy to create another new block
# there will has 2 knn_scan operator to scan the blocks, and one merge_knn to merge
statement ok
Expand Down
3 changes: 3 additions & 0 deletions test/sql/dql/knn/test_knn_l2.slt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ SELECT c1, ROW_ID(), DISTANCE() FROM test_knn_l2 SEARCH MATCH VECTOR (c2, [0.3,
6 2 0.060000
4 1 0.100000

statement error
SELECT c1, ROW_ID(), SIMILARITY() FROM test_knn_l2 SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'l2', 3);

# copy to create another new block
# there will has 2 knn_scan operator to scan the blocks, and one merge_knn to merge
statement ok
Expand Down
Loading