Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why do I get '2' as batch size? #168

Open
Flock1 opened this issue Mar 12, 2021 · 4 comments
Open

Why do I get '2' as batch size? #168

Flock1 opened this issue Mar 12, 2021 · 4 comments

Comments

@Flock1
Copy link

Flock1 commented Mar 12, 2021

Hey,

This is a really great tool to visualize the model. However, I was trying to see how my decoder is working in the VAE and the input to the VAE is the latent space (dim = (2,2)). However, when I get the output, I see an extra 2 there. Like this:
summary(decoder, (2,2))
Output is:

DECODER
torch.Size([2, 2, 2])

My decoder is initialized like this:

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
        self.adapt = nn.AdaptiveMaxPool1d(input_len)

        
    def forward(self, x):
        print("DECODER")
        print(x.shape) #1
        x = self.fc(x)
        x = x.reshape(-1,x.shape[0], x.shape[1])
        x = self.adapt(x)
        x = x.view(x.size(0), capacity*2, axis_transfer, axis_transfer) # unflatten batch of feature vectors to a batch of multi-channel feature maps
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
        return x

Do let me know.

@cainmagi
Copy link

torchsummary would use a batch size 2 tensor to test the network, and get the information of each layer.
See the codes here:

# batch_size of 2 for batchnorm

Even you configure batch_size in the input argument, this value is only used for calculating the flow size. The network is still tested by the batch size 2 tensor.

This behavior may cause errors when the network requires the input batch to be a specific value. To fix this problem, I modify the codes and let the testing tensor use batch_size when this value is not None, see
https://github.com/sksq96/pytorch-summary/pull/165/files#diff-ebda1cc7f304708e45ef4e19fb0484036eff8eb3c4b47a2886ca1cf0f731c0bbR118

Actually, it seems that the author has not maintained this package for a long time. I recommend you to try some alternatives like torchinfo.

@Flock1
Copy link
Author

Flock1 commented Mar 17, 2021

Thanks a lot. I wanted to ask why does it take '1' as batch size when I input a shape similar to an image, like (3,28,28)? Because in that case, I don't see '2' as batch size.

I will definitely check out torchinfo

@cainmagi
Copy link

cainmagi commented Mar 17, 2021

Thanks a lot. I wanted to ask why does it take '1' as batch size when I input a shape similar to an image, like (3,28,28)? Because in that case, I don't see '2' as batch size.

I will definitely check out torchinfo

I do not understand your question. In your previous posts, you have not mentioned any batch with a batch size of 1.

By the way, I do not understand

I don't see '2' as batch size.

either. Because you have mentioned that your output is

DECODER
torch.Size([2, 2, 2])

Why do you say you do not see 2 as batch size? It is clear that the first element of the returned shape is 2.


Here is a tip: if your are using

torchsummary.summary(..., input_size=...)

You should not let your input_size become something like [3, 28, 28]. That would cause errors. Instead, you should use ((3, 28, 28), ) or (3, 28, 28). The official implementation is quite unstable in some cases.

@letsgo247
Copy link

@cainmagi Thanks! torchinfo works!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants