diff --git a/constructor.go b/constructor.go index 8e694a26..00bba546 100644 --- a/constructor.go +++ b/constructor.go @@ -53,7 +53,11 @@ type constructorNode struct { // Type information about constructor results. resultList resultList - order int // order of this node in graphHolder + // order of this node in each Scopes' graphHolders. + orders map[*Scope]int + + // scope this node was originally provided to. + s *Scope } type constructorOptions struct { @@ -65,12 +69,12 @@ type constructorOptions struct { Location *digreflect.Func } -func newConstructorNode(ctor interface{}, c containerStore, opts constructorOptions) (*constructorNode, error) { +func newConstructorNode(ctor interface{}, s *Scope, opts constructorOptions) (*constructorNode, error) { cval := reflect.ValueOf(ctor) ctype := cval.Type() cptr := cval.Pointer() - params, err := newParamList(ctype, c) + params, err := newParamList(ctype, s) if err != nil { return nil, err } @@ -99,8 +103,10 @@ func newConstructorNode(ctor interface{}, c containerStore, opts constructorOpti id: dot.CtorID(cptr), paramList: params, resultList: results, + orders: make(map[*Scope]int), + s: s, } - n.order = c.newGraphNode(n) + s.newGraphNode(n, n.orders) return n, nil } @@ -109,7 +115,7 @@ func (n *constructorNode) ParamList() paramList { return n.paramList } func (n *constructorNode) ResultList() resultList { return n.resultList } func (n *constructorNode) ID() dot.CtorID { return n.id } func (n *constructorNode) CType() reflect.Type { return n.ctype } -func (n *constructorNode) Order() int { return n.order } +func (n *constructorNode) Order(s *Scope) int { return n.orders[s] } func (n *constructorNode) String() string { return fmt.Sprintf("deps: %v, ctor: %v", n.paramList, n.ctype) diff --git a/container.go b/container.go index 4b2045f2..ef5c5a85 100644 --- a/container.go +++ b/container.go @@ -21,12 +21,9 @@ package dig import ( - "bytes" "fmt" "math/rand" "reflect" - "sort" - "time" "go.uber.org/dig/internal/dot" ) @@ -63,32 +60,12 @@ type Option interface { } // Container is a directed acyclic graph of types and their dependencies. +// A Container is the root Scope that represents the top-level scoped +// directed acyclic graph of the dependencies. type Container struct { - // Mapping from key to all the constructor node that can provide a value for that - // key. - providers map[key][]*constructorNode - - nodes []*constructorNode - - // Values that have already been generated in the container. - values map[key]reflect.Value - - // Values groups that have already been generated in the container. - groups map[key][]reflect.Value - - // Source of randomness. - rand *rand.Rand - - // Flag indicating whether the graph has been checked for cycles. - isVerifiedAcyclic bool - - // Defer acyclic check on provide until Invoke. - deferAcyclicVerification bool - - // invokerFn calls a function with arguments provided to Provide or Invoke. - invokerFn invokerFn - - gh *graphHolder + // this is the "root" Scope that represents the + // root of the scope tree. + scope *Scope } // containerWriter provides write access to the Container's underlying data @@ -108,8 +85,8 @@ type containerWriter interface { type containerStore interface { containerWriter - // Adds a new graph node to the Container and returns its order. - newGraphNode(w interface{}) int + // Adds a new graph node to the Container + newGraphNode(w interface{}, orders map[*Scope]int) // Returns a slice containing all known types. knownTypes() []reflect.Type @@ -130,6 +107,12 @@ type containerStore interface { // type. getGroupProviders(name string, t reflect.Type) []provider + // Returns the providers that can produce a value with the given name and + // type across all the Scopes that are in effect of this containerStore. + getAllValueProviders(name string, t reflect.Type) []provider + + getStoresFromRoot() []containerStore + createGraph() *dot.Graph // Returns invokerFn function to use when calling arguments. @@ -138,15 +121,8 @@ type containerStore interface { // New constructs a Container. func New(opts ...Option) *Container { - c := &Container{ - providers: make(map[key][]*constructorNode), - values: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - invokerFn: defaultInvoker, - } - - c.gh = newGraphHolder(c) + s := newScope() + c := &Container{scope: s} for _, opt := range opts { opt.applyOption(c) @@ -172,7 +148,7 @@ func (deferAcyclicVerificationOption) String() string { } func (deferAcyclicVerificationOption) applyOption(c *Container) { - c.deferAcyclicVerification = true + c.scope.deferAcyclicVerification = true } // Changes the source of randomness for the container. @@ -189,7 +165,7 @@ func (o setRandOption) String() string { } func (o setRandOption) applyOption(c *Container) { - c.rand = o.r + c.scope.rand = o.r } // DryRun is an Option which, when set to true, disables invocation of functions supplied to @@ -206,9 +182,9 @@ func (o dryRunOption) String() string { func (o dryRunOption) applyOption(c *Container) { if o { - c.invokerFn = dryInvoker + c.scope.invokerFn = dryInvoker } else { - c.invokerFn = defaultInvoker + c.scope.invokerFn = defaultInvoker } } @@ -230,105 +206,14 @@ func dryInvoker(fn reflect.Value, _ []reflect.Value) []reflect.Value { return results } -func (c *Container) knownTypes() []reflect.Type { - typeSet := make(map[reflect.Type]struct{}, len(c.providers)) - for k := range c.providers { - typeSet[k.t] = struct{}{} - } - - types := make([]reflect.Type, 0, len(typeSet)) - for t := range typeSet { - types = append(types, t) - } - sort.Sort(byTypeName(types)) - return types -} - -func (c *Container) getValue(name string, t reflect.Type) (v reflect.Value, ok bool) { - v, ok = c.values[key{name: name, t: t}] - return -} - -func (c *Container) setValue(name string, t reflect.Type, v reflect.Value) { - c.values[key{name: name, t: t}] = v -} - -func (c *Container) getValueGroup(name string, t reflect.Type) []reflect.Value { - items := c.groups[key{group: name, t: t}] - // shuffle the list so users don't rely on the ordering of grouped values - return shuffledCopy(c.rand, items) -} - -func (c *Container) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { - k := key{group: name, t: t} - c.groups[k] = append(c.groups[k], v) -} - -func (c *Container) getValueProviders(name string, t reflect.Type) []provider { - return c.getProviders(key{name: name, t: t}) -} - -func (c *Container) getGroupProviders(name string, t reflect.Type) []provider { - return c.getProviders(key{group: name, t: t}) -} - -func (c *Container) getProviders(k key) []provider { - nodes := c.providers[k] - providers := make([]provider, len(nodes)) - for i, n := range nodes { - providers[i] = n - } - return providers -} - -// invokerFn return a function to run when calling function provided to Provide or Invoke. Used for -// running container in dry mode. -func (c *Container) invoker() invokerFn { - return c.invokerFn -} - -func (c *Container) newGraphNode(wrapped interface{}) int { - return c.gh.NewNode(wrapped) -} - -func (c *Container) cycleDetectedError(cycle []int) error { - var path []cycleErrPathEntry - for _, n := range cycle { - if n, ok := c.gh.Lookup(n).(*constructorNode); ok { - path = append(path, cycleErrPathEntry{ - Key: key{ - t: n.CType(), - }, - Func: n.Location(), - }) - } - } - return errCycleDetected{Path: path} -} - // String representation of the entire Container func (c *Container) String() string { - b := &bytes.Buffer{} - fmt.Fprintln(b, "nodes: {") - for k, vs := range c.providers { - for _, v := range vs { - fmt.Fprintln(b, "\t", k, "->", v) - } - } - fmt.Fprintln(b, "}") - - fmt.Fprintln(b, "values: {") - for k, v := range c.values { - fmt.Fprintln(b, "\t", k, "=>", v) - } - for k, vs := range c.groups { - for _, v := range vs { - fmt.Fprintln(b, "\t", k, "=>", v) - } - } - fmt.Fprintln(b, "}") + return c.scope.String() +} - return b.String() +// Scope creates a child scope of the Container with the given name. +func (c *Container) Scope(name string, opts ...ScopeOption) *Scope { + return c.scope.Scope(name, opts...) } type byTypeName []reflect.Type diff --git a/cycle_error.go b/cycle_error.go index d6916bbf..c1d41abf 100644 --- a/cycle_error.go +++ b/cycle_error.go @@ -33,7 +33,8 @@ type cycleErrPathEntry struct { } type errCycleDetected struct { - Path []cycleErrPathEntry + Path []cycleErrPathEntry + scope *Scope } func (e errCycleDetected) Error() string { @@ -46,6 +47,7 @@ func (e errCycleDetected) Error() string { // b := new(bytes.Buffer) + fmt.Fprintf(b, "In Scope %s: \n", e.scope.name) for i, entry := range e.Path { if i > 0 { b.WriteString("\n\tdepends on ") diff --git a/dig_test.go b/dig_test.go index ad6e70ba..2e3dd8d6 100644 --- a/dig_test.go +++ b/dig_test.go @@ -3105,14 +3105,14 @@ func TestNodeAlreadyCalled(t *testing.T) { type type1 struct{} f := func() type1 { return type1{} } - n, err := newConstructorNode(f, New(), constructorOptions{}) + n, err := newConstructorNode(f, newScope(), constructorOptions{}) require.NoError(t, err, "failed to build node") require.False(t, n.called, "node must not have been called") c := New() - require.NoError(t, n.Call(c), "invoke failed") + require.NoError(t, n.Call(c.scope), "invoke failed") require.True(t, n.called, "node must be called") - require.NoError(t, n.Call(c), "calling again should be okay") + require.NoError(t, n.Call(c.scope), "calling again should be okay") } func TestFailingFunctionDoesNotCreateInvalidState(t *testing.T) { diff --git a/graph.go b/graph.go index 755470b5..6c8e7e13 100644 --- a/graph.go +++ b/graph.go @@ -33,13 +33,13 @@ type graphNode struct { // It saves constructorNodes and paramGroupedSlice (value groups) // as nodes in the graph. // It implements the graph interface defined by internal/graph. -// It has 1-1 correspondence with the Container whose graph it represents. +// It has 1-1 correspondence with the Scope whose graph it represents. type graphHolder struct { // all the nodes defined in the graph. nodes []*graphNode - // Container whose graph this holder contains. - c *Container + // Scope whose graph this holder contains. + s *Scope // Number of nodes in the graph at last snapshot. // -1 if no snapshot has been taken. @@ -48,9 +48,8 @@ type graphHolder struct { var _ graph.Graph = (*graphHolder)(nil) -func newGraphHolder(c *Container) *graphHolder { - return &graphHolder{c: c, snap: -1} - +func newGraphHolder(s *Scope) *graphHolder { + return &graphHolder{s: s, snap: -1} } func (gh *graphHolder) Order() int { return len(gh.nodes) } @@ -72,9 +71,9 @@ func (gh *graphHolder) EdgesFrom(u int) []int { orders = append(orders, getParamOrder(gh, param)...) } case *paramGroupedSlice: - providers := gh.c.getGroupProviders(w.Group, w.Type.Elem()) + providers := gh.s.getAllGroupProviders(w.Group, w.Type.Elem()) for _, provider := range providers { - orders = append(orders, provider.Order()) + orders = append(orders, provider.Order(gh.s)) } } return orders diff --git a/invoke.go b/invoke.go index 0729f4c7..acfc25af 100644 --- a/invoke.go +++ b/invoke.go @@ -43,6 +43,18 @@ type InvokeOption interface { // The function may return an error to indicate failure. The error will be // returned to the caller as-is. func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { + return c.scope.Invoke(function, opts...) +} + +// Invoke runs the given function after instantiating its dependencies. +// +// Any arguments that the function has are treated as its dependencies. The +// dependencies are instantiated in an unspecified order along with any +// dependencies that they might have. +// +// The function may return an error to indicate failure. The error will be +// returned to the caller as-is. +func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error { ftype := reflect.TypeOf(function) if ftype == nil { return errors.New("can't invoke an untyped nil") @@ -51,33 +63,33 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return errf("can't invoke non-function %v (type %v)", function, ftype) } - pl, err := newParamList(ftype, c) + pl, err := newParamList(ftype, s) if err != nil { return err } - if err := shallowCheckDependencies(c, pl); err != nil { + if err := shallowCheckDependencies(s, pl); err != nil { return errMissingDependencies{ Func: digreflect.InspectFunc(function), Reason: err, } } - if !c.isVerifiedAcyclic { - if ok, cycle := graph.IsAcyclic(c.gh); !ok { - return errf("cycle detected in dependency graph", c.cycleDetectedError(cycle)) + if !s.isVerifiedAcyclic { + if ok, cycle := graph.IsAcyclic(s.gh); !ok { + return errf("cycle detected in dependency graph", s.cycleDetectedError(cycle)) } - c.isVerifiedAcyclic = true + s.isVerifiedAcyclic = true } - args, err := pl.BuildList(c) + args, err := pl.BuildList(s) if err != nil { return errArgumentsFailed{ Func: digreflect.InspectFunc(function), Reason: err, } } - returned := c.invokerFn(reflect.ValueOf(function), args) + returned := s.invokerFn(reflect.ValueOf(function), args) if len(returned) == 0 { return nil } @@ -112,7 +124,7 @@ func findMissingDependencies(c containerStore, params ...param) []paramSingle { for _, param := range params { switch p := param.(type) { case paramSingle: - if ns := c.getValueProviders(p.Name, p.Type); len(ns) == 0 && !p.Optional { + if ns := c.getAllValueProviders(p.Name, p.Type); len(ns) == 0 && !p.Optional { missingDeps = append(missingDeps, p) } case paramObject: diff --git a/param.go b/param.go index 942edfe5..bf9d4272 100644 --- a/param.go +++ b/param.go @@ -146,10 +146,35 @@ func (pl paramList) Build(containerStore) (reflect.Value, error) { // to the underlying constructor. func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { args := make([]reflect.Value, len(pl.Params)) + argsBuilt := make([]bool, len(pl.Params)) + allContainers := c.getStoresFromRoot() for i, p := range pl.Params { var err error - args[i], err = p.Build(c) - if err != nil { + var arg reflect.Value + // iterate through the tree path of scopes. + for _, c := range allContainers { + if arg, err = p.Build(c); err == nil { + args[i] = arg + argsBuilt[i] = true + } + } + + // If an argument failed to build, that means none of the + // scopes had the type. This should be reported. + if !argsBuilt[i] { + return nil, err + } + + // If argument has successfully been built, it's possible + // for these errors to occur in child scopes that don't + // contain the given parameter type. We can safely ignore + // these. + // If it's an error other than missing types/dependencies, + // this means some constructor returned an error that must + // be reported. + _, isErrMissingTypes := err.(errMissingTypes) + _, isErrMissingDeps := err.(errMissingDependencies) + if err != nil && !isErrMissingTypes && !isErrMissingDeps { return nil, err } } @@ -264,14 +289,14 @@ func getParamOrder(gh *graphHolder, param param) []int { var orders []int switch p := param.(type) { case paramSingle: - providers := gh.c.getValueProviders(p.Name, p.Type) + providers := gh.s.getAllValueProviders(p.Name, p.Type) for _, provider := range providers { - orders = append(orders, provider.Order()) + orders = append(orders, provider.Order(gh.s)) } case paramGroupedSlice: // value group parameters have nodes of their own. // We can directly return that here. - orders = append(orders, p.order) + orders = append(orders, p.orders[gh.s]) case paramObject: for _, pf := range p.Fields { orders = append(orders, getParamOrder(gh, pf.Param)...) @@ -410,7 +435,7 @@ type paramGroupedSlice struct { // Type of the slice. Type reflect.Type - order int + orders map[*Scope]int } func (pt paramGroupedSlice) String() string { @@ -438,7 +463,7 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped if err != nil { return paramGroupedSlice{}, err } - pg := paramGroupedSlice{Group: g.Name, Type: f.Type} + pg := paramGroupedSlice{Group: g.Name, Type: f.Type, orders: make(map[*Scope]int)} name := f.Tag.Get(_nameTag) optional, _ := isFieldOptional(f) @@ -457,7 +482,7 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped case optional: return pg, errors.New("value groups cannot be optional") } - pg.order = c.newGraphNode(&pg) + c.newGraphNode(&pg, pg.orders) return pg, nil } diff --git a/param_test.go b/param_test.go index f1e82545..7a1f41ed 100644 --- a/param_test.go +++ b/param_test.go @@ -30,10 +30,10 @@ import ( ) func TestParamListBuild(t *testing.T) { - p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), New()) + p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), newScope()) require.NoError(t, err) assert.Panics(t, func() { - p.Build(New()) + p.Build(newScope()) }) } @@ -57,7 +57,7 @@ func TestParamObjectSuccess(t *testing.T) { } `name:"bar"` } - po, err := newParamObject(reflect.TypeOf(in{}), New()) + po, err := newParamObject(reflect.TypeOf(in{}), newScope()) require.NoError(t, err) require.Len(t, po.Fields, 4) @@ -114,7 +114,7 @@ func TestParamObjectWithUnexportedFieldsSuccess(t *testing.T) { _ = in{}.t2 // unused - po, err := newParamObject(reflect.TypeOf(in{}), New()) + po, err := newParamObject(reflect.TypeOf(in{}), newScope()) require.NoError(t, err) require.Len(t, po.Fields, 1) @@ -138,7 +138,7 @@ func TestParamObjectFailure(t *testing.T) { _ = in{}.a2 // unused but needed - _, err := newParamObject(reflect.TypeOf(in{}), New()) + _, err := newParamObject(reflect.TypeOf(in{}), newScope()) require.Error(t, err) assert.Contains(t, err.Error(), `bad field "a2" of dig.in: unexported fields not allowed in dig.In, did you mean to export "a2" (dig.A)`) @@ -155,7 +155,7 @@ func TestParamObjectFailure(t *testing.T) { _ = in{}.a2 // unused but needed - _, err := newParamObject(reflect.TypeOf(in{}), New()) + _, err := newParamObject(reflect.TypeOf(in{}), newScope()) require.Error(t, err) assert.Contains(t, err.Error(), `bad field "a2" of dig.in: unexported fields not allowed in dig.In, did you mean to export "a2" (dig.A)`) @@ -172,7 +172,7 @@ func TestParamObjectFailure(t *testing.T) { _ = in{}.a2 // unused but needed - _, err := newParamObject(reflect.TypeOf(in{}), New()) + _, err := newParamObject(reflect.TypeOf(in{}), newScope()) require.Error(t, err) assert.Contains(t, err.Error(), `invalid value "foo" for "ignore-unexported" tag on field In: strconv.ParseBool: parsing "foo": invalid syntax`) @@ -227,7 +227,7 @@ func TestParamGroupSliceErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - _, err := newParamObject(reflect.TypeOf(tt.shape), New()) + _, err := newParamObject(reflect.TypeOf(tt.shape), newScope()) require.Error(t, err, "expected failure") assert.Contains(t, err.Error(), tt.wantErr) }) diff --git a/provide.go b/provide.go index 40d3d867..bcf28af5 100644 --- a/provide.go +++ b/provide.go @@ -308,7 +308,7 @@ type provider interface { // Order reports the order of this provider in the graphHolder. // This value is usually returned by the graphHolder.NewNode method. - Order() int + Order(*Scope) int // Location returns where this constructor was defined. Location() *digreflect.Func @@ -343,11 +343,32 @@ type provider interface { // same types are requested multiple times, the previously produced value will // be reused. // -// In addition to accepting constructors that accept dependencies as separate -// arguments and produce results as separate return values, Provide also -// accepts constructors that specify dependencies as dig.In structs and/or -// specify results as dig.Out structs. +// Provide accepts argument types or dig.In structs as dependencies, and +// separate return values or dig.Out structs for results. func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) error { + return c.scope.Provide(constructor, opts...) +} + +// Provide teaches the Scope how to build values of one or more types and +// expresses their dependencies. +// +// The first argument of Provide is a function that accepts zero or more +// parameters and returns one or more results. The function may optionally +// return an error to indicate that it failed to build the value. This +// function will be treated as the constructor for all the types it returns. +// This function will be called AT MOST ONCE when a type produced by it, or a +// type that consumes this function's output, is requested via Invoke. If the +// same types are requested multiple times, the previously produced value will +// be reused. +// +// Provide accepts argument types or dig.In structs as dependencies, and +// separate return values or dig.Out structs for results. +// +// When a constructor is Provided to a Scope, it will propagate this to any +// Scopes that are descendents, but not ancestors of this Scope. +// To provide a constructor to all the Scopes available, provide it to +// Container, which is the root Scope. +func (s *Scope) Provide(constructor interface{}, opts ...ProvideOption) error { ctype := reflect.TypeOf(constructor) if ctype == nil { return errors.New("can't provide an untyped nil") @@ -364,7 +385,7 @@ func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) erro return err } - if err := c.provide(constructor, options); err != nil { + if err := s.provide(constructor, options); err != nil { return errProvide{ Func: digreflect.InspectFunc(constructor), Reason: err, @@ -373,20 +394,25 @@ func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) erro return nil } -func (c *Container) provide(ctor interface{}, opts provideOptions) (err error) { +func (s *Scope) provide(ctor interface{}, opts provideOptions) (err error) { + // For all scopes affected by this change, // take a snapshot of the current graph state before // we start making changes to it as we may need to // undo them upon encountering errors. - c.gh.Snapshot() - defer func() { - if err != nil { - c.gh.Rollback() - } - }() + allScopes := s.appendLeafScopes(nil) + for _, s := range allScopes { + s := s + s.gh.Snapshot() + defer func() { + if err != nil { + s.gh.Rollback() + } + }() + } n, err := newConstructorNode( ctor, - c, + s, constructorOptions{ ResultName: opts.Name, ResultGroup: opts.Group, @@ -398,7 +424,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) (err error) { return err } - keys, err := c.findAndValidateResults(n) + keys, err := s.findAndValidateResults(n) if err != nil { return err } @@ -411,25 +437,29 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) (err error) { oldProviders := make(map[key][]*constructorNode) for k := range keys { // Cache old providers before running cycle detection. - oldProviders[k] = c.providers[k] - c.providers[k] = append(c.providers[k], n) + oldProviders[k] = s.providers[k] + s.providers[k] = append(s.providers[k], n) } - c.isVerifiedAcyclic = false - if !c.deferAcyclicVerification { - if ok, cycle := graph.IsAcyclic(c.gh); !ok { + for _, s := range allScopes { + s.isVerifiedAcyclic = false + if s.deferAcyclicVerification { + continue + } + if ok, cycle := graph.IsAcyclic(s.gh); !ok { // When a cycle is detected, recover the old providers to reset // the providers map back to what it was before this node was // introduced. for k, ops := range oldProviders { - c.providers[k] = ops + s.providers[k] = ops } - return errf("this function introduces a cycle", c.cycleDetectedError(cycle)) + return errf("this function introduces a cycle", s.cycleDetectedError(cycle)) } - c.isVerifiedAcyclic = true + s.isVerifiedAcyclic = true } - c.nodes = append(c.nodes, n) + + s.nodes = append(s.nodes, n) // Record introspection info for caller if Info option is specified if info := opts.Info; info != nil { @@ -461,11 +491,11 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) (err error) { } // Builds a collection of all result types produced by this constructor. -func (c *Container) findAndValidateResults(n *constructorNode) (map[key]struct{}, error) { +func (s *Scope) findAndValidateResults(n *constructorNode) (map[key]struct{}, error) { var err error keyPaths := make(map[key]string) walkResult(n.ResultList(), connectionVisitor{ - c: c, + s: s, n: n, err: &err, keyPaths: keyPaths, @@ -485,7 +515,7 @@ func (c *Container) findAndValidateResults(n *constructorNode) (map[key]struct{} // Visits the results of a node and compiles a collection of all the keys // produced by that node. type connectionVisitor struct { - c *Container + s *Scope n *constructorNode // If this points to a non-nil value, we've already encountered an error @@ -570,7 +600,7 @@ func (cv connectionVisitor) checkKey(k key, path string) error { "already provided by %v", conflict, ) } - if ps := cv.c.providers[k]; len(ps) > 0 { + if ps := cv.s.providers[k]; len(ps) > 0 { cons := make([]string, len(ps)) for i, p := range ps { cons[i] = fmt.Sprint(p.Location()) diff --git a/scope.go b/scope.go new file mode 100644 index 00000000..9ea08343 --- /dev/null +++ b/scope.go @@ -0,0 +1,270 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package dig + +import ( + "bytes" + "fmt" + "math/rand" + "reflect" + "sort" + "time" +) + +// A ScopeOption modifies the default behavior of Scope; currently, +// there are no implementations. +type ScopeOption interface { + noScopeOption() //yet +} + +// Scope is a scoped DAG of types and their dependencies. +// A Scope may also have one or more child Scopes that inherit +// from it. +type Scope struct { + // This implements containerStore interface. + + // Name of the Scope + name string + // Mapping from key to all the constructor node that can provide a value for that + // key. + providers map[key][]*constructorNode + + // constructorNodes provided directly to this Scope. i.e. it does not include + // any nodes that were provided to the parent Scope this inherited from. + nodes []*constructorNode + + // Values that generated directly in the Scope. + values map[key]reflect.Value + + // Values groups that generated directly in the Scope. + groups map[key][]reflect.Value + + // Source of randomness. + rand *rand.Rand + + // Flag indicating whether the graph has been checked for cycles. + isVerifiedAcyclic bool + + // Defer acyclic check on provide until Invoke. + deferAcyclicVerification bool + + // invokerFn calls a function with arguments provided to Provide or Invoke. + invokerFn invokerFn + + // graph of this Scope. Note that this holds the dependency graph of all the + // nodes that affect this Scope, not just the ones provided directly to this Scope. + gh *graphHolder + + // Parent of this Scope. + parentScope *Scope + + // All the child scopes of this Scope. + childScopes []*Scope +} + +func newScope() *Scope { + s := &Scope{ + name: "container", + providers: make(map[key][]*constructorNode), + values: make(map[key]reflect.Value), + groups: make(map[key][]reflect.Value), + invokerFn: defaultInvoker, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + } + s.gh = newGraphHolder(s) + return s +} + +// Scope creates a new Scope with the given name and options from current Scope. +// Any constructors that the current Scope knows about, as well as any modifications +// made to it in the future will be propagated to the child scope. +// However, no modifications made to the child scope being created will be propagated +// to the parent Scope. +func (s *Scope) Scope(name string, opts ...ScopeOption) *Scope { + child := newScope() + child.name = name + child.parentScope = s + child.invokerFn = s.invokerFn + child.deferAcyclicVerification = s.deferAcyclicVerification + + // child copies the parent's graph nodes. + child.gh.nodes = append(child.gh.nodes, s.gh.nodes...) + + for _, opt := range opts { + opt.noScopeOption() + } + + s.childScopes = append(s.childScopes, child) + return child +} + +// getScopesFromRoot returns a list of Scopes from the root Container +// until the current Scope. +func (s *Scope) getScopesFromRoot() []*Scope { + var scopes []*Scope + for s := s; s != nil; s = s.parentScope { + scopes = append(scopes, s) + } + for i, j := 0, len(scopes)-1; i < j; i, j = i+1, j-1 { + scopes[i], scopes[j] = scopes[j], scopes[i] + } + return scopes +} + +func (s *Scope) appendLeafScopes(dest []*Scope) []*Scope { + dest = append(dest, s) + for _, cs := range s.childScopes { + dest = cs.appendLeafScopes(dest) + } + return dest +} + +func (s *Scope) getStoresFromRoot() []containerStore { + var stores []containerStore + for s := s; s != nil; s = s.parentScope { + stores = append(stores, s) + } + for i, j := 0, len(stores)-1; i < j; i, j = i+1, j-1 { + stores[i], stores[j] = stores[j], stores[i] + } + return stores +} + +func (s *Scope) knownTypes() []reflect.Type { + typeSet := make(map[reflect.Type]struct{}, len(s.providers)) + for k := range s.providers { + typeSet[k.t] = struct{}{} + } + + types := make([]reflect.Type, 0, len(typeSet)) + for t := range typeSet { + types = append(types, t) + } + sort.Sort(byTypeName(types)) + return types +} + +func (s *Scope) getValue(name string, t reflect.Type) (v reflect.Value, ok bool) { + v, ok = s.values[key{name: name, t: t}] + return +} + +func (s *Scope) setValue(name string, t reflect.Type, v reflect.Value) { + s.values[key{name: name, t: t}] = v +} + +func (s *Scope) getValueGroup(name string, t reflect.Type) []reflect.Value { + items := s.groups[key{group: name, t: t}] + // shuffle the list so users don't rely on the ordering of grouped values + return shuffledCopy(s.rand, items) +} + +func (s *Scope) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { + k := key{group: name, t: t} + s.groups[k] = append(s.groups[k], v) +} + +func (s *Scope) getValueProviders(name string, t reflect.Type) []provider { + return s.getProviders(key{name: name, t: t}) +} + +func (s *Scope) getGroupProviders(name string, t reflect.Type) []provider { + return s.getProviders(key{group: name, t: t}) +} + +func (s *Scope) getProviders(k key) []provider { + nodes := s.providers[k] + providers := make([]provider, len(nodes)) + for i, n := range nodes { + providers[i] = n + } + return providers +} + +func (s *Scope) getAllGroupProviders(name string, t reflect.Type) []provider { + return s.getAllProviders(key{group: name, t: t}) +} + +func (s *Scope) getAllValueProviders(name string, t reflect.Type) []provider { + return s.getAllProviders(key{name: name, t: t}) +} + +func (s *Scope) getAllProviders(k key) []provider { + allScopes := s.getScopesFromRoot() + var providers []provider + for _, scope := range allScopes { + providers = append(providers, scope.getProviders(k)...) + } + return providers +} + +func (s *Scope) invoker() invokerFn { + return s.invokerFn +} + +// adds a new graphNode to this Scope and all of its descendent +// scope. +func (s *Scope) newGraphNode(wrapped interface{}, orders map[*Scope]int) { + orders[s] = s.gh.NewNode(wrapped) + for _, cs := range s.childScopes { + cs.newGraphNode(wrapped, orders) + } +} + +func (s *Scope) cycleDetectedError(cycle []int) error { + var path []cycleErrPathEntry + for _, n := range cycle { + if n, ok := s.gh.Lookup(n).(*constructorNode); ok { + path = append(path, cycleErrPathEntry{ + Key: key{ + t: n.CType(), + }, + Func: n.Location(), + }) + } + } + return errCycleDetected{Path: path, scope: s} +} + +// String representation of the entire Scope +func (s *Scope) String() string { + b := &bytes.Buffer{} + fmt.Fprintln(b, "nodes: {") + for k, vs := range s.providers { + for _, v := range vs { + fmt.Fprintln(b, "\t", k, "->", v) + } + } + fmt.Fprintln(b, "}") + + fmt.Fprintln(b, "values: {") + for k, v := range s.values { + fmt.Fprintln(b, "\t", k, "=>", v) + } + for k, vs := range s.groups { + for _, v := range vs { + fmt.Fprintln(b, "\t", k, "=>", v) + } + } + fmt.Fprintln(b, "}") + + return b.String() +} diff --git a/scope_test.go b/scope_test.go new file mode 100644 index 00000000..bd8f5106 --- /dev/null +++ b/scope_test.go @@ -0,0 +1,218 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package dig + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScopedOperations(t *testing.T) { + t.Parallel() + + t.Run("getStores/ScopesFromRoot returns scopes from root in order of distance from root", func(t *testing.T) { + c := New() + s1 := c.Scope("child1") + s2 := s1.Scope("child2") + s3 := s2.Scope("child2") + + assert.Equal(t, []containerStore{c.scope, s1, s2, s3}, s3.getStoresFromRoot()) + assert.Equal(t, []*Scope{c.scope, s1, s2, s3}, s3.getScopesFromRoot()) + }) + + t.Run("private provides", func(t *testing.T) { + c := New() + s := c.Scope("child") + type A struct{} + + f := func(a *A) { + assert.NotEqual(t, nil, a) + } + + require.NoError(t, s.Provide(func() *A { return &A{} })) + assert.NoError(t, s.Invoke(f)) + assert.Error(t, c.Invoke(f)) + }) + + t.Run("private provides inherits", func(t *testing.T) { + type A struct{} + type B struct{} + + useA := func(a *A) { + assert.NotEqual(t, nil, a) + } + useB := func(b *B) { + assert.NotEqual(t, nil, b) + } + + c := New() + require.NoError(t, c.Provide(func() *A { return &A{} })) + + child := c.Scope("child") + require.NoError(t, child.Provide(func() *B { return &B{} })) + assert.NoError(t, child.Invoke(useA)) + assert.NoError(t, child.Invoke(useB)) + + grandchild := child.Scope("grandchild") + + assert.NoError(t, grandchild.Invoke(useA)) + assert.NoError(t, grandchild.Invoke(useB)) + assert.Error(t, c.Invoke(useB)) + }) + + t.Run("provides to top-level Container propogates to all scopes", func(t *testing.T) { + type A struct{} + + // Scope tree: + // root <-- Provide(func() *A) + // / \ + // c1 c2 + // | / \ + // gc1 gc2 gc3 + var allScopes []*Scope + root := New() + + allScopes = append(allScopes, root.Scope("child 1"), root.Scope("child 2")) + allScopes = append(allScopes, allScopes[0].Scope("grandchild 1"), allScopes[1].Scope("grandchild 2"), allScopes[1].Scope("grandchild 3")) + + require.NoError(t, root.Provide(func() *A { + return &A{} + })) + + // top-level provide should be available in all the scopes. + for _, scope := range allScopes { + assert.NoError(t, scope.Invoke(func(a *A) {})) + } + }) +} + +func TestScopeFailures(t *testing.T) { + t.Parallel() + + t.Run("introduce a cycle with child", func(t *testing.T) { + // what root sees: + // A <- B C + // | ^ + // |_________| + // + // what child sees: + // A <- B <- C + // | ^ + // |_________| + type A struct{} + type B struct{} + type C struct{} + newA := func(*C) *A { return &A{} } + newB := func(*A) *B { return &B{} } + newC := func(*B) *C { return &C{} } + + // Create a child Scope, and introduce a cycle + // in the child only. + check := func(c *Container, fails bool) { + s := c.Scope("child") + assert.NoError(t, c.Provide(newA)) + assert.NoError(t, s.Provide(newB)) + err := c.Provide(newC) + + if fails { + assert.Error(t, err, "expected a cycle to be introduced in the child") + assert.Contains(t, err.Error(), "In Scope child") + } else { + assert.NoError(t, err) + } + } + + // Same as check, but this time child should inherit + // parent-provided constructors upon construction. + checkWithInheritance := func(c *Container, fails bool) { + assert.NoError(t, c.Provide(newA)) + s := c.Scope("child") + assert.NoError(t, s.Provide(newB)) + err := c.Provide(newC) + if fails { + assert.Error(t, err, "expected a cycle to be introduced in the child") + assert.Contains(t, err.Error(), "In Scope child") + } else { + assert.NoError(t, err) + } + } + + // Test using different permutations + nodeferContainers := []func() *Container{ + func() *Container { return New() }, + func() *Container { return New(DryRun(true)) }, + func() *Container { return New(DryRun(false)) }, + } + // Container permutations with DeferAcyclicVerification. + deferredContainers := []func() *Container{ + func() *Container { return New(DeferAcyclicVerification()) }, + func() *Container { return New(DeferAcyclicVerification(), DryRun(true)) }, + func() *Container { return New(DeferAcyclicVerification(), DryRun(false)) }, + } + + for _, c := range nodeferContainers { + check(c(), true) + checkWithInheritance(c(), true) + } + + // with deferAcyclicVerification, these should not + // error on Provides. + for _, c := range deferredContainers { + check(c(), false) + checkWithInheritance(c(), false) + } + }) + + t.Run("private provides do not propagate upstream", func(t *testing.T) { + type A struct{} + + root := New() + c := root.Scope("child") + gc := c.Scope("grandchild") + require.NoError(t, gc.Provide(func() *A { return &A{} })) + + assert.Error(t, root.Invoke(func(a *A) {}), "invoking on grandchild's private-provided type should fail") + assert.Error(t, c.Invoke(func(a *A) {}), "invoking on child's private-provided type should fail") + }) + + t.Run("private provides to child should be available to grandchildren, but not root", func(t *testing.T) { + type A struct{} + // Scope tree: + // root + // | + // child <-- Provide(func() *A) + // / \ + // gc1 gc2 + root := New() + c := root.Scope("child") + gc := c.Scope("grandchild") + + require.NoError(t, c.Provide(func() *A { return &A{} })) + + err := root.Invoke(func(a *A) {}) + assert.Error(t, err, "expected Invoke in root container on child's private-provided type to fail") + assert.Contains(t, err.Error(), "missing type: *dig.A") + + assert.NoError(t, gc.Invoke(func(a *A) {}), "expected Invoke in grandchild container on child's private-provided type to fail") + }) +} diff --git a/visualize.go b/visualize.go index 372712fa..d73048a0 100644 --- a/visualize.go +++ b/visualize.go @@ -167,9 +167,13 @@ func CanVisualizeError(err error) bool { } func (c *Container) createGraph() *dot.Graph { + return c.scope.createGraph() +} + +func (s *Scope) createGraph() *dot.Graph { dg := dot.NewGraph() - for _, n := range c.nodes { + for _, n := range s.nodes { dg.AddCtor(newDotCtor(n), n.paramList.DotParam(), n.resultList.DotResult()) } diff --git a/visualize_test.go b/visualize_test.go index d87380f1..128ac7ac 100644 --- a/visualize_test.go +++ b/visualize_test.go @@ -377,7 +377,7 @@ func TestNewDotCtor(t *testing.T) { type t1 struct{} type t2 struct{} - n, err := newConstructorNode(func(A t1) t2 { return t2{} }, New(), constructorOptions{}) + n, err := newConstructorNode(func(A t1) t2 { return t2{} }, newScope(), constructorOptions{}) require.NoError(t, err) n.location = &digreflect.Func{