Skip to content

Commit

Permalink
expression: let cast function supports explicit set charset (pingca…
Browse files Browse the repository at this point in the history
  • Loading branch information
Defined2014 authored and Benjamin2037 committed Sep 11, 2024
1 parent 468d429 commit 381d1ff
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pkg/expression/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ func genVecBuiltinFuncBenchCase(ctx BuildContext, funcName string, testCase vecE
case types.ETJson:
fc = &castAsJSONFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETString:
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, false}
}
baseFunc, err = fc.getFunction(ctx, cols)
} else if funcName == ast.GetVar {
Expand Down
29 changes: 29 additions & 0 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,35 @@ func newBaseBuiltinCastFunc(builtinFunc baseBuiltinFunc, inUnion bool) baseBuilt
}
}

func newBaseBuiltinCastFunc4String(ctx BuildContext, funcName string, args []Expression, tp *types.FieldType, isExplicitCharset bool) (baseBuiltinFunc, error) {
var bf baseBuiltinFunc
var err error
if isExplicitCharset {
bf = baseBuiltinFunc{
bufAllocator: newLocalColumnPool(),
childrenVectorizedOnce: new(sync.Once),

args: args,
tp: tp,
}
bf.SetCharsetAndCollation(tp.GetCharset(), tp.GetCollate())
bf.setCollator(collate.GetCollator(tp.GetCollate()))
bf.SetCoercibility(CoercibilityExplicit)
bf.SetExplicitCharset(true)
if tp.GetCharset() == charset.CharsetASCII {
bf.SetRepertoire(ASCII)
} else {
bf.SetRepertoire(UNICODE)
}
} else {
bf, err = newBaseBuiltinFunc(ctx, funcName, args, tp)
if err != nil {
return baseBuiltinFunc{}, err
}
}
return bf, nil
}

// vecBuiltinFunc contains all vectorized methods for a builtin function.
type vecBuiltinFunc interface {
// vectorized returns if this builtin function itself supports vectorized evaluation.
Expand Down
13 changes: 7 additions & 6 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,15 @@ func (c *castAsDecimalFunctionClass) getFunction(ctx BuildContext, args []Expres
type castAsStringFunctionClass struct {
baseFunctionClass

tp *types.FieldType
tp *types.FieldType
isExplicitCharset bool
}

func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFunc(ctx, c.funcName, args, c.tp)
bf, err := newBaseBuiltinCastFunc4String(ctx, c.funcName, args, c.tp, c.isExplicitCharset)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2265,7 +2266,7 @@ func CanImplicitEvalReal(expr Expression) bool {
// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union
// Expression.
func BuildCastFunction4Union(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) {
res, err := BuildCastFunctionWithCheck(ctx, expr, tp, true)
res, err := BuildCastFunctionWithCheck(ctx, expr, tp, true, false)
terror.Log(err)
return
}
Expand Down Expand Up @@ -2302,13 +2303,13 @@ func BuildCastCollationFunction(ctx BuildContext, expr Expression, ec *ExprColla

// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(ctx BuildContext, expr Expression, tp *types.FieldType) (res Expression) {
res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false)
res, err := BuildCastFunctionWithCheck(ctx, expr, tp, false, false)
terror.Log(err)
return
}

// BuildCastFunctionWithCheck builds a CAST ScalarFunction from the Expression and return error if any.
func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.FieldType, inUnion bool) (res Expression, err error) {
func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.FieldType, inUnion bool, isExplicitCharset bool) (res Expression, err error) {
argType := expr.GetType(ctx.GetEvalCtx())
// If source argument's nullable, then target type should be nullable
if !mysql.HasNotNullFlag(argType.GetFlag()) {
Expand Down Expand Up @@ -2336,7 +2337,7 @@ func BuildCastFunctionWithCheck(ctx BuildContext, expr Expression, tp *types.Fie
case types.ETVectorFloat32:
fc = &castAsVectorFloat32FunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
case types.ETString:
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp}
fc = &castAsStringFunctionClass{baseFunctionClass{ast.Cast, 1, 1}, tp, isExplicitCharset}
if expr.GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeBit {
tp.SetFlen((expr.GetType(ctx.GetEvalCtx()).GetFlen() + 7) / 8)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func TestCastFuncSig(t *testing.T) {
tp := types.NewFieldType(mysql.TypeVarString)
tp.SetCharset(charset.CharsetBin)
args := []Expression{c.before}
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false)
require.NoError(t, err)
switch i {
case 0:
Expand Down Expand Up @@ -742,7 +742,7 @@ func TestCastFuncSig(t *testing.T) {
tp := types.NewFieldType(mysql.TypeVarString)
tp.SetFlen(c.flen)
tp.SetCharset(charset.CharsetBin)
stringFunc, err := newBaseBuiltinFunc(ctx, "", args, tp)
stringFunc, err := newBaseBuiltinCastFunc4String(ctx, "", args, tp, false)
require.NoError(t, err)
switch i {
case 0:
Expand Down Expand Up @@ -1099,7 +1099,7 @@ func TestCastFuncSig(t *testing.T) {
// null case
args := []Expression{&Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}}
row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(nil)})
bf, err := newBaseBuiltinFunc(ctx, "", args, types.NewFieldType(mysql.TypeVarString))
bf, err := newBaseBuiltinCastFunc4String(ctx, "", args, types.NewFieldType(mysql.TypeVarString), false)
require.NoError(t, err)
sig = &builtinCastRealAsStringSig{bf}
sRes, err := evalBuiltinFunc(sig, ctx, row.ToRow())
Expand Down Expand Up @@ -1694,7 +1694,7 @@ func TestCastArrayFunc(t *testing.T) {
},
}
for _, tt := range tbl {
f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp, false)
f, err := BuildCastFunctionWithCheck(ctx, datumsToConstants(types.MakeDatums(types.CreateBinaryJSON(tt.input)))[0], tt.tp, false, false)
if !tt.buildFuncSuccess {
require.Error(t, err, tt.input)
continue
Expand Down
25 changes: 21 additions & 4 deletions pkg/expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ type collationInfo struct {

charset string
collation string

isExplicitCharset bool
}

// Hash64 implements the base.Hasher.<0th> interface.
Expand All @@ -55,6 +57,7 @@ func (c *collationInfo) Hash64(h base.Hasher) {
h.HashInt(int(c.repertoire))
h.HashString(c.charset)
h.HashString(c.collation)
h.HashBool(c.isExplicitCharset)
}

// Equals implements the base.Hasher.<1th> interface.
Expand All @@ -76,7 +79,8 @@ func (c *collationInfo) Equals(other any) bool {
c.coerInit.Load() == c2.coerInit.Load() &&
c.repertoire == c2.repertoire &&
c.charset == c2.charset &&
c.collation == c2.collation
c.collation == c2.collation &&
c.isExplicitCharset == c2.isExplicitCharset
}

func (c *collationInfo) HasCoercibility() bool {
Expand Down Expand Up @@ -109,6 +113,14 @@ func (c *collationInfo) CharsetAndCollation() (string, string) {
return c.charset, c.collation
}

func (c *collationInfo) IsExplicitCharset() bool {
return c.isExplicitCharset
}

func (c *collationInfo) SetExplicitCharset(explicit bool) {
c.isExplicitCharset = explicit
}

// CollationInfo contains all interfaces about dealing with collation.
type CollationInfo interface {
// HasCoercibility returns if the Coercibility value is initialized.
Expand All @@ -131,6 +143,12 @@ type CollationInfo interface {

// SetCharsetAndCollation sets charset and collation.
SetCharsetAndCollation(chs, coll string)

// IsExplicitCharset return the charset is explicit set or not.
IsExplicitCharset() bool

// SetExplicitCharset set the charset is explicit or not.
SetExplicitCharset(bool)
}

// Coercibility values are used to check whether the collation of one item can be coerced to
Expand Down Expand Up @@ -279,9 +297,8 @@ func deriveCollation(ctx BuildContext, funcName string, args []Expression, retTy
case ast.Cast:
// We assume all the cast are implicit.
ec = &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType(ctx.GetEvalCtx()).GetCharset(), args[0].GetType(ctx.GetEvalCtx()).GetCollate()}
// Non-string type cast to string type should use @@character_set_connection and @@collation_connection.
// String type cast to string type should keep its original charset and collation. It should not happen.
if retType == types.ETString && argTps[0] != types.ETString {
// Cast to string type should use @@character_set_connection and @@collation_connection.
if retType == types.ETString {
ec.Charset, ec.Collation = ctx.GetCharsetInfo()
}
return ec, nil
Expand Down
10 changes: 10 additions & 0 deletions pkg/expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,16 @@ func (sf *ScalarFunction) SetRepertoire(r Repertoire) {
sf.Function.SetRepertoire(r)
}

// IsExplicitCharset return the charset is explicit set or not.
func (sf *ScalarFunction) IsExplicitCharset() bool {
return sf.Function.IsExplicitCharset()
}

// SetExplicitCharset set the charset is explicit or not.
func (sf *ScalarFunction) SetExplicitCharset(explicit bool) {
sf.Function.SetExplicitCharset(explicit)
}

const emptyScalarFunctionSize = int64(unsafe.Sizeof(ScalarFunction{}))

// MemoryUsage return the memory usage of ScalarFunction
Expand Down
4 changes: 3 additions & 1 deletion pkg/expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,10 @@ func ColumnSubstituteImpl(ctx BuildContext, expr Expression, schema *Schema, new
if substituted {
flag := v.RetType.GetFlag()
var e Expression
var err error
if v.FuncName.L == ast.Cast {
e = BuildCastFunction(ctx, newArg, v.RetType)
e, err = BuildCastFunctionWithCheck(ctx, newArg, v.RetType, false, v.Function.IsExplicitCharset())
terror.Log(err)
} else {
// for grouping function recreation, use clone (meta included) instead of newFunction
e = v.Clone()
Expand Down
2 changes: 2 additions & 0 deletions pkg/expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ func (m *MockExpr) Coercibility() Coercibility { return
func (m *MockExpr) SetCoercibility(Coercibility) {}
func (m *MockExpr) Repertoire() Repertoire { return UNICODE }
func (m *MockExpr) SetRepertoire(Repertoire) {}
func (m *MockExpr) IsExplicitCharset() bool { return false }
func (m *MockExpr) SetExplicitCharset(bool) {}

func (m *MockExpr) CharsetAndCollation() (string, string) {
return "", ""
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
return retNode, false
}

castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, false)
castFunction, err := expression.BuildCastFunctionWithCheck(er.sctx, arg, v.Tp, false, v.ExplicitCharSet)
if err != nil {
er.err = err
return retNode, false
Expand Down
1 change: 1 addition & 0 deletions tests/integrationtest/r/executor/executor.result
Original file line number Diff line number Diff line change
Expand Up @@ -4379,4 +4379,5 @@ LOCK TABLE executor__executor.t WRITE, test2.t2 WRITE;
LOCK TABLE executor__executor.t WRITE, test2.t2 WRITE;
Error 8020 (HY000): Table 't' was locked in WRITE by server: <server> session: <session>
unlock tables;
unlock tables;
drop user 'testuser'@'localhost';
30 changes: 30 additions & 0 deletions tests/integrationtest/r/expression/cast.result
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,33 @@ select 1.194192591e9 > t0.c0 from t0;
select 1.194192591e9 < t0.c0 from t0;
1.194192591e9 < t0.c0
0
drop table if exists test;
CREATE TABLE `test` (
`id` bigint(20) NOT NULL,
`update_user` varchar(32) DEFAULT NULL,
PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char character set gbk) desc , id limit 3;
id update_user
1 张三
3 张三
2 李四
drop table test;
CREATE TABLE `test` (
`id` bigint NOT NULL,
`update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char) desc , id limit 3;
id update_user
2 李四
4 李四
1 张三
3 changes: 3 additions & 0 deletions tests/integrationtest/t/executor/executor.test
Original file line number Diff line number Diff line change
Expand Up @@ -2709,6 +2709,9 @@ connection default;
--error 8020
LOCK TABLE executor__executor.t WRITE, test2.t2 WRITE;

connection conn1;
unlock tables;

disconnect conn1;
unlock tables;
drop user 'testuser'@'localhost';
26 changes: 26 additions & 0 deletions tests/integrationtest/t/expression/cast.test
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,29 @@ select t0.c0 > 1.194192591e9 from t0;
select t0.c0 < 1.194192591e9 from t0;
select 1.194192591e9 > t0.c0 from t0;
select 1.194192591e9 < t0.c0 from t0;

# TestCastAsStringExplicitCharSet
drop table if exists test;
CREATE TABLE `test` (
`id` bigint(20) NOT NULL,
`update_user` varchar(32) DEFAULT NULL,
PRIMARY KEY (`id`) /*T![clustered_index] CLUSTERED */
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char character set gbk) desc , id limit 3;

drop table test;
CREATE TABLE `test` (
`id` bigint NOT NULL,
`update_user` varchar(32) CHARACTER SET gbk COLLATE gbk_chinese_ci DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
insert into test values(1,'张三');
insert into test values(2,'李四');
insert into test values(3,'张三');
insert into test values(4,'李四');
select * from test order by cast(update_user as char) desc , id limit 3;

0 comments on commit 381d1ff

Please sign in to comment.