Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

world age issues when loading a bson file containing a model with flux utility functions #1769

Closed
ajaiantilal opened this issue Nov 18, 2021 · 2 comments

Comments

@ajaiantilal
Copy link

I get a method error when trying to load the following model (error on the BSON.load line). I think its due to the reshape/unsqueeze.

Not sure if its a flux.jl issue or a BSON.jl issue or an error on my part.
Is there a way to properly save/load models with reshape/unsqueeze functions?

using Flux
using BSON
Generator_Conv(input_dim::Int, hidden_dim::Int,latent_dim::Int) = 
    Chain(
        Dense(input_dim, hidden_dim),
        BatchNorm(hidden_dim),            
        #x->Flux.reshape(x,hidden_dim,1,:),  #either reshape/unsqueeze triggers it
        x->Flux.unsqueeze(x,2),
    )
function test()
    gen =  Generator_Conv(10,10,10) |> cpu    
    BSON.@save "test.bson" gen        
    #gen = BSON.load(joinpath(@__DIR__,"test.bson"),@__MODULE__)[:gen]
    gen = BSON.load("test.bson")[:gen]
    a = rand(10,10)
    gen(a)
end
test()

Running on Julia 1.6, BSON v0.3.3, Flux v0.12.6

@ToucheSir
Copy link
Member

This is a dupe of JuliaIO/BSON.jl#69. The culprit is the closure, which you can see in the MWE for that issue.

If only loading at the top level is not an option, I'd recommend one of the following:

  1. Create an external function that encapsulates the same logic (assuming you don't need a closure)
add2(x) = x + 2

function test()
  f = add2
  BSON.@save "test.bson" f
  f = BSON.load("test.bson")[:f]
  f(1)
end
  1. Use a callable struct instead of an anonymous function. Helpfully, Julia provides a few built-ins for this:
function test()
  f = Base.Fix2(+, 2) # or Base.Fix1(+, 2) for 2 + x
  BSON.@save "test.bson" f
  f = BSON.load("test.bson")[:f]
  f(1)
end

@ajaiantilal
Copy link
Author

Thanks this was extremely helpful. And just in case someone else runs into the same issues, these are the two ways i did it for unsqueeze/reshape

    #method 1: works though couldn't capture additional arguments
    layer_reshape_add_middle_dim(x) = reshape(x,size(x,1),1,size(x,2))
    
    #method 2:  capturing arbitrary number of arguments and passing them onwards
    struct layer_reshape_struct
        dim::Int
    end
    (m::layer_reshape_struct)(x) = Flux.unsqueeze(x,m.dim)
 
    Generator_Conv(input_dim::Int, hidden_dim::Int,latent_dim::Int) = 
        Chain(
            Dense(input_dim, hidden_dim),
            BatchNorm(hidden_dim),
            #layer_reshape_add_middle_dim
            layer_reshape_struct(2) 
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants