Skip to content

Commit

Permalink
test: separate out the enzyme testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 26, 2024
1 parent 574b0d8 commit ef14bab
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 25 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
version: ['1']
group:
- Core
- Enzyme
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- "LinearSolveHYPRE"
- "LinearSolvePardiso"
- "LinearSolveBandedMatrices"
- "Enzyme"
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
group: "${{ matrix.group }}"
Expand Down
41 changes: 18 additions & 23 deletions ext/LinearSolveEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module LinearSolveEnzymeExt
using LinearSolve
using LinearSolve.LinearAlgebra
using EnzymeCore
using EnzymeCore: EnzymeRules

function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfigWidth{1},
function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@assert !(prob isa Const)
Expand All @@ -19,26 +20,20 @@ function EnzymeCore.EnzymeRules.forward(config::EnzymeCore.EnzymeRules.FwdConfig
dres = func.val(prob.dval, alg.val; kwargs...)
dres.b .= res.b == dres.b ? zero(dres.b) : dres.b
dres.A .= res.A == dres.A ? zero(dres.A) : dres.A
if RT <: DuplicatedNoNeed
return dres
elseif RT <: Duplicated
return Duplicated(res, dres)
end
error("Unsupported return type $RT")

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
Duplicated(res, dres)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
dres
return dres
elseif EnzymeRules.needs_primal(config)
res
return res
else
nothing
return nothing
end
end

function EnzymeCore.EnzymeRules.forward(
config::EnzymeCore.EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1}, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
@assert !(linsolve isa Const)
Expand Down Expand Up @@ -66,17 +61,17 @@ function EnzymeCore.EnzymeRules.forward(
linsolve.val.b = b

if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
Duplicated(res, dres)
return Duplicated(res, dres)
elseif EnzymeRules.needs_shadow(config)
dres
return dres
elseif EnzymeRules.needs_primal(config)
res
return res
else
nothing
return nothing
end
end

function EnzymeCore.EnzymeRules.augmented_primal(
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.init)},
::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down Expand Up @@ -111,10 +106,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
(dval.b for dval in prob.dval)
end

return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
return EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
end

function EnzymeCore.EnzymeRules.reverse(
function EnzymeRules.reverse(
config, func::Const{typeof(LinearSolve.init)}, ::Type{RT},
cache, prob::EnzymeCore.Annotation{LP}, alg::Const;
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
Expand Down Expand Up @@ -148,7 +143,7 @@ end
# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
function EnzymeCore.EnzymeRules.augmented_primal(
function EnzymeRules.augmented_primal(
config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
Expand Down Expand Up @@ -201,10 +196,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
cachesolve = deepcopy(linsolve.val)

cache = (copy(res.u), resvals, cachesolve, dAs, dbs)
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
return EnzymeRules.AugmentedReturn(res, dres, cache)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP};
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys, _linsolve, dAs, dbs = cache
Expand Down
1 change: 0 additions & 1 deletion test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Enzyme, ForwardDiff
using LinearSolve, LinearAlgebra, Test
using FiniteDiff
using SafeTestsets

n = 4
A = rand(n, n);
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ if GROUP == "All" || GROUP == "Core"
@time @safetestset "Non-Square Tests" include("nonsquare.jl")
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "Traits" include("traits.jl")
@time @safetestset "BandedMatrices" include("banded.jl")
@time @safetestset "Static Arrays" include("static_arrays.jl")
end

if GROUP == "All" || GROUP == "Enzyme"
@time @safetestset "Enzyme Derivative Rules" include("enzyme.jl")
end

if GROUP == "LinearSolveCUDA"
Pkg.activate("gpu")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Expand Down

0 comments on commit ef14bab

Please sign in to comment.