Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen cleanup #2 #76

Merged
merged 11 commits into from
Jan 3, 2023
79 changes: 78 additions & 1 deletion cmd/codegen/arguments_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,84 @@ package main
import "C"
import "fmt"

type typeWrapper func(arg ArgDef) (argType string, def string, varName string)
type argumentWrapper func(arg ArgDef) (argType string, def string, varName string)

func argWrapper(argType string) (wrapper argumentWrapper, err error) {
argWrapperMap := map[string]argumentWrapper{
"char*": constCharW,
"const char*": constCharW,
"const char**": charPtrPtrW,
"const char* const[]": charPtrPtrW,
"unsigned char": ucharW,
"unsigned char**": uCharPtrW,
"size_t": sizeTW,
"size_t*": sizeTPtrW,
"float": floatW,
"float*": floatPtrW,
"const float*": floatArrayW,
"short": shortW,
"unsigned short": ushortW,
"ImU8": u8W,
"const ImU8*": u8SliceW,
"ImU16": u16W,
"const ImU16*": u16SliceW,
"ImU32": u32W,
"const ImU32*": u32SliceW,
"ImU64": u64W,
"const ImU64*": uint64ArrayW,
"ImS8": s8W,
"const ImS8*": s8SliceW,
"ImS16": s16W,
"const ImS16*": s16SliceW,
"ImS32": s32W,
"const ImS32*": s32SliceW,
"const ImS64*": int64ArrayW,
"int": intW,
"int*": intPtrW,
"unsigned int": uintW,
"unsigned int*": uintPtrW,
"double": doubleW,
"double*": doublePtrW,
"bool": boolW,
"bool*": boolPtrW,
"int[2]": int2W,
"int[3]": int3W,
"int[4]": int4W,
"float[2]": float2W,
"float[3]": float3W,
"float[4]": float4W,
"ImWchar": imWcharW,
"const ImWchar*": imWcharPtrW,
"ImGuiID": imGuiIDW,
"ImTextureID": imTextureIDW,
"ImDrawIdx": imDrawIdxW,
"ImGuiTableColumnIdx": imTableColumnIdxW,
"ImGuiTableDrawChannelIdx": imTableDrawChannelIdxW,
"void*": voidPtrW,
"const void*": voidPtrW,
"const ImVec2": imVec2W,
"const ImVec2*": imVec2PtrW,
"ImVec2": imVec2W,
"ImVec2*": imVec2PtrW,
"ImVec2[2]": imVec22W,
"const ImVec4": imVec4W,
"const ImVec4*": imVec4PtrW,
"ImVec4": imVec4W,
"ImVec4*": imVec4PtrW,
"ImColor*": imColorPtrW,
"ImRect": imRectW,
"ImRect*": imRectPtrW,
"ImPlotPoint": imPlotPointW,
"const ImPlotPoint": imPlotPointW,
"ImPlotPoint*": imPlotPointPtrW,
}

if wrapper, ok := argWrapperMap[argType]; ok {
return wrapper, nil
}

return nil, fmt.Errorf("no wrapper for type %s", argType)
}

func constCharW(arg ArgDef) (argType string, def string, varName string) {
argType = "string"
Expand Down
181 changes: 31 additions & 150 deletions cmd/codegen/gengo.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,6 @@ import (
"github.com/thoas/go-funk"
)

type TypeMap struct {
GoType string
CgoWrapper string
}

func tm(goType string, cgoWrapper string) *TypeMap {
return &TypeMap{
GoType: goType,
CgoWrapper: cgoWrapper,
}
}

var structMemberTypeMap = map[string]*TypeMap{
"unsigned int": tm("uint32", "C.uint(%s)"),
"float": tm("float32", "C.float(%s)"),
"int": tm("int32", "C.int(%s)"),
}

func trimImGuiPrefix(id string) string {
// don't trim prefixes for implot's ImAxis - it conflicts with ImGuIAxis (from imgui_internal.h)
if strings.HasPrefix(id, "ImAxis") {
Expand Down Expand Up @@ -141,113 +123,6 @@ import "unsafe"

`, prefix))

argWrapperMap := map[string]typeWrapper{
"char*": constCharW,
"const char*": constCharW,
"const char**": charPtrPtrW,
"const char* const[]": charPtrPtrW,
"unsigned char": ucharW,
"unsigned char**": uCharPtrW,
"size_t": sizeTW,
"size_t*": sizeTPtrW,
"float": floatW,
"float*": floatPtrW,
"const float*": floatArrayW,
"short": shortW,
"unsigned short": ushortW,
"ImU8": u8W,
"const ImU8*": u8SliceW,
"ImU16": u16W,
"const ImU16*": u16SliceW,
"ImU32": u32W,
"const ImU32*": u32SliceW,
"ImU64": u64W,
"const ImU64*": uint64ArrayW,
"ImS8": s8W,
"const ImS8*": s8SliceW,
"ImS16": s16W,
"const ImS16*": s16SliceW,
"ImS32": s32W,
"const ImS32*": s32SliceW,
"const ImS64*": int64ArrayW,
"int": intW,
"int*": intPtrW,
"unsigned int": uintW,
"unsigned int*": uintPtrW,
"double": doubleW,
"double*": doublePtrW,
"bool": boolW,
"bool*": boolPtrW,
"int[2]": int2W,
"int[3]": int3W,
"int[4]": int4W,
"float[2]": float2W,
"float[3]": float3W,
"float[4]": float4W,
"ImWchar": imWcharW,
"const ImWchar*": imWcharPtrW,
"ImGuiID": imGuiIDW,
"ImTextureID": imTextureIDW,
"ImDrawIdx": imDrawIdxW,
"ImGuiTableColumnIdx": imTableColumnIdxW,
"ImGuiTableDrawChannelIdx": imTableDrawChannelIdxW,
"void*": voidPtrW,
"const void*": voidPtrW,
"const ImVec2": imVec2W,
"const ImVec2*": imVec2PtrW,
"ImVec2": imVec2W,
"ImVec2*": imVec2PtrW,
"ImVec2[2]": imVec22W,
"const ImVec4": imVec4W,
"const ImVec4*": imVec4PtrW,
"ImVec4": imVec4W,
"ImVec4*": imVec4PtrW,
"ImColor*": imColorPtrW,
"ImRect": imRectW,
"ImRect*": imRectPtrW,
"ImPlotPoint": imPlotPointW,
"const ImPlotPoint": imPlotPointW,
"ImPlotPoint*": imPlotPointPtrW,
}

returnWrapperMap := map[string]returnWrapper{
"bool": boolReturnW,
"char*": constCharReturnW,
"const char*": constCharReturnW,
"const ImWchar*": constWCharPtrReturnW,
"ImWchar": imWcharReturnW,
"float": floatReturnW,
"double": doubleReturnW,
"int": intReturnW,
"unsigned int": uintReturnW,
"short": intReturnW,
"ImS8": intReturnW,
"ImS16": intReturnW,
"ImS32": intReturnW,
"ImU8": uintReturnW,
"ImU16": uintReturnW,
"ImU32": u32ReturnW,
"ImU64": uint64ReturnW,
"ImVec4": imVec4ReturnW,
"const ImVec4*": imVec4PtrReturnW,
"ImGuiID": idReturnW,
"ImTextureID": textureIdReturnW,
"ImVec2": imVec2ReturnW,
"ImColor": imColorReturnW,
"ImPlotPoint": imPlotPointReturnW,
"ImRect": imRectReturnW,
"ImGuiTableColumnIdx": imTableColumnIdxReturnW,
"ImGuiTableDrawChannelIdx": imTableDrawChannelIdxReturnW,
"void*": voidPtrReturnW,
"size_t": doubleReturnW,
}

type argOutput struct {
ArgType string
ArgDef string
VarName string
}

isEnum := func(argType string) bool {
for _, en := range enumNames {
if argType == en {
Expand Down Expand Up @@ -299,7 +174,7 @@ import "unsafe"
continue
}

if v, ok := argWrapperMap[a.Type]; ok {
if v, err := argWrapper(a.Type); err == nil {
argType, argDef, varName := v(a)
if goEnumName := trimImGuiPrefix(argType); isEnum(goEnumName) {
argType = goEnumName
Expand Down Expand Up @@ -359,19 +234,6 @@ import "unsafe"
fmt.Printf("generated: %s%s\n", f.FuncName, f.Args)
}

// Generate function args
argStmtFunc := func() string {
var invokeStmt []string
for _, aw := range argWrappers {
invokeStmt = append(invokeStmt, aw.VarName)
if len(aw.ArgDef) > 0 {
sb.WriteString(fmt.Sprintf("%s\n\n", aw.ArgDef))
}
}

return strings.Join(invokeStmt, ",")
}

skipStructs := []string{
"ImVec1",
"ImVec2",
Expand Down Expand Up @@ -444,8 +306,8 @@ import "unsafe"
// find out the return type
outArg := f.ArgsT[0]
outArgT := strings.TrimSuffix(outArg.Type, "*")
returnWrapper, found := returnWrapperMap[outArgT]
if !found {
returnWrapper, err := getReturnTypeWrapperFunc(outArgT)
if err != nil {
fmt.Printf("Unknown return type \"%s\" in function %s\n", f.Ret, f.FuncName)
continue
}
Expand All @@ -457,7 +319,7 @@ import "unsafe"
// temporary out arg definition
sb.WriteString(fmt.Sprintf("%s := &%s{}\n", outArg.Name, returnType))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

// C function call
sb.WriteString(fmt.Sprintf("C.%s(%s)\n", f.FuncName, argInvokeStmt))
Expand All @@ -478,27 +340,27 @@ import "unsafe"

sb.WriteString(fmt.Sprintf("func (self %[1]s) %[2]s(%[3]s) {\n", funcParts[0], funcName, strings.Join(args, ",")))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

sb.WriteString(fmt.Sprintf("C.%s(self.handle(), %s)\n", f.FuncName, argInvokeStmt))
sb.WriteString("}\n\n")
} else {
sb.WriteString(funcSignatureFunc(f.FuncName, args, ""))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

sb.WriteString(fmt.Sprintf("C.%s(%s)\n", f.FuncName, argInvokeStmt))
sb.WriteString("}\n\n")
}

convertedFuncCount += 1
default:
if rf, ok := returnWrapperMap[f.Ret]; ok {
if rf, err := getReturnTypeWrapperFunc(f.Ret); err == nil {
returnType, returnStmt := rf()

sb.WriteString(funcSignatureFunc(f.FuncName, args, returnType))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

sb.WriteString(fmt.Sprintf(returnStmt, fmt.Sprintf("C.%s(%s)", f.FuncName, argInvokeStmt)))
sb.WriteString("}\n\n")
Expand All @@ -509,7 +371,7 @@ import "unsafe"

sb.WriteString(funcSignatureFunc(f.FuncName, args, returnType))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

sb.WriteString(fmt.Sprintf("return %s(%s)", returnType, fmt.Sprintf("C.%s(%s)", f.FuncName, argInvokeStmt)))
sb.WriteString("}\n\n")
Expand All @@ -522,7 +384,7 @@ import "unsafe"

sb.WriteString(funcSignatureFunc(f.FuncName, args, pureReturnType))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

sb.WriteString(fmt.Sprintf("return (%s)(unsafe.Pointer(%s))", pureReturnType, fmt.Sprintf("C.%s(%s)", f.FuncName, argInvokeStmt)))
sb.WriteString("}\n\n")
Expand All @@ -531,7 +393,7 @@ import "unsafe"
} else if f.StructGetter && funk.ContainsString(structNames, f.Ret) {
sb.WriteString(funcSignatureFunc(f.FuncName, args, f.Ret))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

sb.WriteString(fmt.Sprintf("return new%sFromC(C.%s(%s))", f.Ret, f.FuncName, argInvokeStmt))
sb.WriteString("}\n\n")
Expand Down Expand Up @@ -559,7 +421,7 @@ import "unsafe"

sb.WriteString(fmt.Sprintf("func %s(%s) %s {\n", newFuncName, strings.Join(args, ","), returnType))

argInvokeStmt := argStmtFunc()
argInvokeStmt := argStmtFunc(argWrappers, &sb)

sb.WriteString(fmt.Sprintf("return (%s)(unsafe.Pointer(C.%s(%s)))", returnType, f.FuncName, argInvokeStmt))

Expand All @@ -582,3 +444,22 @@ import "unsafe"

_, _ = goFile.WriteString(sb.String())
}

type argOutput struct {
ArgType string
ArgDef string
VarName string
}

// Generate function args
func argStmtFunc(argWrappers []argOutput, sb *strings.Builder) string {
var invokeStmt []string
for _, aw := range argWrappers {
invokeStmt = append(invokeStmt, aw.VarName)
if len(aw.ArgDef) > 0 {
sb.WriteString(fmt.Sprintf("%s\n\n", aw.ArgDef))
}
}

return strings.Join(invokeStmt, ",")
}
Loading