diff --git a/src/fft_convolve.jl b/src/fft_convolve.jl index 5ff47d4a..877aab42 100644 --- a/src/fft_convolve.jl +++ b/src/fft_convolve.jl @@ -195,3 +195,31 @@ function fft_conv_adj!( return output end + + +""" + fft_conv_adj2!(output, image2, ker3, plans) +In-place version of adjoint of convolving a 2D `image2` with a 3D kernel `ker3` +""" +function fft_conv_adj2!( + output::AbstractArray{<:RealU,3}, + image2::AbstractMatrix{<:RealU}, + ker3::AbstractArray{<:RealU,3}, + plans::Vector{<:PlanPSF}, +) + + size(output, 1) == size(image2, 1) || throw("size 1") + size(output, 3) == size(image2, 2) || throw("size 2") + + fun = y -> fft_conv_adj!( + (@view output[:, y, :]), + image2, + (@view ker3[:, :, y]), + plans[Threads.threadid()], + ) + + ntasks = length(plans) + Threads.foreach(fun, _setup(1:size(output, 2)); ntasks) + + return output +end diff --git a/test/adjoint-fftconv.jl b/test/adjoint-fftconv.jl index 0855cbd0..c35c1448 100644 --- a/test/adjoint-fftconv.jl +++ b/test/adjoint-fftconv.jl @@ -2,7 +2,7 @@ # test adjoint consistency for FFT convolution methods on very small case using SPECTrecon: plan_psf -using SPECTrecon: fft_conv!, fft_conv_adj! +using SPECTrecon: fft_conv!, fft_conv_adj!, fft_conv_adj2! using SPECTrecon: fft_conv, fft_conv_adj using LinearMapsAA: LinearMapAA using Test: @test, @testset @@ -29,6 +29,8 @@ end @test maximum(result) ≤ 1 fft_conv_adj!(result, image3, ker3, plan) @test maximum(result) ≤ 1.5 # boundary is the sum of replicate padding + fft_conv_adj2!(result, image3[:, 3, :], ker3, plan) + @test maximum(result) ≤ 1.5 # boundary is the sum of replicate padding end