diff --git a/test/destructure.jl b/test/destructure.jl index fee24368..2cffcf77 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -106,16 +106,15 @@ end end @testset "using Yota" begin - @test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1)) - # These are all broken! + @test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) - @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent()) - @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0]) - @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) + @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1] @test g5.a[1].x == [0,0,1] - @test g5.a[2] === ZeroTangent() + @test g5.a[2] === nothing g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1] @test g6.a == [0,0,0] @@ -128,7 +127,7 @@ end @test g8[3] == [[10.0]] g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] - @test g9.c === ZeroTangent() + @test g9.c ===nothing end end @@ -199,11 +198,11 @@ end @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] v8, re8 = destructure(m8) - @test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any}) - @test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr) + @test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] + @test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] re9 = destructure(m9)[2] - @test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array}) + @test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] end end