Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test failed on v0.2 #11

Closed
yifita opened this issue Aug 21, 2017 · 18 comments
Closed

test failed on v0.2 #11

yifita opened this issue Aug 21, 2017 · 18 comments

Comments

@yifita
Copy link

yifita commented Aug 21, 2017

the efficient_densenet_bottleneck_test.py failed in test_backward_computes_backward_pass

>       assert(almost_equal(layer.conv.weight.grad.data, layer_efficient.conv_weight.grad.data))
E       assert False
E        +  where False = almost_equal(\n(0 ,0 ,.,.) = \n    0.3746\n\n(0 ,1 ,.,.) = \n   70.7402\n\n(0 ,2 ,.,.) = \n   68.3647\n\n(0 ,3 ,.,.) = \n    5.2501\n\n(0 ,4 ,.,...) = \n  101.7459\n\n(3 ,6 ,.,.) = \n   10.9038\n\n(3 ,7 ,.,.) = \n    0.0000\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n, \n(0 ,0 ,.,.) = \n  0.0000e+00\n\n(0 ,1 ,.,.) = \n -2.0594e+24\n\n(0 ,2 ,.,.) = \n -9.6653e+20\n\n(0 ,3 ,.,.) = \n  2.1138e+21\n\n(...-1.5375e+00\n\n(3 ,6 ,.,.) = \n -7.0127e-03\n\n(3 ,7 ,.,.) = \n  0.0000e+00\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n)
E        +    where \n(0 ,0 ,.,.) = \n    0.3746\n\n(0 ,1 ,.,.) = \n   70.7402\n\n(0 ,2 ,.,.) = \n   68.3647\n\n(0 ,3 ,.,.) = \n    5.2501\n\n(0 ,4 ,.,...) = \n  101.7459\n\n(3 ,6 ,.,.) = \n   10.9038\n\n(3 ,7 ,.,.) = \n    0.0000\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n = Variable containing:\n(0 ,0 ,.,.) = \n    0.3746\n\n(0 ,1 ,.,.) = \n   70.7402\n\n(0 ,2 ,.,.) = \n   68.3647\n\n(0 ,3 ,.,.) = \n ...) = \n  101.7459\n\n(3 ,6 ,.,.) = \n   10.9038\n\n(3 ,7 ,.,.) = \n    0.0000\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n.data
E        +      where Variable containing:\n(0 ,0 ,.,.) = \n    0.3746\n\n(0 ,1 ,.,.) = \n   70.7402\n\n(0 ,2 ,.,.) = \n   68.3647\n\n(0 ,3 ,.,.) = \n ...) = \n  101.7459\n\n(3 ,6 ,.,.) = \n   10.9038\n\n(3 ,7 ,.,.) = \n    0.0000\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n = Parameter containing:\n(0 ,0 ,.,.) = \n  0.0978\n\n(0 ,1 ,.,.) = \n  1.9624\n\n(0 ,2 ,.,.) = \n  2.4802\n\n(0 ,3 ,.,.) = \n  1.06...5 ,.,.) = \n  0.4832\n\n(3 ,6 ,.,.) = \n  1.0052\n\n(3 ,7 ,.,.) = \n  1.7624\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n.grad
E        +        where Parameter containing:\n(0 ,0 ,.,.) = \n  0.0978\n\n(0 ,1 ,.,.) = \n  1.9624\n\n(0 ,2 ,.,.) = \n  2.4802\n\n(0 ,3 ,.,.) = \n  1.06...5 ,.,.) = \n  0.4832\n\n(3 ,6 ,.,.) = \n  1.0052\n\n(3 ,7 ,.,.) = \n  1.7624\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n = Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), bias=False).weight
E        +          where Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), bias=False) = Sequential (\n  (norm): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True)\n  (relu): ReLU (inplace)\n  (conv): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)\n).conv
E        +    and   \n(0 ,0 ,.,.) = \n  0.0000e+00\n\n(0 ,1 ,.,.) = \n -2.0594e+24\n\n(0 ,2 ,.,.) = \n -9.6653e+20\n\n(0 ,3 ,.,.) = \n  2.1138e+21\n\n(...-1.5375e+00\n\n(3 ,6 ,.,.) = \n -7.0127e-03\n\n(3 ,7 ,.,.) = \n  0.0000e+00\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n = Variable containing:\n(0 ,0 ,.,.) = \n  0.0000e+00\n\n(0 ,1 ,.,.) = \n -2.0594e+24\n\n(0 ,2 ,.,.) = \n -9.6653e+20\n\n(0 ,3 ,.,....-1.5375e+00\n\n(3 ,6 ,.,.) = \n -7.0127e-03\n\n(3 ,7 ,.,.) = \n  0.0000e+00\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n.data
E        +      where Variable containing:\n(0 ,0 ,.,.) = \n  0.0000e+00\n\n(0 ,1 ,.,.) = \n -2.0594e+24\n\n(0 ,2 ,.,.) = \n -9.6653e+20\n\n(0 ,3 ,.,....-1.5375e+00\n\n(3 ,6 ,.,.) = \n -7.0127e-03\n\n(3 ,7 ,.,.) = \n  0.0000e+00\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n = Parameter containing:\n(0 ,0 ,.,.) = \n  0.0978\n\n(0 ,1 ,.,.) = \n  1.9624\n\n(0 ,2 ,.,.) = \n  2.4802\n\n(0 ,3 ,.,.) = \n  1.06...5 ,.,.) = \n  0.4832\n\n(3 ,6 ,.,.) = \n  1.0052\n\n(3 ,7 ,.,.) = \n  1.7624\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n.grad
E        +        where Parameter containing:\n(0 ,0 ,.,.) = \n  0.0978\n\n(0 ,1 ,.,.) = \n  1.9624\n\n(0 ,2 ,.,.) = \n  2.4802\n\n(0 ,3 ,.,.) = \n  1.06...5 ,.,.) = \n  0.4832\n\n(3 ,6 ,.,.) = \n  1.0052\n\n(3 ,7 ,.,.) = \n  1.7624\n[torch.cuda.FloatTensor of size 4x8x1x1 (GPU 0)]\n = _EfficientDensenetBottleneck (\n).conv_weight

I uncommented the code in densenet_efficient.py

self.efficient_batch_norm.training = False,

but the issue persists.

@yifita
Copy link
Author

yifita commented Aug 21, 2017

Hi I hope you guys could have a look at this, it seems like a pretty major problem.
So far keeping L269 as is produces at least the same relu_output, but the backward convolution is always wrong. Could it be that the _cudnn_info from L442, stores certain pointer to the data from the forward pass, which is overwritten in the recomputation?

def forward(self, weight, bias, input):
    self._cudnn_info = torch._C._cudnn_convolution_full_forward(...)

def backward(self, weight, bias, input, grad_output):
    torch._C._cudnn_convolution_backward_filter(grad_output, input, grad_weight, self._cudnn_info,
                                                    cudnn.benchmark)

@taineleau-zz
Copy link
Collaborator

Hi, thanks for your info! I will look at it asap.

@gpleiss
Copy link
Owner

gpleiss commented Aug 21, 2017

@yifita what version of PyTorch are you using?

@yifita
Copy link
Author

yifita commented Aug 21, 2017

@gpleiss I'm using cuda 8.0, cudnn 5.1.5, python 3.6, pytorch 0.1.12

@gpleiss
Copy link
Owner

gpleiss commented Aug 24, 2017

I ran this on cuda 8.0, cudnn 5.1.5, python 3.6, and pytorch 0.1.12_2, and all the tests pass for me...

@yifita
Copy link
Author

yifita commented Aug 24, 2017

Oh i'm sorry, I actually have pytorch 0.2.0! I wonder how I got the previous version
Indeed, after i downgraded my pytorch to 0.1.12, the test past. I wonder what have changed?

@taineleau-zz
Copy link
Collaborator

@yifita I guess you'd like to have a look at the release note https://github.com/pytorch/pytorch/releases/tag/v0.2.0.

Basically, some fundamental operations' api is changed.

@yifita
Copy link
Author

yifita commented Aug 24, 2017

had a look, seems that they focused on broadcasting and indexing while adding a few layers that I wanted to experiment in my model. So i'd prefer sticking around with 0.2.0.
will you be catching up the 0.2.0 version in this repository?

@taineleau-zz
Copy link
Collaborator

taineleau-zz commented Aug 24, 2017

I think you could do this yourself. Add these lines and clean up all the warning should make this code work under ver 0.2:

# insert this to the top of your scripts (usually main.py)
import sys, warnings, traceback, torch
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
    sys.stderr.write(warnings.formatwarning(message, category, filename, lineno, line))
    traceback.print_stack(sys._getframe(2))
warnings.showwarning = warn_with_traceback; warnings.simplefilter('always', UserWarning);
torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True

see the Important Breakages and Workarounds part in the release note.

@taineleau-zz taineleau-zz changed the title test failed test failed on v0.2 Aug 24, 2017
@gpleiss
Copy link
Owner

gpleiss commented Aug 24, 2017

We will be catching this repo up soon! I'll try to get to it later today.

@yifita
Copy link
Author

yifita commented Aug 24, 2017

@taineleau it doesn't seem to be that straightforward. The only warning i had is nn.Container deprecated, but changing it to nn.Module didn't solve the issue.

@taineleau-zz
Copy link
Collaborator

@yifita hmmm... I come to realize that they might change the backend a little bit...

@yifita
Copy link
Author

yifita commented Aug 24, 2017

yep :/ my guess is that it's related to double backpropagation

@taineleau-zz
Copy link
Collaborator

Sorry for the late response.
@yifita : Thanks for your feedback!

After looking at the v0.2's backend, I guess the convnd is broken mainly because they refactored the API a little bit. Currently, they have removed the Python level code calling the conv.backward from the code base so it's a little bit hard to learn the correct API from the CPP file. I have sent them an email to ask about the change, and hopefully, we could get some hints from them.

@taineleau-zz
Copy link
Collaborator

Update:

I shot an email to PyTorch's developer and he said the API of torch._C._cudnn_convolution_ was not changed, so I am sorry that I cannot figure it out.

However, the good news is that I have run the models on PyTorch v0.2, the performance (both final accuracy and speed) remains the same as that on v0.1.12.

So I guess this issue wouldn't be a big concern for now.

@mingminzhen
Copy link

For the torch._C.cudnn_convolution issue, I try to change the invoking as

torch._C._cudnn_convolution_full_forward(
            input, weight, bias, res,
            (self.padding, self.padding),
            (self.stride, self.stride),
            (self.dilation, self.dilation),
            self.groups, cudnn.benchmark,False
        )

Additional "False" parameter is added. It works in the master version of PyTorch (0.4).
But I'm not sure about the meaning of the parameter.

@taineleau-zz
Copy link
Collaborator

@mingminzhen This flag means using deterministic conv or not. This is not the issue though. Please check this topic.

@gpleiss
Copy link
Owner

gpleiss commented Mar 13, 2018

With #28, this repo is now compatible with PyTorch 0.3.x. Closing this issue.

@gpleiss gpleiss closed this as completed Mar 13, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants