Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In-place the RandomVariables within a Scan #543

Open
Tracked by #1426
brandonwillard opened this issue Jul 31, 2021 · 1 comment
Open
Tracked by #1426

In-place the RandomVariables within a Scan #543

brandonwillard opened this issue Jul 31, 2021 · 1 comment
Labels
enhancement New feature or request graph rewriting help wanted Extra attention is needed random variables Involves random variables and/or sampling Scan Involves the `Scan` `Op`

Comments

@brandonwillard
Copy link
Member

brandonwillard commented Jul 31, 2021

RandomVariable Ops inside of Scans don't appear to be in-placed:

import aesara
import aesara.tensor as at


srng = at.random.RandomStream(23)

res, updates = aesara.scan(
    lambda: srng.normal(),
    n_steps=10,
    strict=True,
)

fn = aesara.function([], res, updates=updates, mode="FAST_RUN")

aesara.dprint(fn)
# for{cpu,scan_fn}.0 [id A] ''   0
#  |TensorConstant{10} [id B]
#  |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F82A7A1A320>) [id C]
#  |TensorConstant{10} [id B]
# for{cpu,scan_fn}.1 [id A] ''   0
#
# Inner graphs of the scan ops:
#
# for{cpu,scan_fn}.0 [id A] ''
#  >normal_rv{0, (0, 0), floatX, False}.1 [id D] ''
#  > |<RandomGeneratorType> [id E] -> [id C]
#  > |TensorConstant{[]} [id F]
#  > |TensorConstant{11} [id G]
#  > |TensorConstant{0.0} [id H]
#  > |TensorConstant{1.0} [id I]
#  >normal_rv{0, (0, 0), floatX, False}.0 [id D] ''
#
# for{cpu,scan_fn}.1 [id A] ''
#  >normal_rv{0, (0, 0), floatX, False}.1 [id D] ''
#  >normal_rv{0, (0, 0), floatX, False}.0 [id D] ''

The Falses in the normal_rv string outputs are the boolean inplace values, and they're indicating that the random_make_inplace optimization hasn't been applied to those RandomVariables.

When inplace is False, the RNG state is intentionally copied (see here), which is rather wasteful when it's not required/desired and confusing because it causes sampled values to repeat within the loop body.

@brandonwillard brandonwillard added enhancement New feature or request help wanted Extra attention is needed graph rewriting random variables Involves random variables and/or sampling labels Jul 31, 2021
@brandonwillard
Copy link
Member Author

The reason it's not being in-placed: the Supervisor Feature says that the RandomGeneratorType instance generated for the srng.normal() is "protected".

The Supervisor instances are created here, and "protected" values are the fgraph.inputs that are mutable (e.g. shared variables) or have destroyers.

The relevant graph compilation is happening in Scan.make_thunk, when the inner-graph of the Scan is being compiled. The input variable that's being "protected" is the dummy input variable that represents the actual shared RandomGeneratorType created by srng. Since the actual shared variable is mutable, this seems like a very unnecessary restriction.

It looks like we can work around this issue by setting mutable = True for the appropriate wrapped_inputs here. We could probably tag the appropriate inner-inputs somewhere around here using the mutable property already associated with each input in dummy_f.maker.expanded_inputs, and then we can use that tag later in Scan.make_thunk.

This whole thing does bring into question the role of the updates returned by Scan in such a case, because it seems like we wouldn't need the updates if they're being updated in-place...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting help wanted Extra attention is needed random variables Involves random variables and/or sampling Scan Involves the `Scan` `Op`
Projects
None yet
Development

No branches or pull requests

1 participant