Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training classifier on bottleneck #15

Open
tsaxena opened this issue Dec 28, 2020 · 6 comments
Open

Training classifier on bottleneck #15

tsaxena opened this issue Dec 28, 2020 · 6 comments

Comments

@tsaxena
Copy link

tsaxena commented Dec 28, 2020

@alexandre01 I am trying to train a classifier on the icons using the bottleneck embeddings that I get by model inference using pretrained model. In some cases though it doesnt seem to work. I used the code in your latent_ops notebook and encode each icon. In some cases I get this error
The size of tensor a (10) must match the size of tensor b (8) at non-singleton dimension 0
Am I missing something?

@tsaxena
Copy link
Author

tsaxena commented Dec 28, 2020

I am using the following methods to encode icon, but does seem to work on some of the icon indices.

`
def encode(data):
model_args = batchify((data[key] for key in cfg.model_args), device)
with torch.no_grad():
z = model(*model_args, encode_mode=True)
return z

def encode_icon(idx):
data = dataset.get(id=idx, random_aug=False)
return encode(data)
`

@alexandre01
Copy link
Owner

alexandre01 commented Dec 28, 2020

Hello @tsaxena!
Yes unfortunately this is due to the fact that the pretrained model was trained with a maximum of 8 paths per SVG. Since an index embedding is used in the model, this means one cannot perform inference using a larger amount.

You'd need to train a model with more paths or filter your classification dataset to eight paths.

@tsaxena
Copy link
Author

tsaxena commented Dec 28, 2020

Thanks for the prompt reply @alexandre01 . So does that mean you did not use all 100k icons for training the pretrained model?

@alexandre01
Copy link
Owner

alexandre01 commented Dec 28, 2020

Yes, the dataset is about 100k icons, but because of time constraints the pretrained model was only trained on a filtered subset. And I don't have access to the GPU server I used anymore.

@tsaxena
Copy link
Author

tsaxena commented Dec 28, 2020

I eventually want to fine tune the network on svgs that will have more than 8 paths. Do you suggest training from scratch? From what I understand, the max number of SVG paths is a configuration parameter that can be changed.

@pwichmann
Copy link

I am not @alexandre01. But what you say is correct, I think.

The max number of paths is a config parameter and can be changed in deepsvg/model/confg.py:

self.max_num_groups = 8          # Number of paths (N_P)

There are also individual configs in configs/deepsvg/ that overwrite this variable. You may need to increase the number of paths here as well.

You would need to retrain to cope with a larger number of paths. If you have a beefy GPU, retraining does not take very long. On my RTX 3090, I can retrain from scratch within hours.

What is your use case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants