diff --git a/pytorch2flux.jl b/pytorch2flux.jl new file mode 100644 index 0000000..50fbcbf --- /dev/null +++ b/pytorch2flux.jl @@ -0,0 +1,99 @@ +# Converts the weigths of a PyTorch model to a Flux model from Metalhead + +# PyTorch need to be installed + +# Tested on ResNet and VGG models + +using Flux +import Metalhead +using DataStructures +using Statistics +using BSON +using PyCall +using Images +using Test + +torchvision = pyimport("torchvision") +torch = pyimport("torch") + +modellib = [ + ("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11), + ("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13), + ("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16), + ("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19), + ("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18), + ("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34), + ("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50), + ("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101), + ("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152), +] + + +function _list_state(node::Flux.BatchNorm,channel,prefix) + # use the same order of parameters than PyTorch + put!(channel, (prefix * ".γ", node.γ)) # weigth (learnable) + put!(channel, (prefix * ".β", node.β)) # bias (learnable) + put!(channel, (prefix * ".μ", node.μ)) # running mean + put!(channel, (prefix * ".σ²", node.σ²)) # running variance +end + +function _list_state(node::Union{Flux.Conv,Flux.Dense},channel,prefix) + put!(channel, (prefix * ".weight", node.weight)) + + if node.bias !== Flux.Zeros() + put!(channel, (prefix * ".bias", node.bias)) + end +end + +_list_state(node,channel,prefix) = nothing + +function _list_state(node::Union{Flux.Chain,Flux.Parallel},channel,prefix) + for (i,n) in enumerate(node.layers) + _list_state(n,channel,prefix * ".layers[$i]") + end +end + +function list_state(node; prefix = "model") + Channel() do channel + _list_state(node,channel,prefix) + end +end + +for (modelname,jlmodel,pymodel) in modellib + + model = jlmodel() + pytorchmodel = pymodel(pretrained=true) + + state = OrderedDict(list_state(model.layers)) + + # pytorchmodel.state_dict() looses the order + state_dict = OrderedDict(pycall(pytorchmodel.state_dict,PyObject).items()) + pytorch_pp = OrderedDict((k,v.numpy()) for (k,v) in state_dict if !occursin("num_batches_tracked",k)) + + + # loop over all parameters + for ((flux_key,flux_param),(pytorch_key,pytorch_param)) in zip(state,pytorch_pp) + if size(flux_param) == size(pytorch_param) + # Dense weight and vectors + flux_param .= pytorch_param + elseif size(flux_param) == reverse(size(pytorch_param)) + tmp = pytorch_param + tmp = permutedims(tmp,ndims(tmp):-1:1) + + if ndims(flux_param) == 4 + # convolutional weights + flux_param .= reverse(tmp,dims=(1,2)) + else + flux_param .= tmp + end + else + @debug begin + @show size(flux_param), size(pytorch_param) + end + error("incompatible shape $flux_key $pytorch_key") + end + end + + @info "saving model $modelname" + BSON.@save joinpath(@__DIR__,"weights","$(modelname).bson") model +end diff --git a/test/compare_pytorch.jl b/test/compare_pytorch.jl new file mode 100644 index 0000000..76e3de1 --- /dev/null +++ b/test/compare_pytorch.jl @@ -0,0 +1,96 @@ +# Compare Flux model from Metalhead to PyTorch model +# for a sample image + +# PyTorch need to be installed + +# Tested on ResNet and VGG models + +using Flux +import Metalhead +using DataStructures +using Statistics +using BSON +using PyCall +using Images +using Test + +using MLUtils +using Random + +torchvision = pyimport("torchvision") +torch = pyimport("torch") + +modellib = [ + ("vgg11", () -> Metalhead.VGG(11), torchvision.models.vgg11), + ("vgg13", () -> Metalhead.VGG(13), torchvision.models.vgg13), + ("vgg16", () -> Metalhead.VGG(16), torchvision.models.vgg16), + ("vgg19", () -> Metalhead.VGG(19), torchvision.models.vgg19), + ("resnet18", () -> Metalhead.ResNet(18), torchvision.models.resnet18), + ("resnet34", () -> Metalhead.ResNet(34), torchvision.models.resnet34), + ("resnet50", () -> Metalhead.ResNet(50), torchvision.models.resnet50), + ("resnet101",() -> Metalhead.ResNet(101),torchvision.models.resnet101), + ("resnet152",() -> Metalhead.ResNet(152),torchvision.models.resnet152), +] + + +tr(tmp) = permutedims(tmp,ndims(tmp):-1:1) + + +function normalize(data) + cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1)) + cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1)) + return (data .- cmean) ./ cstd +end + +# test image +guitar_path = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") + +# image net labels +labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")) + +weightsdir = joinpath(@__DIR__,"..","weights") + +for (modelname,jlmodel,pymodel) in modellib + println(modelname) + + model = jlmodel() + + saved_model = BSON.load(joinpath(weightsdir,"$(modelname).bson")) + Flux.loadmodel!(model,saved_model[:model]) + + pytorchmodel = pymodel(pretrained=true) + + Flux.testmode!(model) + + sz = (224, 224) + img = Images.load(guitar_path); + img = imresize(img, sz); + # CHW -> WHC + data = permutedims(convert(Array{Float32}, channelview(img)), (3,2,1)) + data = normalize(data[:,:,:,1:1]) + + out = model(data) |> softmax; + out = out[:,1] + + println(" Flux:") + + for i in sortperm(out,rev=true)[1:5] + println(" $(labels[i]): $(out[i])") + end + + + pytorchmodel.eval() + output = pytorchmodel(torch.Tensor(tr(data))); + probabilities = torch.nn.functional.softmax(output[0], dim=0).detach().numpy(); + + println(" PyTorch:") + + for i in sortperm(probabilities[:,1],rev=true)[1:5] + println(" $(labels[i]): $(probabilities[i])") + end + + @test maximum(out) ≈ maximum(probabilities) + @test argmax(out) ≈ argmax(probabilities) + + println() +end