Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Scope to Dig #305

Merged
merged 16 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions constructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ type constructorNode struct {
// Type information about constructor results.
resultList resultList

order int // order of this node in graphHolder
// order of this node in each Scopes' graphHolders.
orders map[*Scope]int

// scope this node was originally provided to.
s *Scope
}

type constructorOptions struct {
Expand All @@ -65,12 +69,12 @@ type constructorOptions struct {
Location *digreflect.Func
}

func newConstructorNode(ctor interface{}, c containerStore, opts constructorOptions) (*constructorNode, error) {
func newConstructorNode(ctor interface{}, s *Scope, opts constructorOptions) (*constructorNode, error) {
cval := reflect.ValueOf(ctor)
ctype := cval.Type()
cptr := cval.Pointer()

params, err := newParamList(ctype, c)
params, err := newParamList(ctype, s)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -99,8 +103,10 @@ func newConstructorNode(ctor interface{}, c containerStore, opts constructorOpti
id: dot.CtorID(cptr),
paramList: params,
resultList: results,
orders: make(map[*Scope]int),
s: s,
}
n.order = c.newGraphNode(n)
s.newGraphNode(n, n.orders)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing n.orders into newGraphNode feels like a leak of internal state here,
but I can't think of an obvious alternative to this just yet.

return n, nil
}

Expand All @@ -109,7 +115,7 @@ func (n *constructorNode) ParamList() paramList { return n.paramList }
func (n *constructorNode) ResultList() resultList { return n.resultList }
func (n *constructorNode) ID() dot.CtorID { return n.id }
func (n *constructorNode) CType() reflect.Type { return n.ctype }
func (n *constructorNode) Order() int { return n.order }
func (n *constructorNode) Order(s *Scope) int { return n.orders[s] }

func (n *constructorNode) String() string {
return fmt.Sprintf("deps: %v, ctor: %v", n.paramList, n.ctype)
Expand Down
163 changes: 24 additions & 139 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@
package dig

import (
"bytes"
"fmt"
"math/rand"
"reflect"
"sort"
"time"

"go.uber.org/dig/internal/dot"
)
Expand Down Expand Up @@ -63,32 +60,12 @@ type Option interface {
}

// Container is a directed acyclic graph of types and their dependencies.
// A Container is the root Scope that represents the top-level scoped
// directed acyclic graph of the dependencies.
type Container struct {
// Mapping from key to all the constructor node that can provide a value for that
// key.
providers map[key][]*constructorNode

nodes []*constructorNode

// Values that have already been generated in the container.
values map[key]reflect.Value

// Values groups that have already been generated in the container.
groups map[key][]reflect.Value

// Source of randomness.
rand *rand.Rand

// Flag indicating whether the graph has been checked for cycles.
isVerifiedAcyclic bool

// Defer acyclic check on provide until Invoke.
deferAcyclicVerification bool

// invokerFn calls a function with arguments provided to Provide or Invoke.
invokerFn invokerFn

gh *graphHolder
// this is the "root" Scope that represents the
// root of the scope tree.
scope *Scope
}

// containerWriter provides write access to the Container's underlying data
Expand All @@ -108,8 +85,8 @@ type containerWriter interface {
type containerStore interface {
containerWriter

// Adds a new graph node to the Container and returns its order.
newGraphNode(w interface{}) int
// Adds a new graph node to the Container
newGraphNode(w interface{}, orders map[*Scope]int)

// Returns a slice containing all known types.
knownTypes() []reflect.Type
Expand All @@ -130,6 +107,12 @@ type containerStore interface {
// type.
getGroupProviders(name string, t reflect.Type) []provider

// Returns the providers that can produce a value with the given name and
// type across all the Scopes that are in effect of this containerStore.
getAllValueProviders(name string, t reflect.Type) []provider

getStoresFromRoot() []containerStore

createGraph() *dot.Graph

// Returns invokerFn function to use when calling arguments.
Expand All @@ -138,15 +121,8 @@ type containerStore interface {

// New constructs a Container.
func New(opts ...Option) *Container {
c := &Container{
providers: make(map[key][]*constructorNode),
values: make(map[key]reflect.Value),
groups: make(map[key][]reflect.Value),
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
invokerFn: defaultInvoker,
}

c.gh = newGraphHolder(c)
s := newScope()
c := &Container{scope: s}

for _, opt := range opts {
opt.applyOption(c)
Expand All @@ -172,7 +148,7 @@ func (deferAcyclicVerificationOption) String() string {
}

func (deferAcyclicVerificationOption) applyOption(c *Container) {
c.deferAcyclicVerification = true
c.scope.deferAcyclicVerification = true
}

// Changes the source of randomness for the container.
Expand All @@ -189,7 +165,7 @@ func (o setRandOption) String() string {
}

func (o setRandOption) applyOption(c *Container) {
c.rand = o.r
c.scope.rand = o.r
}

// DryRun is an Option which, when set to true, disables invocation of functions supplied to
Expand All @@ -206,9 +182,9 @@ func (o dryRunOption) String() string {

func (o dryRunOption) applyOption(c *Container) {
if o {
c.invokerFn = dryInvoker
c.scope.invokerFn = dryInvoker
} else {
c.invokerFn = defaultInvoker
c.scope.invokerFn = defaultInvoker
}
}

Expand All @@ -230,105 +206,14 @@ func dryInvoker(fn reflect.Value, _ []reflect.Value) []reflect.Value {
return results
}

func (c *Container) knownTypes() []reflect.Type {
typeSet := make(map[reflect.Type]struct{}, len(c.providers))
for k := range c.providers {
typeSet[k.t] = struct{}{}
}

types := make([]reflect.Type, 0, len(typeSet))
for t := range typeSet {
types = append(types, t)
}
sort.Sort(byTypeName(types))
return types
}

func (c *Container) getValue(name string, t reflect.Type) (v reflect.Value, ok bool) {
v, ok = c.values[key{name: name, t: t}]
return
}

func (c *Container) setValue(name string, t reflect.Type, v reflect.Value) {
c.values[key{name: name, t: t}] = v
}

func (c *Container) getValueGroup(name string, t reflect.Type) []reflect.Value {
items := c.groups[key{group: name, t: t}]
// shuffle the list so users don't rely on the ordering of grouped values
return shuffledCopy(c.rand, items)
}

func (c *Container) submitGroupedValue(name string, t reflect.Type, v reflect.Value) {
k := key{group: name, t: t}
c.groups[k] = append(c.groups[k], v)
}

func (c *Container) getValueProviders(name string, t reflect.Type) []provider {
return c.getProviders(key{name: name, t: t})
}

func (c *Container) getGroupProviders(name string, t reflect.Type) []provider {
return c.getProviders(key{group: name, t: t})
}

func (c *Container) getProviders(k key) []provider {
nodes := c.providers[k]
providers := make([]provider, len(nodes))
for i, n := range nodes {
providers[i] = n
}
return providers
}

// invokerFn return a function to run when calling function provided to Provide or Invoke. Used for
// running container in dry mode.
func (c *Container) invoker() invokerFn {
return c.invokerFn
}

func (c *Container) newGraphNode(wrapped interface{}) int {
return c.gh.NewNode(wrapped)
}

func (c *Container) cycleDetectedError(cycle []int) error {
var path []cycleErrPathEntry
for _, n := range cycle {
if n, ok := c.gh.Lookup(n).(*constructorNode); ok {
path = append(path, cycleErrPathEntry{
Key: key{
t: n.CType(),
},
Func: n.Location(),
})
}
}
return errCycleDetected{Path: path}
}

// String representation of the entire Container
func (c *Container) String() string {
b := &bytes.Buffer{}
fmt.Fprintln(b, "nodes: {")
for k, vs := range c.providers {
for _, v := range vs {
fmt.Fprintln(b, "\t", k, "->", v)
}
}
fmt.Fprintln(b, "}")

fmt.Fprintln(b, "values: {")
for k, v := range c.values {
fmt.Fprintln(b, "\t", k, "=>", v)
}
for k, vs := range c.groups {
for _, v := range vs {
fmt.Fprintln(b, "\t", k, "=>", v)
}
}
fmt.Fprintln(b, "}")
return c.scope.String()
}

return b.String()
// Scope creates a child scope of the Container with the given name.
func (c *Container) Scope(name string, opts ...ScopeOption) *Scope {
return c.scope.Scope(name, opts...)
}

type byTypeName []reflect.Type
Expand Down
4 changes: 3 additions & 1 deletion cycle_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ type cycleErrPathEntry struct {
}

type errCycleDetected struct {
Path []cycleErrPathEntry
Path []cycleErrPathEntry
scope *Scope
}

func (e errCycleDetected) Error() string {
Expand All @@ -46,6 +47,7 @@ func (e errCycleDetected) Error() string {
//
b := new(bytes.Buffer)

fmt.Fprintf(b, "In Scope %s: \n", e.scope.name)
sywhang marked this conversation as resolved.
Show resolved Hide resolved
sywhang marked this conversation as resolved.
Show resolved Hide resolved
for i, entry := range e.Path {
if i > 0 {
b.WriteString("\n\tdepends on ")
Expand Down
6 changes: 3 additions & 3 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3105,14 +3105,14 @@ func TestNodeAlreadyCalled(t *testing.T) {
type type1 struct{}
f := func() type1 { return type1{} }

n, err := newConstructorNode(f, New(), constructorOptions{})
n, err := newConstructorNode(f, newScope(), constructorOptions{})
require.NoError(t, err, "failed to build node")
require.False(t, n.called, "node must not have been called")

c := New()
require.NoError(t, n.Call(c), "invoke failed")
require.NoError(t, n.Call(c.scope), "invoke failed")
require.True(t, n.called, "node must be called")
require.NoError(t, n.Call(c), "calling again should be okay")
require.NoError(t, n.Call(c.scope), "calling again should be okay")
}

func TestFailingFunctionDoesNotCreateInvalidState(t *testing.T) {
Expand Down
15 changes: 7 additions & 8 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ type graphNode struct {
// It saves constructorNodes and paramGroupedSlice (value groups)
// as nodes in the graph.
// It implements the graph interface defined by internal/graph.
// It has 1-1 correspondence with the Container whose graph it represents.
// It has 1-1 correspondence with the Scope whose graph it represents.
type graphHolder struct {
// all the nodes defined in the graph.
nodes []*graphNode

// Container whose graph this holder contains.
c *Container
// Scope whose graph this holder contains.
s *Scope

// Number of nodes in the graph at last snapshot.
// -1 if no snapshot has been taken.
Expand All @@ -48,9 +48,8 @@ type graphHolder struct {

var _ graph.Graph = (*graphHolder)(nil)

func newGraphHolder(c *Container) *graphHolder {
return &graphHolder{c: c, snap: -1}

func newGraphHolder(s *Scope) *graphHolder {
return &graphHolder{s: s, snap: -1}
}

func (gh *graphHolder) Order() int { return len(gh.nodes) }
Expand All @@ -72,9 +71,9 @@ func (gh *graphHolder) EdgesFrom(u int) []int {
orders = append(orders, getParamOrder(gh, param)...)
}
case *paramGroupedSlice:
providers := gh.c.getGroupProviders(w.Group, w.Type.Elem())
providers := gh.s.getAllGroupProviders(w.Group, w.Type.Elem())
for _, provider := range providers {
orders = append(orders, provider.Order())
orders = append(orders, provider.Order(gh.s))
}
}
return orders
Expand Down
Loading