diff --git a/pkg/util/walker/demo/default_visitor_test.go b/pkg/util/walker/demo/default_visitor_test.go new file mode 100644 index 000000000000..e84b5653436a --- /dev/null +++ b/pkg/util/walker/demo/default_visitor_test.go @@ -0,0 +1,45 @@ +package demo + +import ( + "context" + "fmt" + "reflect" + "strings" + "testing" +) + +func TestTree(t *testing.T) { + x := &Bar{ + foo: Foo{ + val: "Hello", + }, + fooPtr: &Foo{ + val: "World!", + }, + } + + depth := 0 + var w strings.Builder + v := &StatementVisitorBase{ + DefaultPre: func(ctx StatementContext, x Statement) (b bool, e error) { + for i := 0; i < depth; i++ { + if _, err := w.WriteString(" "); err != nil { + return false, nil + } + } + w.WriteString(fmt.Sprintf("Name: %s; Type: %s\n", x.Name(), reflect.TypeOf(x))) + depth++ + return true, nil + }, + DefaultPost: func(ctx StatementContext, x Statement) error { + depth-- + return nil + }, + } + if x2, dirty, err := WalkStatement(context.Background(), x, v); err != nil { + t.Fatal(err) + } else { + t.Logf("%v %+v %+v", dirty, x, x2) + } + t.Logf("Output:\n%s", w.String()) +} diff --git a/pkg/util/walker/demo/demo.go b/pkg/util/walker/demo/demo.go new file mode 100644 index 000000000000..cb0eb93fdce6 --- /dev/null +++ b/pkg/util/walker/demo/demo.go @@ -0,0 +1,55 @@ +package demo + +// REVIEWERS: Start reading this file first, then check out +// generated_api.go for the bulk of the user-visible API. + +import ( + "time" + + "github.com/cockroachdb/cockroach/pkg/util/walker" +) + +// The user starts by defining a common base interface for their +// visitable types. It is permissible for a package to define +// multiple base types, such as `Statement` and `Expr`. Any symbols +// derived from a visitable interface will have the interface's name +// included for disambiguation, so prefer shorter names. +type Statement interface { + // The generator looks for this magic type. + walker.Interface + + // This method isn't special in any way, we just need some way to + // distinguish types defined in this package as being assignable + // to Statement. An `isStatement()` marker would also be just fine. + Name() string +} + +// Foo is a "pure value" type which contains no pointers or slices. +type Foo struct { + val string +} + +// Because *Foo implements statement, the user will get *Foo in +// the visitor API. +func (*Foo) Name() string { return "Foo" } + +type Bar struct { + foo Foo + fooPtr *Foo + + quux Quux + quuxPtr *Quux + + fooSlice []Foo + fooPtrSlice []*Foo +} + +func (*Bar) Name() string { return "Bar" } + +type Quux struct { + now time.Time +} + +// Quux implements the Statement interface with a value receiver, +// so the user will also see it by-value. +func (Quux) Name() string { return "Quux" } diff --git a/pkg/util/walker/demo/generated_api.go b/pkg/util/walker/demo/generated_api.go new file mode 100644 index 000000000000..9e4f30aa8858 --- /dev/null +++ b/pkg/util/walker/demo/generated_api.go @@ -0,0 +1,111 @@ +// Code generated by hand. DO NOT EDIT. +// source: demo.go +package demo + +// This file contains the interfaces that users will interact with. + +import ( + "context" + "reflect" +) + +// StatementContext allows for in-place structural modification by a +// visitor. +type StatementContext interface { + context.Context + + // Accepts an arbitrary value to be processed in the current context. + // In general, a visitor will pass itself for the v value, but + // we allow it to be overridden to allow a meta-visitor to choose + // between other visitor implementations to apply. + Accept(v StatementVisitor, n Statement) (res Statement, changed bool) + + // AcceptMany is a slice-oriented version of the above. It guarantees + // that the returned type will be identical to the input type. + AcceptMany(v StatementVisitor, n []Statement) (res []Statement, changed bool) + + CanReplace() bool + Replace(n Statement) + + CanInsertBefore() bool + InsertBefore(n Statement) + + CanInsertAfter() bool + InsertAfter(n Statement) + + CanRemove() bool + Remove() + + // Internal version of Accept to check types. + accept(v StatementVisitor, n statementImpl, assignableTo reflect.Type) (res statementImpl, changed bool) + // Internal version of AcceptMany to check types. + acceptMany(v StatementVisitor, n []statementImpl, assignableTo reflect.Type) (res []statementImpl, changed bool) +} + +// This generated interface will contain pre/post pairs for +// every type that implements the visitable interface. +// Whether or not you get a pointer or a struct type in +// these methods depends on whether or not the struct or +// the pointer type implements the visitable interface. +type StatementVisitor interface { + PreBar(ctx StatementContext, x *Bar) (bool, error) + PreFoo(ctx StatementContext, x *Foo) (bool, error) + PreQuux(ctx StatementContext, x Quux) (bool, error) + + PostBar(ctx StatementContext, x *Bar) error + PostFoo(ctx StatementContext, x *Foo) error + PostQuux(ctx StatementContext, x Quux) error +} + +// A default implementation of the visitor implementation. +// This has provisions for allowing users to provide default +// pre/post methods since we can't call methods defined in a struct +// that has embedded this base type. +type StatementVisitorBase struct { + DefaultPre func(ctx StatementContext, x Statement) (bool, error) + DefaultPost func(ctx StatementContext, x Statement) error +} + +var _ StatementVisitor = &StatementVisitorBase{} + +func (b StatementVisitorBase) PreBar(ctx StatementContext, x *Bar) (bool, error) { + if b.DefaultPre == nil { + return true, nil + } + return b.DefaultPre(ctx, x) +} + +func (b StatementVisitorBase) PreFoo(ctx StatementContext, x *Foo) (bool, error) { + if b.DefaultPre == nil { + return true, nil + } + return b.DefaultPre(ctx, x) +} + +func (b StatementVisitorBase) PreQuux(ctx StatementContext, x Quux) (bool, error) { + if b.DefaultPre == nil { + return true, nil + } + return b.DefaultPre(ctx, x) +} + +func (b StatementVisitorBase) PostBar(ctx StatementContext, x *Bar) error { + if b.DefaultPost == nil { + return nil + } + return b.DefaultPost(ctx, x) +} + +func (b StatementVisitorBase) PostFoo(ctx StatementContext, x *Foo) error { + if b.DefaultPost == nil { + return nil + } + return b.DefaultPost(ctx, x) +} + +func (b StatementVisitorBase) PostQuux(ctx StatementContext, x Quux) error { + if b.DefaultPre == nil { + return nil + } + return b.DefaultPost(ctx, x) +} diff --git a/pkg/util/walker/demo/generated_contexts.go b/pkg/util/walker/demo/generated_contexts.go new file mode 100644 index 000000000000..e2646c6b6097 --- /dev/null +++ b/pkg/util/walker/demo/generated_contexts.go @@ -0,0 +1,248 @@ +// Code generated by hand. DO NOT EDIT. +// source: demo.go +package demo + +import ( + "context" + "reflect" + "sync" + + "github.com/cockroachdb/cockroach/pkg/util/walker" +) + +type baseStatementContext struct { + context.Context + assignableTo reflect.Type + dirty bool +} + +var _ StatementContext = &baseStatementContext{} + +func (c baseStatementContext) Accept(v StatementVisitor, x Statement) (Statement, bool) { + s := ensureStatementImpl(x, statementType) + return c.accept(v, s, statementType) +} + +func (c baseStatementContext) accept( + v StatementVisitor, x statementImpl, assignableTo reflect.Type, +) (statementImpl, bool) { + ctx := newScalarStatementContext(c, assignableTo, + assignableTo.Kind() == reflect.Interface || assignableTo.Kind() == reflect.Ptr) + + recurse, err := x.pre(ctx, v) + if err != nil { + ctx.close() + ctx.unwind(err) + } + if ctx.dirty { + x = ctx.replacement + // If the user has nullified the value that we were visiting, + // exit early. + if x == nil { + ctx.close() + return x, true + } + ctx.replacement = nil + } + if recurse { + x.traverse(ctx, v) + } + if err := x.post(ctx, v); err != nil { + ctx.close() + ctx.unwind(err) + } + if ctx.replacement != nil { + x = ctx.replacement + } + dirty := ctx.dirty + ctx.close() + return x, dirty +} + +func (c baseStatementContext) AcceptMany(v StatementVisitor, n []Statement) ([]Statement, bool) { + newValue, changed := c.acceptMany(v, ensureStatementImpls(n, statementType), statementType) + if !changed { + return n, false + } + ret := make([]Statement, len(newValue)) + for i, j := range ret { + ret[i] = j + } + return ret, true +} + +func (c baseStatementContext) acceptMany( + v StatementVisitor, n []statementImpl, assignableTo reflect.Type, +) ([]statementImpl, bool) { + out := make([]statementImpl, 0, len(n)) + dirty := false + for _, x := range n { + ctx := &sliceStatementContext{ + baseStatementContext: baseStatementContext{ + Context: c, + assignableTo: assignableTo, + }, + } + + if ctx.insertBefore != nil { + dirty = true + out = append(out, ctx.insertBefore...) + } + if ctx.didRemove { + dirty = true + } else { + if ctx.didReplace { + dirty = true + out = append(out, ctx.replace) + } else { + // Not dirty, retaining existing element. + out = append(out, x) + } + } + if ctx.insertAfter != nil { + dirty = true + out = append(out, ctx.insertAfter...) + } + } + if dirty { + return out, true + } + return n, false +} + +func (c baseStatementContext) CanReplace() bool { + return false +} + +func (c baseStatementContext) Replace(n Statement) { + panic("this context cannot replace") +} + +func (c baseStatementContext) CanInsertBefore() bool { + return false +} + +func (c baseStatementContext) InsertBefore(n Statement) { + panic("this context cannot insert") +} + +func (c baseStatementContext) CanInsertAfter() bool { + return false +} + +func (c baseStatementContext) InsertAfter(n Statement) { + panic("this context cannot insert") +} + +func (baseStatementContext) CanRemove() bool { + return false +} + +func (baseStatementContext) Remove() { + panic("this context cannot remove") +} + +// unwind uses panic/recover to allow quickly unwinding up to +// the top-level walk function. +func (c baseStatementContext) unwind(err error) { + panic(walker.WalkError{Reason: err}) +} + +// scalerStatementContext instances should be obtained through +// newScalarStatementContext. +type scalerStatementContext struct { + baseStatementContext + allowRemove bool + replacement statementImpl +} + +var _ StatementContext = &scalerStatementContext{} + +var scalarStatementContextPool = sync.Pool{New: func() interface{} { + return &scalerStatementContext{} +}} + +func newScalarStatementContext( + parent baseStatementContext, assignableTo reflect.Type, allowRemove bool, +) *scalerStatementContext { + ret := scalarStatementContextPool.Get().(*scalerStatementContext) + *ret = scalerStatementContext{ + baseStatementContext: baseStatementContext{ + Context: parent.Context, + assignableTo: assignableTo, + }, + allowRemove: allowRemove, + } + return ret +} + +func (c *scalerStatementContext) close() { + scalarStatementContextPool.Put(c) +} + +func (c *scalerStatementContext) CanRemove() bool { + return c.allowRemove +} + +func (*scalerStatementContext) CanReplace() bool { + return true +} + +func (c *scalerStatementContext) Remove() { + if c.allowRemove { + c.dirty = true + c.replacement = nil + } else { + c.baseStatementContext.Remove() + } +} + +func (c *scalerStatementContext) Replace(n Statement) { + if n == nil { + c.Remove() + } else { + c.dirty = true + c.replacement = ensureStatementImpl(n, c.assignableTo) + } +} + +type sliceStatementContext struct { + baseStatementContext + didRemove bool + didReplace bool + insertAfter []statementImpl + insertBefore []statementImpl + replace statementImpl +} + +func (c *sliceStatementContext) CanInsertAfter() bool { + return true +} +func (c *sliceStatementContext) CanInsertBefore() bool { + return true +} +func (c *sliceStatementContext) CanRemove() bool { + return true +} +func (c *sliceStatementContext) CanReplace() bool { + return true +} +func (c *sliceStatementContext) InsertAfter(val Statement) { + c.dirty = true + c.insertAfter = append(c.insertAfter, ensureStatementImpl(val, c.assignableTo)) +} +func (c *sliceStatementContext) InsertBefore(val Statement) { + c.dirty = true + c.insertBefore = append(c.insertBefore, ensureStatementImpl(val, c.assignableTo)) +} +func (c *sliceStatementContext) Remove() { + c.dirty = true + c.didRemove = true + c.didReplace = false +} +func (c *sliceStatementContext) Replace(x Statement) { + c.dirty = true + c.didRemove = false + c.didReplace = true + c.replace = ensureStatementImpl(x, c.assignableTo) +} diff --git a/pkg/util/walker/demo/generated_enhancements.go b/pkg/util/walker/demo/generated_enhancements.go new file mode 100644 index 000000000000..328d6f077994 --- /dev/null +++ b/pkg/util/walker/demo/generated_enhancements.go @@ -0,0 +1,144 @@ +// Code generated by hand. DO NOT EDIT. +// source: demo.go +package demo + +// This file contains additional methods defined on the visitable types. + +import ( + "context" + "fmt" + "reflect" +) + +// statementImpl is an enhanced Statement. +type statementImpl interface { + Statement + // pre calls the relevant PreXYZ method on the visitor. + pre(ctx StatementContext, v StatementVisitor) (bool, error) + // post calls the relevant PreXYZ method on the visitor. + post(ctx StatementContext, v StatementVisitor) error + // traverse visits the fields within the statement. + traverse(ctx StatementContext, v StatementVisitor) +} + +func ensureStatementImpl(val interface{}, assignableTo reflect.Type) statementImpl { + var ret reflect.Value + valTyp := reflect.TypeOf(val) + if valTyp.ConvertibleTo(assignableTo) && valTyp.AssignableTo(statementImplType) { + ret = reflect.ValueOf(val).Convert(assignableTo) + } else if ptrType := reflect.PtrTo(valTyp); ptrType.ConvertibleTo(assignableTo) && ptrType.AssignableTo(statementImplType) { + ret = reflect.New(valTyp) + ret.Elem().Set(reflect.ValueOf(val)) + } else { + panic(fmt.Sprintf("unhandled conversion %+v to %v", val, assignableTo)) + } + return ret.Interface().(statementImpl) +} + +func ensureStatementImpls(slice interface{}, assignableTo reflect.Type) []statementImpl { + val := reflect.ValueOf(slice) + ln := val.Len() + ret := make([]statementImpl, ln) + for i := 0; i < ln; i++ { + ret[i] = ensureStatementImpl(val.Index(i), assignableTo) + } + return ret +} + +// Whether or not these traverse() methods are generated as +// Foo or *Foo should depend on the receiver type used when +// implementing the user's visitable interface. +var _ statementImpl = &Foo{} +var _ statementImpl = &Bar{} +var _ statementImpl = Quux{} + +func (x *Foo) Walk(ctx context.Context, v StatementVisitor) (*Foo, bool, error) { + ret, changed, err := walkStatement(ctx, x, v, fooPtrType) + return ret.(*Foo), changed, err +} + +func (x *Foo) pre(ctx StatementContext, v StatementVisitor) (bool, error) { + return v.PreFoo(ctx, x) +} + +func (x *Foo) post(ctx StatementContext, v StatementVisitor) error { + return v.PostFoo(ctx, x) +} + +// No fields +func (x *Foo) traverse(ctx StatementContext, v StatementVisitor) {} + +func (x *Bar) Walk(ctx context.Context, v StatementVisitor) (*Bar, bool, error) { + ret, changed, err := walkStatement(ctx, x, v, barPtrType) + return ret.(*Bar), changed, err +} + +func (x *Bar) pre(ctx StatementContext, v StatementVisitor) (bool, error) { + return v.PreBar(ctx, x) +} + +func (x *Bar) post(ctx StatementContext, v StatementVisitor) error { + return v.PostBar(ctx, x) +} + +func (x *Bar) traverse(ctx StatementContext, v StatementVisitor) { + dirty := false + // *Visitable in Visitable field + newFoo := &x.foo + if x, changed := ctx.accept(v, &x.foo, fooPtrType); changed { + dirty = true + newFoo = x.(*Foo) + } + // *Visitable in *Visitable field + newFooPtr := x.fooPtr + if newFooPtr != nil { + if x, changed := ctx.accept(v, x.fooPtr, fooPtrType); changed { + dirty = true + newFooPtr = x.(*Foo) + } + } + // Visitable in Visitable field + newQuux := x.quux + if x, changed := ctx.accept(v, &x.quux, quuxType); changed { + dirty = true + newQuux = x.(Quux) + } + // Visitable in *Visitable field + newQuuxPtr := x.quuxPtr + if newQuuxPtr != nil { + if x, changed := ctx.accept(v, x.quuxPtr, quuxType); changed { + dirty = true + t := x.(Quux) + newQuuxPtr = &t + } + } + + { + // Todo slices + } + + if dirty { + ctx.Replace(&Bar{ + foo: *newFoo, + fooPtr: newFooPtr, + quux: newQuux, + quuxPtr: newQuuxPtr, + }) + } +} + +func (x Quux) Walk(ctx context.Context, v StatementVisitor) (Quux, bool, error) { + ret, changed, err := walkStatement(ctx, x, v, quuxType) + return ret.(Quux), changed, err +} + +func (x Quux) pre(ctx StatementContext, v StatementVisitor) (bool, error) { + return v.PreQuux(ctx, x) +} + +func (x Quux) post(ctx StatementContext, v StatementVisitor) error { + return v.PostQuux(ctx, x) +} + +// No fields. +func (x Quux) traverse(ctx StatementContext, v StatementVisitor) {} diff --git a/pkg/util/walker/demo/generated_walk.go b/pkg/util/walker/demo/generated_walk.go new file mode 100644 index 000000000000..61b9c3176bd1 --- /dev/null +++ b/pkg/util/walker/demo/generated_walk.go @@ -0,0 +1,67 @@ +// Code generated by hand. DO NOT EDIT. +// source: demo.go +package demo + +// This file contains miscellaneous support. + +import ( + "context" + "reflect" + + "github.com/cockroachdb/cockroach/pkg/util/walker" +) + +// Generate some type tokens to prevent inappropriate assignments. +var ( + statementType = reflect.TypeOf([]Statement(nil)).Elem() + statementImplType = reflect.TypeOf([]statementImpl(nil)).Elem() + barPtrType = reflect.TypeOf([]*Bar(nil)).Elem() + fooPtrType = reflect.TypeOf([]*Foo(nil)).Elem() + quuxType = reflect.TypeOf([]Quux(nil)).Elem() +) + +func WalkStatement( + ctx context.Context, tgt Statement, v StatementVisitor, +) (Statement, bool, error) { + return walkStatement(ctx, tgt, v, statementType) +} + +func walkStatement( + ctx context.Context, tgt Statement, v StatementVisitor, assignableTo reflect.Type, +) (statementImpl, bool, error) { + var err error + defer func() { + if r := recover(); r != nil { + if we, ok := r.(*walker.WalkError); ok { + err = we + } else { + panic(r) + } + } + }() + + s := ensureStatementImpl(tgt, assignableTo) + s, changed := (&baseStatementContext{Context: ctx}).accept(v, s, assignableTo) + return s, changed, err +} + +func WalkStatements( + ctx context.Context, tgt []Statement, v StatementVisitor, +) ([]Statement, bool, error) { + panic("unimplemented") + /* + var err error + defer func() { + if r := recover(); r != nil { + if we, ok := r.(*walker.WalkError); ok { + err = we + } else { + panic(r) + } + } + }() + + ret, changed := (&baseStatementContext{Context: ctx}).acceptMany(v, ensureStatementImpls(tgt, statementImplType), statementImplType) + return ret, changed, err + */ +} diff --git a/pkg/util/walker/demo/modification_test.go b/pkg/util/walker/demo/modification_test.go new file mode 100644 index 000000000000..d675985b8da9 --- /dev/null +++ b/pkg/util/walker/demo/modification_test.go @@ -0,0 +1,135 @@ +package demo + +// In this test, we're going to show mutations performed in-place +// as well as mutations performed by replacement. We have visitable +// types *Foo and Quux. We can modify *Foo in place, but must +// replace values of Quux. + +import ( + "context" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/util/timeutil" +) + +type Printer struct { + StatementVisitorBase + w strings.Builder +} + +var _ StatementVisitor = &Printer{} + +func (p *Printer) PreFoo(ctx StatementContext, foo *Foo) (bool, error) { + p.w.WriteString(foo.val) + return false, nil +} + +func (p *Printer) PreQuux(ctx StatementContext, x Quux) (bool, error) { + p.w.WriteString(x.now.String()) + return false, nil +} + +type Mutator struct { + StatementVisitorBase +} + +var _ StatementVisitor = &Mutator{} + +// We're going to mutate Foo's in-place. +func (Mutator) PreFoo(ctx StatementContext, foo *Foo) (bool, error) { + // Via Russ Cox + // https://groups.google.com/d/msg/golang-nuts/oPuBaYJ17t4/PCmhdAyrNVkJ + n := 0 + runes := make([]rune, len(foo.val)) + for _, r := range foo.val { + runes[n] = r + n++ + } + // Account for multi-byte points. + runes = runes[0:n] + // Reverse. + for i := 0; i < n/2; i++ { + runes[i], runes[n-1-i] = runes[n-1-i], runes[i] + } + + // Update in-place. + foo.val = string(runes) + return false, nil +} + +// We're going to replace Quux instances. +func (Mutator) PostQuux(ctx StatementContext, quux Quux) error { + quux.now = timeutil.Now() + ctx.Replace(quux) + // Just to be explicit that once the replacement has happened, + // it's all by-value. + quux.now = timeutil.UnixEpoch + return nil +} + +func TestPrint(t *testing.T) { + x := Bar{ + foo: Foo{ + val: "olleH", + }, + fooPtr: &Foo{ + val: "!dlroW ", + }, + quux: Quux{ + now: timeutil.UnixEpoch, + }, + quuxPtr: &Quux{ + now: timeutil.UnixEpoch, + }, + } + + x2, changed, err := x.Walk(context.Background(), Mutator{}) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("not changed") + } + if x.fooPtr != x2.fooPtr { + t.Fatal("fooPtr should not have changed") + } + + sv := &Printer{} + x3, changed, err := x2.Walk(context.Background(), sv) + if err != nil { + t.Fatal(err) + } + if changed { + t.Fatal("should not have changed") + } + if x2.fooPtr != x3.fooPtr { + t.Fatal("pointer should not have changed") + } + t.Log(sv.w.String()) +} + +func BenchmarkNoop(b *testing.B) { + x := Bar{ + foo: Foo{ + val: "olleH", + }, + fooPtr: &Foo{ + val: "!dlroW ", + }, + quux: Quux{ + now: timeutil.UnixEpoch, + }, + quuxPtr: &Quux{ + now: timeutil.UnixEpoch, + }, + } + v := &StatementVisitorBase{} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, _, err := x.Walk(context.Background(), v); err != nil { + b.Fatal(err) + } + } +} diff --git a/pkg/util/walker/walker.go b/pkg/util/walker/walker.go new file mode 100644 index 000000000000..fd4462199367 --- /dev/null +++ b/pkg/util/walker/walker.go @@ -0,0 +1,19 @@ +package walker + +// Interface represents a visitable node. +type Interface interface { +} + +type WalkError struct { + Reason error +} + +var _ error = &WalkError{} + +func (e WalkError) Cause() error { + return e.Reason +} + +func (e WalkError) Error() string { + return e.Reason.Error() +}