Skip to content

Commit

Permalink
Extend
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 22, 2023
1 parent a08386d commit 9273a20
Showing 1 changed file with 69 additions and 16 deletions.
85 changes: 69 additions & 16 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,84 @@ using Enzyme

using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)

Check warning on line 14 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L11-L14

Added lines #L11 - L14 were not covered by tests
else
(func.val(dval, alg.val; kwargs...) for dval in prob.dval)

Check warning on line 16 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L16

Added line #L16 was not covered by tests
end
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing)

Check warning on line 18 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L18

Added line #L18 was not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
return (nothing, nothing)

Check warning on line 22 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
end

# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = deepcopy(res)
dres.u .= 0
cache = (copy(prob.val.A), res, dres.u)
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache)
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
res = func.val(linsolve.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
deepcopy(res)

Check warning on line 31 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L28-L31

Added lines #L28 - L31 were not covered by tests
else
(deepcopy(res) for dval in linsolve.dval)

Check warning on line 33 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L33

Added line #L33 was not covered by tests
end

if EnzymeRules.width(config) == 1
dres.u .= 0

Check warning on line 37 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L36-L37

Added lines #L36 - L37 were not covered by tests
else
for dr in dres
dr.u .= 0
end

Check warning on line 41 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L39-L41

Added lines #L39 - L41 were not covered by tests
end

resvals = if EnzymeRules.width(config) == 1
dres.u

Check warning on line 45 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
else
(dr.u for dr in dres)

Check warning on line 47 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L47

Added line #L47 was not covered by tests
end

cache = (copy(linsolve.val.A), res, resvals)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)

Check warning on line 51 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end

function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem}
A, y, dy = cache
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
A, y, dys = cache

Check warning on line 55 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L54-L55

Added lines #L54 - L55 were not covered by tests

dA = prob.dval.A
db = prob.dval.b
@assert !(typeof(linsolve) <: Const)
@assert !(typeof(linsolve) <: Active)

Check warning on line 58 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L57-L58

Added lines #L57 - L58 were not covered by tests

invprob = LinearProblem(transpose(A), dy)
if EnzymeRules.width(config) == 1
dys = (dys,)

Check warning on line 61 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
end

z = func.val(invprob, alg; kwargs...)
dAs = if EnzymeRules.width(config) == 1
(linsolve.dval.A,)

Check warning on line 65 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L64-L65

Added lines #L64 - L65 were not covered by tests
else
(dval.A for dval in linsolve.dval)

Check warning on line 67 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L67

Added line #L67 was not covered by tests
end

dA .-= z * transpose(y)
db .+= z
dy .= 0
return (nothing, nothing)
dbs = if EnzymeRules.width(config) == 1
(linsolve.dval.b,)

Check warning on line 71 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
else
(dval.b for dval in linsolve.dval)

Check warning on line 73 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L73

Added line #L73 was not covered by tests
end

for (dA, db, dy) in zip(dAs, dbs, dys)
invprob = LinearSolve.LinearProblem(transpose(A), dy)
z = solve(invprob;

Check warning on line 78 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L76-L78

Added lines #L76 - L78 were not covered by tests
abstol = linsolve.val.abstol,
reltol = linsolve.val.reltol,
verbose = linsolve.val.verbose)

dA .-= z * transpose(y)
db .+= z
dy .= eltype(dy)(0)
end

Check warning on line 86 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L83-L86

Added lines #L83 - L86 were not covered by tests

return (nothing,)

Check warning on line 88 in ext/LinearSolveEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveEnzymeExt.jl#L88

Added line #L88 was not covered by tests
end

end

0 comments on commit 9273a20

Please sign in to comment.