diff --git a/ext/LinearSolveEnzymeExt.jl b/ext/LinearSolveEnzymeExt.jl index f38cf56e2..b87f45c85 100644 --- a/ext/LinearSolveEnzymeExt.jl +++ b/ext/LinearSolveEnzymeExt.jl @@ -8,31 +8,84 @@ using Enzyme using EnzymeCore +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + res = func.val(prob.val, alg.val; kwargs...) + dres = if EnzymeRules.width(config) == 1 + func.val(prob.dval, alg.val; kwargs...) + else + (func.val(dval, alg.val; kwargs...) for dval in prob.dval) + end + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing) +end + +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem} + return (nothing, nothing) +end + # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} - res = func.val(prob.val, alg.val; kwargs...) - dres = deepcopy(res) - dres.u .= 0 - cache = (copy(prob.val.A), res, dres.u) - return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, typeof(cache)}(res, dres, cache) +function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + res = func.val(linsolve.val; kwargs...) + dres = if EnzymeRules.width(config) == 1 + deepcopy(res) + else + (deepcopy(res) for dval in linsolve.dval) + end + + if EnzymeRules.width(config) == 1 + dres.u .= 0 + else + for dr in dres + dr.u .= 0 + end + end + + resvals = if EnzymeRules.width(config) == 1 + dres.u + else + (dr.u for dr in dres) + end + + cache = (copy(linsolve.val.A), res, resvals) + return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache) end -function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(LinearSolve.solve)}, ::Type{Duplicated{RT}}, cache, prob::Duplicated{LP}, alg::Const; kwargs...) where {RT, LP <: LinearProblem} - A, y, dy = cache +function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache} + A, y, dys = cache - dA = prob.dval.A - db = prob.dval.b + @assert !(typeof(linsolve) <: Const) + @assert !(typeof(linsolve) <: Active) - invprob = LinearProblem(transpose(A), dy) + if EnzymeRules.width(config) == 1 + dys = (dys,) + end - z = func.val(invprob, alg; kwargs...) + dAs = if EnzymeRules.width(config) == 1 + (linsolve.dval.A,) + else + (dval.A for dval in linsolve.dval) + end - dA .-= z * transpose(y) - db .+= z - dy .= 0 - return (nothing, nothing) + dbs = if EnzymeRules.width(config) == 1 + (linsolve.dval.b,) + else + (dval.b for dval in linsolve.dval) + end + + for (dA, db, dy) in zip(dAs, dbs, dys) + invprob = LinearSolve.LinearProblem(transpose(A), dy) + z = solve(invprob; + abstol = linsolve.val.abstol, + reltol = linsolve.val.reltol, + verbose = linsolve.val.verbose) + + dA .-= z * transpose(y) + db .+= z + dy .= eltype(dy)(0) + end + + return (nothing,) end end \ No newline at end of file