-
-
Notifications
You must be signed in to change notification settings - Fork 66
/
utilities.jl
35 lines (30 loc) · 881 Bytes
/
utilities.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Utility function for classifier head of vision transformer-like models
_seconddimmean(x) = mean(x, dims = 2)[:, 1, :]
"""
weights(model)
Load the pre-trained weights for `model` using the stored artifacts.
"""
function weights(model)
try
path = joinpath(@artifact_str(model), "$model.bson")
return BSON.load(path, @__MODULE__)[:weights]
catch e
throw(ArgumentError("No pre-trained weights available for $model."))
end
end
"""
loadpretrain!(model, name)
Load the pre-trained weight artifacts matching `<name>.bson` into `model`.
"""
loadpretrain!(model, name) = Flux.loadparams!(model, weights(name))
function _maybe_big_show(io, model)
if isdefined(Flux, :_big_show)
if isnothing(get(io, :typeinfo, nothing)) # e.g. top level in REPL
Flux._big_show(io, model)
else
show(io, model)
end
else
show(io, model)
end
end