Skip to content

Commit

Permalink
plan: fix race conditions in Exchange node
Browse files Browse the repository at this point in the history
Fixes src-d#828

Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed Sep 30, 2019
1 parent 2e82b0a commit 4a21aa6
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions sql/plan/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ type exchangeRowIter struct {
parallelism int
partitions sql.PartitionIter
tree sql.Node
mut sync.Mutex
tokens chan struct{}
mut sync.RWMutex
tokensChan chan struct{}
started bool
rows chan sql.Row
err chan error
quit chan struct{}

quitMut sync.RWMutex
quitChan chan struct{}
}

func newExchangeRowIter(
Expand All @@ -97,34 +99,40 @@ func newExchangeRowIter(
started: false,
tree: tree,
partitions: iter,
quit: make(chan struct{}),
quitChan: make(chan struct{}),
}
}

func (it *exchangeRowIter) releaseToken() {
it.mut.Lock()
defer it.mut.Unlock()

if it.tokens != nil {
it.tokens <- struct{}{}
if it.tokensChan != nil {
it.tokensChan <- struct{}{}
}
}

func (it *exchangeRowIter) closeTokens() {
it.mut.Lock()
defer it.mut.Unlock()

close(it.tokens)
it.tokens = nil
close(it.tokensChan)
it.tokensChan = nil
}

func (it *exchangeRowIter) tokens() chan struct{} {
it.mut.RLock()
defer it.mut.RUnlock()
return it.tokensChan
}

func (it *exchangeRowIter) fillTokens() {
it.mut.Lock()
defer it.mut.Unlock()

it.tokens = make(chan struct{}, it.parallelism)
it.tokensChan = make(chan struct{}, it.parallelism)
for i := 0; i < it.parallelism; i++ {
it.tokens <- struct{}{}
it.tokensChan <- struct{}{}
}
}

Expand All @@ -142,7 +150,7 @@ func (it *exchangeRowIter) start() {
it.err <- context.Canceled
it.closeTokens()
return
case <-it.quit:
case <-it.quit():
it.closeTokens()
return
case p, ok := <-partitions:
Expand Down Expand Up @@ -179,9 +187,9 @@ func (it *exchangeRowIter) iterPartitions(ch chan<- sql.Partition) {
case <-it.ctx.Done():
it.err <- context.Canceled
return
case <-it.quit:
case <-it.quit():
return
case <-it.tokens:
case <-it.tokens():
}

p, err := it.partitions.Next()
Expand Down Expand Up @@ -226,7 +234,7 @@ func (it *exchangeRowIter) iterPartition(p sql.Partition) {
case <-it.ctx.Done():
it.err <- context.Canceled
return
case <-it.quit:
case <-it.quit():
return
default:
}
Expand Down Expand Up @@ -263,17 +271,29 @@ func (it *exchangeRowIter) Next() (sql.Row, error) {
}
}

func (it *exchangeRowIter) Close() (err error) {
if it.quit != nil {
close(it.quit)
it.quit = nil
func (it *exchangeRowIter) quit() chan struct{} {
it.quitMut.RLock()
defer it.quitMut.RUnlock()
if it.quitChan == nil {
return nil
}

return it.quitChan
}

func (it *exchangeRowIter) Close() error {
it.quitMut.Lock()
if it.quitChan != nil {
close(it.quitChan)
it.quitChan = nil
}
it.quitMut.Unlock()

if it.partitions != nil {
err = it.partitions.Close()
return it.partitions.Close()
}

return err
return nil
}

type exchangePartition struct {
Expand Down

0 comments on commit 4a21aa6

Please sign in to comment.