Skip to content

Commit

Permalink
sessionctx, types, executor: support encoding and decoding user-defin…
Browse files Browse the repository at this point in the history
…ed variables (#35343)

close #35288
  • Loading branch information
djshow832 authored Jun 15, 2022
1 parent 0c9460c commit 4fc9551
Show file tree
Hide file tree
Showing 15 changed files with 410 additions and 10 deletions.
9 changes: 0 additions & 9 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6070,12 +6070,3 @@ func TestIsFastPlan(t *testing.T) {
require.Equal(t, ca.isFastPlan, ok)
}
}

func TestShowSessionStates(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustQuery("show session_states").Check(testkit.Rows())
tk.MustExec("set session_states 'x'")
tk.MustGetErrCode("set session_states 1", errno.ErrParse)
}
24 changes: 24 additions & 0 deletions executor/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/privilege/privileges"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/helper"
Expand Down Expand Up @@ -1930,6 +1931,29 @@ func (e *ShowExec) fetchShowBuiltins() error {
}

func (e *ShowExec) fetchShowSessionStates(ctx context.Context) error {
sessionStates := &sessionstates.SessionStates{}
err := e.ctx.EncodeSessionStates(ctx, e.ctx, sessionStates)
if err != nil {
return err
}
stateBytes, err := gjson.Marshal(sessionStates)
if err != nil {
return errors.Trace(err)
}
stateJSON := json.BinaryJSON{}
if err = stateJSON.UnmarshalJSON(stateBytes); err != nil {
return err
}
// This will be implemented in future PRs.
tokenBytes, err := gjson.Marshal("")
if err != nil {
return errors.Trace(err)
}
tokenJSON := json.BinaryJSON{}
if err = tokenJSON.UnmarshalJSON(tokenBytes); err != nil {
return err
}
e.appendRow([]interface{}{stateJSON, tokenJSON})
return nil
}

Expand Down
11 changes: 10 additions & 1 deletion executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package executor

import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"strings"
Expand All @@ -41,6 +43,7 @@ import (
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -1687,7 +1690,13 @@ func asyncDelayShutdown(p *os.Process, delay time.Duration) {
}

func (e *SimpleExec) executeSetSessionStates(ctx context.Context, s *ast.SetSessionStatesStmt) error {
return nil
var sessionStates sessionstates.SessionStates
decoder := json.NewDecoder(bytes.NewReader([]byte(s.SessionStates)))
decoder.UseNumber()
if err := decoder.Decode(&sessionStates); err != nil {
return errors.Trace(err)
}
return e.ctx.DecodeSessionStates(ctx, e.ctx, &sessionStates)
}

func (e *SimpleExec) executeAdmin(s *ast.AdminStmt) error {
Expand Down
11 changes: 11 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/sessiontxn/legacy"
"github.com/pingcap/tidb/sessiontxn/staleread"
Expand Down Expand Up @@ -3500,3 +3501,13 @@ func (s *session) getSnapshotInterceptor() kv.SnapshotInterceptor {
func (s *session) GetStmtStats() *stmtstats.StatementStats {
return s.stmtStats
}

// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface.
func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
return s.sessionVars.EncodeSessionStates(ctx, sessionStates)
}

// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface.
func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
return s.sessionVars.DecodeSessionStates(ctx, sessionStates)
}
10 changes: 10 additions & 0 deletions sessionctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/kvcache"
Expand All @@ -41,8 +42,17 @@ type InfoschemaMetaVersion interface {
SchemaMetaVersion() int64
}

// SessionStatesHandler is an interface for encoding and decoding session states.
type SessionStatesHandler interface {
// EncodeSessionStates encodes session states into a JSON.
EncodeSessionStates(context.Context, Context, *sessionstates.SessionStates) error
// DecodeSessionStates decodes a map into session states.
DecodeSessionStates(context.Context, Context, *sessionstates.SessionStates) error
}

// Context is an interface for transaction and executive args environment.
type Context interface {
SessionStatesHandler
// NewTxn creates a new transaction for further execution.
// If old transaction is valid, it is committed first.
// It's used in BEGIN statement and DDL statements to commit old transaction.
Expand Down
27 changes: 27 additions & 0 deletions sessionctx/sessionstates/session_states.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sessionstates

import (
ptypes "github.com/pingcap/tidb/parser/types"
"github.com/pingcap/tidb/types"
)

// SessionStates contains all the states in the session that should be migrated when the session
// is migrated to another server. It is shown by `show session_states` and recovered by `set session_states`.
type SessionStates struct {
UserVars map[string]*types.Datum `json:"user-var-values,omitempty"`
UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"`
}
91 changes: 91 additions & 0 deletions sessionctx/sessionstates/session_states_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sessionstates_test

import (
"fmt"
"strings"
"testing"

"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
)

func TestGrammar(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
rows := tk.MustQuery("show session_states").Rows()
require.Len(t, rows, 1)
tk.MustExec("set session_states '{}'")
tk.MustGetErrCode("set session_states 1", errno.ErrParse)
}

func TestUserVars(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("create table test.t1(" +
"j json, b blob, s varchar(255), st set('red', 'green', 'blue'), en enum('red', 'green', 'blue'))")
tk.MustExec("insert into test.t1 values('{\"color:\": \"red\"}', 'red', 'red', 'red,green', 'red')")

tests := []string{
"",
"set @%s=null",
"set @%s=1",
"set @%s=1.0e10",
"set @%s=1.0-1",
"set @%s=now()",
"set @%s=1, @%s=1.0-1",
"select @%s:=1+1",
// TiDB doesn't support following features.
//"select j into @%s from test.t1",
//"select j,b,s,st,en into @%s,@%s,@%s,@%s,@%s from test.t1",
}

for _, tt := range tests {
tk1 := testkit.NewTestKit(t, store)
tk2 := testkit.NewTestKit(t, store)
namesNum := strings.Count(tt, "%s")
names := make([]any, 0, namesNum)
for i := 0; i < namesNum; i++ {
names = append(names, fmt.Sprintf("a%d", i))
}
var sql string
if len(tt) > 0 {
sql = fmt.Sprintf(tt, names...)
tk1.MustExec(sql)
}
showSessionStatesAndSet(t, tk1, tk2)
for _, name := range names {
sql := fmt.Sprintf("select @%s", name)
msg := fmt.Sprintf("sql: %s, var name: %s", sql, name)
value1 := tk1.MustQuery(sql).Rows()[0][0]
value2 := tk2.MustQuery(sql).Rows()[0][0]
require.Equal(t, value1, value2, msg)
}
}
}

func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) {
rows := tk1.MustQuery("show session_states").Rows()
require.Len(t, rows, 1)
state := rows[0][0].(string)
state = strings.ReplaceAll(state, "\\", "\\\\")
state = strings.ReplaceAll(state, "'", "\\'")
setSQL := fmt.Sprintf("set session_states '%s'", state)
tk2.MustExec(setSQL)
}
35 changes: 35 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package variable

import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"fmt"
Expand All @@ -41,6 +42,8 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
ptypes "github.com/pingcap/tidb/parser/types"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/stmtctx"
pumpcli "github.com/pingcap/tidb/tidb-binlog/pump_client"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -1834,6 +1837,38 @@ func (s *SessionVars) GetTemporaryTable(tblInfo *model.TableInfo) tableutil.Temp
return nil
}

// EncodeSessionStates saves session states into SessionStates.
func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *sessionstates.SessionStates) (err error) {
// Encode user-defined variables.
s.UsersLock.RLock()
sessionStates.UserVars = make(map[string]*types.Datum, len(s.Users))
for name, userVar := range s.Users {
sessionStates.UserVars[name] = userVar.Clone()
}
sessionStates.UserVarTypes = make(map[string]*ptypes.FieldType, len(s.UserVarTypes))
for name, userVarType := range s.UserVarTypes {
sessionStates.UserVarTypes[name] = userVarType.Clone()
}
s.UsersLock.RUnlock()
return
}

// DecodeSessionStates restores session states from SessionStates.
func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *sessionstates.SessionStates) (err error) {
// Decode user-defined variables.
s.UsersLock.Lock()
s.Users = make(map[string]types.Datum, len(sessionStates.UserVars))
for name, userVar := range sessionStates.UserVars {
s.Users[name] = *userVar.Clone()
}
s.UserVarTypes = make(map[string]*ptypes.FieldType, len(sessionStates.UserVarTypes))
for name, userVarType := range sessionStates.UserVarTypes {
s.UserVarTypes[name] = userVarType.Clone()
}
s.UsersLock.Unlock()
return
}

// TableDelta stands for the changed count for one table or partition.
type TableDelta struct {
Delta int64
Expand Down
57 changes: 57 additions & 0 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package types

import (
gjson "encoding/json"
"fmt"
"math"
"sort"
Expand Down Expand Up @@ -2008,6 +2009,62 @@ func (d *Datum) MemUsage() (sum int64) {
return EmptyDatumSize + int64(cap(d.b)) + int64(len(d.collation))
}

type jsonDatum struct {
K byte `json:"k"`
Decimal uint16 `json:"decimal,omitempty"`
Length uint32 `json:"length,omitempty"`
I int64 `json:"i,omitempty"`
Collation string `json:"collation,omitempty"`
B []byte `json:"b,omitempty"`
Time Time `json:"time,omitempty"`
MyDecimal *MyDecimal `json:"mydecimal,omitempty"`
}

// MarshalJSON implements Marshaler.MarshalJSON interface.
func (d *Datum) MarshalJSON() ([]byte, error) {
jd := &jsonDatum{
K: d.k,
Decimal: d.decimal,
Length: d.length,
I: d.i,
Collation: d.collation,
B: d.b,
}
switch d.k {
case KindMysqlTime:
jd.Time = d.GetMysqlTime()
case KindMysqlDecimal:
jd.MyDecimal = d.GetMysqlDecimal()
default:
if d.x != nil {
return nil, errors.New(fmt.Sprintf("unsupported type: %d", d.k))
}
}
return gjson.Marshal(jd)
}

// UnmarshalJSON implements Unmarshaler.UnmarshalJSON interface.
func (d *Datum) UnmarshalJSON(data []byte) error {
var jd jsonDatum
if err := gjson.Unmarshal(data, &jd); err != nil {
return err
}
d.k = jd.K
d.decimal = jd.Decimal
d.length = jd.Length
d.i = jd.I
d.collation = jd.Collation
d.b = jd.B

switch jd.K {
case KindMysqlTime:
d.SetMysqlTime(jd.Time)
case KindMysqlDecimal:
d.SetMysqlDecimal(jd.MyDecimal)
}
return nil
}

func invalidConv(d *Datum, tp byte) (Datum, error) {
return Datum{}, errors.Errorf("cannot convert datum from %s to type %s", KindStr(d.Kind()), TypeStr(tp))
}
Expand Down
Loading

0 comments on commit 4fc9551

Please sign in to comment.