-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RNG to Dropout #1618
Add RNG to Dropout #1618
Conversation
function dropout_mask(x, p; dims=:) | ||
y = rand!(similar(x, _dropout_shape(x, dims))) | ||
function dropout_mask(rng::AbstractRNG, x, p; dims=:) | ||
y = rand!(rng, similar(x, _dropout_shape(x, dims))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this dispatches to the GPU intrinsic when rng
is specified. @maleadt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only override default_rng
, I haven't looked into getting GLOBAL_RNG to work (maybe it does).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can switch it for default_rng
too of course
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you have to call it on the GPU. You can't call default_rng
in the Dropout ctor and forward that RNG object to the GPU (unless Dropout objects are created in device code?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it depends on who we give the responsibility to provide a valid rng to (the user or flux) - a proverbial gpu(rng)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really follow the relevance of the linked comment. AFAIK there are no GPU compatible RNGs. There are simply the RNGs provided by CUDA.jl/GPUArrays.jl. to_gpu
is similar to trainable
— it allows us to define how to map leaf in the structure to the GPU without defining CUDA.cu
on types that it shouldn't be defined on. It's out of necessity, but if you can think of an alternative, go for it (I'm not tied to to_gpu
at all).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was referring to the exactly the ones provided by CUDA/CURAND. Sorry if that wasn't clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I don't understand your comment then. This provides a mapping from GLOBAL_RNG
to CUDA.CURAND.default_rng()
. And when a user does Dropout |> gpu
with any other CPU RNG, it throws an error instead of silently ignoring it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the most sensible place to error is when the user tries m |> gpu
on a model that contains Dropout
with CPU RNGs that have no GPU equivalent. It would be weird for m |> gpu
to run successfully when the model isn't executable as is on the GPU. Either way, I don't care how we do it, but as long as we don't silently use a CUDA RNG when x isa CuArray
and throw an error instead, then it's fine by me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - it will error anyway, its more clear when we can say that CUDA needs a way to dispatch to a CUDA RNG. I agree on that.
LGTM, just needs tests and addressing the gpu case. |
See also #1372 |
Switched out Manually adding a dispatch seems to go against the "write your kernel once and run anywhere" process. |
y .= _dropout_kernel.(y, p, 1 - p) | ||
return y | ||
end | ||
|
||
""" | ||
Dropout(p; dims=:) | ||
Dropout([rng = default_rng()], p; dims=:) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataLoader
takes rng
as a keyword argument
Flux.jl/src/data/dataloader.jl
Line 70 in 335286a
function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG) |
we may want to do the same here for consistency.
For the functional form
dropout
instead, it is ok to take rng
as the first positional argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pr seems easier to parse and cleaner api wise. Good point on the data loaders. We should revisit them separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pr seems easier to parse and cleaner api wise.
I don't agree with any of those. Actually having rng as the first positional argument forced you to duplicate the Constructor. And I generally prefer keyword arguments in these cases since user code becomes more self-explanatory.
Merging this PR as it is would mean introducing an inconsistency now and break things later (if we then change DataLoader as well), all of which for no added benefit in my opinion.
Let's hear if @darsnack and @ToucheSir have an opinion on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the constructor was duplicated without reason (its not duplication at all - its forwarded to the other constructor). Flux hasn't really been kwarg forward where it doesn't add value, so we can keep this for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I think kwargs add clarity over positional arguments precisely because their position is irrelevant. In general, you can assume the user is not going to remember the placement of arguments > 3. That being said, having the RNG first is a standard pattern throughout Julia.
I could see dropout([rng], ...)
following this pattern, but it seems really odd for Dropout
. For example, in Distributions, you don't construct Normal([rng], ...)
. I don't think I've ever seen this pattern used in a non-functional form (i.e. constructors). I would also suggest using a keyword argument for the constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, we wouldn't want that. More generally, not everything inside that do block may run on the gpu, and it's more annoying to write out one part of the problem with this block and others not just for rngs. It would make sense for off loading everything to the GPU for example. I don't think we have the seed api for CUDA? Even then that's only relevant for global effects, it may be desirable to have different rngs for Dropout vs (say) initialisation. Personally I would prefer it if we could handle it outside too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one has enough access to model construction to thread through an RNG to a subset of layers
There are limits to this, but if all someone wanted is to pass a specific RNG to certain types of layers, then that can be done with fmap
and the exclude
keyword specified appropriately.
Personally, I prefer the solution where layer-specific arguments are possible. It would match how RNGs are passed into most Julia functions, and conceptually, the RNG/RNG state is very much an "input" to the model. I am struggling to see how we make that happen though. How does Flax handle it? Do you have to write a custom forward pass to utilize a non-default RNG?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, so you are advocating for storing the RNG in the layer struct? My intent was to figure out what scenarios would require doing so, and thus far it doesn't seem like there are many. Using different RNGs for the forward pass vs initialization isn't one of them IMO, because those don't happen at the same time (and thus you can seed/scope the RNG for one and not the other).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I also prefer not to store the RNG in the struct, but I don't see an alternative. If seeding is the only goal, then this PR is not really needed. If different RNG types are required, then I feel that you need scoping, but that can get tricky with different devices like Dhairya mentioned. Only option left seems like the ability to pass through layer-specific arguments on the forward pass. IMO this would be a nice addition if we figured out a good way to do it. It can be used for more than RNGs.
Re fmap
: that was basically me pointing out that if we are forced to store the RNG in the struct, swapping it out doesn't require fine-grained access to the model building.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I was trying to get at (and probably muddled, apologies) is that putting the RNG in the struct isn't necessary to fix #1372. If we instead provided, say, a top-level function to seed all RNGs like PyTorch does, then all this discussion about device compatibility and conversion would be moot. This is to say nothing of edge cases that haven't been discussed: I can already foresee saving/loading behaviour and how it interacts with RNG "tying" across different layers being a headache.
TL;DR adding an RNG struct parameter feels like YAGNI in context and isn't worth the implementation complexity/headaches it'll bring IMO. If we really want to look into this, a more principled investigation (e.g. can we co-opt something like JAX's RNG key design) would be prudent.
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
No, we have device overrides now: EDIT: I misunderstood; that mechanism works only for device code. You'll need additional dispatch to get a RNG suitable for CuArrays. |
src/layers/normalise.jl
Outdated
end | ||
|
||
function (a::Dropout)(x) | ||
_isactive(a) || return x | ||
return dropout(x, a.p; dims=a.dims, active=true) | ||
return dropout(x, a.p; dims = a.dims, active = true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ignores the RNG right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - I did it to check out what we can do to ensure the kernel can actually run with a correct rng - I don't intend to ignore the rng here when merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more intended for having concrete code to weigh out the different approaches
@DhairyaLGandhi Right, so what is the status of #1617, then? |
See #1849 now. |
Closes #1617