diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 13556cd..47f284f 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -1,6 +1,6 @@ name: Test and coverage -on: [push, pull_request] +on: [push] jobs: build: diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 8215c37..f8e7b58 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -14,9 +14,6 @@ name: "CodeQL" on: push: branches: [ main ] - pull_request: - # The branches below must be a subset of the branches above - branches: [ main ] schedule: - cron: '20 3 * * 4' diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 474e174..59061e3 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,4 +1,4 @@ -on: [push, pull_request] +on: [push] name: Unit Tests jobs: test: diff --git a/debug.go b/debug.go index 18e8716..2f33dda 100644 --- a/debug.go +++ b/debug.go @@ -33,10 +33,7 @@ func debugln(stuff ...interface{}) { if debuglnHook != nil { debuglnHook(stuff...) } else { - for _, s := range stuff { - debugOutput += fmt.Sprint(s) - } - debugOutput += "\n" + debugOutput += fmt.Sprintln(stuff...) } debugOutputMu.Unlock() } diff --git a/intheap.go b/intheap.go index 02c930e..a8ada84 100644 --- a/intheap.go +++ b/intheap.go @@ -1,23 +1,41 @@ -package nject +package nject + +import ( + "container/heap" +) // Code below originated with the container/heap documentation -type IntHeap []int +type IntsHeap [][2]int -func (h IntHeap) Len() int { return len(h) } -func (h IntHeap) Less(i, j int) bool { return h[i] < h[j] } -func (h IntHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h IntsHeap) Len() int { return len(h) } +func (h IntsHeap) Less(i, j int) bool { return h[i][0] < h[j][0] } +func (h IntsHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } -func (h *IntHeap) Push(x interface{}) { +func (h *IntsHeap) Push(x interface{}) { // Push and Pop use pointer receivers because they modify the slice's length, // not just its contents. - *h = append(*h, x.(int)) + *h = append(*h, x.([2]int)) } -func (h *IntHeap) Pop() interface{} { +func (h *IntsHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] *h = old[0 : n-1] return x } + +func push(h *IntsHeap, funcs []*provider, i int) { + priority := i + if i < len(funcs) && funcs[i].reorder { + priority -= len(funcs) + } + heap.Push(h, [2]int{priority, i}) +} + +func pop(h *IntsHeap) int { + //nolint:errcheck // we trust the type + x := heap.Pop(h).([2]int) + return x[1] +} diff --git a/reorder.go b/reorder.go index 989dcd7..0ea4443 100644 --- a/reorder.go +++ b/reorder.go @@ -218,21 +218,29 @@ func reorder(funcs []*provider, initF *provider) []*provider { nodes[pair[1]].weakBefore[pair[0]] = struct{}{} nodes[pair[0]].weakAfter[pair[1]] = struct{}{} } + for _, pair := range weakPairs { + if _, ok := nodes[pair[0]].weakBefore[pair[1]]; ok { + debugln("\tremove mutual weak", pair) + delete(nodes[pair[1]].weakBefore, pair[0]) + delete(nodes[pair[0]].weakBefore, pair[0]) + delete(nodes[pair[0]].weakAfter, pair[1]) + delete(nodes[pair[1]].weakAfter, pair[1]) + } + } - unblocked := &IntHeap{} + unblocked := &IntsHeap{} heap.Init(unblocked) - weakBlocked := &IntHeap{} + weakBlocked := &IntsHeap{} heap.Init(weakBlocked) if initF != nil { for _, t := range noNoType(initF.flows[outputParams]) { if num, ok := downTypes[t]; ok { debugln("\trelease down for InitF", t) - heap.Push(unblocked, num) + push(unblocked, funcs, num) } } } - x := topo{ funcs: funcs, nodes: nodes, @@ -265,9 +273,9 @@ type topo struct { funcs []*provider nodes []node cannotReorder []int - unblocked *IntHeap // no weak or strong blocks - weakBlocked *IntHeap // only weak blocks - done []bool // TODO: use https://pkg.go.dev/github.com/boljen/go-bitmap#Bitmap instead + unblocked *IntsHeap // no weak or strong blocks + weakBlocked *IntsHeap // only weak blocks + done []bool // TODO: use https://pkg.go.dev/github.com/boljen/go-bitmap#Bitmap instead reorderedFuncs []*provider upTypes map[typeCode]int downTypes map[typeCode]int @@ -277,16 +285,17 @@ func (x *topo) release(n, i int) { if n >= len(x.funcs) { // types only have strong relationships debugln("\treleased", n) - heap.Push(x.unblocked, n) + push(x.unblocked, x.funcs, n) } else { delete(x.nodes[n].after, i) delete(x.nodes[n].weakAfter, i) if len(x.nodes[n].after) == 0 { - debugln("\treleased", n) if len(x.nodes[n].weakAfter) == 0 { - heap.Push(x.unblocked, n) + debugln("\treleased", n) + push(x.unblocked, x.funcs, n) } else { - heap.Push(x.weakBlocked, n) + debugln("\treleased (weak)", n, x.nodes[n].weakAfter) + push(x.weakBlocked, x.funcs, n) } } else { debugln("\tcannot release", n, x.nodes[n].after) @@ -294,15 +303,24 @@ func (x *topo) release(n, i int) { } } +func (x *topo) releaseNode(i int) { + for n := range x.nodes[i].weakBefore { + delete(x.nodes[n].weakAfter, i) + } + for n := range x.nodes[i].before { + x.release(n, i) + } +} + func (x *topo) run() { for { if x.unblocked.Len() > 0 { //nolint:errcheck // cast is safe - i := heap.Pop(x.unblocked).(int) + i := pop(x.unblocked) x.processOne(i, true) } else if x.weakBlocked.Len() > 0 { //nolint:errcheck // cast is safe - i := heap.Pop(x.weakBlocked).(int) + i := pop(x.weakBlocked) x.processOne(i, true) } else if len(x.cannotReorder) > 0 { i := x.cannotReorder[0] @@ -323,12 +341,6 @@ func (x *topo) run() { } } -func (x *topo) releaseNode(i int) { - for n := range x.nodes[i].before { - x.release(n, i) - } -} - func (x *topo) processOne(i int, release bool) { debugln("\tpopped", i, release) if release { diff --git a/reorder_test.go b/reorder_test.go index c954ff9..b6ab284 100644 --- a/reorder_test.go +++ b/reorder_test.go @@ -130,6 +130,7 @@ func TestReorderUnused(t *testing.T) { } func TestReorderOverride(t *testing.T) { + t.Parallel() var dd *Debugging seq1 := Sequence("example", Shun(func() string { @@ -154,3 +155,32 @@ func TestReorderOverride(t *testing.T) { t.Log(dd.Trace) } } + +func TestReorderInOut(t *testing.T) { + t.Parallel() + type r string + var final string + var dd *Debugging + assert.NoError(t, Run(t.Name(), + func() string { + return "start" + }, + func(s string) r { + return r(s) + }, + Reorder(func(s string) string { + return s + " reordered1" + }), + Reorder(func(s string) string { + return s + " reordered2" + }), + func(r r, _ string, d *Debugging) { + final = string(r) + dd = d + }, + )) + assert.Equal(t, "start reordered1 reordered2", final) + if t.Failed() { + t.Log(dd.Trace) + } +}