-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sparse initialization #1454
Conversation
end | ||
rows, cols = dims | ||
prop_zero = min(1.0, sparsity) | ||
num_zeros = ceil(Integer, prop_zero * rows) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use \div
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you mean something like div(rows, 1/prop_zero)
? This returns a float since prop_zero
is a float, so would require further casting to an integer. I thought above was a bit easier to follow, but am happy to go with what you think is best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
\div{tab}
should return an int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be missing something. I'm finding ÷
to behave the same way as div
, i.e. it's returning a float for float values of prop_zero
.
julia> prop_zero = 0.11; rows = 50;
julia> ÷(rows, 1/prop_zero, RoundUp)
6.0
Using ÷
as an infix operator still returns a float, but also doesn't allow to specify a RoundingMode
. We need to round up to maintain consistency with PyTorch.
Thanks for looking into this! I've left a couple of thoughts in the implementation. We would need to use a different name though since |
Current implementation does def sparse_(tensor, sparsity, std=0.01):
r"""Fills the 2D input `Tensor` as a sparse matrix, where the
non-zero elements will be drawn from the normal distribution
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
Hessian-free optimization` - Martens, J. (2010).
Args:
tensor: an n-dimensional `torch.Tensor`
sparsity: The fraction of elements in each column to be set to zero
std: the standard deviation of the normal distribution used to generate
the non-zero values
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.sparse_(w, sparsity=0.1)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
rows, cols = tensor.shape
num_zeros = int(math.ceil(sparsity * rows))
with torch.no_grad():
tensor.normal_(0, std)
for col_idx in range(cols):
row_indices = torch.randperm(rows)
zero_indices = row_indices[:num_zeros]
tensor[zero_indices, col_idx] = 0
return tensor We should follow them, swapping cols with rows, so |
Sorry, now I see that you randomly permute with |
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
I'm not clear on why we need to swap rows and cols compared to PyTorch. I understand the batch dimension is different, but as far as I could tell Flux uses similar shapes for weights. E.g In [1]: from torch import nn
In [2]: nn.Linear(1,2).weight.shape
Out[2]: torch.Size([2, 1]) julia> using Flux
julia> size(Dense(1,2).W)
(2, 1) |
@atiyo you're right, I always thought that pytorch applies the transform Another consideration is that maybe we can move initialization functions to a submodule, but this doesn't have to be necessarily discussed here. |
Let's not move it to a submodule, doesn't seem worthwhile enough as a standalone to me
|
bors r+ |
Build succeeded: |
Add sparse initialization, documentation and tests. Trim whitespace in editted files.
This PR is intended to address one of the outstanding points in bringing Flux to parity with PyTorch's features so it partially addresses #1431 and fully addresses #1450.
The implementation follows the method given in PyTorch implementation: a normally-distributed array is created, then a fixed proportion of randomly chosen row-indices is zeroed out for every column. Like the PyTorch version, it is restricted to 2-d Arrays.
PR Checklist
@dhairyagandhi96
(for API changes).