Skip to content

Commit

Permalink
*: support select/ explain select using bind info (pingcap#10284)
Browse files Browse the repository at this point in the history
  • Loading branch information
iamzhoug37 authored and XuHuaiyu committed Apr 29, 2019
1 parent 69b02a3 commit cd10bca
Show file tree
Hide file tree
Showing 9 changed files with 538 additions and 16 deletions.
169 changes: 169 additions & 0 deletions bindinfo/bind.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package bindinfo

import "github.com/pingcap/parser/ast"

// BindHint will add hints for originStmt according to hintedStmt' hints.
func BindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode {
switch x := originStmt.(type) {
case *ast.SelectStmt:
return selectBind(x, hintedStmt.(*ast.SelectStmt))
default:
return originStmt
}
}

func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt {
if hintedNode.TableHints != nil {
originalNode.TableHints = hintedNode.TableHints
}
if originalNode.From != nil {
originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join)
}
if originalNode.Where != nil {
originalNode.Where = exprBind(originalNode.Where, hintedNode.Where).(ast.ExprNode)
}

if originalNode.Having != nil {
originalNode.Having.Expr = exprBind(originalNode.Having.Expr, hintedNode.Having.Expr)
}

if originalNode.OrderBy != nil {
originalNode.OrderBy = orderByBind(originalNode.OrderBy, hintedNode.OrderBy)
}

if originalNode.Fields != nil {
origFields := originalNode.Fields.Fields
hintFields := hintedNode.Fields.Fields
for idx := range origFields {
origFields[idx].Expr = exprBind(origFields[idx].Expr, hintFields[idx].Expr)
}
}
return originalNode
}

func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause {
for idx := 0; idx < len(originalNode.Items); idx++ {
originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr)
}
return originalNode
}

func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode {
switch v := originalNode.(type) {
case *ast.SubqueryExpr:
if v.Query != nil {
v.Query = resultSetNodeBind(v.Query, hintedNode.(*ast.SubqueryExpr).Query)
}
case *ast.ExistsSubqueryExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.PatternInExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.BinaryOperationExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.BinaryOperationExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.BinaryOperationExpr).R)
}
case *ast.IsNullExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsNullExpr).Expr)
}
case *ast.IsTruthExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsTruthExpr).Expr)
}
case *ast.PatternLikeExpr:
if v.Pattern != nil {
v.Pattern = exprBind(v.Pattern, hintedNode.(*ast.PatternLikeExpr).Pattern)
}
case *ast.CompareSubqueryExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.CompareSubqueryExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.CompareSubqueryExpr).R)
}
case *ast.BetweenExpr:
if v.Left != nil {
v.Left = exprBind(v.Left, hintedNode.(*ast.BetweenExpr).Left)
}
if v.Right != nil {
v.Right = exprBind(v.Right, hintedNode.(*ast.BetweenExpr).Right)
}
case *ast.UnaryOperationExpr:
if v.V != nil {
v.V = exprBind(v.V, hintedNode.(*ast.UnaryOperationExpr).V)
}
case *ast.CaseExpr:
if v.Value != nil {
v.Value = exprBind(v.Value, hintedNode.(*ast.CaseExpr).Value)
}
if v.ElseClause != nil {
v.ElseClause = exprBind(v.ElseClause, hintedNode.(*ast.CaseExpr).ElseClause)
}
}
return originalNode
}

func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode {
switch x := originalNode.(type) {
case *ast.Join:
return joinBind(x, hintedNode.(*ast.Join))
case *ast.TableSource:
ts, _ := hintedNode.(*ast.TableSource)
switch v := x.Source.(type) {
case *ast.SelectStmt:
x.Source = selectBind(v, ts.Source.(*ast.SelectStmt))
case *ast.UnionStmt:
x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt))
case *ast.TableName:
x.Source.(*ast.TableName).IndexHints = ts.Source.(*ast.TableName).IndexHints
}
return x
case *ast.SelectStmt:
return selectBind(x, hintedNode.(*ast.SelectStmt))
case *ast.UnionStmt:
return unionSelectBind(x, hintedNode.(*ast.UnionStmt))
default:
return x
}
}

func joinBind(originalNode, hintedNode *ast.Join) *ast.Join {
if originalNode.Left != nil {
originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left)
}

if hintedNode.Right != nil {
originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right)
}

return originalNode
}

func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode {
selects := originalNode.SelectList.Selects
for i := len(selects) - 1; i >= 0; i-- {
originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i])
}

return originalNode
}
144 changes: 144 additions & 0 deletions bindinfo/bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,147 @@ func (s *testSuite) TestSessionBinding(c *C) {
c.Check(bindData.OriginalSQL, Equals, "select * from t where i > ?")
c.Check(bindData.Status, Equals, "deleted")
}

func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) {
tk := testkit.NewTestKit(c, s.store)
s.cleanBindingEnv(tk)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("drop table if exists t2")
tk.MustExec("create table t1(id int)")
tk.MustExec("create table t2(id int)")

tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows(
"HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]",
"├─TableReader_12 9990.00 root data:Selection_11",
"│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))",
"│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo",
"└─TableReader_15 9990.00 root data:Selection_14",
" └─Selection_14 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows(
"MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id",
"├─Sort_11 9990.00 root test.t1.id:asc",
"│ └─TableReader_10 9990.00 root data:Selection_9",
"│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))",
"│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo",
"└─Sort_15 9990.00 root test.t2.id:asc",
" └─TableReader_14 9990.00 root data:Selection_13",
" └─Selection_13 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")

tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows(
"MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id",
"├─Sort_11 9990.00 root test.t1.id:asc",
"│ └─TableReader_10 9990.00 root data:Selection_9",
"│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))",
"│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo",
"└─Sort_15 9990.00 root test.t2.id:asc",
" └─TableReader_14 9990.00 root data:Selection_13",
" └─Selection_13 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id")

tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows(
"HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]",
"├─TableReader_12 9990.00 root data:Selection_11",
"│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))",
"│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo",
"└─TableReader_15 9990.00 root data:Selection_14",
" └─Selection_14 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))
}

func (s *testSuite) TestExplain(c *C) {
tk := testkit.NewTestKit(c, s.store)
s.cleanBindingEnv(tk)
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("drop table if exists t2")
tk.MustExec("create table t1(id int)")
tk.MustExec("create table t2(id int)")

tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows(
"HashLeftJoin_8 12487.50 root inner join, inner:TableReader_15, equal:[eq(test.t1.id, test.t2.id)]",
"├─TableReader_12 9990.00 root data:Selection_11",
"│ └─Selection_11 9990.00 cop not(isnull(test.t1.id))",
"│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo",
"└─TableReader_15 9990.00 root data:Selection_14",
" └─Selection_14 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustQuery("explain SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id").Check(testkit.Rows(
"MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id",
"├─Sort_11 9990.00 root test.t1.id:asc",
"│ └─TableReader_10 9990.00 root data:Selection_9",
"│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))",
"│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo",
"└─Sort_15 9990.00 root test.t2.id:asc",
" └─TableReader_14 9990.00 root data:Selection_13",
" └─Selection_13 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")

tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows(
"MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id",
"├─Sort_11 9990.00 root test.t1.id:asc",
"│ └─TableReader_10 9990.00 root data:Selection_9",
"│ └─Selection_9 9990.00 cop not(isnull(test.t1.id))",
"│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo",
"└─Sort_15 9990.00 root test.t2.id:asc",
" └─TableReader_14 9990.00 root data:Selection_13",
" └─Selection_13 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id")
}

func (s *testSuite) TestErrorBind(c *C) {
tk := testkit.NewTestKit(c, s.store)
s.cleanBindingEnv(tk)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t(i int, s varchar(20))")
tk.MustExec("create table t1(i int, s varchar(20))")
tk.MustExec("create index index_t on t(i,s)")

_, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100")
c.Assert(err, IsNil, Commentf("err %v", err))

bindData := s.domain.BindHandle().GetBindRecord("select * from t where i > ?", "test")
c.Check(bindData, NotNil)
c.Check(bindData.OriginalSQL, Equals, "select * from t where i > ?")
c.Check(bindData.BindSQL, Equals, "select * from t use index(index_t) where i>100")
c.Check(bindData.Db, Equals, "test")
c.Check(bindData.Status, Equals, "using")
c.Check(bindData.Charset, NotNil)
c.Check(bindData.Collation, NotNil)
c.Check(bindData.CreateTime, NotNil)
c.Check(bindData.UpdateTime, NotNil)

tk.MustExec("drop index index_t on t")
_, err = tk.Exec("select * from t where i > 10")
c.Check(err, IsNil)

s.domain.BindHandle().DropInvalidBindRecord()

rs, err := tk.Exec("show global bindings")
c.Assert(err, IsNil)
chk := rs.NewRecordBatch()
err = rs.Next(context.TODO(), chk)
c.Check(err, IsNil)
c.Check(chk.NumRows(), Equals, 0)
}
8 changes: 5 additions & 3 deletions bindinfo/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ import (
)

const (
// using is the bind info's in use status.
using = "using"
// Using is the bind info's in use status.
Using = "using"
// deleted is the bind info's deleted status.
deleted = "deleted"
// Invalid is the bind info's invalid status.
Invalid = "invalid"
)

// BindMeta stores the basic bind info and bindSql astNode.
type BindMeta struct {
*BindRecord
ast ast.StmtNode //ast will be used to do query sql bind check
Ast ast.StmtNode //ast will be used to do query sql bind check
}

// cache is a k-v map, key is original sql, value is a slice of BindMeta.
Expand Down
Loading

0 comments on commit cd10bca

Please sign in to comment.