Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sessionctx: support encoding and decoding session variables #35531

Merged
merged 8 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import (
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/temptable"
"github.com/pingcap/tidb/util/logutil/consistency"
"github.com/pingcap/tidb/util/sem"
"github.com/pingcap/tidb/util/topsql"
topsqlstate "github.com/pingcap/tidb/util/topsql/state"
"github.com/pingcap/tidb/util/topsql/stmtstats"
Expand Down Expand Up @@ -3496,10 +3497,44 @@ func (s *session) GetStmtStats() *stmtstats.StatementStats {

// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface.
func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
return s.sessionVars.EncodeSessionStates(ctx, sessionStates)
if err = s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil {
return err
}

// Encode session variables. We put it here instead of SessionVars to avoid cycle import.
sessionStates.SystemVars = make(map[string]string)
for _, sv := range variable.GetSysVars() {
switch {
case sv.Hidden, sv.HasNoneScope(), sv.HasInstanceScope(), !sv.HasSessionScope():
// Hidden, none-scoped, and instance-scoped variables cannot be modified.
djshow832 marked this conversation as resolved.
Show resolved Hide resolved
// Noop variables should also be migrated even if they are noop.
continue
case sv.ReadOnly:
// Skip read-only variables here. We encode them into SessionStates manually.
continue
case sem.IsEnabled() && sem.IsInvisibleSysVar(sv.Name):
// If they are shown, there will be a security issue.
continue
}
// Get all session variables because the default values may change between versions.
if val, keep, err := variable.GetSessionStatesSystemVar(s.sessionVars, sv.Name); err == nil && keep {
sessionStates.SystemVars[sv.Name] = val
}
}
return
}

// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface.
func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
return s.sessionVars.DecodeSessionStates(ctx, sessionStates)
if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil {
return err
}

// Decode session variables.
for name, val := range sessionStates.SystemVars {
if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil {
return err
}
}
return err
}
1 change: 1 addition & 0 deletions sessionctx/sessionstates/session_states.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ import (
type SessionStates struct {
UserVars map[string]*types.Datum `json:"user-var-values,omitempty"`
UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"`
SystemVars map[string]string `json:"sys-vars,omitempty"`
}
131 changes: 128 additions & 3 deletions sessionctx/sessionstates/session_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ package sessionstates_test

import (
"fmt"
"strconv"
"strings"
"testing"

"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/util/sem"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -80,12 +83,134 @@ func TestUserVars(t *testing.T) {
}
}

func TestSystemVars(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()

tests := []struct {
stmts []string
varName string
inSessionStates bool
checkStmt string
expectedValue string
}{
{
// normal variable
inSessionStates: true,
varName: variable.TiDBMaxTiFlashThreads,
expectedValue: strconv.Itoa(variable.DefTiFlashMaxThreads),
},
{
// hidden variable
inSessionStates: false,
varName: variable.TiDBTxnReadTS,
},
{
// none-scoped variable
inSessionStates: false,
varName: variable.DataDir,
expectedValue: "/usr/local/mysql/data/",
},
{
// instance-scoped variable
inSessionStates: false,
varName: variable.TiDBGeneralLog,
expectedValue: "0",
},
{
// global-scoped variable
inSessionStates: false,
varName: variable.TiDBAutoAnalyzeStartTime,
expectedValue: variable.DefAutoAnalyzeStartTime,
},
{
// sem invisible variable
inSessionStates: false,
varName: variable.TiDBAllowRemoveAutoInc,
},
{
// noop variables
stmts: []string{"set sql_buffer_result=true"},
inSessionStates: true,
varName: "sql_buffer_result",
expectedValue: "1",
},
{
stmts: []string{"set transaction isolation level repeatable read"},
inSessionStates: true,
varName: "tx_isolation_one_shot",
expectedValue: "REPEATABLE-READ",
},
{
inSessionStates: false,
varName: variable.Timestamp,
},
{
stmts: []string{"set timestamp=100"},
inSessionStates: true,
varName: variable.Timestamp,
expectedValue: "100",
},
{
stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000"},
inSessionStates: true,
varName: variable.RandSeed1,
checkStmt: "select rand()",
expectedValue: "0.028870999839968048",
},
{
stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000", "select rand()"},
inSessionStates: true,
varName: variable.RandSeed1,
checkStmt: "select rand()",
expectedValue: "0.11641535266900002",
},
}

sem.Enable()
for _, tt := range tests {
tk1 := testkit.NewTestKit(t, store)
for _, stmt := range tt.stmts {
if strings.HasPrefix(stmt, "select") {
tk1.MustQuery(stmt)
} else {
tk1.MustExec(stmt)
}
}
tk2 := testkit.NewTestKit(t, store)
rows := tk1.MustQuery("show session_states").Rows()
state := rows[0][0].(string)
msg := fmt.Sprintf("var name: '%s', expected value: '%s'", tt.varName, tt.expectedValue)
require.Equal(t, tt.inSessionStates, strings.Contains(state, tt.varName), msg)
state = strconv.Quote(state)
setSQL := fmt.Sprintf("set session_states %s", state)
tk2.MustExec(setSQL)
if len(tt.expectedValue) > 0 {
checkStmt := tt.checkStmt
if len(checkStmt) == 0 {
checkStmt = fmt.Sprintf("select @@%s", tt.varName)
}
tk2.MustQuery(checkStmt).Check(testkit.Rows(tt.expectedValue))
}
}

{
// The session value should not change even if the global value changes.
tk1 := testkit.NewTestKit(t, store)
tk1.MustQuery("select @@autocommit").Check(testkit.Rows("1"))
tk2 := testkit.NewTestKit(t, store)
tk2.MustExec("set global autocommit=0")
tk3 := testkit.NewTestKit(t, store)
showSessionStatesAndSet(t, tk1, tk3)
tk3.MustQuery("select @@autocommit").Check(testkit.Rows("1"))
}
}

func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) {
rows := tk1.MustQuery("show session_states").Rows()
require.Len(t, rows, 1)
state := rows[0][0].(string)
state = strings.ReplaceAll(state, "\\", "\\\\")
state = strings.ReplaceAll(state, "'", "\\'")
setSQL := fmt.Sprintf("set session_states '%s'", state)
state = strconv.Quote(state)
setSQL := fmt.Sprintf("set session_states %s", state)
tk2.MustExec(setSQL)
}
16 changes: 16 additions & 0 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ var defaultSysVars = []*SysVar{
}
timestamp := s.StmtCtx.GetOrStoreStmtCache(stmtctx.StmtNowTsCacheKey, time.Now()).(time.Time)
return types.ToString(float64(timestamp.UnixNano()) / float64(time.Second))
}, GetStateValue: func(s *SessionVars) (string, bool, error) {
timestamp, ok := s.systems[Timestamp]
return timestamp, ok && timestamp != DefTimestamp, nil
}},
{Scope: ScopeSession, Name: WarningCount, Value: "0", ReadOnly: true, skipInit: true, GetSession: func(s *SessionVars) (string, error) {
return strconv.Itoa(s.SysWarningCount), nil
Expand All @@ -86,9 +89,13 @@ var defaultSysVars = []*SysVar{
}},
{Scope: ScopeSession, Name: LastInsertID, Value: "", skipInit: true, Type: TypeInt, AllowEmpty: true, MinValue: 0, MaxValue: math.MaxInt64, GetSession: func(s *SessionVars) (string, error) {
return strconv.FormatUint(s.StmtCtx.PrevLastInsertID, 10), nil
}, GetStateValue: func(s *SessionVars) (string, bool, error) {
return "", false, nil
}},
{Scope: ScopeSession, Name: Identity, Value: "", skipInit: true, Type: TypeInt, AllowEmpty: true, MinValue: 0, MaxValue: math.MaxInt64, GetSession: func(s *SessionVars) (string, error) {
return strconv.FormatUint(s.StmtCtx.PrevLastInsertID, 10), nil
}, GetStateValue: func(s *SessionVars) (string, bool, error) {
return "", false, nil
}},
/* TiDB specific variables */
// TODO: TiDBTxnScope is hidden because local txn feature is not done.
Expand Down Expand Up @@ -192,6 +199,11 @@ var defaultSysVars = []*SysVar{
s.txnIsolationLevelOneShot.state = oneShotSet
s.txnIsolationLevelOneShot.value = val
return nil
}, GetStateValue: func(s *SessionVars) (string, bool, error) {
if s.txnIsolationLevelOneShot.state != oneShotDef {
return s.txnIsolationLevelOneShot.value, true, nil
}
return "", false, nil
}},
{Scope: ScopeSession, Name: TiDBOptimizerSelectivityLevel, Value: strconv.Itoa(DefTiDBOptimizerSelectivityLevel), skipInit: true, Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error {
s.OptimizerSelectivityLevel = tidbOptPositiveInt32(val, DefTiDBOptimizerSelectivityLevel)
Expand Down Expand Up @@ -307,12 +319,16 @@ var defaultSysVars = []*SysVar{
return nil
}, GetSession: func(s *SessionVars) (string, error) {
return "0", nil
}, GetStateValue: func(s *SessionVars) (string, bool, error) {
return strconv.FormatUint(uint64(s.Rng.GetSeed1()), 10), true, nil
}},
{Scope: ScopeSession, Name: RandSeed2, Type: TypeInt, Value: "0", skipInit: true, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error {
s.Rng.SetSeed2(uint32(tidbOptPositiveInt32(val, 0)))
return nil
}, GetSession: func(s *SessionVars) (string, error) {
return "0", nil
}, GetStateValue: func(s *SessionVars) (string, bool, error) {
return strconv.FormatUint(uint64(s.Rng.GetSeed2()), 10), true, nil
}},
{Scope: ScopeSession, Name: TiDBReadConsistency, Value: string(ReadConsistencyStrict), Type: TypeStr, Hidden: true,
Validation: func(_ *SessionVars, normalized string, _ string, _ ScopeFlag) (string, error) {
Expand Down
3 changes: 3 additions & 0 deletions sessionctx/variable/variable.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ type SysVar struct {
GetSession func(*SessionVars) (string, error)
// GetGlobal is a getter function for global scope.
GetGlobal func(*SessionVars) (string, error)
// GetStateValue gets the value for session states, which is used for migrating sessions.
// We need a function to override GetSession sometimes, because GetSession may not return the real value.
GetStateValue func(*SessionVars) (string, bool, error)
bb7133 marked this conversation as resolved.
Show resolved Hide resolved
// skipInit defines if the sysvar should be loaded into the session on init.
// This is only important to set for sysvars that include session scope,
// since global scoped sysvars are not-applicable.
Expand Down
23 changes: 23 additions & 0 deletions sessionctx/variable/varsutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,29 @@ func GetSessionOrGlobalSystemVar(s *SessionVars, name string) (string, error) {
return sv.GetGlobalFromHook(s)
}

// GetSessionStatesSystemVar gets the session variable value for session states.
// It's only used for encoding session states when migrating a session.
// The returned boolean indicates whether to keep this value in the session states.
func GetSessionStatesSystemVar(s *SessionVars, name string) (string, bool, error) {
sv := GetSysVar(name)
if sv == nil {
return "", false, ErrUnknownSystemVar.GenWithStackByArgs(name)
}
// Call GetStateValue first if it exists. Otherwise, call GetSession.
if sv.GetStateValue != nil {
return sv.GetStateValue(s)
}
if sv.GetSession != nil {
val, err := sv.GetSessionFromHook(s)
return val, err == nil, err
}
// Only get the cached value. No need to check the global or default value.
if val, ok := s.systems[sv.Name]; ok {
return val, true, nil
}
return "", false, nil
xhebox marked this conversation as resolved.
Show resolved Hide resolved
}

// GetGlobalSystemVar gets a global system variable.
func GetGlobalSystemVar(s *SessionVars, name string) (string, error) {
sv := GetSysVar(name)
Expand Down
19 changes: 19 additions & 0 deletions sessionctx/variable/varsutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,22 @@ func TestStmtVars(t *testing.T) {
err = SetStmtVar(vars, MaxExecutionTime, "100")
require.NoError(t, err)
}

func TestSessionStatesSystemVar(t *testing.T) {
vars := NewSessionVars()
err := SetSessionSystemVar(vars, "autocommit", "1")
require.NoError(t, err)
val, keep, err := GetSessionStatesSystemVar(vars, "autocommit")
require.NoError(t, err)
require.Equal(t, "ON", val)
require.Equal(t, true, keep)
_, keep, err = GetSessionStatesSystemVar(vars, Timestamp)
require.NoError(t, err)
require.Equal(t, false, keep)
err = SetSessionSystemVar(vars, MaxAllowedPacket, "1024")
require.NoError(t, err)
val, keep, err = GetSessionStatesSystemVar(vars, MaxAllowedPacket)
require.NoError(t, err)
require.Equal(t, "1024", val)
require.Equal(t, true, keep)
}
14 changes: 14 additions & 0 deletions util/mathutil/rand.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,17 @@ func (rng *MysqlRng) SetSeed2(seed uint32) {
defer rng.mu.Unlock()
rng.seed2 = seed
}

// GetSeed1 is an interface to get seed1. It's only used for getting session states.
func (rng *MysqlRng) GetSeed1() uint32 {
rng.mu.Lock()
defer rng.mu.Unlock()
return rng.seed1
}

// GetSeed2 is an interface to get seed2. It's only used for getting session states.
func (rng *MysqlRng) GetSeed2() uint32 {
rng.mu.Lock()
defer rng.mu.Unlock()
return rng.seed2
}
2 changes: 2 additions & 0 deletions util/mathutil/rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ func TestRandWithSeed1AndSeed2(t *testing.T) {
require.Equal(t, rng.Gen(), 0.028870999839968048)
require.Equal(t, rng.Gen(), 0.11641535266900002)
require.Equal(t, rng.Gen(), 0.49546379455874096)
require.Equal(t, rng.GetSeed1(), uint32(532000198))
require.Equal(t, rng.GetSeed2(), uint32(689000330))
}