Skip to content

Commit

Permalink
New feature: Added a deformation inverse prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnAshburner committed Apr 17, 2024
1 parent 7230aac commit f65bcc5
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ If you are brave enouth to try using it, then the following may work (in Julia)


Note that the automated githib actions fail because CUDA drivers are missing, which leads on to several other problems.
Multi-dimensionsional ffts on GPU (CUFFT) can also be [problematic](https://github.com/JuliaGPU/CUDA.jl/issues/119) with older Julia versions (fixed somewhere between Julia 1.7.2 and 1.9.3).

[![Build Status](https://github.com/spm/PushPull.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/spm/PushPull.jl/actions/workflows/CI.yml?query=branch%3Amain)

104 changes: 87 additions & 17 deletions src/multigrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,66 @@ Multiplying `v` by the Hessian (`H`).
function Hv!(v::VolType, H::VolType, u::VolType=zero(v))::VolType
@assert(all(dim(H) .== dim(v)))
@assert(all(size(v) .== size(u)))
@assert(size(H,4) == 6)
@assert(size(v,4) == 3)
h11 = view(H,:,:,:,1)
h22 = view(H,:,:,:,2)
h33 = view(H,:,:,:,3)
h12 = view(H,:,:,:,4)
h13 = view(H,:,:,:,5)
h23 = view(H,:,:,:,6)
v1 = view(v,:,:,:,1)
v2 = view(v,:,:,:,2)
v3 = view(v,:,:,:,3)
u1 = view(u,:,:,:,1)
u2 = view(u,:,:,:,2)
u3 = view(u,:,:,:,3)
u1 .+= h11.*v1 .+ h12.*v2 .+ h13.*v3
u2 .+= h12.*v1 .+ h22.*v2 .+ h23.*v3
u3 .+= h13.*v1 .+ h23.*v2 .+ h33.*v3
@assert(ndims(v)==4 && ndims(H)==4)
dv = size(v,4)
dh = size(H,4)
@assert(dh==1 || dh==dv || dh == Int((dv+1)*dv/2))

if false #size(v,4) == 3 # Special case
v1 = view(v,:,:,:,1)
v2 = view(v,:,:,:,2)
v3 = view(v,:,:,:,3)
u1 = view(u,:,:,:,1)
u2 = view(u,:,:,:,2)
u3 = view(u,:,:,:,3)
if size(H,4) >= 3
h11 = view(H,:,:,:,1)
h22 = view(H,:,:,:,2)
h33 = view(H,:,:,:,3)
if size(H,4) == 6
h12 = view(H,:,:,:,4)
h13 = view(H,:,:,:,5)
h23 = view(H,:,:,:,6)
u1 .+= h11.*v1 .+ h12.*v2 .+ h13.*v3
u2 .+= h12.*v1 .+ h22.*v2 .+ h23.*v3
u3 .+= h13.*v1 .+ h23.*v2 .+ h33.*v3
elseif size(H,4) == 3
u1 .+= h11.*v1
u2 .+= h22.*v2
u3 .+= h33.*v3
else
error()
end
elseif size(H,4) == 1
h11 = view(H,:,:,:,1)
u1 .+= h11.*v1
u2 .+= h11.*v2
u3 .+= h11.*v3
else
error()
end
return u
else # General case
if dh==1
h = view(H,:,:,:,1)
for i=1:dv
view(u,:,:,:,i) .+= h.*view(v,:,:,:,i)
end
elseif dh==dv || dh==Int((dv+1)*dv/2)
for i=1:dv
view(u,:,:,:,i) .+= view(H,:,:,:,i).*view(v,:,:,:,i)
end
if dh==Int((dv+1)*dv/2)
ii = dv
for i=1:dv, j=i+1:dv
ii += 1
h = view(H,:,:,:,ii)
view(u,:,:,:,i) .+= h.*view(v,:,:,:,j)
view(u,:,:,:,j) .+= h.*view(v,:,:,:,i)
end
end
end
end
return u
end

Expand Down Expand Up @@ -213,3 +256,30 @@ function fcycle!(v::VolType, g::VolType, HL::PyramidType; nit_pre::Integer=4, ni
return v
end


function invert_def(phi::T)::T where T<:VolType
d = size(phi)
if length(d)>4
iphi = zero(phi)
for i in CartesianIndices(d[5:end])
iphi[:,:,:,:,i] .= invert_def(phi[:,:,:,:,i])
end
return iphi
else
Id = id(d[1:3]; gpu=~isa(phi,Array))
sett = Settings(1, [2 1 1;1 2 1;1 1 2], 0)
g = push(Id, phi, d[1:3], sett)
o = typeof(g)(undef,(d[1:3]...,1))
o .= 1
h = push(o, phi, d[1:3], Settings(1,1,0))
g .-= h.*Id
H = typeof(g)(undef,(d[1:3]...,3))
H[:,:,:,1:3].=h
HL = hessian_pyramid(H,[1f0,1f0,1f0],[0., 0.01, 0.1, 0.01])
iphi = zero(g)
vcycle!(iphi, g, HL; nit_pre=2, nit_post=2)
iphi .+= Id
return iphi
end
end

0 comments on commit f65bcc5

Please sign in to comment.