Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type cannot been referred with structured model function #2000

Closed
deahhh opened this issue Jun 15, 2022 · 8 comments
Closed

type cannot been referred with structured model function #2000

deahhh opened this issue Jun 15, 2022 · 8 comments

Comments

@deahhh
Copy link

deahhh commented Jun 15, 2022

using Flux
struct DConv
      conv::Conv
end
function (m::DConv)(x)
      m.conv(x)
end

Flux.@functor DConv

@code_warntype m(rand(Float32, 10,10,5,1))

output as:

MethodInstance for (::DConv)(::Array{Float32, 4})
from (m::DConv)(x) in Main at REPL[59]:1
Arguments
m::DConv
x::Array{Float32, 4}
Body_::Any_
1 ─ %1 = Base.getproperty(m, :conv)::Conv
│ %2 = (%1)(x)::Any
└── return %2

You can find that, the sum of execution time of each clause in function, is much less than execute the function directly, with BenchmarkTools.
It's so much that, I got 222ms:9ms in practice. Not the simple example.

@ToucheSir
Copy link
Member

struct DConv{C<:Conv}
      conv::C
end

Conv is a parametric type, so including it in a struct field without filling out the parameters is inherently type unstable. Any further overhead you see is likely due to not interpolating properly when calling BenchmarkTools functions.

@deahhh
Copy link
Author

deahhh commented Jun 15, 2022

@ToucheSir , thanks a lot!

@deahhh deahhh closed this as completed Jun 15, 2022
@deahhh
Copy link
Author

deahhh commented Jun 16, 2022

@ToucheSir
I found that cat is not type stable. And there is no other to use as cat.

@ToucheSir
Copy link
Member

cat is only type stable if you use dims=Val(...). That's a Base feature and has nothing to do with Flux. I see it isn't mentioned in the manual, so you may want to take it up there as a documentation issue.

@deahhh
Copy link
Author

deahhh commented Jun 17, 2022

@ToucheSir
Great!

@deahhh
Copy link
Author

deahhh commented Jun 17, 2022

@ToucheSir
Would you help me on another question:
I got 222ms execution time with @btime before the calling, while 9ms in total before almost every line inner! I thougt that's type-instable, but I konw its not now, as I make it type-stable now.

Any suggestion?

First experiment:

209.548 ms (26226 allocations: 1.44 MiB)

Second experiment:

  2.452 ns (0 allocations: 0 bytes)
  68.748 μs (279 allocations: 16.20 KiB)
  69.494 μs (279 allocations: 16.20 KiB)
  1.265 ms (4283 allocations: 254.20 KiB)
  18.716 ns (0 allocations: 0 bytes)
  19.152 ns (0 allocations: 0 bytes)
  7.785 ns (0 allocations: 0 bytes)
  7.819 ns (0 allocations: 0 bytes)
  1.091 ms (3489 allocations: 191.00 KiB)
  7.784 ns (0 allocations: 0 bytes)
  2.531 ns (0 allocations: 0 bytes)
  69.095 μs (282 allocations: 16.30 KiB)
  69.612 μs (282 allocations: 16.30 KiB)
  1.354 ms (4278 allocations: 254.02 KiB)
  18.726 ns (0 allocations: 0 bytes)
  18.546 ns (0 allocations: 0 bytes)
  8.783 ns (0 allocations: 0 bytes)
  7.671 ns (0 allocations: 0 bytes)
  1.143 ms (3515 allocations: 191.59 KiB)
  7.824 ns (0 allocations: 0 bytes)
  2.282 ns (0 allocations: 0 bytes)
  73.834 μs (281 allocations: 16.30 KiB)
  73.659 μs (281 allocations: 16.30 KiB)
  1.385 ms (4346 allocations: 256.08 KiB)
  18.680 ns (0 allocations: 0 bytes)
  18.621 ns (0 allocations: 0 bytes)
  7.814 ns (0 allocations: 0 bytes)
  8.580 ns (0 allocations: 0 bytes)
  1.190 ms (3584 allocations: 193.45 KiB)
  8.609 ns (0 allocations: 0 bytes)
  28.334 μs (113 allocations: 6.34 KiB)
  8.843 ns (0 allocations: 0 bytes)

The model is very complex:

struct MBEBlock{C<:Conv,BI<:Union{BatchNorm,InstanceNorm}, CH<:Chain,F<:Function}
    up_conv::C
    bn::Vector{BI}
    norm0::BI
    norm1::BI
    conv1::C
    conv2::Vector{CH}
    conv3::Vector{C}
    act::F
    concat::Bool
    residual::Bool
end
Flux.@functor MBEBlock

struct CoarseDecoder{C<:Conv, MBE<:MBEBlock, SMR<:SMRBlock, ECA<:ECABlocks}
    up_convs_bg::Vector{MBE}
    up_convs_mask::Vector{SMR}
    atts_mask::Vector{ECA}
    atts_bg::Vector{ECA}
    conv_final_bg::C
    use_att::Bool
end
Flux.@functor CoarseDecoder

The function as:

function (m::CoarseDecoder)(bg::T, fg, mask::T, encoder_outs=nothing) where{T}
    bg_x = bg
    mask_x = mask
    mask_outs = Vector{T}()
    bg_outs = Vector{T}()
    for (i, (up_bg, up_mask)) in enumerate(zip(m.up_convs_bg, m.up_convs_mask))
        @btime before_pool = $encoder_outs[end-($i-1)]
        before_pool = encoder_outs[end-(i-1)] #encoder_outs===nothing ? nothing : encoder_outs[end-(i-1)]
        if m.use_att
            @btime $m.atts_mask[$i]($before_pool)
            mask_before_pool = m.atts_mask[i](before_pool)
            @btime $m.atts_bg[$i]($before_pool)
            bg_before_pool = m.atts_bg[i](before_pool)
        end
        @btime $up_mask($mask_x,$mask_before_pool)
        # @code_warntype up_mask(mask_x, mask_before_pool)
        smr_outs = up_mask(mask_x, mask_before_pool)
        @btime mask_x = $smr_outs["feats"][1]
        mask_x = smr_outs["feats"][1]
        @btime primary_map, self_calibrated_map = $smr_outs["attn_maps"]
        primary_map, self_calibrated_map = smr_outs["attn_maps"]
        @btime push!($mask_outs, $primary_map)
        push!(mask_outs, primary_map)
        @btime push!($mask_outs, $self_calibrated_map)
        push!(mask_outs, self_calibrated_map)
        @btime $up_bg($bg_x, $bg_before_pool, $self_calibrated_map)
        bg_x = up_bg(bg_x, bg_before_pool, self_calibrated_map) # 这里可能有问题
        @btime push!($bg_outs, $bg_x)
        push!(bg_outs, bg_x)
    end
    if m.conv_final_bg !== nothing
        @btime $m.conv_final_bg($bg_x)
        bg_x = m.conv_final_bg(bg_x)
        @btime push!($bg_outs, $bg_x)
        push!(bg_outs, bg_x)
    end
    #@show length(bg_outs)
    #@show length(mask_outs)
    return bg_outs, mask_outs, nothing
end

@ToucheSir
Copy link
Member

@btime runs each expression multiple times, so I do not think it's doing what you think it's doing. This kind of question is better suited for a help forum like Discourse, Slack or Zulip, so I'd recommend moving it there. We reserve the issue tracker for bugs and feature requests.

@deahhh
Copy link
Author

deahhh commented Jun 20, 2022

There is a huge difference in time consumption between calling a function sentence by sentence and calling the entire function at once
@ToucheSir
Could you help me?
Thanks!
I found the problem very challenging. It should consumed less than 60ms, as pytorch does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants