Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

purego: add float32 and float64 support to callbacks #120

Merged
merged 8 commits into from
Apr 3, 2023
Merged
54 changes: 54 additions & 0 deletions callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,57 @@ func buildSharedLib(libFile string, sources ...string) error {

return nil
}

func TestNewCallback(t *testing.T) {
// This tests the maximum number of arguments a function to NewCallback can take
const (
expectCbTotal = -3
expectedCbTotalF = float64(36)
)
var cbTotal int
var cbTotalF float64
imp := purego.NewCallback(func(a1, a2, a3, a4, a5, a6, a7, a8, a9 int,
f1, f2, f3, f4, f5, f6, f7, f8 float64) {
cbTotal = a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9
cbTotalF = f1 + f2 + f3 + f4 + f5 + f6 + f7 + f8
})
var fn func(a1, a2, a3, a4, a5, a6, a7, a8, a9 int,
f1, f2, f3, f4, f5, f6, f7, f8 float64)
purego.RegisterFunc(&fn, imp)
fn(1, 2, -3, 4, -5, 6, -7, 8, -9,
1, 2, 3, 4, 5, 6, 7, 8)

if cbTotal != expectCbTotal {
t.Fatalf("cbTotal not correct got %d but wanted %d", cbTotal, expectCbTotal)
}
if cbTotalF != expectedCbTotalF {
t.Fatalf("cbTotalF not correct got %f but wanted %f", cbTotalF, expectedCbTotalF)
}
}

func TestNewCallback32(t *testing.T) {
// This tests the maximum number of float32 arguments a function to NewCallback
const (
expectCbTotal = 6
expectedCbTotalF = float32(45)
)
var cbTotal int
var cbTotalF float32
imp := purego.NewCallback(func(a1, a2, a3, a4, a5, a6, a7, a8 int,
f1, f2, f3, f4, f5, f6, f7, f8, f9 float32) {
cbTotal = a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8
cbTotalF = f1 + f2 + f3 + f4 + f5 + f6 + f7 + f8 + f9
})
var fn func(a1, a2, a3, a4, a5, a6, a7, a8 int,
f1, f2, f3, f4, f5, f6, f7, f8, f9 float32)
purego.RegisterFunc(&fn, imp)
fn(1, 2, -3, 4, -5, 6, -7, 8,
1, 2, 3, 4, 5, 6, 7, 8, 9)

if cbTotal != expectCbTotal {
t.Fatalf("cbTotal not correct got %d but wanted %d", cbTotal, expectCbTotal)
}
if cbTotalF != expectedCbTotalF {
t.Fatalf("cbTotalF not correct got %f but wanted %f", cbTotalF, expectedCbTotalF)
}
}
18 changes: 13 additions & 5 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func RegisterFunc(fptr interface{}, cfn uintptr) {
stack++
}
case reflect.Float32, reflect.Float64:
if floats < 8 {
if floats < numOfFloats {
floats++
} else {
stack++
Expand All @@ -108,7 +108,8 @@ func RegisterFunc(fptr interface{}, cfn uintptr) {
panic("purego: unsupported kind " + arg.Kind().String())
}
}
if ints+stack > maxArgs || floats+stack > maxArgs {
sizeOfStack := maxArgs - numOfIntegerRegisters()
if stack > sizeOfStack {
panic("purego: too many arguments")
}
}
Expand All @@ -127,7 +128,7 @@ func RegisterFunc(fptr interface{}, cfn uintptr) {
}
var sysargs [maxArgs]uintptr
var stack = sysargs[numOfIntegerRegisters():]
var floats [8]float64
var floats [numOfFloats]uintptr
var numInts int
var numFloats int
var numStack int
Expand Down Expand Up @@ -168,9 +169,16 @@ func RegisterFunc(fptr interface{}, cfn uintptr) {
} else {
addInt(0)
}
case reflect.Float32, reflect.Float64:
case reflect.Float32:
if numFloats < len(floats) {
floats[numFloats] = uintptr(math.Float32bits(float32(v.Float())))
numFloats++
} else {
addStack(uintptr(math.Float32bits(float32(v.Float()))))
}
case reflect.Float64:
if numFloats < len(floats) {
floats[numFloats] = v.Float()
floats[numFloats] = uintptr(math.Float64bits(v.Float()))
numFloats++
} else {
addStack(uintptr(math.Float64bits(v.Float())))
Expand Down
28 changes: 18 additions & 10 deletions sys_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,23 @@ TEXT callbackasm1(SB), NOSPLIT, $0

MOVQ 0(SP), R10 // get the return SP so that we can align register args with stack args

// make space for first six arguments below the frame
ADJSP $6*8, SP
MOVQ DI, 8(SP)
MOVQ SI, 16(SP)
MOVQ DX, 24(SP)
MOVQ CX, 32(SP)
MOVQ R8, 40(SP)
MOVQ R9, 48(SP)
LEAQ 8(SP), R8 // R8 = address of args vector
// make space for first six int and 8 float arguments below the frame
ADJSP $14*8, SP
MOVSD X0, (1*8)(SP)
MOVSD X1, (2*8)(SP)
MOVSD X2, (3*8)(SP)
MOVSD X3, (4*8)(SP)
MOVSD X4, (5*8)(SP)
MOVSD X5, (6*8)(SP)
MOVSD X6, (7*8)(SP)
MOVSD X7, (8*8)(SP)
MOVQ DI, (9*8)(SP)
MOVQ SI, (10*8)(SP)
MOVQ DX, (11*8)(SP)
MOVQ CX, (12*8)(SP)
MOVQ R8, (13*8)(SP)
MOVQ R9, (14*8)(SP)
LEAQ 8(SP), R8 // R8 = address of args vector

MOVQ R10, 0(SP) // push the stack pointer below registers

Expand Down Expand Up @@ -128,7 +136,7 @@ TEXT callbackasm1(SB), NOSPLIT, $0

MOVQ 0(SP), R10 // get the SP back

ADJSP $-6*8, SP // remove arguments
ADJSP $-14*8, SP // remove arguments

MOVQ R10, 0(SP)

Expand Down
32 changes: 20 additions & 12 deletions sys_arm64.s
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,32 @@ TEXT callbackasm1(SB), NOSPLIT|NOFRAME, $0

// Save callback register arguments R0-R7.
// We do this at the top of the frame so they're contiguous with stack arguments.
SUB $(8*8), RSP, R14
STP (R0, R1), (0*8)(R14)
STP (R2, R3), (2*8)(R14)
STP (R4, R5), (4*8)(R14)
STP (R6, R7), (6*8)(R14)
SUB $(16*8), RSP, R14
FMOVD F0, (0*8)(R14)
FMOVD F1, (1*8)(R14)
FMOVD F2, (2*8)(R14)
FMOVD F3, (3*8)(R14)
FMOVD F4, (4*8)(R14)
FMOVD F5, (5*8)(R14)
FMOVD F6, (6*8)(R14)
FMOVD F7, (7*8)(R14)
STP (R0, R1), (8*8)(R14)
STP (R2, R3), (10*8)(R14)
STP (R4, R5), (12*8)(R14)
STP (R6, R7), (14*8)(R14)

// Adjust SP by frame size.
// crosscall2 clobbers FP in the frame record so only save/restore SP.
SUB $(28*8), RSP
SUB $(28*8), RSP
MOVD R30, (RSP)

// Create a struct callbackArgs on our stack.
ADD $(callbackArgs__size + 3*8), RSP, R13
MOVD R12, callbackArgs_index(R13) // callback index
ADD $(callbackArgs__size + 3*8), RSP, R13
MOVD R12, callbackArgs_index(R13) // callback index
MOVD R14, R0
MOVD R0, callbackArgs_args(R13) // address of args vector
MOVD R0, callbackArgs_args(R13) // address of args vector
MOVD $0, R0
MOVD R0, callbackArgs_result(R13) // result
MOVD R0, callbackArgs_result(R13) // result

// Move parameters into registers
// Get the ABIInternal function pointer
Expand All @@ -101,11 +109,11 @@ TEXT callbackasm1(SB), NOSPLIT|NOFRAME, $0
BL crosscall2(SB)

// Get callback result.
ADD $(callbackArgs__size + 3*8), RSP, R13
ADD $(callbackArgs__size + 3*8), RSP, R13
MOVD callbackArgs_result(R13), R0

// Restore SP
MOVD (RSP), R30
ADD $(28*8), RSP
ADD $(28*8), RSP

RET
5 changes: 4 additions & 1 deletion syscall.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

package purego

const maxArgs = 9
const (
maxArgs = 9
numOfFloats = 8 // arm64 and amd64 both have 8 float registers
)

// SyscallN takes fn, a C function pointer and a list of arguments as uintptr.
// There is an internal maximum number of arguments that SyscallN can take. It panics
Expand Down
2 changes: 1 addition & 1 deletion syscall_cgo_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var syscall9XABI0 = uintptr(cgo.Syscall9XABI0)
// this is only here to make the assembly files happy :)
type syscall9Args struct {
fn, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr
f1, f2, f3, f4, f5, f6, f7, f8 float64
f1, f2, f3, f4, f5, f6, f7, f8 uintptr
r1, r2, err uintptr
}

Expand Down
53 changes: 36 additions & 17 deletions syscall_sysv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package purego

import (
"math"
"reflect"
"runtime"
"sync"
Expand All @@ -17,16 +16,14 @@ var syscall9XABI0 uintptr

type syscall9Args struct {
fn, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr
f1, f2, f3, f4, f5, f6, f7, f8 float64
f1, f2, f3, f4, f5, f6, f7, f8 uintptr
r1, r2, err uintptr
}

//go:nosplit
func syscall_syscall9X(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2, err uintptr) {
args := syscall9Args{fn, a1, a2, a3, a4, a5, a6, a7, a8, a9,
math.Float64frombits(uint64(a1)), math.Float64frombits(uint64(a2)), math.Float64frombits(uint64(a3)),
math.Float64frombits(uint64(a4)), math.Float64frombits(uint64(a5)), math.Float64frombits(uint64(a6)),
math.Float64frombits(uint64(a7)), math.Float64frombits(uint64(a8)),
a1, a2, a3, a4, a5, a6, a7, a8,
r1, r2, err}
runtime_cgocall(syscall9XABI0, unsafe.Pointer(&args))
return args.r1, args.r2, args.err
Expand Down Expand Up @@ -61,15 +58,13 @@ type callbackArgs struct {
index uintptr
// args points to the argument block.
//
// For cdecl and stdcall, all arguments are on the stack.
// The structure of the arguments goes
// float registers followed by the
// integer registers followed by the stack.
//
// For fastcall, the trampoline spills register arguments to
// the reserved spill slots below the stack arguments,
// resulting in a layout equivalent to stdcall.
//
// For arm, the trampoline stores the register arguments just
// below the stack arguments, so again we can treat it as one
// big stack arguments frame.
// This variable is treated as a continuous
// block of memory containing all of the arguments
// for this callback.
args unsafe.Pointer
// Below are out-args from callbackWrap
result uintptr
Expand All @@ -84,8 +79,7 @@ func compileCallback(fn interface{}) uintptr {
for i := 0; i < ty.NumIn(); i++ {
in := ty.In(i)
switch in.Kind() {
case reflect.Struct, reflect.Float32, reflect.Float64,
reflect.Interface, reflect.Func, reflect.Slice,
case reflect.Struct, reflect.Interface, reflect.Func, reflect.Slice,
reflect.Chan, reflect.Complex64, reflect.Complex128,
reflect.String, reflect.Map, reflect.Invalid:
panic("purego: unsupported argument type: " + in.Kind().String())
Expand Down Expand Up @@ -136,9 +130,34 @@ func callbackWrap(a *callbackArgs) {
fnType := fn.Type()
args := make([]reflect.Value, fnType.NumIn())
frame := (*[callbackMaxFrame]uintptr)(a.args)
var floatsN int
// intsN is offset by the integer position by the number
// of floatsN because in the frame it starts with the float
// registers followed by the integer and then the stack after that.
var intsN int = numOfFloats
// the stack is located in the frame after the floats and integers
var stack = numOfIntegerRegisters() + numOfFloats
for i := range args {
//TODO: support float32 and float64
args[i] = reflect.NewAt(fnType.In(i), unsafe.Pointer(&frame[i])).Elem()
var pos int
switch fnType.In(i).Kind() {
case reflect.Float32, reflect.Float64:
if floatsN >= numOfFloats {
pos = stack
stack++
} else {
pos = floatsN
}
floatsN++
default:
if intsN >= numOfIntegerRegisters()+numOfFloats {
pos = stack
stack++
} else {
pos = intsN
}
intsN++
}
args[i] = reflect.NewAt(fnType.In(i), unsafe.Pointer(&frame[pos])).Elem()
}
ret := fn.Call(args)
if len(ret) > 0 {
Expand Down