-
-
Notifications
You must be signed in to change notification settings - Fork 2
import PyTorch weights for VGG11/13/16/19 and ResNet50/101/152 #2
Conversation
This looks good. Do you mind posting the validation numbers as a PR comment? You are right that saving |
Thank you for having a look! It is not too clear to me how we can adjust function ResNet(depth::Int = 50; pretrain = false, nclasses = 1000)
if pretrain
return BSON.load(some_path)[:model]
else
# construct model
end (I don't mind if you want to leave this PR open until this discussion is settled. ) Here are the top 5 results for this image for the networks in question:
|
I was thinking that function weights(model)
try
path = joinpath(@artifact_str(model), "$model.bson")
artifact = BSON.load(path, @__MODULE__)
if haskey(artifact, :model)
return artifact[:model]
elseif haskey(artifact, :weights)
return artifact[:weights]
else
throw(ArgumentError("No pre-trained weights available for $model."))
end
catch e
throw(ArgumentError("No pre-trained weights available for $model."))
end
end We'd want |
OK, this is indeed a good idea! |
Is there an issue in Flux tracking the new |
I think that this PR would solve the issue here: The issue is originally reported as: |
Okay let's think about merging this now that the Flux PR is merged. Are the VGG weights here with batch norm or not? |
I used the model in their "default" configuration: so no batchnorm for VGG, only for ResNet. |
I think this PR is pretty much ready then. Could you separate the testing code from the |
Great! I am currently out of office, but I should be able to work on this Tuesday or Wednesday next week. |
I am struggling to load the BSON files generated with Flux v0.12.8 / Metalhead v0.6.1 with the new version of Flux v0.13.0 and Metalhead v0.7.0 (while I can still load them with the old versions of the packages): This is the error I get:
It seems to work when I recreate the BSON files with the new version of Flux/Metalhead. |
also adding resnet 18 and resnet 34
Unfortunately, it seems that I used up all my LFS bandwidth
It seems that I have to wait one month for a new quota (or pay): I was able to push the code changes but the not the BSON files. |
Okay after discussing this LFS issue with other folks, it seems like our best option will be for someone (e.g. me) to make releases with the BSON files in a tarball. As opposed to right now, where they are additionally stored as LFS objects in the repo. I'll need to adjust the contributing guide to document this new procedure. For this particular set of weights, I'll just run your scripts to generate the files. Can you strip the PR of any changes to the weight files so that I can merge? |
OK, this PR includes now only the scripts. If it is easier for you I put the updated weights (temporarily) here: |
|
||
# loop over all parameters | ||
for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp) | ||
if size(flux_param) == size(pytorch_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a bit risky. What if the input size is equal to the output size but a permutation is needed?
Maybe we should have specific rules for each pytorch layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, but the scope of this script (so far at least) is just the tested VGG and ResNet models (not for an arbitrary model).
was there anything stopping this? |
The Git LFS issues, and a lack of bandwidth on my part for doing this manually. |
merging this as it is inert code that can be useful to people (I want to experiment a bit with it) |
Thank you for merging it! @CarloLucibello let me know how it goes! |
Could we open another issue to keep track of progress on the merging of the weights? |
You can follow the PR on Metalhead: FluxML/Metalhead.jl#164 |
As discussed, here are the weights for VGG11/13/16/19 and ResNet50/101/152 converted from PyTorch.
The models are saved as:
If I use this, I am wondering if the non-trainable parameters (in batchnorm) are saved. Let me know if this is OK.