Skip to content

Commit

Permalink
ccl/sqlproxyccl: support waitResumed on the processor to block until …
Browse files Browse the repository at this point in the history
…resumption

Previously, there could be a race where suspend() was called right after
resuming the processors. If the processor goroutines have not started, suspend
will implicitly return, leading to a violation of an invariant, where we want
the processors to be suspended before proceeding. This commit adds a new
waitResumed method on the processor that allows callers to block until the
processors have been resumed.

Release justification: sqlproxy-only change.

Release note: None
  • Loading branch information
jaylim-crl committed Mar 10, 2022
1 parent 1f52149 commit 05b6f49
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 40 deletions.
34 changes: 30 additions & 4 deletions pkg/ccl/sqlproxyccl/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func forward(
metrics *metrics,
clientConn net.Conn,
serverConn net.Conn,
) *forwarder {
) (*forwarder, error) {
ctx, cancelFn := context.WithCancel(ctx)
f := &forwarder{
ctx: ctx,
Expand All @@ -100,8 +100,10 @@ func forward(
clockFn := makeLogicalClockFn()
f.request = newProcessor(clockFn, f.clientConn, f.serverConn) // client -> server
f.response = newProcessor(clockFn, f.serverConn, f.clientConn) // server -> client
f.resumeProcessors()
return f
if err := f.resumeProcessors(); err != nil {
return nil, err
}
return f, nil
}

// Close closes the forwarder and all connections. This is idempotent.
Expand Down Expand Up @@ -130,7 +132,7 @@ func (f *forwarder) Close() {
// asynchronously. The forwarder will be closed if any of the processors
// return an error while resuming. This is idempotent as resume() will return
// nil if the processor has already been started.
func (f *forwarder) resumeProcessors() {
func (f *forwarder) resumeProcessors() error {
go func() {
if err := f.request.resume(f.ctx); err != nil {
f.tryReportError(wrapClientToServerError(err))
Expand All @@ -141,6 +143,13 @@ func (f *forwarder) resumeProcessors() {
f.tryReportError(wrapServerToClientError(err))
}
}()
if err := f.request.waitResumed(f.ctx); err != nil {
return err
}
if err := f.response.waitResumed(f.ctx); err != nil {
return err
}
return nil
}

// tryReportError tries to send err to errCh, and closes the forwarder if
Expand Down Expand Up @@ -252,6 +261,7 @@ func (p *processor) resume(ctx context.Context) error {
return errProcessorResumed
}
p.mu.resumed = true
p.mu.cond.Broadcast()
return nil
}
exitResume := func() {
Expand Down Expand Up @@ -328,6 +338,22 @@ func (p *processor) resume(ctx context.Context) error {
return ctx.Err()
}

// waitResumed waits until the processor has been resumed. This can be used to
// ensure that suspend actually suspends the running processor, and there won't
// be a race where the goroutines have not started running, and suspend returns.
func (p *processor) waitResumed(ctx context.Context) error {
p.mu.Lock()
defer p.mu.Unlock()

for !p.mu.resumed {
if ctx.Err() != nil {
return ctx.Err()
}
p.mu.cond.Wait()
}
return nil
}

// suspend requests for the processor to be suspended if it is in a safe state,
// and blocks until the processor has been terminated. If the suspend request
// failed, suspend returns an error, and the caller is safe to retry again.
Expand Down
128 changes: 93 additions & 35 deletions pkg/ccl/sqlproxyccl/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ func TestForward(t *testing.T) {
t.Run("closed_when_processors_error", func(t *testing.T) {
p1, p2 := net.Pipe()

f, err := forward(bgCtx, nil /* connector */, nil /* metrics */, p1, p2)
defer f.Close()
require.NoError(t, err)

// Close the connection right away to simulate processor error.
p1.Close()

f := forward(bgCtx, nil /* connector */, nil /* metrics */, p1, p2)
defer f.Close()

// We have to wait for the goroutine to run. Once the forwarder stops,
// we're good.
testutils.SucceedsSoon(t, func() error {
Expand All @@ -60,17 +61,22 @@ func TestForward(t *testing.T) {
clientProxy, client := net.Pipe()
serverProxy, server := net.Pipe()

f := forward(ctx, nil /* connector */, nil /* metrics */, clientProxy, serverProxy)
f, err := forward(ctx, nil /* connector */, nil /* metrics */, clientProxy, serverProxy)
defer f.Close()
require.NoError(t, err)
require.Nil(t, f.ctx.Err())

request := f.request
initialClock := request.logicalClockFn()
requestProc := f.request
initialClock := requestProc.logicalClockFn()
barrier := make(chan struct{})
request.testingKnobs.beforeForwardMsg = func() {
requestProc.testingKnobs.beforeForwardMsg = func() {
<-barrier
}

requestProc.mu.Lock()
require.True(t, requestProc.mu.resumed)
requestProc.mu.Unlock()

// Client writes some pgwire messages.
errChan := make(chan error, 1)
go func() {
Expand Down Expand Up @@ -105,10 +111,10 @@ func TestForward(t *testing.T) {
require.True(t, ok)
require.Equal(t, "SELECT 1", m1.String)

request.mu.Lock()
require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), request.mu.lastMessageType)
require.Equal(t, initialClock+1, request.mu.lastMessageTransferredAt)
request.mu.Unlock()
requestProc.mu.Lock()
require.Equal(t, byte(pgwirebase.ClientMsgSimpleQuery), requestProc.mu.lastMessageType)
require.Equal(t, initialClock+1, requestProc.mu.lastMessageTransferredAt)
requestProc.mu.Unlock()

barrier <- struct{}{}
msg, err = backend.Receive()
Expand All @@ -118,10 +124,10 @@ func TestForward(t *testing.T) {
require.Equal(t, "foobar", m2.Portal)
require.Equal(t, uint32(42), m2.MaxRows)

request.mu.Lock()
require.Equal(t, byte(pgwirebase.ClientMsgExecute), request.mu.lastMessageType)
require.Equal(t, initialClock+2, request.mu.lastMessageTransferredAt)
request.mu.Unlock()
requestProc.mu.Lock()
require.Equal(t, byte(pgwirebase.ClientMsgExecute), requestProc.mu.lastMessageType)
require.Equal(t, initialClock+2, requestProc.mu.lastMessageTransferredAt)
requestProc.mu.Unlock()

barrier <- struct{}{}
msg, err = backend.Receive()
Expand All @@ -130,10 +136,10 @@ func TestForward(t *testing.T) {
require.True(t, ok)
require.Equal(t, byte('P'), m3.ObjectType)

request.mu.Lock()
require.Equal(t, byte(pgwirebase.ClientMsgClose), request.mu.lastMessageType)
require.Equal(t, initialClock+3, request.mu.lastMessageTransferredAt)
request.mu.Unlock()
requestProc.mu.Lock()
require.Equal(t, byte(pgwirebase.ClientMsgClose), requestProc.mu.lastMessageType)
require.Equal(t, initialClock+3, requestProc.mu.lastMessageTransferredAt)
requestProc.mu.Unlock()

select {
case err = <-errChan:
Expand All @@ -151,17 +157,22 @@ func TestForward(t *testing.T) {
clientProxy, client := net.Pipe()
serverProxy, server := net.Pipe()

f := forward(ctx, nil /* connector */, nil /* metrics */, clientProxy, serverProxy)
f, err := forward(ctx, nil /* connector */, nil /* metrics */, clientProxy, serverProxy)
defer f.Close()
require.NoError(t, err)
require.Nil(t, f.ctx.Err())

response := f.response
initialClock := response.logicalClockFn()
responseProc := f.response
initialClock := responseProc.logicalClockFn()
barrier := make(chan struct{})
response.testingKnobs.beforeForwardMsg = func() {
responseProc.testingKnobs.beforeForwardMsg = func() {
<-barrier
}

responseProc.mu.Lock()
require.True(t, responseProc.mu.resumed)
responseProc.mu.Unlock()

// Server writes some pgwire messages.
errChan := make(chan error, 1)
go func() {
Expand Down Expand Up @@ -191,10 +202,10 @@ func TestForward(t *testing.T) {
require.Equal(t, "100", m1.Code)
require.Equal(t, "foobarbaz", m1.Message)

response.mu.Lock()
require.Equal(t, byte(pgwirebase.ServerMsgErrorResponse), response.mu.lastMessageType)
require.Equal(t, initialClock+1, response.mu.lastMessageTransferredAt)
response.mu.Unlock()
responseProc.mu.Lock()
require.Equal(t, byte(pgwirebase.ServerMsgErrorResponse), responseProc.mu.lastMessageType)
require.Equal(t, initialClock+1, responseProc.mu.lastMessageTransferredAt)
responseProc.mu.Unlock()

barrier <- struct{}{}
msg, err = frontend.Receive()
Expand All @@ -203,10 +214,10 @@ func TestForward(t *testing.T) {
require.True(t, ok)
require.Equal(t, byte('I'), m2.TxStatus)

response.mu.Lock()
require.Equal(t, byte(pgwirebase.ServerMsgReady), response.mu.lastMessageType)
require.Equal(t, initialClock+2, response.mu.lastMessageTransferredAt)
response.mu.Unlock()
responseProc.mu.Lock()
require.Equal(t, byte(pgwirebase.ServerMsgReady), responseProc.mu.lastMessageType)
require.Equal(t, initialClock+2, responseProc.mu.lastMessageTransferredAt)
responseProc.mu.Unlock()

select {
case err = <-errChan:
Expand All @@ -221,8 +232,9 @@ func TestForwarder_Close(t *testing.T) {

p1, p2 := net.Pipe()

f := forward(context.Background(), nil /* connector */, nil /* metrics */, p1, p2)
f, err := forward(context.Background(), nil /* connector */, nil /* metrics */, p1, p2)
defer f.Close()
require.NoError(t, err)
require.Nil(t, f.ctx.Err())

f.Close()
Expand All @@ -234,8 +246,9 @@ func TestForwarder_tryReportError(t *testing.T) {

p1, p2 := net.Pipe()

f := forward(context.Background(), nil /* connector */, nil /* metrics */, p1, p2)
f, err := forward(context.Background(), nil /* connector */, nil /* metrics */, p1, p2)
defer f.Close()
require.NoError(t, err)

select {
case err := <-f.errCh:
Expand All @@ -255,7 +268,7 @@ func TestForwarder_tryReportError(t *testing.T) {
}

// Forwarder should be closed.
_, err := p1.Write([]byte("foobarbaz"))
_, err = p1.Write([]byte("foobarbaz"))
require.Regexp(t, "closed pipe", err)
require.EqualError(t, f.ctx.Err(), context.Canceled.Error())
}
Expand Down Expand Up @@ -367,6 +380,46 @@ func TestSuspendResumeProcessor(t *testing.T) {
require.EqualError(t, p.suspend(ctx), context.Canceled.Error())
})

t.Run("wait_for_resumed", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

clientProxy, serverProxy := net.Pipe()
defer clientProxy.Close()
defer serverProxy.Close()

p := newProcessor(
makeLogicalClockFn(),
interceptor.NewPGConn(clientProxy),
interceptor.NewPGConn(serverProxy),
)

errCh := make(chan error)
go func() {
errCh <- p.waitResumed(ctx)
}()

select {
case <-errCh:
t.Fatal("expected not resumed")
default:
// We're good.
}

go func() { _ = p.resume(ctx) }()

var lastErr error
require.Eventually(t, func() bool {
select {
case lastErr = <-errCh:
return true
default:
return false
}
}, 10*time.Second, 100*time.Millisecond)
require.NoError(t, lastErr)
})

// This tests that resume() and suspend() can be called multiple times.
// As an aside, we also check that we can suspend when blocked on PeekMsg
// because there are no messages.
Expand Down Expand Up @@ -395,6 +448,8 @@ func TestSuspendResumeProcessor(t *testing.T) {
require.EqualError(t, err, errProcessorResumed.Error())

// Suspend the last goroutine.
err = p.waitResumed(ctx)
require.NoError(t, err)
err = p.suspend(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -529,6 +584,9 @@ func TestSuspendResumeProcessor(t *testing.T) {
// have been forwarded.
go func(p *processor) { _ = p.resume(ctx) }(p)

err := p.waitResumed(ctx)
require.NoError(t, err)

// Now read all the messages on the server for correctness.
for i := 0; i < queryCount; i++ {
msg := <-msgCh
Expand All @@ -542,7 +600,7 @@ func TestSuspendResumeProcessor(t *testing.T) {
}

// Suspend the final goroutine.
err := p.suspend(ctx)
err = p.suspend(ctx)
require.NoError(t, err)
})
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,13 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn
}()

// Pass ownership of conn and crdbConn to the forwarder.
f := forward(ctx, connector, handler.metrics, conn, crdbConn)
f, err := forward(ctx, connector, handler.metrics, conn, crdbConn)
defer f.Close()
if err != nil {
// Don't send to the client here for the same reason below.
handler.metrics.updateForError(err)
return err
}

// Block until an error is received, or when the stopper starts quiescing,
// whichever that happens first.
Expand Down

0 comments on commit 05b6f49

Please sign in to comment.