Skip to content

Commit

Permalink
refactor: use inference execution provider
Browse files Browse the repository at this point in the history
For detailed changes, refer to the related PR: [MaaXYZ/MaaFramework#430](MaaXYZ/MaaFramework#430).
  • Loading branch information
dongwlin committed Dec 5, 2024
1 parent 1a216eb commit 55c3ee8
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 14 deletions.
45 changes: 44 additions & 1 deletion internal/maa/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,42 @@ type MaaCustomRecognitionCallback func(context uintptr, taskId int64, currentTas

type MaaCustomActionCallback func(context uintptr, taskId int64, currentTaskName, customActionName, customActionParam *byte, recoId int64, box uintptr, transArg uintptr) uint64

type MaaInferenceDevice int32

const (
MaaInferenceDevice_CPU MaaInferenceDevice = -2
MaaInferenceDevice_Auto MaaInferenceDevice = -1
MaaInferenceDevice_0 MaaInferenceDevice = 0
MaaInferenceDevice_1 MaaInferenceDevice = 1
// and more gpu id or flag...
)

type MaaInferenceExecutionProvider int32

const (

// I don't recommend setting up MaaResOption_InferenceDevice in this case,
// because you don't know which EP will be used on different user devices.
MaaInferenceExecutionProvider_Auto = 0

// MaaResOption_InferenceDevice will not work.
MaaInferenceExecutionProvider_CPU = 1

// MaaResOption_InferenceDevice will be used to set adapter id,
// It's from Win32 API `EnumAdapters1`.
MaaInferenceExecutionProvider_DirectML = 2

// MaaResOption_InferenceDevice will be used to set coreml_flag,
// Reference to
// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
// But you need to pay attention to the onnxruntime version we use, the latest flag may not be supported.
MaaInferenceExecutionProvider_CoreML = 3

// MaaResOption_InferenceDevice will be used to set NVIDIA GPU ID
// TODO!
MaaInferenceExecutionProvider_CUDA = 4
)

type MaaResOption int32

const (
Expand All @@ -49,7 +85,14 @@ const (
///
/// value: MaaInferenceDevice, eg: 0; val_size: sizeof(MaaInferenceDevice)
/// default value is MaaInferenceDevice_Auto
MaaResOption_InterfaceDevice MaaResOption = 1
MaaResOption_InferenceDevice MaaResOption = 1

/// Use the specified inference execution provider
/// Please set this option before loading the model.
///
/// value: MaaInferenceExecutionProvider, eg: 0; val_size: sizeof(MaaInferenceExecutionProvider)
/// default value is MaaInferenceExecutionProvider_Auto
MaaResOption_InferenceExecutionProvider MaaResOption = 2
)

var (
Expand Down
56 changes: 43 additions & 13 deletions resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,55 @@ func (r *Resource) setOption(key maa.MaaResOption, value unsafe.Pointer, valSize
)
}

type InterfaceDevice int32

// InterfaceDevice
const (
InterfaceDeviceCPU InterfaceDevice = -2
InterfaceDeviceAuto InterfaceDevice = -1
InterfaceDeviceGPU0 InterfaceDevice = 0
InterfaceDeviceGPU1 InterfaceDevice = 1
// and more gpu id...
)

func (r *Resource) SetInterfaceDevice(device InterfaceDevice) bool {
func (r *Resource) setInferenceDevice(device maa.MaaInferenceDevice) bool {
return r.setOption(
maa.MaaResOption_InterfaceDevice,
maa.MaaResOption_InferenceDevice,
unsafe.Pointer(&device),
unsafe.Sizeof(device),
)
}

func (r *Resource) setInferenceExecutionProvider(ep maa.MaaInferenceExecutionProvider) bool {
return r.setOption(
maa.MaaResOption_InferenceExecutionProvider,
unsafe.Pointer(&ep),
unsafe.Sizeof(ep),
)
}

func (r *Resource) setInference(ep maa.MaaInferenceExecutionProvider, deviceID maa.MaaInferenceDevice) bool {
return r.setInferenceExecutionProvider(ep) && r.setInferenceDevice(deviceID)
}

// UseCPU
func (r *Resource) UseCPU() bool {
return r.setInference(maa.MaaInferenceExecutionProvider_CPU, maa.MaaInferenceDevice_CPU)
}

type InterenceDevice = maa.MaaInferenceDevice

const (
InterenceDeviceAuto int32 = -1
InferenceDevice0 int32 = 0
InferenceDevice1 int32 = 1
// and more gpu id or flag...
)

// UseDirectml
func (r *Resource) UseDirectml(deviceID InterenceDevice) bool {
return r.setInference(maa.MaaInferenceExecutionProvider_DirectML, deviceID)
}

// UseCoreml
func (r *Resource) UseCoreml(coremlFlag InterenceDevice) bool {
return r.setInference(maa.MaaInferenceExecutionProvider_CoreML, coremlFlag)
}

// UseAutoExecutionProvider
func (r *Resource) UseAutoExecutionProvider() bool {
return r.setInference(maa.MaaInferenceExecutionProvider_Auto, maa.MaaInferenceDevice_Auto)
}

// RegisterCustomRecognition registers a custom recognition to the resource.
func (r *Resource) RegisterCustomRecognition(name string, recognition CustomRecognition) bool {
id := registerCustomRecognition(recognition)
Expand Down

0 comments on commit 55c3ee8

Please sign in to comment.