diff --git a/executor/aggfuncs/func_group_concat.go b/executor/aggfuncs/func_group_concat.go index a7fb3b7359ca5..4cb3993be65ad 100644 --- a/executor/aggfuncs/func_group_concat.go +++ b/executor/aggfuncs/func_group_concat.go @@ -293,6 +293,7 @@ type topNRows struct { // ('---', 'ccc') should be poped from heap, so '-' should be appended to result. // eg: 'aaa---bbb---ccc' -> 'aaa---bbb-' isSepTruncated bool + collators []collate.Collator } func (h topNRows) Len() int { @@ -302,7 +303,7 @@ func (h topNRows) Len() int { func (h topNRows) Less(i, j int) bool { n := len(h.rows[i].byItems) for k := 0; k < n; k++ { - ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, h.rows[j].byItems[k]) + ret, err := h.rows[i].byItems[k].Compare(h.sctx.GetSessionVars().StmtCtx, h.rows[j].byItems[k], h.collators[k]) if err != nil { h.err = err return false @@ -411,8 +412,10 @@ func (e *groupConcatOrder) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P func (e *groupConcatOrder) AllocPartialResult() (pr PartialResult, memDelta int64) { desc := make([]bool, len(e.byItems)) + ctors := make([]collate.Collator, 0, len(e.byItems)) for i, byItem := range e.byItems { desc[i] = byItem.Desc + ctors = append(ctors, collate.GetCollator(byItem.Expr.GetType().Collate)) } p := &partialResult4GroupConcatOrder{ topN: &topNRows{ @@ -421,6 +424,7 @@ func (e *groupConcatOrder) AllocPartialResult() (pr PartialResult, memDelta int6 limitSize: e.maxLen, sepSize: uint64(len(e.sep)), isSepTruncated: false, + collators: ctors, }, } return PartialResult(p), DefPartialResult4GroupConcatOrderSize + DefTopNRowsSize @@ -513,8 +517,10 @@ func (e *groupConcatDistinctOrder) AppendFinalResult2Chunk(sctx sessionctx.Conte func (e *groupConcatDistinctOrder) AllocPartialResult() (pr PartialResult, memDelta int64) { desc := make([]bool, len(e.byItems)) + ctors := make([]collate.Collator, 0, len(e.byItems)) for i, byItem := range e.byItems { desc[i] = byItem.Desc + ctors = append(ctors, collate.GetCollator(byItem.Expr.GetType().Collate)) } valSet, setSize := set.NewStringSetWithMemoryUsage() p := &partialResult4GroupConcatOrderDistinct{ @@ -524,6 +530,7 @@ func (e *groupConcatDistinctOrder) AllocPartialResult() (pr PartialResult, memDe limitSize: e.maxLen, sepSize: uint64(len(e.sep)), isSepTruncated: false, + collators: ctors, }, valSet: valSet, } diff --git a/executor/builder.go b/executor/builder.go index 1912f17eb5458..8fcf0cbb2a252 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -55,6 +55,7 @@ import ( "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/admin" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/cteutil" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/logutil" @@ -1059,6 +1060,11 @@ func (b *executorBuilder) buildUnionScanFromReader(reader Executor, v *plannerco reader = sel.children[0] } + us.collators = make([]collate.Collator, 0, len(us.columns)) + for _, tp := range retTypes(us) { + us.collators = append(us.collators, collate.GetCollator(tp.Collate)) + } + switch x := reader.(type) { case *TableReaderExecutor: us.desc = x.desc diff --git a/executor/union_scan.go b/executor/union_scan.go index c796d36bb6d31..9f15ef793090b 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" ) // UnionScanExec merges the rows from dirty table and the rows from distsql request. @@ -59,6 +60,7 @@ type UnionScanExec struct { // cacheTable not nil means it's reading from cached table. cacheTable kv.MemBuffer + collators []collate.Collator } // Open implements the Executor Open interface. @@ -273,7 +275,7 @@ func (us *UnionScanExec) compare(a, b []types.Datum) (int, error) { for _, colOff := range us.usedIndex { aColumn := a[colOff] bColumn := b[colOff] - cmp, err := aColumn.CompareDatum(sc, &bColumn) + cmp, err := aColumn.Compare(sc, &bColumn, us.collators[colOff]) if err != nil { return 0, err } @@ -281,5 +283,5 @@ func (us *UnionScanExec) compare(a, b []types.Datum) (int, error) { return cmp, nil } } - return us.belowHandleCols.Compare(a, b) + return us.belowHandleCols.Compare(a, b, us.collators) } diff --git a/expression/aggregation/aggregation.go b/expression/aggregation/aggregation.go index 3a52e6719f087..84380552d7f71 100644 --- a/expression/aggregation/aggregation.go +++ b/expression/aggregation/aggregation.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tipb/go-tipb" ) @@ -68,9 +69,9 @@ func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, sc *stmtctx.St case tipb.ExprType_GroupConcat: return &concatFunction{aggFunction: newAggFunc(ast.AggFuncGroupConcat, args, false)}, nil case tipb.ExprType_Max: - return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMax, args, false), isMax: true}, nil + return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMax, args, false), isMax: true, ctor: collate.GetCollator(args[0].GetType().Collate)}, nil case tipb.ExprType_Min: - return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMin, args, false)}, nil + return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMin, args, false), ctor: collate.GetCollator(args[0].GetType().Collate)}, nil case tipb.ExprType_First: return &firstRowFunction{aggFunction: newAggFunc(ast.AggFuncFirstRow, args, false)}, nil case tipb.ExprType_Agg_BitOr: diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 1d5381f6c973d..30f020e7dfdf2 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -28,6 +28,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/collate" ) // AggFuncDesc describes an aggregation function signature, only used in planner. @@ -230,9 +231,9 @@ func (a *AggFuncDesc) GetAggFunc(ctx sessionctx.Context) Aggregation { } return &concatFunction{aggFunction: aggFunc, maxLen: maxLen} case ast.AggFuncMax: - return &maxMinFunction{aggFunction: aggFunc, isMax: true} + return &maxMinFunction{aggFunction: aggFunc, isMax: true, ctor: collate.GetCollator(a.Args[0].GetType().Collate)} case ast.AggFuncMin: - return &maxMinFunction{aggFunction: aggFunc, isMax: false} + return &maxMinFunction{aggFunction: aggFunc, isMax: false, ctor: collate.GetCollator(a.Args[0].GetType().Collate)} case ast.AggFuncFirstRow: return &firstRowFunction{aggFunction: aggFunc} case ast.AggFuncBitOr: diff --git a/expression/aggregation/max_min.go b/expression/aggregation/max_min.go index be25c7160a188..10f312d275023 100644 --- a/expression/aggregation/max_min.go +++ b/expression/aggregation/max_min.go @@ -18,11 +18,13 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/collate" ) type maxMinFunction struct { aggFunction isMax bool + ctor collate.Collator } // GetResult implements Aggregation interface. @@ -49,7 +51,7 @@ func (mmf *maxMinFunction) Update(evalCtx *AggEvaluateContext, sc *stmtctx.State return nil } var c int - c, err = evalCtx.Value.CompareDatum(sc, &value) + c, err = evalCtx.Value.Compare(sc, &value, mmf.ctor) if err != nil { return err } diff --git a/planner/core/handle_cols.go b/planner/core/handle_cols.go index c1bca6eec7ddf..48d6ab2444edd 100644 --- a/planner/core/handle_cols.go +++ b/planner/core/handle_cols.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/collate" ) // HandleCols is the interface that holds handle columns. @@ -48,7 +49,7 @@ type HandleCols interface { // NumCols returns the number of columns. NumCols() int // Compare compares two datum rows by handle order. - Compare(a, b []types.Datum) (int, error) + Compare(a, b []types.Datum, ctors []collate.Collator) (int, error) // GetFieldTypes return field types of columns. GetFieldsTypes() []*types.FieldType } @@ -145,11 +146,11 @@ func (cb *CommonHandleCols) String() string { } // Compare implements the kv.HandleCols interface. -func (cb *CommonHandleCols) Compare(a, b []types.Datum) (int, error) { - for _, col := range cb.columns { +func (cb *CommonHandleCols) Compare(a, b []types.Datum, ctors []collate.Collator) (int, error) { + for i, col := range cb.columns { aDatum := &a[col.Index] bDatum := &b[col.Index] - cmp, err := aDatum.CompareDatum(cb.sc, bDatum) + cmp, err := aDatum.Compare(cb.sc, bDatum, ctors[i]) if err != nil { return 0, err } @@ -237,7 +238,7 @@ func (ib *IntHandleCols) NumCols() int { } // Compare implements the kv.HandleCols interface. -func (ib *IntHandleCols) Compare(a, b []types.Datum) (int, error) { +func (ib *IntHandleCols) Compare(a, b []types.Datum, ctors []collate.Collator) (int, error) { aInt := a[ib.col.Index].GetInt64() bInt := b[ib.col.Index].GetInt64() if aInt == bInt { diff --git a/util/ranger/detacher.go b/util/ranger/detacher.go index 7b422d10243cc..d7ac4c3efbe85 100644 --- a/util/ranger/detacher.go +++ b/util/ranger/detacher.go @@ -448,7 +448,8 @@ func allSinglePoints(sc *stmtctx.StatementContext, points []*point) []*point { if !left.start || right.start || left.excl || right.excl { return nil } - cmp, err := left.value.CompareDatum(sc, &right.value) + // Since the point's collations are equal to the column's collation, we can use any of them. + cmp, err := left.value.Compare(sc, &right.value, collate.GetCollator(left.value.Collation())) if err != nil || cmp != 0 { return nil } diff --git a/util/ranger/points.go b/util/ranger/points.go index a02f77cc08909..d6f88268a2f17 100644 --- a/util/ranger/points.go +++ b/util/ranger/points.go @@ -467,7 +467,7 @@ func handleEnumFromBinOp(sc *stmtctx.StatementContext, ft *types.FieldType, val } d := types.NewCollateMysqlEnumDatum(tmpEnum, ft.Collate) - if v, err := d.CompareDatum(sc, &val); err == nil { + if v, err := d.Compare(sc, &val, collate.GetCollator(ft.Collate)); err == nil { switch op { case ast.LT: if v < 0 {