Skip to content

Commit

Permalink
Testing: Tests of gradients of push/pull, numerical vs analytic
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnAshburner committed Mar 14, 2024
1 parent f9a7883 commit fce0b8d
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,33 @@ function operator_consistency(d::NTuple{3,Int64}, reg::Vector{<:AbstractFloat},
return sum((u1[:].-u0[:]).^2)./sum(u0[:].^2)
end

using Flux
function test_grad(fun,ϕ)

function numerical_grad(fun, θ₀)
E₀ = fun(θ₀)
g = zero(θ₀)
θ = deepcopy(θ₀)
ϵ = 0.01
for i=1:length(θ₀)
# Used i:i because scalar indexing of CUDA.jl
# arrays is problematic.
θ[i:i] .+= ϵ
E₊ = fun(θ)
θ[i:i] .-= 2ϵ
E₋ = fun(θ)
g[i:i] .= (E₊ - E₋)/(2ϵ)
θ[i:i] .= θ₀[i:i]
end
return g
end

g0 = gradient(fun,ϕ)[1]
return sum((g0-numerical_grad(fun,ϕ)).^2)/sum(g0.^2)
end



@testset "PushPull.jl" begin
# Write your tests here.

Expand Down Expand Up @@ -103,5 +130,81 @@ end

d = (2,3,4)
@test operator_consistency(d,[1e-3, 1.,9.,1.], false) < tol

d = (8,7,1)
c = 2

f1 = randn(Float32,(d...,c))
f2 = randn(Float32,(d...,c))
phi = randn(Float32,(d...,3))
phi .+= PushPull.id(d,false)

tol = 1e-3
sett = PushPull.Settings((1,1,1),(0,1,2),1)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < 1e-2
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < 5e-2
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < 5e-2

sett = PushPull.Settings((2,2,2),(0,1,2),1)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < tol
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < tol
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < tol

sett = PushPull.Settings((3,3,3),(0,1,2),1)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < tol
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < tol
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < tol

g2 = pull_grad(f2, phi, sett)
@test test_grad-> sum((pull_grad(f1, θ, sett) .- g2).^2),phi) < tol

f1 = CuArray(f1)
f2 = CuArray(f2)
phi = CuArray(phi)

sett = PushPull.Settings((1,1,1),(0,1,2),1)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < 1e-2
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < 5e-2
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < 5e-2

sett = PushPull.Settings((2,2,2),(0,1,2),1)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < tol
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < tol
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < tol

sett = PushPull.Settings((3,3,3),(0,1,2),1)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < tol
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < tol
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < tol

bs = 2
f1 = randn(Float32,(d...,c,bs))
f2 = randn(Float32,(d...,c,bs))
phi = randn(Float32,(d...,3,bs))
phi .+= PushPull.id(d,false)

sett = PushPull.Settings((3,3,3),(0,1,2),1)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < tol
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < tol
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < tol

f1 = CuArray(f1)
f2 = CuArray(f2)
phi = CuArray(phi)
@test test_grad-> sum((pull(f1, θ, sett) .- f2).^2),phi) < tol
@test test_grad-> sum((pull(θ,phi, sett) .- f2).^2),f1) < tol
@test test_grad-> sum((push(f1, θ,d,sett) .- f2).^2),phi) < tol
@test test_grad-> sum((push(θ,phi,d,sett) .- f2).^2),f1) < tol

g2 = pull_grad(f2, phi, sett)
@test test_grad-> sum((pull_grad(f1, θ, sett) .- g2).^2),phi) < tol
end
nothing

0 comments on commit fce0b8d

Please sign in to comment.