From d015b275e61fa8b38cffd77120867e19e66aad97 Mon Sep 17 00:00:00 2001 From: Matthew Nibecker Date: Wed, 13 Mar 2024 15:01:58 -0700 Subject: [PATCH] Reset stateful expressions on EOS This commit fixes a bug occurring where stateful expressions were not getting reset when an op encounters EOS. This would most notably result in unexpected values when using stateful expressions inside of lateral queries. Closes #4943 --- compiler/kernel/expr.go | 179 ++++++++++-------- compiler/kernel/filter.go | 4 +- compiler/kernel/groupby.go | 21 +- compiler/kernel/op.go | 109 ++++++----- runtime/sam/expr/agg.go | 33 +++- runtime/sam/op/apply.go | 17 +- runtime/sam/op/explode/explode.go | 21 +- runtime/sam/op/exprswitch/exprswitch.go | 10 +- runtime/sam/op/groupby/groupby.go | 9 +- runtime/sam/op/join/join.go | 9 +- runtime/sam/op/merge/merge.go | 13 +- runtime/sam/op/merge/merge_test.go | 3 +- runtime/sam/op/meta/sequence.go | 2 +- runtime/sam/op/router.go | 6 + runtime/sam/op/sort/sort.go | 8 +- runtime/sam/op/switcher/switch.go | 6 +- runtime/sam/op/top/top.go | 5 +- runtime/sam/op/traverse/over.go | 24 ++- runtime/sam/op/yield/yield.go | 13 +- .../sam/op/ztests/stateful-expr-reset.yaml | 83 ++++++++ 20 files changed, 381 insertions(+), 194 deletions(-) create mode 100644 runtime/sam/op/ztests/stateful-expr-reset.yaml diff --git a/compiler/kernel/expr.go b/compiler/kernel/expr.go index 2ed8f9d346..8c9443b320 100644 --- a/compiler/kernel/expr.go +++ b/compiler/kernel/expr.go @@ -16,6 +16,19 @@ import ( "golang.org/x/text/unicode/norm" ) +type exprContext struct { + resetters expr.Resetters +} + +func newExprContext() *exprContext { return new(exprContext) } + +func (r *exprContext) addResetter(resetter expr.Resetter) { + if r == nil { + return + } + r.resetters = append(r.resetters, resetter) +} + // compileExpr compiles the given Expression into an object // that evaluates the expression against a provided Record. It returns an // error if compilation fails for any reason. @@ -46,7 +59,7 @@ import ( // TBD: string values and net.IP address do not need to be copied because they // are allocated by go libraries and temporary buffers are not used. This will // change down the road when we implement no-allocation string and IP conversion. -func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { +func (b *Builder) compileExpr(ectx *exprContext, e dag.Expr) (expr.Evaluator, error) { if e == nil { return nil, errors.New("null expression not allowed") } @@ -60,39 +73,44 @@ func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { case *dag.Var: return expr.NewVar(e.Slot), nil case *dag.Search: - return b.compileSearch(e) + return b.compileSearch(ectx, e) case *dag.This: return expr.NewDottedExpr(b.zctx(), field.Path(e.Path)), nil case *dag.Dot: - return b.compileDotExpr(e) + return b.compileDotExpr(ectx, e) case *dag.UnaryExpr: - return b.compileUnary(*e) + return b.compileUnary(ectx, *e) case *dag.BinaryExpr: - return b.compileBinary(e) + return b.compileBinary(ectx, e) case *dag.Conditional: - return b.compileConditional(*e) + return b.compileConditional(ectx, *e) case *dag.Call: - return b.compileCall(*e) + return b.compileCall(ectx, *e) case *dag.RegexpMatch: - return b.compileRegexpMatch(e) + return b.compileRegexpMatch(ectx, e) case *dag.RegexpSearch: - return b.compileRegexpSearch(e) + return b.compileRegexpSearch(ectx, e) case *dag.RecordExpr: - return b.compileRecordExpr(e) + return b.compileRecordExpr(ectx, e) case *dag.ArrayExpr: - return b.compileArrayExpr(e) + return b.compileArrayExpr(ectx, e) case *dag.SetExpr: - return b.compileSetExpr(e) + return b.compileSetExpr(ectx, e) case *dag.MapCall: - return b.compileMapCall(e) + return b.compileMapCall(ectx, e) case *dag.MapExpr: - return b.compileMapExpr(e) + return b.compileMapExpr(ectx, e) case *dag.Agg: - agg, err := b.compileAgg(e) + agg, err := b.compileAgg(ectx, e) if err != nil { return nil, err } - return expr.NewAggregatorExpr(b.zctx(), agg), nil + if ectx == nil { + panic("system error: exprContext is nil") + } + aggexpr := expr.NewAggregatorExpr(b.zctx(), agg) + ectx.addResetter(aggexpr) + return aggexpr, nil case *dag.OverExpr: return b.compileOverExpr(e) default: @@ -100,31 +118,31 @@ func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { } } -func (b *Builder) compileExprWithEmpty(e dag.Expr) (expr.Evaluator, error) { +func (b *Builder) compileExprWithEmpty(ectx *exprContext, e dag.Expr) (expr.Evaluator, error) { if e == nil { return nil, nil } - return b.compileExpr(e) + return b.compileExpr(ectx, e) } -func (b *Builder) compileBinary(e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileBinary(ectx *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { if slice, ok := e.RHS.(*dag.BinaryExpr); ok && slice.Op == ":" { - return b.compileSlice(e.LHS, slice) + return b.compileSlice(ectx, e.LHS, slice) } if e.Op == "in" { // Do a faster comparison if the LHS is a compile-time constant expression. - if in, err := b.compileConstIn(e); in != nil && err == nil { + if in, err := b.compileConstIn(ectx, e); in != nil && err == nil { return in, err } } - if e, err := b.compileConstCompare(e); e != nil && err == nil { + if e, err := b.compileConstCompare(ectx, e); e != nil && err == nil { return e, nil } - lhs, err := b.compileExpr(e.LHS) + lhs, err := b.compileExpr(ectx, e.LHS) if err != nil { return nil, err } - rhs, err := b.compileExpr(e.RHS) + rhs, err := b.compileExpr(ectx, e.RHS) if err != nil { return nil, err } @@ -148,7 +166,7 @@ func (b *Builder) compileBinary(e *dag.BinaryExpr) (expr.Evaluator, error) { } } -func (b *Builder) compileConstIn(e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileConstIn(r *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { literal, err := b.evalAtCompileTime(e.LHS) if err != nil || literal.IsError() { // If the RHS here is a literal value, it would be good @@ -161,14 +179,14 @@ func (b *Builder) compileConstIn(e *dag.BinaryExpr) (expr.Evaluator, error) { if eql == nil || err != nil { return nil, nil } - operand, err := b.compileExpr(e.RHS) + operand, err := b.compileExpr(r, e.RHS) if err != nil { return nil, err } return expr.NewFilter(operand, expr.Contains(eql)), nil } -func (b *Builder) compileConstCompare(e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileConstCompare(ectx *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { switch e.Op { case "==", "!=", "<", "<=", ">", ">=": default: @@ -185,19 +203,19 @@ func (b *Builder) compileConstCompare(e *dag.BinaryExpr) (expr.Evaluator, error) // non-error situation that isn't a simple comparison. return nil, nil } - operand, err := b.compileExpr(e.LHS) + operand, err := b.compileExpr(ectx, e.LHS) if err != nil { return nil, err } return expr.NewFilter(operand, comparison), nil } -func (b *Builder) compileSearch(search *dag.Search) (expr.Evaluator, error) { +func (b *Builder) compileSearch(ectx *exprContext, search *dag.Search) (expr.Evaluator, error) { val, err := zson.ParseValue(b.zctx(), search.Value) if err != nil { return nil, err } - e, err := b.compileExpr(search.Expr) + e, err := b.compileExpr(ectx, search.Expr) if err != nil { return nil, err } @@ -210,24 +228,24 @@ func (b *Builder) compileSearch(search *dag.Search) (expr.Evaluator, error) { return expr.NewSearch(search.Text, val, e) } -func (b *Builder) compileSlice(container dag.Expr, slice *dag.BinaryExpr) (expr.Evaluator, error) { - from, err := b.compileExprWithEmpty(slice.LHS) +func (b *Builder) compileSlice(ectx *exprContext, container dag.Expr, slice *dag.BinaryExpr) (expr.Evaluator, error) { + from, err := b.compileExprWithEmpty(ectx, slice.LHS) if err != nil { return nil, err } - to, err := b.compileExprWithEmpty(slice.RHS) + to, err := b.compileExprWithEmpty(ectx, slice.RHS) if err != nil { return nil, err } - e, err := b.compileExpr(container) + e, err := b.compileExpr(ectx, container) if err != nil { return nil, err } return expr.NewSlice(b.zctx(), e, from, to), nil } -func (b *Builder) compileUnary(unary dag.UnaryExpr) (expr.Evaluator, error) { - e, err := b.compileExpr(unary.Operand) +func (b *Builder) compileUnary(ectx *exprContext, unary dag.UnaryExpr) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, unary.Operand) if err != nil { return nil, err } @@ -241,48 +259,48 @@ func (b *Builder) compileUnary(unary dag.UnaryExpr) (expr.Evaluator, error) { } } -func (b *Builder) compileConditional(node dag.Conditional) (expr.Evaluator, error) { - predicate, err := b.compileExpr(node.Cond) +func (b *Builder) compileConditional(ectx *exprContext, node dag.Conditional) (expr.Evaluator, error) { + predicate, err := b.compileExpr(ectx, node.Cond) if err != nil { return nil, err } - thenExpr, err := b.compileExpr(node.Then) + thenExpr, err := b.compileExpr(ectx, node.Then) if err != nil { return nil, err } - elseExpr, err := b.compileExpr(node.Else) + elseExpr, err := b.compileExpr(ectx, node.Else) if err != nil { return nil, err } return expr.NewConditional(b.zctx(), predicate, thenExpr, elseExpr), nil } -func (b *Builder) compileDotExpr(dot *dag.Dot) (expr.Evaluator, error) { - record, err := b.compileExpr(dot.LHS) +func (b *Builder) compileDotExpr(ectx *exprContext, dot *dag.Dot) (expr.Evaluator, error) { + record, err := b.compileExpr(ectx, dot.LHS) if err != nil { return nil, err } return expr.NewDotExpr(b.zctx(), record, dot.RHS), nil } -func (b *Builder) compileLval(e dag.Expr) (*expr.Lval, error) { +func (b *Builder) compileLval(ectx *exprContext, e dag.Expr) (*expr.Lval, error) { switch e := e.(type) { case *dag.BinaryExpr: if e.Op != "[" { return nil, fmt.Errorf("internal error: invalid lval %#v", e) } - lhs, err := b.compileLval(e.LHS) + lhs, err := b.compileLval(ectx, e.LHS) if err != nil { return nil, err } - rhs, err := b.compileExpr(e.RHS) + rhs, err := b.compileExpr(ectx, e.RHS) if err != nil { return nil, err } lhs.Elems = append(lhs.Elems, expr.NewExprLvalElem(b.zctx(), rhs)) return lhs, nil case *dag.Dot: - lhs, err := b.compileLval(e.LHS) + lhs, err := b.compileLval(ectx, e.LHS) if err != nil { return nil, err } @@ -298,21 +316,21 @@ func (b *Builder) compileLval(e dag.Expr) (*expr.Lval, error) { return nil, fmt.Errorf("internal error: invalid lval %#v", e) } -func (b *Builder) compileAssignment(node *dag.Assignment) (expr.Assignment, error) { - lhs, err := b.compileLval(node.LHS) +func (b *Builder) compileAssignment(ectx *exprContext, node *dag.Assignment) (expr.Assignment, error) { + lhs, err := b.compileLval(ectx, node.LHS) if err != nil { return expr.Assignment{}, err } - rhs, err := b.compileExpr(node.RHS) + rhs, err := b.compileExpr(ectx, node.RHS) if err != nil { return expr.Assignment{}, fmt.Errorf("rhs of assigment expression: %w", err) } return expr.Assignment{LHS: lhs, RHS: rhs}, err } -func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { +func (b *Builder) compileCall(ectx *exprContext, call dag.Call) (expr.Evaluator, error) { if tf := expr.NewShaperTransform(call.Name); tf != 0 { - return b.compileShaper(call, tf) + return b.compileShaper(ectx, call, tf) } var path field.Path // First check if call is to a user defined function, otherwise check for @@ -330,42 +348,42 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { dagPath := &dag.This{Kind: "This", Path: path} args = append([]dag.Expr{dagPath}, args...) } - exprs, err := b.compileExprs(args) + exprs, err := b.compileExprs(ectx, args) if err != nil { return nil, fmt.Errorf("%s(): bad argument: %w", call.Name, err) } return expr.NewCall(b.zctx(), fn, exprs), nil } -func (b *Builder) compileMapCall(a *dag.MapCall) (expr.Evaluator, error) { - e, err := b.compileExpr(a.Expr) +func (b *Builder) compileMapCall(ectx *exprContext, a *dag.MapCall) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, a.Expr) if err != nil { return nil, err } - inner, err := b.compileExpr(a.Inner) + inner, err := b.compileExpr(ectx, a.Inner) if err != nil { return nil, err } return expr.NewMapCall(b.zctx(), e, inner), nil } -func (b *Builder) compileShaper(node dag.Call, tf expr.ShaperTransform) (expr.Evaluator, error) { +func (b *Builder) compileShaper(ectx *exprContext, node dag.Call, tf expr.ShaperTransform) (expr.Evaluator, error) { args := node.Args - field, err := b.compileExpr(args[0]) + field, err := b.compileExpr(ectx, args[0]) if err != nil { return nil, err } - typExpr, err := b.compileExpr(args[1]) + typExpr, err := b.compileExpr(ectx, args[1]) if err != nil { return nil, err } return expr.NewShaper(b.zctx(), field, typExpr, tf) } -func (b *Builder) compileExprs(in []dag.Expr) ([]expr.Evaluator, error) { +func (b *Builder) compileExprs(ectx *exprContext, in []dag.Expr) ([]expr.Evaluator, error) { var exprs []expr.Evaluator for _, e := range in { - ev, err := b.compileExpr(e) + ev, err := b.compileExpr(ectx, e) if err != nil { return nil, err } @@ -374,8 +392,8 @@ func (b *Builder) compileExprs(in []dag.Expr) ([]expr.Evaluator, error) { return exprs, nil } -func (b *Builder) compileRegexpMatch(match *dag.RegexpMatch) (expr.Evaluator, error) { - e, err := b.compileExpr(match.Expr) +func (b *Builder) compileRegexpMatch(ectx *exprContext, match *dag.RegexpMatch) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, match.Expr) if err != nil { return nil, err } @@ -386,8 +404,8 @@ func (b *Builder) compileRegexpMatch(match *dag.RegexpMatch) (expr.Evaluator, er return expr.NewRegexpMatch(re, e), nil } -func (b *Builder) compileRegexpSearch(search *dag.RegexpSearch) (expr.Evaluator, error) { - e, err := b.compileExpr(search.Expr) +func (b *Builder) compileRegexpSearch(ectx *exprContext, search *dag.RegexpSearch) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, search.Expr) if err != nil { return nil, err } @@ -399,12 +417,12 @@ func (b *Builder) compileRegexpSearch(search *dag.RegexpSearch) (expr.Evaluator, return expr.SearchByPredicate(expr.Contains(match), e), nil } -func (b *Builder) compileRecordExpr(record *dag.RecordExpr) (expr.Evaluator, error) { +func (b *Builder) compileRecordExpr(ectx *exprContext, record *dag.RecordExpr) (expr.Evaluator, error) { var elems []expr.RecordElem for _, elem := range record.Elems { switch elem := elem.(type) { case *dag.Field: - e, err := b.compileExpr(elem.Value) + e, err := b.compileExpr(ectx, elem.Value) if err != nil { return nil, err } @@ -413,7 +431,7 @@ func (b *Builder) compileRecordExpr(record *dag.RecordExpr) (expr.Evaluator, err Field: e, }) case *dag.Spread: - e, err := b.compileExpr(elem.Expr) + e, err := b.compileExpr(ectx, elem.Expr) if err != nil { return nil, err } @@ -423,34 +441,34 @@ func (b *Builder) compileRecordExpr(record *dag.RecordExpr) (expr.Evaluator, err return expr.NewRecordExpr(b.zctx(), elems) } -func (b *Builder) compileArrayExpr(array *dag.ArrayExpr) (expr.Evaluator, error) { - elems, err := b.compileVectorElems(array.Elems) +func (b *Builder) compileArrayExpr(ectx *exprContext, array *dag.ArrayExpr) (expr.Evaluator, error) { + elems, err := b.compileVectorElems(ectx, array.Elems) if err != nil { return nil, err } return expr.NewArrayExpr(b.zctx(), elems), nil } -func (b *Builder) compileSetExpr(set *dag.SetExpr) (expr.Evaluator, error) { - elems, err := b.compileVectorElems(set.Elems) +func (b *Builder) compileSetExpr(ectx *exprContext, set *dag.SetExpr) (expr.Evaluator, error) { + elems, err := b.compileVectorElems(ectx, set.Elems) if err != nil { return nil, err } return expr.NewSetExpr(b.zctx(), elems), nil } -func (b *Builder) compileVectorElems(elems []dag.VectorElem) ([]expr.VectorElem, error) { +func (b *Builder) compileVectorElems(ectx *exprContext, elems []dag.VectorElem) ([]expr.VectorElem, error) { var out []expr.VectorElem for _, elem := range elems { switch elem := elem.(type) { case *dag.Spread: - e, err := b.compileExpr(elem.Expr) + e, err := b.compileExpr(ectx, elem.Expr) if err != nil { return nil, err } out = append(out, expr.VectorElem{Spread: e}) case *dag.VectorValue: - e, err := b.compileExpr(elem.Expr) + e, err := b.compileExpr(ectx, elem.Expr) if err != nil { return nil, err } @@ -460,14 +478,14 @@ func (b *Builder) compileVectorElems(elems []dag.VectorElem) ([]expr.VectorElem, return out, nil } -func (b *Builder) compileMapExpr(m *dag.MapExpr) (expr.Evaluator, error) { +func (b *Builder) compileMapExpr(ectx *exprContext, m *dag.MapExpr) (expr.Evaluator, error) { var entries []expr.Entry for _, f := range m.Entries { - key, err := b.compileExpr(f.Key) + key, err := b.compileExpr(ectx, f.Key) if err != nil { return nil, err } - val, err := b.compileExpr(f.Value) + val, err := b.compileExpr(ectx, f.Value) if err != nil { return nil, err } @@ -480,16 +498,17 @@ func (b *Builder) compileOverExpr(over *dag.OverExpr) (expr.Evaluator, error) { if over.Body == nil { return nil, errors.New("over expression requires a lateral query body") } - names, lets, err := b.compileDefs(over.Defs) + var ectx exprContext + names, lets, err := b.compileDefs(&ectx, over.Defs) if err != nil { return nil, err } - exprs, err := b.compileExprs(over.Exprs) + exprs, err := b.compileExprs(&ectx, over.Exprs) if err != nil { return nil, err } parent := traverse.NewExpr(b.rctx.Context, b.zctx()) - enter := traverse.NewOver(b.rctx, parent, exprs) + enter := traverse.NewOver(b.rctx, parent, exprs, expr.NopResetter) scope := enter.AddScope(b.rctx.Context, names, lets) exits, err := b.compileSeq(over.Body, []zbuf.Puller{scope}) if err != nil { diff --git a/compiler/kernel/filter.go b/compiler/kernel/filter.go index 51b50a630c..75b4c65868 100644 --- a/compiler/kernel/filter.go +++ b/compiler/kernel/filter.go @@ -17,7 +17,7 @@ func (f *Filter) AsEvaluator() (expr.Evaluator, error) { if f == nil { return nil, nil } - return f.builder.compileExpr(f.pushdown) + return f.builder.compileExpr(nil, f.pushdown) } func (f *Filter) AsBufferFilter() (*expr.BufferFilter, error) { @@ -39,7 +39,7 @@ func (f *DeleteFilter) AsEvaluator() (expr.Evaluator, error) { // expression so we get all values that don't match. We also add a missing // call so if the expression results in an error("missing") the value is // kept. - return f.builder.compileExpr(&dag.BinaryExpr{ + return f.builder.compileExpr(nil, &dag.BinaryExpr{ Kind: "BinaryExpr", Op: "or", LHS: &dag.UnaryExpr{ diff --git a/compiler/kernel/groupby.go b/compiler/kernel/groupby.go index 45959766b1..0e66bed0a4 100644 --- a/compiler/kernel/groupby.go +++ b/compiler/kernel/groupby.go @@ -13,23 +13,24 @@ import ( ) func (b *Builder) compileGroupBy(parent zbuf.Puller, summarize *dag.Summarize) (*groupby.Op, error) { - keys, err := b.compileAssignments(summarize.Keys) + var ectx exprContext + keys, err := b.compileAssignments(&ectx, summarize.Keys) if err != nil { return nil, err } - names, reducers, err := b.compileAggAssignments(summarize.Aggs) + names, reducers, err := b.compileAggAssignments(&ectx, summarize.Aggs) if err != nil { return nil, err } dir := order.Direction(summarize.InputSortDir) - return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut) + return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut, ectx.resetters) } -func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) { +func (b *Builder) compileAggAssignments(ectx *exprContext, assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) { names := make(field.List, 0, len(assignments)) aggs := make([]*expr.Aggregator, 0, len(assignments)) for _, assignment := range assignments { - name, agg, err := b.compileAggAssignment(assignment) + name, agg, err := b.compileAggAssignment(ectx, assignment) if err != nil { return nil, nil, err } @@ -39,7 +40,7 @@ func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.Lis return names, aggs, nil } -func (b *Builder) compileAggAssignment(assignment dag.Assignment) (field.Path, *expr.Aggregator, error) { +func (b *Builder) compileAggAssignment(ectx *exprContext, assignment dag.Assignment) (field.Path, *expr.Aggregator, error) { aggAST, ok := assignment.RHS.(*dag.Agg) if !ok { return nil, nil, errors.New("aggregator is not an aggregation expression") @@ -48,23 +49,23 @@ func (b *Builder) compileAggAssignment(assignment dag.Assignment) (field.Path, * if !ok { return nil, nil, fmt.Errorf("internal error: aggregator assignment LHS is not a static path: %#v", assignment.LHS) } - m, err := b.compileAgg(aggAST) + m, err := b.compileAgg(ectx, aggAST) return this.Path, m, err } -func (b *Builder) compileAgg(agg *dag.Agg) (*expr.Aggregator, error) { +func (b *Builder) compileAgg(ectx *exprContext, agg *dag.Agg) (*expr.Aggregator, error) { name := agg.Name var err error var arg expr.Evaluator if agg.Expr != nil { - arg, err = b.compileExpr(agg.Expr) + arg, err = b.compileExpr(ectx, agg.Expr) if err != nil { return nil, err } } var where expr.Evaluator if agg.Where != nil { - where, err = b.compileExpr(agg.Where) + where, err = b.compileExpr(ectx, agg.Where) if err != nil { return nil, err } diff --git a/compiler/kernel/op.go b/compiler/kernel/op.go index 7613b1ec49..7acabbc998 100644 --- a/compiler/kernel/op.go +++ b/compiler/kernel/op.go @@ -125,7 +125,8 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) case *dag.Summarize: return b.compileGroupBy(parent, v) case *dag.Cut: - assignments, err := b.compileAssignments(v.Args) + var ectx exprContext + assignments, err := b.compileAssignments(&ectx, v.Args) if err != nil { return nil, err } @@ -134,7 +135,7 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) if v.Quiet { cutter.Quiet() } - return op.NewApplier(b.rctx, parent, cutter), nil + return op.NewApplier(b.rctx, parent, cutter, ectx.resetters), nil case *dag.Drop: if len(v.Args) == 0 { return nil, errors.New("drop: no fields given") @@ -148,13 +149,14 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) fields = append(fields, field.Path) } dropper := expr.NewDropper(b.rctx.Zctx, fields) - return op.NewApplier(b.rctx, parent, dropper), nil + return op.NewApplier(b.rctx, parent, dropper, expr.NopResetter), nil case *dag.Sort: - fields, err := b.compileExprs(v.Args) + var ectx exprContext + fields, err := b.compileExprs(&ectx, v.Args) if err != nil { return nil, err } - sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst) + sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst, ectx.resetters) if err != nil { return nil, fmt.Errorf("compiling sort: %w", err) } @@ -176,32 +178,36 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) case *dag.Pass: return pass.New(parent), nil case *dag.Filter: - f, err := b.compileExpr(v.Expr) + var ectx exprContext + f, err := b.compileExpr(&ectx, v.Expr) if err != nil { return nil, fmt.Errorf("compiling filter: %w", err) } - return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f)), nil + return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f), ectx.resetters), nil case *dag.Top: - fields, err := b.compileExprs(v.Args) + var ectx exprContext + fields, err := b.compileExprs(&ectx, v.Args) if err != nil { return nil, fmt.Errorf("compiling top: %w", err) } - return top.New(b.rctx.Zctx, parent, v.Limit, fields, v.Flush), nil + return top.New(b.rctx.Zctx, parent, v.Limit, fields, ectx.resetters, v.Flush), nil case *dag.Put: - clauses, err := b.compileAssignments(v.Args) + var ectx exprContext + clauses, err := b.compileAssignments(&ectx, v.Args) if err != nil { return nil, err } putter := expr.NewPutter(b.rctx.Zctx, clauses) - return op.NewApplier(b.rctx, parent, putter), nil + return op.NewApplier(b.rctx, parent, putter, ectx.resetters), nil case *dag.Rename: + var ectx exprContext var srcs, dsts []*expr.Lval for _, a := range v.Args { - src, err := b.compileLval(a.RHS) + src, err := b.compileLval(&ectx, a.RHS) if err != nil { return nil, err } - dst, err := b.compileLval(a.LHS) + dst, err := b.compileLval(&ectx, a.LHS) if err != nil { return nil, err } @@ -209,7 +215,7 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) dsts = append(dsts, dst) } renamer := expr.NewRenamer(b.rctx.Zctx, srcs, dsts) - return op.NewApplier(b.rctx, parent, renamer), nil + return op.NewApplier(b.rctx, parent, renamer, ectx.resetters), nil case *dag.Fuse: return fuse.New(b.rctx, parent) case *dag.Shape: @@ -223,19 +229,21 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) if err != nil { return nil, err } - args, err := b.compileExprs(v.Args) + var ectx exprContext + args, err := b.compileExprs(&ectx, v.Args) if err != nil { return nil, err } - return explode.New(b.rctx.Zctx, parent, args, typ, v.As) + return explode.New(b.rctx.Zctx, parent, args, ectx.resetters, typ, v.As) case *dag.Over: return b.compileOver(parent, v) case *dag.Yield: - exprs, err := b.compileExprs(v.Exprs) + var ectx exprContext + exprs, err := b.compileExprs(&ectx, v.Exprs) if err != nil { return nil, err } - t := yield.New(parent, exprs) + t := yield.New(parent, exprs, ectx.resetters) return t, nil case *dag.PoolScan: if parent != nil { @@ -352,11 +360,11 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) } } -func (b *Builder) compileDefs(defs []dag.Def) ([]string, []expr.Evaluator, error) { +func (b *Builder) compileDefs(ectx *exprContext, defs []dag.Def) ([]string, []expr.Evaluator, error) { exprs := make([]expr.Evaluator, 0, len(defs)) names := make([]string, 0, len(defs)) for _, def := range defs { - e, err := b.compileExpr(def.Expr) + e, err := b.compileExpr(ectx, def.Expr) if err != nil { return nil, nil, err } @@ -370,15 +378,16 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller, if len(over.Defs) != 0 && over.Body == nil { return nil, errors.New("internal error: over operator has defs but no body") } - withNames, withExprs, err := b.compileDefs(over.Defs) + var ectx exprContext + withNames, withExprs, err := b.compileDefs(&ectx, over.Defs) if err != nil { return nil, err } - exprs, err := b.compileExprs(over.Exprs) + exprs, err := b.compileExprs(&ectx, over.Exprs) if err != nil { return nil, err } - enter := traverse.NewOver(b.rctx, parent, exprs) + enter := traverse.NewOver(b.rctx, parent, exprs, ectx.resetters) if over.Body == nil { return enter, nil } @@ -398,10 +407,10 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller, return scope.NewExit(exit), nil } -func (b *Builder) compileAssignments(assignments []dag.Assignment) ([]expr.Assignment, error) { +func (b *Builder) compileAssignments(ectx *exprContext, assignments []dag.Assignment) ([]expr.Assignment, error) { keys := make([]expr.Assignment, 0, len(assignments)) for _, assignment := range assignments { - a, err := b.compileAssignment(&assignment) + a, err := b.compileAssignment(ectx, &assignment) if err != nil { return nil, err } @@ -433,7 +442,11 @@ func (b *Builder) compileSeq(seq dag.Seq, parents []zbuf.Puller) ([]zbuf.Puller, } func (b *Builder) compileScope(scope *dag.Scope, parents []zbuf.Puller) ([]zbuf.Puller, error) { - if err := b.compileFuncs(scope.Funcs); err != nil { + // XXX We need to fix how udfs are compiled since there is currently a bug + // where aggregation expressions in udfs do not have separate state per + // invocation. The fix for this might use exprContext to compile udf + // expressions per invocation. + if err := b.compileFuncs(&exprContext{}, scope.Funcs); err != nil { return nil, err } return b.compileSeq(scope.Body, parents) @@ -481,7 +494,7 @@ func (b *Builder) compileScatter(par *dag.Scatter, parents []zbuf.Puller) ([]zbu return ops, nil } -func (b *Builder) compileFuncs(fns []*dag.Func) error { +func (b *Builder) compileFuncs(ectx *exprContext, fns []*dag.Func) error { udfs := make([]*expr.UDF, 0, len(fns)) for _, f := range fns { if _, ok := b.funcs[f.Name]; ok { @@ -493,7 +506,7 @@ func (b *Builder) compileFuncs(fns []*dag.Func) error { } for i := range fns { var err error - if udfs[i].Body, err = b.compileExpr(fns[i].Expr); err != nil { + if udfs[i].Body, err = b.compileExpr(ectx, fns[i].Expr); err != nil { return err } } @@ -505,11 +518,12 @@ func (b *Builder) compileExprSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([ if len(parents) > 1 { parent = combine.New(b.rctx, parents) } - e, err := b.compileExpr(swtch.Expr) + var ectx exprContext + e, err := b.compileExpr(&ectx, swtch.Expr) if err != nil { return nil, err } - s := exprswitch.New(b.rctx, parent, e) + s := exprswitch.New(b.rctx, parent, e, ectx.resetters) var exits []zbuf.Puller for _, c := range swtch.Cases { var val *zed.Value @@ -537,20 +551,19 @@ func (b *Builder) compileSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([]zbu if len(parents) > 1 { parent = combine.New(b.rctx, parents) } - n := len(swtch.Cases) - switcher := switcher.New(b.rctx, parent) - parents = []zbuf.Puller{} - for _, c := range swtch.Cases { - f, err := b.compileExpr(c.Expr) + var ectx exprContext + cases := make([]expr.Evaluator, len(swtch.Cases)) + for i, c := range swtch.Cases { + var err error + cases[i], err = b.compileExpr(&ectx, c.Expr) if err != nil { return nil, fmt.Errorf("compiling switch case filter: %w", err) } - sc := switcher.AddCase(f) - parents = append(parents, sc) } + switcher := switcher.New(b.rctx, parent, ectx.resetters) var ops []zbuf.Puller - for k := 0; k < n; k++ { - o, err := b.compileSeq(swtch.Cases[k].Path, []zbuf.Puller{parents[k]}) + for i, c := range cases { + o, err := b.compileSeq(swtch.Cases[i].Path, []zbuf.Puller{switcher.AddCase(c)}) if err != nil { return nil, err } @@ -578,16 +591,17 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error if len(parents) != 2 { return nil, ErrJoinParents } - assignments, err := b.compileAssignments(o.Args) + var ectx exprContext + assignments, err := b.compileAssignments(&ectx, o.Args) if err != nil { return nil, err } lhs, rhs := splitAssignments(assignments) - leftKey, err := b.compileExpr(o.LeftKey) + leftKey, err := b.compileExpr(&ectx, o.LeftKey) if err != nil { return nil, err } - rightKey, err := b.compileExpr(o.RightKey) + rightKey, err := b.compileExpr(&ectx, o.RightKey) if err != nil { return nil, err } @@ -607,18 +621,19 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error default: return nil, fmt.Errorf("unknown kind of join: '%s'", o.Style) } - join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs) + join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs, ectx.resetters) if err != nil { return nil, err } return []zbuf.Puller{join}, nil case *dag.Merge: - e, err := b.compileExpr(o.Expr) + var ectx exprContext + e, err := b.compileExpr(&ectx, o.Expr) if err != nil { return nil, err } cmp := expr.NewComparator(true, o.Order == order.Desc, e).WithMissingAsNull() - return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare)}, nil + return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare, ectx.resetters)}, nil case *dag.Combine: return []zbuf.Puller{combine.New(b.rctx, parents)}, nil default: @@ -671,7 +686,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) { if in == nil { return zed.Null, nil } - e, err := b.compileExpr(in) + e, err := b.compileExpr(nil, in) if err != nil { return zed.Null, err } @@ -687,7 +702,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) { func compileExpr(in dag.Expr) (expr.Evaluator, error) { b := NewBuilder(runtime.NewContext(context.Background(), zed.NewContext()), nil) - return b.compileExpr(in) + return b.compileExpr(nil, in) } func EvalAtCompileTime(zctx *zed.Context, in dag.Expr) (val zed.Value, err error) { diff --git a/runtime/sam/expr/agg.go b/runtime/sam/expr/agg.go index f89558e50a..21fa90faeb 100644 --- a/runtime/sam/expr/agg.go +++ b/runtime/sam/expr/agg.go @@ -5,6 +5,24 @@ import ( "github.com/brimdata/zed/runtime/sam/expr/agg" ) +type Resetter interface { + Reset() +} + +type Resetters []Resetter + +func (rs Resetters) Reset() { + for _, r := range rs { + r.Reset() + } +} + +var NopResetter = nopResetter{} + +type nopResetter struct{} + +func (nopResetter) Reset() {} + type Aggregator struct { pattern agg.Pattern expr Evaluator @@ -48,22 +66,27 @@ func (a *Aggregator) Apply(zctx *zed.Context, ectx Context, f agg.Function, this // NewAggregatorExpr returns an Evaluator from agg. The returned Evaluator // retains the same functionality of the aggregation only it returns it's // current state every time a new value is consumed. -func NewAggregatorExpr(zctx *zed.Context, agg *Aggregator) Evaluator { - return &aggregatorExpr{agg: agg, zctx: zctx} +func NewAggregatorExpr(zctx *zed.Context, agg *Aggregator) *AggregatorExpr { + return &AggregatorExpr{agg: agg, zctx: zctx} } -type aggregatorExpr struct { +type AggregatorExpr struct { agg *Aggregator fn agg.Function zctx *zed.Context } -var _ Evaluator = (*aggregatorExpr)(nil) +var ( + _ Evaluator = (*AggregatorExpr)(nil) + _ Resetter = (*AggregatorExpr)(nil) +) -func (s *aggregatorExpr) Eval(ectx Context, val zed.Value) zed.Value { +func (s *AggregatorExpr) Eval(ectx Context, val zed.Value) zed.Value { if s.fn == nil { s.fn = s.agg.NewFunction() } s.agg.Apply(s.zctx, ectx, s.fn, val) return s.fn.Result(s.zctx) } + +func (s *AggregatorExpr) Reset() { s.fn = nil } diff --git a/runtime/sam/op/apply.go b/runtime/sam/op/apply.go index 43e0736bcc..c31056d91a 100644 --- a/runtime/sam/op/apply.go +++ b/runtime/sam/op/apply.go @@ -8,16 +8,18 @@ import ( ) type applier struct { - rctx *runtime.Context - parent zbuf.Puller - expr expr.Evaluator + rctx *runtime.Context + parent zbuf.Puller + expr expr.Evaluator + resetter expr.Resetter } -func NewApplier(rctx *runtime.Context, parent zbuf.Puller, expr expr.Evaluator) *applier { +func NewApplier(rctx *runtime.Context, parent zbuf.Puller, expr expr.Evaluator, r expr.Resetter) *applier { return &applier{ - rctx: rctx, - parent: parent, - expr: expr, + rctx: rctx, + parent: parent, + expr: expr, + resetter: r, } } @@ -25,6 +27,7 @@ func (a *applier) Pull(done bool) (zbuf.Batch, error) { for { batch, err := a.parent.Pull(done) if batch == nil || err != nil { + a.resetter.Reset() return nil, err } vals := batch.Values() diff --git a/runtime/sam/op/explode/explode.go b/runtime/sam/op/explode/explode.go index 03c4c6678a..b97c44f01c 100644 --- a/runtime/sam/op/explode/explode.go +++ b/runtime/sam/op/explode/explode.go @@ -11,20 +11,22 @@ import ( // zng type T, outputs one record for each field of the input record of // type T. It is useful for type-based indexing. type Op struct { - parent zbuf.Puller - outType zed.Type - typ zed.Type - args []expr.Evaluator + parent zbuf.Puller + outType zed.Type + typ zed.Type + args []expr.Evaluator + resetter expr.Resetter } // New creates a exploder for type typ, where the // output records' single field is named name. -func New(zctx *zed.Context, parent zbuf.Puller, args []expr.Evaluator, typ zed.Type, name string) (zbuf.Puller, error) { +func New(zctx *zed.Context, parent zbuf.Puller, args []expr.Evaluator, resetter expr.Resetter, typ zed.Type, name string) (zbuf.Puller, error) { return &Op{ - parent: parent, - outType: zctx.MustLookupTypeRecord([]zed.Field{{Name: name, Type: typ}}), - typ: typ, - args: args, + parent: parent, + outType: zctx.MustLookupTypeRecord([]zed.Field{{Name: name, Type: typ}}), + typ: typ, + args: args, + resetter: resetter, }, nil } @@ -32,6 +34,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { for { batch, err := o.parent.Pull(done) if batch == nil || err != nil { + o.resetter.Reset() return nil, err } vals := batch.Values() diff --git a/runtime/sam/op/exprswitch/exprswitch.go b/runtime/sam/op/exprswitch/exprswitch.go index 1000b713d4..8bcb07eb56 100644 --- a/runtime/sam/op/exprswitch/exprswitch.go +++ b/runtime/sam/op/exprswitch/exprswitch.go @@ -10,6 +10,7 @@ import ( type ExprSwitch struct { *op.Router + expr.Resetter expr expr.Evaluator cases map[string]*switchCase defaultCase *switchCase @@ -22,12 +23,13 @@ type switchCase struct { vals []zed.Value } -func New(rctx *runtime.Context, parent zbuf.Puller, e expr.Evaluator) *ExprSwitch { +func New(rctx *runtime.Context, parent zbuf.Puller, e expr.Evaluator, r expr.Resetter) *ExprSwitch { router := op.NewRouter(rctx, parent) s := &ExprSwitch{ - Router: router, - expr: e, - cases: make(map[string]*switchCase), + Router: router, + Resetter: r, + expr: e, + cases: make(map[string]*switchCase), } router.Link(s) return s diff --git a/runtime/sam/op/groupby/groupby.go b/runtime/sam/op/groupby/groupby.go index 8a7d3fd882..c90a145a6e 100644 --- a/runtime/sam/op/groupby/groupby.go +++ b/runtime/sam/op/groupby/groupby.go @@ -28,6 +28,7 @@ type Op struct { resultCh chan op.Result doneCh chan struct{} batch zbuf.Batch + resetter expr.Resetter } // Aggregator performs the core aggregation computation for a @@ -110,7 +111,7 @@ func NewAggregator(ctx context.Context, zctx *zed.Context, keyRefs, keyExprs, ag }, nil } -func New(rctx *runtime.Context, parent zbuf.Puller, keys []expr.Assignment, aggNames field.List, aggs []*expr.Aggregator, limit int, inputSortDir order.Direction, partialsIn, partialsOut bool) (*Op, error) { +func New(rctx *runtime.Context, parent zbuf.Puller, keys []expr.Assignment, aggNames field.List, aggs []*expr.Aggregator, limit int, inputSortDir order.Direction, partialsIn, partialsOut bool, resetter expr.Resetter) (*Op, error) { names := make(field.List, 0, len(keys)+len(aggNames)) for _, e := range keys { p, ok := e.LHS.Path() @@ -144,6 +145,7 @@ func New(rctx *runtime.Context, parent zbuf.Puller, keys []expr.Assignment, aggN agg: agg, resultCh: make(chan op.Result), doneCh: make(chan struct{}), + resetter: resetter, }, nil } @@ -249,6 +251,10 @@ func (o *Op) run() { } func (o *Op) sendResult(b zbuf.Batch, err error) (bool, bool) { + if b == nil { + // Reset stateful aggregation expression on EOS. + o.resetter.Reset() + } select { case o.resultCh <- op.Result{Batch: b, Err: err}: return false, true @@ -288,6 +294,7 @@ func (o *Op) reset() { o.batch.Unref() o.batch = nil } + o.resetter.Reset() } // Consume adds a value to an aggregation. diff --git a/runtime/sam/op/join/join.go b/runtime/sam/op/join/join.go index 066511952f..58d1cba798 100644 --- a/runtime/sam/op/join/join.go +++ b/runtime/sam/op/join/join.go @@ -30,11 +30,12 @@ type Op struct { joinKey *zed.Value joinSet []zed.Value types map[int]map[int]*zed.TypeRecord + resetter expr.Resetter } func New(rctx *runtime.Context, anti, inner bool, left, right zbuf.Puller, leftKey, rightKey expr.Evaluator, leftDir, rightDir order.Direction, lhs []*expr.Lval, - rhs []expr.Evaluator) (*Op, error) { + rhs []expr.Evaluator, resetter expr.Resetter) (*Op, error) { var o order.Which switch { case leftDir != order.Unknown: @@ -45,13 +46,13 @@ func New(rctx *runtime.Context, anti, inner bool, left, right zbuf.Puller, leftK var err error // Add sorts if needed. if !leftDir.HasOrder(o) { - left, err = sort.New(rctx, left, []expr.Evaluator{leftKey}, o, false) + left, err = sort.New(rctx, left, []expr.Evaluator{leftKey}, o, false, resetter) if err != nil { return nil, err } } if !rightDir.HasOrder(o) { - right, err = sort.New(rctx, right, []expr.Evaluator{rightKey}, o, false) + right, err = sort.New(rctx, right, []expr.Evaluator{rightKey}, o, false, resetter) if err != nil { return nil, err } @@ -70,6 +71,7 @@ func New(rctx *runtime.Context, anti, inner bool, left, right zbuf.Puller, leftK compare: expr.NewValueCompareFn(o, true), cutter: expr.NewCutter(rctx.Zctx, lhs, rhs), types: make(map[int]map[int]*zed.TypeRecord), + resetter: resetter, }, nil } @@ -90,6 +92,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { } if leftRec == nil { if len(out) == 0 { + o.resetter.Reset() return nil, nil } //XXX See issue #3427. diff --git a/runtime/sam/op/merge/merge.go b/runtime/sam/op/merge/merge.go index 819597b6ee..258e8ad343 100644 --- a/runtime/sam/op/merge/merge.go +++ b/runtime/sam/op/merge/merge.go @@ -27,21 +27,23 @@ type Op struct { // The head-of-line (hol) queue is maintained as a min-heap on cmp of // hol.vals[0] (see Less) so that the next Read always returns // hol[0].vals[0]. - hol []*puller + hol []*puller + resetter expr.Resetter } var _ zbuf.Puller = (*Op)(nil) var _ zio.Reader = (*Op)(nil) -func New(ctx context.Context, parents []zbuf.Puller, cmp expr.CompareFn) *Op { +func New(ctx context.Context, parents []zbuf.Puller, cmp expr.CompareFn, r expr.Resetter) *Op { pullers := make([]*puller, 0, len(parents)) for _, p := range parents { pullers = append(pullers, newPuller(ctx, p)) } return &Op{ - ctx: ctx, - cmp: cmp, - parents: pullers, + ctx: ctx, + cmp: cmp, + parents: pullers, + resetter: r, } } @@ -118,6 +120,7 @@ func (o *Op) run() error { // each parent, e.g., a parent may be immediately blocked because it has // no data at (re)start and should not be re-entered into the HOL queue. func (o *Op) start() error { + o.resetter.Reset() o.hol = o.hol[:0] for _, parent := range o.parents { parent.blocked = false diff --git a/runtime/sam/op/merge/merge_test.go b/runtime/sam/op/merge/merge_test.go index a771a91070..d87ff2a618 100644 --- a/runtime/sam/op/merge/merge_test.go +++ b/runtime/sam/op/merge/merge_test.go @@ -9,6 +9,7 @@ import ( "github.com/brimdata/zed" "github.com/brimdata/zed/order" "github.com/brimdata/zed/pkg/field" + "github.com/brimdata/zed/runtime/sam/expr" "github.com/brimdata/zed/runtime/sam/op/merge" "github.com/brimdata/zed/zbuf" "github.com/brimdata/zed/zio" @@ -102,7 +103,7 @@ func TestParallelOrder(t *testing.T) { } sortKey := order.NewSortKey(c.order, field.DottedList(c.field)) cmp := zbuf.NewComparator(zctx, sortKey).Compare - om := merge.New(context.Background(), parents, cmp) + om := merge.New(context.Background(), parents, cmp, expr.NopResetter) var sb strings.Builder err := zbuf.CopyPuller(zsonio.NewWriter(zio.NopCloser(&sb), zsonio.WriterOpts{}), om) diff --git a/runtime/sam/op/meta/sequence.go b/runtime/sam/op/meta/sequence.go index 217536ea4c..8fad6bbf07 100644 --- a/runtime/sam/op/meta/sequence.go +++ b/runtime/sam/op/meta/sequence.go @@ -197,7 +197,7 @@ func newObjectsScanner(ctx context.Context, zctx *zed.Context, pool *lake.Pool, if len(pullers) == 1 { return pullers[0], nil } - return merge.New(ctx, pullers, lake.ImportComparator(zctx, pool).Compare), nil + return merge.New(ctx, pullers, lake.ImportComparator(zctx, pool).Compare, expr.NopResetter), nil } func newObjectScanner(ctx context.Context, zctx *zed.Context, pool *lake.Pool, object *data.Object, ranges []seekindex.Range, filter zbuf.Filter, progress *zbuf.Progress) (zbuf.Puller, error) { diff --git a/runtime/sam/op/router.go b/runtime/sam/op/router.go index b695f9562f..30f3aafa2d 100644 --- a/runtime/sam/op/router.go +++ b/runtime/sam/op/router.go @@ -5,6 +5,7 @@ import ( "slices" "sync" + "github.com/brimdata/zed/runtime/sam/expr" "github.com/brimdata/zed/zbuf" ) @@ -93,6 +94,11 @@ func (r *Router) blocked() bool { // after receiving the EOS, it's done will be captured as soon as we unblock // all channels. func (r *Router) sendEOS(err error) bool { + defer func() { + if r, ok := r.selector.(expr.Resetter); ok { + r.Reset() + } + }() // First, we need to send EOS to all non-blocked legs and // catch any dones in progress. This result in all routes // being blocked. diff --git a/runtime/sam/op/sort/sort.go b/runtime/sam/op/sort/sort.go index e1db24c5a1..78b47e243e 100644 --- a/runtime/sam/op/sort/sort.go +++ b/runtime/sam/op/sort/sort.go @@ -28,9 +28,10 @@ type Op struct { once sync.Once resultCh chan op.Result comparator *expr.Comparator + resetter expr.Resetter } -func New(rctx *runtime.Context, parent zbuf.Puller, fields []expr.Evaluator, order order.Which, nullsFirst bool) (*Op, error) { +func New(rctx *runtime.Context, parent zbuf.Puller, fields []expr.Evaluator, order order.Which, nullsFirst bool, r expr.Resetter) (*Op, error) { return &Op{ rctx: rctx, parent: parent, @@ -38,6 +39,7 @@ func New(rctx *runtime.Context, parent zbuf.Puller, fields []expr.Evaluator, ord nullsFirst: nullsFirst, fieldResolvers: fields, resultCh: make(chan op.Result), + resetter: r, }, nil } @@ -170,6 +172,10 @@ func (o *Op) sendSpills(spiller *spill.MergeSort) bool { } func (o *Op) sendResult(b zbuf.Batch, err error) bool { + if b == nil && err == nil { + // Reset Evaluators as EOS + o.resetter.Reset() + } select { case o.resultCh <- op.Result{Batch: b, Err: err}: return true diff --git a/runtime/sam/op/switcher/switch.go b/runtime/sam/op/switcher/switch.go index e32d4ab980..eacd9c1a7f 100644 --- a/runtime/sam/op/switcher/switch.go +++ b/runtime/sam/op/switcher/switch.go @@ -10,6 +10,7 @@ import ( type Selector struct { *op.Router + expr.Resetter cases []*switchCase } @@ -21,10 +22,11 @@ type switchCase struct { vals []zed.Value } -func New(rctx *runtime.Context, parent zbuf.Puller) *Selector { +func New(rctx *runtime.Context, parent zbuf.Puller, r expr.Resetter) *Selector { router := op.NewRouter(rctx, parent) s := &Selector{ - Router: router, + Router: router, + Resetter: r, } router.Link(s) return s diff --git a/runtime/sam/op/top/top.go b/runtime/sam/op/top/top.go index f7ed00986d..325a905ed9 100644 --- a/runtime/sam/op/top/top.go +++ b/runtime/sam/op/top/top.go @@ -21,12 +21,13 @@ type Op struct { zctx *zed.Context limit int fields []expr.Evaluator + resetter expr.Resetter records *expr.RecordSlice compare expr.CompareFn flushEvery bool } -func New(zctx *zed.Context, parent zbuf.Puller, limit int, fields []expr.Evaluator, flushEvery bool) *Op { +func New(zctx *zed.Context, parent zbuf.Puller, limit int, fields []expr.Evaluator, r expr.Resetter, flushEvery bool) *Op { if limit == 0 { limit = defaultTopLimit } @@ -35,6 +36,7 @@ func New(zctx *zed.Context, parent zbuf.Puller, limit int, fields []expr.Evaluat limit: limit, fields: fields, flushEvery: flushEvery, + resetter: r, } } @@ -45,6 +47,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { return nil, err } if batch == nil { + defer o.resetter.Reset() return o.sorted(), nil } vals := batch.Values() diff --git a/runtime/sam/op/traverse/over.go b/runtime/sam/op/traverse/over.go index e0673ca0e2..096eaef8eb 100644 --- a/runtime/sam/op/traverse/over.go +++ b/runtime/sam/op/traverse/over.go @@ -11,19 +11,21 @@ import ( ) type Over struct { - parent zbuf.Puller - exprs []expr.Evaluator - outer []zed.Value - batch zbuf.Batch - enter *Enter - zctx *zed.Context + parent zbuf.Puller + exprs []expr.Evaluator + outer []zed.Value + batch zbuf.Batch + enter *Enter + resetter expr.Resetter + zctx *zed.Context } -func NewOver(rctx *runtime.Context, parent zbuf.Puller, exprs []expr.Evaluator) *Over { +func NewOver(rctx *runtime.Context, parent zbuf.Puller, exprs []expr.Evaluator, resetter expr.Resetter) *Over { return &Over{ - parent: parent, - exprs: exprs, - zctx: rctx.Zctx, + parent: parent, + exprs: exprs, + resetter: resetter, + zctx: rctx.Zctx, } } @@ -36,12 +38,14 @@ func (o *Over) AddScope(ctx context.Context, names []string, exprs []expr.Evalua func (o *Over) Pull(done bool) (zbuf.Batch, error) { if done { o.outer = nil + o.resetter.Reset() return o.parent.Pull(true) } for { if len(o.outer) == 0 { batch, err := o.parent.Pull(false) if batch == nil || err != nil { + o.resetter.Reset() return nil, err } o.batch = batch diff --git a/runtime/sam/op/yield/yield.go b/runtime/sam/op/yield/yield.go index 3d67ade000..e06aa4f098 100644 --- a/runtime/sam/op/yield/yield.go +++ b/runtime/sam/op/yield/yield.go @@ -7,14 +7,16 @@ import ( ) type Op struct { - parent zbuf.Puller - exprs []expr.Evaluator + parent zbuf.Puller + exprs []expr.Evaluator + resetter expr.Resetter } -func New(parent zbuf.Puller, exprs []expr.Evaluator) *Op { +func New(parent zbuf.Puller, exprs []expr.Evaluator, r expr.Resetter) *Op { return &Op{ - parent: parent, - exprs: exprs, + parent: parent, + exprs: exprs, + resetter: r, } } @@ -22,6 +24,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { for { batch, err := o.parent.Pull(done) if batch == nil || err != nil { + o.resetter.Reset() return nil, err } vals := batch.Values() diff --git a/runtime/sam/op/ztests/stateful-expr-reset.yaml b/runtime/sam/op/ztests/stateful-expr-reset.yaml new file mode 100644 index 0000000000..1f44e685e0 --- /dev/null +++ b/runtime/sam/op/ztests/stateful-expr-reset.yaml @@ -0,0 +1,83 @@ +script: | + echo '// yield' + echo null | zq -z -I yield.zed - + echo '// filter' + echo null | zq -z -I filter.zed - + echo '// switch' + echo null | zq -z -I switch.zed - + echo '// exprswitch' + echo null | zq -z -I exprswitch.zed - + echo '// over' + echo null | zq -z -I over.zed - + echo '// over with' + echo null | zq -z -I over-with.zed - + echo '// summarize' + echo null | zq -z -I summarize.zed - + +inputs: + - name: yield.zed + data: | + yield null, null + | over this => ( yield count() ) + - name: filter.zed + data: | + yield [1,2,3,4], [5,6,7] + | over this => ( where count() % 3 == 0 ) + - name: switch.zed + data: | + yield [1], [1] + | over this => ( + switch sum(this) ( + case 1 => yield "sum is 1" + ) + ) + - name: exprswitch.zed + data: | + yield [1], [1] + | over this => ( + switch ( + case sum(this) == 1 => yield "sum is 1" + ) + ) + - name: over.zed + data: | + yield null, null + | over this => ( + over count() + ) + - name: over-with.zed + data: | + yield [1], [1] + | over this => ( + over this with count = count() => ( yield count ) + ) + - name: summarize.zed + data: | + yield [1], [1] + | over this => ( sum(this) by c := count() ) + + +outputs: + - name: stdout + data: | + // yield + 1(uint64) + 1(uint64) + // filter + 3 + 7 + // switch + "sum is 1" + "sum is 1" + // exprswitch + "sum is 1" + "sum is 1" + // over + 1(uint64) + 1(uint64) + // over with + 1(uint64) + 1(uint64) + // summarize + {c:1(uint64),sum:1} + {c:1(uint64),sum:1}