Skip to content

Commit

Permalink
feat: Add support for using slices as return types and arguments for …
Browse files Browse the repository at this point in the history
…closures in Go

Signed-off-by: Felicitas Pojtinger <felicitas@pojtinger.com>
  • Loading branch information
pojntfx committed Oct 3, 2024
1 parent 24005d7 commit b93809e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
15 changes: 6 additions & 9 deletions go/pkg/rpc/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var (
ErrNotAFunction = errors.New("not a function")

ErrInvalidArgsCount = errors.New("invalid argument count")
ErrInvalidArg = errors.New("invalid argument")
ErrInvalidArg = errors.New("invalid argument, either the type doesn't match or is too complex and can't be inspected")

ErrClosureDoesNotExist = errors.New("closure does not exist")
)
Expand Down Expand Up @@ -45,15 +45,12 @@ func createClosure(fn interface{}) (func(args ...interface{}) (interface{}, erro

in := make([]reflect.Value, len(args))
for i, arg := range args {
if argType := reflect.TypeOf(arg); argType != functionType.In(i) {
if argType.ConvertibleTo(functionType.In(i)) {
in[i] = reflect.ValueOf(arg).Convert(functionType.In(i))
} else {
return nil, ErrInvalidArg
}
} else {
in[i] = reflect.ValueOf(arg)
convertedArgVal, err := convertValue(reflect.ValueOf(arg), functionType.In(i))
if err != nil {
return nil, ErrInvalidArg
}

in[i] = convertedArgVal
}

out, err := utils.Call(reflect.ValueOf(fn), in)
Expand Down
36 changes: 35 additions & 1 deletion go/pkg/rpc/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (

ErrInvalidFunctionCallPath = errors.New("invalid or empty function call path")
ErrInvalidReturn = errors.New("invalid return, can only return an error or a value and an error")
ErrReturnValueTooComplex = errors.New("invalid return, either the type doesn't match or is too complex and can't be inspected")
ErrInvalidArgs = errors.New("invalid arguments, first argument needs to be a context.Context")

ErrCannotCallNonFunction = errors.New("can not call non function")
Expand Down Expand Up @@ -444,7 +445,12 @@ func (r Registry[R, T]) findLocalFunctionToCallRecursively(
errReturnValue := reflect.New(functionType.Out(1))

if el := rcpRv[0].Elem(); el.IsValid() {
valueReturnValue.Elem().Set(el.Convert(valueReturnValue.Type().Elem()))
convertedValueReturnType, err := convertValue(el, valueReturnValue.Type().Elem())
if err != nil {
panic(err)
}

valueReturnValue.Elem().Set(convertedValueReturnType)
}
errReturnValue.Elem().Set(rcpRv[1])

Expand Down Expand Up @@ -500,6 +506,34 @@ func findMethodByFunctionCallPathRecursively(root interface{}, functionCallPath
return function, nil
}

func convertValue(srcVal reflect.Value, dstType reflect.Type) (reflect.Value, error) {
for srcVal.Kind() == reflect.Interface {
srcVal = srcVal.Elem()
}

if srcVal.Type().ConvertibleTo(dstType) {
return srcVal.Convert(dstType), nil
}

// We can't convert slices directly, we have to convert each element individually
if srcVal.Kind() == reflect.Slice && dstType.Kind() == reflect.Slice {
srcLen := srcVal.Len()
dstSlice := reflect.MakeSlice(dstType, srcLen, srcLen)
for i := 0; i < srcLen; i++ {
elem, err := convertValue(srcVal.Index(i), dstType.Elem())
if err != nil {
return reflect.Value{}, err
}

dstSlice.Index(i).Set(elem)
}

return dstSlice, nil
}

return reflect.Value{}, ErrReturnValueTooComplex
}

// LinkMessage exposes local RPCs and implements remote RPCs via a message-based transport
func (r Registry[R, T]) LinkMessage(
ctx context.Context, // Context for read, write and in-flight RPC operations
Expand Down

0 comments on commit b93809e

Please sign in to comment.