diff --git a/src/interface.jl b/src/interface.jl index 42f839e3..fda6a49f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -182,7 +182,8 @@ function process_entry!(@nospecialize(job::CompilerJob), mod::LLVM.Module, if job.source.kernel # pass all bitstypes by value; by default Julia passes aggregates by reference # (this improves performance, and is mandated by certain back-ends like SPIR-V). - args = classify_arguments(job, eltype(llvmtype(entry))) + source_sig = Base.signature_type(job.source.f, job.source.tt)::Type + args = classify_arguments(source_sig, eltype(llvmtype(entry))) for arg in args if arg.cc == BITS_REF attr = if LLVM.version() >= v"12" diff --git a/src/irgen.jl b/src/irgen.jl index 4d527aa6..5184df76 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -287,10 +287,8 @@ end GHOST # not passed end -function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType) - source_sig = Base.signature_type(job.source.f, job.source.tt)::Type +function classify_arguments(source_sig::Type, codegen_ft::LLVM.FunctionType) source_types = [source_sig.parameters...] - codegen_types = parameters(codegen_ft) args = [] diff --git a/src/spirv.jl b/src/spirv.jl index 632b6c51..1ebcfec4 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -239,7 +239,8 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F else ft end - args = classify_arguments(job, orig_ft) + source_sig = Base.signature_type(job.source.f, job.source.tt)::Type + args = classify_arguments(source_sig, entry_f) filter!(args) do arg arg.cc != GHOST end