Skip to content

Commit

Permalink
NCR-14625 add Apply
Browse files Browse the repository at this point in the history
  • Loading branch information
solokirrik committed Jan 22, 2024
1 parent 05be24a commit 54ee6f9
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
40 changes: 40 additions & 0 deletions apply.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package pipeline

import "context"

type apply[A, B, C any] struct {
a Processor[A, []B]
b Processor[B, C]
}

func (j *apply[A, B, C]) Process(ctx context.Context, a A) ([]C, error) {
bs, err := j.a.Process(ctx, a)
if err != nil {
j.a.Cancel(a, err)
return []C{}, err
}

cs := make([]C, 0, len(bs))

for i := range bs {
c, err := j.b.Process(ctx, bs[i])
if err != nil {
j.b.Cancel(bs[i], err)
return cs, err
}

cs = append(cs, c)
}

return cs, nil
}

func (j *apply[A, B, C]) Cancel(_ A, _ error) {}

// Apply connects two processes, applying the second to each item of the first output
func Apply[A, B, C any](
a Processor[A, []B],
b Processor[B, C],
) Processor[A, []C] {
return &apply[A, B, C]{a, b}
}
57 changes: 57 additions & 0 deletions apply_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package pipeline

import (
"context"
"strings"
"testing"
)

func TestLoopApply(t *testing.T) {
transform := NewProcessor(func(_ context.Context, s string) ([]string, error) {
return strings.Split(s, ","), nil
}, nil)

double := NewProcessor(func(_ context.Context, s string) (string, error) {
return s + s, nil
}, nil)

addLeadingZero := NewProcessor(func(_ context.Context, s string) (string, error) {
return "0" + s, nil
}, nil)

looper := Apply(
transform,
Sequence(
double,
addLeadingZero,
double,
),
)

gotCount := 0
input := "1,2,3,4,5"
want := []string{"011011", "022022", "033033", "044044", "055055"}

for out := range Process(context.Background(), looper, Emit(input)) {
for j := range out {
gotCount++
if !contains(want, out[j]) {
t.Errorf("does not contains got=%v, want=%v", out[j], want)
}
}
}

if gotCount != len(want) {
t.Errorf("total results got=%v, want=%v", gotCount, len(want))
}
}

func contains(s []string, e string) bool {
for i := range s {
if s[i] == e {
return true
}
}

return false
}
1 change: 0 additions & 1 deletion collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ func TestCollect(t *testing.T) {
if !reflect.DeepEqual(test.want.out, outs) {
t.Errorf("out = %v, want %v", outs, test.want.out)
}

})
}
}

0 comments on commit 54ee6f9

Please sign in to comment.