Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme extension #377

Merged
merged 20 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -30,6 +31,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Expand All @@ -42,6 +44,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2"
[extensions]
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand Down Expand Up @@ -78,6 +81,8 @@ julia = "1.6"

[extras]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -95,4 +100,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals"]
test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll", "BlockDiagonals", "Enzyme", "FiniteDiff"]
166 changes: 166 additions & 0 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
module LinearSolveEnzymeExt

using LinearSolve
using LinearSolve.LinearAlgebra
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)


using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)

Check warning on line 15 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L12-L15

Added lines #L12 - L15 were not covered by tests
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
func.val(prob.dval[i], alg.val; kwargs...)

Check warning on line 19 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L17-L19

Added lines #L17 - L19 were not covered by tests
end
end
d_A = if EnzymeRules.width(config) == 1
dres.A

Check warning on line 23 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L22-L23

Added lines #L22 - L23 were not covered by tests
else
(dval.A for dval in dres)

Check warning on line 25 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L25

Added line #L25 was not covered by tests
end
d_b = if EnzymeRules.width(config) == 1
dres.b

Check warning on line 28 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L27-L28

Added lines #L27 - L28 were not covered by tests
else
(dval.b for dval in dres)

Check warning on line 30 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L30

Added line #L30 was not covered by tests
end


prob_d_A = if EnzymeRules.width(config) == 1
prob.dval.A

Check warning on line 35 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
else
(dval.A for dval in prob.dval)

Check warning on line 37 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L37

Added line #L37 was not covered by tests
end
prob_d_b = if EnzymeRules.width(config) == 1
prob.dval.b

Check warning on line 40 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L39-L40

Added lines #L39 - L40 were not covered by tests
else
(dval.b for dval in prob.dval)

Check warning on line 42 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L42

Added line #L42 was not covered by tests
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))

Check warning on line 45 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L45

Added line #L45 was not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
d_A, d_b, prob_d_A, prob_d_b = cache

Check warning on line 49 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L48-L49

Added lines #L48 - L49 were not covered by tests

if EnzymeRules.width(config) == 1
if d_A !== prob_d_A
prob_d_A .+= d_A
d_A .= 0

Check warning on line 54 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L51-L54

Added lines #L51 - L54 were not covered by tests
end
if d_b !== prob_d_b
prob_d_b .+= d_b
d_b .= 0

Check warning on line 58 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L56-L58

Added lines #L56 - L58 were not covered by tests
end
else
for i in 1:EnzymeRules.width(config)
if d_A !== prob_d_A[i]
prob_d_A[i] .+= d_A[i]
d_A[i] .= 0

Check warning on line 64 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L61-L64

Added lines #L61 - L64 were not covered by tests
end
if d_b !== prob_d_b[i]
prob_d_b[i] .+= d_b[i]
d_b[i] .= 0

Check warning on line 68 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L66-L68

Added lines #L66 - L68 were not covered by tests
end
end

Check warning on line 70 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L70

Added line #L70 was not covered by tests
end

return (nothing, nothing)

Check warning on line 73 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L73

Added line #L73 was not covered by tests
end

# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
res = func.val(linsolve.val; kwargs...)

Check warning on line 80 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L79-L80

Added lines #L79 - L80 were not covered by tests

dres = if EnzymeRules.width(config) == 1
deepcopy(res)

Check warning on line 83 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L82-L83

Added lines #L82 - L83 were not covered by tests
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
deepcopy(res)

Check warning on line 87 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L85-L87

Added lines #L85 - L87 were not covered by tests
end
end

if EnzymeRules.width(config) == 1
dres.u .= 0

Check warning on line 92 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
else
for dr in dres
dr.u .= 0
end

Check warning on line 96 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L94-L96

Added lines #L94 - L96 were not covered by tests
end

resvals = if EnzymeRules.width(config) == 1
dres.u

Check warning on line 100 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L99-L100

Added lines #L99 - L100 were not covered by tests
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
dres[i].u

Check warning on line 104 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L102-L104

Added lines #L102 - L104 were not covered by tests
end
end

dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)

Check warning on line 109 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L108-L109

Added lines #L108 - L109 were not covered by tests
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].A

Check warning on line 113 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L111-L113

Added lines #L111 - L113 were not covered by tests
end
end

dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)

Check warning on line 118 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L117-L118

Added lines #L117 - L118 were not covered by tests
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
linsolve.dval[i].b

Check warning on line 122 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L120-L122

Added lines #L120 - L122 were not covered by tests
end
end

cachesolve = deepcopy(linsolve.val)

Check warning on line 126 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L126

Added line #L126 was not covered by tests

cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)

Check warning on line 129 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache

Check warning on line 133 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L132-L133

Added lines #L132 - L133 were not covered by tests

@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

Check warning on line 136 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L135-L136

Added lines #L135 - L136 were not covered by tests

if EnzymeRules.width(config) == 1
dys = (dys,)

Check warning on line 139 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L138-L139

Added lines #L138 - L139 were not covered by tests
end

for (dA, db, dy) in zip(dAs, dbs, dys)
z = if _linsolve.cacheval isa Factorization
_linsolve.cacheval' \ dy
elseif _linsolve.cacheval isa Tuple && _linsolve.cacheval[1] isa Factorization
_linsolve.cacheval[1]' \ dy
elseif _linsolve.alg isa AbstractKrylovSubspaceMethod

Check warning on line 147 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L142-L147

Added lines #L142 - L147 were not covered by tests
# Doesn't modify `A`, so it's safe to just reuse it
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
solve(invprob;

Check warning on line 150 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L149-L150

Added lines #L149 - L150 were not covered by tests
abstol = _linsolve.val.abstol,
reltol = _linsolve.val.reltol,
verbose = _linsolve.val.verbose)
else
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")

Check warning on line 155 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L155

Added line #L155 was not covered by tests
end

dA .-= z * transpose(y)
db .+= z
dy .= eltype(dy)(0)
end

Check warning on line 161 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L158-L161

Added lines #L158 - L161 were not covered by tests

return (nothing,)

Check warning on line 163 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L163

Added line #L163 was not covered by tests
end

end
3 changes: 3 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
@require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin
include("../ext/LinearSolveMKLExt.jl")
end
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("../ext/LinearSolveEnzymeExt.jl")

Check warning on line 19 in src/init.jl

View check run for this annotation

Codecov / codecov/patch

src/init.jl#L19

Added line #L19 was not covered by tests
end
end
end
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ end
end
end

test_algs = if VERISON >= v"1.9"
test_algs = if VERSION >= v"1.9"
(LUFactorization(),
QRFactorization(),
SVDFactorization(),
Expand Down
125 changes: 125 additions & 0 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end

f(A, b1) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))

@test dA ≈ dA2
@test db1 ≈ db12

A = rand(n, n);
dA = zeros(n, n);
dA2 = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

#=
# Batch test fails
# Captured in MWE: https://github.com/EnzymeAD/Enzyme.jl/issues/1075

function fbatch(y, A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
y[1] = norm(s1)
nothing
end

y = [0.0]
dy1 = [1.0]
dy2 = [1.0]
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))

@test_broken dA ≈ dA_2
@test_broken dA2 ≈ dA_2
@test_broken db1 ≈ db1_2
@test_broken db12 ≈ db1_2
=#

function f(A, b1, b2; alg = LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

f(A, b1, b2)
Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1),eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x,eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x->f(eltype(x).(A),eltype(x).(b1),x), copy(b2))

@test dA ≈ dA2
@test db1 ≈ db12
@test db2 ≈ db22

function f2(A, b1, b2; alg = RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f2(A, b1, b2)
dA = zeros(n, n);
db1 = zeros(n);
db2 = zeros(n);
Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA ≈ dA2
@test db1 ≈ db12
@test db2 ≈ db22

#=
function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))

@test dA ≈ dA2 atol=5e-5
@test db1 ≈ db12
@test db2 ≈ db22
=#
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ const HAS_EXTENSIONS = isdefined(Base, :get_extension)

if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests" include("basictests.jl")
@time @safetestset "Re-solve" include("resolve.jl")
VERSION >= v"1.9" && @time @safetestset "Re-solve" include("resolve.jl")
@time @safetestset "Zero Initialization Tests" include("zeroinittests.jl")
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Traits" include("traits.jl")
end

Expand Down
Loading