diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index f2d21af1111..e7563e8f258 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -815,6 +815,18 @@ func (cached *builtinDegrees) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinElt) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinExp) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -827,6 +839,18 @@ func (cached *builtinExp) CachedSize(alloc bool) int64 { size += cached.CallExpr.CachedSize(false) return size } +func (cached *builtinField) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr + size += cached.CallExpr.CachedSize(false) + return size +} func (cached *builtinFloor) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 2bda9826be9..babf23feb5a 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2347,6 +2347,167 @@ func (asm *assembler) Fn_BIT_LENGTH() { }, "FN BIT_LENGTH VARCHAR(SP-1)") } +func (asm *assembler) Fn_FIELD_i(args int) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } + + tar := env.vm.stack[env.vm.sp-args].(*evalInt64) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalInt64) + + if tar.i == arg.i { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD INT64(SP-%d)...INT64(SP-1)", args) +} + +func (asm *assembler) Fn_FIELD_b(args int, col colldata.Collation) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } + + tar := env.vm.stack[env.vm.sp-args].(*evalBytes) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + str := env.vm.stack[env.vm.sp-args+i+1].(*evalBytes) + + // We cannot do these comparison earlier in the compilation, + // because if we convert everything first, we error on cases + // where there is a match. MySQL will do an element for element + // comparison where if there's a match already, it doesn't matter + // if there was an invalid conversion later on. + // + // This means we also must convert here in this compiler function + // and can't eagerly do the conversion. + toCharset := col.Charset() + fromCharset := colldata.Lookup(str.col.Collation).Charset() + if fromCharset != toCharset && !toCharset.IsSuperset(fromCharset) { + str, env.vm.err = evalToVarchar(str, col.ID(), true) + if env.vm.err != nil { + env.vm.stack[env.vm.sp-args] = nil + env.vm.sp -= args - 1 + return 1 + } + } + + // Compare target and current string + if col.Collate(tar.bytes, str.bytes, false) == 0 { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD VARCHAR(SP-%d)...VARCHAR(SP-1)", args) +} + +func (asm *assembler) Fn_FIELD_d(args int) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } + + tar := env.vm.stack[env.vm.sp-args].(*evalDecimal) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalDecimal) + + if tar.dec.Equal(arg.dec) { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD DECIMAL(SP-%d)...DECIMAL(SP-1)", args) +} + +func (asm *assembler) Fn_FIELD_f(args int) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + if env.vm.stack[env.vm.sp-args] == nil { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + } + + tar := env.vm.stack[env.vm.sp-args].(*evalFloat) + + for i := range args - 1 { + if env.vm.stack[env.vm.sp-args+i+1] == nil { + continue + } + + arg := env.vm.stack[env.vm.sp-args+i+1].(*evalFloat) + + if tar.f == arg.f { + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(int64(i + 1)) + env.vm.sp -= args - 1 + return 1 + } + } + + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalInt64(0) + env.vm.sp -= args - 1 + return 1 + }, "FN FIELD FLOAT64(SP-%d)...FLOAT64(SP-1)", args) +} + +func (asm *assembler) Fn_ELT(args int, tt sqltypes.Type, tc collations.TypedCollation) { + asm.adjustStack(-args + 1) + asm.emit(func(env *ExpressionEnv) int { + i := env.vm.stack[env.vm.sp-args].(*evalInt64) + + if i.i < 1 || int(i.i) >= args || env.vm.stack[env.vm.sp-args+int(i.i)] == nil { + env.vm.stack[env.vm.sp-args] = nil + } else { + b := env.vm.stack[env.vm.sp-args+int(i.i)].(*evalBytes) + env.vm.stack[env.vm.sp-args] = env.vm.arena.newEvalRaw(b.bytes, tt, tc) + } + + env.vm.sp -= args - 1 + return 1 + }, "FN ELT INT64(SP-%d) VARCHAR(SP-%d)...VARCHAR(SP-1)", args, args-1) +} + func (asm *assembler) Fn_INSERT(col collations.TypedCollation) { asm.adjustStack(-3) diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index e0887037c0a..663475327e5 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -30,6 +30,16 @@ import ( ) type ( + builtinField struct { + CallExpr + collate collations.ID + } + + builtinElt struct { + CallExpr + collate collations.ID + } + builtinInsert struct { CallExpr collate collations.ID @@ -141,6 +151,8 @@ type ( } ) +var _ IR = (*builtinField)(nil) +var _ IR = (*builtinElt)(nil) var _ IR = (*builtinInsert)(nil) var _ IR = (*builtinChangeCase)(nil) var _ IR = (*builtinCharLength)(nil) @@ -164,6 +176,296 @@ var _ IR = (*builtinConcat)(nil) var _ IR = (*builtinConcatWs)(nil) var _ IR = (*builtinReplace)(nil) +func fieldSQLType(arg sqltypes.Type, tt sqltypes.Type) sqltypes.Type { + if sqltypes.IsNull(arg) { + // If we have a NULL combined with only so far numerical types, + // we have to convert it all to DOUBLE. + if sqltypes.IsIntegral(tt) || sqltypes.IsDecimal(tt) { + return sqltypes.Float64 + } + return tt + } + + if typeIsTextual(arg) && typeIsTextual(tt) { + return sqltypes.VarChar + } else if sqltypes.IsIntegral(arg) && sqltypes.IsIntegral(tt) { + return sqltypes.Int64 + } + + if (sqltypes.IsIntegral(arg) || sqltypes.IsDecimal(arg)) && (sqltypes.IsIntegral(tt) || sqltypes.IsDecimal(tt)) { + return sqltypes.Decimal + } + + return sqltypes.Float64 +} + +func (call *builtinField) eval(env *ExpressionEnv) (eval, error) { + args, err := call.args(env) + if err != nil { + return nil, err + } + if args[0] == nil { + return newEvalInt64(0), nil + } + + // If the arguments contain both integral and string values + // MySQL converts all the arguments to DOUBLE + tt := args[0].SQLType() + + for _, arg := range args[1:] { + var at sqltypes.Type + if arg == nil { + at = sqltypes.Null + } else { + at = arg.SQLType() + } + + tt = fieldSQLType(at, tt) + } + + if tt == sqltypes.Int64 { + tar := evalToInt64(args[0]) + + for i, arg := range args[1:] { + if arg == nil { + continue + } + + e := evalToInt64(arg) + if tar.i == e.i { + return newEvalInt64(int64(i + 1)), nil + } + } + } else if tt == sqltypes.VarChar { + col := evalCollation(args[0]) + collation := colldata.Lookup(col.Collation) + tar := args[0].(*evalBytes) + + for i, arg := range args[1:] { + if arg == nil { + continue + } + + e, err := evalToVarchar(arg, col.Collation, true) + if err != nil { + return nil, err + } + + // Compare target and current string + if collation.Collate(tar.bytes, e.bytes, false) == 0 { + return newEvalInt64(int64(i + 1)), nil + } + } + } else if tt == sqltypes.Decimal { + tar := evalToDecimal(args[0], 0, 0) + + for i, arg := range args[1:] { + if arg == nil { + continue + } + + e := evalToDecimal(arg, 0, 0) + if tar.dec.Equal(e.dec) { + return newEvalInt64(int64(i + 1)), nil + } + } + } else { + tar, _ := evalToFloat(args[0]) + + for i, arg := range args[1:] { + if arg == nil { + continue + } + + e, _ := evalToFloat(arg) + if tar.f == e.f { + return newEvalInt64(int64(i + 1)), nil + } + } + } + + return newEvalInt64(0), nil +} + +func (call *builtinField) compile(c *compiler) (ctype, error) { + strs := make([]ctype, len(call.Arguments)) + + for i, arg := range call.Arguments { + var err error + strs[i], err = arg.compile(c) + if err != nil { + return ctype{}, err + } + } + + // If the arguments contain both integral and string values + // MySQL converts all the arguments to DOUBLE + tt := strs[0].Type + col := strs[0].Col + + for _, str := range strs { + tt = fieldSQLType(str.Type, tt) + } + + if tt == sqltypes.Int64 { + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) + + switch str.Type { + case sqltypes.Int64: + default: + c.asm.Convert_xi(offset) + } + c.asm.jumpDestination(skip) + } + + c.asm.Fn_FIELD_i(len(call.Arguments)) + } else if tt == sqltypes.VarChar { + collation := colldata.Lookup(col.Collation) + c.asm.Fn_FIELD_b(len(call.Arguments), collation) + } else if tt == sqltypes.Decimal { + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) + + switch str.Type { + case sqltypes.Decimal: + default: + c.asm.Convert_xd(offset, 0, 0) + } + c.asm.jumpDestination(skip) + } + + c.asm.Fn_FIELD_d(len(call.Arguments)) + } else { + for i, str := range strs { + offset := len(strs) - i + skip := c.compileNullCheckOffset(str, offset) + + switch str.Type { + case sqltypes.Float64: + default: + c.asm.Convert_xf(offset) + } + + c.asm.jumpDestination(skip) + } + + c.asm.Fn_FIELD_f(len(call.Arguments)) + } + + return ctype{Type: sqltypes.Int64, Col: collationNumeric}, nil +} + +func (call *builtinElt) eval(env *ExpressionEnv) (eval, error) { + var ca collationAggregation + tt := sqltypes.VarChar + + args, err := call.args(env) + if err != nil { + return nil, err + } + + if args[0] == nil { + return nil, nil + } + + i := evalToInt64(args[0]).i + if i < 1 || i >= int64(len(args)) || args[i] == nil { + return nil, nil + } + + for _, arg := range args[1:] { + if arg == nil { + continue + } + + tt = concatSQLType(arg.SQLType(), tt) + err = ca.add(evalCollation(arg), env.collationEnv) + if err != nil { + return nil, err + } + } + + tc := ca.result() + // If we only had numbers, we instead fall back to the default + // collation instead of using the numeric collation. + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(tt, call.collate) + } + + b, err := evalToVarchar(args[i], tc.Collation, true) + if err != nil { + return nil, err + } + + return newEvalRaw(tt, b.bytes, b.col), nil +} + +func (call *builtinElt) compile(c *compiler) (ctype, error) { + args := make([]ctype, len(call.Arguments)) + + var ca collationAggregation + tt := sqltypes.VarChar + + var skip *jump + for i, arg := range call.Arguments { + var err error + args[i], err = arg.compile(c) + if err != nil { + return ctype{}, nil + } + + if i == 0 { + skip = c.compileNullCheck1(args[i]) + continue + } + + tt = concatSQLType(args[i].Type, tt) + err = ca.add(args[i].Col, c.env.CollationEnv()) + if err != nil { + return ctype{}, err + } + } + + tc := ca.result() + // If we only had numbers, we instead fall back to the default + // collation instead of using the numeric collation. + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(tt, call.collate) + } + + _ = c.compileToInt64(args[0], len(args)) + + for i, arg := range args[1:] { + offset := len(args) - (i + 1) + skip := c.compileNullCheckOffset(arg, offset) + + switch arg.Type { + case sqltypes.VarBinary, sqltypes.Binary, sqltypes.Blob: + if tc.Collation != collations.CollationBinaryID { + c.asm.Convert_xce(offset, arg.Type, tc.Collation) + } + case sqltypes.VarChar, sqltypes.Char, sqltypes.Text: + fromCharset := colldata.Lookup(arg.Col.Collation).Charset() + toCharset := colldata.Lookup(tc.Collation).Charset() + if fromCharset != toCharset && !toCharset.IsSuperset(fromCharset) { + c.asm.Convert_xce(offset, arg.Type, tc.Collation) + } + default: + c.asm.Convert_xce(offset, arg.Type, tc.Collation) + } + + c.asm.jumpDestination(skip) + } + + c.asm.Fn_ELT(len(args), tt, tc) + c.asm.jumpDestination(skip) + + return ctype{Type: tt, Col: tc, Flag: flagNullable}, nil +} + func insert(str, newstr *evalBytes, pos, l int) []byte { pos-- diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 7d0139e6bbb..64dbd773a44 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -63,6 +63,8 @@ var Cases = []TestCase{ {Run: TupleComparisons}, {Run: Comparisons}, {Run: InStatement}, + {Run: FnField}, + {Run: FnElt}, {Run: FnInsert}, {Run: FnLower}, {Run: FnUpper}, @@ -1339,6 +1341,90 @@ var JSONExtract_Schema = []*querypb.Field{ }, } +func FnField(yield Query) { + for _, s1 := range inputStrings { + for _, s2 := range inputStrings { + for _, s3 := range inputStrings { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } + + for _, s1 := range radianInputs { + for _, s2 := range radianInputs { + for _, s3 := range radianInputs { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } + + // Contains failing testcases + for _, s1 := range inputStrings { + for _, s2 := range radianInputs { + for _, s3 := range inputStrings { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } + + // Contains failing testcases + for _, s1 := range inputBitwise { + for _, s2 := range inputBitwise { + for _, s3 := range inputBitwise { + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + } + } + } + + mysqlDocSamples := []string{ + "FIELD('Bb', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", + "FIELD('Gg', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", + } + for _, q := range mysqlDocSamples { + yield(q, nil) + } +} + +func FnElt(yield Query) { + for _, s1 := range inputStrings { + for _, n := range inputBitwise { + yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil) + } + } + + for _, s1 := range inputStrings { + for _, s2 := range inputStrings { + for _, n := range inputBitwise { + yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil) + } + } + } + + validIndex := []string{ + "1", + "2", + "3", + } + for _, s1 := range inputStrings { + for _, s2 := range inputStrings { + for _, s3 := range inputStrings { + for _, n := range validIndex { + yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil) + } + } + } + } + + mysqlDocSamples := []string{ + "ELT(1, 'Aa', 'Bb', 'Cc', 'Dd')", + "ELT(4, 'Aa', 'Bb', 'Cc', 'Dd')", + } + + for _, q := range mysqlDocSamples { + yield(q, nil) + } +} + func FnInsert(yield Query) { for _, s := range insertStrings { for _, ns := range insertStrings { diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 710245257ed..4f7ba1a451c 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -278,6 +278,16 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) { return nil, argError(method) } return &builtinPad{CallExpr: call, collate: ast.cfg.Collation, left: method == "lpad"}, nil + case "field": + if len(args) < 2 { + return nil, argError(method) + } + return &builtinField{CallExpr: call, collate: ast.cfg.Collation}, nil + case "elt": + if len(args) < 2 { + return nil, argError(method) + } + return &builtinElt{CallExpr: call, collate: ast.cfg.Collation}, nil case "lower", "lcase": if len(args) != 1 { return nil, argError(method)