-
-
Notifications
You must be signed in to change notification settings - Fork 611
/
train.jl
123 lines (100 loc) · 2.51 KB
/
train.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
using Juno
import Zygote: Params, gradient
"""
update!(x, x̄)
Update the array `x` according to `x .-= x̄`.
"""
function update!(x::AbstractArray, x̄)
x .-= x̄
end
"""
update!(opt, p, g)
update!(opt, ps::Params, gs)
Perform an update step of the parameters `ps` (or the single parameter `p`)
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
As a result, the parameters are mutated and the optimizer's internal state may change.
"""
function update!(opt, x, x̄)
x .-= apply!(opt, x, x̄)
end
function update!(opt, xs::Params, gs)
for x in xs
gs[x] == nothing && continue
update!(opt, x, gs[x])
end
end
# Callback niceties
call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
struct StopException <: Exception end
"""
stop()
Call `Flux.stop()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to stop and exit.
# Examples
```julia
cb = function ()
accuracy() > 0.9 && Flux.stop()
end
```
"""
function stop()
throw(StopException())
end
"""
train!(loss, params, data, opt; cb)
For each datapoint `d` in `data` compute the gradient of `loss(d...)` through
backpropagation and call the optimizer `opt`.
In case datapoints `d` are of numeric array type, assume no splatting is needed
and compute the gradient of `loss(d)`.
A callback is given with the keyword argument `cb`. For example, this will print
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
"""
function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
@progress for d in data
try
if d isa AbstractArray{<:Number}
gs = gradient(ps) do
loss(d)
end
else
gs = gradient(ps) do
loss(d...)
end
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
break
else
rethrow(ex)
end
end
end
end
"""
@epochs N body
Run `body` `N` times. Mainly useful for quickly doing multiple epochs of
training in a REPL.
# Examples
```jldoctest
julia> Flux.@epochs 2 println("hello")
[ Info: Epoch 1
hello
[ Info: Epoch 2
hello
```
"""
macro epochs(n, ex)
:(@progress for i = 1:$(esc(n))
@info "Epoch $i"
$(esc(ex))
end)
end