Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error exporting a model with Flux.GlobalMeanPool() operation #87

Closed
lambe opened this issue Oct 26, 2023 · 6 comments
Closed

Error exporting a model with Flux.GlobalMeanPool() operation #87

lambe opened this issue Oct 26, 2023 · 6 comments

Comments

@lambe
Copy link

lambe commented Oct 26, 2023

Here's a simple script to test exporting and importing a model

using Flux
using ONNXNaiveNASflux

test_model = Chain(
    Conv((1, 1), 10 => 10),
    GlobalMeanPool(),
    Flux.MLUtils.flatten,
    Dense(10, 10, relu),
    Dense(10, 10, relu),
)

test_tensor = rand(Float32, 5, 5, 10, 4)
out_tensor = test_model(test_tensor)
println("out_tensor size: ", size(out_tensor))  # Expected output size (10, 4)

# Export the model
ONNXNaiveNASflux.save("test.onnx", test_model)

# Import the model
new_model = ONNXNaiveNASflux.load("test.onnx")
@assert new_model == test_model

Error returned:

ERROR: LoadError: MethodError: no method matching size(::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})

Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:582
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:581
  size(::Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRPackedQ})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.3+0.x64.linux.gnu/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:585
  ...

Stacktrace:
  [1] (::GlobalMeanPool)(x::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/conv.jl:631
  [2] macro expansion
    @ ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53 [inlined]
  [3] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}, x::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:53
  [4] (::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}})(x::ONNXNaiveNASflux.ProtoProbe{String, ONNXNaiveNASflux.var"#306#307"{typeof(ONNXNaiveNASflux.genname), Set{String}}, ONNXNaiveNASflux.BaseOnnx.GraphProto, Tuple{Missing, Missing, Int64, Missing}})
    @ Flux ~/.julia/packages/Flux/uCLgc/src/layers/basic.jl:51
  [5] graphproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}, indata::Pair{String, Tuple{Missing, Missing, Int64, Missing}}; namestrat::Function)
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:200
  [6] modelproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}, indata::Pair{String, Tuple{Missing, Missing, Int64, Missing}}; modelname::String, namestrat::Function, posthook::typeof(ONNXNaiveNASflux.validate), kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:45
  [7] modelproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}, inshapes::Tuple{Missing, Missing, Int64, Missing}; kwargs::Base.Pairs{Symbol, String, Tuple{Symbol}, NamedTuple{(:modelname,), Tuple{String}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:42
  [8] modelproto(f::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}; kwargs::Base.Pairs{Symbol, String, Tuple{Symbol}, NamedTuple{(:modelname,), Tuple{String}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:41
  [9] modelproto
    @ ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:41 [inlined]
 [10] #save#308
    @ ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:10 [inlined]
 [11] save(::String, ::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, GlobalMeanPool, typeof(MLUtils.flatten), Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}})
    @ ONNXNaiveNASflux ~/.julia/packages/ONNXNaiveNASflux/EEDqS/src/serialize/serialize.jl:10
 [12] top-level scope
    @ ~/toolpath/AlphaToolpath/examples/globalmeanpool_test.jl:19
 [13] include(fname::String)
    @ Base.MainInclude ./client.jl:478
 [14] top-level scope
    @ REPL[1]:1
 [15] top-level scope
    @ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52
in expression starting at /home/ablambe/toolpath/AlphaToolpath/examples/globalmeanpool_test.jl:19

GlobalAveragePool is listed as a supported operation in the docs, but it doesn't seem to work for me. Is it a name change issue?

@DrChainsaw
Copy link
Owner

I had to check the code myself, but it seems that serialization of Flux built in global pools is not implemented. Iirc the deserialization implementation as well as the statement that it is supported is from before Flux had its own global pools and just told users to build their own (e.g. from Max/MeanPool).

It should be pretty easy to add though. Just add the methods somewhere here and just directly call the functions below. Just be mindful if there is a difference between what Flux does and what ONNX thinks the global pools does described somewhere here, for example w.r.t dropping the spatial dimensions.

The second argument to the globalmeanpool and globalmaxpool can be used to account for this difference (identity should work if there is no difference).

@lambe
Copy link
Author

lambe commented Oct 27, 2023

Ok, I've added the appropriate methods and the script in the description is able to save and load the model. See #88

There's still a warning thrown

┌ Warning: No valid input sizes provided. Shape inference could not be done. Either provide Integer insizes manually or use load(...; infer_shapes=false) to disable. If disabled, graph mutation might not work.
└ @ ONNXNaiveNASflux ~/toolpath/ONNXNaiveNASflux.jl/src/deserialize/infershape.jl:47

and the @assert statement fails, so I'm going to see if updating deserialize/ops.jl is needed.

@lambe
Copy link
Author

lambe commented Oct 27, 2023

Update: the @assert statement should be failing since test_model is a flux model type, but new_model is an ONNX-style computational graph.

However, some good news, running test_tensor through new_model results in an approximately identical answer. (Add the following code block to the test script in the description.)

import ONNXRunTime as ORT

function run_ort(nn_ort, x)
    x_t = permutedims(x, (4, 3, 2, 1))
    inp = Dict(only(nn_ort.input_names) => x_t)
    out = only(values(nn_ort(inp)))
    permutedims(out, reverse(1:ndims(out)))
end

ort_new = ORT.load_inference("test.onnx")
new_out_tensor = run_ort(ort_new, test_tensor)
@assert isapprox(new_out_tensor, out_tensor)

I'm happy with this solution, but interested to know if that warning can be addressed before merging the PR.

@DrChainsaw
Copy link
Owner

DrChainsaw commented Oct 27, 2023

Took a look at the warning and it is correct to warn here so no action is needed.

The reason for the warning is this:

  1. The height and width of the input can't be inferred from the first layer type alone (e.g. Conv((1,1), 10 => 10)) does not need the first two dimensions to have any particular size) when saving the model. The number of channels is correctly inferred though.
  2. ONNXNaiveNASflux uses Flux.outputshape to infer all input sizes when they are not given as input when loading. This function throws an exception if the size of any dimension is missing or 0. It will also throw is sizes don't line up (e.g. if someone does flatten or reshape without global pooling first) so we can't just guess on some size. Instead it is checked that all sizes are >0 and if not the attempt at shape inference is abandoned and the warning is printed.

To avoid the warning, you can just supply the input sizes when loading the model:

ONNXNaiveNASflux.load("test.onnx", size(test_tensor))

or when saving:

ONNXNaiveNASflux.save("test.onnx", test_model, size(test_tensor))

The input sizes are only used when using the NaiveNASlib features to do parameter pruning or other NAS-like things which is why the model seems to work just fine for inference (and should work fine for training too).

There are only a few op-types which need the size info so chances are that the NaiveNASlib stuff would work as well even if input sizes are not inferred. Since it would be quite difficult to understand the reason for the error you get if you try to change the dimension of some parameter when this info is missing I decided to always print the warning by default.

@lambe
Copy link
Author

lambe commented Oct 31, 2023

Great, thanks for the context! I'll update my calling code with a size parameter.

@DrChainsaw
Copy link
Owner

Fixed in #88

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants