Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add (across) and (for), for consuming/looping over streams in parallel #148

Merged
merged 2 commits into from
Apr 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions pkg/bass/across.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package bass

import (
"context"
"fmt"
"reflect"
)

type CrossSource struct {
sources []*Source
cases []reflect.SelectCase
chans []<-chan Value
next []Value
}

func Across(ctx context.Context, sources ...*Source) *Source {
agg := &CrossSource{
sources: sources,
cases: make([]reflect.SelectCase, len(sources)),
chans: make([]<-chan Value, len(sources)),
next: make([]Value, len(sources)),
}

for i, src := range sources {
ch := make(chan Value)
agg.chans[i] = ch
agg.cases[i] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ch),
}

go agg.update(ctx, src.PipeSource, ch)
}

return &Source{agg}
}

func (cross *CrossSource) String() string {
return fmt.Sprintf("<cross: %v>", cross.sources)
}

func (cross *CrossSource) update(ctx context.Context, stream PipeSource, ch chan<- Value) {
for {
obj, err := stream.Next(ctx)
if err != nil {
close(ch)
return
}

select {
case ch <- obj:
case <-ctx.Done():
return
}
}
}

func (cross *CrossSource) Next(ctx context.Context) (Value, error) {
if len(cross.chans) == 0 {
return nil, ErrEndOfSource
}

updated := false
for i, ch := range cross.chans {
if cross.next[i] != nil {
continue
}

select {
case val, ok := <-ch:
if !ok {
return nil, ErrEndOfSource
}

cross.next[i] = val
updated = true

case <-ctx.Done():
return nil, ErrInterrupted
}
}

if updated {
return NewList(cross.next...), nil
}

cases := cross.cases

doneIdx := len(cases)
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
})

defaultIdx := len(cases)
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectDefault,
})

hasNew := false
hasDefault := true
exhausted := 0

for {
idx, val, recvOK := reflect.Select(cases)
if idx == doneIdx { // ctx.Done()
return nil, ErrInterrupted
}

if hasDefault && idx == defaultIdx {
if hasNew {
return NewList(cross.next...), nil
}

// nothing new, remove the default so we block on an update instead
cases = append(cases[:idx], cases[idx+1:]...)
hasDefault = false
continue
}

if !recvOK {
exhausted++

cases[idx] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(nil),
}

if exhausted == len(cross.cases) {
// all sources have run dry
return nil, ErrEndOfSource
}

continue
}

cross.next[idx] = val.Interface().(Value)

if hasDefault {
hasNew = true

// nil out the channel so we don't skip values while collecting from the
// other sources
cases[idx] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(nil),
}
} else {
return NewList(cross.next...), nil
}
}
}
90 changes: 90 additions & 0 deletions pkg/bass/across_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package bass_test

import (
"context"
"testing"

"github.com/vito/bass/pkg/bass"
"github.com/vito/is"
)

func TestAcross(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for _, example := range []struct {
Name string
Sources [][]bass.Value
}{
{
Name: "empty",
},
{
Name: "two sources",
Sources: [][]bass.Value{
{bass.Int(0), bass.Int(2), bass.Int(4)},
{bass.Int(1), bass.Int(3), bass.Int(5)},
},
},
{
Name: "two sources, imbalanced",
Sources: [][]bass.Value{
{
bass.Int(0),
bass.Int(2),
bass.Int(4),
bass.Int(6),
bass.Int(8),
bass.Int(10),
bass.Int(12),
},
{bass.Int(1), bass.Int(3), bass.Int(5)},
},
},
{
Name: "three sources",
Sources: [][]bass.Value{
{bass.Int(0), bass.Int(2), bass.Int(4)},
{bass.Int(1), bass.Int(3), bass.Int(5)},
{bass.Symbol("one"), bass.Symbol("three"), bass.Symbol("five")},
},
},
} {
t.Run(example.Name, func(t *testing.T) {
is := is.New(t)

srcs := make([]*bass.Source, len(example.Sources))
for i, vs := range example.Sources {
srcs[i] = bass.NewSource(bass.NewInMemorySource(vs...))
}

src := bass.Across(ctx, srcs...)

have := make([][]bass.Value, len(example.Sources))
for {
val, err := src.PipeSource.Next(ctx)
t.Logf("next: %v %v", val, err)
if err == bass.ErrEndOfSource {
break
}

is.NoErr(err)

vals, err := bass.ToSlice(val.(bass.List))
is.NoErr(err)

for i, v := range vals {
seen := len(have[i])
if seen == 0 || have[i][seen-1] != v {
have[i] = append(have[i], v)
}
}
}

for i, vals := range example.Sources {
t.Logf("saw from %d: %v = %v", i, vals, have[i])
is.Equal(vals, have[i])
}
})
}
}
13 changes: 11 additions & 2 deletions pkg/bass/ground.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,18 @@ func init() {
Func("list->source", "[list]", func(list []Value) Value {
return &Source{NewInMemorySource(list...)}
}),
"creates a stream source from a list of values in chronological order",
"creates a pipe source from a list of values in chronological order",
`=> (list->source [1 2 3])`)

Ground.Set("across",
Func("across", "sources", Across),
"returns a pipe source that yields a list of values across all the given sources",
`Each list has the last value for each source. Values from each source are never skipped, but not every combination will be produced.`,
`=> (def evens (list->source [0 2 4]))`,
`=> (def odds (list->source [1 3 5]))`,
`=> (def combined (across evens odds))`,
`=> [(next combined) (next combined)]`)

Ground.Set("emit",
Func("emit", "[val sink]", func(val Value, sink PipeSink) error {
return sink.Emit(val)
Expand All @@ -420,7 +429,7 @@ func init() {
return val, nil
}),
`receive the next value from a source`,
`If the stream has ended, no value will be available. A default value may be provided, otherwise an error is raised.`,
`If the source has ended, no value will be available. A default value may be provided, otherwise an error is raised.`,
`=> (next (list->source [1]) :eof)`,
`=> (next *stdin* :eof)`)

Expand Down
22 changes: 19 additions & 3 deletions pkg/bass/ground_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,7 @@ func TestGroundPipes(t *testing.T) {
Stdin []bass.Value
Err error
Result bass.Value
Stdout []bass.Value
Sink []bass.Value
}

for _, test := range []example{
Expand All @@ -1469,7 +1469,7 @@ func TestGroundPipes(t *testing.T) {
Name: "emit",
Bass: "(emit 42 sink)",
Result: bass.Null{},
Stdout: []bass.Value{bass.Int(42)},
Sink: []bass.Value{bass.Int(42)},
},
{
Name: "next",
Expand Down Expand Up @@ -1518,6 +1518,22 @@ func TestGroundPipes(t *testing.T) {
Bass: "(take 2 (list->source [1 2 3]))",
Result: bass.NewList(bass.Int(1), bass.Int(2)),
},
{
Name: "across",
Bass: "(next (across (list->source [0 2 4]) (list->source [1 3 5])))",
Result: bass.NewList(bass.Int(0), bass.Int(1)),
},
{
Name: "for",
// NB: cheating here a bit by not going over all of them, but it's not
// worth the complexity as there is nondeterminism here and it's more
// thoroughly tested in (across) already
Bass: "(for [even (list->source [0]) odd (list->source [1])] (emit [even odd] sink))",
Result: bass.Null{},
Sink: []bass.Value{
bass.NewList(bass.Int(0), bass.Int(1)),
},
},
} {
t.Run(test.Name, func(t *testing.T) {
is := is.New(t)
Expand Down Expand Up @@ -1547,7 +1563,7 @@ func TestGroundPipes(t *testing.T) {

stdoutSource := bass.NewJSONSource("test", sinkBuf)

for _, val := range test.Stdout {
for _, val := range test.Sink {
next, err := stdoutSource.Next(context.Background())
is.NoErr(err)
Equal(t, next, val)
Expand Down
19 changes: 19 additions & 0 deletions std/streams.bass
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@
(f n)
(each source f)))))

; loops over values from sources
;
; Takes a list alternating bindings and their sources, similar to (let).
; Reads values across all sources and evaluates the body for each set of values
; as they are read with (next).
;
; Returns null when the source reaches its end.
;
; => (def evens (list->source [0 2 4]))
;
; => (def odds (list->source [1 3 5]))
;
; => (for [a evens b odds] (logf "nums: %d %d" a b))
^:indent
(defop for [bindings & body] scope
(let [sources (map-pairs (fn [_ src] src) bindings)
args (map-pairs (fn [arg _] arg) bindings)]
(eval [each [across & sources] [fn [args] & body]] scope)))

; reads the next n values from the source into a list
;
; => (take 2 (list->source [1 2 3]))
Expand Down