Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constructing a Chain from a dictionary #2142

Closed
shivance opened this issue Dec 27, 2022 · 9 comments
Closed

Constructing a Chain from a dictionary #2142

shivance opened this issue Dec 27, 2022 · 9 comments

Comments

@shivance
Copy link

shivance commented Dec 27, 2022

Describe the potential feature

PyTorch allows users to create networks from dictionaries, i.e.
nn.Sequential accepts dictionary, and maps the key (essentially layer name) to layer itself.See

https://github.com/mateuszbuda/brain-segmentation-pytorch/blob/d45f8908ab2f0246ba204c702a6161c9eb25f902/unet.py#L68

Flux currently supports similar functionality through namedtuples,

julia> m = Chain(a=Dense(3 => 6), b = Dense(6=>9))
Chain(
  a = Dense(3 => 6),                    # 24 parameters
  b = Dense(6 => 9),                    # 63 parameters
)                   # Total: 4 arrays, 87 parameters, 604 bytes.
julia>
julia> m[:a]
Dense(3 => 6)       # 24 parameters
julia>
julia> model = Chain(d = Dense(1=>2), f = m)
Chain(
  d = Dense(1 => 2),                    # 4 parameters
  f = Chain(
    a = Dense(3 => 6),                  # 24 parameters
    b = Dense(6 => 9),                  # 63 parameters
  ),
)                   # Total: 6 arrays, 91 parameters, 812 bytes.

However It would be nice to have an interface for chain which allows creating layers from dicts !

@skyleaworlder
Copy link
Contributor

Seems a simple issue. I could add a straightforward function if it's really required

@CarloLucibello
Copy link
Member

You can just splat a dictionary with symbols keys into the keyword arguments:

julia> d  = Dict(:a=> Dense(2 => 3), :b => Dense(3 => 2))
Dict{Symbol, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}} with 2 entries:
  :a => Dense(2 => 3)
  :b => Dense(3 => 2)

julia> Chain(; d...)
Chain(
  a = Dense(2 => 3),                    # 9 parameters
  b = Dense(3 => 2),                    # 8 parameters
)                   # Total: 4 arrays, 17 parameters, 324 bytes.

@CarloLucibello
Copy link
Member

I don't know if it is worth it, but we could add support for dictionaries with string keys since manipulating strings is more convenient than manipulating symbols for creating fancy layer names.

@shivance
Copy link
Author

@CarloLucibello can I take a crack at implementing this feature?

@ToucheSir
Copy link
Member

I don't think we should support this natively, because unlike Python iteration order for Dicts is undefined in Julia. String keys seem more reasonable, but I think they'd add significant complexity to the internals of Chain for very little added benefit.

@mcabbott mcabbott changed the title Improving flexibility for Chain Constructing a Chain from a dictionary Dec 27, 2022
@skyleaworlder
Copy link
Contributor

I don't think we should support this natively, because unlike Python iteration order for Dicts is undefined in Julia. String keys seem more reasonable, but I think they'd add significant complexity to the internals of Chain for very little added benefit.

Although it seems not necessary, OrderDict from OrderCollection can specify unique iterate order.

@ToucheSir
Copy link
Member

Yes, but we're not taking that on as a dependency just for this. I think it's good that we don't make constructing from a Dict easy because it avoids the potential footgun of an unexpected iteration order. Those who are aware of the risk should also know how to do something like #2142 (comment).

@mcabbott
Copy link
Member

If you want to push to something in a loop etc, you can make vector of pairs. That seems like the obvious ordered container here.

d  = Dict(:one => Dense(2 => 3), :two => Dense(3 => 4));
Chain(; d...)  # for me at least, this cannot run

v  = [:one => Dense(2 => 3), :two => Dense(3 => 4)];
Chain(; v...)

@CarloLucibello
Copy link
Member

closing as won't be implemented

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants