Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add max-sigterm-delay flag #1256

Merged
merged 11 commits into from
Jul 7, 2022
10 changes: 7 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ any client SSL certificates.`,
cmd.PersistentFlags().Uint64Var(&c.conf.MaxConnections, "max-connections", 0,
`Limits the number of connections by refusing any additional connections.
When this flag is not set, there is no limit.`)
cmd.PersistentFlags().DurationVar(&c.conf.WaitOnClose, "wait-after-sigterm", 0,
`Amount of time to wait for any open connections to close before
kurtisvg marked this conversation as resolved.
Show resolved Hide resolved
shutting down the proxy. When this flag is not set, the proxy
will shutdown immediately after receiving a SIGTERM.`)
cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "",
"Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.")
cmd.PersistentFlags().BoolVar(&c.disableTraces, "disable-traces", false,
Expand Down Expand Up @@ -434,7 +438,7 @@ func runSignalWrapper(cmd *Command) error {
cmd.Println("The proxy has started successfully and is ready for new connections!")
defer func() {
if cErr := p.Close(); cErr != nil {
cmd.PrintErrf("error during shutdown: %v\n", cErr)
cmd.PrintErrf("The proxy failed to close cleanly: %v\n", cErr)
}
}()

Expand All @@ -445,9 +449,9 @@ func runSignalWrapper(cmd *Command) error {
err := <-shutdownCh
switch {
case errors.Is(err, errSigInt):
cmd.PrintErrln("SIGINT signal received. Shuting down...")
cmd.PrintErrln("SIGINT signal received. Shutting down...")
case errors.Is(err, errSigTerm):
cmd.PrintErrln("SIGTERM signal received. Shuting down...")
cmd.PrintErrln("SIGTERM signal received. Shutting down...")
default:
cmd.PrintErrf("The proxy has encountered a terminal error: %v\n", err)
}
Expand Down
7 changes: 7 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ func TestNewCommandArguments(t *testing.T) {
MaxConnections: 1,
}),
},
{
desc: "using wait after signterm flag",
args: []string{"--wait-after-sigterm", "10s", "proj:region:inst"},
want: withDefaults(&proxy.Config{
WaitOnClose: 10 * time.Second,
}),
},
{
desc: "using the private-ip flag",
args: []string{"--private-ip", "proj:region:inst"},
Expand Down
42 changes: 38 additions & 4 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ type Config struct {
// connections. A zero-value indicates no limit.
MaxConnections uint64

// WaitOnClose sets the duration to wait for connections to close before
// shutting down. Not setting this field means to close immediately
// regardless of any open connections.
WaitOnClose time.Duration

// PrivateIP enables connections via the database server's private IP address
// for all instances.
PrivateIP bool
Expand Down Expand Up @@ -218,6 +223,10 @@ type Client struct {

// mnts is a list of all mounted sockets for this client
mnts []*socketMount

// waitOnClose is the maximum duration to wait for open connections to close
// when shutting down.
waitOnClose time.Duration
}

// NewClient completes the initial setup required to get the proxy to a "steady" state.
Expand Down Expand Up @@ -265,10 +274,11 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
mnts = append(mnts, m)
}
c := &Client{
mnts: mnts,
cmd: cmd,
dialer: d,
maxConns: conf.MaxConnections,
mnts: mnts,
cmd: cmd,
dialer: d,
maxConns: conf.MaxConnections,
waitOnClose: conf.WaitOnClose,
}
return c, nil
}
Expand Down Expand Up @@ -318,16 +328,40 @@ func (m MultiErr) Error() string {
// Close triggers the proxyClient to shutdown.
func (c *Client) Close() error {
var mErr MultiErr
// First, close all open socket listeners to prevent additional connections.
for _, m := range c.mnts {
err := m.Close()
if err != nil {
mErr = append(mErr, err)
}
}
// Next, close the dialer to prevent any additional refreshes.
cErr := c.dialer.Close()
if cErr != nil {
mErr = append(mErr, cErr)
}
if c.waitOnClose == 0 {
if len(mErr) > 0 {
return mErr
}
return nil
}
timeout := time.After(c.waitOnClose)
tick := time.Tick(100 * time.Millisecond)
for {
select {
case <-tick:
if atomic.LoadUint64(&c.connCount) > 0 {
continue
}
case <-timeout:
}
break
}
open := atomic.LoadUint64(&c.connCount)
if open > 0 {
mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.waitOnClose))
}
if len(mErr) > 0 {
return mErr
}
Expand Down
77 changes: 70 additions & 7 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type errorDialer struct {
fakeDialer
}

func (errorDialer) Close() error {
func (*errorDialer) Close() error {
return errors.New("errorDialer returns error on Close")
}

Expand Down Expand Up @@ -314,6 +314,71 @@ func TestClientLimitsMaxConnections(t *testing.T) {
}
}

func tryTCPDial(t *testing.T, addr string) net.Conn {
attempts := 10
var (
conn net.Conn
err error
)
for i := 0; i < attempts; i++ {
conn, err = net.Dial("tcp", addr)
if err != nil {
time.Sleep(100 * time.Millisecond)
continue
}
return conn
}

t.Fatalf("failed to dial in %v attempts: %v", attempts, err)
return nil
}

func TestClientCloseWaitsForActiveConnections(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Port: 5000,
Instances: []proxy.InstanceConnConfig{
{Name: "proj:region:pg"},
},
Dialer: &fakeDialer{},
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
go c.Serve(context.Background())

conn := tryTCPDial(t, "127.0.0.1:5000")
_ = conn.Close()

if err := c.Close(); err != nil {
t.Fatalf("c.Close error: %v", err)
}

in.WaitOnClose = time.Second
in.Port = 5001
c, err = proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
go c.Serve(context.Background())

var open []net.Conn
for i := 0; i < 5; i++ {
conn = tryTCPDial(t, "127.0.0.1:5001")
open = append(open, conn)
}
defer func() {
for _, o := range open {
o.Close()
}
}()

if err := c.Close(); err == nil {
t.Fatal("c.Close should error, got = nil")
}
}

func TestClientClosesCleanly(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Expand All @@ -328,12 +393,8 @@ func TestClientClosesCleanly(t *testing.T) {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

conn, dErr := net.Dial("tcp", "127.0.0.1:5000")
if dErr != nil {
t.Fatalf("net.Dial error = %v", dErr)
}
conn := tryTCPDial(t, "127.0.0.1:5000")
_ = conn.Close()

if err := c.Close(); err != nil {
Expand All @@ -355,7 +416,9 @@ func TestClosesWithError(t *testing.T) {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

conn := tryTCPDial(t, "127.0.0.1:5000")
defer conn.Close()

if err = c.Close(); err == nil {
t.Fatal("c.Close() should error, got nil")
Expand Down