Skip to content

Commit

Permalink
fixes for parallel stabilize; call update handlers in parallel, inter…
Browse files Browse the repository at this point in the history
…locked update handler register, observe update handler takes value
  • Loading branch information
wcharczuk committed Feb 20, 2024
1 parent 8558351 commit 42a6ef1
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 22 deletions.
71 changes: 71 additions & 0 deletions examples/diagram/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package main

import (
"bytes"
"context"
"fmt"
"math/rand"
"os"

"github.com/wcharczuk/go-incr"
)

const (
SIZE = 8
ROUNDS = 3
)

func concat(a, b string) string {
return a + b
}

func main() {
ctx := context.Background()
graph := incr.New()

nodes := make([]incr.Incr[string], SIZE)
vars := make([]incr.VarIncr[string], 0, SIZE)
for x := 0; x < SIZE; x++ {
v := incr.Var(graph, fmt.Sprintf("var_%d", x))
v.Node().SetLabel(fmt.Sprintf("var-%d", x))
vars = append(vars, v)
nodes[x] = v
}

var cursor int
for x := SIZE; x > 0; x >>= 1 {
for y := 0; y < x-1; y += 2 {
n := incr.Map2(graph, nodes[cursor+y], nodes[cursor+y+1], concat)
n.Node().SetLabel(fmt.Sprintf("map-%d", cursor))
nodes = append(nodes, n)
}
cursor += x
}

if os.Getenv("DEBUG") != "" {
ctx = incr.WithTracing(ctx)
}
_ = incr.MustObserve(graph, nodes[len(nodes)-1])

var err error
for n := 0; n < ROUNDS; n++ {
err = graph.Stabilize(ctx)
if err != nil {
fatal(err)
}
vars[rand.Intn(len(vars))].Set(fmt.Sprintf("set_%d", n))
err = graph.Stabilize(ctx)
if err != nil {
fatal(err)
}
}

buf := new(bytes.Buffer)
_ = incr.Dot(buf, graph)
fmt.Print(buf.String())
}

func fatal(err error) {
fmt.Fprintf(os.Stderr, "%+v\n", err)
os.Exit(1)
}
134 changes: 134 additions & 0 deletions examples/naive/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package main

import (
"context"
"fmt"
"math/big"
"math/rand"
"strings"
"time"

"github.com/wcharczuk/go-incr"
)

const (
SIZE = 8192 << 3
ROUNDS = 256
)

type NaiveNode[A any] interface {
Value(context.Context) A
}

type Node[A, B any] struct {
Children []NaiveNode[A]
Action func(context.Context, ...A) B
}

func (n Node[A, B]) Value(ctx context.Context) B {
inputs := make([]A, len(n.Children))
for x := 0; x < len(n.Children); x++ {
inputs[x] = n.Children[x].Value(ctx)
}
return n.Action(ctx, inputs...)
}

func Var[A any](v A) NaiveNode[A] {
return &Node[any, A]{
Action: func(ctx context.Context, _ ...any) A {
return v
},
}
}

func Map[A, B any](child0, child1 NaiveNode[A], fn func(context.Context, ...A) B) NaiveNode[B] {
return &Node[A, B]{
Children: []NaiveNode[A]{child0, child1},
Action: fn,
}
}

func concatN(_ context.Context, values ...string) string {
return strings.Join(values, "")
}

func main() {
ctx := context.Background()
naiiveVars, naiiveNodes := makeNaiiveNodes()
var naiiveResults []time.Duration
for n := 0; n < ROUNDS; n++ {
start := time.Now()
randomNode := naiiveVars[rand.Intn(len(naiiveVars))].(*Node[any, string])
randomNode.Action = func(ctx context.Context, _ ...any) string {
return fmt.Sprintf("set_%d", n)
}
_ = naiiveNodes[len(naiiveNodes)-1].Value(ctx)
naiiveResults = append(naiiveResults, time.Since(start))
}

graph := incr.New()
incrVars, incrNodes := makeIncrNodes(ctx, graph)
incr.MustObserve(graph, incrNodes[0])

var incrResults []time.Duration
for n := 0; n < ROUNDS; n++ {
start := time.Now()
incrVars[rand.Intn(len(incrVars))].Set(fmt.Sprintf("set_%d", n))
_ = graph.Stabilize(ctx)
incrResults = append(incrResults, time.Since(start))
}

fmt.Println("results!")
fmt.Printf("naiive: %v\n", avgDurations(naiiveResults).Round(time.Microsecond))
fmt.Printf("incr: %v\n", avgDurations(incrResults).Round(time.Microsecond))
}

func avgDurations(values []time.Duration) time.Duration {
accum := new(big.Int)
for _, v := range values {
accum.Add(accum, big.NewInt(int64(v)))
}
return time.Duration(accum.Div(accum, big.NewInt(int64(len(values)))).Int64())
}

func makeNaiiveNodes() (vars []NaiveNode[string], nodes []NaiveNode[string]) {
nodes = make([]NaiveNode[string], SIZE)
vars = make([]NaiveNode[string], SIZE)
for x := 0; x < SIZE; x++ {
v := Var(fmt.Sprintf("var_%d", x))
nodes[x] = v
vars[x] = v
}

var cursor int
for x := SIZE; x > 0; x >>= 1 {
for y := 0; y < x-1; y += 2 {
n := Map[string, string](nodes[cursor+y], nodes[cursor+y+1], concatN)
nodes = append(nodes, n)
}
cursor += x
}
return
}

func makeIncrNodes(_ context.Context, graph *incr.Graph) (vars []incr.VarIncr[string], nodes []incr.Incr[string]) {
nodes = make([]incr.Incr[string], SIZE)
vars = make([]incr.VarIncr[string], SIZE)
for x := 0; x < SIZE; x++ {
v := incr.Var(graph, fmt.Sprintf("var_%d", x))
vars[x] = v
nodes[x] = v
}

var cursor int
for x := SIZE; x > 0; x >>= 1 {
for y := 0; y < x-1; y += 2 {
n := incr.Map2(graph, nodes[cursor+y], nodes[cursor+y+1], func(a, b string) string {
return concatN(context.TODO(), a, b)
})
nodes = append(nodes, n)
}
cursor += x
}
return
}
34 changes: 24 additions & 10 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ func (graph *Graph) stabilizeStart(ctx context.Context) context.Context {
return ctx
}

func (graph *Graph) stabilizeEnd(ctx context.Context, err error) {
func (graph *Graph) stabilizeEnd(ctx context.Context, err error, parallel bool) {
defer func() {
graph.stabilizationStarted = time.Time{}
atomic.StoreInt32(&graph.status, StatusNotStabilizing)
Expand All @@ -534,7 +534,7 @@ func (graph *Graph) stabilizeEnd(ctx context.Context, err error) {
} else {
TracePrintf(ctx, "stabilization complete (%v elapsed)", time.Since(graph.stabilizationStarted).Round(time.Microsecond))
}
graph.stabilizeEndRunUpdateHandlers(ctx)
graph.stabilizeEndRunUpdateHandlers(ctx, parallel)
graph.stabilizationNum++
graph.stabilizeEndHandleSetDuringStabilization(ctx)
}
Expand All @@ -549,7 +549,7 @@ func (graph *Graph) stabilizeEndHandleSetDuringStabilization(ctx context.Context
clear(graph.setDuringStabilization)
}

func (graph *Graph) stabilizeEndRunUpdateHandlers(ctx context.Context) {
func (graph *Graph) stabilizeEndRunUpdateHandlers(ctx context.Context, parallel bool) {
graph.handleAfterStabilizationMu.Lock()
defer graph.handleAfterStabilizationMu.Unlock()

Expand All @@ -560,9 +560,23 @@ func (graph *Graph) stabilizeEndRunUpdateHandlers(ctx context.Context) {
TracePrintln(ctx, "stabilization calling user update handlers complete")
}()
}
for _, uhGroup := range graph.handleAfterStabilization {
for _, uh := range uhGroup {
uh(ctx)
if parallel {
for _, uhGroup := range graph.handleAfterStabilization {
for _, uh := range uhGroup {
graph.workerPool.Go(func(handler func(context.Context)) func() error {
return func() error {
handler(ctx)
return nil
}
}(uh))
}
}
_ = graph.workerPool.Wait()
} else {
for _, uhGroup := range graph.handleAfterStabilization {
for _, uh := range uhGroup {
uh(ctx)
}
}
}
clear(graph.handleAfterStabilization)
Expand Down Expand Up @@ -601,9 +615,9 @@ func (graph *Graph) recompute(ctx context.Context, n INode, parallel bool) (err

nn.changedAt = graph.stabilizationNum
if len(nn.onUpdateHandlers) > 0 {
// graph.handleAfterStabilizationMu.Lock()
graph.handleAfterStabilizationMu.Lock()
graph.handleAfterStabilization[nn.id] = nn.onUpdateHandlers
// graph.handleAfterStabilizationMu.Unlock()
graph.handleAfterStabilizationMu.Unlock()
}

if parallel {
Expand All @@ -624,9 +638,9 @@ func (graph *Graph) recompute(ctx context.Context, n INode, parallel bool) (err
// children of this node but will not have any children themselves.
for _, o := range nn.observers {
if len(o.Node().onUpdateHandlers) > 0 {
// graph.handleAfterStabilizationMu.Lock()
graph.handleAfterStabilizationMu.Lock()
graph.handleAfterStabilization[nn.id] = o.Node().onUpdateHandlers
// graph.handleAfterStabilizationMu.Unlock()
graph.handleAfterStabilizationMu.Unlock()
}
}
return
Expand Down
19 changes: 9 additions & 10 deletions observe.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func Observe[A any](g *Graph, observed Incr[A]) (ObserveIncr[A], error) {
// of incrementals starting a given input.
type ObserveIncr[A any] interface {
IObserver
// OnUpdate lets you register an update handler for the observer node.
//
// This handler is called when the observed node is recomputed (and
// not strictly if the node has changed.)
OnUpdate(func(context.Context, A))
// Value returns the observed node value.
Value() A
}
Expand All @@ -42,11 +47,6 @@ type ObserveIncr[A any] interface {
type IObserver interface {
INode

// OnUpdate lets you register an update handler for the observer node.
//
// This handler is called when the observed node is recomputed (and
// not strictly if the node has changed.)
OnUpdate(func(context.Context))
// Unobserve effectively removes a given node from the observed ref count for a graph.
//
// As well, it unlinks the observer from its parent nodes, and as a result
Expand All @@ -64,11 +64,12 @@ var (
type observeIncr[A any] struct {
n *Node
observed Incr[A]
value A
}

func (o *observeIncr[A]) OnUpdate(fn func(context.Context)) {
o.n.OnUpdate(fn)
func (o *observeIncr[A]) OnUpdate(fn func(context.Context, A)) {
o.n.OnUpdate(func(ctx context.Context) {
fn(ctx, o.Value())
})
}

func (o *observeIncr[A]) Node() *Node { return o.n }
Expand All @@ -81,8 +82,6 @@ func (o *observeIncr[A]) Node() *Node { return o.n }
// To observe parts of a graph again, use the `MustObserve(...)` helper.
func (o *observeIncr[A]) Unobserve(ctx context.Context) {
GraphForNode(o).unobserveNode(o, o.observed)
var value A
o.value = value
o.observed = nil
}

Expand Down
33 changes: 33 additions & 0 deletions observe_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package incr

import (
"context"
"fmt"
"testing"

Expand Down Expand Up @@ -184,3 +185,35 @@ func Test_Observe_alreadyNecessary(t *testing.T) {

testutil.Equal(t, "foo", o2.Value())
}

func Test_Observe_onUpdate(t *testing.T) {
g := New()
v := Var(g, "foo")
m0 := Map(g, v, ident)
o, err := Observe(g, m0)
testutil.NoError(t, err)

var gotValues []string
var updateCalls int
o.OnUpdate(func(ctx context.Context, value string) {
testutil.BlueDye(ctx, t)
gotValues = append(gotValues, value)
updateCalls++
})
ctx := testContext()
err = g.Stabilize(ctx)
testutil.NoError(t, err)

testutil.Equal(t, "foo", o.Value())
testutil.Equal(t, 1, updateCalls)
testutil.Equal(t, []string{"foo"}, gotValues)

v.Set("not-foo")

err = g.Stabilize(ctx)
testutil.NoError(t, err)

testutil.Equal(t, "not-foo", o.Value())
testutil.Equal(t, 2, updateCalls)
testutil.Equal(t, []string{"foo", "not-foo"}, gotValues)
}
2 changes: 1 addition & 1 deletion parallel_stabilize.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (graph *Graph) ParallelStabilize(ctx context.Context) (err error) {
}
ctx = graph.stabilizeStart(ctx)
defer func() {
graph.stabilizeEnd(ctx, err)
graph.stabilizeEnd(ctx, err, true /*parallel*/)
}()
err = graph.parallelStabilize(ctx)
return
Expand Down
2 changes: 1 addition & 1 deletion stabilize.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (graph *Graph) Stabilize(ctx context.Context) (err error) {
}
ctx = graph.stabilizeStart(ctx)
defer func() {
graph.stabilizeEnd(ctx, err)
graph.stabilizeEnd(ctx, err, false /*parallel*/)
}()

var immediateRecompute []INode
Expand Down

0 comments on commit 42a6ef1

Please sign in to comment.