Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros committed Oct 8, 2024
1 parent add4469 commit bf3a8e7
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ func buildVectorInfoWithCheck(indexPartSpecifications []*ast.IndexPartSpecificat
if !ok {
return nil, "", dbterror.ErrUnsupportedAddVectorIndex.FastGenByArgs(fmt.Sprintf("unsupported function: %v", idxPart.Expr))
}
distanceMetric, ok := model.FnNameToDistanceMetric[f.FnName.L]
distanceMetric, ok := model.IndexableFnNameToDistanceMetric[f.FnName.L]
if !ok {
return nil, "", dbterror.ErrUnsupportedAddVectorIndex.FastGenByArgs("currently only L2 and Cosine distance is indexable")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ func constructResultOfShowCreateTable(ctx sessionctx.Context, dbName *pmodel.CIS
cols = append(cols, colInfo)
}
if idxInfo.VectorInfo != nil {
funcName := model.DistanceMetricToFnName[idxInfo.VectorInfo.DistanceMetric]
funcName := model.IndexableDistanceMetricToFnName[idxInfo.VectorInfo.DistanceMetric]
fmt.Fprintf(buf, "((%s(%s)))", strings.ToUpper(funcName), strings.Join(cols, ","))
} else {
fmt.Fprintf(buf, "(%s)", strings.Join(cols, ","))
Expand Down
40 changes: 15 additions & 25 deletions pkg/expression/vs_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ package expression
import (
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/intest"
)

var (
Expand All @@ -33,66 +33,56 @@ var (
}
)

// VectorSearchExpr defines a minimal Vector Search expression, which is
// a vector distance function, a column to search with, and a reference vector.
type VectorSearchExpr struct {
// VectorHelper is a helper struct for vector indexes.
type VectorHelper struct {
DistanceFnName model.CIStr
Vec types.VectorFloat32
Column *Column
ColumnID int64
}

// ExtractVectorSearch extracts a VectorSearchExpr from an expression.
// ExtractVectorHelper extracts a VectorSearchExpr from an expression.
// NOTE: not all VectorSearch functions are supported by the index. The caller
// needs to check the distance function name.
func ExtractVectorSearch(expr Expression) (*VectorSearchExpr, error) {
func ExtractVectorHelper(expr Expression) *VectorHelper {
x, ok := expr.(*ScalarFunction)
if !ok {
return nil, nil
return nil
}

if _, isVecFn := vsDistanceFnNamesLower[x.FuncName.L]; !isVecFn {
return nil, nil
return nil
}

args := x.GetArgs()
if len(args) != 2 {
return nil, errors.Errorf("internal: expect 2 args for function %s, but got %d", x.FuncName.L, len(args))
}

// One arg must be a vector column ref, and one arg must be a vector constant.
// Note: this must be run after constant folding.

var vectorConstant *Constant = nil
var vectorColumn *Column = nil
nVectorColumns := 0
nVectorConstants := 0
for _, arg := range args {
if v, ok := arg.(*Column); ok {
if v.RetType.GetType() != mysql.TypeTiDBVectorFloat32 {
break
return nil
}
vectorColumn = v
nVectorColumns++
} else if v, ok := arg.(*Constant); ok {
if v.RetType.GetType() != mysql.TypeTiDBVectorFloat32 {
break
return nil
}
vectorConstant = v
nVectorConstants++
}
}
if nVectorColumns != 1 || nVectorConstants != 1 {
return nil, nil
return nil
}

// All check passed.
if vectorConstant.Value.Kind() != types.KindVectorFloat32 {
return nil, errors.Errorf("internal: expect vectorFloat32 constant, but got %s", vectorConstant.Value.String())
}
intest.Assert(vectorConstant.Value.Kind() == types.KindVectorFloat32, "internal: expect vectorFloat32 constant, but got %s", vectorConstant.Value.String())

return &VectorSearchExpr{
return &VectorHelper{
DistanceFnName: x.FuncName,
Vec: vectorConstant.Value.GetVectorFloat32(),
Column: vectorColumn,
}, nil
ColumnID: vectorColumn.ID,
}
}
8 changes: 4 additions & 4 deletions pkg/meta/model/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ const (
DistanceMetricInnerProduct DistanceMetric = "INNER_PRODUCT"
)

// FnNameToDistanceMetric maps a distance function name to the distance metric.
// IndexableFnNameToDistanceMetric maps a distance function name to the distance metric.
// Only indexable distance functions should be listed here!
var FnNameToDistanceMetric = map[string]DistanceMetric{
var IndexableFnNameToDistanceMetric = map[string]DistanceMetric{
ast.VecCosineDistance: DistanceMetricCosine,
ast.VecL2Distance: DistanceMetricL2,
}

// DistanceMetricToFnName maps a distance metric to the distance function name.
var DistanceMetricToFnName = map[DistanceMetric]string{
// IndexableDistanceMetricToFnName maps a distance metric to the distance function name.
var IndexableDistanceMetricToFnName = map[DistanceMetric]string{
DistanceMetricCosine: ast.VecCosineDistance,
DistanceMetricL2: ast.VecL2Distance,
}
Expand Down
21 changes: 9 additions & 12 deletions pkg/planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/paging"
"github.com/pingcap/tidb/pkg/util/plancodec"
Expand Down Expand Up @@ -1024,33 +1025,27 @@ func fixTopNForANNIndex(p *PhysicalTopN) bool {
// supported yet.
return false
}
vs, err := expression.ExtractVectorSearch(p.ByItems[0].Expr)
if err != nil || vs == nil {
vs := expression.ExtractVectorHelper(p.ByItems[0].Expr)
if vs == nil {
return false
}
// Note that even if this is a vector search expression, it may not hit vector index
// because not all vector search functions are indexable.
distanceMetric, ok := model.FnNameToDistanceMetric[vs.DistanceFnName.L]
if !ok {
return false
}
distanceMetric, ok := model.IndexableFnNameToDistanceMetric[vs.DistanceFnName.L]
// User may build a vector index with different distance metric.
// In this case the index shall not push down.
if distanceMetric != ts.AnnIndexExtra.IndexInfo.VectorInfo.DistanceMetric {
if !ok || distanceMetric != ts.AnnIndexExtra.IndexInfo.VectorInfo.DistanceMetric {
return false
}
// User may build a vector index with different vector column.
// In this case the index shall not push down.
col := ts.Table.Columns[ts.AnnIndexExtra.IndexInfo.Columns[0].Offset]
if col.ID != vs.Column.ID {
if col.ID != vs.ColumnID {
return false
}

distanceMetricPB, ok := tipb.VectorDistanceMetric_value[string(distanceMetric)]
if !ok {
// This should not happen.
return false
}
intest.Assert(distanceMetricPB != 0, "invalid distance metric")
ts.AnnIndexExtra.PushDownQueryInfo = &tipb.ANNQueryInfo{
QueryType: tipb.ANNQueryType_OrderBy,
DistanceMetric: tipb.VectorDistanceMetric(distanceMetricPB),
Expand All @@ -1059,6 +1054,8 @@ func fixTopNForANNIndex(p *PhysicalTopN) bool {
RefVecF32: vs.Vec.SerializeTo(nil),
IndexId: int64(ts.AnnIndexExtra.IndexInfo.ID),
}
ts.AnnIndexExtra.PushDownQueryInfo.ColumnId = new(int64)
*ts.AnnIndexExtra.PushDownQueryInfo.ColumnId = vs.ColumnID
ts.PlanCostInit = false
return true
}
Expand Down

0 comments on commit bf3a8e7

Please sign in to comment.