Skip to content

Commit

Permalink
fix |> gpu bug in autosize
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 15, 2022
1 parent c7ed5fe commit a4825d4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
7 changes: 4 additions & 3 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ julia> @autosize (img..., 1, 32) Chain( # size is only needed at ru
Dense(_ => _÷4, relu, init=Flux.rand32), # can calculate output size _÷4
SkipConnection(Dense(_ => _, relu), +),
Dense(_ => 10),
) |> gpu # moves to GPU after initialisation
)
Chain(
Chain(
c = Conv((3, 3), 1 => 5, pad=1, stride=2), # 50 parameters
Expand Down Expand Up @@ -290,8 +290,6 @@ mutable struct LazyLayer
layer
end

@functor LazyLayer

function (l::LazyLayer)(x::AbstractArray, ys::AbstractArray...)
l.layer === nothing || return l.layer(x, ys...)
made = l.make(x) # for something like `Bilinear((_,__) => 7)`, perhaps need `make(xy...)`, later.
Expand Down Expand Up @@ -320,6 +318,9 @@ function ChainRulesCore.rrule(::typeof(striplazy), m)
end

params!(p::Params, x::LazyLayer, seen = IdSet()) = error("LazyLayer should never be used within params(m). Call striplazy(m) first.")

Functors.functor(::Type{<:LazyLayer}, x) = error("LazyLayer should not be walked with Functors.jl, as the arrays which Flux.gpu wants to move may not exist yet.")

function Base.show(io::IO, l::LazyLayer)
printstyled(io, "LazyLayer(", color=:light_black)
if l.layer == nothing
Expand Down
5 changes: 4 additions & 1 deletion test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ end
Dense(_ => _÷4, relu, init=Flux.rand32), # can calculate output size _÷4
SkipConnection(Dense(_ => _, relu), +),
Dense(_ => 10),
) |> gpu # moves to GPU after initialisation
)
@test randn(Float32, img..., 1, 32) |> gpu |> m |> size == (10, 32)
end

Expand All @@ -241,4 +241,7 @@ end
@test_throws Exception Flux.params(lm)
@test_throws Exception gradient(x -> sum(abs2, lm(x)), [1,2])
@test_throws Exception gradient(m -> sum(abs2, Flux.striplazy(m)([1,2])), ld)

# Can't let |> gpu act before the arrays are materialized... so it's an error:
@test_throws ErrorException @eval @autosize (1,2,3) Dense(_=>2) |> f64
end

0 comments on commit a4825d4

Please sign in to comment.