Skip to content

Commit

Permalink
lib: add evaluation state dump utility type
Browse files Browse the repository at this point in the history
  • Loading branch information
efd6 committed Oct 15, 2024
1 parent f538a8d commit 10ba417
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 22 deletions.
174 changes: 174 additions & 0 deletions lib/dump.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package lib

import (
"bytes"
"encoding/json"
"fmt"
"sort"
"strings"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
)

// NewDump returns an evaluation dump that can be used to examine the complete
// set of evaluation states from a CEL program. The program must have been
// constructed with a cel.Env.Program call including the cel.OptTrackState
// evaluation option. The ast and details parameters must be valid for the
// program.
func NewDump(ast *cel.Ast, details *cel.EvalDetails) *Dump {
if ast == nil || details == nil {
return nil
}
return &Dump{ast: ast, det: details}
}

// Dump is an evaluation dump.
type Dump struct {
ast *cel.Ast
det *cel.EvalDetails
}

func (d *Dump) String() string {
if d == nil {
return ""
}
var buf strings.Builder
for i, v := range d.NodeValues() {
if i != 0 {
buf.WriteByte('\n')
}
fmt.Fprint(&buf, v)
}
return buf.String()
}

// NodeValues returns the evaluation results, source location and source
// snippets for the expressions in the dump. The nodes are sorted in
// source order.
func (d *Dump) NodeValues() []NodeValue {
if d == nil {
return nil
}
es := d.det.State()
var values []NodeValue
for _, id := range es.IDs() {
if id == 0 {
continue
}
v, ok := es.Value(id)
if !ok {
continue
}
values = append(values, d.nodeValue(v, id))
}
sort.Slice(values, func(i, j int) bool {
vi := values[i].loc
vj := values[j].loc
switch {
case vi.Line() < vj.Line():
return true
case vi.Line() > vj.Line():
return false
default:
}
switch {
case vi.Column() < vj.Column():
return true
case vi.Column() > vj.Column():
return false
default:
// If we are here we have executed more than once
// and have different values, so sort lexically.
// This is not ideal given that values may include
// maps which do not render consistently and so
// we're breaking the sort invariant that comparisons
// will be consistent. For what we are doing this is
// good enough.
return fmt.Sprint(values[i].val) < fmt.Sprint(values[j].val)
}
})
return values
}

func (d *Dump) nodeValue(val any, id int64) NodeValue {
v := NodeValue{
loc: d.ast.NativeRep().SourceInfo().GetStartLocation(id),
src: d.ast.Source(),
val: val,
}
return v
}

// NodeValue is a CEL expression node value and annotation.
type NodeValue struct {
loc common.Location
src common.Source
val any
}

func (v NodeValue) MarshalJSON() ([]byte, error) {
type val struct {
Location string `json:"loc"`
Src string `json:"src"`
Val any `json:"val"`
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
err := enc.Encode(val{
Location: v.Loc(),
Src: v.Src(),
Val: v.val,
})
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func (v NodeValue) String() string {
return fmt.Sprintf("%s\n%s\n%v\n", v.Loc(), v.Src(), v.Val())
}

func (v NodeValue) Val() any {
return v.val
}

func (v NodeValue) Loc() string {
return fmt.Sprintf("%s:%d:%d", v.src.Description(), v.loc.Line(), v.loc.Column()+1)
}

func (v NodeValue) Src() string {
snippet, ok := v.src.Snippet(v.loc.Line())
if !ok {
return ""
}
src := " | " + strings.Replace(snippet, "\t", " ", -1)
ind := "\n | " + strings.Repeat(".", minInt(v.loc.Column(), len(snippet))) + "^"
return src + ind
}

func minInt(a, b int) int {
if a < b {
return a
}
return b
}
44 changes: 30 additions & 14 deletions mito.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func Main() int {
logTrace := flag.Bool("log_requests", false, "log request traces to stderr (go1.21+)")
maxTraceBody := flag.Int("max_log_body", 1000, "maximum length of body logged in request traces (go1.21+)")
fold := flag.Bool("fold", false, "apply constant folding optimisation")
dumpState := flag.String("dump", "", "dump eval state ('always' or 'error')")
version := flag.Bool("version", false, "print version and exit")
flag.Parse()
if *version {
Expand Down Expand Up @@ -195,8 +196,14 @@ func Main() int {
}

for n := int(0); *maxExecutions < 0 || n < *maxExecutions; n++ {
res, val, err := eval(string(b), root, input, *fold, libs...)
res, val, dump, err := eval(string(b), root, input, *fold, *dumpState != "", libs...)
if *dumpState == "always" {
fmt.Fprint(os.Stderr, dump)
}
if err != nil {
if *dumpState == "error" {
fmt.Fprint(os.Stderr, dump)
}
fmt.Fprintln(os.Stderr, err)
return 1
}
Expand Down Expand Up @@ -325,15 +332,20 @@ func debug(tag string, value any) {
fmt.Fprintf(os.Stderr, "%s: logging %q: %v\n", level, tag, value)
}

func eval(src, root string, input interface{}, fold bool, libs ...cel.EnvOption) (string, any, error) {
prg, ast, err := compile(src, root, fold, libs...)
func eval(src, root string, input interface{}, fold, details bool, libs ...cel.EnvOption) (string, any, *lib.Dump, error) {
prg, ast, err := compile(src, root, fold, details, libs...)
if err != nil {
return "", nil, fmt.Errorf("failed program instantiation: %v", err)
return "", nil, nil, fmt.Errorf("failed program instantiation: %v", err)
}
return run(prg, ast, false, input)
res, val, det, err := run(prg, ast, false, input)
var dump *lib.Dump
if details {
dump = lib.NewDump(ast, det)
}
return res, val, dump, err
}

func compile(src, root string, fold bool, libs ...cel.EnvOption) (cel.Program, *cel.Ast, error) {
func compile(src, root string, fold, details bool, libs ...cel.EnvOption) (cel.Program, *cel.Ast, error) {
opts := append([]cel.EnvOption{
cel.Declarations(decls.NewVar(root, decls.Dyn)),
}, libs...)
Expand All @@ -358,40 +370,44 @@ func compile(src, root string, fold bool, libs ...cel.EnvOption) (cel.Program, *
}
}

prg, err := env.Program(ast)
var progOpts []cel.ProgramOption
if details {
progOpts = []cel.ProgramOption{cel.EvalOptions(cel.OptTrackState)}
}
prg, err := env.Program(ast, progOpts...)
if err != nil {
return nil, nil, fmt.Errorf("failed program instantiation: %v", err)
}
return prg, ast, nil
}

func run(prg cel.Program, ast *cel.Ast, fast bool, input interface{}) (string, any, error) {
func run(prg cel.Program, ast *cel.Ast, fast bool, input interface{}) (string, any, *cel.EvalDetails, error) {
if input == nil {
input = interpreter.EmptyActivation()
}
out, _, err := prg.Eval(input)
out, det, err := prg.Eval(input)
if err != nil {
return "", nil, fmt.Errorf("failed eval: %v", lib.DecoratedError{AST: ast, Err: err})
return "", nil, det, fmt.Errorf("failed eval: %v", lib.DecoratedError{AST: ast, Err: err})
}

v, err := out.ConvertToNative(reflect.TypeOf(&structpb.Value{}))
if err != nil {
return "", nil, fmt.Errorf("failed proto conversion: %v", err)
return "", nil, det, fmt.Errorf("failed proto conversion: %v", err)
}
val := v.(*structpb.Value).AsInterface()
if fast {
b, err := protojson.MarshalOptions{}.Marshal(v.(proto.Message))
if err != nil {
return "", nil, fmt.Errorf("failed native conversion: %v", err)
return "", nil, det, fmt.Errorf("failed native conversion: %v", err)
}
return string(b), val, nil
return string(b), val, det, nil
}
var buf strings.Builder
enc := json.NewEncoder(&buf)
enc.SetEscapeHTML(false)
enc.SetIndent("", "\t")
err = enc.Encode(val)
return strings.TrimRight(buf.String(), "\n"), val, err
return strings.TrimRight(buf.String(), "\n"), val, det, err
}

// rot13 is provided for testing purposes.
Expand Down
Loading

0 comments on commit 10ba417

Please sign in to comment.