Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme extension #377

Merged
merged 20 commits into from
Sep 24, 2023
Merged

Add Enzyme extension #377

merged 20 commits into from
Sep 24, 2023

Conversation

wsmoses
Copy link
Contributor

@wsmoses wsmoses commented Sep 22, 2023

requires current Enzyme main for a custom rules fix

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023

Sample call:

using Enzyme

using LinearSolve, LinearAlgebra

n = 4
A = rand(n, n);
dA = zeros(n, n);
b1 = rand(n);
db1 = zeros(n);
b2 = rand(n);
db2 = zeros(n);

function f(A, b1, b2; alg = LUFactorization())
    prob = LinearProblem(A, b1)

    sol1 = solve(prob, alg)

    s1 = sol1.u
    norm(s1)
end

f(A, b1, b2) # Uses BLAS

Enzyme.autodiff(Reverse, f, Duplicated(A, dA), Duplicated(b1, db1), Duplicated(b2, db2))

@show dA, db1, db2

@codecov
Copy link

codecov bot commented Sep 22, 2023

Codecov Report

Merging #377 (89e10df) into main (5a25b7d) will increase coverage by 48.24%.
Report is 6 commits behind head on main.
The diff coverage is 1.11%.

@@             Coverage Diff             @@
##             main     #377       +/-   ##
===========================================
+ Coverage   20.01%   68.25%   +48.24%     
===========================================
  Files          14       24       +10     
  Lines        1444     1884      +440     
===========================================
+ Hits          289     1286      +997     
+ Misses       1155      598      -557     
Files Changed Coverage Δ
ext/LinearSolveEnzymeExt.jl 0.00% <0.00%> (ø)
src/init.jl 57.14% <50.00%> (-17.86%) ⬇️

... and 22 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@ChrisRackauckas ChrisRackauckas force-pushed the master branch 2 times, most recently from 004eb9d to 70e8599 Compare September 22, 2023 13:16
@ChrisRackauckas
Copy link
Member

It looks like this only handles the case of solve, but not solve!. So I presume this case would still not work:

using LinearSolve, LinearAlgebra
# using MKL_jll

n = 100
A = rand(n, n)
b1 = rand(n);
b2 = rand(n);

function f(A, b1, b2; alg = LUFactorization())
    prob = LinearProblem(A, b1)

    linsolve = init(prob, alg)
    sol1 = solve!(linsolve)

    s1 = copy(sol1.u)

    linsolve.b = b2
    sol2 = solve!(linsolve)

    s2 = copy(sol2.u)
    norm(s1 + s2)
end

f(A, b1, b2) # Uses BLAS
f(A, b1, b2; alg=RFLUFactorization()) # Uses loops
f(A, b1, b2; alg=MKLLUFactorization()) # Requires `using MKL_jll`

using Enzyme

dA = zero(A)
db1 = zero(b1)
db2 = zero(b2)
Enzyme.autodiff(Reverse, f, Duplicated(A,dA), 
                Duplicated(b1, db1), Duplicated(b2, db2))

which is EnzymeAD/Enzyme.jl#1065.

I at least added a test for the solve case, but the most common case is on solve! so it would be good to figure out how to do that. It's the same thing except solve!(cache) has cache.A and cache.b1, where cache.isfresh == true means A is already factorized. Is there a way to define the derivative w.r.t. fields of the mutable cache? Or should this be done with a solve!_up type thing?

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023

Pushed extension for solve! and init now.

While was at it, also added batch mode support.

Comment on lines 11 to 23
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
res = func.val(prob.val, alg.val; kwargs...)
dres = if EnzymeRules.width(config) == 1
func.val(prob.dval, alg.val; kwargs...)
else
(func.val(dval, alg.val; kwargs...) for dval in prob.dval)
end
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
return (nothing, nothing)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one required? It seems like it doesn't do much?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init hits that global variable stuff, so we need a rule for corresponding shadow initialization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see

@ChrisRackauckas
Copy link
Member

While was at it, also added batch mode support.

What in here was required for batch mode support?

(dr.u for dr in dres)
end

cache = (copy(linsolve.val.A), res, resvals)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this copy necessary?

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023

Not specializing to just duplicated but also supporting batchduplicated, which has dval as a tuple of shadows

@ChrisRackauckas
Copy link
Member

As a tuple, does that have an issue scaling to say batch of a 100 or 1000 things?

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023 via email

end

for (dA, db, dy) in zip(dAs, dbs, dys)
invprob = LinearSolve.LinearProblem(transpose(A), dy)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the forward pass the matrix A is factorized, so in theory we don't need to factorize it again, just transpose A from the forward pass. Is there a way to grab that?

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023 via email

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023 via email

@ChrisRackauckas
Copy link
Member

The key that I'm pointing out here is similar to the top of https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/. But here, what solve! is doing is solving:

_A = lu!(A)
_A \ b1

and then the backpass is:

_At = lu!(A')
_At \ db1

but we also have that (essentially) _At = _A', or at least it can be computed in O(n) time, whereas a factorization is O(n^3) and thus lu! is one of the most expensive operations.

So what I'm wondering is if it's safe to assume that linsolve is the same linsolve object from the forward pass, or if it may have been further mutated.

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023

The key that I'm pointing out here is similar to the top of https://docs.sciml.ai/LinearSolve/stable/tutorials/caching_interface/. But here, what solve! is doing is solving:

_A = lu!(A)

_A \ b1

and then the backpass is:

_At = lu!(A')

_At \ db1

but we also have that (essentially) _At = _A', or at least it can be computed in O(n) time, whereas a factorization is O(n^3) and thus lu! is one of the most expensive operations.

So what I'm wondering is if it's safe to assume that linsolve is the same linsolve object from the forward pass, or if it may have been further mutated.

It's the same Julia object, but it's possible it's fields may have been modified. If it's immutable, then it's the same.

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 22, 2023

Even if it's overwritten, however, you can still add whatever is relevant from he LU into the cache and use that as a starting point

@ChrisRackauckas
Copy link
Member

Awesome, I'll leave that as a follow-up, no need to handle it in this PR. But the tests do need to get fixed.

@ChrisRackauckas
Copy link
Member

The transpose of the factorization is the factorization of the transpose:

using LinearAlgebra
A = rand(4,4)
luA = lu(A)

At = transpose(A)
luAt = lu(At)

b = rand(4)

x  = A \ b
x2 = A' \ b
x3 = luA \ b
x4 = luAt \ b
x5 = luA' \ b

x  x3
x2  x4  x5

Confirmed from https://web.mit.edu/18.06/www/Spring17/Transposes.pdf. We can use this to generalize and optimize a bit.


function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
y, dys = cache
_linsolve = linsolve.val
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still wrong, because linsolve still couldve been overwritten from forward to reverse. You need to cache it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay was just about to ask that, thanks. I think with that this may be completed. Though check the batch syntax in the test: the test still errors with BatchDuplicated and I'm not sure what to do there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the error log from?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ERROR: TypeError: in ccall argument 6, expected Tuple{Float64, Float64}, got a value of type Float64
Stacktrace:
 [1] macro expansion
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9774 [inlined]
 [2] enzyme_call
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9452 [inlined]
 [3] CombinedAdjointThunk
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:9415 [inlined]
 [4] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:213 [inlined]
 [5] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:236 [inlined]
 [6] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::BatchDuplicated{Matrix{Float64}, 2}, ::BatchDuplicated{Vector{Float64}, 2})
   @ Enzyme C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:222
 [7] top-level scope
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thats an easy one [which we sohuld fix]. You can't use an active return right now in batch mode (which also makes little sense here since you'd back propagate the same value to each). Just wrap that func in a closure that stores it to a vector or something

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, yeah the test was a bit dumb but just a quick sanity check 😓. Fixing that gives:

ERROR: Enzyme execution failed.
Enzyme: Augmented forward pass custom rule Tuple{EnzymeCore.EnzymeRules.ConfigWidth{2, true, true, (false, false, false)}, Const{typeof(init)}, Type{BatchDuplicated{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, 2}}, BatchDuplicated{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, 2}, Const{LUFactorization{RowMaximum}}} return type mismatch, expected EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Tuple{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}}, Any} found EnzymeCore.EnzymeRules.AugmentedReturn{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LUFactorization{RowMaximum}, LU{Float64, Matrix{Float64}, Vector{Int64}}, IdentityOperator, IdentityOperator, Float64, Bool}, Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, Tuple{Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#3#6"}, Base.Generator{Base.Generator{Tuple{LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, LinearProblem{Nothing, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}, LinearSolveEnzymeExt.var"#2#5"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Const{typeof(init)}, Const{LUFactorization{RowMaximum}}}}, LinearSolveEnzymeExt.var"#4#7"}}}
Stacktrace:
 [1] #solve#5
   @ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:193
 [2] solve
   @ C:\Users\accou\.julia\dev\LinearSolve\src\common.jl:190
 [3] #fbatch#207
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:39
 [4] fbatch
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:36
 [5] fbatch
   @ c:\Users\accou\.julia\dev\LinearSolve\test\enzyme.jl:0

Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\compiler.jl:3066

@ChrisRackauckas
Copy link
Member

The solving twice tests are a bit odd:

julia> db1
4-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0

julia> db2
4-element Vector{Float64}:
  2.1215949279204196
 -3.7095838683317943
 -1.2286715744423384
  5.967859589815037

It doubles db2 and has db1 = 0. I think it's because the solve!(linsolve).u aliases between the two. The forward pass is fine because of the copy, but the Enzyme rule likely needs to copy something as well?

@ChrisRackauckas
Copy link
Member

We can skip over that last test to merge, but do you know why that one algorithm would be treated so differently by Enzyme? I would've thought it didn't care if we're capturing stuff in rules, but it treats this algorithm particularly differently:

https://github.com/SciML/LinearSolve.jl/actions/runs/6290016689/job/17077077461?pr=377#step:6:807

@ChrisRackauckas ChrisRackauckas merged commit 1e6150e into SciML:main Sep 24, 2023
15 of 17 checks passed
@wsmoses wsmoses deleted the master branch September 24, 2023 16:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants