Skip to content

Commit

Permalink
Reset stateful expressions on EOS
Browse files Browse the repository at this point in the history
This commit fixes a bug occurring where stateful expressions were not
getting reset when an op encounters EOS. This would most notably result
in unexpected values when using stateful expressions inside of lateral
queries.

Closes #4943
  • Loading branch information
mattnibs committed Mar 13, 2024
1 parent 38763f8 commit d015b27
Show file tree
Hide file tree
Showing 20 changed files with 381 additions and 194 deletions.
179 changes: 99 additions & 80 deletions compiler/kernel/expr.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions compiler/kernel/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (f *Filter) AsEvaluator() (expr.Evaluator, error) {
if f == nil {
return nil, nil
}
return f.builder.compileExpr(f.pushdown)
return f.builder.compileExpr(nil, f.pushdown)
}

func (f *Filter) AsBufferFilter() (*expr.BufferFilter, error) {
Expand All @@ -39,7 +39,7 @@ func (f *DeleteFilter) AsEvaluator() (expr.Evaluator, error) {
// expression so we get all values that don't match. We also add a missing
// call so if the expression results in an error("missing") the value is
// kept.
return f.builder.compileExpr(&dag.BinaryExpr{
return f.builder.compileExpr(nil, &dag.BinaryExpr{
Kind: "BinaryExpr",
Op: "or",
LHS: &dag.UnaryExpr{
Expand Down
21 changes: 11 additions & 10 deletions compiler/kernel/groupby.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@ import (
)

func (b *Builder) compileGroupBy(parent zbuf.Puller, summarize *dag.Summarize) (*groupby.Op, error) {
keys, err := b.compileAssignments(summarize.Keys)
var ectx exprContext
keys, err := b.compileAssignments(&ectx, summarize.Keys)
if err != nil {
return nil, err
}
names, reducers, err := b.compileAggAssignments(summarize.Aggs)
names, reducers, err := b.compileAggAssignments(&ectx, summarize.Aggs)
if err != nil {
return nil, err
}
dir := order.Direction(summarize.InputSortDir)
return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut)
return groupby.New(b.rctx, parent, keys, names, reducers, summarize.Limit, dir, summarize.PartialsIn, summarize.PartialsOut, ectx.resetters)
}

func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) {
func (b *Builder) compileAggAssignments(ectx *exprContext, assignments []dag.Assignment) (field.List, []*expr.Aggregator, error) {
names := make(field.List, 0, len(assignments))
aggs := make([]*expr.Aggregator, 0, len(assignments))
for _, assignment := range assignments {
name, agg, err := b.compileAggAssignment(assignment)
name, agg, err := b.compileAggAssignment(ectx, assignment)
if err != nil {
return nil, nil, err
}
Expand All @@ -39,7 +40,7 @@ func (b *Builder) compileAggAssignments(assignments []dag.Assignment) (field.Lis
return names, aggs, nil
}

func (b *Builder) compileAggAssignment(assignment dag.Assignment) (field.Path, *expr.Aggregator, error) {
func (b *Builder) compileAggAssignment(ectx *exprContext, assignment dag.Assignment) (field.Path, *expr.Aggregator, error) {
aggAST, ok := assignment.RHS.(*dag.Agg)
if !ok {
return nil, nil, errors.New("aggregator is not an aggregation expression")
Expand All @@ -48,23 +49,23 @@ func (b *Builder) compileAggAssignment(assignment dag.Assignment) (field.Path, *
if !ok {
return nil, nil, fmt.Errorf("internal error: aggregator assignment LHS is not a static path: %#v", assignment.LHS)
}
m, err := b.compileAgg(aggAST)
m, err := b.compileAgg(ectx, aggAST)
return this.Path, m, err
}

func (b *Builder) compileAgg(agg *dag.Agg) (*expr.Aggregator, error) {
func (b *Builder) compileAgg(ectx *exprContext, agg *dag.Agg) (*expr.Aggregator, error) {
name := agg.Name
var err error
var arg expr.Evaluator
if agg.Expr != nil {
arg, err = b.compileExpr(agg.Expr)
arg, err = b.compileExpr(ectx, agg.Expr)
if err != nil {
return nil, err
}
}
var where expr.Evaluator
if agg.Where != nil {
where, err = b.compileExpr(agg.Where)
where, err = b.compileExpr(ectx, agg.Where)
if err != nil {
return nil, err
}
Expand Down
109 changes: 62 additions & 47 deletions compiler/kernel/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error)
case *dag.Summarize:
return b.compileGroupBy(parent, v)
case *dag.Cut:
assignments, err := b.compileAssignments(v.Args)
var ectx exprContext
assignments, err := b.compileAssignments(&ectx, v.Args)
if err != nil {
return nil, err
}
Expand All @@ -134,7 +135,7 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error)
if v.Quiet {
cutter.Quiet()
}
return op.NewApplier(b.rctx, parent, cutter), nil
return op.NewApplier(b.rctx, parent, cutter, ectx.resetters), nil
case *dag.Drop:
if len(v.Args) == 0 {
return nil, errors.New("drop: no fields given")
Expand All @@ -148,13 +149,14 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error)
fields = append(fields, field.Path)
}
dropper := expr.NewDropper(b.rctx.Zctx, fields)
return op.NewApplier(b.rctx, parent, dropper), nil
return op.NewApplier(b.rctx, parent, dropper, expr.NopResetter), nil
case *dag.Sort:
fields, err := b.compileExprs(v.Args)
var ectx exprContext
fields, err := b.compileExprs(&ectx, v.Args)
if err != nil {
return nil, err
}
sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst)
sort, err := sort.New(b.rctx, parent, fields, v.Order, v.NullsFirst, ectx.resetters)
if err != nil {
return nil, fmt.Errorf("compiling sort: %w", err)
}
Expand All @@ -176,40 +178,44 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error)
case *dag.Pass:
return pass.New(parent), nil
case *dag.Filter:
f, err := b.compileExpr(v.Expr)
var ectx exprContext
f, err := b.compileExpr(&ectx, v.Expr)
if err != nil {
return nil, fmt.Errorf("compiling filter: %w", err)
}
return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f)), nil
return op.NewApplier(b.rctx, parent, expr.NewFilterApplier(b.rctx.Zctx, f), ectx.resetters), nil
case *dag.Top:
fields, err := b.compileExprs(v.Args)
var ectx exprContext
fields, err := b.compileExprs(&ectx, v.Args)
if err != nil {
return nil, fmt.Errorf("compiling top: %w", err)
}
return top.New(b.rctx.Zctx, parent, v.Limit, fields, v.Flush), nil
return top.New(b.rctx.Zctx, parent, v.Limit, fields, ectx.resetters, v.Flush), nil
case *dag.Put:
clauses, err := b.compileAssignments(v.Args)
var ectx exprContext
clauses, err := b.compileAssignments(&ectx, v.Args)
if err != nil {
return nil, err
}
putter := expr.NewPutter(b.rctx.Zctx, clauses)
return op.NewApplier(b.rctx, parent, putter), nil
return op.NewApplier(b.rctx, parent, putter, ectx.resetters), nil
case *dag.Rename:
var ectx exprContext
var srcs, dsts []*expr.Lval
for _, a := range v.Args {
src, err := b.compileLval(a.RHS)
src, err := b.compileLval(&ectx, a.RHS)
if err != nil {
return nil, err
}
dst, err := b.compileLval(a.LHS)
dst, err := b.compileLval(&ectx, a.LHS)
if err != nil {
return nil, err
}
srcs = append(srcs, src)
dsts = append(dsts, dst)
}
renamer := expr.NewRenamer(b.rctx.Zctx, srcs, dsts)
return op.NewApplier(b.rctx, parent, renamer), nil
return op.NewApplier(b.rctx, parent, renamer, ectx.resetters), nil
case *dag.Fuse:
return fuse.New(b.rctx, parent)
case *dag.Shape:
Expand All @@ -223,19 +229,21 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error)
if err != nil {
return nil, err
}
args, err := b.compileExprs(v.Args)
var ectx exprContext
args, err := b.compileExprs(&ectx, v.Args)
if err != nil {
return nil, err
}
return explode.New(b.rctx.Zctx, parent, args, typ, v.As)
return explode.New(b.rctx.Zctx, parent, args, ectx.resetters, typ, v.As)
case *dag.Over:
return b.compileOver(parent, v)
case *dag.Yield:
exprs, err := b.compileExprs(v.Exprs)
var ectx exprContext
exprs, err := b.compileExprs(&ectx, v.Exprs)
if err != nil {
return nil, err
}
t := yield.New(parent, exprs)
t := yield.New(parent, exprs, ectx.resetters)
return t, nil
case *dag.PoolScan:
if parent != nil {
Expand Down Expand Up @@ -352,11 +360,11 @@ func (b *Builder) compileLeaf(o dag.Op, parent zbuf.Puller) (zbuf.Puller, error)
}
}

func (b *Builder) compileDefs(defs []dag.Def) ([]string, []expr.Evaluator, error) {
func (b *Builder) compileDefs(ectx *exprContext, defs []dag.Def) ([]string, []expr.Evaluator, error) {
exprs := make([]expr.Evaluator, 0, len(defs))
names := make([]string, 0, len(defs))
for _, def := range defs {
e, err := b.compileExpr(def.Expr)
e, err := b.compileExpr(ectx, def.Expr)
if err != nil {
return nil, nil, err
}
Expand All @@ -370,15 +378,16 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller,
if len(over.Defs) != 0 && over.Body == nil {
return nil, errors.New("internal error: over operator has defs but no body")
}
withNames, withExprs, err := b.compileDefs(over.Defs)
var ectx exprContext
withNames, withExprs, err := b.compileDefs(&ectx, over.Defs)
if err != nil {
return nil, err
}
exprs, err := b.compileExprs(over.Exprs)
exprs, err := b.compileExprs(&ectx, over.Exprs)
if err != nil {
return nil, err
}
enter := traverse.NewOver(b.rctx, parent, exprs)
enter := traverse.NewOver(b.rctx, parent, exprs, ectx.resetters)
if over.Body == nil {
return enter, nil
}
Expand All @@ -398,10 +407,10 @@ func (b *Builder) compileOver(parent zbuf.Puller, over *dag.Over) (zbuf.Puller,
return scope.NewExit(exit), nil
}

func (b *Builder) compileAssignments(assignments []dag.Assignment) ([]expr.Assignment, error) {
func (b *Builder) compileAssignments(ectx *exprContext, assignments []dag.Assignment) ([]expr.Assignment, error) {
keys := make([]expr.Assignment, 0, len(assignments))
for _, assignment := range assignments {
a, err := b.compileAssignment(&assignment)
a, err := b.compileAssignment(ectx, &assignment)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -433,7 +442,11 @@ func (b *Builder) compileSeq(seq dag.Seq, parents []zbuf.Puller) ([]zbuf.Puller,
}

func (b *Builder) compileScope(scope *dag.Scope, parents []zbuf.Puller) ([]zbuf.Puller, error) {
if err := b.compileFuncs(scope.Funcs); err != nil {
// XXX We need to fix how udfs are compiled since there is currently a bug
// where aggregation expressions in udfs do not have separate state per
// invocation. The fix for this might use exprContext to compile udf
// expressions per invocation.
if err := b.compileFuncs(&exprContext{}, scope.Funcs); err != nil {
return nil, err
}
return b.compileSeq(scope.Body, parents)
Expand Down Expand Up @@ -481,7 +494,7 @@ func (b *Builder) compileScatter(par *dag.Scatter, parents []zbuf.Puller) ([]zbu
return ops, nil
}

func (b *Builder) compileFuncs(fns []*dag.Func) error {
func (b *Builder) compileFuncs(ectx *exprContext, fns []*dag.Func) error {
udfs := make([]*expr.UDF, 0, len(fns))
for _, f := range fns {
if _, ok := b.funcs[f.Name]; ok {
Expand All @@ -493,7 +506,7 @@ func (b *Builder) compileFuncs(fns []*dag.Func) error {
}
for i := range fns {
var err error
if udfs[i].Body, err = b.compileExpr(fns[i].Expr); err != nil {
if udfs[i].Body, err = b.compileExpr(ectx, fns[i].Expr); err != nil {
return err
}
}
Expand All @@ -505,11 +518,12 @@ func (b *Builder) compileExprSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([
if len(parents) > 1 {
parent = combine.New(b.rctx, parents)
}
e, err := b.compileExpr(swtch.Expr)
var ectx exprContext
e, err := b.compileExpr(&ectx, swtch.Expr)
if err != nil {
return nil, err
}
s := exprswitch.New(b.rctx, parent, e)
s := exprswitch.New(b.rctx, parent, e, ectx.resetters)
var exits []zbuf.Puller
for _, c := range swtch.Cases {
var val *zed.Value
Expand Down Expand Up @@ -537,20 +551,19 @@ func (b *Builder) compileSwitch(swtch *dag.Switch, parents []zbuf.Puller) ([]zbu
if len(parents) > 1 {
parent = combine.New(b.rctx, parents)
}
n := len(swtch.Cases)
switcher := switcher.New(b.rctx, parent)
parents = []zbuf.Puller{}
for _, c := range swtch.Cases {
f, err := b.compileExpr(c.Expr)
var ectx exprContext
cases := make([]expr.Evaluator, len(swtch.Cases))
for i, c := range swtch.Cases {
var err error
cases[i], err = b.compileExpr(&ectx, c.Expr)
if err != nil {
return nil, fmt.Errorf("compiling switch case filter: %w", err)
}
sc := switcher.AddCase(f)
parents = append(parents, sc)
}
switcher := switcher.New(b.rctx, parent, ectx.resetters)
var ops []zbuf.Puller
for k := 0; k < n; k++ {
o, err := b.compileSeq(swtch.Cases[k].Path, []zbuf.Puller{parents[k]})
for i, c := range cases {
o, err := b.compileSeq(swtch.Cases[i].Path, []zbuf.Puller{switcher.AddCase(c)})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -578,16 +591,17 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error
if len(parents) != 2 {
return nil, ErrJoinParents
}
assignments, err := b.compileAssignments(o.Args)
var ectx exprContext
assignments, err := b.compileAssignments(&ectx, o.Args)
if err != nil {
return nil, err
}
lhs, rhs := splitAssignments(assignments)
leftKey, err := b.compileExpr(o.LeftKey)
leftKey, err := b.compileExpr(&ectx, o.LeftKey)
if err != nil {
return nil, err
}
rightKey, err := b.compileExpr(o.RightKey)
rightKey, err := b.compileExpr(&ectx, o.RightKey)
if err != nil {
return nil, err
}
Expand All @@ -607,18 +621,19 @@ func (b *Builder) compile(o dag.Op, parents []zbuf.Puller) ([]zbuf.Puller, error
default:
return nil, fmt.Errorf("unknown kind of join: '%s'", o.Style)
}
join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs)
join, err := join.New(b.rctx, anti, inner, leftParent, rightParent, leftKey, rightKey, leftDir, rightDir, lhs, rhs, ectx.resetters)
if err != nil {
return nil, err
}
return []zbuf.Puller{join}, nil
case *dag.Merge:
e, err := b.compileExpr(o.Expr)
var ectx exprContext
e, err := b.compileExpr(&ectx, o.Expr)
if err != nil {
return nil, err
}
cmp := expr.NewComparator(true, o.Order == order.Desc, e).WithMissingAsNull()
return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare)}, nil
return []zbuf.Puller{merge.New(b.rctx, parents, cmp.Compare, ectx.resetters)}, nil
case *dag.Combine:
return []zbuf.Puller{combine.New(b.rctx, parents)}, nil
default:
Expand Down Expand Up @@ -671,7 +686,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) {
if in == nil {
return zed.Null, nil
}
e, err := b.compileExpr(in)
e, err := b.compileExpr(nil, in)
if err != nil {
return zed.Null, err
}
Expand All @@ -687,7 +702,7 @@ func (b *Builder) evalAtCompileTime(in dag.Expr) (val zed.Value, err error) {

func compileExpr(in dag.Expr) (expr.Evaluator, error) {
b := NewBuilder(runtime.NewContext(context.Background(), zed.NewContext()), nil)
return b.compileExpr(in)
return b.compileExpr(nil, in)
}

func EvalAtCompileTime(zctx *zed.Context, in dag.Expr) (val zed.Value, err error) {
Expand Down
Loading

0 comments on commit d015b27

Please sign in to comment.