diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp index 28c211aa631e4d..a6a2f3595fe7db 100644 --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -1581,6 +1581,11 @@ bool CodeGenModule::ReturnTypeUsesSRet(const CGFunctionInfo &FI) { return RI.isIndirect() || (RI.isInAlloca() && RI.getInAllocaSRet()); } +bool CodeGenModule::ReturnTypeHasInReg(const CGFunctionInfo &FI) { + const auto &RI = FI.getReturnInfo(); + return RI.getInReg(); +} + bool CodeGenModule::ReturnSlotInterferesWithArgs(const CGFunctionInfo &FI) { return ReturnTypeUsesSRet(FI) && getTargetCodeGenInfo().doesReturnSlotInterfereWithArgs(); diff --git a/clang/lib/CodeGen/CGObjCGNU.cpp b/clang/lib/CodeGen/CGObjCGNU.cpp index a36b0cdddaf0af..05e3f8d4bfc2a3 100644 --- a/clang/lib/CodeGen/CGObjCGNU.cpp +++ b/clang/lib/CodeGen/CGObjCGNU.cpp @@ -2903,23 +2903,29 @@ CGObjCGNU::GenerateMessageSend(CodeGenFunction &CGF, break; case CodeGenOptions::Mixed: case CodeGenOptions::NonLegacy: + StringRef name = "objc_msgSend"; if (CGM.ReturnTypeUsesFPRet(ResultType)) { - imp = - CGM.CreateRuntimeFunction(llvm::FunctionType::get(IdTy, IdTy, true), - "objc_msgSend_fpret") - .getCallee(); + name = "objc_msgSend_fpret"; } else if (CGM.ReturnTypeUsesSRet(MSI.CallInfo)) { - // The actual types here don't matter - we're going to bitcast the - // function anyway - imp = - CGM.CreateRuntimeFunction(llvm::FunctionType::get(IdTy, IdTy, true), - "objc_msgSend_stret") - .getCallee(); - } else { - imp = CGM.CreateRuntimeFunction( - llvm::FunctionType::get(IdTy, IdTy, true), "objc_msgSend") - .getCallee(); + name = "objc_msgSend_stret"; + + // The address of the memory block is be passed in x8 for POD type, + // or in x0 for non-POD type (marked as inreg). + bool shouldCheckForInReg = + CGM.getContext() + .getTargetInfo() + .getTriple() + .isWindowsMSVCEnvironment() && + CGM.getContext().getTargetInfo().getTriple().isAArch64(); + if (shouldCheckForInReg && CGM.ReturnTypeHasInReg(MSI.CallInfo)) { + name = "objc_msgSend_stret2"; + } } + // The actual types here don't matter - we're going to bitcast the + // function anyway + imp = CGM.CreateRuntimeFunction(llvm::FunctionType::get(IdTy, IdTy, true), + name) + .getCallee(); } // Reset the receiver in case the lookup modified it diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h index ec34680fd3f7e6..d9ece4d98eecae 100644 --- a/clang/lib/CodeGen/CodeGenModule.h +++ b/clang/lib/CodeGen/CodeGenModule.h @@ -1239,6 +1239,9 @@ class CodeGenModule : public CodeGenTypeCache { /// Return true iff the given type uses 'sret' when used as a return type. bool ReturnTypeUsesSRet(const CGFunctionInfo &FI); + /// Return true iff the given type has `inreg` set. + bool ReturnTypeHasInReg(const CGFunctionInfo &FI); + /// Return true iff the given type uses an argument slot when 'sret' is used /// as a return type. bool ReturnSlotInterferesWithArgs(const CGFunctionInfo &FI); diff --git a/clang/test/CodeGenObjCXX/msabi-stret-arm64.mm b/clang/test/CodeGenObjCXX/msabi-stret-arm64.mm new file mode 100644 index 00000000000000..3bbdbebc5cb576 --- /dev/null +++ b/clang/test/CodeGenObjCXX/msabi-stret-arm64.mm @@ -0,0 +1,77 @@ +// RUN: %clang_cc1 -triple aarch64-pc-windows-msvc -fobjc-runtime=gnustep-2.2 -fobjc-dispatch-method=non-legacy -emit-llvm -o - %s | FileCheck %s + +// Pass and return for type size <= 8 bytes. +struct S1 { + int a[2]; +}; + +// Pass and return hfa <= 8 bytes +struct F1 { + float a[2]; +}; + +// Pass and return for type size > 16 bytes. +struct S2 { + int a[5]; +}; + +// Pass and return aggregate (of size < 16 bytes) with non-trivial destructor. +// Sret and inreg: Returned in x0 +struct S3 { + int a[3]; + ~S3(); +}; +S3::~S3() { +} + + +@interface MsgTest { id isa; } @end +@implementation MsgTest +- (S1) smallS1 { + S1 x; + x.a[0] = 0; + x.a[1] = 1; + return x; + +} +- (F1) smallF1 { + F1 x; + x.a[0] = 0.2f; + x.a[1] = 0.5f; + return x; +} +- (S2) stretS2 { + S2 x; + for (int i = 0; i < 5; i++) { + x.a[i] = i; + } + return x; +} +- (S3) stretInRegS3 { + S3 x; + for (int i = 0; i < 3; i++) { + x.a[i] = i; + } + return x; +} ++ (S3) msgTestStretInRegS3 { + S3 x; + for (int i = 0; i < 3; i++) { + x.a[i] = i; + } + return x; +} +@end + +void test0(MsgTest *t) { + // CHECK: call {{.*}} @objc_msgSend + S1 ret = [t smallS1]; + // CHECK: call {{.*}} @objc_msgSend + F1 ret2 = [t smallF1]; + // CHECK: call {{.*}} @objc_msgSend_stret + S2 ret3 = [t stretS2]; + // CHECK: call {{.*}} @objc_msgSend_stret2 + S3 ret4 = [t stretInRegS3]; + // CHECK: call {{.*}} @objc_msgSend_stret2 + S3 ret5 = [MsgTest msgTestStretInRegS3]; +}