diff --git a/test/enzyme.jl b/test/enzyme.jl index 8f6d213c0..dbacb70f1 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -27,4 +27,18 @@ db12 = FiniteDiff.finite_difference_gradient(x->f(A,x, b2), copy(b1)) @test dA ≈ dA2 @test db1 ≈ db12 -@test db2 == zeros(4) \ No newline at end of file +@test db2 == zeros(4) + +A = rand(n, n); +dA = zeros(n, n); +dA2 = zeros(n, n); +b1 = rand(n); +db1 = zeros(n); +db12 = zeros(n); + +b2 = rand(n); +db2 = zeros(n); +db22 = zeros(n); + +@test_broken Enzyme.autodiff(Reverse, f, Duplicated(copy(A), dA), BatchDuplicated(copy(b1), (db1, db12)), BatchDuplicated(copy(b2), (db2, db22))) +@test_broken Enzyme.autodiff(Reverse, f, BatchDuplicated(copy(A), (dA, dA2)), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2)) \ No newline at end of file