diff --git a/expression/bench_test.go b/expression/bench_test.go index b8e60d2fa9f95..a983305ce8a3d 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -406,6 +406,19 @@ func (g *numStrGener) gen() interface{} { return fmt.Sprintf("%v", g.rangeInt64Gener.gen()) } +// realStrGener is used to generate real number strings. +type realStrGener struct { + rangeRealGener +} + +func (g *realStrGener) gen() interface{} { + val := g.rangeRealGener.gen() + if val == nil { + return nil + } + return fmt.Sprintf("%v", val) +} + // ipv6StrGener is used to generate ipv6 strings. type ipv6StrGener struct { } diff --git a/expression/builtin_cast_vec.go b/expression/builtin_cast_vec.go index 0358bfe094cba..897af3e904c9e 100644 --- a/expression/builtin_cast_vec.go +++ b/expression/builtin_cast_vec.go @@ -1240,11 +1240,47 @@ func (b *builtinCastJSONAsDecimalSig) vecEvalDecimal(input *chunk.Chunk, result } func (b *builtinCastStringAsRealSig) vectorized() bool { - return false + return true } func (b *builtinCastStringAsRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Column) error { - return errors.Errorf("not implemented") + if IsBinaryLiteral(b.args[0]) { + // This block is skipped by `castAsRealFunctionClass.getFunction()` + return b.args[0].VecEvalReal(b.ctx, input, result) + } + + n := input.NumRows() + bufStrings, err := b.bufAllocator.get(types.ETString, n) + if err != nil { + return err + } + defer b.bufAllocator.put(bufStrings) + if err := b.args[0].VecEvalString(b.ctx, input, bufStrings); err != nil { + return err + } + + result.ResizeFloat64(n, false) + result.MergeNulls(bufStrings) + f64s := result.Float64s() + sc := b.ctx.GetSessionVars().StmtCtx + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + val := bufStrings.GetString(i) + res, err := types.StrToFloat(sc, val) + if err != nil { + return err + } + if b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && res < 0 { + res = 0 + } + if res, err = types.ProduceFloatWithSpecifiedTp(res, b.tp, sc); err != nil { + return err + } + f64s[i] = res + } + return nil } func (b *builtinCastStringAsDecimalSig) vectorized() bool { diff --git a/expression/builtin_cast_vec_test.go b/expression/builtin_cast_vec_test.go index 5033506fd4a96..b96a7979b72ef 100644 --- a/expression/builtin_cast_vec_test.go +++ b/expression/builtin_cast_vec_test.go @@ -20,6 +20,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" ) @@ -38,6 +39,12 @@ var vecBuiltinCastCases = map[string][]vecExprBenchCase{ {retEvalType: types.ETReal, childrenTypes: []types.EvalType{types.ETJson}}, {retEvalType: types.ETReal, childrenTypes: []types.EvalType{types.ETDecimal}}, {retEvalType: types.ETReal, childrenTypes: []types.EvalType{types.ETDatetime}}, + {retEvalType: types.ETReal, childrenTypes: []types.EvalType{types.ETString}, + geners: []dataGenerator{&realStrGener{rangeRealGener{begin: -100000.0, end: 100000.0, nullRation: 0.5}}}, + }, + {retEvalType: types.ETReal, childrenTypes: []types.EvalType{types.ETString}, + constants: []*Constant{{Value: types.NewBinaryLiteralDatum([]byte("TiDB")), RetType: types.NewFieldType(mysql.TypeVarString)}}, + }, {retEvalType: types.ETDuration, childrenTypes: []types.EvalType{types.ETDatetime}, geners: []dataGenerator{&dateTimeGenerWithFsp{ defaultGener: defaultGener{nullRation: 0.2, eType: types.ETDatetime},