From 86a23bbe9ff67decfd18315e95f51b7ca962dbb3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 17 Oct 2022 14:40:21 -0400 Subject: [PATCH] add again changes made on website which got lost in a local rebase without checking first because I forgot about this for ages --- Project.toml | 2 +- test/destructure.jl | 19 +++++++++---------- test/runtests.jl | 5 ++++- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 7cf05fad..ef4a65da 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" Functors = "0.3, 0.4" -Yota = "0.7.3" +Yota = "0.8.1" Zygote = "0.6.40" julia = "1.6" diff --git a/test/destructure.jl b/test/destructure.jl index fee24368..592b9be9 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -106,16 +106,15 @@ end end @testset "using Yota" begin - @test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1)) - # These are all broken! + @test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] @test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) - @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent()) - @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0]) - @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) + @test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1] @test g5.a[1].x == [0,0,1] - @test g5.a[2] === ZeroTangent() + @test g5.a[2] === nothing g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1] @test g6.a == [0,0,0] @@ -128,7 +127,7 @@ end @test g8[3] == [[10.0]] g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1] - @test g9.c === ZeroTangent() + @test g9.c === nothing end end @@ -199,11 +198,11 @@ end @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] v8, re8 = destructure(m8) - @test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any}) - @test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr) + @test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] + @test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] re9 = destructure(m9)[2] - @test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array}) + @test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] end end diff --git a/test/runtests.jl b/test/runtests.jl index 294f5a6f..23d474c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,7 +39,10 @@ end # Make Yota's output look like Zygote's: -Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2]) +Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2])) +y2z(::AbstractZero) = nothing # we don't care about different flavours of zero +y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples! +y2z(x) = x @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin