Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Tcp keep alive and provide keep alive period setting #13434

Merged
merged 7 commits into from
Jul 11, 2023
1 change: 1 addition & 0 deletions go/flags/endtoend/vtgate.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Usage of vtgate:
--max_payload_size int The threshold for query payloads in bytes. A payload greater than this threshold will result in a failure to handle the query.
--message_stream_grace_period duration the amount of time to give for a vttablet to resume if it ends a message stream, usually because of a reparent. (default 30s)
--min_number_serving_vttablets int The minimum number of vttablets for each replicating tablet_type (e.g. replica, rdonly) that will be continue to be used even with replication lag above discovery_low_replication_lag, but still below discovery_high_replication_lag_minimum_serving. (default 2)
--mysql-server-keepalive-period duration TCP period between keep-alives
--mysql-server-pool-conn-read-buffers If set, the server will pool incoming connection read buffers
--mysql_allow_clear_text_without_tls If set, the server will allow the use of a clear text password over non-SSL connections.
--mysql_auth_server_impl string Which auth server implementation to use. Options: none, ldap, clientcert, static, vault. (default "static")
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestValidCert(t *testing.T) {
authServer := newAuthServerClientCert()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -114,7 +114,7 @@ func TestNoCert(t *testing.T) {
authServer := newAuthServerClientCert()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -221,7 +221,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -294,7 +294,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -341,7 +341,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -424,7 +424,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down
6 changes: 3 additions & 3 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func New(t testing.TB) *DB {
authServer := mysql.NewAuthServerNone()

// Start listening.
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false)
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -382,7 +382,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
if db.shouldClose.Load() {
c.Close()

//log error
// log error
if err := callback(&sqltypes.Result{}); err != nil {
log.Errorf("callback failed : %v", err)
}
Expand All @@ -393,7 +393,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
// The driver may send this at connection time, and we don't want it to
// interfere.
if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") {
//log error
// log error
if err := callback(&sqltypes.Result{}); err != nil {
log.Errorf("callback failed : %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestClearTextClientAuth(t *testing.T) {
defer authServer.close()

// Create the listener.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestSSLConnection(t *testing.T) {
defer authServer.close()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
89 changes: 62 additions & 27 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ type Listener struct {
// connBufferPooling configures if vtgate server pools connection buffers
connBufferPooling bool

// connKeepAlivePeriod is period between tcp keep-alives.
connKeepAlivePeriod time.Duration

// shutdown indicates that Shutdown method was called.
shutdown atomic.Bool

Expand All @@ -208,6 +211,8 @@ type Listener struct {
// handled further by the MySQL handler. An non-nil error will stop
// processing the connection by the MySQL handler.
PreHandleFunc func(context.Context, net.Conn, uint32) (net.Conn, error)

TcpPropFunc func(*net.TCPConn, time.Duration) error
}

// NewFromListener creates a new mysql listener from an existing net.Listener
Expand All @@ -218,15 +223,17 @@ func NewFromListener(
connReadTimeout time.Duration,
connWriteTimeout time.Duration,
connBufferPooling bool,
keepAlivePeriod time.Duration,
) (*Listener, error) {
cfg := ListenerConfig{
Listener: l,
AuthServer: authServer,
Handler: handler,
ConnReadTimeout: connReadTimeout,
ConnWriteTimeout: connWriteTimeout,
ConnReadBufferSize: connBufferSize,
ConnBufferPooling: connBufferPooling,
Listener: l,
AuthServer: authServer,
Handler: handler,
ConnReadTimeout: connReadTimeout,
ConnWriteTimeout: connWriteTimeout,
ConnReadBufferSize: connBufferSize,
ConnBufferPooling: connBufferPooling,
ConnKeepAlivePeriod: keepAlivePeriod,
}
return NewListenerWithConfig(cfg)
}
Expand All @@ -240,31 +247,33 @@ func NewListener(
connWriteTimeout time.Duration,
proxyProtocol bool,
connBufferPooling bool,
keepAlivePeriod time.Duration,
) (*Listener, error) {
listener, err := net.Listen(protocol, address)
if err != nil {
return nil, err
}
if proxyProtocol {
proxyListener := &proxyproto.Listener{Listener: listener}
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling)
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod)
}

return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling)
return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod)
}

// ListenerConfig should be used with NewListenerWithConfig to specify listener parameters.
type ListenerConfig struct {
// Protocol-Address pair and Listener are mutually exclusive parameters
Protocol string
Address string
Listener net.Listener
AuthServer AuthServer
Handler Handler
ConnReadTimeout time.Duration
ConnWriteTimeout time.Duration
ConnReadBufferSize int
ConnBufferPooling bool
Protocol string
Address string
Listener net.Listener
AuthServer AuthServer
Handler Handler
ConnReadTimeout time.Duration
ConnWriteTimeout time.Duration
ConnReadBufferSize int
ConnBufferPooling bool
ConnKeepAlivePeriod time.Duration
}

// NewListenerWithConfig creates new listener using provided config. There are
Expand All @@ -282,15 +291,17 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) {
}

return &Listener{
authServer: cfg.AuthServer,
handler: cfg.Handler,
listener: l,
ServerVersion: servenv.AppVersion.MySQLVersion(),
connectionID: 1,
connReadTimeout: cfg.ConnReadTimeout,
connWriteTimeout: cfg.ConnWriteTimeout,
connReadBufferSize: cfg.ConnReadBufferSize,
connBufferPooling: cfg.ConnBufferPooling,
authServer: cfg.AuthServer,
handler: cfg.Handler,
listener: l,
ServerVersion: servenv.AppVersion.MySQLVersion(),
connectionID: 1,
connReadTimeout: cfg.ConnReadTimeout,
connWriteTimeout: cfg.ConnWriteTimeout,
connReadBufferSize: cfg.ConnReadBufferSize,
connBufferPooling: cfg.ConnBufferPooling,
connKeepAlivePeriod: cfg.ConnKeepAlivePeriod,
TcpPropFunc: setTcpConnProperties,
}, nil
}

Expand Down Expand Up @@ -336,6 +347,14 @@ func (l *Listener) Accept() {
// handle is called in a go routine for each client connection.
// FIXME(alainjobart) handle per-connection logs in a way that makes sense.
func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Time) {

// Enable KeepAlive on TCP connections and change keep-alive period if provided.
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := l.TcpPropFunc(tcpConn, l.connKeepAlivePeriod); err != nil {
log.Errorf("error in setting tcp properties: %v", err)
}
}

if l.connReadTimeout != 0 || l.connWriteTimeout != 0 {
conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout)
}
Expand Down Expand Up @@ -531,6 +550,22 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}
}

func setTcpConnProperties(conn *net.TCPConn, keepAlivePeriod time.Duration) error {
if err := conn.SetKeepAlive(true); err != nil {
return vterrors.Wrapf(err, "unable to enable keepalive on tcp connection")
}

if keepAlivePeriod <= 0 {
return nil
}

if err := conn.SetKeepAlivePeriod(keepAlivePeriod); err != nil {
return vterrors.Wrapf(err, "unable to set keepalive period on tcp connection")
}

return nil
}

// Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed.
func (l *Listener) Close() {
l.listener.Close()
Expand Down
Loading