Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server, tidb-server: improve unix socket handling #8836

Merged
merged 17 commits into from
Jan 9, 2019
Merged
61 changes: 55 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
Expand Down Expand Up @@ -80,6 +81,7 @@ type Server struct {
tlsConfig *tls.Config
driver IDriver
listener net.Listener
socket net.Listener
rwlock *sync.RWMutex
concurrentLimiter *TokenLimiter
clients map[uint32]*clientConn
Expand Down Expand Up @@ -133,6 +135,39 @@ func (s *Server) isUnixSocket() bool {
return s.cfg.Socket != ""
}

func (s *Server) forwardUnixSocketToTCP() {
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
for {
if s.listener == nil {
return // server shutdown has started
}
if uconn, err := s.socket.Accept(); err == nil {
log.Infof("server socket forwarding from [%s] to [%s]", s.cfg.Socket, addr)
go s.handleForwardedConnection(uconn, addr)
} else {
if s.listener != nil {
log.Errorf("server failed to forward from [%s] to [%s], err: %s", s.cfg.Socket, addr, err)
}
}
}
}

func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than logging warnings, I would return an error and have the caller deal with it (the caller can log or retry).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, handleForwardedConnection is an async call (called as a go routine), so I think it is important to handle its own?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the caller would expand to an anonymous function that calls this function and does something with the error.

defer terror.Call(uconn.Close)
if tconn, err := net.Dial("tcp", addr); err == nil {
go func() {
if _, err := io.Copy(uconn, tconn); err != nil {
log.Warningf("copy server to socket failed: %s", err)
}
}()
if _, err := io.Copy(tconn, uconn); err != nil {
log.Warningf("socket forward copy failed: %s", err)
}
} else {
log.Warningf("socket forward failed: could not connect to [%s], err: %s", addr, err)
}
}

// NewServer creates a new Server.
func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
s := &Server{
Expand All @@ -151,15 +186,24 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
}

var err error
if cfg.Socket != "" {
if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket)
}
} else {

if s.cfg.Host != "" && s.cfg.Port != 0 {
morgo marked this conversation as resolved.
Show resolved Hide resolved
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
if s.listener, err = net.Listen("tcp", addr); err == nil {
log.Infof("Server is running MySQL Protocol at [%s]", addr)
if cfg.Socket != "" {
if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil {
log.Infof("Server redirecting [%s] to [%s]", s.cfg.Socket, addr)
go s.forwardUnixSocketToTCP()
}
}
}
} else if cfg.Socket != "" {
if s.listener, err = net.Listen("unix", cfg.Socket); err == nil {
log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket)
}
} else {
err = errors.New("Server not configured to listen on either -socket or -host and -port")
}

if cfg.ProxyProtocol.Networks != "" {
Expand Down Expand Up @@ -292,6 +336,11 @@ func (s *Server) Close() {
terror.Log(errors.Trace(err))
s.listener = nil
}
if s.socket != nil {
err := s.socket.Close()
terror.Log(errors.Trace(err))
morgo marked this conversation as resolved.
Show resolved Hide resolved
s.socket = nil
}
if s.statusServer != nil {
err := s.statusServer.Close()
terror.Log(errors.Trace(err))
Expand Down Expand Up @@ -419,7 +468,7 @@ func (s *Server) kickIdleConnection() {
for _, cc := range conns {
err := cc.Close()
if err != nil {
log.Error("close connection error:", err)
log.Errorf("close connection error: %s", err)
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,34 @@ func (ts *TidbTestSuite) TestMultiStatements(c *C) {
runTestMultiStatements(c)
}

func (ts *TidbTestSuite) TestSocketForwarding(c *C) {
cfg := config.NewConfig()
cfg.Socket = "/tmp/tidbtest.sock"
cfg.Port = 3999
os.Remove(cfg.Socket)
cfg.Status.ReportStatus = false

server, err := NewServer(cfg, ts.tidbdrv)
c.Assert(err, IsNil)
go server.Run()
time.Sleep(time.Millisecond * 100)
defer server.Close()

runTestRegression(c, func(config *mysql.Config) {
config.User = "root"
config.Net = "unix"
config.Addr = "/tmp/tidbtest.sock"
config.DBName = "test"
config.Strict = true
}, "SocketRegression")
}

func (ts *TidbTestSuite) TestSocket(c *C) {
cfg := config.NewConfig()
cfg.Socket = "/tmp/tidbtest.sock"
cfg.Port = 0
os.Remove(cfg.Socket)
cfg.Host = ""
cfg.Status.ReportStatus = false

server, err := NewServer(cfg, ts.tidbdrv)
Expand All @@ -178,6 +203,7 @@ func (ts *TidbTestSuite) TestSocket(c *C) {
config.DBName = "test"
config.Strict = true
}, "SocketRegression")

}

// generateCert generates a private key and a certificate in PEM format based on parameters.
Expand Down