-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Internal error while using Zygote #180
Comments
Code without having to open a file
#
# Repro case code was based on code from: https://github.com/awf/autodiff
#
using Pkg
using Printf
using SpecialFunctions
using LinearAlgebra
using Zygote
using Zygote: @adjoint
struct Wishart
gamma::Float64
m::Int
end
function pack(alphas,means,icf)
[alphas[:];means[:];icf[:]]
end
function unpack(d,k,packed)
alphas = reshape(packed[1:k],1,k)
off = k
means = reshape(packed[(1:d*k) .+ off],d,k)
icf_sz = div(d*(d + 1),2)
off += d*k
icf = reshape(packed[off+1:end],icf_sz,k)
(alphas,means,icf)
end
sumsq(v) = sum(abs2, v)
function ltri_unpack(D, LT)
d=length(D)
make_row(r::Int, L) = hcat(reshape([ L[i] for i=1:r-1 ],1,r-1), D[r], zeros(1,d-r))
row_start(r::Int) = div((r-1)*(r-2),2)
inds(r) = row_start(r) .+ (1:r-1)
vcat([ make_row(r, LT[inds(r)]) for r=1:d ]...)
end
function get_Q(d,icf)
ltri_unpack(exp.(icf[1:d]),icf[d+1:end])
end
function log_gamma_distrib(a, p)
out = 0.25 * p * (p - 1) * 1.1447298858494002 #convert(Float64, log(pi))
for j in 1:p
out += lgamma(a + 0.5*(1 - j))
end
out
end
function log_wishart_prior(wishart::Wishart, sum_qs, Qs)
p = size(Qs[1],1)
n = p + wishart.m + 1
C = n*p*(log(wishart.gamma) - 0.5*log(2)) - log_gamma_distrib(0.5*n, p)
frobenius = sum(abs2, Qs)
# frobenius = 0.
# for Q in Qs
# frobenius += sum(abs2,diag(Q))
# end
# frobenius += sum(abs2,icf[d+1:end,:])
# @show icf[d+1:end,:]
# @show icf
# @show Qs
0.5*wishart.gamma^2 * frobenius - wishart.m*sum(sum_qs) - k*C
end
# input should be 1 dimensional
function logsumexp(x)
mx = maximum(x)
log(sum(exp.(x .- mx))) + mx
end
function diagsums(Qs)
mapslices(slice -> sum(diag(slice)), Qs; dims=[1,2])
end
@adjoint function diagsums(Qs)
diagsums(Qs),
function (Δ)
Δ′ = zero(Qs)
for (i, δ) in enumerate(Δ)
for j in 1:size(Qs, 1)
Δ′[j,j,i] = δ
end
end
(Δ′,)
end
end
function expdiags(Qs)
mapslices(Qs; dims=[1,2]) do slice
slice[diagind(slice)] .= exp.(slice[diagind(slice)])
slice
end
end
@adjoint function expdiags(Qs)
expdiags(Qs),
function (Δ)
Δ′ = zero(Qs)
Δ′ .= Δ
for i in 1:size(Qs, 3)
for j in 1:size(Qs, 1)
Δ′[j,j,i] *= exp(Qs[j,j,i])
end
end
(Δ′,)
end
end
Base.:*(::Float64, ::Nothing) = nothing
function gmm_objective(alphas,means,Qs,x,wishart::Wishart)
d = size(x,1)
n = size(x,2)
CONSTANT = -n*d*0.5*log(2 * pi)
sum_qs = reshape(diagsums(Qs), 1, size(Qs, 3))
slse = sum(sum_qs)
Qs = expdiags(Qs)
main_term = zeros(Float64,1,k)
slse = 0.
for ix=1:n
formula(ik) = -0.5*sum(abs2, Qs[:, :, ik] * (x[:,ix] .- means[:,ik]))
sumexp = 0.
for ik=1:k
sumexp += exp(formula(ik) + alphas[ik] + sum_qs[ik])
end
slse += log(sumexp)
end
CONSTANT + slse - n*logsumexp(alphas) + log_wishart_prior(wishart, sum_qs, Qs)
end
alphas = randn(1,10)
means = rand(10,10)
icf = randn(55,10)
x = randn(10,10000)
wishart = Wishart(1.0,0)
d = size(means,1)
k = size(means,2)
n = size(x,2)
const Qs = cat([get_Q(d,icf[:,ik]) for ik in 1:k]...; dims=[3])
# Objective
# Call once in case of precompilation etc
err = gmm_objective(alphas,means,Qs,x,wishart)
function wrapper_gmm_objective(alphas, means, Qs)
gmm_objective(alphas,means,Qs,x,wishart)
end
# Gradient
g = (alphas, means, Qs)-> Zygote.gradient(wrapper_gmm_objective, alphas, means, Qs)
J = g(alphas, means, Qs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The attached repro file generates an internal error.
This was discovered while playing around with Zygote, trying to adapt a benchmark that works with ForwardDiff and Flux.Tracker. It was run on Windows 10.
This was initially opened against Julia, but Keno requested that it be moved to Zygote for further triage.
The text was updated successfully, but these errors were encountered: