Skip to content
This repository has been archived by the owner on Aug 4, 2022. It is now read-only.

import PyTorch weights for VGG11/13/16/19 and ResNet50/101/152 #2

Merged
merged 1 commit into from
Jun 1, 2022

Conversation

Alexander-Barth
Copy link
Contributor

As discussed, here are the weights for VGG11/13/16/19 and ResNet50/101/152 converted from PyTorch.

The models are saved as:

BSON.@save "mymodel.bson" model

If I use this, I am wondering if the non-trainable parameters (in batchnorm) are saved. Let me know if this is OK.

@darsnack
Copy link
Member

darsnack commented Feb 9, 2022

This looks good. Do you mind posting the validation numbers as a PR comment?

You are right that saving params(model) will exclude the weights. Saving the model struct itself will be okay though. We can probably adjust this line to either load a key labeled "weights" or "model" depending on how the model was saved. Then copy either the weights or the weights + state over accordingly. Probably this is something that we should address more concretely in Flux itself.

@Alexander-Barth
Copy link
Contributor Author

Thank you for having a look!

It is not too clear to me how we can adjust Metalhead.weights for this case. What about changing the model constructors like this:

function ResNet(depth::Int = 50; pretrain = false, nclasses = 1000)
    if pretrain
         return BSON.load(some_path)[:model]
    else
         # construct model
end

(I don't mind if you want to leave this PR open until this discussion is settled. )

Here are the top 5 results for this image for the networks in question:

vgg11
Flux:
acoustic guitar: 0.6216891
electric guitar: 0.22196834
stage: 0.12096002
banjo: 0.017143413
pick: 0.008767363
PyTorch:
acoustic guitar: 0.6216894
electric guitar: 0.22196802
stage: 0.12096031
banjo: 0.017143423
pick: 0.008767367
[ Info: saving model vgg11

vgg13
Flux:
electric guitar: 0.42012087
acoustic guitar: 0.35307145
stage: 0.19994916
banjo: 0.013066366
sax: 0.002884374
PyTorch:
electric guitar: 0.42012057
acoustic guitar: 0.35307187
stage: 0.19994901
banjo: 0.013066382
sax: 0.0028843747
[ Info: saving model vgg13

vgg16
Flux:
acoustic guitar: 0.52722526
stage: 0.28869343
electric guitar: 0.17026302
banjo: 0.005473086
pick: 0.0025982172
PyTorch:
acoustic guitar: 0.5272254
stage: 0.28869352
electric guitar: 0.17026275
banjo: 0.0054730875
pick: 0.0025982202
[ Info: saving model vgg16

vgg19
Flux:
acoustic guitar: 0.51687706
electric guitar: 0.2988484
stage: 0.12674169
pick: 0.021229278
banjo: 0.01773064
PyTorch:
acoustic guitar: 0.5168766
electric guitar: 0.29884866
stage: 0.12674181
pick: 0.021229278
banjo: 0.01773064
[ Info: saving model vgg19

resnet50
Flux:
acoustic guitar: 0.88006437
stage: 0.02751653
electric guitar: 0.02026861
pick: 0.01966484
banjo: 0.015501627
PyTorch:
acoustic guitar: 0.88006437
stage: 0.027516453
electric guitar: 0.020268533
pick: 0.019664822
banjo: 0.015501627
[ Info: saving model resnet50

resnet101
Flux:
acoustic guitar: 0.93985564
electric guitar: 0.0155377425
banjo: 0.0148489
stage: 0.013608559
pick: 0.0070351795
PyTorch:
acoustic guitar: 0.9398558
electric guitar: 0.015537745
banjo: 0.014848876
stage: 0.013608537
pick: 0.0070351744
[ Info: saving model resnet101

resnet152
Flux:
acoustic guitar: 0.9387497
banjo: 0.01910968
electric guitar: 0.017470127
stage: 0.011316595
pick: 0.008845404
PyTorch:
acoustic guitar: 0.9387497
banjo: 0.019109717
electric guitar: 0.017470127
stage: 0.011316595
pick: 0.008845378
[ Info: saving model resnet152

@darsnack
Copy link
Member

I was thinking that weights would do something like

function weights(model)
  try
    path = joinpath(@artifact_str(model), "$model.bson")
    artifact = BSON.load(path, @__MODULE__)
    if haskey(artifact, :model)
      return artifact[:model]
    elseif haskey(artifact, :weights)
      return artifact[:weights]
    else
      throw(ArgumentError("No pre-trained weights available for $model."))
    end
  catch e
    throw(ArgumentError("No pre-trained weights available for $model."))
  end
end

We'd want Flux.loadparams!(m, x) to do what it currently does when x is a vector, but otherwise it will use Functors.jl to walk m and x structurally and copy over the weights/state. That way instead of just returning what BSON.load spits out (which could be wrong), we actually make sure that the model built and the model saved are structurally equivalent.

@Alexander-Barth
Copy link
Contributor Author

OK, this is indeed a good idea!

@adrhill
Copy link

adrhill commented Mar 22, 2022

Is there an issue in Flux tracking the new Flux.loadparams! functionality that would be required?

@Alexander-Barth
Copy link
Contributor Author

I think that this PR would solve the issue here:
FluxML/Flux.jl#1875

The issue is originally reported as:
FluxML/Flux.jl#1027

@darsnack
Copy link
Member

darsnack commented Apr 6, 2022

Okay let's think about merging this now that the Flux PR is merged. Are the VGG weights here with batch norm or not?

@Alexander-Barth
Copy link
Contributor Author

I used the model in their "default" configuration: so no batchnorm for VGG, only for ResNet.

@darsnack
Copy link
Member

darsnack commented Apr 6, 2022

I think this PR is pretty much ready then. Could you separate the testing code from the pytorch2flux.jl script? It would be nice to have a consistent script for evaluating all the models.

@Alexander-Barth
Copy link
Contributor Author

Great! I am currently out of office, but I should be able to work on this Tuesday or Wednesday next week.

@Alexander-Barth
Copy link
Contributor Author

I am struggling to load the BSON files generated with Flux v0.12.8 / Metalhead v0.6.1 with the new version of Flux v0.13.0 and Metalhead v0.7.0 (while I can still load them with the old versions of the packages):

This is the error I get:

julia> saved_model = BSON.load(joinpath(weightsdir,"$(modelname).bson"))                                                                                              
ERROR: MethodError: Cannot `convert` an object of type Float64 to an object of type Vector{Any}                                                                       
Closest candidates are:                                                                                                                                               
  convert(::Type{Array{T, N}}, ::AxisArray{T, N}) where {T, N} at ~/.julia/packages/AxisArrays/FWWEV/src/core.jl:304                                                  
  convert(::Type{Array{T, N}}, ::FillArrays.Ones{V, N}) where {T, V, N} at ~/.julia/packages/FillArrays/5Arin/src/FillArrays.jl:440                                   
  convert(::Type{Array{T, N}}, ::StaticArrays.SizedArray{S, T, N, N, Array{T, N}}) where {S, T, N} at ~/.julia/packages/StaticArrays/12k3X/src/SizedArray.jl:118      
  ...                                                                                                                                                                 
Stacktrace:                                                                                                                                                           
  [1] newstruct!(::IdDict{Any, Any}, ::Float64, ::Function, ::Bool)                                                                                                   
    @ BSON ~/.julia/packages/BSON/rOaki/src/extensions.jl:107                                                                                                         
  [2] newstruct_raw(cache::IdDict{Any, Any}, T::Type, d::Dict{Symbol, Any}, init::                                                                                    
Module)                                                                                                                                                               
    @ BSON ~/.julia/packages/BSON/rOaki/src/extensions.jl:154                                                                                                         
  [3] (::BSON.var"#49#50")(d::Dict{Symbol, Any}, cache::IdDict{Any, Any}, init::Module)                                                                               
...

It seems to work when I recreate the BSON files with the new version of Flux/Metalhead.
Do you know if the the BSON files specific to a Flux/Metalhead version?

@Alexander-Barth
Copy link
Contributor Author

Alexander-Barth commented Apr 14, 2022

also adding resnet 18 and resnet 34

resnet18                                                                                                                                                              
  Flux:                                                                                                                                                               
    acoustic guitar: 0.6355375                                                                                                                                        
    stage: 0.2761564                                                                                                                                                  
    electric guitar: 0.06121754                                                                                                                                       
    pick: 0.01084846                                                                                                                                                  
    banjo: 0.0056079887                                                                                                                                               
  PyTorch:                                                                                                                                                            
    acoustic guitar: 0.635536                                                                                                                                         
    stage: 0.27615786                                                                                                                                                 
    electric guitar: 0.06121728                                                                                                                                       
    pick: 0.010848445                                                                                                                                                 
    banjo: 0.0056079812                                                                                                                                               
                                                                                                                                                                      
resnet34                                                                                                                                                              
  Flux:                                                                                                                                                               
    acoustic guitar: 0.31706628                                                                                                                                       
    stage: 0.22608073                                                                                                                                                 
    microphone: 0.09899674                                                                                                                                            
    electric guitar: 0.09198513                                                                                                                                       
    pick: 0.05222302                                                                                                                                                  
  PyTorch:                                                                                                                                                            
    acoustic guitar: 0.31706595                                                                                                                                       
    stage: 0.22608049                                                                                                                                                 
    microphone: 0.09899692                                                                                                                                            
    electric guitar: 0.09198507                                                                                                                                       
    pick: 0.052222993                                

Unfortunately, it seems that I used up all my LFS bandwidth

$ git push
batch response: This repository is over its data quota. Account responsible for LFS bandwidth should purchase more data packs to restore access.

It seems that I have to wait one month for a new quota (or pay):
https://stackoverflow.com/a/62905819/3801401

I was able to push the code changes but the not the BSON files.

@darsnack
Copy link
Member

darsnack commented Apr 14, 2022

Okay after discussing this LFS issue with other folks, it seems like our best option will be for someone (e.g. me) to make releases with the BSON files in a tarball. As opposed to right now, where they are additionally stored as LFS objects in the repo.

I'll need to adjust the contributing guide to document this new procedure. For this particular set of weights, I'll just run your scripts to generate the files. Can you strip the PR of any changes to the weight files so that I can merge?

@Alexander-Barth
Copy link
Contributor Author

OK, this PR includes now only the scripts.

If it is easier for you I put the updated weights (temporarily) here:
https://data-assimilation.net/upload/Alex/MetalheadWeights_tmp/
(I will need to remove the files in a week or so because the web-server is also close to full)


# loop over all parameters
for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp)
if size(flux_param) == size(pytorch_param)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit risky. What if the input size is equal to the output size but a permutation is needed?
Maybe we should have specific rules for each pytorch layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, but the scope of this script (so far at least) is just the tested VGG and ResNet models (not for an arbitrary model).

@CarloLucibello
Copy link
Member

was there anything stopping this?

@darsnack
Copy link
Member

The Git LFS issues, and a lack of bandwidth on my part for doing this manually.

@CarloLucibello
Copy link
Member

CarloLucibello commented Jun 1, 2022

merging this as it is inert code that can be useful to people (I want to experiment a bit with it)

@CarloLucibello CarloLucibello merged commit ab10732 into FluxML:main Jun 1, 2022
@Alexander-Barth
Copy link
Contributor Author

Thank you for merging it! @CarloLucibello let me know how it goes!

@adrhill
Copy link

adrhill commented Jun 1, 2022

Could we open another issue to keep track of progress on the merging of the weights?

@darsnack
Copy link
Member

darsnack commented Jun 9, 2022

You can follow the PR on Metalhead: FluxML/Metalhead.jl#164

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants