Skip to content

Commit

Permalink
bugfix: Reorder was duplicating injectors (#46)
Browse files Browse the repository at this point in the history
* bugfix: Reorder was duplicating injectors

- bugfix: Reorder was duplicating injectors
- reorder now bombs out if the count of injectors doesn't match
- added a regression test
- small improvement to regression test generation

* try to get back a couple lines of coverage
  • Loading branch information
muir authored Aug 7, 2022
1 parent 18541fc commit 1b600d6
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 24 deletions.
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ linters:
- exportloopref
- errcheck
- exhaustive
- typecheck
enable-all: false
disable:
- maligned
Expand Down
18 changes: 12 additions & 6 deletions debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ func elem(i interface{}) reflect.Type {

func generateReproduce(funcs []*provider, invokeF *provider, initF *provider) string {
subs := make(map[typeCode]string)
t := ""
f := "func TestRegression(t *testing.T) {\n"
var t string
var f string
f += "\twrapTest(t, func(t *testing.T) {\n"
f += "\t\tcalled := make(map[string]int)\n"
f += "\t\tvar invoker " + funcSig(subs, &t, elem(invokeF.fn)) + "\n"
Expand Down Expand Up @@ -292,10 +292,16 @@ func generateReproduce(funcs []*provider, invokeF *provider, initF *provider) st
f += " {\n"
f += fmt.Sprintf("%s\t\t\t\t\tcalled[%q]++\n", extraIndent, n)
f += extraIndent + "\t\t\t\t\tinner(" + strings.Join(substituteDefaults(subs, typesIn(typ.In(0))), ", ") + ")\n"
f += extraIndent + "\t\t\t\t\treturn " + strings.Join(substituteDefaults(subs, out), ", ") + "\n"
if len(out) > 0 {
f += extraIndent + "\t\t\t\t\treturn " + strings.Join(substituteDefaults(subs, out), ", ") + "\n"
}
f += extraIndent + "\t\t\t\t}"
} else {
f += fmt.Sprintf(" { called[%q]++; return %s }", n, strings.Join(substituteDefaults(subs, out), ", "))
if len(out) > 0 {
f += fmt.Sprintf(" { called[%q]++; return %s }", n, strings.Join(substituteDefaults(subs, out), ", "))
} else {
f += fmt.Sprintf(" { called[%q]++ }", n)
}
}
f += close + ","
if fm.include {
Expand All @@ -318,7 +324,7 @@ func generateReproduce(funcs []*provider, invokeF *provider, initF *provider) st
f += "\t\tinvoker(" + strings.Join(substituteDefaults(subs, typesIn(elem(invokeF.fn))), ", ") + ")\n"
f += "\t})\n"
f += "}\n"
return t + "\n" + f
return "func TestRegression(t *testing.T) {\n" + t + "\n" + f
}

// TODO: take note of which interfaces implement each other and new interfaces that
Expand Down Expand Up @@ -348,7 +354,7 @@ func substituteTypes(subs map[typeCode]string, defineTypes *string, types []refl
}
} else {
subs[tc] = fmt.Sprintf("s%03d", tc)
*defineTypes += fmt.Sprintf("// %s\ntype s%03d int\n", tc, tc)
*defineTypes += fmt.Sprintf("\t// %s\n\ttype s%03d int\n", tc, tc)
}
}
replacements = append(replacements, subs[tc])
Expand Down
1 change: 1 addition & 0 deletions debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func TestDetailedError(t *testing.T) {
Reorder(func() time.Time { return time.Now() }),
NotCacheable(func(i int) int32 { return int32(i) }),
),
func(_ MyType1, _ MyType3) {},
// CallsInner(func(i func()) { i() }),
Memoize(func(i int32) int32 { return i }),
OverridesError(func(i func()) error { return nil }),
Expand Down
13 changes: 9 additions & 4 deletions include.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ type includeWorkingData struct {

func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*provider, error) {
var err error
funcs = reorder(funcs, initF)
funcs, err = reorder(funcs, initF)
if err != nil {
return nil, err
}
for i, fm := range funcs {
fm.chainPosition = i
}
Expand Down Expand Up @@ -115,6 +118,7 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
return nil, err
}

debugln("eliminating providers that cannot be included")
for _, fm := range funcs {
if fm.cannotInclude != nil {
debugf("Excluding %s: %s", fm, fm.cannotInclude)
Expand All @@ -123,6 +127,7 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
}
}

debugln("eliminate unused providers")
eliminateUnused(funcs)

tryWithout := func(without ...*provider) bool {
Expand Down Expand Up @@ -164,7 +169,7 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
return err == nil
}

// Attempt to eliminate providers
debugln("attempt to eliminate additional providers")
for _, fm := range proposeEliminations(funcs) {
if fm.d.excluded != nil {
continue
Expand Down Expand Up @@ -193,12 +198,12 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
debugln("final calculate flows")
err = providesReturns(funcs, initF)
if err != nil {
return nil, fmt.Errorf("internal error: uh oh")
return nil, fmt.Errorf("internal error: uh oh: %w", err)
}
debugf("final check chain validity")
err = validateChainMarkIncludeExclude(funcs, true)
if err != nil {
return nil, fmt.Errorf("internal error: uh oh #2")
return nil, fmt.Errorf("internal error: uh oh #2: %w", err)
}

return funcs, nil
Expand Down
151 changes: 151 additions & 0 deletions regressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,154 @@ func TestRegression7642(t *testing.T) {
assert.Equal(t, 0, called["ReassembleQuote"])
})
}

func TestRegression9(t *testing.T) {
type i003 interface {
x003()
}
type s004 int
type i005 interface {
x005()
}
type s006 int
type s009 int
type i011 interface {
x011()
}
type i013 interface {
x013()
}
type i015 interface {
x015()
}
type s017 int
type s018 int
type s020 int
type i021 interface {
x021()
}
type s022 int
type s019 int
type s023 int
type s024 int
type s025 int
type s026 int
type s027 int
type s028 int
type s029 int
type s030 int
type s031 int
type s032 int
type s033 int
type s034 int
type s035 int
type i036 interface {
x036()
}
type s037 int
type s038 int
type s039 int
type s040 int
type s041 int
type s043 int
type s044 int
type s045 int
type s046 int
type s047 int
type s048 int
type s049 int
type s050 int
type s051 int
type s007 int
type s008 int
type s010 int
type s012 int
type s014 int
type s016 int
type i052 interface {
x052()
}
type s054 int

wrapTest(t, func(t *testing.T) {
called := make(map[string]int)
var invoker func() error
err := Sequence("regression",
Provide("Run()error", func() TerminalError { called["Run()error"]++; return nil }),
Provide("TCP-0", func() i003 { called["TCP-0"]++; return nil }),
Provide("TCP-1", func(inner func() error, _ s004) {
called["TCP-1"]++
inner()
}),
Provide("TCP-2", func() i005 { called["TCP-2"]++; return nil }),
Provide("TCP-3", func(_ i005) i005 { called["TCP-3"]++; return nil }),
Shun(Provide("integration-before-user-0", func() s006 { called["integration-before-user-0"]++; return 0 })),
Shun(Provide("integration-before-user-1", func() s009 { called["integration-before-user-1"]++; return 0 })),
Provide("base-chain-0", func(_ i003) i011 { called["base-chain-0"]++; return nil }),
Provide("base-chain-1", func(_ i003) i013 { called["base-chain-1"]++; return nil }),
Provide("base-chain-2", func() i015 { called["base-chain-2"]++; return nil }),
Provide("base-chain-3", func(_ i003, _ s006, _ i011) s017 { called["base-chain-3"]++; return 0 }),
Provide("base-chain-4", func(_ i003, _ i011, _ s018, _ i015) s020 { called["base-chain-4"]++; return 0 }),
Provide("server-chain-0", func(_ i003, _ i011) i021 { called["server-chain-0"]++; return nil }),
Provide("server-chain-1", func() s022 { called["server-chain-1"]++; return 0 }),
Provide("server-chain-2", func() s019 { called["server-chain-2"]++; return 0 }),
Shun(Provide("server-chain-3", func() s023 { called["server-chain-3"]++; return 0 })),
Provide("server-chain-4", func(_ i003, _ s023, _ i015, _ i011) s024 { called["server-chain-4"]++; return 0 }),
Provide("server-chain-5", func(_ i011) s025 { called["server-chain-5"]++; return 0 }),
Shun(Provide("server-chain-6", func() s026 { called["server-chain-6"]++; return 0 })),
Provide("server-chain-7", func(_ i003, _ i015, _ s026, _ i011) s027 { called["server-chain-7"]++; return 0 }),
Shun(Provide("server-chain-8", func() s028 { called["server-chain-8"]++; return 0 })),
Provide("server-chain-9", func(_ i003, _ i015, _ s028, _ i011) s029 { called["server-chain-9"]++; return 0 }),
Shun(Provide("server-chain-10", func() s030 { called["server-chain-10"]++; return 0 })),
Provide("server-chain-11", func(_ i011) s031 { called["server-chain-11"]++; return 0 }),
Provide("server-chain-12", func(_ i011) s032 { called["server-chain-12"]++; return 0 }),
Shun(Provide("server-chain-13", func() s033 { called["server-chain-13"]++; return 0 })),
Provide("server-chain-14", func(_ s020, _ i011, _ s033) s034 { called["server-chain-14"]++; return 0 }),
Shun(Provide("server-chain-15", func() s035 { called["server-chain-15"]++; return 0 })),
Provide("server-chain-16", func(_ i003, _ i011, _ i021, _ s035) i036 { called["server-chain-16"]++; return nil }),
Provide("server-chain-17", func(_ i005, _ i003, _ s037, _ s022, _ i013) s038 { called["server-chain-17"]++; return 0 }),
Provide("server-chain-18", func(_ i003) s039 { called["server-chain-18"]++; return 0 }),
Provide("server-chain-19", func(_ i003) s040 { called["server-chain-19"]++; return 0 }),
Provide("server-chain-20", func(_ s017, _ i013) s041 { called["server-chain-20"]++; return 0 }),
Provide("server-chain-21", func(_ s041, _ s020, _ s039, _ s040) s043 { called["server-chain-21"]++; return 0 }),
Provide("server-chain-22", func(_ s020, _ s039, _ s040, _ i011, _ s017) s044 { called["server-chain-22"]++; return 0 }),
Provide("integration-before-user-4", func(_ i011, _ s044, _ i013, _ s017, _ s038, _ s025, _ i036, _ s020, _ s024, _ s034, _ s045, _ s029, _ s030, _ s031, _ s032) s046 {
called["integration-before-user-4"]++
return 0
}),
Provide("integration-before-user-5", func(_ i011, _ s044, _ i013, _ s017, _ s038, _ s025, _ i036, _ s020, _ s024, _ s045, _ s034, _ s029, _ s030, _ s031, _ s032) s047 {
called["integration-before-user-5"]++
return 0
}),
Provide("environments-0", func(_ i003, _ i011) s048 { called["environments-0"]++; return 0 }),
Provide("environments-1", func(_ i005, _ s017, _ s046, _ s047, _ s048, _ s006) s049 { called["environments-1"]++; return 0 }),
Provide("environments-2", func(_ s048, _ s019, _ s020) s050 { called["environments-2"]++; return 0 }),
Provide("environments-3", func(_ i005, _ s048, _ s049, _ s050) s051 { called["environments-3"]++; return 0 }),
Provide("integration-before-user-7", func() s019 { called["integration-before-user-7"]++; return 0 }),
Provide("integration-user-and-more-0", func(_ i005) s007 { called["integration-user-and-more-0"]++; return 0 }),
Provide("public-client-0", func(_ i005, _ s007) s008 { called["public-client-0"]++; return 0 }),
Provide("public-client-1", func(_ s008) s010 { called["public-client-1"]++; return 0 }),
Provide("integration-user-and-more-2", func(_ i005, _ i003, _ s010, _ s009) s012 { called["integration-user-and-more-2"]++; return 0 }),
Provide("integration-user-and-more-3", func(_ s012) s014 { called["integration-user-and-more-3"]++; return 0 }),
Provide("integration-user-and-more-4", func(_ i005, _ i003, _ s008) s016 { called["integration-user-and-more-4"]++; return 0 }),
Provide("cluster-0", func(_ i003) i052 { called["cluster-0"]++; return nil }),
Provide("revs-0", func(_ i003) { called["revs-0"]++ }),
Reorder(Provide("revs-1", func() s009 { called["revs-1"]++; return 0 })),
Provide("revs-2", func(_ i005) (s045, error) { called["revs-2"]++; return 0, nil }),
Required(Provide("revs-3", func(_ i005, _ i003, _ s008, _ s045, _ s012) { called["revs-3"]++ })),
Provide("revs-4", func(_ i005, _ i003, _ s019) s054 { called["revs-4"]++; return 0 }),
Provide("user-chain-1", func(_ i005, _ i003, _ s019, _ s054) { called["user-chain-1"]++ }),
Provide("user-chain-2", func(_ i005, _ i003, _ s019, _ s054) { called["user-chain-2"]++ }),
Shun(NonFinal(Provide("TCP-6", func(inner func()) error {
called["TCP-6"]++
inner()
return nil
}))),
Provide("user-chain-3", func(_ i005, _ s054, _ s014, _ s019, _ s012) { called["user-chain-3"]++ }),
).Bind(&invoker, nil)
if !assert.NoError(t, err, "bind error") {
t.Log(DetailedError(err))
}
// invoker()
})
}
28 changes: 14 additions & 14 deletions reorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nject

import (
"container/heap"
"fmt"
)

// Reorder annotates a provider to say that its position in the injection
Expand Down Expand Up @@ -54,7 +55,7 @@ func Reorder(fn interface{}) Provider {
//

// generateCheckers must be called before reorder()
func reorder(funcs []*provider, initF *provider) []*provider {
func reorder(funcs []*provider, initF *provider) ([]*provider, error) {
debugln("begin reorder ----------------------------------------------------------")
var someReorder bool
for i, fm := range funcs {
Expand All @@ -64,7 +65,7 @@ func reorder(funcs []*provider, initF *provider) []*provider {
}
}
if !someReorder {
return funcs
return funcs, nil
}

availableDown := make(interfaceMap)
Expand Down Expand Up @@ -258,14 +259,17 @@ func reorder(funcs []*provider, initF *provider) []*provider {
debugln("\t\t", i, fm)
}
debugln("------------------")
return x.reorderedFuncs
if len(funcs) != len(x.reorderedFuncs) {
return nil, fmt.Errorf("internal error: count of funcs changed during reorder")
}
return x.reorderedFuncs, nil
}

type node struct {
before map[int]struct{}
after map[int]struct{}
weakBefore map[int]struct{}
weakAfter map[int]struct{}
before map[int]struct{} // set of nodes that must be released before this node (dependent node is required)
after map[int]struct{} // set of nodes that must be released after this node
weakBefore map[int]struct{} // set of nodes that must be released before this node (dependent node is desired)
weakAfter map[int]struct{} // set of nodes that must be released after this node
}

// topo is the working data for a toplogical sort
Expand Down Expand Up @@ -315,11 +319,9 @@ func (x *topo) releaseNode(i int) {
func (x *topo) run() {
for {
if x.unblocked.Len() > 0 {
//nolint:errcheck // cast is safe
i := pop(x.unblocked)
x.processOne(i, true)
} else if x.weakBlocked.Len() > 0 {
//nolint:errcheck // cast is safe
i := pop(x.weakBlocked)
x.processOne(i, true)
} else if len(x.cannotReorder) > 0 {
Expand All @@ -343,12 +345,10 @@ func (x *topo) run() {

func (x *topo) processOne(i int, release bool) {
debugln("\tpopped", i, release)
if release {
if x.done[i] {
return
}
x.done[i] = true
if x.done[i] {
return
}
x.done[i] = true
if i > len(x.funcs) {
if release {
x.releaseNode(i)
Expand Down

0 comments on commit 1b600d6

Please sign in to comment.