Skip to content

reporting error from parallel computing using blackjax: RWState(position='ShapedArray(float32[1])', logdensity='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'). #695

Answered by junpenglao
yyang97 asked this question in Q&A
Discussion options

You must be logged in to vote

it meant the output of the logdensity function

def normal_post(theta,xobs):
    return jnp.sum(jax.scipy.stats.norm.logpdf(xobs,loc = theta))

Returns float32 initially, and after 1 step it returns float64.
Did you try also casting xobs to np.float64?

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@yyang97
Comment options

@junpenglao
Comment options

Answer selected by yyang97
@yyang97
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants