Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more conjugates #113

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Add more conjugates #113

wants to merge 2 commits into from

Conversation

xjing76
Copy link
Contributor

@xjing76 xjing76 commented Feb 24, 2023

Trying to add more conjugates.

will be addition with current commit Uniform -Pareto.

@brandonwillard @rlouf Please let me know which of the conjugates should be more prioritized.

@xjing76 xjing76 force-pushed the conjugates branch 2 times, most recently from af45283 to 089e8b3 Compare February 24, 2023 16:12
)
Y_et = etuple(etuplize(at.random.uniform), var(), var(), var(), 1, theta_et)

# new_x_et = at.max(observed_val)
Copy link
Contributor Author

@xjing76 xjing76 Feb 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am having most trouble right here, where I am not sure how to get the at.max(obs, x_m) @brandonwillard

Copy link
Member

@brandonwillard brandonwillard Feb 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does etuple(at.max, observed_val, x_lv) work?

@brandonwillard
Copy link
Member

@brandonwillard @rlouf Please let me know which of the conjugates should be more prioritized.

It would probably be best for us to get all the exponential family models out of the way first; otherwise, I have no particular order/preference.

y_vv = at.iscalar("y")
n_tt = at.scalar("n")

Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the test graph looks like this:

>>> aesara.dprint(Y_rv)
pareto_rv{0, (0, 0), floatX, False}.1 [id A]
 |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FC4162BFF20>) [id B]
 |TensorConstant{[]} [id C]
 |TensorConstant{11} [id D]
 |MaxAndArgmax{axis=()}.0 [id E] 'max'
 | |y [id F]
 |Elemwise{add,no_inplace} [id G]
   |k [id H]
   |n [id I]

and that MaxAndArgmax Op isn't same as the at.math.max used in the etuple graph. at.math.max is a function that constructs a MaxAndArgmax Op and uses it to further construct a graph for the max of its argument. In other words, we need an etuple form/"pattern" that matches the types of graphs output by the helper function at.math.max.

Often the easiest way to find etuple forms for the graphs constructed by helper functions is to etuplize said graphs and spot their generalities.
For example:

>>> from etuples import etuplize
>>> etuplize(at.math.max(at.vector("x")))
e(e(aesara.tensor.math.MaxAndArgmax, (0,)), x)
>>> etuplize(at.math.max(at.matrix("x")))
e(e(aesara.tensor.math.MaxAndArgmax, (0, 1)), x)

As we can see, the axis property in the MaxAndArgmax Op will change according to the dimensions of the input (i.e. it computes the max across all dimensions), so we don't want to use a very specific value for the matching form. Instead, we can use another logic variable in place of those values.

Here's a general testing setup for that part of the problem:

import aesara
import aesara.tensor as at

from etuples import etuplize, etuple


srng = at.random.RandomStream(0)

k_tt = at.scalar("k")
y_vv = at.iscalar("y")
n_tt = at.scalar("n")

Y_rv = srng.pareto(at.max(y_vv), k_tt + n_tt)

# This is what we need to match/unify:
etuplize(Y_rv)
# e(
#     e(aesara.tensor.random.basic.ParetoRV, 'pareto', 0, (0, 0), 'floatX', False),
#     RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FA1F9B3D9E0>),
#     TensorConstant{[]},
#     TensorConstant{11},
#     e(e(aesara.tensor.math.MaxAndArgmax, ()), y),
#     e(
#         e(
#             aesara.tensor.elemwise.Elemwise,
#             <aesara.scalar.basic.Add at 0x7fa1fd3823d0>,
#             <frozendict {}>),
#         k,
#         n))

from unification import var
from kanren import run, eq
from aesara.tensor.math import MaxAndArgmax


observed_val = var()
axis_lv = var()
new_x_et = etuple(etuple(MaxAndArgmax, axis_lv), observed_val)

k_lv, n_lv = var(), var()
new_k_et = etuple(etuplize(at.add), k_lv, n_lv)

theta_rng_lv = var()
theta_size_lv = var()
theta_type_idx_lv = var()
theta_posterior_et = etuple(
    etuplize(at.random.pareto),
    theta_rng_lv,
    theta_size_lv,
    theta_type_idx_lv,
    new_x_et,
    new_k_et,
)


run(0, (new_x_et, new_k_et), eq(Y_rv, theta_posterior_et))
# ((e(e(aesara.tensor.math.MaxAndArgmax, ()), y),
#   e(
#       e(
#           aesara.tensor.elemwise.Elemwise,
#           <aesara.scalar.basic.Add at 0x7fa1fd3823d0>,
#           <frozendict {}>),
#       k,
#       n)),)

@codecov
Copy link

codecov bot commented Mar 24, 2023

Codecov Report

Patch coverage: 50.00% and project coverage change: -3.76 ⚠️

Comparison is base (64b0e50) 98.50% compared to head (ae0af70) 94.75%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #113      +/-   ##
==========================================
- Coverage   98.50%   94.75%   -3.76%     
==========================================
  Files          10       11       +1     
  Lines         737      820      +83     
  Branches       63       69       +6     
==========================================
+ Hits          726      777      +51     
- Misses          4       33      +29     
- Partials        7       10       +3     
Impacted Files Coverage Δ
aemcmc/conjugates.py 81.69% <50.00%> (-18.31%) ⬇️

... and 2 files with indirect coverage changes

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@xjing76
Copy link
Contributor Author

xjing76 commented Mar 24, 2023

Sorry, have not been on top of this. (After Covid and rounds of strep ), I will try to add at least one conjugate a day from now on.

@brandonwillard
Copy link
Member

brandonwillard commented Mar 24, 2023

Sorry, have not been on top of this. (After Covid and rounds of strep ), I will try to add at least one conjugate a day from now on.

Ouch! No problem at all; we really appreciate what you've done already. I still need to provide a solution for the default updates on the re-used RNG states anyway, so no rush.

N.B. This comment describes the issue.

@xjing76
Copy link
Contributor Author

xjing76 commented Mar 25, 2023

Great! I just realized that the coverage testing suggesting that I have entire local_uniform_pareto_posterior and local_beta_bernoulli_posterior missed. I probably got something wrong.. https://app.codecov.io/gh/aesara-devs/aemcmc/pull/113/blob/aemcmc/conjugates.py#L332

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants