Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RNG to Dropout #1618

Closed
wants to merge 10 commits into from
Closed

Conversation

DhairyaLGandhi
Copy link
Member

Closes #1617

function dropout_mask(x, p; dims=:)
y = rand!(similar(x, _dropout_shape(x, dims)))
function dropout_mask(rng::AbstractRNG, x, p; dims=:)
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this dispatches to the GPU intrinsic when rng is specified. @maleadt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only override default_rng, I haven't looked into getting GLOBAL_RNG to work (maybe it does).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can switch it for default_rng too of course

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you have to call it on the GPU. You can't call default_rng in the Dropout ctor and forward that RNG object to the GPU (unless Dropout objects are created in device code?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it depends on who we give the responsibility to provide a valid rng to (the user or flux) - a proverbial gpu(rng)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really follow the relevance of the linked comment. AFAIK there are no GPU compatible RNGs. There are simply the RNGs provided by CUDA.jl/GPUArrays.jl. to_gpu is similar to trainable — it allows us to define how to map leaf in the structure to the GPU without defining CUDA.cu on types that it shouldn't be defined on. It's out of necessity, but if you can think of an alternative, go for it (I'm not tied to to_gpu at all).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was referring to the exactly the ones provided by CUDA/CURAND. Sorry if that wasn't clear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I don't understand your comment then. This provides a mapping from GLOBAL_RNG to CUDA.CURAND.default_rng(). And when a user does Dropout |> gpu with any other CPU RNG, it throws an error instead of silently ignoring it.

Copy link
Member

@darsnack darsnack Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the most sensible place to error is when the user tries m |> gpu on a model that contains Dropout with CPU RNGs that have no GPU equivalent. It would be weird for m |> gpu to run successfully when the model isn't executable as is on the GPU. Either way, I don't care how we do it, but as long as we don't silently use a CUDA RNG when x isa CuArray and throw an error instead, then it's fine by me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - it will error anyway, its more clear when we can say that CUDA needs a way to dispatch to a CUDA RNG. I agree on that.

@CarloLucibello
Copy link
Member

LGTM, just needs tests and addressing the gpu case.
A slightly different interface could be given by using a keyword argument for the layer constructor, Dropout(p; rng=GLOBAL_RNG).

@ablaom
Copy link

ablaom commented Jun 15, 2021

See also #1372

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Jun 16, 2021

Switched out GLOBAL_RNG for default_rng, @maleadt do we need to do something differently for it to pick up CUDA.CURAND.default_rng()

Manually adding a dispatch seems to go against the "write your kernel once and run anywhere" process.

y .= _dropout_kernel.(y, p, 1 - p)
return y
end

"""
Dropout(p; dims=:)
Dropout([rng = default_rng()], p; dims=:)
Copy link
Member

@CarloLucibello CarloLucibello Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataLoader takes rng as a keyword argument

function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)

we may want to do the same here for consistency.
For the functional form dropout instead, it is ok to take rng as the first positional argument

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pr seems easier to parse and cleaner api wise. Good point on the data loaders. We should revisit them separately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pr seems easier to parse and cleaner api wise.

I don't agree with any of those. Actually having rng as the first positional argument forced you to duplicate the Constructor. And I generally prefer keyword arguments in these cases since user code becomes more self-explanatory.
Merging this PR as it is would mean introducing an inconsistency now and break things later (if we then change DataLoader as well), all of which for no added benefit in my opinion.
Let's hear if @darsnack and @ToucheSir have an opinion on this

Copy link
Member Author

@DhairyaLGandhi DhairyaLGandhi Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the constructor was duplicated without reason (its not duplication at all - its forwarded to the other constructor). Flux hasn't really been kwarg forward where it doesn't add value, so we can keep this for now.

Copy link
Member

@darsnack darsnack Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think kwargs add clarity over positional arguments precisely because their position is irrelevant. In general, you can assume the user is not going to remember the placement of arguments > 3. That being said, having the RNG first is a standard pattern throughout Julia.

I could see dropout([rng], ...) following this pattern, but it seems really odd for Dropout. For example, in Distributions, you don't construct Normal([rng], ...). I don't think I've ever seen this pattern used in a non-functional form (i.e. constructors). I would also suggest using a keyword argument for the constructor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we wouldn't want that. More generally, not everything inside that do block may run on the gpu, and it's more annoying to write out one part of the problem with this block and others not just for rngs. It would make sense for off loading everything to the GPU for example. I don't think we have the seed api for CUDA? Even then that's only relevant for global effects, it may be desirable to have different rngs for Dropout vs (say) initialisation. Personally I would prefer it if we could handle it outside too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one has enough access to model construction to thread through an RNG to a subset of layers

There are limits to this, but if all someone wanted is to pass a specific RNG to certain types of layers, then that can be done with fmap and the exclude keyword specified appropriately.

Personally, I prefer the solution where layer-specific arguments are possible. It would match how RNGs are passed into most Julia functions, and conceptually, the RNG/RNG state is very much an "input" to the model. I am struggling to see how we make that happen though. How does Flax handle it? Do you have to write a custom forward pass to utilize a non-default RNG?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, so you are advocating for storing the RNG in the layer struct? My intent was to figure out what scenarios would require doing so, and thus far it doesn't seem like there are many. Using different RNGs for the forward pass vs initialization isn't one of them IMO, because those don't happen at the same time (and thus you can seed/scope the RNG for one and not the other).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I also prefer not to store the RNG in the struct, but I don't see an alternative. If seeding is the only goal, then this PR is not really needed. If different RNG types are required, then I feel that you need scoping, but that can get tricky with different devices like Dhairya mentioned. Only option left seems like the ability to pass through layer-specific arguments on the forward pass. IMO this would be a nice addition if we figured out a good way to do it. It can be used for more than RNGs.

Re fmap: that was basically me pointing out that if we are forced to store the RNG in the struct, swapping it out doesn't require fine-grained access to the model building.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was trying to get at (and probably muddled, apologies) is that putting the RNG in the struct isn't necessary to fix #1372. If we instead provided, say, a top-level function to seed all RNGs like PyTorch does, then all this discussion about device compatibility and conversion would be moot. This is to say nothing of edge cases that haven't been discussed: I can already foresee saving/loading behaviour and how it interacts with RNG "tying" across different layers being a headache.

TL;DR adding an RNG struct parameter feels like YAGNI in context and isn't worth the implementation complexity/headaches it'll bring IMO. If we really want to look into this, a more principled investigation (e.g. can we co-opt something like JAX's RNG key design) would be prudent.

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
@maleadt
Copy link
Collaborator

maleadt commented Jun 16, 2021

do we need to do something differently for it to pick up CUDA.CURAND.default_rng()

No, we have device overrides now:

https://github.com/JuliaGPU/CUDA.jl/blob/ee70b71b620edad627f7dc8aa7e3e385a63f8bb8/src/device/random.jl#L95

EDIT: I misunderstood; that mechanism works only for device code. You'll need additional dispatch to get a RNG suitable for CuArrays.

end

function (a::Dropout)(x)
_isactive(a) || return x
return dropout(x, a.p; dims=a.dims, active=true)
return dropout(x, a.p; dims = a.dims, active = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ignores the RNG right now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - I did it to check out what we can do to ensure the kernel can actually run with a correct rng - I don't intend to ignore the rng here when merging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more intended for having concrete code to weigh out the different approaches

@ablaom
Copy link

ablaom commented Jan 25, 2022

@DhairyaLGandhi Right, so what is the status of #1617, then?

@darsnack
Copy link
Member

See #1849 now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow specification of RNG in Dropout
6 participants