Skip to content

Haiku integration #5

Answered by C-J-Cundy
MilesCranmer asked this question in Q&A
Mar 18, 2021 · 2 comments · 2 replies
Discussion options

You must be logged in to vote

Hi Miles,
Giving my two cents as a user, this is what I use to integrate a flow from NuX into a bigger model:

    flow = nux.Flow(create_flow, rand_key, train_inputs, batch_axes=(0,))

    def sample_flow(params, state, rng_key, n):
        # Sample from base distribution, i.e. normal
        samples = rnd.normal(rng_key, shape=(n, input_shape))
        # samples = rnd.sample(...)
        output = flow.stateful_apply(
            rng_key, {"x": samples}, params, None, sample=True, reconstruction=True
        )
        return output[0]["x"], output[0]["log_px"], output[1]  # I *think* this is right 

Then sample_flow is a pure function mapping params, state, rng_key to samples and log prob…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
2 replies
@MilesCranmer
Comment options

@EddieCunningham
Comment options

Answer selected by MilesCranmer
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants