Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RNG to Dropout #1618
Add RNG to Dropout #1618
Changes from 4 commits
4ce53a3
c63ded6
1121e7c
0399c37
d9f4927
3a2fcea
3d3fc1e
2996490
0cded38
efc7de3
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this dispatches to the GPU intrinsic when
rng
is specified. @maleadt?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only override
default_rng
, I haven't looked into getting GLOBAL_RNG to work (maybe it does).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can switch it for
default_rng
too of courseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you have to call it on the GPU. You can't call
default_rng
in the Dropout ctor and forward that RNG object to the GPU (unless Dropout objects are created in device code?).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it depends on who we give the responsibility to provide a valid rng to (the user or flux) - a proverbial
gpu(rng)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really follow the relevance of the linked comment. AFAIK there are no GPU compatible RNGs. There are simply the RNGs provided by CUDA.jl/GPUArrays.jl.
to_gpu
is similar totrainable
— it allows us to define how to map leaf in the structure to the GPU without definingCUDA.cu
on types that it shouldn't be defined on. It's out of necessity, but if you can think of an alternative, go for it (I'm not tied toto_gpu
at all).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was referring to the exactly the ones provided by CUDA/CURAND. Sorry if that wasn't clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I don't understand your comment then. This provides a mapping from
GLOBAL_RNG
toCUDA.CURAND.default_rng()
. And when a user doesDropout |> gpu
with any other CPU RNG, it throws an error instead of silently ignoring it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the most sensible place to error is when the user tries
m |> gpu
on a model that containsDropout
with CPU RNGs that have no GPU equivalent. It would be weird form |> gpu
to run successfully when the model isn't executable as is on the GPU. Either way, I don't care how we do it, but as long as we don't silently use a CUDA RNG whenx isa CuArray
and throw an error instead, then it's fine by me.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right - it will error anyway, its more clear when we can say that CUDA needs a way to dispatch to a CUDA RNG. I agree on that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataLoader
takesrng
as a keyword argumentFlux.jl/src/data/dataloader.jl
Line 70 in 335286a
we may want to do the same here for consistency.
For the functional form
dropout
instead, it is ok to takerng
as the first positional argumentThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pr seems easier to parse and cleaner api wise. Good point on the data loaders. We should revisit them separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't agree with any of those. Actually having rng as the first positional argument forced you to duplicate the Constructor. And I generally prefer keyword arguments in these cases since user code becomes more self-explanatory.
Merging this PR as it is would mean introducing an inconsistency now and break things later (if we then change DataLoader as well), all of which for no added benefit in my opinion.
Let's hear if @darsnack and @ToucheSir have an opinion on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the constructor was duplicated without reason (its not duplication at all - its forwarded to the other constructor). Flux hasn't really been kwarg forward where it doesn't add value, so we can keep this for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I think kwargs add clarity over positional arguments precisely because their position is irrelevant. In general, you can assume the user is not going to remember the placement of arguments > 3. That being said, having the RNG first is a standard pattern throughout Julia.
I could see
dropout([rng], ...)
following this pattern, but it seems really odd forDropout
. For example, in Distributions, you don't constructNormal([rng], ...)
. I don't think I've ever seen this pattern used in a non-functional form (i.e. constructors). I would also suggest using a keyword argument for the constructor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, we wouldn't want that. More generally, not everything inside that do block may run on the gpu, and it's more annoying to write out one part of the problem with this block and others not just for rngs. It would make sense for off loading everything to the GPU for example. I don't think we have the seed api for CUDA? Even then that's only relevant for global effects, it may be desirable to have different rngs for Dropout vs (say) initialisation. Personally I would prefer it if we could handle it outside too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are limits to this, but if all someone wanted is to pass a specific RNG to certain types of layers, then that can be done with
fmap
and theexclude
keyword specified appropriately.Personally, I prefer the solution where layer-specific arguments are possible. It would match how RNGs are passed into most Julia functions, and conceptually, the RNG/RNG state is very much an "input" to the model. I am struggling to see how we make that happen though. How does Flax handle it? Do you have to write a custom forward pass to utilize a non-default RNG?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, so you are advocating for storing the RNG in the layer struct? My intent was to figure out what scenarios would require doing so, and thus far it doesn't seem like there are many. Using different RNGs for the forward pass vs initialization isn't one of them IMO, because those don't happen at the same time (and thus you can seed/scope the RNG for one and not the other).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I also prefer not to store the RNG in the struct, but I don't see an alternative. If seeding is the only goal, then this PR is not really needed. If different RNG types are required, then I feel that you need scoping, but that can get tricky with different devices like Dhairya mentioned. Only option left seems like the ability to pass through layer-specific arguments on the forward pass. IMO this would be a nice addition if we figured out a good way to do it. It can be used for more than RNGs.
Re
fmap
: that was basically me pointing out that if we are forced to store the RNG in the struct, swapping it out doesn't require fine-grained access to the model building.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I was trying to get at (and probably muddled, apologies) is that putting the RNG in the struct isn't necessary to fix #1372. If we instead provided, say, a top-level function to seed all RNGs like PyTorch does, then all this discussion about device compatibility and conversion would be moot. This is to say nothing of edge cases that haven't been discussed: I can already foresee saving/loading behaviour and how it interacts with RNG "tying" across different layers being a headache.
TL;DR adding an RNG struct parameter feels like YAGNI in context and isn't worth the implementation complexity/headaches it'll bring IMO. If we really want to look into this, a more principled investigation (e.g. can we co-opt something like JAX's RNG key design) would be prudent.