Skip to content

Commit

Permalink
engine: fix race when reading fields in Concatenate (#14324)
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <vmg@strn.cat>
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
Signed-off-by: Andres Taylor <andres@planetscale.com>
Co-authored-by: Dirkjan Bussink <d.bussink@gmail.com>
Co-authored-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
3 people authored Oct 23, 2023
1 parent 24fb7d0 commit 2dab2cb
Showing 1 changed file with 69 additions and 53 deletions.
122 changes: 69 additions & 53 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package engine

import (
"context"
"slices"
"sync"
"sync/atomic"

"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -236,92 +238,106 @@ func (c *Concatenate) TryStreamExecute(ctx context.Context, vcursor VCursor, bin
}

func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, in func(*sqltypes.Result) error) error {
// Scoped context; any early exit triggers cancel() to clean up ongoing work.
ctx, cancel := context.WithCancel(inCtx)
defer cancel()
var outerErr error

var cbMu sync.Mutex
var wg, fieldMu sync.WaitGroup
var fieldRec atomic.Int64
fieldRec.Store(int64(len(c.Sources)))
fieldMu.Add(1)

rest := make([]*sqltypes.Result, len(c.Sources))
var fields []*querypb.Field
// Mutexes for dealing with concurrent access to shared state.
var (
muCallback sync.Mutex // Protects callback
muFields sync.Mutex // Protects field state
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fields []*querypb.Field // Cached final field types
)

// Process each result chunk, considering type coercion.
callback := func(res *sqltypes.Result, srcIdx int) error {
cbMu.Lock()
defer cbMu.Unlock()
muCallback.Lock()
defer muCallback.Unlock()

// Check if type coercion needed for this source.
// We only need to check if fields are not in NoNeedToTypeCheck set.
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, ok := c.NoNeedToTypeCheck[idx]
if !ok && fields[idx].Type != field.Type {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fields[idx].Type != field.Type {
needsCoercion = true
break
}
}

// Apply type coercion if needed.
if needsCoercion {
for _, row := range res.Rows {
err := c.coerceValuesTo(row, fields)
if err != nil {
if err := c.coerceValuesTo(row, fields); err != nil {
return err
}
}
}
return in(res)
}

once := sync.Once{}

// Start streaming query execution in parallel for all sources.
for i, source := range c.Sources {
wg.Add(1)
currIndex, currSource := i, source

go func() {
defer wg.Done()
wg.Go(func() error {
err := vcursor.StreamExecutePrimitive(ctx, currSource, bindVars, true, func(resultChunk *sqltypes.Result) error {
// if we have fields to compare, make sure all the fields are all the same
if fieldRec.Load() > 0 && resultChunk.Fields != nil {
rest[currIndex] = resultChunk
res := fieldRec.Add(-1)
if res == 0 {
// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
// Process fields when they arrive; coordinate field agreement across sources.
if resultChunk.Fields != nil {
muFields.Lock()

// Capture the initial result chunk to determine field types later.
if rest[currIndex] == nil {
rest[currIndex] = resultChunk

// If this was the last source to report its fields, derive the final output fields.
if !slices.Contains(rest, nil) {
muFields.Unlock()

// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
if err != nil {
return err
}
resultChunk.Fields = fields

defer condFields.Broadcast()
return callback(resultChunk, currIndex)
}
resultChunk.Fields = fields
defer once.Do(func() {
fieldMu.Done()
})

return callback(resultChunk, currIndex)
} else {
fieldMu.Wait()
}
// Wait for fields from all sources.
for slices.Contains(rest, nil) {
condFields.Wait()
}
muFields.Unlock()
}

// If we get here, all the fields have been received
select {
case <-ctx.Done():
// Context check to avoid extra work.
if ctx.Err() != nil {
return nil
default:
return callback(resultChunk, currIndex)
}
return callback(resultChunk, currIndex)
})

// Error handling and context cleanup for this source.
if err != nil {
outerErr = err
muFields.Lock()
if rest[currIndex] == nil {
// Signal that this source is done, even if by failure, to unblock field waiting.
rest[currIndex] = &sqltypes.Result{}
}
cancel()
once.Do(func() {
fieldMu.Done()
})
condFields.Broadcast()
muFields.Unlock()
}
}()

return err
})
}
wg.Wait()
return outerErr
// Wait for all sources to complete.
return wg.Wait()
}

func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) error {
Expand Down

0 comments on commit 2dab2cb

Please sign in to comment.