diff --git a/src/optim.jl b/src/optim.jl index cd561a5c..2195a902 100644 --- a/src/optim.jl +++ b/src/optim.jl @@ -3,7 +3,7 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=1) tm = llvm_machine(job.config.target) - global current_job + global current_job # ScopedValue? current_job = job @dispose pb=NewPMPassBuilder() begin @@ -24,6 +24,15 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level= return end +# TODO: Priority heap to provide order between different plugins +const PIPELINE_CALLBACKS = Dict{String, Any}() +function register_plugin!(name::String, plugin) + if haskey(PIPELINE_CALLBACKS, name) + error("GPUCompiler plugin with name $name is already registered") + end + PIPELINE_CALLBACKS[name] = plugin +end + function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level) buildEarlySimplificationPipeline(mpm, job, opt_level) add!(mpm, AlwaysInlinerPass()) @@ -41,6 +50,9 @@ function buildNewPMPipeline!(mpm, @nospecialize(job::CompilerJob), opt_level) add!(fpm, WarnMissedTransformationsPass()) end end + for (name, callback) in PIPELINE_CALLBACKS + add!(mpm, CallbackPass(name, callback)) + end buildIntrinsicLoweringPipeline(mpm, job, opt_level) buildCleanupPipeline(mpm, job, opt_level) end @@ -423,3 +435,17 @@ function lower_ptls!(mod::LLVM.Module) return changed end LowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!) + + +function callback_pass!(name, callback::F, mod::LLVM.Module) where F + job = current_job::CompilerJob + changed = false + + if haskey(functions(mod), name) + marker = functions(mod)[name] + changed = callback(job, marker, mod) + end + return changed +end + +CallbackPass(name, callback) = NewPMModulePass("CallbackPass<$name>", (mod)->callback_pass!(name, callback, mod)) diff --git a/test/ptx_tests.jl b/test/ptx_tests.jl index 6caa6c71..145a4ace 100644 --- a/test/ptx_tests.jl +++ b/test/ptx_tests.jl @@ -276,6 +276,15 @@ end @test "We did not crash!" != "" end +@testset "Pipeline callbacks" begin + function kernel(x) + PTX.mark(x) + return + end + ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Int})) + @test !occursin("gpucompuler.mark", ir) +end + @testset "exception arguments" begin function kernel(a) unsafe_store!(a, trunc(Int, unsafe_load(a))) diff --git a/test/ptx_testsetup.jl b/test/ptx_testsetup.jl index ed5026f1..e05eaa8f 100644 --- a/test/ptx_testsetup.jl +++ b/test/ptx_testsetup.jl @@ -16,6 +16,28 @@ end GPUCompiler.kernel_state_type(@nospecialize(job::PTXCompilerJob)) = PTXKernelState @inline @generated kernel_state() = GPUCompiler.kernel_state_value(PTXKernelState) +function mark(x) + ccall("gpucompiler.mark", llvcmall, Nothing, (Int,), x) +end + +function remove_mark!(@nospecialize(job), intrinsic, mod::LLVM.Module) + changed = false + + for use in uses(intrinsic) + val = user(use) + if isempty(uses(val)) + unsafe_delete!(LLVM.parent(val), val) + changed = true + else + # the validator will detect this + end + end + + return changed +end + +GPUCompiler.register_plugin!("gpucompiler.mark", remove_mark!) + # a version of the test runtime that has some side effects, loading the kernel state # (so that we can test if kernel state arguments are appropriately optimized away) module PTXTestRuntime