# # Repro case code was based on code from: https://github.com/awf/autodiff # using Pkg using Printf using SpecialFunctions using LinearAlgebra using Zygote using Zygote: @adjoint struct Wishart gamma::Float64 m::Int end function pack(alphas,means,icf) [alphas[:];means[:];icf[:]] end function unpack(d,k,packed) alphas = reshape(packed[1:k],1,k) off = k means = reshape(packed[(1:d*k) .+ off],d,k) icf_sz = div(d*(d + 1),2) off += d*k icf = reshape(packed[off+1:end],icf_sz,k) (alphas,means,icf) end sumsq(v) = sum(abs2, v) function ltri_unpack(D, LT) d=length(D) make_row(r::Int, L) = hcat(reshape([ L[i] for i=1:r-1 ],1,r-1), D[r], zeros(1,d-r)) row_start(r::Int) = div((r-1)*(r-2),2) inds(r) = row_start(r) .+ (1:r-1) vcat([ make_row(r, LT[inds(r)]) for r=1:d ]...) end function get_Q(d,icf) ltri_unpack(exp.(icf[1:d]),icf[d+1:end]) end function log_gamma_distrib(a, p) out = 0.25 * p * (p - 1) * 1.1447298858494002 #convert(Float64, log(pi)) for j in 1:p out += lgamma(a + 0.5*(1 - j)) end out end function log_wishart_prior(wishart::Wishart, sum_qs, Qs) p = size(Qs[1],1) n = p + wishart.m + 1 C = n*p*(log(wishart.gamma) - 0.5*log(2)) - log_gamma_distrib(0.5*n, p) frobenius = sum(abs2, Qs) # frobenius = 0. # for Q in Qs # frobenius += sum(abs2,diag(Q)) # end # frobenius += sum(abs2,icf[d+1:end,:]) # @show icf[d+1:end,:] # @show icf # @show Qs 0.5*wishart.gamma^2 * frobenius - wishart.m*sum(sum_qs) - k*C end # input should be 1 dimensional function logsumexp(x) mx = maximum(x) log(sum(exp.(x .- mx))) + mx end function diagsums(Qs) mapslices(slice -> sum(diag(slice)), Qs; dims=[1,2]) end @adjoint function diagsums(Qs) diagsums(Qs), function (Δ) Δ′ = zero(Qs) for (i, δ) in enumerate(Δ) for j in 1:size(Qs, 1) Δ′[j,j,i] = δ end end (Δ′,) end end function expdiags(Qs) mapslices(Qs; dims=[1,2]) do slice slice[diagind(slice)] .= exp.(slice[diagind(slice)]) slice end end @adjoint function expdiags(Qs) expdiags(Qs), function (Δ) Δ′ = zero(Qs) Δ′ .= Δ for i in 1:size(Qs, 3) for j in 1:size(Qs, 1) Δ′[j,j,i] *= exp(Qs[j,j,i]) end end (Δ′,) end end Base.:*(::Float64, ::Nothing) = nothing function gmm_objective(alphas,means,Qs,x,wishart::Wishart) d = size(x,1) n = size(x,2) CONSTANT = -n*d*0.5*log(2 * pi) sum_qs = reshape(diagsums(Qs), 1, size(Qs, 3)) slse = sum(sum_qs) Qs = expdiags(Qs) main_term = zeros(Float64,1,k) slse = 0. for ix=1:n formula(ik) = -0.5*sum(abs2, Qs[:, :, ik] * (x[:,ix] .- means[:,ik])) sumexp = 0. for ik=1:k sumexp += exp(formula(ik) + alphas[ik] + sum_qs[ik]) end slse += log(sumexp) end CONSTANT + slse - n*logsumexp(alphas) + log_wishart_prior(wishart, sum_qs, Qs) end alphas = randn(1,10) means = rand(10,10) icf = randn(55,10) x = randn(10,10000) wishart = Wishart(1.0,0) d = size(means,1) k = size(means,2) n = size(x,2) const Qs = cat([get_Q(d,icf[:,ik]) for ik in 1:k]...; dims=[3]) # Objective # Call once in case of precompilation etc err = gmm_objective(alphas,means,Qs,x,wishart) function wrapper_gmm_objective(alphas, means, Qs) gmm_objective(alphas,means,Qs,x,wishart) end # Gradient g = (alphas, means, Qs)-> Zygote.gradient(wrapper_gmm_objective, alphas, means, Qs) J = g(alphas, means, Qs)