-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RNG to Dropout #1618
Add RNG to Dropout #1618
Changes from 7 commits
4ce53a3
c63ded6
1121e7c
0399c37
d9f4927
3a2fcea
3d3fc1e
2996490
0cded38
efc7de3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,3 +1,5 @@ | ||||
using Random: default_rng | ||||
|
||||
istraining() = false | ||||
|
||||
@adjoint istraining() = true, _ -> nothing | ||||
|
@@ -10,7 +12,7 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s | |||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) | ||||
|
||||
""" | ||||
dropout(x, p; dims=:, active=true) | ||||
dropout([rng = default_rng()], x, p; dims=:, active=true) | ||||
|
||||
The dropout function. If `active` is `true`, | ||||
for each input, either sets that input to `0` (with probability | ||||
|
@@ -28,26 +30,36 @@ automatically managed using the [`Dropout`](@ref) layer instead of the | |||
|
||||
The [`Dropout`](@ref) layer is what you should use in most scenarios. | ||||
""" | ||||
function dropout(x, p; dims=:, active::Bool=true) | ||||
function dropout(rng::AbstractRNG, x, p; dims = :, active::Bool = true) | ||||
active || return x | ||||
y = dropout_mask(x, p, dims=dims) | ||||
y = dropout_mask(rng, x, p, dims=dims) | ||||
return x .* y | ||||
end | ||||
|
||||
@adjoint function dropout(x, p; dims=:, active::Bool=true) | ||||
function dropout(x, p; dims = :, active::Bool = true) | ||||
dropout(default_rng(), x, p, dims = dims, active = active) | ||||
end | ||||
|
||||
# CUDA currently needs a manual dispatch to avoid | ||||
# calling a non-GPU RNG with a CuArray | ||||
function dropout(x::CUDA.CuArray, p; dims = :, active::Bool = true) | ||||
dropout(CUDA.CURAND.default_rng(), x, p, dims = dims, active = active) | ||||
end | ||||
|
||||
@adjoint function dropout(rng, x, p; dims = :, active::Bool = true) | ||||
active || return x, Δ -> (Δ, nothing) | ||||
y = dropout_mask(x, p, dims=dims) | ||||
return x .* y, Δ -> (Δ .* y, nothing) | ||||
y = dropout_mask(rng, x, p, dims = dims) | ||||
return x .* y, Δ -> (nothing, Δ .* y, nothing) | ||||
end | ||||
|
||||
function dropout_mask(x, p; dims=:) | ||||
y = rand!(similar(x, _dropout_shape(x, dims))) | ||||
function dropout_mask(rng::AbstractRNG, x, p; dims=:) | ||||
y = rand!(rng, similar(x, _dropout_shape(x, dims))) | ||||
y .= _dropout_kernel.(y, p, 1 - p) | ||||
return y | ||||
end | ||||
|
||||
""" | ||||
Dropout(p; dims=:) | ||||
Dropout([rng = default_rng()], p; dims=:) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Flux.jl/src/data/dataloader.jl Line 70 in 335286a
we may want to do the same here for consistency. For the functional form dropout instead, it is ok to take rng as the first positional argument
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pr seems easier to parse and cleaner api wise. Good point on the data loaders. We should revisit them separately. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't agree with any of those. Actually having rng as the first positional argument forced you to duplicate the Constructor. And I generally prefer keyword arguments in these cases since user code becomes more self-explanatory. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the constructor was duplicated without reason (its not duplication at all - its forwarded to the other constructor). Flux hasn't really been kwarg forward where it doesn't add value, so we can keep this for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, I think kwargs add clarity over positional arguments precisely because their position is irrelevant. In general, you can assume the user is not going to remember the placement of arguments > 3. That being said, having the RNG first is a standard pattern throughout Julia. I could see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, we wouldn't want that. More generally, not everything inside that do block may run on the gpu, and it's more annoying to write out one part of the problem with this block and others not just for rngs. It would make sense for off loading everything to the GPU for example. I don't think we have the seed api for CUDA? Even then that's only relevant for global effects, it may be desirable to have different rngs for Dropout vs (say) initialisation. Personally I would prefer it if we could handle it outside too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There are limits to this, but if all someone wanted is to pass a specific RNG to certain types of layers, then that can be done with Personally, I prefer the solution where layer-specific arguments are possible. It would match how RNGs are passed into most Julia functions, and conceptually, the RNG/RNG state is very much an "input" to the model. I am struggling to see how we make that happen though. How does Flax handle it? Do you have to write a custom forward pass to utilize a non-default RNG? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, so you are advocating for storing the RNG in the layer struct? My intent was to figure out what scenarios would require doing so, and thus far it doesn't seem like there are many. Using different RNGs for the forward pass vs initialization isn't one of them IMO, because those don't happen at the same time (and thus you can seed/scope the RNG for one and not the other). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I also prefer not to store the RNG in the struct, but I don't see an alternative. If seeding is the only goal, then this PR is not really needed. If different RNG types are required, then I feel that you need scoping, but that can get tricky with different devices like Dhairya mentioned. Only option left seems like the ability to pass through layer-specific arguments on the forward pass. IMO this would be a nice addition if we figured out a good way to do it. It can be used for more than RNGs. Re There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I was trying to get at (and probably muddled, apologies) is that putting the RNG in the struct isn't necessary to fix #1372. If we instead provided, say, a top-level function to seed all RNGs like PyTorch does, then all this discussion about device compatibility and conversion would be moot. This is to say nothing of edge cases that haven't been discussed: I can already foresee saving/loading behaviour and how it interacts with RNG "tying" across different layers being a headache. TL;DR adding an RNG struct parameter feels like YAGNI in context and isn't worth the implementation complexity/headaches it'll bring IMO. If we really want to look into this, a more principled investigation (e.g. can we co-opt something like JAX's RNG key design) would be prudent. |
||||
|
||||
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input. | ||||
|
||||
|
@@ -60,20 +72,25 @@ Does nothing to the input once [`Flux.testmode!`](@ref) is `true`. | |||
mutable struct Dropout{F,D} | ||||
p::F | ||||
dims::D | ||||
rng::AbstractRNG | ||||
active::Union{Bool, Nothing} | ||||
end | ||||
|
||||
function Dropout(p; dims=:) | ||||
function Dropout(p; dims = :) | ||||
Dropout(default_rng(), p; dims) | ||||
end | ||||
|
||||
function Dropout(rng, p; dims = :) | ||||
@assert 0 ≤ p ≤ 1 | ||||
Dropout(p, dims, nothing) | ||||
Dropout(p, dims, rng, nothing) | ||||
end | ||||
|
||||
function (a::Dropout)(x) | ||||
_isactive(a) || return x | ||||
return dropout(x, a.p; dims=a.dims, active=true) | ||||
return dropout(x, a.p; dims = a.dims, active = true) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This ignores the RNG right now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right - I did it to check out what we can do to ensure the kernel can actually run with a correct rng - I don't intend to ignore the rng here when merging. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more intended for having concrete code to weigh out the different approaches |
||||
end | ||||
|
||||
testmode!(m::Dropout, mode=true) = | ||||
testmode!(m::Dropout, mode = true) = | ||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) | ||||
|
||||
function Base.show(io::IO, d::Dropout) | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this dispatches to the GPU intrinsic when
rng
is specified. @maleadt?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only override
default_rng
, I haven't looked into getting GLOBAL_RNG to work (maybe it does).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can switch it for
default_rng
too of courseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you have to call it on the GPU. You can't call
default_rng
in the Dropout ctor and forward that RNG object to the GPU (unless Dropout objects are created in device code?).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it depends on who we give the responsibility to provide a valid rng to (the user or flux) - a proverbial
gpu(rng)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really follow the relevance of the linked comment. AFAIK there are no GPU compatible RNGs. There are simply the RNGs provided by CUDA.jl/GPUArrays.jl.
to_gpu
is similar totrainable
— it allows us to define how to map leaf in the structure to the GPU without definingCUDA.cu
on types that it shouldn't be defined on. It's out of necessity, but if you can think of an alternative, go for it (I'm not tied toto_gpu
at all).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was referring to the exactly the ones provided by CUDA/CURAND. Sorry if that wasn't clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I don't understand your comment then. This provides a mapping from
GLOBAL_RNG
toCUDA.CURAND.default_rng()
. And when a user doesDropout |> gpu
with any other CPU RNG, it throws an error instead of silently ignoring it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the most sensible place to error is when the user tries
m |> gpu
on a model that containsDropout
with CPU RNGs that have no GPU equivalent. It would be weird form |> gpu
to run successfully when the model isn't executable as is on the GPU. Either way, I don't care how we do it, but as long as we don't silently use a CUDA RNG whenx isa CuArray
and throw an error instead, then it's fine by me.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - it will error anyway, its more clear when we can say that CUDA needs a way to dispatch to a CUDA RNG. I agree on that.