Skip to content

Commit

Permalink
push batch test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 24, 2023
1 parent b0d228d commit c2ad2db
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@ b1 = rand(n);
db1 = zeros(n);
db12 = zeros(n);

@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))
function fbatch(y, A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
y[1] = norm(s1)
nothing
end

y = [0.0]
dy1 = [1.0]
dy2 = [1.0]
Enzyme.autodiff(Reverse, fbatch, BatchDuplicated(y, (dy1, dy2)), BatchDuplicated(copy(A), (dA, dA2)), BatchDuplicated(copy(b1), (db1, db12)))

dA2 = ForwardDiff.gradient(x->f(x,eltype(x).(b1)), copy(A))
db12 = ForwardDiff.gradient(x->f(eltype(x).(A),x), copy(b1))
Expand Down Expand Up @@ -92,7 +105,7 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
function f3(A, b1, b2; alg = KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = solve!(cache).u
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
Expand Down

0 comments on commit c2ad2db

Please sign in to comment.