From 2ba29a28d32e3331a4d8841f43baaef7d36a061a Mon Sep 17 00:00:00 2001 From: EricZequan <110292382+EricZequan@users.noreply.github.com> Date: Mon, 12 Aug 2024 11:15:02 +0800 Subject: [PATCH] expression: Add vector functions (#55021) ref pingcap/tidb#54245 --- pkg/expression/aggregation/aggregation.go | 19 ++ pkg/expression/builtin.go | 11 +- pkg/expression/builtin_arithmetic.go | 118 ++++++++ pkg/expression/builtin_cast.go | 2 +- pkg/expression/builtin_compare.go | 80 +++++ pkg/expression/builtin_control.go | 86 ++++++ pkg/expression/builtin_other.go | 37 +++ pkg/expression/builtin_vec.go | 276 ++++++++++++++++++ pkg/expression/distsql_builtin.go | 48 +++ pkg/expression/distsql_builtin_test.go | 6 + pkg/expression/infer_pushdown.go | 9 +- pkg/expression/integration_test/BUILD.bazel | 2 +- .../integration_test/integration_test.go | 146 ++++++++- pkg/types/vector.go | 13 + pkg/types/vector_functions.go | 229 +++++++++++++++ tests/integrationtest/r/executor/show.result | 5 + 16 files changed, 1073 insertions(+), 14 deletions(-) diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go index d02625b7865d7..24d9d9710db4e 100644 --- a/pkg/expression/aggregation/aggregation.go +++ b/pkg/expression/aggregation/aggregation.go @@ -239,6 +239,9 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp if aggFunc.Name == ast.AggFuncApproxPercentile { return false } + if !checkVectorAggPushDown(ctx, aggFunc) { + return false + } ret := true switch storeType { case kv.TiFlash: @@ -253,6 +256,22 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp return ret } +// checkVectorAggPushDown returns false if this aggregate function is not supported to push down. +// - The aggregate function is not calculated over a Vector column (returns true) +// - The aggregate function is calculated over a Vector column and the function is supported (returns true) +// - The aggregate function is calculated over a Vector column and the function is not supported (returns false) +func checkVectorAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc) bool { + switch aggFunc.Name { + case ast.AggFuncCount, ast.AggFuncMin, ast.AggFuncMax, ast.AggFuncFirstRow: + return true + default: + if aggFunc.Args[0].GetType(ctx).GetType() == mysql.TypeTiDBVectorFloat32 { + return false + } + } + return true +} + // CheckAggPushFlash checks whether an agg function can be pushed to flash storage. func CheckAggPushFlash(ctx expression.EvalContext, aggFunc *AggFuncDesc) bool { for _, arg := range aggFunc.Args { diff --git a/pkg/expression/builtin.go b/pkg/expression/builtin.go index 942201750f771..e01405efda29a 100644 --- a/pkg/expression/builtin.go +++ b/pkg/expression/builtin.go @@ -929,9 +929,14 @@ var funcs = map[string]functionClass{ ast.JSONLength: &jsonLengthFunctionClass{baseFunctionClass{ast.JSONLength, 1, 2}}, // vector functions (TiDB extension) - ast.VecDims: &vecDimsFunctionClass{baseFunctionClass{ast.VecDims, 1, 1}}, - ast.VecFromText: &vecFromTextFunctionClass{baseFunctionClass{ast.VecFromText, 1, 1}}, - ast.VecAsText: &vecAsTextFunctionClass{baseFunctionClass{ast.VecAsText, 1, 1}}, + ast.VecDims: &vecDimsFunctionClass{baseFunctionClass{ast.VecDims, 1, 1}}, + ast.VecL1Distance: &vecL1DistanceFunctionClass{baseFunctionClass{ast.VecL1Distance, 2, 2}}, + ast.VecL2Distance: &vecL2DistanceFunctionClass{baseFunctionClass{ast.VecL2Distance, 2, 2}}, + ast.VecNegativeInnerProduct: &vecNegativeInnerProductFunctionClass{baseFunctionClass{ast.VecNegativeInnerProduct, 2, 2}}, + ast.VecCosineDistance: &vecCosineDistanceFunctionClass{baseFunctionClass{ast.VecCosineDistance, 2, 2}}, + ast.VecL2Norm: &vecL2NormFunctionClass{baseFunctionClass{ast.VecL2Norm, 1, 1}}, + ast.VecFromText: &vecFromTextFunctionClass{baseFunctionClass{ast.VecFromText, 1, 1}}, + ast.VecAsText: &vecAsTextFunctionClass{baseFunctionClass{ast.VecAsText, 1, 1}}, // TiDB internal function. ast.TiDBDecodeKey: &tidbDecodeKeyFunctionClass{baseFunctionClass{ast.TiDBDecodeKey, 1, 1}}, diff --git a/pkg/expression/builtin_arithmetic.go b/pkg/expression/builtin_arithmetic.go index d2cdb64f9b329..acdf87f8a27a0 100644 --- a/pkg/expression/builtin_arithmetic.go +++ b/pkg/expression/builtin_arithmetic.go @@ -57,6 +57,10 @@ var ( _ builtinFunc = &builtinArithmeticModIntSignedSignedSig{} _ builtinFunc = &builtinArithmeticModRealSig{} _ builtinFunc = &builtinArithmeticModDecimalSig{} + + _ builtinFunc = &builtinArithmeticPlusVectorFloat32Sig{} + _ builtinFunc = &builtinArithmeticMinusVectorFloat32Sig{} + _ builtinFunc = &builtinArithmeticMultiplyVectorFloat32Sig{} ) // isConstantBinaryLiteral return true if expr is constant binary literal @@ -167,6 +171,15 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx BuildContext, args []Expre if err := c.verifyArgs(args); err != nil { return nil, err } + if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() { + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32) + if err != nil { + return nil, err + } + sig := &builtinArithmeticPlusVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32) + return sig, nil + } lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1]) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) @@ -317,6 +330,15 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx BuildContext, args []Expr if err := c.verifyArgs(args); err != nil { return nil, err } + if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() { + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32) + if err != nil { + return nil, err + } + sig := &builtinArithmeticMinusVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32) + return sig, nil + } lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1]) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal) @@ -500,6 +522,15 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx BuildContext, args []E if err := c.verifyArgs(args); err != nil { return nil, err } + if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() { + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32) + if err != nil { + return nil, err + } + sig := &builtinArithmeticMultiplyVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32) + return sig, nil + } lhsTp, rhsTp := args[0].GetType(ctx.GetEvalCtx()), args[1].GetType(ctx.GetEvalCtx()) lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1]) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { @@ -1157,3 +1188,90 @@ func (s *builtinArithmeticModIntSignedSignedSig) evalInt(ctx EvalContext, row ch return a % b, false, nil } + +type builtinArithmeticPlusVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (s *builtinArithmeticPlusVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinArithmeticPlusVectorFloat32Sig{} + newSig.cloneFrom(&s.baseBuiltinFunc) + return newSig +} + +func (s *builtinArithmeticPlusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isLHSNull, err + } + b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isRHSNull, err + } + if isLHSNull || isRHSNull { + return types.ZeroVectorFloat32, true, nil + } + v, err := a.Add(b) + if err != nil { + return types.ZeroVectorFloat32, true, err + } + return v, false, nil +} + +type builtinArithmeticMinusVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (s *builtinArithmeticMinusVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinArithmeticMinusVectorFloat32Sig{} + newSig.cloneFrom(&s.baseBuiltinFunc) + return newSig +} + +func (s *builtinArithmeticMinusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isLHSNull, err + } + b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isRHSNull, err + } + if isLHSNull || isRHSNull { + return types.ZeroVectorFloat32, true, nil + } + v, err := a.Sub(b) + if err != nil { + return types.ZeroVectorFloat32, true, err + } + return v, false, nil +} + +type builtinArithmeticMultiplyVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (s *builtinArithmeticMultiplyVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinArithmeticMultiplyVectorFloat32Sig{} + newSig.cloneFrom(&s.baseBuiltinFunc) + return newSig +} + +func (s *builtinArithmeticMultiplyVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isLHSNull, err + } + b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row) + if err != nil { + return types.ZeroVectorFloat32, isRHSNull, err + } + if isLHSNull || isRHSNull { + return types.ZeroVectorFloat32, true, nil + } + v, err := a.Mul(b) + if err != nil { + return types.ZeroVectorFloat32, true, err + } + return v, false, nil +} diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index 3ae35a314f7f5..16f7a4c2f44eb 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -691,7 +691,7 @@ func (c *castAsVectorFloat32FunctionClass) getFunction(ctx BuildContext, args [] sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32) case types.ETString: sig = &builtinCastStringAsVectorFloat32Sig{bf} - sig.setPbCode(tipb.ScalarFuncSig_CastStringAsVectorFloat32) + // sig.setPbCode(tipb.ScalarFuncSig_CastStringAsVectorFloat32) default: return nil, errors.Errorf("cannot cast from %s to %s", argTp, "VectorFloat32") } diff --git a/pkg/expression/builtin_compare.go b/pkg/expression/builtin_compare.go index 2987bc52e49bb..2ab2bcd2ae800 100644 --- a/pkg/expression/builtin_compare.go +++ b/pkg/expression/builtin_compare.go @@ -46,6 +46,7 @@ var ( _ builtinFunc = &builtinCoalesceStringSig{} _ builtinFunc = &builtinCoalesceTimeSig{} _ builtinFunc = &builtinCoalesceDurationSig{} + _ builtinFunc = &builtinCoalesceVectorFloat32Sig{} _ builtinFunc = &builtinGreatestIntSig{} _ builtinFunc = &builtinGreatestRealSig{} @@ -54,6 +55,7 @@ var ( _ builtinFunc = &builtinGreatestDurationSig{} _ builtinFunc = &builtinGreatestTimeSig{} _ builtinFunc = &builtinGreatestCmpStringAsTimeSig{} + _ builtinFunc = &builtinGreatestVectorFloat32Sig{} _ builtinFunc = &builtinLeastIntSig{} _ builtinFunc = &builtinLeastRealSig{} _ builtinFunc = &builtinLeastDecimalSig{} @@ -61,6 +63,7 @@ var ( _ builtinFunc = &builtinLeastTimeSig{} _ builtinFunc = &builtinLeastDurationSig{} _ builtinFunc = &builtinLeastCmpStringAsTimeSig{} + _ builtinFunc = &builtinLeastVectorFloat32Sig{} _ builtinFunc = &builtinIntervalIntSig{} _ builtinFunc = &builtinIntervalRealSig{} @@ -167,6 +170,9 @@ func (c *coalesceFunctionClass) getFunction(ctx BuildContext, args []Expression) case types.ETJson: sig = &builtinCoalesceJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CoalesceJson) + case types.ETVectorFloat32: + sig = &builtinCoalesceVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CoalesceVectorFloat32) default: return nil, errors.Errorf("%s is not supported for COALESCE()", retEvalTp) } @@ -331,6 +337,28 @@ func (b *builtinCoalesceJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (res t return res, isNull, err } +// builtinCoalesceVectorFloat32Sig is builtin function coalesce signature which return type vector float32. +// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_coalesce +type builtinCoalesceVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinCoalesceVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinCoalesceVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCoalesceVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + for _, a := range b.getArgs() { + res, isNull, err = a.EvalVectorFloat32(ctx, row) + if err != nil || !isNull { + break + } + } + return res, isNull, err +} + func aggregateType(ctx EvalContext, args []Expression) *types.FieldType { fieldTypes := make([]*types.FieldType, len(args)) for i := range fieldTypes { @@ -499,6 +527,9 @@ func (c *greatestFunctionClass) getFunction(ctx BuildContext, args []Expression) sig = &builtinGreatestTimeSig{bf, false} sig.setPbCode(tipb.ScalarFuncSig_GreatestTime) } + case types.ETVectorFloat32: + sig = &builtinGreatestVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_GreatestVectorFloat32) default: return nil, errors.Errorf("unsupported type %s during evaluation", argTp) } @@ -754,6 +785,29 @@ func (b *builtinGreatestDurationSig) evalDuration(ctx EvalContext, row chunk.Row return res, false, nil } +type builtinGreatestVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinGreatestVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinGreatestVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinGreatestVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + for i := 0; i < len(b.args); i++ { + v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return types.VectorFloat32{}, true, err + } + if i == 0 || v.Compare(res) > 0 { + res = v + } + } + return res, false, nil +} + type leastFunctionClass struct { baseFunctionClass } @@ -814,6 +868,9 @@ func (c *leastFunctionClass) getFunction(ctx BuildContext, args []Expression) (s sig = &builtinLeastTimeSig{bf, false} sig.setPbCode(tipb.ScalarFuncSig_LeastTime) } + case types.ETVectorFloat32: + sig = &builtinLeastVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_LeastVectorFloat32) default: return nil, errors.Errorf("unsupported type %s during evaluation", argTp) } @@ -1039,6 +1096,29 @@ func (b *builtinLeastDurationSig) evalDuration(ctx EvalContext, row chunk.Row) ( return res, false, nil } +type builtinLeastVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinLeastVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinLeastVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinLeastVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) { + for i := 0; i < len(b.args); i++ { + v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return types.VectorFloat32{}, true, err + } + if i == 0 || v.Compare(res) < 0 { + res = v + } + } + return res, false, nil +} + type intervalFunctionClass struct { baseFunctionClass } diff --git a/pkg/expression/builtin_control.go b/pkg/expression/builtin_control.go index 4232597bce08c..099d0b7443249 100644 --- a/pkg/expression/builtin_control.go +++ b/pkg/expression/builtin_control.go @@ -38,6 +38,7 @@ var ( _ builtinFunc = &builtinCaseWhenTimeSig{} _ builtinFunc = &builtinCaseWhenDurationSig{} _ builtinFunc = &builtinCaseWhenJSONSig{} + _ builtinFunc = &builtinCaseWhenVectorFloat32Sig{} _ builtinFunc = &builtinIfNullIntSig{} _ builtinFunc = &builtinIfNullRealSig{} _ builtinFunc = &builtinIfNullDecimalSig{} @@ -45,6 +46,7 @@ var ( _ builtinFunc = &builtinIfNullTimeSig{} _ builtinFunc = &builtinIfNullDurationSig{} _ builtinFunc = &builtinIfNullJSONSig{} + _ builtinFunc = &builtinIfNullVectorFloat32Sig{} _ builtinFunc = &builtinIfIntSig{} _ builtinFunc = &builtinIfRealSig{} _ builtinFunc = &builtinIfDecimalSig{} @@ -52,6 +54,7 @@ var ( _ builtinFunc = &builtinIfTimeSig{} _ builtinFunc = &builtinIfDurationSig{} _ builtinFunc = &builtinIfJSONSig{} + _ builtinFunc = &builtinIfVectorFloat32Sig{} ) func maxlen(lhsFlen, rhsFlen int) int { @@ -373,6 +376,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx BuildContext, args []Expression) case types.ETJson: sig = &builtinCaseWhenJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CaseWhenJson) + case types.ETVectorFloat32: + sig = &builtinCaseWhenVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_CaseWhenVectorFloat32) default: return nil, errors.Errorf("%s is not supported for CASE WHEN", tp) } @@ -629,6 +635,40 @@ func (b *builtinCaseWhenJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (ret t return ret, true, nil } +type builtinCaseWhenVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinCaseWhenVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinCaseWhenVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalVectorFloat32 evals a builtinCaseWhenVectorFloat32Sig. +// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#operator_case +func (b *builtinCaseWhenVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (ret types.VectorFloat32, isNull bool, err error) { + var condition int64 + args, l := b.getArgs(), len(b.getArgs()) + for i := 0; i < l-1; i += 2 { + condition, isNull, err = args[i].EvalInt(ctx, row) + if err != nil { + return + } + if isNull || condition == 0 { + continue + } + return args[i+1].EvalVectorFloat32(ctx, row) + } + // when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1) + // else clause -> args[l-1] + // If case clause has else clause, l%2 == 1. + if l%2 == 1 { + return args[l-1].EvalVectorFloat32(ctx, row) + } + return ret, true, nil +} + type ifFunctionClass struct { baseFunctionClass } @@ -676,6 +716,9 @@ func (c *ifFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig case types.ETJson: sig = &builtinIfJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_IfJson) + case types.ETVectorFloat32: + sig = &builtinIfVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_IfVectorFloat32) default: return nil, errors.Errorf("%s is not supported for IF()", evalTps) } @@ -829,6 +872,27 @@ func (b *builtinIfJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (ret types.B return b.args[2].EvalJSON(ctx, row) } +type builtinIfVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinIfVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinIfVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinIfVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (ret types.VectorFloat32, isNull bool, err error) { + arg0, isNull0, err := b.args[0].EvalInt(ctx, row) + if err != nil { + return ret, true, err + } + if !isNull0 && arg0 != 0 { + return b.args[1].EvalVectorFloat32(ctx, row) + } + return b.args[2].EvalVectorFloat32(ctx, row) +} + type ifNullFunctionClass struct { baseFunctionClass } @@ -878,6 +942,9 @@ func (c *ifNullFunctionClass) getFunction(ctx BuildContext, args []Expression) ( case types.ETJson: sig = &builtinIfNullJSONSig{bf} sig.setPbCode(tipb.ScalarFuncSig_IfNullJson) + case types.ETVectorFloat32: + sig = &builtinIfNullVectorFloat32Sig{bf} + // sig.setPbCode(tipb.ScalarFuncSig_IfNullVectorFloat32) default: return nil, errors.Errorf("%s is not supported for IFNULL()", evalTps) } @@ -1016,3 +1083,22 @@ func (b *builtinIfNullJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (types.B arg1, isNull, err := b.args[1].EvalJSON(ctx, row) return arg1, isNull || err != nil, err } + +type builtinIfNullVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinIfNullVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinIfNullVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinIfNullVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) { + arg0, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if !isNull { + return arg0, err != nil, err + } + arg1, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + return arg1, isNull || err != nil, err +} diff --git a/pkg/expression/builtin_other.go b/pkg/expression/builtin_other.go index 1f672c69b8315..aa6fd599f7e5b 100644 --- a/pkg/expression/builtin_other.go +++ b/pkg/expression/builtin_other.go @@ -57,6 +57,7 @@ var ( _ builtinFunc = &builtinInTimeSig{} _ builtinFunc = &builtinInDurationSig{} _ builtinFunc = &builtinInJSONSig{} + _ builtinFunc = &builtinInVectorFloat32Sig{} _ builtinFunc = &builtinRowSig{} _ builtinFunc = &builtinSetStringVarSig{} _ builtinFunc = &builtinSetIntVarSig{} @@ -153,6 +154,9 @@ func (c *inFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig case types.ETJson: sig = &builtinInJSONSig{baseBuiltinFunc: bf} sig.setPbCode(tipb.ScalarFuncSig_InJson) + case types.ETVectorFloat32: + sig = &builtinInVectorFloat32Sig{baseBuiltinFunc: bf} + // sig.setPbCode(tipb.ScalarFuncSig_InVectorFloat32) default: return nil, errors.Errorf("%s is not supported for IN()", args[0].GetType(ctx.GetEvalCtx()).EvalType()) } @@ -683,6 +687,39 @@ func (b *builtinInJSONSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, return 0, hasNull, nil } +type builtinInVectorFloat32Sig struct { + baseBuiltinFunc +} + +func (b *builtinInVectorFloat32Sig) Clone() builtinFunc { + newSig := &builtinInVectorFloat32Sig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinInVectorFloat32Sig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg0, isNull0, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull0 || err != nil { + return 0, isNull0, err + } + var hasNull bool + for _, arg := range b.args[1:] { + evaledArg, isNull, err := arg.EvalVectorFloat32(ctx, row) + if err != nil { + return 0, true, err + } + if isNull { + hasNull = true + continue + } + result := arg0.Compare(evaledArg) + if result == 0 { + return 1, false, nil + } + } + return 0, hasNull, nil +} + type rowFunctionClass struct { baseFunctionClass } diff --git a/pkg/expression/builtin_vec.go b/pkg/expression/builtin_vec.go index 44cdc690bb9c0..e1d8374461828 100644 --- a/pkg/expression/builtin_vec.go +++ b/pkg/expression/builtin_vec.go @@ -15,11 +15,35 @@ package expression import ( + "math" + "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tipb/go-tipb" ) +var ( + _ functionClass = &vecDimsFunctionClass{} + _ functionClass = &vecL1DistanceFunctionClass{} + _ functionClass = &vecL2DistanceFunctionClass{} + _ functionClass = &vecNegativeInnerProductFunctionClass{} + _ functionClass = &vecCosineDistanceFunctionClass{} + _ functionClass = &vecL2NormFunctionClass{} + _ functionClass = &vecFromTextFunctionClass{} + _ functionClass = &vecAsTextFunctionClass{} +) + +var ( + _ builtinFunc = &builtinVecDimsSig{} + _ builtinFunc = &builtinVecL1DistanceSig{} + _ builtinFunc = &builtinVecL2DistanceSig{} + _ builtinFunc = &builtinVecNegativeInnerProductSig{} + _ builtinFunc = &builtinVecCosineDistanceSig{} + _ builtinFunc = &builtinVecL2NormSig{} + _ builtinFunc = &builtinVecFromTextSig{} + _ builtinFunc = &builtinVecAsTextSig{} +) + type vecDimsFunctionClass struct { baseFunctionClass } @@ -59,6 +83,258 @@ func (b *builtinVecDimsSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, return int64(v.Len()), false, nil } +type vecL1DistanceFunctionClass struct { + baseFunctionClass +} + +type builtinVecL1DistanceSig struct { + baseBuiltinFunc +} + +func (b *builtinVecL1DistanceSig) Clone() builtinFunc { + newSig := &builtinVecL1DistanceSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecL1DistanceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecL1DistanceSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecL1DistanceSig) + return sig, nil +} + +func (b *builtinVecL1DistanceSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.L1Distance(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecL2DistanceFunctionClass struct { + baseFunctionClass +} + +type builtinVecL2DistanceSig struct { + baseBuiltinFunc +} + +func (b *builtinVecL2DistanceSig) Clone() builtinFunc { + newSig := &builtinVecL2DistanceSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecL2DistanceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecL2DistanceSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecL2DistanceSig) + return sig, nil +} + +func (b *builtinVecL2DistanceSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.L2Distance(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecNegativeInnerProductFunctionClass struct { + baseFunctionClass +} + +type builtinVecNegativeInnerProductSig struct { + baseBuiltinFunc +} + +func (b *builtinVecNegativeInnerProductSig) Clone() builtinFunc { + newSig := &builtinVecNegativeInnerProductSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecNegativeInnerProductFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecNegativeInnerProductSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecNegativeInnerProductSig) + return sig, nil +} + +func (b *builtinVecNegativeInnerProductSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.NegativeInnerProduct(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecCosineDistanceFunctionClass struct { + baseFunctionClass +} + +type builtinVecCosineDistanceSig struct { + baseBuiltinFunc +} + +func (b *builtinVecCosineDistanceSig) Clone() builtinFunc { + newSig := &builtinVecCosineDistanceSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecCosineDistanceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecCosineDistanceSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecCosineDistanceSig) + return sig, nil +} + +func (b *builtinVecCosineDistanceSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v1, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + v2, isNull, err := b.args[1].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d, err := v1.CosineDistance(v2) + if err != nil { + return res, false, err + } + + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + +type vecL2NormFunctionClass struct { + baseFunctionClass +} + +type builtinVecL2NormSig struct { + baseBuiltinFunc +} + +func (b *builtinVecL2NormSig) Clone() builtinFunc { + newSig := &builtinVecL2NormSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *vecL2NormFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETVectorFloat32) + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...) + if err != nil { + return nil, err + } + sig := &builtinVecL2NormSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_VecL2NormSig) + return sig, nil +} + +func (b *builtinVecL2NormSig) evalReal(ctx EvalContext, row chunk.Row) (res float64, isNull bool, err error) { + v, isNull, err := b.args[0].EvalVectorFloat32(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + d := v.L2Norm() + if math.IsNaN(d) { + return 0, true, nil + } + return d, false, nil +} + type vecFromTextFunctionClass struct { baseFunctionClass } diff --git a/pkg/expression/distsql_builtin.go b/pkg/expression/distsql_builtin.go index 0012410a60c80..ed83d0cdd6688 100644 --- a/pkg/expression/distsql_builtin.go +++ b/pkg/expression/distsql_builtin.go @@ -1076,6 +1076,42 @@ func getSignatureByPB(ctx BuildContext, sigCode tipb.ScalarFuncSig, tp *tipb.Fie case tipb.ScalarFuncSig_FromBinary: // TODO: set the `cannotConvertStringAsWarning` accordingly f = &builtinInternalFromBinarySig{base, false} + case tipb.ScalarFuncSig_CastVectorFloat32AsString: + f = &builtinCastVectorFloat32AsStringSig{base} + case tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32: + f = &builtinCastVectorFloat32AsVectorFloat32Sig{base} + case tipb.ScalarFuncSig_LTVectorFloat32: + f = &builtinLTVectorFloat32Sig{base} + case tipb.ScalarFuncSig_LEVectorFloat32: + f = &builtinLEVectorFloat32Sig{base} + case tipb.ScalarFuncSig_GTVectorFloat32: + f = &builtinGTVectorFloat32Sig{base} + case tipb.ScalarFuncSig_GEVectorFloat32: + f = &builtinGEVectorFloat32Sig{base} + case tipb.ScalarFuncSig_NEVectorFloat32: + f = &builtinNEVectorFloat32Sig{base} + case tipb.ScalarFuncSig_EQVectorFloat32: + f = &builtinEQVectorFloat32Sig{base} + case tipb.ScalarFuncSig_NullEQVectorFloat32: + f = &builtinNullEQVectorFloat32Sig{base} + case tipb.ScalarFuncSig_VectorFloat32AnyValue: + f = &builtinVectorFloat32AnyValueSig{base} + case tipb.ScalarFuncSig_VectorFloat32IsNull: + f = &builtinVectorFloat32IsNullSig{base} + case tipb.ScalarFuncSig_VecAsTextSig: + f = &builtinVecAsTextSig{base} + case tipb.ScalarFuncSig_VecDimsSig: + f = &builtinVecDimsSig{base} + case tipb.ScalarFuncSig_VecL1DistanceSig: + f = &builtinVecL1DistanceSig{base} + case tipb.ScalarFuncSig_VecL2DistanceSig: + f = &builtinVecL2DistanceSig{base} + case tipb.ScalarFuncSig_VecNegativeInnerProductSig: + f = &builtinVecNegativeInnerProductSig{base} + case tipb.ScalarFuncSig_VecCosineDistanceSig: + f = &builtinVecCosineDistanceSig{base} + case tipb.ScalarFuncSig_VecL2NormSig: + f = &builtinVecL2NormSig{base} default: e = ErrFunctionNotExists.GenWithStackByArgs("FUNCTION", sigCode) @@ -1149,6 +1185,8 @@ func PBToExpr(ctx BuildContext, expr *tipb.Expr, tps []*types.FieldType) (Expres return convertJSON(expr.Val) case tipb.ExprType_MysqlEnum: return convertEnum(expr.Val, expr.FieldType) + case tipb.ExprType_TiDBVectorFloat32: + return convertVectorFloat32(expr.Val) } if expr.Tp != tipb.ExprType_ScalarFunc { panic("should be a tipb.ExprType_ScalarFunc") @@ -1293,6 +1331,16 @@ func convertJSON(val []byte) (*Constant, error) { return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeJSON)}, nil } +func convertVectorFloat32(val []byte) (*Constant, error) { + v, _, err := types.ZeroCopyDeserializeVectorFloat32(val) + if err != nil { + return nil, errors.Errorf("invalid VectorFloat32 %x", val) + } + var d types.Datum + d.SetVectorFloat32(v) + return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeTiDBVectorFloat32)}, nil +} + func convertEnum(val []byte, tp *tipb.FieldType) (*Constant, error) { _, uVal, err := codec.DecodeUint(val) if err != nil { diff --git a/pkg/expression/distsql_builtin_test.go b/pkg/expression/distsql_builtin_test.go index 47214f4dd14e8..ca52e767c69a4 100644 --- a/pkg/expression/distsql_builtin_test.go +++ b/pkg/expression/distsql_builtin_test.go @@ -907,6 +907,12 @@ func datumExpr(t *testing.T, d types.Datum) *tipb.Expr { expr.Val = make([]byte, 0, 1024) expr.Val, err = codec.EncodeValue(time.UTC, expr.Val, d) require.NoError(t, err) + case types.KindVectorFloat32: + expr.Tp = tipb.ExprType_TiDBVectorFloat32 + var err error + expr.Val = make([]byte, 0, 1024) + expr.Val, err = codec.EncodeValue(nil, expr.Val, d) + require.NoError(t, err) case types.KindMysqlTime: expr.Tp = tipb.ExprType_MysqlTime var err error diff --git a/pkg/expression/infer_pushdown.go b/pkg/expression/infer_pushdown.go index 48f7a6a51e778..60648002d4c55 100644 --- a/pkg/expression/infer_pushdown.go +++ b/pkg/expression/infer_pushdown.go @@ -121,10 +121,6 @@ func canScalarFuncPushDown(ctx PushDownContext, scalarFunc *ScalarFunction, stor func canExprPushDown(ctx PushDownContext, expr Expression, storeType kv.StoreType, canEnumPush bool) bool { pc := ctx.PbConverter() - if expr.GetType(ctx.EvalCtx()).GetType() == mysql.TypeTiDBVectorFloat32 { - // For both TiKV and TiFlash, currently Vector cannot be pushed. - return false - } if storeType == kv.TiFlash { switch expr.GetType(ctx.EvalCtx()).GetType() { case mysql.TypeEnum, mysql.TypeBit, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeUnspecified: @@ -202,6 +198,9 @@ func scalarExprSupportedByTiKV(ctx EvalContext, sf *ScalarFunction) bool { ast.JSONInsert, ast.JSONReplace, ast.JSONRemove, ast.JSONLength, ast.JSONMergePatch, ast.JSONUnquote, ast.JSONContains, ast.JSONValid, ast.JSONMemberOf, ast.JSONArrayAppend, + // vector functions. + ast.VecDims, ast.VecL1Distance, ast.VecL2Distance, ast.VecNegativeInnerProduct, ast.VecCosineDistance, ast.VecL2Norm, ast.VecAsText, + // date functions. ast.Date, ast.Week /* ast.YearWeek, ast.ToSeconds */, ast.DateDiff, /* ast.TimeDiff, ast.AddTime, ast.SubTime, */ @@ -404,6 +403,8 @@ func scalarExprSupportedByFlash(ctx EvalContext, function *ScalarFunction) bool return true case ast.IsIPv4, ast.IsIPv6: return true + case ast.VecDims, ast.VecL1Distance, ast.VecL2Distance, ast.VecNegativeInnerProduct, ast.VecCosineDistance, ast.VecL2Norm, ast.VecAsText: + return true case ast.Grouping: // grouping function for grouping sets identification. return true } diff --git a/pkg/expression/integration_test/BUILD.bazel b/pkg/expression/integration_test/BUILD.bazel index 14960eabf9a92..22eddae5b4cc2 100644 --- a/pkg/expression/integration_test/BUILD.bazel +++ b/pkg/expression/integration_test/BUILD.bazel @@ -8,7 +8,7 @@ go_test( "main_test.go", ], flaky = True, - shard_count = 35, + shard_count = 40, deps = [ "//pkg/config", "//pkg/domain", diff --git a/pkg/expression/integration_test/integration_test.go b/pkg/expression/integration_test/integration_test.go index afa49ac9d3d16..a557f8a942cbe 100644 --- a/pkg/expression/integration_test/integration_test.go +++ b/pkg/expression/integration_test/integration_test.go @@ -181,8 +181,19 @@ func TestVector(t *testing.T) { "[1.4,4.5,8.5,7.7,6.2]", )) - // Arithmatic: Currently not implemented. - tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1,2]') + VEC_FROM_TEXT('[2,3]')`) + // Golang produce different results in different Arch for float points. + // Adding a ROUND to make this test stable. + // See https://go.dev/ref/spec#Arithmetic_operators + tk.MustQuery(`SELECT val, + ROUND(VEC_Cosine_Distance(val, '[1,2,3,4,5]'), 5) AS d + FROM t ORDER BY d DESC; + `).Check(testkit.Rows( + "[8.7,5.7,7.7,9.8,1.5] 0.25641", + "[3.6,9.7,2.4,6.6,4.9] 0.18577", + "[7.7,6.7,8.3,7.8,5.7] 0.12677", + "[4.7,4.9,2.6,5.2,7.4] 0.06925", + "[1.4,4.5,8.5,7.7,6.2] 0.04973", + )) } func TestVectorOperators(t *testing.T) { @@ -204,9 +215,9 @@ func TestVectorOperators(t *testing.T) { tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NOT NULL`).Check(testkit.Rows("1")) tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NULL`).Check(testkit.Rows("0")) tk.MustQuery(`SELECT * FROM t WHERE embedding = VEC_FROM_TEXT('[1,2,3]');`).Check(testkit.Rows("[1,2,3]")) - tk.MustQuery(`SELECT * FROM t WHERE embedding BETWEEN '[1,2,3]' AND '[4,5,6]'`).Check(testkit.Rows("[1,2,3]", "[4,5,6]")) - tk.MustExecToErr(`SELECT * FROM t WHERE embedding IN ('[1, 2, 3]', '[4, 5, 6]')`) - tk.MustExecToErr(`SELECT * FROM t WHERE embedding NOT IN ('[1, 2, 3]', '[4, 5, 6]')`) + tk.MustQuery(`SELECT * FROM t WHERE embedding BETWEEN '[1, 2, 3]' AND '[4, 5, 6]'`).Check(testkit.Rows("[1,2,3]", "[4,5,6]")) + tk.MustQuery(`SELECT * FROM t WHERE embedding IN ('[1, 2, 3]', '[4, 5, 6]')`).Check(testkit.Rows("[1,2,3]", "[4,5,6]")) + tk.MustQuery(`SELECT * FROM t WHERE embedding NOT IN ('[1, 2, 3]', '[4, 5, 6]')`).Check(testkit.Rows("[7,8,9]")) } func TestVectorCompare(t *testing.T) { @@ -234,6 +245,15 @@ func TestVectorCompare(t *testing.T) { tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') >= '[1]';").Check(testkit.Rows("1")) tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') < '[1]';").Check(testkit.Rows("0")) tk.MustQuery("SELECT VEC_FROM_TEXT('[1, 2, 3]') <= '[1]';").Check(testkit.Rows("0")) + + tk.MustQuery(`SELECT GREATEST(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'), VEC_FROM_TEXT('[7, 8, 9]')) AS result;`).Check(testkit.Rows("[7,8,9]")) + tk.MustQuery(`SELECT LEAST(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'), VEC_FROM_TEXT('[7, 8, 9]')) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]')) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(NULL, VEC_FROM_TEXT('[1, 2, 3]')) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(VEC_FROM_TEXT('[1, 2, 3]'), 1) AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(VEC_FROM_TEXT('[1, 2, 3]'), '1') AS result;`).Check(testkit.Rows("[1,2,3]")) + tk.MustQuery(`SELECT COALESCE(1, VEC_FROM_TEXT('[1, 2, 3]'), 1) AS result;`).Check(testkit.Rows("1")) + tk.MustQuery(`SELECT COALESCE('1', VEC_FROM_TEXT('[1, 2, 3]'), '1') AS result;`).Check(testkit.Rows("1")) } func TestVectorConversion(t *testing.T) { @@ -300,6 +320,58 @@ func TestVectorConversion(t *testing.T) { require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)") } +func TestVectorAssignVariable(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec(`SET @a = VEC_FROM_TEXT('[1,2,3]');`) + tk.MustQuery(`SELECT @a;`).Check(testkit.Rows("[1,2,3]")) +} + +func TestVectorControlFlow(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + // IF + tk.MustQuery("SELECT IF(VEC_FROM_TEXT('[1, 2, 3]'), 1, 0);").Check(testkit.Rows("1")) + tk.MustQuery("SELECT IF(TRUE, VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'));").Check(testkit.Rows("[1,2,3]")) + + // IFNULL + tk.MustQuery("SELECT IFNULL(VEC_FROM_TEXT('[1, 2, 3]'), 1);").Check(testkit.Rows("[1,2,3]")) + tk.MustQuery("SELECT IFNULL(NULL, VEC_FROM_TEXT('[1, 2, 3]'));").Check(testkit.Rows("[1,2,3]")) + + // NULLIF + tk.MustQuery("SELECT NULLIF(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[1, 2, 3]'));").Check(testkit.Rows("")) + tk.MustQuery("SELECT NULLIF(VEC_FROM_TEXT('[1, 2, 3]'), VEC_FROM_TEXT('[4, 5, 6]'));").Check(testkit.Rows("[1,2,3]")) + + // CASE WHEN + tk.MustQuery("SELECT CASE WHEN TRUE THEN VEC_FROM_TEXT('[1, 2, 3]') ELSE VEC_FROM_TEXT('[4, 5, 6]') END;").Check(testkit.Rows("[1,2,3]")) +} + +func TestVectorStringCompare(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustExec("DROP TABLE IF EXISTS t1;") + tk.MustExec("CREATE TABLE t1 (val vector);") + tk.MustExec("INSERT INTO t1 VALUES ('[1,2,3]'), ('[4,5,6]');") + + // LIKE + tk.MustQuery("SELECT * FROM t1 WHERE val LIKE '%2%';").Check(testkit.Rows("[1,2,3]")) + + // ILIKE + tk.MustQuery("SELECT * FROM t1 WHERE val ILIKE '%2%';").Check(testkit.Rows("[1,2,3]")) + + // STRCMP + tk.MustQuery("SELECT STRCMP('[1,2,3]', VEC_FROM_TEXT('[1,2,3]'));").Check(testkit.Rows("0")) + tk.MustQuery("SELECT STRCMP('[4,5,6]', VEC_FROM_TEXT('[1,2,3]'));").Check(testkit.Rows("1")) +} + func TestVectorAggregations(t *testing.T) { store := testkit.CreateMockStore(t) @@ -429,6 +501,70 @@ func TestVectorSetOperation(t *testing.T) { )) } +func TestVectorArithmatic(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + tk.MustExec(`CREATE TABLE t(embedding VECTOR);`) + tk.MustExec(`INSERT INTO t VALUES + ('[1, 2, 3]'), + ('[4, 5, 6]'), + ('[7, 8, 9]'); + `) + tk.MustQuery(`SELECT embedding + '[1, 2, 3]' FROM t;`).Check(testkit.Rows("[2,4,6]", "[5,7,9]", "[8,10,12]")) + tk.MustQuery(`SELECT embedding + embedding FROM t;`).Check(testkit.Rows("[2,4,6]", "[8,10,12]", "[14,16,18]")) + tk.MustQueryToErr(`SELECT embedding + 1 FROM t;`) + tk.MustQueryToErr(`SELECT embedding + '[]' FROM t;`) + tk.MustQuery(`SELECT embedding - '[1, 2, 3]' FROM t;`).Check(testkit.Rows("[0,0,0]", "[3,3,3]", "[6,6,6]")) + tk.MustQuery(`SELECT embedding - embedding FROM t;`).Check(testkit.Rows("[0,0,0]", "[0,0,0]", "[0,0,0]")) + tk.MustQueryToErr(`SELECT embedding - '[1]' FROM t;`) + + tk.MustQuery(`SELECT VEC_FROM_TEXT('[1,2]') + VEC_FROM_TEXT('[2,3]');`).Check(testkit.Rows("[3,5]")) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[1,2]') + '[2,3]';`).Check(testkit.Rows("[3,5]")) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1,2]') + '[2,3,4]';`) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1]') + 2;`) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1]') + '2';`) + + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[3e38]') + '[3e38]';`) + tk.MustQuery(`SELECT VEC_FROM_TEXT('[1,2,3]') * '[4,5,6]';`).Check(testkit.Rows("[4,10,18]")) + tk.MustQueryToErr(`SELECT VEC_FROM_TEXT('[1e37]') * '[1e37]';`) +} + +func TestVectorFunctions(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("USE test;") + + tk.MustQuery(`SELECT VEC_L1_DISTANCE('[0,0]', '[3,4]');`).Check(testkit.Rows("7")) + tk.MustQuery(`SELECT VEC_L1_DISTANCE('[0,0]', '[0,1]');`).Check(testkit.Rows("1")) + tk.MustQueryToErr("SELECT VEC_L1_DISTANCE('[1,2]', '[3]');") + tk.MustQuery(`SELECT VEC_L1_DISTANCE('[3e38]', '[-3e38]');`).Check(testkit.Rows("+Inf")) + + tk.MustQuery(`SELECT VEC_L2_DISTANCE('[0,0]', '[3,4]');`).Check(testkit.Rows("5")) + tk.MustQuery(`SELECT VEC_L2_DISTANCE('[0,0]', '[0,1]');`).Check(testkit.Rows("1")) + tk.MustQueryToErr(`SELECT VEC_L2_DISTANCE('[1,2]', '[3]');`) + tk.MustQuery(`SELECT VEC_L2_DISTANCE('[3e38]', '[-3e38]');`).Check(testkit.Rows("+Inf")) + + tk.MustQuery(`SELECT VEC_NEGATIVE_INNER_PRODUCT('[1,2]', '[3,4]');`).Check(testkit.Rows("-11")) + tk.MustQueryToErr(`SELECT VEC_NEGATIVE_INNER_PRODUCT('[1,2]', '[3]');`) + tk.MustQuery(`SELECT VEC_NEGATIVE_INNER_PRODUCT('[3e38]', '[3e38]');`).Check(testkit.Rows("-Inf")) + + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,2]', '[2,4]');`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,2]', '[0,0]');`).Check(testkit.Rows("")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[1,1]');`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,0]', '[0,2]');`).Check(testkit.Rows("1")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[-1,-1]');`).Check(testkit.Rows("2")) + tk.MustQueryToErr(`SELECT VEC_COSINE_DISTANCE('[1,2]', '[3]');`) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[1.1,1.1]');`).Check(testkit.Rows("0")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[1,1]', '[-1.1,-1.1]');`).Check(testkit.Rows("2")) + tk.MustQuery(`SELECT VEC_COSINE_DISTANCE('[3e38]', '[3e38]');`).Check(testkit.Rows("")) + + tk.MustQuery(`SELECT VEC_L2_NORM('[3,4]');`).Check(testkit.Rows("5")) + tk.MustQuery(`SELECT VEC_L2_NORM('[0,1]');`).Check(testkit.Rows("1")) +} + func TestGetLock(t *testing.T) { ctx := context.Background() store := testkit.CreateMockStore(t, mockstore.WithStoreType(mockstore.EmbedUnistore)) diff --git a/pkg/types/vector.go b/pkg/types/vector.go index 8bb86f51378ae..0cd2d21093d4e 100644 --- a/pkg/types/vector.go +++ b/pkg/types/vector.go @@ -16,6 +16,7 @@ package types import ( "encoding/binary" + "math" "strconv" "unsafe" @@ -151,16 +152,28 @@ func ZeroCopyDeserializeVectorFloat32(b []byte) (VectorFloat32, []byte, error) { // ParseVectorFloat32 parses a string into a vector. func ParseVectorFloat32(s string) (VectorFloat32, error) { var values []float32 + var valueError error // We explicitly use a JSON float parser to reject other JSON types. parser := jsoniter.ParseString(jsoniter.ConfigDefault, s) parser.ReadArrayCB(func(parser *jsoniter.Iterator) bool { v := parser.ReadFloat64() + if math.IsNaN(v) { + valueError = errors.Errorf("NaN not allowed in vector") + return false + } + if math.IsInf(v, 0) { + valueError = errors.Errorf("infinite value not allowed in vector") + return false + } values = append(values, float32(v)) return true }) if parser.Error != nil { return ZeroVectorFloat32, errors.Errorf("Invalid vector text: %s", s) } + if valueError != nil { + return ZeroVectorFloat32, valueError + } dim := len(values) if err := CheckVectorDimValid(dim); err != nil { diff --git a/pkg/types/vector_functions.go b/pkg/types/vector_functions.go index 226a714bc8157..c75ade920b21c 100644 --- a/pkg/types/vector_functions.go +++ b/pkg/types/vector_functions.go @@ -14,6 +14,235 @@ package types +import ( + "math" + + "github.com/pingcap/errors" +) + +func (a VectorFloat32) checkIdenticalDims(b VectorFloat32) error { + if a.Len() != b.Len() { + return errors.Errorf("vectors have different dimensions: %d and %d", a.Len(), b.Len()) + } + return nil +} + +// L2SquaredDistance returns the squared L2 distance between two vectors. +// This saves a sqrt calculation. +func (a VectorFloat32) L2SquaredDistance(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + diff := va[i] - vb[i] + distance += diff * diff + } + + return float64(distance), nil +} + +// L2Distance returns the L2 distance between two vectors. +func (a VectorFloat32) L2Distance(b VectorFloat32) (float64, error) { + d, err := a.L2SquaredDistance(b) + if err != nil { + return 0, errors.Trace(err) + } + return math.Sqrt(d), nil +} + +// InnerProduct returns the inner product of two vectors. +func (a VectorFloat32) InnerProduct(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + distance += va[i] * vb[i] + } + + return float64(distance), nil +} + +// NegativeInnerProduct returns the negative inner product of two vectors. +func (a VectorFloat32) NegativeInnerProduct(b VectorFloat32) (float64, error) { + d, err := a.InnerProduct(b) + if err != nil { + return 0, errors.Trace(err) + } + return d * -1, nil +} + +// CosineDistance returns the cosine distance between two vectors. +func (a VectorFloat32) CosineDistance(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + var norma float32 = 0.0 + var normb float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + distance += va[i] * vb[i] + norma += va[i] * va[i] + normb += vb[i] * vb[i] + } + + similarity := float64(distance) / math.Sqrt(float64(norma)*float64(normb)) + + if math.IsNaN(similarity) { + // Divide by zero + return math.NaN(), nil + } + + if similarity > 1.0 { + similarity = 1.0 + } else if similarity < -1.0 { + similarity = -1.0 + } + + return 1.0 - similarity, nil +} + +// L1Distance returns the L1 distance between two vectors. +func (a VectorFloat32) L1Distance(b VectorFloat32) (float64, error) { + if err := a.checkIdenticalDims(b); err != nil { + return 0, errors.Trace(err) + } + + var distance float32 = 0.0 + va := a.Elements() + vb := b.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + diff := va[i] - vb[i] + if diff < 0 { + diff = -diff + } + distance += diff + } + + return float64(distance), nil +} + +// L2Norm returns the L2 norm of the vector. +func (a VectorFloat32) L2Norm() float64 { + // Note: We align the impl with pgvector: Only l2_norm use double + // precision during calculation. + var norm float64 = 0.0 + + va := a.Elements() + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + norm += float64(va[i]) * float64(va[i]) + } + return math.Sqrt(norm) +} + +// Add adds two vectors. The vectors must have the same dimension. +func (a VectorFloat32) Add(b VectorFloat32) (VectorFloat32, error) { + if err := a.checkIdenticalDims(b); err != nil { + return ZeroVectorFloat32, errors.Trace(err) + } + + result := InitVectorFloat32(a.Len()) + + va := a.Elements() + vb := b.Elements() + vr := result.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + vr[i] = va[i] + vb[i] + } + for i, iMax := 0, a.Len(); i < iMax; i++ { + if math.IsInf(float64(vr[i]), 0) { + return ZeroVectorFloat32, errors.Errorf("value out of range: overflow") + } + if math.IsNaN(float64(vr[i])) { + return ZeroVectorFloat32, errors.Errorf("value out of range: NaN") + } + } + + return result, nil +} + +// Sub subtracts two vectors. The vectors must have the same dimension. +func (a VectorFloat32) Sub(b VectorFloat32) (VectorFloat32, error) { + if err := a.checkIdenticalDims(b); err != nil { + return ZeroVectorFloat32, errors.Trace(err) + } + + result := InitVectorFloat32(a.Len()) + + va := a.Elements() + vb := b.Elements() + vr := result.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + vr[i] = va[i] - vb[i] + } + + for i, iMax := 0, a.Len(); i < iMax; i++ { + if math.IsInf(float64(vr[i]), 0) { + return ZeroVectorFloat32, errors.Errorf("value out of range: overflow") + } + if math.IsNaN(float64(vr[i])) { + return ZeroVectorFloat32, errors.Errorf("value out of range: NaN") + } + } + + return result, nil +} + +// Mul multiplies two vectors. The vectors must have the same dimension. +func (a VectorFloat32) Mul(b VectorFloat32) (VectorFloat32, error) { + if err := a.checkIdenticalDims(b); err != nil { + return ZeroVectorFloat32, errors.Trace(err) + } + + result := InitVectorFloat32(a.Len()) + + va := a.Elements() + vb := b.Elements() + vr := result.Elements() + + for i, iMax := 0, a.Len(); i < iMax; i++ { + // Hope this can be vectorized. + vr[i] = va[i] * vb[i] + } + + for i, iMax := 0, a.Len(); i < iMax; i++ { + if math.IsInf(float64(vr[i]), 0) { + return ZeroVectorFloat32, errors.Errorf("value out of range: overflow") + } + if math.IsNaN(float64(vr[i])) { + return ZeroVectorFloat32, errors.Errorf("value out of range: NaN") + } + + // TODO: Check for underflow. + // See https://github.com/pgvector/pgvector/blob/81d13bd40f03890bb5b6360259628cd473c2e467/src/vector.c#L873 + } + + return result, nil +} + // Compare returns an integer comparing two vectors. The result will be 0 if a==b, -1 if a < b, and +1 if a > b. func (a VectorFloat32) Compare(b VectorFloat32) int { la := a.Len() diff --git a/tests/integrationtest/r/executor/show.result b/tests/integrationtest/r/executor/show.result index d7088b1fc386d..be1fe6fc7e470 100644 --- a/tests/integrationtest/r/executor/show.result +++ b/tests/integrationtest/r/executor/show.result @@ -876,8 +876,13 @@ uuid_short uuid_to_bin validate_password_strength vec_as_text +vec_cosine_distance vec_dims vec_from_text +vec_l1_distance +vec_l2_distance +vec_l2_norm +vec_negative_inner_product version vitess_hash week