-
-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BatchNorm1D #476
Comments
Started working on it here: |
FYI, after merging #482, you may need to change your implementation a little. |
@nkoppel Thanks for letting me know, I will sync my fork. I think one unclear point is where and how to put the batchnorm layer in the syntax of an nn in dfdx. The traditional way would be to do Linear -> BatchNorm -> Activation, but I have seen some research from Chollet that argues going Batchnorm -> Linear -> Activation might actually be better. The latter might be also much simpler to implement in dfdx? Eg with something like
The other way to do it would be to already add a Batchnorm option in the tuple and be able to directly write
I find this much nicer, but it might need more tinkering with the code. Which is why I wanted to have a deeper look into the 2D version and follow what was done there (haven't done it yet). Let me know what you both think (@coreylowman too, curious what your view on this is) |
Ok, I checked the code of BatchNorm2D and everything works out of the box for 1D by adding inference and training implementations to BatchNorm2D as follows:
I also realized after I understood a bit more in depth the code, that what I was talking above was practically nonsense, since Modules can already be stacked together any way we like. I confirmed it with the MNIST example:
Each epoch now takes 1.5x longer on CPU (I guess to be expected with the batchnorm layers), but the nn also converges faster. I would rename the current BatchNorm2D to BatchNorm (or _BatchNorm) and have BatchNorm1D and BatchNorm2D "inherit" from it if that is ok. If you think there is a better way, suggestions are very welcome. |
I'm fairly certain the bottleneck is the binary operations backward when dealing with broadcasted arrays. I think sum() backwards is also a big slowdown as well. I think #491 is probably the most relevant issue for this.
Gotcha. I think we might actually be able to get away with just moving train_fwd and infer_fwd to being functions that accept all the tensors instead of structs, and then the different structs can just call them directly: fn train_fwd<...>(
x: Tensor<S, E, D,T>,
running_mean: &mut Tensor<Rank1<C>, E, D>,
running_bias: &mut Tensor,
scale: Tensor,
bias: Tensor,
epsilon: E
) |
Ok, this sounds interesting, I 'll try to have a look at it once done with the current issue. Since there is already a more specific issue, there is no point to open a new one ftm I guess.
Great, sounds good, let's do it like that, it keeps things simple. On it! |
Just checking in, because it has been a couple of days. I have been struggling a bit to make it work with both ways of trying to deduplicate the code:
Restricting the bound as suggested has created other problems elsewhere. I have also tried to add the implementations myself, but not succesfully (it also feels I shouldn't have to). Will be checking it again in the next days, and if I don't sort it out until Friday, I might just copy BatchNorm2D and convert it to BatchNorm1D and leave it to a better Rust programmer or someone more knowledgeable about all the hidden details of dfdx to deduplicate the code. Maybe it's a dumb thing I am doing and I 'll sort it out in the next days (still want to fight a bit through it). |
Okay sounds good. This is useful to know though, as I really want to make it as easy as possible for people to contribute at all levels of rust experience. So if you have any other things that you got stuck on, sharing them would be really useful! |
Implementation of BatchNorm2D exists, but of BatchNorm1D doesn't.
The text was updated successfully, but these errors were encountered: