Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal error while using Zygote #180

Open
NiklasGustafsson opened this issue May 7, 2019 · 1 comment
Open

Internal error while using Zygote #180

NiklasGustafsson opened this issue May 7, 2019 · 1 comment

Comments

@NiklasGustafsson
Copy link

The attached repro file generates an internal error.

Exception: EXCEPTION_ACCESS_VIOLATION at 0x5cf13ed55 -- _ZN4llvm13LiveRangeCalc9calculateERNS_12LiveIntervalEb at C:\dd\Julia-1.1.0\bin\LLVM.dll (unknown line)
…

This was discovered while playing around with Zygote, trying to adapt a benchmark that works with ForwardDiff and Flux.Tracker. It was run on Windows 10.

This was initially opened against Julia, but Keno requested that it be moved to Zygote for further triage.

@KristofferC
Copy link
Contributor

Code without having to open a file

#
# Repro case code was based on code from: https://github.com/awf/autodiff
#

using Pkg
using Printf
using SpecialFunctions
using LinearAlgebra
using Zygote
using Zygote: @adjoint

struct Wishart
  gamma::Float64
  m::Int
end

function pack(alphas,means,icf)
  [alphas[:];means[:];icf[:]]
end

function unpack(d,k,packed)
  alphas = reshape(packed[1:k],1,k)
  off = k
  means = reshape(packed[(1:d*k) .+ off],d,k)
  icf_sz = div(d*(d + 1),2)
  off += d*k
  icf = reshape(packed[off+1:end],icf_sz,k)
  (alphas,means,icf)
end

sumsq(v) = sum(abs2, v)

function ltri_unpack(D, LT)
  d=length(D)
  make_row(r::Int, L) = hcat(reshape([ L[i] for i=1:r-1 ],1,r-1), D[r], zeros(1,d-r))
  row_start(r::Int) = div((r-1)*(r-2),2)
  inds(r) = row_start(r) .+ (1:r-1)
  vcat([ make_row(r, LT[inds(r)]) for r=1:d ]...)
end

function get_Q(d,icf)
  ltri_unpack(exp.(icf[1:d]),icf[d+1:end])
end

function log_gamma_distrib(a, p)
  out = 0.25 * p * (p - 1) * 1.1447298858494002 #convert(Float64, log(pi))
	for j in 1:p
    out += lgamma(a + 0.5*(1 - j))
  end
	out
end

function log_wishart_prior(wishart::Wishart, sum_qs, Qs)
  p = size(Qs[1],1)
  n = p + wishart.m + 1
  C = n*p*(log(wishart.gamma) - 0.5*log(2)) - log_gamma_distrib(0.5*n, p)

  frobenius = sum(abs2, Qs)
  # frobenius = 0.
  # for Q in Qs
  #   frobenius += sum(abs2,diag(Q))
  # end
  # frobenius += sum(abs2,icf[d+1:end,:])
  # @show icf[d+1:end,:]
  # @show icf
  # @show Qs
	0.5*wishart.gamma^2 * frobenius - wishart.m*sum(sum_qs) - k*C
end


# input should be 1 dimensional
function logsumexp(x)
  mx = maximum(x)
  log(sum(exp.(x .- mx))) + mx
end

function diagsums(Qs)
  mapslices(slice -> sum(diag(slice)), Qs; dims=[1,2])
end

@adjoint function diagsums(Qs)
  diagsums(Qs),
  function (Δ)
      Δ′ = zero(Qs)
      for (i, δ) in enumerate(Δ)
          for j in 1:size(Qs, 1)
              Δ′[j,j,i] = δ
          end
      end
      (Δ′,)
  end
end

function expdiags(Qs)
  mapslices(Qs; dims=[1,2]) do slice
    slice[diagind(slice)] .= exp.(slice[diagind(slice)])
    slice
  end
end

@adjoint function expdiags(Qs)
  expdiags(Qs),
  function (Δ)
      Δ′ = zero(Qs)
      Δ′ .= Δ
      for i in 1:size(Qs, 3)
          for j in 1:size(Qs, 1)
              Δ′[j,j,i] *= exp(Qs[j,j,i])
          end
      end
      (Δ′,)
  end
end

Base.:*(::Float64, ::Nothing) = nothing

function gmm_objective(alphas,means,Qs,x,wishart::Wishart)
  d = size(x,1)
  n = size(x,2)
  CONSTANT = -n*d*0.5*log(2 * pi)
  sum_qs = reshape(diagsums(Qs), 1, size(Qs, 3))
  slse = sum(sum_qs)
  Qs = expdiags(Qs)

  main_term = zeros(Float64,1,k)

  slse = 0.
  for ix=1:n
    formula(ik) = -0.5*sum(abs2, Qs[:, :, ik] * (x[:,ix] .- means[:,ik]))
    sumexp = 0.
    for ik=1:k
      sumexp += exp(formula(ik) + alphas[ik] + sum_qs[ik])
    end
    slse += log(sumexp)
  end

  CONSTANT + slse - n*logsumexp(alphas) + log_wishart_prior(wishart, sum_qs, Qs)
end

alphas = randn(1,10)
means = rand(10,10)
icf = randn(55,10)
x = randn(10,10000)
wishart = Wishart(1.0,0)

d = size(means,1)
k = size(means,2)
n = size(x,2)

const Qs = cat([get_Q(d,icf[:,ik]) for ik in 1:k]...; dims=[3])

# Objective
# Call once in case of precompilation etc
err = gmm_objective(alphas,means,Qs,x,wishart)

function wrapper_gmm_objective(alphas, means, Qs)
  gmm_objective(alphas,means,Qs,x,wishart)
end

# Gradient
g = (alphas, means, Qs)-> Zygote.gradient(wrapper_gmm_objective, alphas, means, Qs)

J = g(alphas, means, Qs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants