Skip to content

Commit

Permalink
feat: add max-sigterm-delay flag (#1256)
Browse files Browse the repository at this point in the history
This is the v2 version of term_timeout and will make proxy wait at most
for the specified duration before exiting once a SIGTERM has been
received by the process.
  • Loading branch information
enocom authored Jul 7, 2022
1 parent 188b089 commit 73f509a
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 14 deletions.
12 changes: 9 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ any client SSL certificates.`,
cmd.PersistentFlags().Uint64Var(&c.conf.MaxConnections, "max-connections", 0,
`Limits the number of connections by refusing any additional connections.
When this flag is not set, there is no limit.`)
cmd.PersistentFlags().DurationVar(&c.conf.WaitOnClose, "max-sigterm-delay", 0,
`Maximum amount of time to wait after for any open connections
to close after receiving a TERM signal. The proxy will shut
down when the number of open connections reaches 0 or when
the maximum time has passed. Defaults to 0s.`)

cmd.PersistentFlags().StringVar(&c.telemetryProject, "telemetry-project", "",
"Enable Cloud Monitoring and Cloud Trace integration with the provided project ID.")
cmd.PersistentFlags().BoolVar(&c.disableTraces, "disable-traces", false,
Expand Down Expand Up @@ -434,7 +440,7 @@ func runSignalWrapper(cmd *Command) error {
cmd.Println("The proxy has started successfully and is ready for new connections!")
defer func() {
if cErr := p.Close(); cErr != nil {
cmd.PrintErrf("error during shutdown: %v\n", cErr)
cmd.PrintErrf("The proxy failed to close cleanly: %v\n", cErr)
}
}()

Expand All @@ -445,9 +451,9 @@ func runSignalWrapper(cmd *Command) error {
err := <-shutdownCh
switch {
case errors.Is(err, errSigInt):
cmd.PrintErrln("SIGINT signal received. Shuting down...")
cmd.PrintErrln("SIGINT signal received. Shutting down...")
case errors.Is(err, errSigTerm):
cmd.PrintErrln("SIGTERM signal received. Shuting down...")
cmd.PrintErrln("SIGTERM signal received. Shutting down...")
default:
cmd.PrintErrf("The proxy has encountered a terminal error: %v\n", err)
}
Expand Down
7 changes: 7 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@ func TestNewCommandArguments(t *testing.T) {
MaxConnections: 1,
}),
},
{
desc: "using wait after signterm flag",
args: []string{"--max-sigterm-delay", "10s", "proj:region:inst"},
want: withDefaults(&proxy.Config{
WaitOnClose: 10 * time.Second,
}),
},
{
desc: "using the private-ip flag",
args: []string{"--private-ip", "proj:region:inst"},
Expand Down
42 changes: 38 additions & 4 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ type Config struct {
// connections. A zero-value indicates no limit.
MaxConnections uint64

// WaitOnClose sets the duration to wait for connections to close before
// shutting down. Not setting this field means to close immediately
// regardless of any open connections.
WaitOnClose time.Duration

// PrivateIP enables connections via the database server's private IP address
// for all instances.
PrivateIP bool
Expand Down Expand Up @@ -218,6 +223,10 @@ type Client struct {

// mnts is a list of all mounted sockets for this client
mnts []*socketMount

// waitOnClose is the maximum duration to wait for open connections to close
// when shutting down.
waitOnClose time.Duration
}

// NewClient completes the initial setup required to get the proxy to a "steady" state.
Expand Down Expand Up @@ -265,10 +274,11 @@ func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client,
mnts = append(mnts, m)
}
c := &Client{
mnts: mnts,
cmd: cmd,
dialer: d,
maxConns: conf.MaxConnections,
mnts: mnts,
cmd: cmd,
dialer: d,
maxConns: conf.MaxConnections,
waitOnClose: conf.WaitOnClose,
}
return c, nil
}
Expand Down Expand Up @@ -318,16 +328,40 @@ func (m MultiErr) Error() string {
// Close triggers the proxyClient to shutdown.
func (c *Client) Close() error {
var mErr MultiErr
// First, close all open socket listeners to prevent additional connections.
for _, m := range c.mnts {
err := m.Close()
if err != nil {
mErr = append(mErr, err)
}
}
// Next, close the dialer to prevent any additional refreshes.
cErr := c.dialer.Close()
if cErr != nil {
mErr = append(mErr, cErr)
}
if c.waitOnClose == 0 {
if len(mErr) > 0 {
return mErr
}
return nil
}
timeout := time.After(c.waitOnClose)
tick := time.Tick(100 * time.Millisecond)
for {
select {
case <-tick:
if atomic.LoadUint64(&c.connCount) > 0 {
continue
}
case <-timeout:
}
break
}
open := atomic.LoadUint64(&c.connCount)
if open > 0 {
mErr = append(mErr, fmt.Errorf("%d connection(s) still open after waiting %v", open, c.waitOnClose))
}
if len(mErr) > 0 {
return mErr
}
Expand Down
77 changes: 70 additions & 7 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type errorDialer struct {
fakeDialer
}

func (errorDialer) Close() error {
func (*errorDialer) Close() error {
return errors.New("errorDialer returns error on Close")
}

Expand Down Expand Up @@ -314,6 +314,71 @@ func TestClientLimitsMaxConnections(t *testing.T) {
}
}

func tryTCPDial(t *testing.T, addr string) net.Conn {
attempts := 10
var (
conn net.Conn
err error
)
for i := 0; i < attempts; i++ {
conn, err = net.Dial("tcp", addr)
if err != nil {
time.Sleep(100 * time.Millisecond)
continue
}
return conn
}

t.Fatalf("failed to dial in %v attempts: %v", attempts, err)
return nil
}

func TestClientCloseWaitsForActiveConnections(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Port: 5000,
Instances: []proxy.InstanceConnConfig{
{Name: "proj:region:pg"},
},
Dialer: &fakeDialer{},
}
c, err := proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
go c.Serve(context.Background())

conn := tryTCPDial(t, "127.0.0.1:5000")
_ = conn.Close()

if err := c.Close(); err != nil {
t.Fatalf("c.Close error: %v", err)
}

in.WaitOnClose = time.Second
in.Port = 5001
c, err = proxy.NewClient(context.Background(), &cobra.Command{}, in)
if err != nil {
t.Fatalf("proxy.NewClient error: %v", err)
}
go c.Serve(context.Background())

var open []net.Conn
for i := 0; i < 5; i++ {
conn = tryTCPDial(t, "127.0.0.1:5001")
open = append(open, conn)
}
defer func() {
for _, o := range open {
o.Close()
}
}()

if err := c.Close(); err == nil {
t.Fatal("c.Close should error, got = nil")
}
}

func TestClientClosesCleanly(t *testing.T) {
in := &proxy.Config{
Addr: "127.0.0.1",
Expand All @@ -328,12 +393,8 @@ func TestClientClosesCleanly(t *testing.T) {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

conn, dErr := net.Dial("tcp", "127.0.0.1:5000")
if dErr != nil {
t.Fatalf("net.Dial error = %v", dErr)
}
conn := tryTCPDial(t, "127.0.0.1:5000")
_ = conn.Close()

if err := c.Close(); err != nil {
Expand All @@ -355,7 +416,9 @@ func TestClosesWithError(t *testing.T) {
t.Fatalf("proxy.NewClient error want = nil, got = %v", err)
}
go c.Serve(context.Background())
time.Sleep(time.Second) // allow the socket to start listening

conn := tryTCPDial(t, "127.0.0.1:5000")
defer conn.Close()

if err = c.Close(); err == nil {
t.Fatal("c.Close() should error, got nil")
Expand Down

0 comments on commit 73f509a

Please sign in to comment.