Skip to content

Commit

Permalink
add again changes made on website which got lost in a local rebase wi…
Browse files Browse the repository at this point in the history
…thout checking first because I forgot about this for ages
  • Loading branch information
mcabbott committed Oct 17, 2022
1 parent 2bb637a commit e451d15
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "1"
Functors = "0.3"
Yota = "0.7.3"
Yota = "0.8.1"
Zygote = "0.6.40"
julia = "1.6"

Expand Down
19 changes: 9 additions & 10 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,15 @@ end
end

@testset "using Yota" begin
@test_broken Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] # Unexpected expression: $(Expr(:static_parameter, 1))
# These are all broken!
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], ZeroTangent())
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = ZeroTangent(), z = [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = ZeroTangent(), z = [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])

g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1]
@test g5.a[1].x == [0,0,1]
@test g5.a[2] === ZeroTangent()
@test g5.a[2] === nothing

g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1]
@test g6.a == [0,0,0]
Expand All @@ -128,7 +127,7 @@ end
@test g8[3] == [[10.0]]

g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
@test g9.c === ZeroTangent()
@test g9.c === nothing
end
end

Expand Down Expand Up @@ -199,11 +198,11 @@ end
@test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]

v8, re8 = destructure(m8)
@test_broken Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] # MethodError: no method matching zero(::Type{Any})
@test_broken Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] # MethodError: no method matching !(::Expr)
@test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
@test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]

re9 = destructure(m9)[2]
@test_broken Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] # MethodError: no method matching zero(::Type{Array})
@test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
end
end

Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ end

# Make Yota's output look like Zygote's:

Yota_gradient(f, xs...) = Base.tail(Yota.grad(f, xs...)[2])
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
y2z(x) = x

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin
Expand Down

0 comments on commit e451d15

Please sign in to comment.