Skip to content

Commit

Permalink
*: refactor encoding and uniform usages (#30288)
Browse files Browse the repository at this point in the history
  • Loading branch information
tangenta authored Dec 20, 2021
1 parent e3c56b7 commit ab35db1
Show file tree
Hide file tree
Showing 29 changed files with 1,051 additions and 1,028 deletions.
9 changes: 5 additions & 4 deletions cmd/explaintest/r/new_character_set_builtin.result
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
set @@sql_mode = '';
drop table if exists t;
create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20));
insert into t values ('一二三', '一二三', '一二三');
Expand Down Expand Up @@ -244,17 +245,17 @@ insert into t values ('65'), ('123456'), ('123456789');
select char(a using gbk), char(a using utf8), char(a) from t;
char(a using gbk) char(a using utf8) char(a)
A A A
釦 �@ �@
NULL [� [�
釦  �@
[ [ [�
select char(12345678 using gbk);
char(12345678 using gbk)
糰N
set @@tidb_enable_vectorized_expression = true;
select char(a using gbk), char(a using utf8), char(a) from t;
char(a using gbk) char(a using utf8) char(a)
A A A
釦 �@ �@
NULL [� [�
釦  �@
[ [ [�
select char(12345678 using gbk);
char(12345678 using gbk)
糰N
Expand Down
1 change: 1 addition & 0 deletions cmd/explaintest/t/new_character_set_builtin.test
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
set @@sql_mode = '';
-- test for builtin function hex(), length(), ascii(), octet_length()
drop table if exists t;
create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20));
Expand Down
152 changes: 102 additions & 50 deletions expression/builtin_convert_charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package expression

import (
"fmt"
"unicode/utf8"

"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/ast"
Expand All @@ -27,6 +26,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/dbterror"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tipb/go-tipb"
)

Expand Down Expand Up @@ -92,9 +92,9 @@ func (b *builtinInternalToBinarySig) evalString(row chunk.Row) (res string, isNu
return res, isNull, err
}
tp := b.args[0].GetType()
enc := charset.NewEncoding(tp.Charset)
res, err = enc.EncodeString(val)
return res, false, err
enc := charset.FindEncoding(tp.Charset)
ret, err := enc.Transform(nil, hack.Slice(val), charset.OpEncode)
return string(ret), false, err
}

func (b *builtinInternalToBinarySig) vectorized() bool {
Expand All @@ -111,19 +111,19 @@ func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *c
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}
enc := charset.NewEncoding(b.args[0].GetType().Charset)
enc := charset.FindEncoding(b.args[0].GetType().Charset)
result.ReserveString(n)
var encodedBuf []byte
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}
strBytes, err := enc.Encode(encodedBuf, buf.GetBytes(i))
encodedBuf, err = enc.Transform(encodedBuf, buf.GetBytes(i), charset.OpEncode)
if err != nil {
return err
}
result.AppendBytes(strBytes)
result.AppendBytes(encodedBuf)
}
return nil
}
Expand Down Expand Up @@ -170,9 +170,13 @@ func (b *builtinInternalFromBinarySig) evalString(row chunk.Row) (res string, is
if isNull || err != nil {
return val, isNull, err
}
transferString := b.getTransferFunc()
tBytes, err := transferString([]byte(val))
return string(tBytes), false, err
enc := charset.FindEncoding(b.tp.Charset)
ret, err := enc.Transform(nil, hack.Slice(val), charset.OpDecode)
if err != nil {
strHex := fmt.Sprintf("%X", val)
err = errCannotConvertString.GenWithStackByArgs(strHex, charset.CharsetBin, b.tp.Charset)
}
return string(ret), false, err
}

func (b *builtinInternalFromBinarySig) vectorized() bool {
Expand All @@ -189,45 +193,25 @@ func (b *builtinInternalFromBinarySig) vecEvalString(input *chunk.Chunk, result
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}
transferString := b.getTransferFunc()
enc := charset.FindEncoding(b.tp.Charset)
var encBuf []byte
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}
str, err := transferString(buf.GetBytes(i))
str := buf.GetBytes(i)
encBuf, err = enc.Transform(encBuf, str, charset.OpDecode)
if err != nil {
return err
strHex := fmt.Sprintf("%X", str)
return errCannotConvertString.GenWithStackByArgs(strHex, charset.CharsetBin, b.tp.Charset)
}
result.AppendBytes(str)
result.AppendBytes(encBuf)
}
return nil
}

func (b *builtinInternalFromBinarySig) getTransferFunc() func([]byte) ([]byte, error) {
var transferString func([]byte) ([]byte, error)
if b.tp.Charset == charset.CharsetUTF8MB4 || b.tp.Charset == charset.CharsetUTF8 {
transferString = func(s []byte) ([]byte, error) {
if !utf8.Valid(s) {
return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset)
}
return s, nil
}
} else {
enc := charset.NewEncoding(b.tp.Charset)
var buf []byte
transferString = func(s []byte) ([]byte, error) {
str, err := enc.Decode(buf, s)
if err != nil {
return nil, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), charset.CharsetBin, b.tp.Charset)
}
return str, nil
}
}
return transferString
}

// BuildToBinaryFunction builds to_binary function.
func BuildToBinaryFunction(ctx sessionctx.Context, expr Expression) (res Expression) {
fc := &tidbToBinaryFunctionClass{baseFunctionClass{InternalFuncToBinary, 1, 1}}
Expand Down Expand Up @@ -258,26 +242,94 @@ func BuildFromBinaryFunction(ctx sessionctx.Context, expr Expression, tp *types.
return FoldConstant(res)
}

type funcProp int8

const (
funcPropNone funcProp = iota
// The arguments of these functions are wrapped with to_binary().
// For compatibility reason, legacy charsets arguments are not wrapped.
// Legacy charsets: utf8mb4, utf8, latin1, ascii, binary.
funcPropBinAware
// The arguments of these functions are wrapped with to_binary() or from_binary() according to
// the evaluated result charset and the argument charset.
// For binary argument && string result, wrap it with from_binary().
// For string argument && binary result, wrap it with to_binary().
funcPropAuto
)

// convertActionMap collects from https://dev.mysql.com/doc/refman/8.0/en/string-functions.html.
var convertActionMap = map[funcProp][]string{
funcPropNone: {
/* args != strings */
ast.Bin, ast.CharFunc, ast.DateFormat, ast.Oct, ast.Space,
/* only 1 string arg, no implicit conversion */
ast.CharLength, ast.CharacterLength, ast.FromBase64, ast.Lcase, ast.Left, ast.LoadFile,
ast.Lower, ast.LTrim, ast.Mid, ast.Ord, ast.Quote, ast.Repeat, ast.Reverse, ast.Right,
ast.RTrim, ast.Soundex, ast.Substr, ast.Substring, ast.Ucase, ast.Unhex, ast.Upper, ast.WeightString,
/* args are independent, no implicit conversion */
ast.Elt,
},
funcPropBinAware: {
/* result is binary-aware */
ast.ASCII, ast.BitLength, ast.Hex, ast.Length, ast.OctetLength, ast.ToBase64,
/* encrypt functions */
ast.AesDecrypt, ast.Decode, ast.Encode, ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1,
ast.SHA2, ast.Compress, ast.AesEncrypt,
},
funcPropAuto: {
/* string functions */ ast.Concat, ast.ConcatWS, ast.ExportSet, ast.Field, ast.FindInSet,
ast.InsertFunc, ast.Instr, ast.Lpad, ast.Locate, ast.Lpad, ast.MakeSet, ast.Position,
ast.Replace, ast.Rpad, ast.SubstringIndex, ast.Trim,
/* operators */
ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, ast.NE, ast.NullEQ, ast.If, ast.Ifnull, ast.In,
ast.Case,
/* string comparing */
ast.Like, ast.Strcmp,
/* regex */
ast.Regexp,
},
}

var convertFuncsMap = map[string]funcProp{}

func init() {
for k, fns := range convertActionMap {
for _, f := range fns {
convertFuncsMap[f] = k
}
}
}

// HandleBinaryLiteral wraps `expr` with to_binary or from_binary sig.
func HandleBinaryLiteral(ctx sessionctx.Context, expr Expression, ec *ExprCollation, funcName string) Expression {
switch funcName {
case ast.Concat, ast.ConcatWS, ast.Lower, ast.Lcase, ast.Reverse, ast.Upper, ast.Ucase, ast.Quote, ast.Coalesce,
ast.Left, ast.Right, ast.Repeat, ast.Trim, ast.LTrim, ast.RTrim, ast.Substr, ast.SubstringIndex, ast.Replace,
ast.Substring, ast.Mid, ast.Translate, ast.InsertFunc, ast.Lpad, ast.Rpad, ast.Elt, ast.ExportSet, ast.MakeSet,
ast.FindInSet, ast.Regexp, ast.Field, ast.Locate, ast.Instr, ast.Position, ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ,
ast.NE, ast.NullEQ, ast.Strcmp, ast.If, ast.Ifnull, ast.Like, ast.In, ast.DateFormat, ast.TimeFormat:
if ec.Charset == charset.CharsetBin && expr.GetType().Charset != charset.CharsetBin {
argChs, dstChs := expr.GetType().Charset, ec.Charset
switch convertFuncsMap[funcName] {
case funcPropNone:
return expr
case funcPropBinAware:
if isLegacyCharset(argChs) {
return expr
}
return BuildToBinaryFunction(ctx, expr)
case funcPropAuto:
if argChs != charset.CharsetBin && dstChs == charset.CharsetBin {
if isLegacyCharset(argChs) {
return expr
}
return BuildToBinaryFunction(ctx, expr)
} else if ec.Charset != charset.CharsetBin && expr.GetType().Charset == charset.CharsetBin {
} else if argChs == charset.CharsetBin && dstChs != charset.CharsetBin {
ft := expr.GetType().Clone()
ft.Charset, ft.Collate = ec.Charset, ec.Collation
return BuildFromBinaryFunction(ctx, expr, ft)
}
case ast.Hex, ast.Length, ast.OctetLength, ast.ASCII, ast.ToBase64, ast.AesEncrypt, ast.AesDecrypt, ast.Decode, ast.Encode,
ast.PasswordFunc, ast.MD5, ast.SHA, ast.SHA1, ast.SHA2, ast.Compress:
if _, err := charset.GetDefaultCollationLegacy(expr.GetType().Charset); err != nil {
return BuildToBinaryFunction(ctx, expr)
}
}
return expr
}

func isLegacyCharset(chs string) bool {
switch chs {
case charset.CharsetUTF8, charset.CharsetUTF8MB4, charset.CharsetASCII, charset.CharsetLatin1, charset.CharsetBin:
return true
}
return false
}
46 changes: 27 additions & 19 deletions expression/builtin_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package expression

import (
"encoding/hex"
"fmt"
"strings"
"testing"

Expand Down Expand Up @@ -91,9 +92,10 @@ func TestSQLEncode(t *testing.T) {
d, err := f.Eval(chunk.Row{})
require.NoError(t, err)
if test.origin != nil {
result, err := charset.NewEncoding(test.chs).EncodeString(test.origin.(string))
enc := charset.FindEncoding(test.chs)
result, err := enc.Transform(nil, []byte(test.origin.(string)), charset.OpEncode)
require.NoError(t, err)
require.Equal(t, types.NewCollationStringDatum(result, test.chs), d)
require.Equal(t, types.NewCollationStringDatum(string(result), test.chs), d)
} else {
result := types.NewDatum(test.origin)
require.Equal(t, result.GetBytes(), d.GetBytes())
Expand Down Expand Up @@ -163,7 +165,8 @@ func TestAESEncrypt(t *testing.T) {
testAmbiguousInput(t, ctx, ast.AesEncrypt)

// Test GBK String
gbkStr, _ := charset.NewEncoding("gbk").EncodeString("你好")
enc := charset.FindEncoding("gbk")
gbkStr, _ := enc.Transform(nil, []byte("你好"), charset.OpEncode)
gbkTests := []struct {
mode string
chs string
Expand All @@ -188,19 +191,20 @@ func TestAESEncrypt(t *testing.T) {
}

for _, tt := range gbkTests {
msg := fmt.Sprintf("%v", tt)
err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs)
require.NoError(t, err)
require.NoError(t, err, msg)
err = variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode)
require.NoError(t, err)
require.NoError(t, err, msg)

args := datumsToConstants([]types.Datum{types.NewDatum(tt.origin)})
args := primitiveValsToConstants(ctx, []interface{}{tt.origin})
args = append(args, primitiveValsToConstants(ctx, tt.params)...)
f, err := fc.getFunction(ctx, args)

require.NoError(t, err)
require.NoError(t, err, msg)
crypt, err := evalBuiltinFunc(f, chunk.Row{})
require.NoError(t, err)
require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt))
require.NoError(t, err, msg)
require.Equal(t, types.NewDatum(tt.crypt), toHex(crypt), msg)
}
}

Expand All @@ -209,29 +213,32 @@ func TestAESDecrypt(t *testing.T) {

fc := funcs[ast.AesDecrypt]
for _, tt := range aesTests {
msg := fmt.Sprintf("%v", tt)
err := variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode)
require.NoError(t, err)
require.NoError(t, err, msg)
args := []types.Datum{fromHex(tt.crypt)}
for _, param := range tt.params {
args = append(args, types.NewDatum(param))
}
f, err := fc.getFunction(ctx, datumsToConstants(args))
require.NoError(t, err)
require.NoError(t, err, msg)
str, err := evalBuiltinFunc(f, chunk.Row{})
require.NoError(t, err)
require.NoError(t, err, msg)
if tt.origin == nil {
require.True(t, str.IsNull())
continue
}
require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str)
require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str, msg)
}
err := variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, "aes-128-ecb")
require.NoError(t, err)
testNullInput(t, ctx, ast.AesDecrypt)
testAmbiguousInput(t, ctx, ast.AesDecrypt)

// Test GBK String
gbkStr, _ := charset.NewEncoding("gbk").EncodeString("你好")
enc := charset.FindEncoding("gbk")
r, _ := enc.Transform(nil, []byte("你好"), charset.OpEncode)
gbkStr := string(r)
gbkTests := []struct {
mode string
chs string
Expand All @@ -256,18 +263,19 @@ func TestAESDecrypt(t *testing.T) {
}

for _, tt := range gbkTests {
msg := fmt.Sprintf("%v", tt)
err := ctx.GetSessionVars().SetSystemVar(variable.CharacterSetConnection, tt.chs)
require.NoError(t, err)
require.NoError(t, err, msg)
err = variable.SetSessionSystemVar(ctx.GetSessionVars(), variable.BlockEncryptionMode, tt.mode)
require.NoError(t, err)
require.NoError(t, err, msg)
// Set charset and collate except first argument
args := datumsToConstants([]types.Datum{fromHex(tt.crypt)})
args = append(args, primitiveValsToConstants(ctx, tt.params)...)
f, err := fc.getFunction(ctx, args)
require.NoError(t, err)
require.NoError(t, err, msg)
str, err := evalBuiltinFunc(f, chunk.Row{})
require.NoError(t, err)
require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str)
require.NoError(t, err, msg)
require.Equal(t, types.NewCollationStringDatum(tt.origin.(string), charset.CollationBin), str, msg)
}
}

Expand Down
Loading

0 comments on commit ab35db1

Please sign in to comment.