Skip to content

Commit

Permalink
refactor: Extract commonly used function wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
siyul-park committed Jan 13, 2025
1 parent f9ec77e commit 2ed522a
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 68 deletions.
18 changes: 1 addition & 17 deletions ext/pkg/control/if.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package control

import (
"context"
"reflect"
"time"

"github.com/siyul-park/uniflow/ext/pkg/language"
Expand Down Expand Up @@ -36,22 +35,7 @@ func NewIfNodeCodec(compiler language.Compiler) scheme.Codec {
if err != nil {
return nil, err
}
return NewIfNode(func(ctx context.Context, env any) (bool, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

res, err := program.Run(ctx, []any{env})
if err != nil {
return false, err
}
if len(res) == 0 {
return false, nil
}
return !reflect.ValueOf(res[0]).IsZero(), nil
}), nil
return NewIfNode(language.Predicate[any](language.Timeout(program, spec.Timeout))), nil
})
}

Expand Down
18 changes: 1 addition & 17 deletions ext/pkg/control/reduce.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,7 @@ func NewReduceNodeCodec(compiler language.Compiler) scheme.Codec {
if err != nil {
return nil, err
}

return NewReduceNode(func(ctx context.Context, acc, cur any, index int) (any, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

res, err := program.Run(ctx, []any{acc, cur, index})
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, nil
}
return res[0], nil
}, spec.Init), nil
return NewReduceNode(language.TriFunction[any, any, int, any](language.Timeout(program, spec.Timeout)), spec.Init), nil
})
}

Expand Down
17 changes: 1 addition & 16 deletions ext/pkg/control/snippet.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,7 @@ func NewSnippetNodeCodec(module *language.Module) scheme.Codec {
return nil, err
}

return NewSnippetNode(func(ctx context.Context, arg any) (any, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

res, err := program.Run(ctx, []any{arg})
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, nil
}
return res[0], nil
}), nil
return NewSnippetNode(language.Function[any, any](language.Timeout(program, spec.Timeout))), nil
})
}

Expand Down
18 changes: 1 addition & 17 deletions ext/pkg/control/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package control

import (
"context"
"reflect"
"sync"
"time"

Expand Down Expand Up @@ -48,22 +47,7 @@ func NewSwitchNodeCodec(compiler language.Compiler) scheme.Codec {
return nil, err
}

conditions[i] = func(ctx context.Context, env any) (bool, error) {
if spec.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, spec.Timeout)
defer cancel()
}

res, err := program.Run(ctx, []any{env})
if err != nil {
return false, err
}
if len(res) == 0 {
return false, nil
}
return !reflect.ValueOf(res[0]).IsZero(), nil
}
conditions[i] = language.Predicate[any](language.Timeout(program, spec.Timeout))
}

n := NewSwitchNode()
Expand Down
65 changes: 64 additions & 1 deletion ext/pkg/language/program.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package language

import "context"
import (
"context"
"reflect"
"time"
)

// Program represents an interface for running a compiled program with a given environment.
type Program interface {
Expand All @@ -12,6 +16,65 @@ type RunFunc func(context.Context, []any) ([]any, error)

var _ Program = RunFunc(nil)

// Timeout returns a Program that runs the given program with a specified timeout.
func Timeout(program Program, timeout time.Duration) Program {
return RunFunc(func(ctx context.Context, args []any) ([]any, error) {
if timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
return program.Run(ctx, args)
})
}

// Predicate returns a function that runs the program and checks if the first result is non-zero.
func Predicate[T any](program Program) func(context.Context, T) (bool, error) {
return func(ctx context.Context, input T) (bool, error) {
res, err := program.Run(ctx, []any{input})
if err != nil || len(res) == 0 {
return false, err
}
return !reflect.ValueOf(res[0]).IsZero(), nil
}
}

// Function returns a function that runs the program and returns the first result cast to type R.
func Function[T any, R any](program Program) func(context.Context, T) (R, error) {
return func(ctx context.Context, input T) (R, error) {
res, err := program.Run(ctx, []any{input})
if err != nil || len(res) == 0 {
var zero R
return zero, err
}
return res[0].(R), nil
}
}

// BiFunction returns a function that runs the program with two inputs and returns the first result cast to type R.
func BiFunction[T any, U any, R any](program Program) func(context.Context, T, U) (R, error) {
return func(ctx context.Context, t T, u U) (R, error) {
res, err := program.Run(ctx, []any{t, u})
if err != nil || len(res) == 0 {
var zero R
return zero, err
}
return res[0].(R), nil
}
}

// TriFunction returns a function that runs the program with three inputs and returns the first result cast to type R.
func TriFunction[T any, U any, V any, R any](program Program) func(context.Context, T, U, V) (R, error) {
return func(ctx context.Context, t T, u U, v V) (R, error) {
res, err := program.Run(ctx, []any{t, u, v})
if err != nil || len(res) == 0 {
var zero R
return zero, err
}
return res[0].(R), nil
}
}

// Run executes the program with the provided environment using the RunFunc.
func (f RunFunc) Run(ctx context.Context, args []any) ([]any, error) {
return f(ctx, args)
Expand Down
65 changes: 65 additions & 0 deletions ext/pkg/language/program_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package language

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestTimeout(t *testing.T) {
program := RunFunc(func(ctx context.Context, args []any) ([]any, error) {
_, ok := ctx.Deadline()
assert.True(t, ok)
return nil, nil
})
timeout := Timeout(program, 1*time.Second)

_, err := timeout.Run(context.Background(), nil)
assert.NoError(t, err)
}

func TestPredicate(t *testing.T) {
program := RunFunc(func(ctx context.Context, args []any) ([]any, error) {
return []any{1}, nil
})
predicate := Predicate[int](program)

result, err := predicate(context.Background(), 1)
assert.NoError(t, err)
assert.True(t, result)
}

func TestFunction(t *testing.T) {
program := RunFunc(func(ctx context.Context, args []any) ([]any, error) {
return []any{"result"}, nil
})
function := Function[int, string](program)

result, err := function(context.Background(), 1)
assert.NoError(t, err)
assert.Equal(t, "result", result)
}

func TestBiFunction(t *testing.T) {
program := RunFunc(func(ctx context.Context, args []any) ([]any, error) {
return []any{"result"}, nil
})
biFunction := BiFunction[int, int, string](program)

result, err := biFunction(context.Background(), 1, 2)
assert.NoError(t, err)
assert.Equal(t, "result", result)
}

func TestTriFunction(t *testing.T) {
program := RunFunc(func(ctx context.Context, args []any) ([]any, error) {
return []any{"result"}, nil
})
triFunction := TriFunction[int, int, int, string](program)

result, err := triFunction(context.Background(), 1, 2, 3)
assert.NoError(t, err)
assert.Equal(t, "result", result)
}

0 comments on commit 2ed522a

Please sign in to comment.