diff --git a/compiler/kernel/vop.go b/compiler/kernel/vop.go index 869cc94a34..cb7cb4ef0b 100644 --- a/compiler/kernel/vop.go +++ b/compiler/kernel/vop.go @@ -84,7 +84,7 @@ func (b *Builder) compileVamFork(fork *dag.Fork, parents []vector.Puller) ([]vec for _, seq := range fork.Paths { var parent vector.Puller if f != nil && !isEntry(seq) { - parent = f.AddExit() + parent = f.AddBranch() } exit, err := b.compileVamSeq(seq, []vector.Puller{parent}) if err != nil { diff --git a/runtime/vam/op/fork.go b/runtime/vam/op/fork.go index 5a2976776d..343b25e2c5 100644 --- a/runtime/vam/op/fork.go +++ b/runtime/vam/op/fork.go @@ -2,97 +2,29 @@ package op import ( "context" - "sync" "github.com/brimdata/super/vector" ) type Fork struct { - ctx context.Context - parent vector.Puller - - branches []*forkBranch - nblocked int - once sync.Once + router *router } func NewFork(ctx context.Context, parent vector.Puller) *Fork { - return &Fork{ - ctx: ctx, - parent: parent, - } -} - -func (f *Fork) AddExit() vector.Puller { - branch := &forkBranch{f, make(chan result), make(chan struct{}), false} - f.branches = append(f.branches, branch) - return branch -} - -func (f *Fork) run() { - for { - if f.nblocked == len(f.branches) { - // Send done upstream. - if _, err := f.parent.Pull(true); err != nil { - for _, b := range f.branches { - select { - case b.resultCh <- result{nil, err}: - case <-f.ctx.Done(): - } - } - return - } - f.unblockBranches() - } - vec, err := f.parent.Pull(false) - for _, b := range f.branches { - if b.blocked { - continue - } - select { - case b.resultCh <- result{vec, err}: - case <-b.doneCh: - b.blocked = true - f.nblocked++ - case <-f.ctx.Done(): - return - } - } - if vec == nil && err == nil { - // EOS unblocks all branches. - f.unblockBranches() - } - } + f := &Fork{} + f.router = newRouter(ctx, f, parent) + return f } -func (f *Fork) unblockBranches() { - for _, b := range f.branches { - b.blocked = false - } - f.nblocked = 0 +func (f *Fork) AddBranch() vector.Puller { + return f.router.addRoute() } -type forkBranch struct { - fork *Fork - resultCh chan result - doneCh chan struct{} - blocked bool -} - -func (f *forkBranch) Pull(done bool) (vector.Any, error) { - f.fork.once.Do(func() { go f.fork.run() }) - if done { - select { - case f.doneCh <- struct{}{}: - return nil, nil - case <-f.fork.ctx.Done(): - return nil, f.fork.ctx.Err() +func (f *Fork) forward(vec vector.Any) bool { + for _, r := range f.router.routes { + if !r.send(vec, nil) { + return false } } - select { - case r := <-f.resultCh: - return r.vector, r.err - case <-f.fork.ctx.Done(): - return nil, f.fork.ctx.Err() - } + return true } diff --git a/runtime/vam/op/router.go b/runtime/vam/op/router.go new file mode 100644 index 0000000000..c7c4e6f200 --- /dev/null +++ b/runtime/vam/op/router.go @@ -0,0 +1,113 @@ +package op + +import ( + "context" + "sync" + + "github.com/brimdata/super/vector" +) + +type forwarder interface { + forward(vector.Any) bool +} + +type router struct { + ctx context.Context + forwarder forwarder + parent vector.Puller + + routes []*route + nblocked int + once sync.Once +} + +func newRouter(ctx context.Context, f forwarder, parent vector.Puller) *router { + return &router{ctx: ctx, forwarder: f, parent: parent} +} + +func (r *router) addRoute() vector.Puller { + route := &route{r, make(chan result), make(chan struct{}), false} + r.routes = append(r.routes, route) + return route +} + +func (r *router) run() { + for { + if r.nblocked == len(r.routes) { + // Send done upstream. + if _, err := r.parent.Pull(true); err != nil { + for _, route := range r.routes { + select { + case route.resultCh <- result{nil, err}: + case <-r.ctx.Done(): + } + } + return + } + r.unblockBranches() + } + vec, err := r.parent.Pull(false) + if vec != nil && err == nil { + if !r.forwarder.forward(vec) { + return + } + continue + } + for _, route := range r.routes { + if !route.send(vec, err) { + return + } + } + if vec == nil && err == nil { + // EOS unblocks all branches. + r.unblockBranches() + } + } +} + +func (r *router) unblockBranches() { + for _, route := range r.routes { + route.blocked = false + } + r.nblocked = 0 +} + +type route struct { + router *router + resultCh chan result + doneCh chan struct{} + blocked bool +} + +func (r *route) Pull(done bool) (vector.Any, error) { + r.router.once.Do(func() { go r.router.run() }) + if done { + select { + case r.doneCh <- struct{}{}: + return nil, nil + case <-r.router.ctx.Done(): + return nil, r.router.ctx.Err() + } + } + select { + case r := <-r.resultCh: + return r.vector, r.err + case <-r.router.ctx.Done(): + return nil, r.router.ctx.Err() + } +} + +func (r *route) send(vec vector.Any, err error) bool { + if r.blocked { + return true + } + select { + case r.resultCh <- result{vec, err}: + case <-r.doneCh: + r.blocked = true + r.router.nblocked++ + case <-r.router.ctx.Done(): + return false + } + return true +}