diff --git a/compiler/kernel/expr.go b/compiler/kernel/expr.go index 2ed8f9d346..8c9443b320 100644 --- a/compiler/kernel/expr.go +++ b/compiler/kernel/expr.go @@ -16,6 +16,19 @@ import ( "golang.org/x/text/unicode/norm" ) +type exprContext struct { + resetters expr.Resetters +} + +func newExprContext() *exprContext { return new(exprContext) } + +func (r *exprContext) addResetter(resetter expr.Resetter) { + if r == nil { + return + } + r.resetters = append(r.resetters, resetter) +} + // compileExpr compiles the given Expression into an object // that evaluates the expression against a provided Record. It returns an // error if compilation fails for any reason. @@ -46,7 +59,7 @@ import ( // TBD: string values and net.IP address do not need to be copied because they // are allocated by go libraries and temporary buffers are not used. This will // change down the road when we implement no-allocation string and IP conversion. -func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { +func (b *Builder) compileExpr(ectx *exprContext, e dag.Expr) (expr.Evaluator, error) { if e == nil { return nil, errors.New("null expression not allowed") } @@ -60,39 +73,44 @@ func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { case *dag.Var: return expr.NewVar(e.Slot), nil case *dag.Search: - return b.compileSearch(e) + return b.compileSearch(ectx, e) case *dag.This: return expr.NewDottedExpr(b.zctx(), field.Path(e.Path)), nil case *dag.Dot: - return b.compileDotExpr(e) + return b.compileDotExpr(ectx, e) case *dag.UnaryExpr: - return b.compileUnary(*e) + return b.compileUnary(ectx, *e) case *dag.BinaryExpr: - return b.compileBinary(e) + return b.compileBinary(ectx, e) case *dag.Conditional: - return b.compileConditional(*e) + return b.compileConditional(ectx, *e) case *dag.Call: - return b.compileCall(*e) + return b.compileCall(ectx, *e) case *dag.RegexpMatch: - return b.compileRegexpMatch(e) + return b.compileRegexpMatch(ectx, e) case *dag.RegexpSearch: - return b.compileRegexpSearch(e) + return b.compileRegexpSearch(ectx, e) case *dag.RecordExpr: - return b.compileRecordExpr(e) + return b.compileRecordExpr(ectx, e) case *dag.ArrayExpr: - return b.compileArrayExpr(e) + return b.compileArrayExpr(ectx, e) case *dag.SetExpr: - return b.compileSetExpr(e) + return b.compileSetExpr(ectx, e) case *dag.MapCall: - return b.compileMapCall(e) + return b.compileMapCall(ectx, e) case *dag.MapExpr: - return b.compileMapExpr(e) + return b.compileMapExpr(ectx, e) case *dag.Agg: - agg, err := b.compileAgg(e) + agg, err := b.compileAgg(ectx, e) if err != nil { return nil, err } - return expr.NewAggregatorExpr(b.zctx(), agg), nil + if ectx == nil { + panic("system error: exprContext is nil") + } + aggexpr := expr.NewAggregatorExpr(b.zctx(), agg) + ectx.addResetter(aggexpr) + return aggexpr, nil case *dag.OverExpr: return b.compileOverExpr(e) default: @@ -100,31 +118,31 @@ func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { } } -func (b *Builder) compileExprWithEmpty(e dag.Expr) (expr.Evaluator, error) { +func (b *Builder) compileExprWithEmpty(ectx *exprContext, e dag.Expr) (expr.Evaluator, error) { if e == nil { return nil, nil } - return b.compileExpr(e) + return b.compileExpr(ectx, e) } -func (b *Builder) compileBinary(e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileBinary(ectx *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { if slice, ok := e.RHS.(*dag.BinaryExpr); ok && slice.Op == ":" { - return b.compileSlice(e.LHS, slice) + return b.compileSlice(ectx, e.LHS, slice) } if e.Op == "in" { // Do a faster comparison if the LHS is a compile-time constant expression. - if in, err := b.compileConstIn(e); in != nil && err == nil { + if in, err := b.compileConstIn(ectx, e); in != nil && err == nil { return in, err } } - if e, err := b.compileConstCompare(e); e != nil && err == nil { + if e, err := b.compileConstCompare(ectx, e); e != nil && err == nil { return e, nil } - lhs, err := b.compileExpr(e.LHS) + lhs, err := b.compileExpr(ectx, e.LHS) if err != nil { return nil, err } - rhs, err := b.compileExpr(e.RHS) + rhs, err := b.compileExpr(ectx, e.RHS) if err != nil { return nil, err } @@ -148,7 +166,7 @@ func (b *Builder) compileBinary(e *dag.BinaryExpr) (expr.Evaluator, error) { } } -func (b *Builder) compileConstIn(e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileConstIn(r *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { literal, err := b.evalAtCompileTime(e.LHS) if err != nil || literal.IsError() { // If the RHS here is a literal value, it would be good @@ -161,14 +179,14 @@ func (b *Builder) compileConstIn(e *dag.BinaryExpr) (expr.Evaluator, error) { if eql == nil || err != nil { return nil, nil } - operand, err := b.compileExpr(e.RHS) + operand, err := b.compileExpr(r, e.RHS) if err != nil { return nil, err } return expr.NewFilter(operand, expr.Contains(eql)), nil } -func (b *Builder) compileConstCompare(e *dag.BinaryExpr) (expr.Evaluator, error) { +func (b *Builder) compileConstCompare(ectx *exprContext, e *dag.BinaryExpr) (expr.Evaluator, error) { switch e.Op { case "==", "!=", "<", "<=", ">", ">=": default: @@ -185,19 +203,19 @@ func (b *Builder) compileConstCompare(e *dag.BinaryExpr) (expr.Evaluator, error) // non-error situation that isn't a simple comparison. return nil, nil } - operand, err := b.compileExpr(e.LHS) + operand, err := b.compileExpr(ectx, e.LHS) if err != nil { return nil, err } return expr.NewFilter(operand, comparison), nil } -func (b *Builder) compileSearch(search *dag.Search) (expr.Evaluator, error) { +func (b *Builder) compileSearch(ectx *exprContext, search *dag.Search) (expr.Evaluator, error) { val, err := zson.ParseValue(b.zctx(), search.Value) if err != nil { return nil, err } - e, err := b.compileExpr(search.Expr) + e, err := b.compileExpr(ectx, search.Expr) if err != nil { return nil, err } @@ -210,24 +228,24 @@ func (b *Builder) compileSearch(search *dag.Search) (expr.Evaluator, error) { return expr.NewSearch(search.Text, val, e) } -func (b *Builder) compileSlice(container dag.Expr, slice *dag.BinaryExpr) (expr.Evaluator, error) { - from, err := b.compileExprWithEmpty(slice.LHS) +func (b *Builder) compileSlice(ectx *exprContext, container dag.Expr, slice *dag.BinaryExpr) (expr.Evaluator, error) { + from, err := b.compileExprWithEmpty(ectx, slice.LHS) if err != nil { return nil, err } - to, err := b.compileExprWithEmpty(slice.RHS) + to, err := b.compileExprWithEmpty(ectx, slice.RHS) if err != nil { return nil, err } - e, err := b.compileExpr(container) + e, err := b.compileExpr(ectx, container) if err != nil { return nil, err } return expr.NewSlice(b.zctx(), e, from, to), nil } -func (b *Builder) compileUnary(unary dag.UnaryExpr) (expr.Evaluator, error) { - e, err := b.compileExpr(unary.Operand) +func (b *Builder) compileUnary(ectx *exprContext, unary dag.UnaryExpr) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, unary.Operand) if err != nil { return nil, err } @@ -241,48 +259,48 @@ func (b *Builder) compileUnary(unary dag.UnaryExpr) (expr.Evaluator, error) { } } -func (b *Builder) compileConditional(node dag.Conditional) (expr.Evaluator, error) { - predicate, err := b.compileExpr(node.Cond) +func (b *Builder) compileConditional(ectx *exprContext, node dag.Conditional) (expr.Evaluator, error) { + predicate, err := b.compileExpr(ectx, node.Cond) if err != nil { return nil, err } - thenExpr, err := b.compileExpr(node.Then) + thenExpr, err := b.compileExpr(ectx, node.Then) if err != nil { return nil, err } - elseExpr, err := b.compileExpr(node.Else) + elseExpr, err := b.compileExpr(ectx, node.Else) if err != nil { return nil, err } return expr.NewConditional(b.zctx(), predicate, thenExpr, elseExpr), nil } -func (b *Builder) compileDotExpr(dot *dag.Dot) (expr.Evaluator, error) { - record, err := b.compileExpr(dot.LHS) +func (b *Builder) compileDotExpr(ectx *exprContext, dot *dag.Dot) (expr.Evaluator, error) { + record, err := b.compileExpr(ectx, dot.LHS) if err != nil { return nil, err } return expr.NewDotExpr(b.zctx(), record, dot.RHS), nil } -func (b *Builder) compileLval(e dag.Expr) (*expr.Lval, error) { +func (b *Builder) compileLval(ectx *exprContext, e dag.Expr) (*expr.Lval, error) { switch e := e.(type) { case *dag.BinaryExpr: if e.Op != "[" { return nil, fmt.Errorf("internal error: invalid lval %#v", e) } - lhs, err := b.compileLval(e.LHS) + lhs, err := b.compileLval(ectx, e.LHS) if err != nil { return nil, err } - rhs, err := b.compileExpr(e.RHS) + rhs, err := b.compileExpr(ectx, e.RHS) if err != nil { return nil, err } lhs.Elems = append(lhs.Elems, expr.NewExprLvalElem(b.zctx(), rhs)) return lhs, nil case *dag.Dot: - lhs, err := b.compileLval(e.LHS) + lhs, err := b.compileLval(ectx, e.LHS) if err != nil { return nil, err } @@ -298,21 +316,21 @@ func (b *Builder) compileLval(e dag.Expr) (*expr.Lval, error) { return nil, fmt.Errorf("internal error: invalid lval %#v", e) } -func (b *Builder) compileAssignment(node *dag.Assignment) (expr.Assignment, error) { - lhs, err := b.compileLval(node.LHS) +func (b *Builder) compileAssignment(ectx *exprContext, node *dag.Assignment) (expr.Assignment, error) { + lhs, err := b.compileLval(ectx, node.LHS) if err != nil { return expr.Assignment{}, err } - rhs, err := b.compileExpr(node.RHS) + rhs, err := b.compileExpr(ectx, node.RHS) if err != nil { return expr.Assignment{}, fmt.Errorf("rhs of assigment expression: %w", err) } return expr.Assignment{LHS: lhs, RHS: rhs}, err } -func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { +func (b *Builder) compileCall(ectx *exprContext, call dag.Call) (expr.Evaluator, error) { if tf := expr.NewShaperTransform(call.Name); tf != 0 { - return b.compileShaper(call, tf) + return b.compileShaper(ectx, call, tf) } var path field.Path // First check if call is to a user defined function, otherwise check for @@ -330,42 +348,42 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { dagPath := &dag.This{Kind: "This", Path: path} args = append([]dag.Expr{dagPath}, args...) } - exprs, err := b.compileExprs(args) + exprs, err := b.compileExprs(ectx, args) if err != nil { return nil, fmt.Errorf("%s(): bad argument: %w", call.Name, err) } return expr.NewCall(b.zctx(), fn, exprs), nil } -func (b *Builder) compileMapCall(a *dag.MapCall) (expr.Evaluator, error) { - e, err := b.compileExpr(a.Expr) +func (b *Builder) compileMapCall(ectx *exprContext, a *dag.MapCall) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, a.Expr) if err != nil { return nil, err } - inner, err := b.compileExpr(a.Inner) + inner, err := b.compileExpr(ectx, a.Inner) if err != nil { return nil, err } return expr.NewMapCall(b.zctx(), e, inner), nil } -func (b *Builder) compileShaper(node dag.Call, tf expr.ShaperTransform) (expr.Evaluator, error) { +func (b *Builder) compileShaper(ectx *exprContext, node dag.Call, tf expr.ShaperTransform) (expr.Evaluator, error) { args := node.Args - field, err := b.compileExpr(args[0]) + field, err := b.compileExpr(ectx, args[0]) if err != nil { return nil, err } - typExpr, err := b.compileExpr(args[1]) + typExpr, err := b.compileExpr(ectx, args[1]) if err != nil { return nil, err } return expr.NewShaper(b.zctx(), field, typExpr, tf) } -func (b *Builder) compileExprs(in []dag.Expr) ([]expr.Evaluator, error) { +func (b *Builder) compileExprs(ectx *exprContext, in []dag.Expr) ([]expr.Evaluator, error) { var exprs []expr.Evaluator for _, e := range in { - ev, err := b.compileExpr(e) + ev, err := b.compileExpr(ectx, e) if err != nil { return nil, err } @@ -374,8 +392,8 @@ func (b *Builder) compileExprs(in []dag.Expr) ([]expr.Evaluator, error) { return exprs, nil } -func (b *Builder) compileRegexpMatch(match *dag.RegexpMatch) (expr.Evaluator, error) { - e, err := b.compileExpr(match.Expr) +func (b *Builder) compileRegexpMatch(ectx *exprContext, match *dag.RegexpMatch) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, match.Expr) if err != nil { return nil, err } @@ -386,8 +404,8 @@ func (b *Builder) compileRegexpMatch(match *dag.RegexpMatch) (expr.Evaluator, er return expr.NewRegexpMatch(re, e), nil } -func (b *Builder) compileRegexpSearch(search *dag.RegexpSearch) (expr.Evaluator, error) { - e, err := b.compileExpr(search.Expr) +func (b *Builder) compileRegexpSearch(ectx *exprContext, search *dag.RegexpSearch) (expr.Evaluator, error) { + e, err := b.compileExpr(ectx, search.Expr) if err != nil { return nil, err } @@ -399,12 +417,12 @@ func (b *Builder) compileRegexpSearch(search *dag.RegexpSearch) (expr.Evaluator, return expr.SearchByPredicate(expr.Contains(match), e), nil } -func (b *Builder) compileRecordExpr(record *dag.RecordExpr) (expr.Evaluator, error) { +func (b *Builder) compileRecordExpr(ectx *exprContext, record *dag.RecordExpr) (expr.Evaluator, error) { var elems []expr.RecordElem for _, elem := range record.Elems { switch elem := elem.(type) { case *dag.Field: - e, err := b.compileExpr(elem.Value) + e, err := b.compileExpr(ectx, elem.Value) if err != nil { return nil, err } @@ -413,7 +431,7 @@ func (b *Builder) compileRecordExpr(record *dag.RecordExpr) (expr.Evaluator, err Field: e, }) case *dag.Spread: - e, err := b.compileExpr(elem.Expr) + e, err := b.compileExpr(ectx, elem.Expr) if err != nil { return nil, err } @@ -423,34 +441,34 @@ func (b *Builder) compileRecordExpr(record *dag.RecordExpr) (expr.Evaluator, err return expr.NewRecordExpr(b.zctx(), elems) } -func (b *Builder) compileArrayExpr(array *dag.ArrayExpr) (expr.Evaluator, error) { - elems, err := b.compileVectorElems(array.Elems) +func (b *Builder) compileArrayExpr(ectx *exprContext, array *dag.ArrayExpr) (expr.Evaluator, error) { + elems, err := b.compileVectorElems(ectx, array.Elems) if err != nil { return nil, err } return expr.NewArrayExpr(b.zctx(), elems), nil } -func (b *Builder) compileSetExpr(set *dag.SetExpr) (expr.Evaluator, error) { - elems, err := b.compileVectorElems(set.Elems) +func (b *Builder) compileSetExpr(ectx *exprContext, set *dag.SetExpr) (expr.Evaluator, error) { + elems, err := b.compileVectorElems(ectx, set.Elems) if err != nil { return nil, err } return expr.NewSetExpr(b.zctx(), elems), nil } -func (b *Builder) compileVectorElems(elems []dag.VectorElem) ([]expr.VectorElem, error) { +func (b *Builder) compileVectorElems(ectx *exprContext, elems []dag.VectorElem) ([]expr.VectorElem, error) { var out []expr.VectorElem for _, elem := range elems { switch elem := elem.(type) { case *dag.Spread: - e, err := b.compileExpr(elem.Expr) + e, err := b.compileExpr(ectx, elem.Expr) if err != nil { return nil, err } out = append(out, expr.VectorElem{Spread: e}) case *dag.VectorValue: - e, err := b.compileExpr(elem.Expr) + e, err := b.compileExpr(ectx, elem.Expr) if err != nil { return nil, err } @@ -460,14 +478,14 @@ func (b *Builder) compileVectorElems(elems []dag.VectorElem) ([]expr.VectorElem, return out, nil } -func (b *Builder) compileMapExpr(m *dag.MapExpr) (expr.Evaluator, error) { +func (b *Builder) compileMapExpr(ectx *exprContext, m *dag.MapExpr) (expr.Evaluator, error) { var entries []expr.Entry for _, f := range m.Entries { - key, err := b.compileExpr(f.Key) + key, err := b.compileExpr(ectx, f.Key) if err != nil { return nil, err } - val, err := b.compileExpr(f.Value) + val, err := b.compileExpr(ectx, f.Value) if err != nil { return nil, err } @@ -480,16 +498,17 @@ func (b *Builder) compileOverExpr(over *dag.OverExpr) (expr.Evaluator, error) { if over.Body == nil { return nil, errors.New("over expression requires a lateral query body") } - names, lets, err := b.compileDefs(over.Defs) + var ectx exprContext + names, lets, err := b.compileDefs(&ectx, over.Defs) if err != nil { return nil, err } - exprs, err := b.compileExprs(over.Exprs) + exprs, err := b.compileExprs(&ectx, over.Exprs) if err != nil { return nil, err } parent := traverse.NewExpr(b.rctx.Context, b.zctx()) - enter := traverse.NewOver(b.rctx, parent, exprs) + enter := traverse.NewOver(b.rctx, parent, exprs, expr.NopResetter) scope := enter.AddScope(b.rctx.Context, names, lets) exits, err := b.compileSeq(over.Body, []zbuf.Puller{scope}) if err != nil { diff --git a/compiler/kernel/filter.go b/compiler/kernel/filter.go index 51b50a630c..75b4c65868 100644 --- a/compiler/kernel/filter.go +++ b/compiler/kernel/filter.go @@ -17,7 +17,7 @@ func (f *Filter) AsEvaluator() (expr.Evaluator, error) { if f == nil { return nil, nil } - return f.builder.compileExpr(f.pushdown) + return f.builder.compileExpr(nil, f.pushdown) } func (f *Filter) AsBufferFilter() (*expr.BufferFilter, error) { @@ -39,7 +39,7 @@ func (f *DeleteFilter) AsEvaluator() (expr.Evaluator, error) { // expression so we get all values that don't match. We also add a missing // call so if the expression results in an error("missing") the value is // kept. - return f.builder.compileExpr(&dag.BinaryExpr{ + return f.builder.compileExpr(nil, &dag.BinaryExpr{ Kind: "BinaryExpr", Op: "or", LHS: &dag.UnaryExpr{ diff --git a/compiler/kernel/groupby.go b/compiler/kernel/groupby.go index 45959766b1..0e66bed0a4 100644 --- a/compiler/kernel/groupby.go +++ b/compiler/kernel/groupby.go @@ -13,23 +13,24 @@ import ( ) func (b *Builder) compileGroupBy(parent zbuf.Puller, summarize *dag.Summarize) (*groupby.Op, error) { - keys, err := b.compileAssignments(summarize.Keys) + var ectx exprContext + keys, err := b.compileAssignments(&ectx, summarize.Keys) if err != nil { return nil, err } - names, reducers, err := b.compileAggAssignments(summarize.Aggs) + names, reducers, err := b.compileAggAssignments(&ectx, summarize.Aggs) if err != nil { return nil, err } dir := order.Direction(summarize.InputSortDir) - return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut) + return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut, ectx.resetters) } -func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) { +func (b *Builder) compileAggAssignments(ectx *exprContext, assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) { names := make(field.List, 0, len(assignments)) aggs := make([]*expr.Aggregator, 0, len(assignments)) for _, assignment := range assignments { - name, agg, err := b.compileAggAssignment(assignment) + name, agg, err := b.compileAggAssignment(ectx, assignment) if err != nil { return nil, nil, err } @@ -39,7 +40,7 @@ func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.Lis return names, aggs, nil } -func (b *Builder) compileAggAssignment(assignment dag.Assignment) (field.Path, *expr.Aggregator, error) { +func (b *Builder) compileAggAssignment(ectx *exprContext, assignment dag.Assignment) (field.Path, *expr.Aggregator, error) { aggAST, ok := assignment.RHS.(*dag.Agg) if !ok { return nil, nil, errors.New("aggregator is not an aggregation expression") @@ -48,23 +49,23 @@ func (b *Builder) compileAggAssignment(assignment dag.Assignment) (field.Path, * if !ok { return nil, nil, fmt.Errorf("internal error: aggregator assignment LHS is not a static path: %#v", assignment.LHS) } - m, err := b.compileAgg(aggAST) + m, err := b.compileAgg(ectx, aggAST) return this.Path, m, err } -func (b *Builder) compileAgg(agg *dag.Agg) (*expr.Aggregator, error) { +func (b *Builder) compileAgg(ectx *exprContext, agg *dag.Agg) (*expr.Aggregator, error) { name := agg.Name var err error var arg expr.Evaluator if agg.Expr != nil { - arg, err = b.compileExpr(agg.Expr) + arg, err = b.compileExpr(ectx, agg.Expr) if err != nil { return nil, err } } var where expr.Evaluator if agg.Where != nil { - where, err = b.compileExpr(agg.Where) + where, err = b.compileExpr(ectx, agg.Where) if err != nil { return nil, err } diff --git a/compiler/kernel/op.go b/compiler/kernel/op.go index 7613b1ec49..7acabbc998 100644 --- a/compiler/kernel/op.go +++ b/compiler/kernel/op.go @@ -125,7 +125,8 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) case *dag.Summarize: return b.compileGroupBy(parent, v) case *dag.Cut: - assignments, err := b.compileAssignments(v.Args) + var ectx exprContext + assignments, err := b.compileAssignments(&ectx, v.Args) if err != nil { return nil, err } @@ -134,7 +135,7 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) if v.Quiet { cutter.Quiet() } - return op.NewApplier(b.rctx, parent, cutter), nil + return op.NewApplier(b.rctx, parent, cutter, ectx.resetters), nil case *dag.Drop: if len(v.Args) == 0 { return nil, errors.New("drop: no fields given") @@ -148,13 +149,14 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) fields = append(fields, field.Path) } dropper := expr.NewDropper(b.rctx.Zctx, fields) - return op.NewApplier(b.rctx, parent, dropper), nil + return op.NewApplier(b.rctx, parent, dropper, expr.NopResetter), nil case *dag.Sort: - fields, err := b.compileExprs(v.Args) + var ectx exprContext + fields, err := b.compileExprs(&ectx, v.Args) if err != nil { return nil, err } - sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst) + sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst, ectx.resetters) if err != nil { return nil, fmt.Errorf("compiling sort: %w", err) } @@ -176,32 +178,36 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) case *dag.Pass: return pass.New(parent), nil case *dag.Filter: - f, err := b.compileExpr(v.Expr) + var ectx exprContext + f, err := b.compileExpr(&ectx, v.Expr) if err != nil { return nil, fmt.Errorf("compiling filter: %w", err) } - return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f)), nil + return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f), ectx.resetters), nil case *dag.Top: - fields, err := b.compileExprs(v.Args) + var ectx exprContext + fields, err := b.compileExprs(&ectx, v.Args) if err != nil { return nil, fmt.Errorf("compiling top: %w", err) } - return top.New(b.rctx.Zctx, parent, v.Limit, fields, v.Flush), nil + return top.New(b.rctx.Zctx, parent, v.Limit, fields, ectx.resetters, v.Flush), nil case *dag.Put: - clauses, err := b.compileAssignments(v.Args) + var ectx exprContext + clauses, err := b.compileAssignments(&ectx, v.Args) if err != nil { return nil, err } putter := expr.NewPutter(b.rctx.Zctx, clauses) - return op.NewApplier(b.rctx, parent, putter), nil + return op.NewApplier(b.rctx, parent, putter, ectx.resetters), nil case *dag.Rename: + var ectx exprContext var srcs, dsts []*expr.Lval for _, a := range v.Args { - src, err := b.compileLval(a.RHS) + src, err := b.compileLval(&ectx, a.RHS) if err != nil { return nil, err } - dst, err := b.compileLval(a.LHS) + dst, err := b.compileLval(&ectx, a.LHS) if err != nil { return nil, err } @@ -209,7 +215,7 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) dsts = append(dsts, dst) } renamer := expr.NewRenamer(b.rctx.Zctx, srcs, dsts) - return op.NewApplier(b.rctx, parent, renamer), nil + return op.NewApplier(b.rctx, parent, renamer, ectx.resetters), nil case *dag.Fuse: return fuse.New(b.rctx, parent) case *dag.Shape: @@ -223,19 +229,21 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) if err != nil { return nil, err } - args, err := b.compileExprs(v.Args) + var ectx exprContext + args, err := b.compileExprs(&ectx, v.Args) if err != nil { return nil, err } - return explode.New(b.rctx.Zctx, parent, args, typ, v.As) + return explode.New(b.rctx.Zctx, parent, args, ectx.resetters, typ, v.As) case *dag.Over: return b.compileOver(parent, v) case *dag.Yield: - exprs, err := b.compileExprs(v.Exprs) + var ectx exprContext + exprs, err := b.compileExprs(&ectx, v.Exprs) if err != nil { return nil, err } - t := yield.New(parent, exprs) + t := yield.New(parent, exprs, ectx.resetters) return t, nil case *dag.PoolScan: if parent != nil { @@ -352,11 +360,11 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error) } } -func (b *Builder) compileDefs(defs []dag.Def) ([]string, []expr.Evaluator, error) { +func (b *Builder) compileDefs(ectx *exprContext, defs []dag.Def) ([]string, []expr.Evaluator, error) { exprs := make([]expr.Evaluator, 0, len(defs)) names := make([]string, 0, len(defs)) for _, def := range defs { - e, err := b.compileExpr(def.Expr) + e, err := b.compileExpr(ectx, def.Expr) if err != nil { return nil, nil, err } @@ -370,15 +378,16 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller, if len(over.Defs) != 0 && over.Body == nil { return nil, errors.New("internal error: over operator has defs but no body") } - withNames, withExprs, err := b.compileDefs(over.Defs) + var ectx exprContext + withNames, withExprs, err := b.compileDefs(&ectx, over.Defs) if err != nil { return nil, err } - exprs, err := b.compileExprs(over.Exprs) + exprs, err := b.compileExprs(&ectx, over.Exprs) if err != nil { return nil, err } - enter := traverse.NewOver(b.rctx, parent, exprs) + enter := traverse.NewOver(b.rctx, parent, exprs, ectx.resetters) if over.Body == nil { return enter, nil } @@ -398,10 +407,10 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller, return scope.NewExit(exit), nil } -func (b *Builder) compileAssignments(assignments []dag.Assignment) ([]expr.Assignment, error) { +func (b *Builder) compileAssignments(ectx *exprContext, assignments []dag.Assignment) ([]expr.Assignment, error) { keys := make([]expr.Assignment, 0, len(assignments)) for _, assignment := range assignments { - a, err := b.compileAssignment(&assignment) + a, err := b.compileAssignment(ectx, &assignment) if err != nil { return nil, err } @@ -433,7 +442,11 @@ func (b *Builder) compileSeq(seq dag.Seq, parents []zbuf.Puller) ([]zbuf.Puller, } func (b *Builder) compileScope(scope *dag.Scope, parents []zbuf.Puller) ([]zbuf.Puller, error) { - if err := b.compileFuncs(scope.Funcs); err != nil { + // XXX We need to fix how udfs are compiled since there is currently a bug + // where aggregation expressions in udfs do not have separate state per + // invocation. The fix for this might use exprContext to compile udf + // expressions per invocation. + if err := b.compileFuncs(&exprContext{}, scope.Funcs); err != nil { return nil, err } return b.compileSeq(scope.Body, parents) @@ -481,7 +494,7 @@ func (b *Builder) compileScatter(par *dag.Scatter, parents []zbuf.Puller) ([]zbu return ops, nil } -func (b *Builder) compileFuncs(fns []*dag.Func) error { +func (b *Builder) compileFuncs(ectx *exprContext, fns []*dag.Func) error { udfs := make([]*expr.UDF, 0, len(fns)) for _, f := range fns { if _, ok := b.funcs[f.Name]; ok { @@ -493,7 +506,7 @@ func (b *Builder) compileFuncs(fns []*dag.Func) error { } for i := range fns { var err error - if udfs[i].Body, err = b.compileExpr(fns[i].Expr); err != nil { + if udfs[i].Body, err = b.compileExpr(ectx, fns[i].Expr); err != nil { return err } } @@ -505,11 +518,12 @@ func (b *Builder) compileExprSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([ if len(parents) > 1 { parent = combine.New(b.rctx, parents) } - e, err := b.compileExpr(swtch.Expr) + var ectx exprContext + e, err := b.compileExpr(&ectx, swtch.Expr) if err != nil { return nil, err } - s := exprswitch.New(b.rctx, parent, e) + s := exprswitch.New(b.rctx, parent, e, ectx.resetters) var exits []zbuf.Puller for _, c := range swtch.Cases { var val *zed.Value @@ -537,20 +551,19 @@ func (b *Builder) compileSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([]zbu if len(parents) > 1 { parent = combine.New(b.rctx, parents) } - n := len(swtch.Cases) - switcher := switcher.New(b.rctx, parent) - parents = []zbuf.Puller{} - for _, c := range swtch.Cases { - f, err := b.compileExpr(c.Expr) + var ectx exprContext + cases := make([]expr.Evaluator, len(swtch.Cases)) + for i, c := range swtch.Cases { + var err error + cases[i], err = b.compileExpr(&ectx, c.Expr) if err != nil { return nil, fmt.Errorf("compiling switch case filter: %w", err) } - sc := switcher.AddCase(f) - parents = append(parents, sc) } + switcher := switcher.New(b.rctx, parent, ectx.resetters) var ops []zbuf.Puller - for k := 0; k < n; k++ { - o, err := b.compileSeq(swtch.Cases[k].Path, []zbuf.Puller{parents[k]}) + for i, c := range cases { + o, err := b.compileSeq(swtch.Cases[i].Path, []zbuf.Puller{switcher.AddCase(c)}) if err != nil { return nil, err } @@ -578,16 +591,17 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error if len(parents) != 2 { return nil, ErrJoinParents } - assignments, err := b.compileAssignments(o.Args) + var ectx exprContext + assignments, err := b.compileAssignments(&ectx, o.Args) if err != nil { return nil, err } lhs, rhs := splitAssignments(assignments) - leftKey, err := b.compileExpr(o.LeftKey) + leftKey, err := b.compileExpr(&ectx, o.LeftKey) if err != nil { return nil, err } - rightKey, err := b.compileExpr(o.RightKey) + rightKey, err := b.compileExpr(&ectx, o.RightKey) if err != nil { return nil, err } @@ -607,18 +621,19 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error default: return nil, fmt.Errorf("unknown kind of join: '%s'", o.Style) } - join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs) + join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs, ectx.resetters) if err != nil { return nil, err } return []zbuf.Puller{join}, nil case *dag.Merge: - e, err := b.compileExpr(o.Expr) + var ectx exprContext + e, err := b.compileExpr(&ectx, o.Expr) if err != nil { return nil, err } cmp := expr.NewComparator(true, o.Order == order.Desc, e).WithMissingAsNull() - return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare)}, nil + return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare, ectx.resetters)}, nil case *dag.Combine: return []zbuf.Puller{combine.New(b.rctx, parents)}, nil default: @@ -671,7 +686,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) { if in == nil { return zed.Null, nil } - e, err := b.compileExpr(in) + e, err := b.compileExpr(nil, in) if err != nil { return zed.Null, err } @@ -687,7 +702,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) { func compileExpr(in dag.Expr) (expr.Evaluator, error) { b := NewBuilder(runtime.NewContext(context.Background(), zed.NewContext()), nil) - return b.compileExpr(in) + return b.compileExpr(nil, in) } func EvalAtCompileTime(zctx *zed.Context, in dag.Expr) (val zed.Value, err error) { diff --git a/runtime/sam/expr/agg.go b/runtime/sam/expr/agg.go index f89558e50a..21fa90faeb 100644 --- a/runtime/sam/expr/agg.go +++ b/runtime/sam/expr/agg.go @@ -5,6 +5,24 @@ import ( "github.com/brimdata/zed/runtime/sam/expr/agg" ) +type Resetter interface { + Reset() +} + +type Resetters []Resetter + +func (rs Resetters) Reset() { + for _, r := range rs { + r.Reset() + } +} + +var NopResetter = nopResetter{} + +type nopResetter struct{} + +func (nopResetter) Reset() {} + type Aggregator struct { pattern agg.Pattern expr Evaluator @@ -48,22 +66,27 @@ func (a *Aggregator) Apply(zctx *zed.Context, ectx Context, f agg.Function, this // NewAggregatorExpr returns an Evaluator from agg. The returned Evaluator // retains the same functionality of the aggregation only it returns it's // current state every time a new value is consumed. -func NewAggregatorExpr(zctx *zed.Context, agg *Aggregator) Evaluator { - return &aggregatorExpr{agg: agg, zctx: zctx} +func NewAggregatorExpr(zctx *zed.Context, agg *Aggregator) *AggregatorExpr { + return &AggregatorExpr{agg: agg, zctx: zctx} } -type aggregatorExpr struct { +type AggregatorExpr struct { agg *Aggregator fn agg.Function zctx *zed.Context } -var _ Evaluator = (*aggregatorExpr)(nil) +var ( + _ Evaluator = (*AggregatorExpr)(nil) + _ Resetter = (*AggregatorExpr)(nil) +) -func (s *aggregatorExpr) Eval(ectx Context, val zed.Value) zed.Value { +func (s *AggregatorExpr) Eval(ectx Context, val zed.Value) zed.Value { if s.fn == nil { s.fn = s.agg.NewFunction() } s.agg.Apply(s.zctx, ectx, s.fn, val) return s.fn.Result(s.zctx) } + +func (s *AggregatorExpr) Reset() { s.fn = nil } diff --git a/runtime/sam/op/apply.go b/runtime/sam/op/apply.go index 43e0736bcc..c31056d91a 100644 --- a/runtime/sam/op/apply.go +++ b/runtime/sam/op/apply.go @@ -8,16 +8,18 @@ import ( ) type applier struct { - rctx *runtime.Context - parent zbuf.Puller - expr expr.Evaluator + rctx *runtime.Context + parent zbuf.Puller + expr expr.Evaluator + resetter expr.Resetter } -func NewApplier(rctx *runtime.Context, parent zbuf.Puller, expr expr.Evaluator) *applier { +func NewApplier(rctx *runtime.Context, parent zbuf.Puller, expr expr.Evaluator, r expr.Resetter) *applier { return &applier{ - rctx: rctx, - parent: parent, - expr: expr, + rctx: rctx, + parent: parent, + expr: expr, + resetter: r, } } @@ -25,6 +27,7 @@ func (a *applier) Pull(done bool) (zbuf.Batch, error) { for { batch, err := a.parent.Pull(done) if batch == nil || err != nil { + a.resetter.Reset() return nil, err } vals := batch.Values() diff --git a/runtime/sam/op/explode/explode.go b/runtime/sam/op/explode/explode.go index 03c4c6678a..b97c44f01c 100644 --- a/runtime/sam/op/explode/explode.go +++ b/runtime/sam/op/explode/explode.go @@ -11,20 +11,22 @@ import ( // zng type T, outputs one record for each field of the input record of // type T. It is useful for type-based indexing. type Op struct { - parent zbuf.Puller - outType zed.Type - typ zed.Type - args []expr.Evaluator + parent zbuf.Puller + outType zed.Type + typ zed.Type + args []expr.Evaluator + resetter expr.Resetter } // New creates a exploder for type typ, where the // output records' single field is named name. -func New(zctx *zed.Context, parent zbuf.Puller, args []expr.Evaluator, typ zed.Type, name string) (zbuf.Puller, error) { +func New(zctx *zed.Context, parent zbuf.Puller, args []expr.Evaluator, resetter expr.Resetter, typ zed.Type, name string) (zbuf.Puller, error) { return &Op{ - parent: parent, - outType: zctx.MustLookupTypeRecord([]zed.Field{{Name: name, Type: typ}}), - typ: typ, - args: args, + parent: parent, + outType: zctx.MustLookupTypeRecord([]zed.Field{{Name: name, Type: typ}}), + typ: typ, + args: args, + resetter: resetter, }, nil } @@ -32,6 +34,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { for { batch, err := o.parent.Pull(done) if batch == nil || err != nil { + o.resetter.Reset() return nil, err } vals := batch.Values() diff --git a/runtime/sam/op/exprswitch/exprswitch.go b/runtime/sam/op/exprswitch/exprswitch.go index 1000b713d4..8bcb07eb56 100644 --- a/runtime/sam/op/exprswitch/exprswitch.go +++ b/runtime/sam/op/exprswitch/exprswitch.go @@ -10,6 +10,7 @@ import ( type ExprSwitch struct { *op.Router + expr.Resetter expr expr.Evaluator cases map[string]*switchCase defaultCase *switchCase @@ -22,12 +23,13 @@ type switchCase struct { vals []zed.Value } -func New(rctx *runtime.Context, parent zbuf.Puller, e expr.Evaluator) *ExprSwitch { +func New(rctx *runtime.Context, parent zbuf.Puller, e expr.Evaluator, r expr.Resetter) *ExprSwitch { router := op.NewRouter(rctx, parent) s := &ExprSwitch{ - Router: router, - expr: e, - cases: make(map[string]*switchCase), + Router: router, + Resetter: r, + expr: e, + cases: make(map[string]*switchCase), } router.Link(s) return s diff --git a/runtime/sam/op/groupby/groupby.go b/runtime/sam/op/groupby/groupby.go index 8a7d3fd882..c90a145a6e 100644 --- a/runtime/sam/op/groupby/groupby.go +++ b/runtime/sam/op/groupby/groupby.go @@ -28,6 +28,7 @@ type Op struct { resultCh chan op.Result doneCh chan struct{} batch zbuf.Batch + resetter expr.Resetter } // Aggregator performs the core aggregation computation for a @@ -110,7 +111,7 @@ func NewAggregator(ctx context.Context, zctx *zed.Context, keyRefs, keyExprs, ag }, nil } -func New(rctx *runtime.Context, parent zbuf.Puller, keys []expr.Assignment, aggNames field.List, aggs []*expr.Aggregator, limit int, inputSortDir order.Direction, partialsIn, partialsOut bool) (*Op, error) { +func New(rctx *runtime.Context, parent zbuf.Puller, keys []expr.Assignment, aggNames field.List, aggs []*expr.Aggregator, limit int, inputSortDir order.Direction, partialsIn, partialsOut bool, resetter expr.Resetter) (*Op, error) { names := make(field.List, 0, len(keys)+len(aggNames)) for _, e := range keys { p, ok := e.LHS.Path() @@ -144,6 +145,7 @@ func New(rctx *runtime.Context, parent zbuf.Puller, keys []expr.Assignment, aggN agg: agg, resultCh: make(chan op.Result), doneCh: make(chan struct{}), + resetter: resetter, }, nil } @@ -249,6 +251,10 @@ func (o *Op) run() { } func (o *Op) sendResult(b zbuf.Batch, err error) (bool, bool) { + if b == nil { + // Reset stateful aggregation expression on EOS. + o.resetter.Reset() + } select { case o.resultCh <- op.Result{Batch: b, Err: err}: return false, true @@ -288,6 +294,7 @@ func (o *Op) reset() { o.batch.Unref() o.batch = nil } + o.resetter.Reset() } // Consume adds a value to an aggregation. diff --git a/runtime/sam/op/join/join.go b/runtime/sam/op/join/join.go index 066511952f..58d1cba798 100644 --- a/runtime/sam/op/join/join.go +++ b/runtime/sam/op/join/join.go @@ -30,11 +30,12 @@ type Op struct { joinKey *zed.Value joinSet []zed.Value types map[int]map[int]*zed.TypeRecord + resetter expr.Resetter } func New(rctx *runtime.Context, anti, inner bool, left, right zbuf.Puller, leftKey, rightKey expr.Evaluator, leftDir, rightDir order.Direction, lhs []*expr.Lval, - rhs []expr.Evaluator) (*Op, error) { + rhs []expr.Evaluator, resetter expr.Resetter) (*Op, error) { var o order.Which switch { case leftDir != order.Unknown: @@ -45,13 +46,13 @@ func New(rctx *runtime.Context, anti, inner bool, left, right zbuf.Puller, leftK var err error // Add sorts if needed. if !leftDir.HasOrder(o) { - left, err = sort.New(rctx, left, []expr.Evaluator{leftKey}, o, false) + left, err = sort.New(rctx, left, []expr.Evaluator{leftKey}, o, false, resetter) if err != nil { return nil, err } } if !rightDir.HasOrder(o) { - right, err = sort.New(rctx, right, []expr.Evaluator{rightKey}, o, false) + right, err = sort.New(rctx, right, []expr.Evaluator{rightKey}, o, false, resetter) if err != nil { return nil, err } @@ -70,6 +71,7 @@ func New(rctx *runtime.Context, anti, inner bool, left, right zbuf.Puller, leftK compare: expr.NewValueCompareFn(o, true), cutter: expr.NewCutter(rctx.Zctx, lhs, rhs), types: make(map[int]map[int]*zed.TypeRecord), + resetter: resetter, }, nil } @@ -90,6 +92,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { } if leftRec == nil { if len(out) == 0 { + o.resetter.Reset() return nil, nil } //XXX See issue #3427. diff --git a/runtime/sam/op/merge/merge.go b/runtime/sam/op/merge/merge.go index 819597b6ee..258e8ad343 100644 --- a/runtime/sam/op/merge/merge.go +++ b/runtime/sam/op/merge/merge.go @@ -27,21 +27,23 @@ type Op struct { // The head-of-line (hol) queue is maintained as a min-heap on cmp of // hol.vals[0] (see Less) so that the next Read always returns // hol[0].vals[0]. - hol []*puller + hol []*puller + resetter expr.Resetter } var _ zbuf.Puller = (*Op)(nil) var _ zio.Reader = (*Op)(nil) -func New(ctx context.Context, parents []zbuf.Puller, cmp expr.CompareFn) *Op { +func New(ctx context.Context, parents []zbuf.Puller, cmp expr.CompareFn, r expr.Resetter) *Op { pullers := make([]*puller, 0, len(parents)) for _, p := range parents { pullers = append(pullers, newPuller(ctx, p)) } return &Op{ - ctx: ctx, - cmp: cmp, - parents: pullers, + ctx: ctx, + cmp: cmp, + parents: pullers, + resetter: r, } } @@ -118,6 +120,7 @@ func (o *Op) run() error { // each parent, e.g., a parent may be immediately blocked because it has // no data at (re)start and should not be re-entered into the HOL queue. func (o *Op) start() error { + o.resetter.Reset() o.hol = o.hol[:0] for _, parent := range o.parents { parent.blocked = false diff --git a/runtime/sam/op/merge/merge_test.go b/runtime/sam/op/merge/merge_test.go index a771a91070..d87ff2a618 100644 --- a/runtime/sam/op/merge/merge_test.go +++ b/runtime/sam/op/merge/merge_test.go @@ -9,6 +9,7 @@ import ( "github.com/brimdata/zed" "github.com/brimdata/zed/order" "github.com/brimdata/zed/pkg/field" + "github.com/brimdata/zed/runtime/sam/expr" "github.com/brimdata/zed/runtime/sam/op/merge" "github.com/brimdata/zed/zbuf" "github.com/brimdata/zed/zio" @@ -102,7 +103,7 @@ func TestParallelOrder(t *testing.T) { } sortKey := order.NewSortKey(c.order, field.DottedList(c.field)) cmp := zbuf.NewComparator(zctx, sortKey).Compare - om := merge.New(context.Background(), parents, cmp) + om := merge.New(context.Background(), parents, cmp, expr.NopResetter) var sb strings.Builder err := zbuf.CopyPuller(zsonio.NewWriter(zio.NopCloser(&sb), zsonio.WriterOpts{}), om) diff --git a/runtime/sam/op/meta/sequence.go b/runtime/sam/op/meta/sequence.go index 217536ea4c..8fad6bbf07 100644 --- a/runtime/sam/op/meta/sequence.go +++ b/runtime/sam/op/meta/sequence.go @@ -197,7 +197,7 @@ func newObjectsScanner(ctx context.Context, zctx *zed.Context, pool *lake.Pool, if len(pullers) == 1 { return pullers[0], nil } - return merge.New(ctx, pullers, lake.ImportComparator(zctx, pool).Compare), nil + return merge.New(ctx, pullers, lake.ImportComparator(zctx, pool).Compare, expr.NopResetter), nil } func newObjectScanner(ctx context.Context, zctx *zed.Context, pool *lake.Pool, object *data.Object, ranges []seekindex.Range, filter zbuf.Filter, progress *zbuf.Progress) (zbuf.Puller, error) { diff --git a/runtime/sam/op/router.go b/runtime/sam/op/router.go index b695f9562f..30f3aafa2d 100644 --- a/runtime/sam/op/router.go +++ b/runtime/sam/op/router.go @@ -5,6 +5,7 @@ import ( "slices" "sync" + "github.com/brimdata/zed/runtime/sam/expr" "github.com/brimdata/zed/zbuf" ) @@ -93,6 +94,11 @@ func (r *Router) blocked() bool { // after receiving the EOS, it's done will be captured as soon as we unblock // all channels. func (r *Router) sendEOS(err error) bool { + defer func() { + if r, ok := r.selector.(expr.Resetter); ok { + r.Reset() + } + }() // First, we need to send EOS to all non-blocked legs and // catch any dones in progress. This result in all routes // being blocked. diff --git a/runtime/sam/op/sort/sort.go b/runtime/sam/op/sort/sort.go index e1db24c5a1..78b47e243e 100644 --- a/runtime/sam/op/sort/sort.go +++ b/runtime/sam/op/sort/sort.go @@ -28,9 +28,10 @@ type Op struct { once sync.Once resultCh chan op.Result comparator *expr.Comparator + resetter expr.Resetter } -func New(rctx *runtime.Context, parent zbuf.Puller, fields []expr.Evaluator, order order.Which, nullsFirst bool) (*Op, error) { +func New(rctx *runtime.Context, parent zbuf.Puller, fields []expr.Evaluator, order order.Which, nullsFirst bool, r expr.Resetter) (*Op, error) { return &Op{ rctx: rctx, parent: parent, @@ -38,6 +39,7 @@ func New(rctx *runtime.Context, parent zbuf.Puller, fields []expr.Evaluator, ord nullsFirst: nullsFirst, fieldResolvers: fields, resultCh: make(chan op.Result), + resetter: r, }, nil } @@ -170,6 +172,10 @@ func (o *Op) sendSpills(spiller *spill.MergeSort) bool { } func (o *Op) sendResult(b zbuf.Batch, err error) bool { + if b == nil && err == nil { + // Reset Evaluators as EOS + o.resetter.Reset() + } select { case o.resultCh <- op.Result{Batch: b, Err: err}: return true diff --git a/runtime/sam/op/switcher/switch.go b/runtime/sam/op/switcher/switch.go index e32d4ab980..eacd9c1a7f 100644 --- a/runtime/sam/op/switcher/switch.go +++ b/runtime/sam/op/switcher/switch.go @@ -10,6 +10,7 @@ import ( type Selector struct { *op.Router + expr.Resetter cases []*switchCase } @@ -21,10 +22,11 @@ type switchCase struct { vals []zed.Value } -func New(rctx *runtime.Context, parent zbuf.Puller) *Selector { +func New(rctx *runtime.Context, parent zbuf.Puller, r expr.Resetter) *Selector { router := op.NewRouter(rctx, parent) s := &Selector{ - Router: router, + Router: router, + Resetter: r, } router.Link(s) return s diff --git a/runtime/sam/op/top/top.go b/runtime/sam/op/top/top.go index f7ed00986d..325a905ed9 100644 --- a/runtime/sam/op/top/top.go +++ b/runtime/sam/op/top/top.go @@ -21,12 +21,13 @@ type Op struct { zctx *zed.Context limit int fields []expr.Evaluator + resetter expr.Resetter records *expr.RecordSlice compare expr.CompareFn flushEvery bool } -func New(zctx *zed.Context, parent zbuf.Puller, limit int, fields []expr.Evaluator, flushEvery bool) *Op { +func New(zctx *zed.Context, parent zbuf.Puller, limit int, fields []expr.Evaluator, r expr.Resetter, flushEvery bool) *Op { if limit == 0 { limit = defaultTopLimit } @@ -35,6 +36,7 @@ func New(zctx *zed.Context, parent zbuf.Puller, limit int, fields []expr.Evaluat limit: limit, fields: fields, flushEvery: flushEvery, + resetter: r, } } @@ -45,6 +47,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { return nil, err } if batch == nil { + defer o.resetter.Reset() return o.sorted(), nil } vals := batch.Values() diff --git a/runtime/sam/op/traverse/over.go b/runtime/sam/op/traverse/over.go index e0673ca0e2..096eaef8eb 100644 --- a/runtime/sam/op/traverse/over.go +++ b/runtime/sam/op/traverse/over.go @@ -11,19 +11,21 @@ import ( ) type Over struct { - parent zbuf.Puller - exprs []expr.Evaluator - outer []zed.Value - batch zbuf.Batch - enter *Enter - zctx *zed.Context + parent zbuf.Puller + exprs []expr.Evaluator + outer []zed.Value + batch zbuf.Batch + enter *Enter + resetter expr.Resetter + zctx *zed.Context } -func NewOver(rctx *runtime.Context, parent zbuf.Puller, exprs []expr.Evaluator) *Over { +func NewOver(rctx *runtime.Context, parent zbuf.Puller, exprs []expr.Evaluator, resetter expr.Resetter) *Over { return &Over{ - parent: parent, - exprs: exprs, - zctx: rctx.Zctx, + parent: parent, + exprs: exprs, + resetter: resetter, + zctx: rctx.Zctx, } } @@ -36,12 +38,14 @@ func (o *Over) AddScope(ctx context.Context, names []string, exprs []expr.Evalua func (o *Over) Pull(done bool) (zbuf.Batch, error) { if done { o.outer = nil + o.resetter.Reset() return o.parent.Pull(true) } for { if len(o.outer) == 0 { batch, err := o.parent.Pull(false) if batch == nil || err != nil { + o.resetter.Reset() return nil, err } o.batch = batch diff --git a/runtime/sam/op/yield/yield.go b/runtime/sam/op/yield/yield.go index 3d67ade000..e06aa4f098 100644 --- a/runtime/sam/op/yield/yield.go +++ b/runtime/sam/op/yield/yield.go @@ -7,14 +7,16 @@ import ( ) type Op struct { - parent zbuf.Puller - exprs []expr.Evaluator + parent zbuf.Puller + exprs []expr.Evaluator + resetter expr.Resetter } -func New(parent zbuf.Puller, exprs []expr.Evaluator) *Op { +func New(parent zbuf.Puller, exprs []expr.Evaluator, r expr.Resetter) *Op { return &Op{ - parent: parent, - exprs: exprs, + parent: parent, + exprs: exprs, + resetter: r, } } @@ -22,6 +24,7 @@ func (o *Op) Pull(done bool) (zbuf.Batch, error) { for { batch, err := o.parent.Pull(done) if batch == nil || err != nil { + o.resetter.Reset() return nil, err } vals := batch.Values() diff --git a/runtime/sam/op/ztests/stateful-expr-reset.yaml b/runtime/sam/op/ztests/stateful-expr-reset.yaml new file mode 100644 index 0000000000..1f44e685e0 --- /dev/null +++ b/runtime/sam/op/ztests/stateful-expr-reset.yaml @@ -0,0 +1,83 @@ +script: | + echo '// yield' + echo null | zq -z -I yield.zed - + echo '// filter' + echo null | zq -z -I filter.zed - + echo '// switch' + echo null | zq -z -I switch.zed - + echo '// exprswitch' + echo null | zq -z -I exprswitch.zed - + echo '// over' + echo null | zq -z -I over.zed - + echo '// over with' + echo null | zq -z -I over-with.zed - + echo '// summarize' + echo null | zq -z -I summarize.zed - + +inputs: + - name: yield.zed + data: | + yield null, null + | over this => ( yield count() ) + - name: filter.zed + data: | + yield [1,2,3,4], [5,6,7] + | over this => ( where count() % 3 == 0 ) + - name: switch.zed + data: | + yield [1], [1] + | over this => ( + switch sum(this) ( + case 1 => yield "sum is 1" + ) + ) + - name: exprswitch.zed + data: | + yield [1], [1] + | over this => ( + switch ( + case sum(this) == 1 => yield "sum is 1" + ) + ) + - name: over.zed + data: | + yield null, null + | over this => ( + over count() + ) + - name: over-with.zed + data: | + yield [1], [1] + | over this => ( + over this with count = count() => ( yield count ) + ) + - name: summarize.zed + data: | + yield [1], [1] + | over this => ( sum(this) by c := count() ) + + +outputs: + - name: stdout + data: | + // yield + 1(uint64) + 1(uint64) + // filter + 3 + 7 + // switch + "sum is 1" + "sum is 1" + // exprswitch + "sum is 1" + "sum is 1" + // over + 1(uint64) + 1(uint64) + // over with + 1(uint64) + 1(uint64) + // summarize + {c:1(uint64),sum:1} + {c:1(uint64),sum:1}