diff --git a/dig.go b/dig.go index e99cc9eb..e51d7bc9 100644 --- a/dig.go +++ b/dig.go @@ -65,6 +65,7 @@ type provideOptions struct { Info *ProvideInfo As []interface{} Location *digreflect.Func + Names []string } func (o *provideOptions) Validate() error { @@ -282,10 +283,40 @@ func LocationForPC(pc uintptr) ProvideOption { }) } +type invokeOptions struct { + Names []string +} + +func (*invokeOptions) Validate() error { + return nil +} + // An InvokeOption modifies the default behavior of Invoke. It's included for // future functionality; currently, there are no concrete implementations. type InvokeOption interface { - unimplemented() + applyInvokeOption(*invokeOptions) +} + +type invokeOptionFunc func(*invokeOptions) + +func (f invokeOptionFunc) applyInvokeOption(opts *invokeOptions) { f(opts) } + +type InvokeAndProvideOption interface { + InvokeOption + ProvideOption +} + +type namesOption []string + +func (n namesOption) applyInvokeOption(opts *invokeOptions) { + opts.Names = n +} +func (n namesOption) applyProvideOption(opts *provideOptions) { + opts.Names = n +} + +func Names(names ...string) InvokeAndProvideOption { + return namesOption(names) } // Container is a directed acyclic graph of types and their dependencies. @@ -566,7 +597,15 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return errf("can't invoke non-function %v (type %v)", function, ftype) } - pl, err := newParamList(ftype) + var options invokeOptions + for _, o := range opts { + o.applyInvokeOption(&options) + } + if err := options.Validate(); err != nil { + return err + } + + pl, err := newParamList(ftype, options.Names) if err != nil { return err } @@ -624,6 +663,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error { ResultGroup: opts.Group, ResultAs: opts.As, Location: opts.Location, + ParamNames: opts.Names, }, ) if err != nil { @@ -842,6 +882,7 @@ type nodeOptions struct { ResultGroup string ResultAs []interface{} Location *digreflect.Func + ParamNames []string } func newNode(ctor interface{}, opts nodeOptions) (*node, error) { @@ -849,7 +890,7 @@ func newNode(ctor interface{}, opts nodeOptions) (*node, error) { ctype := cval.Type() cptr := cval.Pointer() - params, err := newParamList(ctype) + params, err := newParamList(ctype, opts.ParamNames) if err != nil { return nil, err } diff --git a/dig_test.go b/dig_test.go index a66ae0c6..74511de5 100644 --- a/dig_test.go +++ b/dig_test.go @@ -536,6 +536,45 @@ func TestEndToEndSuccess(t *testing.T) { }), "invoke should succeed, pulling out two named instances") }) + t.Run("named instances can be used to Provide another instance", func(t *testing.T) { + c := New() + + type A struct{ idx int } + + buildConstructor := func(idx int) func() A { + return func() A { return A{idx: idx} } + } + + require.NoError(t, c.Provide(buildConstructor(1), Name("first"))) + require.NoError(t, c.Provide(buildConstructor(2), Name("second"))) + require.NoError(t, c.Provide(func(a A) int { + return a.idx + 5 + }, Names("first"))) + + require.NoError(t, c.Invoke(func(i int) { + assert.Equal(t, 6, i) + }), "invoke should succeed, pulling out one named instances") + }) + + t.Run("named instances can be invoked Name option", func(t *testing.T) { + c := New() + + type A struct{ idx int } + + buildConstructor := func(idx int) func() A { + return func() A { return A{idx: idx} } + } + + require.NoError(t, c.Provide(buildConstructor(1), Name("first"))) + require.NoError(t, c.Provide(buildConstructor(2), Name("second"))) + require.NoError(t, c.Provide(buildConstructor(3), Name("third"))) + + require.NoError(t, c.Invoke(func(a1 A, a3 A) { + assert.Equal(t, 1, a1.idx) + assert.Equal(t, 3, a3.idx) + }, Names("first", "third")), "invoke should succeed, using two named instances") + }) + t.Run("named and unnamed instances coexist", func(t *testing.T) { c := New() type A struct{ idx int } @@ -561,6 +600,25 @@ func TestEndToEndSuccess(t *testing.T) { })) }) + t.Run("named and unnamed instances can be invoked with Names option", func(t *testing.T) { + c := New() + + type A struct{ idx int } + + buildConstructor := func(idx int) func() A { + return func() A { return A{idx: idx} } + } + + require.NoError(t, c.Provide(buildConstructor(1), Name("first"))) + require.NoError(t, c.Provide(buildConstructor(2), Name("second"))) + require.NoError(t, c.Provide(buildConstructor(3))) + + require.NoError(t, c.Invoke(func(a1 A, a3 A) { + assert.Equal(t, 1, a1.idx) + assert.Equal(t, 3, a3.idx) + }, Names("first")), "invoke should succeed, using two named instances") + }) + t.Run("named instances recurse", func(t *testing.T) { c := New() type A struct{ idx int } diff --git a/param.go b/param.go index df7868f3..3b35f344 100644 --- a/param.go +++ b/param.go @@ -62,11 +62,14 @@ var ( // newParam builds a param from the given type. If the provided type is a // dig.In struct, an paramObject will be returned. -func newParam(t reflect.Type) (param, error) { +func newParam(t reflect.Type, paramName string) (param, error) { switch { case IsOut(t) || (t.Kind() == reflect.Ptr && IsOut(t.Elem())) || embedsType(t, _outPtrType): return nil, errf("cannot depend on result objects", "%v embeds a dig.Out", t) case IsIn(t): + if paramName != "" { + return nil, errf("cannot have a paramName (%s) with a struct that has dig.In", paramName) + } return newParamObject(t) case embedsType(t, _inPtrType): return nil, errf( @@ -77,7 +80,7 @@ func newParam(t reflect.Type) (param, error) { "cannot depend on a pointer to a parameter object, use a value instead", "%v is a pointer to a struct that embeds dig.In", t) default: - return paramSingle{Type: t}, nil + return paramSingle{Type: t, Name: paramName}, nil } } @@ -158,7 +161,7 @@ func (pl paramList) DotParam() []*dot.Param { // // Variadic arguments of a constructor are ignored and not included as // dependencies. -func newParamList(ctype reflect.Type) (paramList, error) { +func newParamList(ctype reflect.Type, names []string) (paramList, error) { numArgs := ctype.NumIn() if ctype.IsVariadic() { // NOTE: If the function is variadic, we skip the last argument @@ -171,8 +174,16 @@ func newParamList(ctype reflect.Type) (paramList, error) { Params: make([]param, 0, numArgs), } + if numArgs < len(names) { + return pl, errf("can't create a constructor with more names=%s than args=%s", names, ctype) + } + for i := 0; i < numArgs; i++ { - p, err := newParam(ctype.In(i)) + name := "" + if i < len(names) { + name = names[i] + } + p, err := newParam(ctype.In(i), name) if err != nil { return pl, errf("bad argument %d", i+1, err) } @@ -370,7 +381,7 @@ func newParamObjectField(idx int, f reflect.StructField) (paramObjectField, erro default: var err error - p, err = newParam(f.Type) + p, err = newParam(f.Type, "") if err != nil { return pof, err } diff --git a/param_test.go b/param_test.go index 68f0cde2..e17e3ea5 100644 --- a/param_test.go +++ b/param_test.go @@ -30,7 +30,7 @@ import ( ) func TestParamListBuild(t *testing.T) { - p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil })) + p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), []string{}) require.NoError(t, err) assert.Panics(t, func() { p.Build(New()) @@ -238,7 +238,7 @@ func TestParamVisitorChecksEverything(t *testing.T) { pl, err := newParamList(reflect.TypeOf(func(io.Reader, params, io.Writer) { t.Fatalf("this function should not be called") - })) + }), []string{}) require.NoError(t, err) idx := 0