Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Create statically compiled Swift closure wrapper to avoid C-style function pointers #440

Merged
merged 13 commits into from
Dec 20, 2024
108 changes: 40 additions & 68 deletions packages/nitrogen/src/syntax/swift/SwiftCxxBridgedType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import { createSwiftVariant, getSwiftVariantCaseName } from './SwiftVariant.js'
import { VoidType } from '../types/VoidType.js'
import { NamedWrappingType } from '../types/NamedWrappingType.js'
import { ErrorType } from '../types/ErrorType.js'
import { createSwiftFunctionBridge } from './SwiftFunction.js'

// TODO: Remove enum bridge once Swift fixes bidirectional enums crashing the `-Swift.h` header.

Expand Down Expand Up @@ -176,10 +177,24 @@ export class SwiftCxxBridgedType implements BridgedType<'swift', 'c++'> {
files.push(extensionFile)
break
}
case 'function': {
const functionType = getTypeAs(this.type, FunctionType)
const extensionFile = createSwiftFunctionBridge(functionType)
files.push(extensionFile)
break
}
case 'promise': {
// Promise needs resolver and rejecter funcs in Swift
const promiseType = getTypeAs(this.type, PromiseType)
files.push(createSwiftFunctionBridge(promiseType.resolverFunction))
files.push(createSwiftFunctionBridge(promiseType.rejecterFunction))
break
}
case 'variant': {
const variant = getTypeAs(this.type, VariantType)
const file = createSwiftVariant(variant)
files.push(file)
break
}
}

Expand Down Expand Up @@ -345,23 +360,27 @@ export class SwiftCxxBridgedType implements BridgedType<'swift', 'c++'> {
const promise = getTypeAs(this.type, PromiseType)
switch (language) {
case 'swift': {
const bridge = this.getBridgeOrThrow()
if (promise.resultingType.kind === 'void') {
// It's void - resolve()
const resolverFunc = new FunctionType(new VoidType(), [])
const rejecterFunc = new FunctionType(new VoidType(), [
new NamedWrappingType('error', new ErrorType()),
])
const resolverFuncBridge = new SwiftCxxBridgedType(resolverFunc)
const rejecterFuncBridge = new SwiftCxxBridgedType(rejecterFunc)
return `
{ () -> ${promise.getCode('swift')} in
let __promise = ${promise.getCode('swift')}()
let __resolver = SwiftClosure { __promise.resolve(withResult: ()) }
let __resolver = { __promise.resolve(withResult: ()) }
let __rejecter = { (__error: Error) in
__promise.reject(withError: __error)
}
let __resolverCpp = __resolver.getFunctionCopy()
let __resolverCpp = ${indent(resolverFuncBridge.parseFromSwiftToCpp('__resolver', 'swift'), ' ')}
let __rejecterCpp = ${indent(rejecterFuncBridge.parseFromSwiftToCpp('__rejecter', 'swift'), ' ')}
${cppParameterName}.addOnResolvedListener(__resolverCpp)
${cppParameterName}.addOnRejectedListener(__rejecterCpp)
let __promiseHolder = bridge.wrap_${bridge.specializationName}(${cppParameterName})
__promiseHolder.addOnResolvedListener(__resolverCpp)
__promiseHolder.addOnRejectedListener(__rejecterCpp)
return __promise
}()`.trim()
} else {
Expand Down Expand Up @@ -389,8 +408,9 @@ export class SwiftCxxBridgedType implements BridgedType<'swift', 'c++'> {
}
let __resolverCpp = ${indent(resolverFuncBridge.parseFromSwiftToCpp('__resolver', 'swift'), ' ')}
let __rejecterCpp = ${indent(rejecterFuncBridge.parseFromSwiftToCpp('__rejecter', 'swift'), ' ')}
${cppParameterName}.${resolverFuncName}(__resolverCpp)
${cppParameterName}.addOnRejectedListener(__rejecterCpp)
let __promiseHolder = bridge.wrap_${bridge.specializationName}(${cppParameterName})
__promiseHolder.${resolverFuncName}(__resolverCpp)
__promiseHolder.addOnRejectedListener(__rejecterCpp)
return __promise
}()`.trim()
}
Expand Down Expand Up @@ -532,18 +552,21 @@ case ${i}:
if (funcType.returnType.kind === 'void') {
return `
{ () -> ${swiftClosureType} in
let __sharedClosure = bridge.share_${bridge.specializationName}(${cppParameterName})
let __wrappedFunction = bridge.wrap_${bridge.specializationName}(${cppParameterName})
return { ${signature} in
__sharedClosure.pointee.call(${indent(paramsForward.join(', '), ' ')})
__wrappedFunction.call(${indent(paramsForward.join(', '), ' ')})
}
}()`.trim()
} else {
const resultBridged = new SwiftCxxBridgedType(funcType.returnType)
const resultBridged = new SwiftCxxBridgedType(
funcType.returnType,
true
)
return `
{ () -> ${swiftClosureType} in
let __sharedClosure = bridge.share_${bridge.specializationName}(${cppParameterName})
let __wrappedFunction = bridge.wrap_${bridge.specializationName}(${cppParameterName})
return { ${signature} in
let __result = __sharedClosure.pointee.call(${indent(paramsForward.join(', '), ' ')})
let __result = __wrappedFunction.call(${indent(paramsForward.join(', '), ' ')})
return ${indent(resultBridged.parseFromCppToSwift('__result', 'swift'), ' ')}
}
}()`.trim()
Expand Down Expand Up @@ -661,30 +684,20 @@ case ${i}:
true
)
switch (language) {
case 'c++':
if (this.isBridgingToDirectCppTarget) {
return swiftParameterName
} else {
return `${swiftParameterName}.getPromise()`
}
case 'swift':
const arg =
promise.resultingType.kind === 'void'
? ''
: resolvingType.parseFromSwiftToCpp('__result', 'swift')
const code = `
return `
{ () -> bridge.${bridge.specializationName} in
let __promise = ${makePromise}()
let __promiseHolder = bridge.wrap_${bridge.specializationName}(__promise)
${swiftParameterName}
.then({ __result in __promise.resolve(${indent(arg, ' ')}) })
.catch({ __error in __promise.reject(__error.toCpp()) })
.then({ __result in __promiseHolder.resolve(${indent(arg, ' ')}) })
.catch({ __error in __promiseHolder.reject(__error.toCpp()) })
return __promise
}()`.trim()
if (this.isBridgingToDirectCppTarget) {
return `${code}.getPromise()`
} else {
return code
}
default:
return swiftParameterName
}
Expand Down Expand Up @@ -772,52 +785,11 @@ case ${i}:
switch (language) {
case 'swift': {
const bridge = this.getBridgeOrThrow()
const func = getTypeAs(this.type, FunctionType)
const cFuncParamsForward = func.parameters
.map((p) => {
const bridged = new SwiftCxxBridgedType(p)
return bridged.parseFromCppToSwift(
`__${p.escapedName}`,
'swift'
)
})
.join(', ')
const paramsSignature = func.parameters
.map((p) => `_ __${p.escapedName}: ${p.getCode('swift')}`)
.join(', ')
const paramsForward = func.parameters
.map((p) => `__${p.escapedName}`)
.join(', ')
const cFuncParamsSignature = [
'__closureHolder: UnsafeMutableRawPointer',
...func.parameters.map((p) => {
const bridged = new SwiftCxxBridgedType(p)
return `__${p.escapedName}: ${bridged.getTypeCode('swift')}`
}),
].join(', ')
const createFunc = `bridge.${bridge.funcName}`
return `
{ () -> bridge.${bridge.specializationName} in
final class ClosureHolder {
let closure: ${func.getCode('swift')}
init(wrappingClosure closure: @escaping ${func.getCode('swift')}) {
self.closure = closure
}
func invoke(${paramsSignature}) {
self.closure(${indent(paramsForward, ' ')})
}
}

let __closureHolder = Unmanaged.passRetained(ClosureHolder(wrappingClosure: ${swiftParameterName})).toOpaque()
func __callClosure(${cFuncParamsSignature}) -> Void {
let closure = Unmanaged<ClosureHolder>.fromOpaque(__closureHolder).takeUnretainedValue()
closure.invoke(${indent(cFuncParamsForward, ' ')})
}
func __destroyClosure(_ __closureHolder: UnsafeMutableRawPointer) -> Void {
Unmanaged<ClosureHolder>.fromOpaque(__closureHolder).release()
}

return ${createFunc}(__closureHolder, __callClosure, __destroyClosure)
let __closureWrapper = ${bridge.specializationName}(${swiftParameterName})
return ${createFunc}(__closureWrapper.toUnsafe())
}()
`.trim()
}
Expand Down
98 changes: 51 additions & 47 deletions packages/nitrogen/src/syntax/swift/SwiftCxxTypeHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,58 +303,47 @@ function createCxxFunctionSwiftHelper(type: FunctionType): SwiftCxxHelper {
return `${p.getCode('c++')} ${p.escapedName}`
}
})
const callCppFuncParamsSignature = type.parameters.map((p) => {
const paramsForward = type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
const cppType = bridge.getTypeCode('c++')
return `${cppType} ${p.escapedName}`
return bridge.parseFromCppToSwift(p.escapedName, 'c++')
})
const name = type.specializationName
const wrapperName = `${name}_Wrapper`
const swiftClassName = `${NitroConfig.getIosModuleName()}::${type.specializationName}`

const callParamsForward = type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
return bridge.parseFromSwiftToCpp(p.escapedName, 'c++')
})
const paramsForward = [
'sharedClosureHolder.get()',
...type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
return bridge.parseFromCppToSwift(p.escapedName, 'c++')
}),
]
const callFuncReturnType = returnBridge.getTypeCode('c++')
const callFuncParams = [
'void* _Nonnull /* closureHolder */',
...type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
return bridge.getTypeCode('c++')
}),
]
const functionPointerParam = `${callFuncReturnType}(* _Nonnull call)(${callFuncParams.join(', ')})`
const name = type.specializationName
const wrapperName = `${name}_Wrapper`

const callFuncReturnType = returnBridge.getTypeCode('c++')
const callCppFuncParamsSignature = type.parameters.map((p) => {
const bridge = new SwiftCxxBridgedType(p)
const cppType = bridge.getTypeCode('c++')
return `${cppType} ${p.escapedName}`
})
let callCppFuncBody: string
if (returnBridge.hasType) {
callCppFuncBody = `
auto __result = _function(${callParamsForward.join(', ')});
auto __result = _function->operator()(${callParamsForward.join(', ')});
return ${returnBridge.parseFromCppToSwift('__result', 'c++')};
`.trim()
} else {
callCppFuncBody = `_function(${callParamsForward.join(', ')});`
callCppFuncBody = `_function->operator()(${callParamsForward.join(', ')});`
}

let callSwiftFuncBody: string
if (returnBridge.hasType) {
callSwiftFuncBody = `
auto __result = call(${paramsForward.join(', ')});
let body: string
if (type.returnType.kind === 'void') {
body = `
swiftClosure.call(${paramsForward.join(', ')});
`.trim()
} else {
body = `
auto __result = swiftClosure.call(${paramsForward.join(', ')});
return ${returnBridge.parseFromSwiftToCpp('__result', 'c++')};
`.trim()
} else {
callSwiftFuncBody = `call(${paramsForward.join(', ')});`
}

// TODO: Remove shared_Func_void(...) function that returns a std::shared_ptr<std::function<...>>
// once Swift fixes the bug where a regular std::function cannot be captured.
// https://github.com/swiftlang/swift/issues/76143

return {
cxxType: actualType,
funcName: `create_${name}`,
Expand All @@ -370,22 +359,16 @@ using ${name} = ${actualType};
*/
class ${wrapperName} final {
public:
explicit ${wrapperName}(const ${actualType}& func): _function(func) {}
explicit ${wrapperName}(${actualType}&& func): _function(std::move(func)) {}
explicit ${wrapperName}(${actualType}&& func): _function(std::make_shared<${actualType}>(std::move(func))) {}
inline ${callFuncReturnType} call(${callCppFuncParamsSignature.join(', ')}) const {
${indent(callCppFuncBody, ' ')}
}
private:
${actualType} _function;
} SWIFT_NONCOPYABLE;
inline ${name} create_${name}(void* _Nonnull closureHolder, ${functionPointerParam}, void(* _Nonnull destroy)(void* _Nonnull)) {
std::shared_ptr<void> sharedClosureHolder(closureHolder, destroy);
return ${name}([sharedClosureHolder = std::move(sharedClosureHolder), call](${paramsSignature.join(', ')}) -> ${type.returnType.getCode('c++')} {
${indent(callSwiftFuncBody, ' ')}
});
}
inline std::shared_ptr<${wrapperName}> share_${name}(const ${name}& value) {
return std::make_shared<${wrapperName}>(value);
std::shared_ptr<${actualType}> _function;
};
${name} create_${name}(void* _Nonnull swiftClosureWrapper);
inline ${wrapperName} wrap_${name}(${name} value) {
return ${wrapperName}(std::move(value));
}
`.trim(),
requiredIncludes: [
Expand All @@ -402,6 +385,24 @@ inline std::shared_ptr<${wrapperName}> share_${name}(const ${name}& value) {
...bridgedType.getRequiredImports(),
],
},
cxxImplementation: {
code: `
${name} create_${name}(void* _Nonnull swiftClosureWrapper) {
auto swiftClosure = ${swiftClassName}::fromUnsafe(swiftClosureWrapper);
return [swiftClosure = std::move(swiftClosure)](${paramsSignature.join(', ')}) mutable -> ${type.returnType.getCode('c++')} {
${indent(body, ' ')}
};
}
`.trim(),
requiredIncludes: [
{
language: 'c++',
// Swift umbrella header
name: getUmbrellaHeaderName(),
space: 'user',
},
],
},
dependencies: [],
}
}
Expand Down Expand Up @@ -566,7 +567,7 @@ ${functions.join('\n')}
function createCxxPromiseSwiftHelper(type: PromiseType): SwiftCxxHelper {
const resultingType = type.resultingType.getCode('c++')
const bridgedType = new SwiftCxxBridgedType(type)
const actualType = `PromiseHolder<${resultingType}>`
const actualType = `std::shared_ptr<Promise<${resultingType}>>`

const resolverArgs: NamedType[] = []
if (type.resultingType.kind !== 'void') {
Expand All @@ -589,7 +590,10 @@ function createCxxPromiseSwiftHelper(type: PromiseType): SwiftCxxHelper {
*/
using ${name} = ${actualType};
inline ${actualType} create_${name}() {
return PromiseHolder<${resultingType}>::create();
return Promise<${resultingType}>::create();
}
inline PromiseHolder<${resultingType}> wrap_${name}(${actualType} promise) {
return PromiseHolder<${resultingType}>(std::move(promise));
}
`.trim(),
requiredIncludes: [
Expand Down
Loading
Loading