Skip to content

Commit

Permalink
support system variable wait_timeout. (#8245) (#8346)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhu-cc authored and jackysp committed Nov 28, 2018
1 parent e5dc251 commit bb7bb14
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 7 deletions.
29 changes: 25 additions & 4 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ func (cc *clientConn) writePacket(data []byte) error {
return cc.pkt.writePacket(data)
}

// getSessionVarsWaitTimeout get session variable wait_timeout
func (cc *clientConn) getSessionVarsWaitTimeout() uint64 {
valStr, _ := cc.ctx.GetSessionVars().GetSystemVar(variable.WaitTimeout)
waitTimeout, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
log.Errorf("con:%d get sysval wait_timeout error, use default value.", cc.connectionID)
// if get waitTimeout error, use default value
waitTimeout = variable.DefWaitTimeout
}
return waitTimeout
}

type handshakeResponse41 struct {
Capability uint32
Collation uint8
Expand Down Expand Up @@ -449,13 +461,22 @@ func (cc *clientConn) Run() {
}

cc.alloc.Reset()
// close connection when idle time is more than wait_timout
waitTimeout := cc.getSessionVarsWaitTimeout()
cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second)
start := time.Now()
data, err := cc.readPacket()
if err != nil {
if terror.ErrorNotEqual(err, io.EOF) {
errStack := errors.ErrorStack(err)
if !strings.Contains(errStack, "use of closed network connection") {
log.Errorf("con:%d read packet error, close this connection %s",
cc.connectionID, errStack)
if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() {
idleTime := time.Now().Sub(start)
log.Infof("con:%d read packet timeout, close this connection, idle: %v, wait_timeout: %v", cc.connectionID, idleTime, waitTimeout)
} else {
errStack := errors.ErrorStack(err)
if !strings.Contains(errStack, "use of closed network connection") {
log.Errorf("con:%d read packet error, close this connection %s",
cc.connectionID, errStack)
}
}
}
return
Expand Down
32 changes: 31 additions & 1 deletion server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@ import (

. "github.com/pingcap/check"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/store/mockstore"
)

type ConnTestSuite struct{}
type ConnTestSuite struct {
dom *domain.Domain
store kv.Storage
}

var _ = Suite(ConnTestSuite{})

Expand Down Expand Up @@ -149,6 +156,29 @@ func (ts ConnTestSuite) TestInitialHandshake(c *C) {
c.Assert(outBuffer.Bytes()[4:], DeepEquals, expected.Bytes())
}

func (ts ConnTestSuite) testGetSessionVarsWaitTimeout(c *C) {
c.Parallel()
var err error
ts.store, err = mockstore.NewMockTikvStore()
c.Assert(err, IsNil)
ts.dom, err = session.BootstrapSession(ts.store)
c.Assert(err, IsNil)
se, err := session.CreateSession4Test(ts.store)
c.Assert(err, IsNil)
tc := &TiDBContext{
session: se,
stmts: make(map[int]*TiDBStatement),
}
cc := &clientConn{
connectionID: 1,
server: &Server{
capability: defaultCapability,
},
ctx: tc,
}
c.Assert(cc.getSessionVarsWaitTimeout(), Equals, 28800)
}

func mapIdentical(m1, m2 map[string]string) bool {
return mapBelong(m1, m2) && mapBelong(m2, m1)
}
Expand Down
17 changes: 16 additions & 1 deletion server/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ package server
import (
"bufio"
"io"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
Expand All @@ -50,6 +51,7 @@ type packetIO struct {
bufReadConn *bufferedReadConn
bufWriter *bufio.Writer
sequence uint8
readTimeout time.Duration
}

func newPacketIO(bufReadConn *bufferedReadConn) *packetIO {
Expand All @@ -63,9 +65,17 @@ func (p *packetIO) setBufferedReadConn(bufReadConn *bufferedReadConn) {
p.bufWriter = bufio.NewWriterSize(bufReadConn, defaultWriterSize)
}

func (p *packetIO) setReadTimeout(timeout time.Duration) {
p.readTimeout = timeout
}

func (p *packetIO) readOnePacket() ([]byte, error) {
var header [4]byte

if p.readTimeout > 0 {
if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
return nil, err
}
}
if _, err := io.ReadFull(p.bufReadConn, header[:]); err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -80,6 +90,11 @@ func (p *packetIO) readOnePacket() ([]byte, error) {
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

data := make([]byte, length)
if p.readTimeout > 0 {
if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
return nil, err
}
}
if _, err := io.ReadFull(p.bufReadConn, data); err != nil {
return nil, errors.Trace(err)
}
Expand Down
1 change: 1 addition & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,7 @@ const loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variab
variable.MaxAllowedPacket + quoteCommaQuote +
variable.TimeZone + quoteCommaQuote +
variable.BlockEncryptionMode + quoteCommaQuote +
variable.WaitTimeout + quoteCommaQuote +
/* TiDB specific global variables: */
variable.TiDBSkipUTF8Check + quoteCommaQuote +
variable.TiDBIndexJoinBatchSize + quoteCommaQuote +
Expand Down
4 changes: 3 additions & 1 deletion sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ var defaultSysVars = []*SysVar{
{ScopeGlobal, "innodb_buffer_pool_size", "134217728"},
{ScopeGlobal, "innodb_adaptive_flushing", "ON"},
{ScopeNone, "datadir", "/usr/local/mysql/data/"},
{ScopeGlobal | ScopeSession, "wait_timeout", "28800"},
{ScopeGlobal | ScopeSession, WaitTimeout, strconv.FormatInt(DefWaitTimeout, 10)},
{ScopeGlobal, "innodb_monitor_enable", ""},
{ScopeNone, "date_format", "%Y-%m-%d"},
{ScopeGlobal, "innodb_buffer_pool_filename", "ib_buffer_pool"},
Expand Down Expand Up @@ -774,6 +774,8 @@ const (
SyncBinlog = "sync_binlog"
// BlockEncryptionMode is the name for 'block_encryption_mode' system variable.
BlockEncryptionMode = "block_encryption_mode"
// WaitTimeout is the name for 'wait_timeout' system variable.
WaitTimeout = "wait_timeout"
// ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable.
ValidatePasswordNumberCount = "validate_password_number_count"
// ValidatePasswordLength is the name of 'validate_password_length' system variable.
Expand Down
1 change: 1 addition & 0 deletions sessionctx/variable/tidb_vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ const (
DefCurretTS = 0
DefMaxChunkSize = 32
DefDMLBatchSize = 20000
DefWaitTimeout = 28800
DefTiDBMemQuotaHashJoin = 32 << 30 // 32GB.
DefTiDBMemQuotaMergeJoin = 32 << 30 // 32GB.
DefTiDBMemQuotaSort = 32 << 30 // 32GB.
Expand Down
2 changes: 2 additions & 0 deletions sessionctx/variable/varsutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string,
return checkUInt64SystemVar(name, value, 400, 524288, vars)
case TmpTableSize:
return checkUInt64SystemVar(name, value, 1024, math.MaxUint64, vars)
case WaitTimeout:
return checkUInt64SystemVar(name, value, 1, 31536000, vars)
case TimeZone:
if strings.EqualFold(value, "SYSTEM") {
return "SYSTEM", nil
Expand Down

0 comments on commit bb7bb14

Please sign in to comment.