Skip to content

Commit

Permalink
codegen: Handle a few more cases, populate some mps structs
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc committed Sep 10, 2023
1 parent 645e013 commit 102ecf9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
2 changes: 2 additions & 0 deletions generate/codegen/gen_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ func (f *Function) WriteGoCallCode(currentModule *modules.Module, cw *CodeWriter
sb.WriteString(cw.IndentStr + fmt.Sprintf(" (*C.%s)(unsafe.Pointer(&%s))", tt.CName(), p.GoName()))
case *typing.IDType:
sb.WriteString(cw.IndentStr + fmt.Sprintf(" %s.Ptr()", p.GoName()))
case *typing.ClassType, *typing.ProtocolType:
sb.WriteString(cw.IndentStr + fmt.Sprintf(" unsafe.Pointer(&%s)", p.GoName()))
default:
sb.WriteString(cw.IndentStr + p.GoName())
}
Expand Down
14 changes: 7 additions & 7 deletions macos/mps/functions.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func StateBatchResourceSize(batch *foundation.Array) uint {
func HintTemporaryMemoryHighWaterMark(cmdBuf metal.CommandBufferWrapper, bytes uint) {
C.HintTemporaryMemoryHighWaterMark(
// *typing.ProtocolType
cmdBuf,
unsafe.Pointer(&cmdBuf),
// *typing.PrimitiveType
C.uint(bytes),
)
Expand All @@ -73,7 +73,7 @@ func ImageBatchResourceSize(batch *foundation.Array) uint {
func SetHeapCacheDuration(cmdBuf metal.CommandBufferWrapper, seconds float64) {
C.SetHeapCacheDuration(
// *typing.ProtocolType
cmdBuf,
unsafe.Pointer(&cmdBuf),
// *typing.PrimitiveType
C.double(seconds),
)
Expand Down Expand Up @@ -101,7 +101,7 @@ func StateBatchSynchronize(batch *foundation.Array, cmdBuf metal.CommandBufferWr
// *typing.PointerType
(*C.MPSStateBatch)(unsafe.Pointer(&batch)),
// *typing.ProtocolType
cmdBuf,
unsafe.Pointer(&cmdBuf),
)
}

Expand Down Expand Up @@ -180,10 +180,10 @@ func GetCustomKernelBroadcastSourceIndex(c CustomKernelArgumentCount, sourceInde
func GetImageType(image Image) ImageType {
rv := C.GetImageType(
// *typing.ClassType
image,
unsafe.Pointer(&image),
)
// *typing.AliasType
return ImageType(rv)
return *(*ImageType)(unsafe.Pointer(&rv))
}

// Returns the integer division parameters for a specified divisor. [Full Topic]
Expand Down Expand Up @@ -222,7 +222,7 @@ func ImageBatchSynchronize(batch *foundation.Array, cmdBuf metal.CommandBufferWr
// *typing.PointerType
(*C.MPSImageBatch)(unsafe.Pointer(&batch)),
// *typing.ProtocolType
cmdBuf,
unsafe.Pointer(&cmdBuf),
)
}

Expand All @@ -232,7 +232,7 @@ func ImageBatchSynchronize(batch *foundation.Array, cmdBuf metal.CommandBufferWr
func SupportsMTLDevice(device metal.DeviceWrapper) bool {
rv := C.SupportsMTLDevice(
// *typing.ProtocolType
device,
unsafe.Pointer(&device),
)
// *typing.PrimitiveType
return bool(rv)
Expand Down

0 comments on commit 102ecf9

Please sign in to comment.