From 9e31e5384ea3521b1ec05bd55b9b1aa1430d67b7 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 19 Jul 2020 16:51:39 -0600 Subject: [PATCH 01/31] Add CTC loss and tests --- src/losses/ctc-gpu.jl | 314 ++++++++++++++++++++++++++++++++++++++++++ src/losses/ctc.jl | 197 ++++++++++++++++++++++++++ test/ctc-gpu.jl | 68 +++++++++ test/ctc.jl | 54 ++++++++ test/runtests.jl | 2 + 5 files changed, 635 insertions(+) create mode 100644 src/losses/ctc-gpu.jl create mode 100644 src/losses/ctc.jl create mode 100644 test/ctc-gpu.jl create mode 100644 test/ctc.jl diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl new file mode 100644 index 0000000000..e9b7c6f275 --- /dev/null +++ b/src/losses/ctc-gpu.jl @@ -0,0 +1,314 @@ +# GPU impelmentation + +# a port of the GPU kernels from Baidu's C++ warp-ctc package +# GitHub: https://github.com/baidu-research/warp-ctc/ +# paper: https://arxiv.org/pdf/1512.02595.pdf + +using Flux +using Statistics +using CuArrays +using CUDAnative + +function log_plus_f(p1, p2) + + isinf(p1) && return p2 + isinf(p2) && return p1 + + if p1 < p2 + p1, p2 = p2, p1 + end + + return p1 + CUDAnative.log(1+CUDAnative.exp(p2 - p1)) +end + +function countRepeats(A) + repeats = 0 + for (i,elem) in enumerate(A) + if i > 1 && A[i] == A[i-1] + repeats += 1 + end + end + return repeats +end + +function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel) + + tid = threadIdx().x + L = labelSize + T = uttLength + S = length(labelsWithBlanks) + + if L + repeats > T + return nothing + end + + labels = labelsWithBlanks + + # Corner-case checking + start = (L + repeats <= T) ? 0 : 1 + last = S > 1 ? 2 : 1 + + # Fill in first column (time step) + i = tid + while i <= last - start + alpha[start + i] = probs[labels[start + i]] + i += blockDim().x + end + + sync_threads() + + # Fill in coefficients for each time step + for t=2:T + startCurCol = (t-1) * S + startPrevCol = (t-2) * S + startProbCol = (t-1) * div(length(probs), T) + + # Corner-case checking + if tid == 1 && !(1 < S - 2*(T-t) - 1) + if start == 0 + alpha[startCurCol + 1] = probs[startProbCol + blankLabel] + alpha[startPrevCol + 1] + elseif start == 1 + alpha[startCurCol + 1] = alpha[startPrevCol + 1] + end + end + + sync_threads() + + # Fill in coefficients for each label class in the target output sequence; + # each thread will process the calculations for one class + idx = tid+1 + while idx <= S + + prevSum = log_plus_f(alpha[startPrevCol + idx], alpha[startPrevCol + idx-1]) + + if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] + prevSum = log_plus_f(prevSum, alpha[startPrevCol + idx-2]) + end + + if idx < S - 2*(T-t) - 1 + alpha[idx + startCurCol] = -Inf32 + else + alpha[startCurCol + idx] = prevSum + probs[startProbCol + labels[idx]] + end + + idx += blockDim().x + end + + sync_threads() + end + return nothing +end + +function computeBetasAndGradKernel(probs, labelSize, uttLength, + repeatsInLabel, labelsWithBlanks, + alphas, beta, output, accum, + grad, blankLabel) + + tid = threadIdx().x + L = labelSize + T = uttLength + S = 2*L + 1 + repeats = repeatsInLabel + + labels = labelsWithBlanks + + if (L+repeats) > T + return nothing + end + + # Corner-case checking + start = S > 1 ? S-2 : 0 + last = L + repeats < T ? S : S-1 + + sync_threads() + + + startCurCol = (T-1)*S + startProbCol = (T-1) * div(length(probs), T) + + i = tid + + # Calculate coefficients for last column (time step) + # then determine alpha and beta product + while i <= last - start + 1 + + beta[startCurCol + i + start] = 0 + output[startCurCol + i + start] = beta[startCurCol + i + start] + alphas[startCurCol + i + start] + i += blockDim().x + end + + sync_threads() + + # Fill in `accum` for last column (time step) + if tid == 1 + startAccCol = startProbCol + startOutputCol = startCurCol + + for i=1:S + labelIdx = labels[i] + accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) + end + end + + sync_threads() + + # Fill in `grad` for last column (time step) + idx = tid + while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) +# + startProbCol = (T - 1) * div(length(probs), T) + startOutputCol = (T - 1) * S + + s = -Inf32 + for i=1:S + s = log_plus_f(s, output[startOutputCol + i]) + end + + # ∂L/∂a (where a is activation before logsoftmax) + grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) + idx += blockDim().x + end + + sync_threads() + + # Fill in the rest of the coefficients + t = T-1 + while t >= 1 + + startCurCol = (t-1)*S + startNextCol = t*S + startProbCol = t * div(length(probs), T) + + if t < T + + idx = tid + while idx <= S-1 + + nextSum = log_plus_f(beta[startNextCol + idx] + probs[startProbCol + labels[idx]], + beta[startNextCol + idx+1] + probs[startProbCol + labels[idx+1]]) + + if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] + nextSum = log_plus_f(nextSum, + beta[startNextCol + idx + 2] + probs[startProbCol + labels[idx+2]]) + end + + if idx > 2*t + beta[idx + startCurCol] = -Inf32 + else + beta[idx + startCurCol] = nextSum + + end + + idx += blockDim().x + end + + sync_threads() + + if tid == 1 && last == S + beta[startCurCol + S] = beta[startNextCol + S] + probs[startProbCol + blankLabel] + end + + sync_threads() + + idx = tid + while idx <= S + output[startCurCol + idx] = alphas[idx+startCurCol] + beta[startCurCol + idx] + idx += blockDim().x + end + + sync_threads() + end + + + sync_threads() + + # Calculate accumulated alpha-beta products for each label class for + # each time step; used in calculating gradients + if tid == 1 + + startAccCol = (t-1) * div(length(accum), T) + startOutputCol = (t-1) * S + + for i=1:S + labelIdx = labels[i] + accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) + end + end + + sync_threads() + + idx = tid + + # Calculate gradients + while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) +# + startProbCol = (t - 1) * div(length(probs), T) + startOutputCol = (t - 1) * S + + s = -Inf32 + for i=1:S + s = log_plus_f(s, output[startOutputCol + i]) + end + + # ∂L/∂a (where a is activation before logsoftmax) + grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) + idx += blockDim().x + end + + sync_threads() + + t -= 1 + sync_threads() + # because of course, it wouldn't work without this earlier return statement + # otherwise, some of the gradient values become 0 + t == 0 && return + end + + return nothing +end + +ctc(ŷ::CuArrays.CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean + +ctc(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y)[1] |> mean + +ctc(ŷ::CuArrays.CuArray, y::CuArrays.CuArray) = ctc_(ŷ, y)[1] |> mean + +# methods for `ctc_` helper function +ctc_(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y) + +function ctc_(ŷ::CuArrays.CuArray, y) + + ŷ = logsoftmax(ŷ) + + blank = size(ŷ, 1) + labels = vec(mapslices(Base.argmax, y, dims=1)) + z = F(labels, blank) + z′ = [blank] + for label in z + push!(z′, label) + push!(z′, blank) + end + T = size(ŷ, 2) + U′ = 2*length(z) + 1 + alphas = CuArrays.fill(log(zero(ŷ[1])), T * U′) + betas = copy(alphas) + output = copy(alphas) + + nRepeats = countRepeats(labels) + + # 1 block with `U′` threads + @cuda blocks=1 threads=U′ computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) + + grads = CuArrays.fill(log(zero(ŷ[1])), length(ŷ)) + accum = copy(grads) + + @cuda blocks=1 threads=U′ computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) + + ls = reshape(collect(output), U′, T) + ls = -1 .* mapslices(logsum, ls, dims=1) |> vec + + gs = reshape(grads, size(ŷ,1), size(ŷ,2)) + + ŷ = alphas = betas = output = accum = grads = nothing + return ls, gs +end diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl new file mode 100644 index 0000000000..4524c0f54f --- /dev/null +++ b/src/losses/ctc.jl @@ -0,0 +1,197 @@ +using Flux +using Zygote: @adjoint +using Statistics + +# CPU implementation + +""" + logadd(a, b) + +Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))` +""" +function logadd(a, b) + isinf(a) && return b + isinf(b) && return a + + # always want the greater number on the left in the exponentiation; + # the magnitude difference may end up making the number very positive + # which will cause exp() to return Inf + # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be + # Inf for Float32 values + if a < b + a, b = b, a + end + + return a + log(1+exp(b-a)) +end + +""" + logsum(a) + +Sums the elements in `a` such that the result equals `log(sum(exp.(a)))` +""" +function logsum(a::AbstractArray) + local s + s = a[1] + for item in a[2:end] + s = logadd(s, item) + end + return s +end + +""" + F(A, blank) + +Removes blanks and repetitions in the sequence `A` + +This is the function `F` as defined in Graves (2012) +""" +function F(A, blank) + prev = A[1] + z = [prev] + for curr in A[2:end] + if curr != prev && curr != blank + push!(z, curr) +`` end + prev = curr + end + return z +end + +""" + addBlanks(z) + +Adds blanks to the start and end of `z`, and between item in `z` +""" +function addBlanks(z, blank) + + z′ = [blank] + for label in z + push!(z′, label) + push!(z′, blank) + end + return z′ +end + +function ctc_(ŷ, y) + + ŷ = logsoftmax(ŷ) + blank = size(ŷ, 1) + + z = F(Base.argmax.([y[:,i] for i=1:size(y,2)]), blank) + z′ = addBlanks(z, blank) + T = size(ŷ, 2) + U = length(z) + U′ = length(z′) + + # Calculate α coefficients, from the upper-left, to the bottom-right + α = zeros(Float64, T, U′) + for t=1:T + for u=1:U′ + if t == u == 1 +# α[t,u] = ŷ[t, blank] + α[t,u] = ŷ[blank, t] + elseif t == 1 && u == 2 +# α[t,u] = ŷ[t, z′[2]] + α[t,u] = ŷ[z′[2], t] + elseif t == 1 && u > 2 + α[t,u] = -Inf + elseif u < U′ - 2(T - t) - 1 + α[t,u] = -Inf + else + idx = u - 2 + idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) + idx = max(1, idx) + + α[t,u] = ŷ[z′[u], t] + logsum(α[t-1, idx:u]) + end + end + end + + # Calculate beta coefficients, from the bottom-right, to the upper-left + β = zeros(Float64, T, U′) + for i=1:length(β) + β[i] = -Inf + end + + # Fill bottom-right corner so bounding errors can be avoided + # by starting `u` at `U′-1` + β[T,U′] = 0.0 + + for t=T:-1:1 + for u=(U′-1):-1:1 + if t == T && u >= U′ - 1 + β[t,u] = 0.0 + elseif t == T && u < U′ - 1 + continue + elseif u > 2t || u > U′ + 1 + continue + else + idx = u+2 + idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) + idx = min(idx, U′) + + v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] + β[t, u] = logsum(v) + end + end + if t < T-1 + β[t, U′] = β[t+1, U′] + ŷ[blank, t] + end + end + + # Loss at each time t is taken as the sum of the product of the α and β coefficients for + # all the label classes at time t + losses = Vector() + for t=1:T + v = [α[t,u] + β[t,u] for u in 1:U′] + push!(losses, -logsum(v)) + end + + # `accum` will hold the sum of the α and β coefficients for + # each label class at time t; used in calculating gradients + accum = fill(-Inf, size(ŷ)) + grads = fill(-Inf, size(ŷ)) + + for t=1:T + for u=1:U′ + accum[z′[u], t] = logadd(accum[z′[u], t], α[t,u] + β[t,u]) + end + for u=1:size(grads, 1) + grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -losses[t]) + end + end + + losses = [x for x in losses] + + return losses, grads +end + +""" + ctc(ŷ, y) + +Computes the connectionist temporal classification loss between `ŷ` +and `y`. + +Both `ŷ` and `y` must be classes-by-time matrices, i.e., each row +represents a class and each column represents a time step. +Additionally, the `logsoftmax` function will be applied to `ŷ`, so +it must be the raw activation values from the neural network and +not, for example, the activations after being passed through a +`softmax` activation function. + +Used for sequence to sequence classification problems such as +speech recognition and handwriting recognition where the exact +time-alignment of the output (e.g., letters) is not needed to +solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf) +or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7) +for mathematical details. +""" +function ctc(ŷ::Array, y::Array) + return ctc_(ŷ, y)[1] |> mean +end + +@adjoint function ctc_(ŷ, y) + ls, gs = ctc_(ŷ, y) + return mean(ls), Δ -> (Δ .* gs, Δ) +end diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl new file mode 100644 index 0000000000..82bb809d83 --- /dev/null +++ b/test/ctc-gpu.jl @@ -0,0 +1,68 @@ +using Test +using Flux +using Flux: ctc_ +using Zygote: gradient +using LinearAlgebra +using CuArrays +using Statistics + +# Custom function to check numerical gradient of ctc loss, +# based on `ngradient` in `Tracker.jl` +# +# Needs to check loss as defined at a particular time step +# related to the change in x because slight deviations in +# input propagate through further time steps, intrinsically +# causing the gradients to change and thus not be comparable +# between the numeric and analytical definitions +function ctc_ngradient(xs...) + f = ctc_ + grads = zero.(xs) + for (x, Δ) in zip(xs, grads), i in 1:length(x) + δ = sqrt(eps()) + t = div(i-1, size(x, 1)) + 1 + tmp = x[i] + x[i] = tmp - δ/2 + y1 = f(xs...)[1][t] + x[i] = tmp + δ/2 + y2 = f(xs...)[1][t] + x[i] = tmp + Δ[i] = (y2-y1)/δ + end + return grads +end + +@testset "ctc-gpu" begin + + x = rand(10, 50) + y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) + + x_cu = CuArray(x) + y_cu = CuArray(y) + + g1 = gradient(ctc, x_cu, y_cu)[1] + g1 = g1 |> collect + + g2 = ctc_ngradient(x, y)[1] + + @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) + + # test that GPU loss matches CPU implementation + + l1 = Flux.ctc_(x_cu, y_cu)[1] + l2 = Flux.ctc_(x, y)[1] + + @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) + + # tests using hand-calculated values + + x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray + y_cu = [1 1 0; 0 0 1; 0 0 0] |> CuArray + + @test mean(ctc(x_cu, y_cu)) ≈ 3.6990738275138035 + + g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] + ghat = gradient(ctc, x_cu, y_cu)[1] |> collect + + @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + +end diff --git a/test/ctc.jl b/test/ctc.jl new file mode 100644 index 0000000000..1771c7dff2 --- /dev/null +++ b/test/ctc.jl @@ -0,0 +1,54 @@ +using Test +using Flux +using Flux: ctc_ +using Zygote: gradient +using LinearAlgebra + +# Custom function to check numerical gradient of ctc loss, +# based on `ngradient` in `Tracker.jl` +# +# Needs to check loss as defined at a particular time step +# related to the change in x because slight deviations in +# input propagate through further time steps, intrinsically +# causing the gradients to change and thus not be comparable +# between the numeric and analytical definitions +function ctc_ngradient(xs...) + f = ctc_ + grads = zero.(xs) + for (x, Δ) in zip(xs, grads), i in 1:length(x) + δ = sqrt(eps()) + t = div(i-1, size(x, 1)) + 1 + tmp = x[i] + x[i] = tmp - δ/2 + y1 = f(xs...)[1][t] + x[i] = tmp + δ/2 + y2 = f(xs...)[1][t] + x[i] = tmp + Δ[i] = (y2-y1)/δ + end + return grads +end + +@testset "ctc" begin + + x = rand(10, 50) + y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) + + g1 = gradient(ctc, x, y)[1] + g1 = g1 + g2 = ctc_ngradient(x, y)[1] + + @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) + + # tests using hand-calculated values + + x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] + y = [1 1 0; 0 0 1; 0 0 0] + + @test ctc(x, y) ≈ 3.6990738275138035 + g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] + ghat = gradient(ctc, x, y)[1] + + @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + +end diff --git a/test/runtests.jl b/test/runtests.jl index c5861cd25c..c723f7aade 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,8 @@ end @testset "Losses" begin include("losses.jl") + include("ctc.jl") + if Flux.use_cuda[] include("ctc-gpu.jl") end end @testset "Layers" begin From 37efaa0165c75c2d88f24ab8e94e797d24fa051f Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 19 Jul 2020 22:35:03 -0600 Subject: [PATCH 02/31] Add ctc to Losses module --- src/losses/Losses.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index c4e5fb0e4b..2289f0c9f8 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -16,9 +16,11 @@ export mse, mae, msle, tversky_loss, dice_coeff_loss, poisson_loss, - hinge_loss, squared_hinge_loss + hinge_loss, siquared_hinge_loss, + ctc include("utils.jl") include("functions.jl") +include("ctc.jl") -end #module \ No newline at end of file +end #module From f471337055c97d04f0582a429106af42faade07f Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Mon, 12 Oct 2020 18:50:54 -0600 Subject: [PATCH 03/31] General updates General updates after changing example networks to use map instead of dot broadcasting. --- src/losses/ctc-gpu.jl | 143 ++++++++++++++++++------------------------ src/losses/ctc.jl | 69 +++++++++----------- test/runtests.jl | 7 ++- 3 files changed, 92 insertions(+), 127 deletions(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index e9b7c6f275..48dc7a5f28 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -6,8 +6,9 @@ using Flux using Statistics -using CuArrays -using CUDAnative +using CUDA + +const MAX_THREADS = 256 function log_plus_f(p1, p2) @@ -18,7 +19,7 @@ function log_plus_f(p1, p2) p1, p2 = p2, p1 end - return p1 + CUDAnative.log(1+CUDAnative.exp(p2 - p1)) + return p1 + CUDA.log(1+CUDA.exp(p2 - p1)) end function countRepeats(A) @@ -51,7 +52,7 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB # Fill in first column (time step) i = tid while i <= last - start - alpha[start + i] = probs[labels[start + i]] + alpha[start+i, 1] = probs[labels[start+i], 1] i += blockDim().x end @@ -59,16 +60,13 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB # Fill in coefficients for each time step for t=2:T - startCurCol = (t-1) * S - startPrevCol = (t-2) * S - startProbCol = (t-1) * div(length(probs), T) # Corner-case checking if tid == 1 && !(1 < S - 2*(T-t) - 1) if start == 0 - alpha[startCurCol + 1] = probs[startProbCol + blankLabel] + alpha[startPrevCol + 1] + alpha[1, t] = probs[blankLabel, t] + alpha[1, t-1] elseif start == 1 - alpha[startCurCol + 1] = alpha[startPrevCol + 1] + alpha[1, t] = alpha[1, t-1] end end @@ -79,16 +77,16 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB idx = tid+1 while idx <= S - prevSum = log_plus_f(alpha[startPrevCol + idx], alpha[startPrevCol + idx-1]) + prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1]) if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] - prevSum = log_plus_f(prevSum, alpha[startPrevCol + idx-2]) + prevSum = log_plus_f(prevSum, alpha[idx-2, t-1]) end if idx < S - 2*(T-t) - 1 - alpha[idx + startCurCol] = -Inf32 + alpha[idx, t] = -Inf32 else - alpha[startCurCol + idx] = prevSum + probs[startProbCol + labels[idx]] + alpha[idx, t] = prevSum + probs[labels[idx], t] end idx += blockDim().x @@ -122,31 +120,23 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, sync_threads() - - startCurCol = (T-1)*S - startProbCol = (T-1) * div(length(probs), T) - i = tid # Calculate coefficients for last column (time step) # then determine alpha and beta product while i <= last - start + 1 - - beta[startCurCol + i + start] = 0 - output[startCurCol + i + start] = beta[startCurCol + i + start] + alphas[startCurCol + i + start] + beta[i+start, T] = 0 + output[i+start, T] = beta[i+start, T] + alphas[i+start, T] i += blockDim().x end sync_threads() # Fill in `accum` for last column (time step) - if tid == 1 - startAccCol = startProbCol - startOutputCol = startCurCol - + if tid == 1 for i=1:S labelIdx = labels[i] - accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) + accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) end end @@ -154,18 +144,16 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Fill in `grad` for last column (time step) idx = tid - while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) -# - startProbCol = (T - 1) * div(length(probs), T) - startOutputCol = (T - 1) * S + while idx <= size(grad, 1) s = -Inf32 + for i=1:S - s = log_plus_f(s, output[startOutputCol + i]) + s = log_plus_f(s, output[i, T]) end # ∂L/∂a (where a is activation before logsoftmax) - grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) + grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s) idx += blockDim().x end @@ -174,28 +162,29 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Fill in the rest of the coefficients t = T-1 while t >= 1 - - startCurCol = (t-1)*S - startNextCol = t*S - startProbCol = t * div(length(probs), T) - if t < T idx = tid - while idx <= S-1 + # while idx <= S-1 + while idx <= S - nextSum = log_plus_f(beta[startNextCol + idx] + probs[startProbCol + labels[idx]], - beta[startNextCol + idx+1] + probs[startProbCol + labels[idx+1]]) + nextSum = beta[idx, t+1] + probs[labels[idx], t+1] + + if idx < S + + nextSum = log_plus_f(nextSum, + beta[idx+1, t+1] + probs[labels[idx+1], t+1]) + end if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] nextSum = log_plus_f(nextSum, - beta[startNextCol + idx + 2] + probs[startProbCol + labels[idx+2]]) + beta[idx + 2, t+1] + probs[labels[idx+2], t+1]) end if idx > 2*t - beta[idx + startCurCol] = -Inf32 + beta[idx, t] = -Inf32 else - beta[idx + startCurCol] = nextSum + beta[idx, t] = nextSum end @@ -205,14 +194,14 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, sync_threads() if tid == 1 && last == S - beta[startCurCol + S] = beta[startNextCol + S] + probs[startProbCol + blankLabel] + beta[S, t] = beta[S, t] + probs[blankLabel, t+1] end sync_threads() idx = tid while idx <= S - output[startCurCol + idx] = alphas[idx+startCurCol] + beta[startCurCol + idx] + output[idx, t] = alphas[idx, t] + beta[idx, t] idx += blockDim().x end @@ -224,14 +213,10 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Calculate accumulated alpha-beta products for each label class for # each time step; used in calculating gradients - if tid == 1 - - startAccCol = (t-1) * div(length(accum), T) - startOutputCol = (t-1) * S - + if tid == 1 for i=1:S labelIdx = labels[i] - accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) + accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) end end @@ -240,18 +225,16 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, idx = tid # Calculate gradients - while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) -# - startProbCol = (t - 1) * div(length(probs), T) - startOutputCol = (t - 1) * S + while idx <= size(grad, 1) s = -Inf32 + for i=1:S - s = log_plus_f(s, output[startOutputCol + i]) + s = log_plus_f(s, output[i, t]) end # ∂L/∂a (where a is activation before logsoftmax) - grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) + grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s) idx += blockDim().x end @@ -259,56 +242,50 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, t -= 1 sync_threads() - # because of course, it wouldn't work without this earlier return statement - # otherwise, some of the gradient values become 0 - t == 0 && return end return nothing end -ctc(ŷ::CuArrays.CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean - -ctc(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y)[1] |> mean - -ctc(ŷ::CuArrays.CuArray, y::CuArrays.CuArray) = ctc_(ŷ, y)[1] |> mean - # methods for `ctc_` helper function -ctc_(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y) +ctc(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean +ctc(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean +ctc(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean +ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y)) -function ctc_(ŷ::CuArrays.CuArray, y) +function ctc_(ŷ::CuArray, y) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - labels = vec(mapslices(Base.argmax, y, dims=1)) + labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)] z = F(labels, blank) z′ = [blank] for label in z push!(z′, label) push!(z′, blank) end + T = size(ŷ, 2) U′ = 2*length(z) + 1 - alphas = CuArrays.fill(log(zero(ŷ[1])), T * U′) - betas = copy(alphas) - output = copy(alphas) + + alphas = CUDA.fill(log(zero(ŷ[1])), U′, T) + betas = CUDA.fill(log(zero(ŷ[1])), U′, T) + output = CUDA.fill(log(zero(ŷ[1])), U′, T) nRepeats = countRepeats(labels) + nThreads = min(U′, MAX_THREADS) - # 1 block with `U′` threads - @cuda blocks=1 threads=U′ computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) + @cuda blocks=1 threads=nThreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) - grads = CuArrays.fill(log(zero(ŷ[1])), length(ŷ)) - accum = copy(grads) + grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) + accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) - @cuda blocks=1 threads=U′ computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) + @cuda blocks=1 threads=nThreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) - ls = reshape(collect(output), U′, T) - ls = -1 .* mapslices(logsum, ls, dims=1) |> vec - - gs = reshape(grads, size(ŷ,1), size(ŷ,2)) - - ŷ = alphas = betas = output = accum = grads = nothing - return ls, gs + ls = collect(output) + ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)]) + + ŷ = alphas = betas = output = accum = nothing + return ls, grads end diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 4524c0f54f..d7c9e9b7d3 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -52,7 +52,7 @@ function F(A, blank) for curr in A[2:end] if curr != prev && curr != blank push!(z, curr) -`` end + end prev = curr end return z @@ -75,6 +75,8 @@ end function ctc_(ŷ, y) + typedZero = zero(ŷ[1]) + ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) @@ -85,19 +87,17 @@ function ctc_(ŷ, y) U′ = length(z′) # Calculate α coefficients, from the upper-left, to the bottom-right - α = zeros(Float64, T, U′) + α = fill(typedZero, T, U′) for t=1:T for u=1:U′ if t == u == 1 -# α[t,u] = ŷ[t, blank] α[t,u] = ŷ[blank, t] elseif t == 1 && u == 2 -# α[t,u] = ŷ[t, z′[2]] α[t,u] = ŷ[z′[2], t] elseif t == 1 && u > 2 - α[t,u] = -Inf + α[t,u] = log(typedZero) elseif u < U′ - 2(T - t) - 1 - α[t,u] = -Inf + α[t,u] = log(typedZero) else idx = u - 2 idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) @@ -109,49 +109,37 @@ function ctc_(ŷ, y) end # Calculate beta coefficients, from the bottom-right, to the upper-left - β = zeros(Float64, T, U′) - for i=1:length(β) - β[i] = -Inf - end + β = fill(log(typedZero), T, U′) # Fill bottom-right corner so bounding errors can be avoided # by starting `u` at `U′-1` - β[T,U′] = 0.0 - - for t=T:-1:1 - for u=(U′-1):-1:1 - if t == T && u >= U′ - 1 - β[t,u] = 0.0 - elseif t == T && u < U′ - 1 + β[T,U′] = typedZero + β[T,U′-1] = typedZero + + # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 + for t=(T-1):-1:1 + for u=U′:-1:1 + if u > 2t || u > U′ + 1 continue - elseif u > 2t || u > U′ + 1 - continue - else - idx = u+2 - idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) - idx = min(idx, U′) - - v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] - β[t, u] = logsum(v) end - end - if t < T-1 - β[t, U′] = β[t+1, U′] + ŷ[blank, t] + + idx = u+2 + idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) + idx = min(idx, U′) + + v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] + β[t, u] = logsum(v) end end + - # Loss at each time t is taken as the sum of the product of the α and β coefficients for - # all the label classes at time t - losses = Vector() - for t=1:T - v = [α[t,u] + β[t,u] for u in 1:U′] - push!(losses, -logsum(v)) - end + # Loss at each time t is taken as the sum of the product (sum in log space) of the + # α and β coefficients for all the label classes at time t + αβ = α + β + losses = -1 .* mapslices(logsum, αβ, dims=2) - # `accum` will hold the sum of the α and β coefficients for - # each label class at time t; used in calculating gradients - accum = fill(-Inf, size(ŷ)) - grads = fill(-Inf, size(ŷ)) + accum = fill(log(typedZero), size(ŷ)) + grads = fill(log(typedZero), size(ŷ)) for t=1:T for u=1:U′ @@ -163,7 +151,6 @@ function ctc_(ŷ, y) end losses = [x for x in losses] - return losses, grads end diff --git a/test/runtests.jl b/test/runtests.jl index c723f7aade..8d1f98d995 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using Random, Statistics, LinearAlgebra using IterTools: ncycle Random.seed!(0) - +#= @testset "Utils" begin include("utils.jl") end @@ -21,13 +21,13 @@ end @testset "Data" begin include("data.jl") end - +=# @testset "Losses" begin include("losses.jl") include("ctc.jl") if Flux.use_cuda[] include("ctc-gpu.jl") end end - +#= @testset "Layers" begin include("layers/basic.jl") include("layers/normalisation.jl") @@ -50,3 +50,4 @@ end doctest(Flux) end end +=# From b19b88cf8d0485b61324b6e9a24ed3deb2d1f38c Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Mon, 12 Oct 2020 20:08:17 -0600 Subject: [PATCH 04/31] Reverting bad merge --- test/runtests.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8d1f98d995..c723f7aade 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using Random, Statistics, LinearAlgebra using IterTools: ncycle Random.seed!(0) -#= + @testset "Utils" begin include("utils.jl") end @@ -21,13 +21,13 @@ end @testset "Data" begin include("data.jl") end -=# + @testset "Losses" begin include("losses.jl") include("ctc.jl") if Flux.use_cuda[] include("ctc-gpu.jl") end end -#= + @testset "Layers" begin include("layers/basic.jl") include("layers/normalisation.jl") @@ -50,4 +50,3 @@ end doctest(Flux) end end -=# From da3564b7e6014a4d5699210d3df9ef6c09085b73 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Mon, 12 Oct 2020 20:11:13 -0600 Subject: [PATCH 05/31] Revert "General updates" This reverts commit f471337055c97d04f0582a429106af42faade07f. --- src/losses/ctc-gpu.jl | 143 ++++++++++++++++++++++++------------------ src/losses/ctc.jl | 69 +++++++++++--------- 2 files changed, 124 insertions(+), 88 deletions(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index 48dc7a5f28..e9b7c6f275 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -6,9 +6,8 @@ using Flux using Statistics -using CUDA - -const MAX_THREADS = 256 +using CuArrays +using CUDAnative function log_plus_f(p1, p2) @@ -19,7 +18,7 @@ function log_plus_f(p1, p2) p1, p2 = p2, p1 end - return p1 + CUDA.log(1+CUDA.exp(p2 - p1)) + return p1 + CUDAnative.log(1+CUDAnative.exp(p2 - p1)) end function countRepeats(A) @@ -52,7 +51,7 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB # Fill in first column (time step) i = tid while i <= last - start - alpha[start+i, 1] = probs[labels[start+i], 1] + alpha[start + i] = probs[labels[start + i]] i += blockDim().x end @@ -60,13 +59,16 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB # Fill in coefficients for each time step for t=2:T + startCurCol = (t-1) * S + startPrevCol = (t-2) * S + startProbCol = (t-1) * div(length(probs), T) # Corner-case checking if tid == 1 && !(1 < S - 2*(T-t) - 1) if start == 0 - alpha[1, t] = probs[blankLabel, t] + alpha[1, t-1] + alpha[startCurCol + 1] = probs[startProbCol + blankLabel] + alpha[startPrevCol + 1] elseif start == 1 - alpha[1, t] = alpha[1, t-1] + alpha[startCurCol + 1] = alpha[startPrevCol + 1] end end @@ -77,16 +79,16 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB idx = tid+1 while idx <= S - prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1]) + prevSum = log_plus_f(alpha[startPrevCol + idx], alpha[startPrevCol + idx-1]) if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] - prevSum = log_plus_f(prevSum, alpha[idx-2, t-1]) + prevSum = log_plus_f(prevSum, alpha[startPrevCol + idx-2]) end if idx < S - 2*(T-t) - 1 - alpha[idx, t] = -Inf32 + alpha[idx + startCurCol] = -Inf32 else - alpha[idx, t] = prevSum + probs[labels[idx], t] + alpha[startCurCol + idx] = prevSum + probs[startProbCol + labels[idx]] end idx += blockDim().x @@ -120,23 +122,31 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, sync_threads() + + startCurCol = (T-1)*S + startProbCol = (T-1) * div(length(probs), T) + i = tid # Calculate coefficients for last column (time step) # then determine alpha and beta product while i <= last - start + 1 - beta[i+start, T] = 0 - output[i+start, T] = beta[i+start, T] + alphas[i+start, T] + + beta[startCurCol + i + start] = 0 + output[startCurCol + i + start] = beta[startCurCol + i + start] + alphas[startCurCol + i + start] i += blockDim().x end sync_threads() # Fill in `accum` for last column (time step) - if tid == 1 + if tid == 1 + startAccCol = startProbCol + startOutputCol = startCurCol + for i=1:S labelIdx = labels[i] - accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) + accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) end end @@ -144,16 +154,18 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Fill in `grad` for last column (time step) idx = tid - while idx <= size(grad, 1) + while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) +# + startProbCol = (T - 1) * div(length(probs), T) + startOutputCol = (T - 1) * S s = -Inf32 - for i=1:S - s = log_plus_f(s, output[i, T]) + s = log_plus_f(s, output[startOutputCol + i]) end # ∂L/∂a (where a is activation before logsoftmax) - grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s) + grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) idx += blockDim().x end @@ -162,29 +174,28 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Fill in the rest of the coefficients t = T-1 while t >= 1 + + startCurCol = (t-1)*S + startNextCol = t*S + startProbCol = t * div(length(probs), T) + if t < T idx = tid - # while idx <= S-1 - while idx <= S + while idx <= S-1 - nextSum = beta[idx, t+1] + probs[labels[idx], t+1] - - if idx < S - - nextSum = log_plus_f(nextSum, - beta[idx+1, t+1] + probs[labels[idx+1], t+1]) - end + nextSum = log_plus_f(beta[startNextCol + idx] + probs[startProbCol + labels[idx]], + beta[startNextCol + idx+1] + probs[startProbCol + labels[idx+1]]) if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] nextSum = log_plus_f(nextSum, - beta[idx + 2, t+1] + probs[labels[idx+2], t+1]) + beta[startNextCol + idx + 2] + probs[startProbCol + labels[idx+2]]) end if idx > 2*t - beta[idx, t] = -Inf32 + beta[idx + startCurCol] = -Inf32 else - beta[idx, t] = nextSum + beta[idx + startCurCol] = nextSum end @@ -194,14 +205,14 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, sync_threads() if tid == 1 && last == S - beta[S, t] = beta[S, t] + probs[blankLabel, t+1] + beta[startCurCol + S] = beta[startNextCol + S] + probs[startProbCol + blankLabel] end sync_threads() idx = tid while idx <= S - output[idx, t] = alphas[idx, t] + beta[idx, t] + output[startCurCol + idx] = alphas[idx+startCurCol] + beta[startCurCol + idx] idx += blockDim().x end @@ -213,10 +224,14 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Calculate accumulated alpha-beta products for each label class for # each time step; used in calculating gradients - if tid == 1 + if tid == 1 + + startAccCol = (t-1) * div(length(accum), T) + startOutputCol = (t-1) * S + for i=1:S labelIdx = labels[i] - accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) + accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) end end @@ -225,16 +240,18 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, idx = tid # Calculate gradients - while idx <= size(grad, 1) + while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) +# + startProbCol = (t - 1) * div(length(probs), T) + startOutputCol = (t - 1) * S s = -Inf32 - for i=1:S - s = log_plus_f(s, output[i, t]) + s = log_plus_f(s, output[startOutputCol + i]) end # ∂L/∂a (where a is activation before logsoftmax) - grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s) + grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) idx += blockDim().x end @@ -242,50 +259,56 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, t -= 1 sync_threads() + # because of course, it wouldn't work without this earlier return statement + # otherwise, some of the gradient values become 0 + t == 0 && return end return nothing end +ctc(ŷ::CuArrays.CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean + +ctc(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y)[1] |> mean + +ctc(ŷ::CuArrays.CuArray, y::CuArrays.CuArray) = ctc_(ŷ, y)[1] |> mean + # methods for `ctc_` helper function -ctc(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean -ctc(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean -ctc(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean -ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y)) +ctc_(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y) -function ctc_(ŷ::CuArray, y) +function ctc_(ŷ::CuArrays.CuArray, y) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)] + labels = vec(mapslices(Base.argmax, y, dims=1)) z = F(labels, blank) z′ = [blank] for label in z push!(z′, label) push!(z′, blank) end - T = size(ŷ, 2) U′ = 2*length(z) + 1 - - alphas = CUDA.fill(log(zero(ŷ[1])), U′, T) - betas = CUDA.fill(log(zero(ŷ[1])), U′, T) - output = CUDA.fill(log(zero(ŷ[1])), U′, T) + alphas = CuArrays.fill(log(zero(ŷ[1])), T * U′) + betas = copy(alphas) + output = copy(alphas) nRepeats = countRepeats(labels) - nThreads = min(U′, MAX_THREADS) - @cuda blocks=1 threads=nThreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) + # 1 block with `U′` threads + @cuda blocks=1 threads=U′ computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) - grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) - accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) + grads = CuArrays.fill(log(zero(ŷ[1])), length(ŷ)) + accum = copy(grads) - @cuda blocks=1 threads=nThreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) + @cuda blocks=1 threads=U′ computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) - ls = collect(output) - ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)]) - - ŷ = alphas = betas = output = accum = nothing - return ls, grads + ls = reshape(collect(output), U′, T) + ls = -1 .* mapslices(logsum, ls, dims=1) |> vec + + gs = reshape(grads, size(ŷ,1), size(ŷ,2)) + + ŷ = alphas = betas = output = accum = grads = nothing + return ls, gs end diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index d7c9e9b7d3..4524c0f54f 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -52,7 +52,7 @@ function F(A, blank) for curr in A[2:end] if curr != prev && curr != blank push!(z, curr) - end +`` end prev = curr end return z @@ -75,8 +75,6 @@ end function ctc_(ŷ, y) - typedZero = zero(ŷ[1]) - ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) @@ -87,17 +85,19 @@ function ctc_(ŷ, y) U′ = length(z′) # Calculate α coefficients, from the upper-left, to the bottom-right - α = fill(typedZero, T, U′) + α = zeros(Float64, T, U′) for t=1:T for u=1:U′ if t == u == 1 +# α[t,u] = ŷ[t, blank] α[t,u] = ŷ[blank, t] elseif t == 1 && u == 2 +# α[t,u] = ŷ[t, z′[2]] α[t,u] = ŷ[z′[2], t] elseif t == 1 && u > 2 - α[t,u] = log(typedZero) + α[t,u] = -Inf elseif u < U′ - 2(T - t) - 1 - α[t,u] = log(typedZero) + α[t,u] = -Inf else idx = u - 2 idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) @@ -109,37 +109,49 @@ function ctc_(ŷ, y) end # Calculate beta coefficients, from the bottom-right, to the upper-left - β = fill(log(typedZero), T, U′) + β = zeros(Float64, T, U′) + for i=1:length(β) + β[i] = -Inf + end # Fill bottom-right corner so bounding errors can be avoided # by starting `u` at `U′-1` - β[T,U′] = typedZero - β[T,U′-1] = typedZero - - # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 - for t=(T-1):-1:1 - for u=U′:-1:1 - if u > 2t || u > U′ + 1 + β[T,U′] = 0.0 + + for t=T:-1:1 + for u=(U′-1):-1:1 + if t == T && u >= U′ - 1 + β[t,u] = 0.0 + elseif t == T && u < U′ - 1 continue - end - - idx = u+2 - idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) - idx = min(idx, U′) + elseif u > 2t || u > U′ + 1 + continue + else + idx = u+2 + idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) + idx = min(idx, U′) - v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] - β[t, u] = logsum(v) + v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] + β[t, u] = logsum(v) + end + end + if t < T-1 + β[t, U′] = β[t+1, U′] + ŷ[blank, t] end end - - # Loss at each time t is taken as the sum of the product (sum in log space) of the - # α and β coefficients for all the label classes at time t - αβ = α + β - losses = -1 .* mapslices(logsum, αβ, dims=2) + # Loss at each time t is taken as the sum of the product of the α and β coefficients for + # all the label classes at time t + losses = Vector() + for t=1:T + v = [α[t,u] + β[t,u] for u in 1:U′] + push!(losses, -logsum(v)) + end - accum = fill(log(typedZero), size(ŷ)) - grads = fill(log(typedZero), size(ŷ)) + # `accum` will hold the sum of the α and β coefficients for + # each label class at time t; used in calculating gradients + accum = fill(-Inf, size(ŷ)) + grads = fill(-Inf, size(ŷ)) for t=1:T for u=1:U′ @@ -151,6 +163,7 @@ function ctc_(ŷ, y) end losses = [x for x in losses] + return losses, grads end From d8242c09e433f7690f09629ba9450c3f2b666aaf Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Mon, 12 Oct 2020 20:14:11 -0600 Subject: [PATCH 06/31] General ctc updates --- src/losses/ctc-gpu.jl | 143 ++++++++++++++++++------------------------ src/losses/ctc.jl | 69 +++++++++----------- 2 files changed, 88 insertions(+), 124 deletions(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index e9b7c6f275..48dc7a5f28 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -6,8 +6,9 @@ using Flux using Statistics -using CuArrays -using CUDAnative +using CUDA + +const MAX_THREADS = 256 function log_plus_f(p1, p2) @@ -18,7 +19,7 @@ function log_plus_f(p1, p2) p1, p2 = p2, p1 end - return p1 + CUDAnative.log(1+CUDAnative.exp(p2 - p1)) + return p1 + CUDA.log(1+CUDA.exp(p2 - p1)) end function countRepeats(A) @@ -51,7 +52,7 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB # Fill in first column (time step) i = tid while i <= last - start - alpha[start + i] = probs[labels[start + i]] + alpha[start+i, 1] = probs[labels[start+i], 1] i += blockDim().x end @@ -59,16 +60,13 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB # Fill in coefficients for each time step for t=2:T - startCurCol = (t-1) * S - startPrevCol = (t-2) * S - startProbCol = (t-1) * div(length(probs), T) # Corner-case checking if tid == 1 && !(1 < S - 2*(T-t) - 1) if start == 0 - alpha[startCurCol + 1] = probs[startProbCol + blankLabel] + alpha[startPrevCol + 1] + alpha[1, t] = probs[blankLabel, t] + alpha[1, t-1] elseif start == 1 - alpha[startCurCol + 1] = alpha[startPrevCol + 1] + alpha[1, t] = alpha[1, t-1] end end @@ -79,16 +77,16 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB idx = tid+1 while idx <= S - prevSum = log_plus_f(alpha[startPrevCol + idx], alpha[startPrevCol + idx-1]) + prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1]) if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] - prevSum = log_plus_f(prevSum, alpha[startPrevCol + idx-2]) + prevSum = log_plus_f(prevSum, alpha[idx-2, t-1]) end if idx < S - 2*(T-t) - 1 - alpha[idx + startCurCol] = -Inf32 + alpha[idx, t] = -Inf32 else - alpha[startCurCol + idx] = prevSum + probs[startProbCol + labels[idx]] + alpha[idx, t] = prevSum + probs[labels[idx], t] end idx += blockDim().x @@ -122,31 +120,23 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, sync_threads() - - startCurCol = (T-1)*S - startProbCol = (T-1) * div(length(probs), T) - i = tid # Calculate coefficients for last column (time step) # then determine alpha and beta product while i <= last - start + 1 - - beta[startCurCol + i + start] = 0 - output[startCurCol + i + start] = beta[startCurCol + i + start] + alphas[startCurCol + i + start] + beta[i+start, T] = 0 + output[i+start, T] = beta[i+start, T] + alphas[i+start, T] i += blockDim().x end sync_threads() # Fill in `accum` for last column (time step) - if tid == 1 - startAccCol = startProbCol - startOutputCol = startCurCol - + if tid == 1 for i=1:S labelIdx = labels[i] - accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) + accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) end end @@ -154,18 +144,16 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Fill in `grad` for last column (time step) idx = tid - while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) -# - startProbCol = (T - 1) * div(length(probs), T) - startOutputCol = (T - 1) * S + while idx <= size(grad, 1) s = -Inf32 + for i=1:S - s = log_plus_f(s, output[startOutputCol + i]) + s = log_plus_f(s, output[i, T]) end # ∂L/∂a (where a is activation before logsoftmax) - grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) + grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s) idx += blockDim().x end @@ -174,28 +162,29 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Fill in the rest of the coefficients t = T-1 while t >= 1 - - startCurCol = (t-1)*S - startNextCol = t*S - startProbCol = t * div(length(probs), T) - if t < T idx = tid - while idx <= S-1 + # while idx <= S-1 + while idx <= S - nextSum = log_plus_f(beta[startNextCol + idx] + probs[startProbCol + labels[idx]], - beta[startNextCol + idx+1] + probs[startProbCol + labels[idx+1]]) + nextSum = beta[idx, t+1] + probs[labels[idx], t+1] + + if idx < S + + nextSum = log_plus_f(nextSum, + beta[idx+1, t+1] + probs[labels[idx+1], t+1]) + end if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] nextSum = log_plus_f(nextSum, - beta[startNextCol + idx + 2] + probs[startProbCol + labels[idx+2]]) + beta[idx + 2, t+1] + probs[labels[idx+2], t+1]) end if idx > 2*t - beta[idx + startCurCol] = -Inf32 + beta[idx, t] = -Inf32 else - beta[idx + startCurCol] = nextSum + beta[idx, t] = nextSum end @@ -205,14 +194,14 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, sync_threads() if tid == 1 && last == S - beta[startCurCol + S] = beta[startNextCol + S] + probs[startProbCol + blankLabel] + beta[S, t] = beta[S, t] + probs[blankLabel, t+1] end sync_threads() idx = tid while idx <= S - output[startCurCol + idx] = alphas[idx+startCurCol] + beta[startCurCol + idx] + output[idx, t] = alphas[idx, t] + beta[idx, t] idx += blockDim().x end @@ -224,14 +213,10 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, # Calculate accumulated alpha-beta products for each label class for # each time step; used in calculating gradients - if tid == 1 - - startAccCol = (t-1) * div(length(accum), T) - startOutputCol = (t-1) * S - + if tid == 1 for i=1:S labelIdx = labels[i] - accum[startAccCol + labelIdx] = log_plus_f(accum[startAccCol + labelIdx], output[startOutputCol + i]) + accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) end end @@ -240,18 +225,16 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, idx = tid # Calculate gradients - while idx <= CUDAnative.div_fast(Float32(length(grad)), Float32(T)) -# - startProbCol = (t - 1) * div(length(probs), T) - startOutputCol = (t - 1) * S + while idx <= size(grad, 1) s = -Inf32 + for i=1:S - s = log_plus_f(s, output[startOutputCol + i]) + s = log_plus_f(s, output[i, t]) end # ∂L/∂a (where a is activation before logsoftmax) - grad[startProbCol + idx] = CUDAnative.exp(probs[startProbCol + idx]) - CUDAnative.exp(accum[startProbCol + idx] - s) + grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s) idx += blockDim().x end @@ -259,56 +242,50 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, t -= 1 sync_threads() - # because of course, it wouldn't work without this earlier return statement - # otherwise, some of the gradient values become 0 - t == 0 && return end return nothing end -ctc(ŷ::CuArrays.CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean - -ctc(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y)[1] |> mean - -ctc(ŷ::CuArrays.CuArray, y::CuArrays.CuArray) = ctc_(ŷ, y)[1] |> mean - # methods for `ctc_` helper function -ctc_(ŷ::Array, y::CuArrays.CuArray) = ctc_(CuArray(ŷ), y) +ctc(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean +ctc(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean +ctc(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean +ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y)) -function ctc_(ŷ::CuArrays.CuArray, y) +function ctc_(ŷ::CuArray, y) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - labels = vec(mapslices(Base.argmax, y, dims=1)) + labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)] z = F(labels, blank) z′ = [blank] for label in z push!(z′, label) push!(z′, blank) end + T = size(ŷ, 2) U′ = 2*length(z) + 1 - alphas = CuArrays.fill(log(zero(ŷ[1])), T * U′) - betas = copy(alphas) - output = copy(alphas) + + alphas = CUDA.fill(log(zero(ŷ[1])), U′, T) + betas = CUDA.fill(log(zero(ŷ[1])), U′, T) + output = CUDA.fill(log(zero(ŷ[1])), U′, T) nRepeats = countRepeats(labels) + nThreads = min(U′, MAX_THREADS) - # 1 block with `U′` threads - @cuda blocks=1 threads=U′ computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) + @cuda blocks=1 threads=nThreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) - grads = CuArrays.fill(log(zero(ŷ[1])), length(ŷ)) - accum = copy(grads) + grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) + accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) - @cuda blocks=1 threads=U′ computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) + @cuda blocks=1 threads=nThreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) - ls = reshape(collect(output), U′, T) - ls = -1 .* mapslices(logsum, ls, dims=1) |> vec - - gs = reshape(grads, size(ŷ,1), size(ŷ,2)) - - ŷ = alphas = betas = output = accum = grads = nothing - return ls, gs + ls = collect(output) + ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)]) + + ŷ = alphas = betas = output = accum = nothing + return ls, grads end diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 4524c0f54f..d7c9e9b7d3 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -52,7 +52,7 @@ function F(A, blank) for curr in A[2:end] if curr != prev && curr != blank push!(z, curr) -`` end + end prev = curr end return z @@ -75,6 +75,8 @@ end function ctc_(ŷ, y) + typedZero = zero(ŷ[1]) + ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) @@ -85,19 +87,17 @@ function ctc_(ŷ, y) U′ = length(z′) # Calculate α coefficients, from the upper-left, to the bottom-right - α = zeros(Float64, T, U′) + α = fill(typedZero, T, U′) for t=1:T for u=1:U′ if t == u == 1 -# α[t,u] = ŷ[t, blank] α[t,u] = ŷ[blank, t] elseif t == 1 && u == 2 -# α[t,u] = ŷ[t, z′[2]] α[t,u] = ŷ[z′[2], t] elseif t == 1 && u > 2 - α[t,u] = -Inf + α[t,u] = log(typedZero) elseif u < U′ - 2(T - t) - 1 - α[t,u] = -Inf + α[t,u] = log(typedZero) else idx = u - 2 idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) @@ -109,49 +109,37 @@ function ctc_(ŷ, y) end # Calculate beta coefficients, from the bottom-right, to the upper-left - β = zeros(Float64, T, U′) - for i=1:length(β) - β[i] = -Inf - end + β = fill(log(typedZero), T, U′) # Fill bottom-right corner so bounding errors can be avoided # by starting `u` at `U′-1` - β[T,U′] = 0.0 - - for t=T:-1:1 - for u=(U′-1):-1:1 - if t == T && u >= U′ - 1 - β[t,u] = 0.0 - elseif t == T && u < U′ - 1 + β[T,U′] = typedZero + β[T,U′-1] = typedZero + + # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 + for t=(T-1):-1:1 + for u=U′:-1:1 + if u > 2t || u > U′ + 1 continue - elseif u > 2t || u > U′ + 1 - continue - else - idx = u+2 - idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) - idx = min(idx, U′) - - v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] - β[t, u] = logsum(v) end - end - if t < T-1 - β[t, U′] = β[t+1, U′] + ŷ[blank, t] + + idx = u+2 + idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) + idx = min(idx, U′) + + v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] + β[t, u] = logsum(v) end end + - # Loss at each time t is taken as the sum of the product of the α and β coefficients for - # all the label classes at time t - losses = Vector() - for t=1:T - v = [α[t,u] + β[t,u] for u in 1:U′] - push!(losses, -logsum(v)) - end + # Loss at each time t is taken as the sum of the product (sum in log space) of the + # α and β coefficients for all the label classes at time t + αβ = α + β + losses = -1 .* mapslices(logsum, αβ, dims=2) - # `accum` will hold the sum of the α and β coefficients for - # each label class at time t; used in calculating gradients - accum = fill(-Inf, size(ŷ)) - grads = fill(-Inf, size(ŷ)) + accum = fill(log(typedZero), size(ŷ)) + grads = fill(log(typedZero), size(ŷ)) for t=1:T for u=1:U′ @@ -163,7 +151,6 @@ function ctc_(ŷ, y) end losses = [x for x in losses] - return losses, grads end From 5bf2635d11da99763b935ee38bec96231b1ba9ea Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Wed, 14 Oct 2020 21:45:15 -0600 Subject: [PATCH 07/31] Get test cases working --- src/losses/Losses.jl | 3 ++- test/ctc-gpu.jl | 21 +++++++++++++++------ test/ctc.jl | 13 +++++++++++-- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 2289f0c9f8..4f738c0202 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -16,11 +16,12 @@ export mse, mae, msle, tversky_loss, dice_coeff_loss, poisson_loss, - hinge_loss, siquared_hinge_loss, + hinge_loss, squared_hinge_loss, ctc include("utils.jl") include("functions.jl") include("ctc.jl") +if CUDA.functional() include("ctc-gpu.jl") end end #module diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index 82bb809d83..8cb2da3277 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -1,9 +1,9 @@ using Test using Flux -using Flux: ctc_ +using Flux.Losses: ctc using Zygote: gradient using LinearAlgebra -using CuArrays +using CUDA using Statistics # Custom function to check numerical gradient of ctc loss, @@ -15,7 +15,7 @@ using Statistics # causing the gradients to change and thus not be comparable # between the numeric and analytical definitions function ctc_ngradient(xs...) - f = ctc_ + f = Flux.Losses.ctc_ grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) @@ -35,7 +35,7 @@ end x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) - + x_cu = CuArray(x) y_cu = CuArray(y) @@ -48,8 +48,8 @@ end # test that GPU loss matches CPU implementation - l1 = Flux.ctc_(x_cu, y_cu)[1] - l2 = Flux.ctc_(x, y)[1] + l1 = ctc(x_cu, y_cu) + l2 = ctc(x, y) @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) @@ -64,5 +64,14 @@ end ghat = gradient(ctc, x_cu, y_cu)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + + x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray + y_cu = [1 1 0 0; 0 0 1 1; 0 0 0 0] |> CuArray + @test ctc(x_cu, y_cu) ≈ 8.02519869363453 + + g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] + + ghat = gradient(ctc, x_cu, y_cu)[1] |> collect + @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) end diff --git a/test/ctc.jl b/test/ctc.jl index 1771c7dff2..a52a13a1a3 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -1,6 +1,6 @@ using Test using Flux -using Flux: ctc_ +using Flux.Losses: ctc using Zygote: gradient using LinearAlgebra @@ -13,7 +13,7 @@ using LinearAlgebra # causing the gradients to change and thus not be comparable # between the numeric and analytical definitions function ctc_ngradient(xs...) - f = ctc_ + f = Flux.Losses.ctc_ grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) @@ -50,5 +50,14 @@ end ghat = gradient(ctc, x, y)[1] @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + + x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] + y = [1 1 0 0; 0 0 1 1; 0 0 0 0] + @test ctc(x, y) ≈ 8.02519869363453 + + g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] + + ghat = gradient(ctc, x, y)[1] + @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) end From e4123ab16f3c7a34da80ff62b779ed4529a10855 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sat, 17 Oct 2020 14:44:29 -0600 Subject: [PATCH 08/31] Update NEWS.md --- NEWS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/NEWS.md b/NEWS.md index f4de21de4a..fe73395528 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# v0.11.2 + +* Add [CTC loss function](https://github.com/FluxML/Flux.jl/pull/1287) to Losses module + # v0.11 * Moved CUDA compatibility to use [CUDA.jl instead of CuArrays.jl](https://github.com/FluxML/Flux.jl/pull/1204) From 46898ed30ef0b523fee59f5d27262b363b0d6407 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 13 Dec 2020 17:53:05 -0700 Subject: [PATCH 09/31] Change ctc to ctc_loss --- src/losses/Losses.jl | 2 +- src/losses/ctc-gpu.jl | 6 +++--- src/losses/ctc.jl | 4 ++-- test/ctc-gpu.jl | 16 ++++++++-------- test/ctc.jl | 14 +++++++------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 4f738c0202..96d86f4cae 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -17,7 +17,7 @@ export mse, mae, msle, dice_coeff_loss, poisson_loss, hinge_loss, squared_hinge_loss, - ctc + ctc_loss include("utils.jl") include("functions.jl") diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index 48dc7a5f28..7aa0aabacd 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -248,9 +248,9 @@ function computeBetasAndGradKernel(probs, labelSize, uttLength, end # methods for `ctc_` helper function -ctc(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean -ctc(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean -ctc(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean +ctc_loss(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean +ctc_loss(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean +ctc_loss(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y)) function ctc_(ŷ::CuArray, y) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index d7c9e9b7d3..666d733546 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -155,7 +155,7 @@ function ctc_(ŷ, y) end """ - ctc(ŷ, y) + ctc_loss(ŷ, y) Computes the connectionist temporal classification loss between `ŷ` and `y`. @@ -174,7 +174,7 @@ solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7) for mathematical details. """ -function ctc(ŷ::Array, y::Array) +function ctc_loss(ŷ::Array, y::Array) return ctc_(ŷ, y)[1] |> mean end diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index 8cb2da3277..06c0f7eaf5 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -1,6 +1,6 @@ using Test using Flux -using Flux.Losses: ctc +using Flux.Losses: ctc_loss using Zygote: gradient using LinearAlgebra using CUDA @@ -39,7 +39,7 @@ end x_cu = CuArray(x) y_cu = CuArray(y) - g1 = gradient(ctc, x_cu, y_cu)[1] + g1 = gradient(ctc_loss, x_cu, y_cu)[1] g1 = g1 |> collect g2 = ctc_ngradient(x, y)[1] @@ -48,8 +48,8 @@ end # test that GPU loss matches CPU implementation - l1 = ctc(x_cu, y_cu) - l2 = ctc(x, y) + l1 = ctc_loss(x_cu, y_cu) + l2 = ctc_loss(x, y) @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) @@ -58,20 +58,20 @@ end x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray y_cu = [1 1 0; 0 0 1; 0 0 0] |> CuArray - @test mean(ctc(x_cu, y_cu)) ≈ 3.6990738275138035 + @test mean(ctc_loss(x_cu, y_cu)) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] - ghat = gradient(ctc, x_cu, y_cu)[1] |> collect + ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray y_cu = [1 1 0 0; 0 0 1 1; 0 0 0 0] |> CuArray - @test ctc(x_cu, y_cu) ≈ 8.02519869363453 + @test ctc_loss(x_cu, y_cu) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] - ghat = gradient(ctc, x_cu, y_cu)[1] |> collect + ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) end diff --git a/test/ctc.jl b/test/ctc.jl index a52a13a1a3..c65944ea48 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -1,6 +1,6 @@ using Test using Flux -using Flux.Losses: ctc +using Flux.Losses: ctc_loss using Zygote: gradient using LinearAlgebra @@ -29,12 +29,12 @@ function ctc_ngradient(xs...) return grads end -@testset "ctc" begin +@testset "ctc_loss" begin x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) - g1 = gradient(ctc, x, y)[1] + g1 = gradient(ctc_loss, x, y)[1] g1 = g1 g2 = ctc_ngradient(x, y)[1] @@ -45,19 +45,19 @@ end x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] y = [1 1 0; 0 0 1; 0 0 0] - @test ctc(x, y) ≈ 3.6990738275138035 + @test ctc_loss(x, y) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] - ghat = gradient(ctc, x, y)[1] + ghat = gradient(ctc_loss, x, y)[1] @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] y = [1 1 0 0; 0 0 1 1; 0 0 0 0] - @test ctc(x, y) ≈ 8.02519869363453 + @test ctc_loss(x, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] - ghat = gradient(ctc, x, y)[1] + ghat = gradient(ctc_loss, x, y)[1] @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) end From 110a608cbeabdf85fea8aee0a66da1626d9f3621 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 13 Dec 2020 17:54:17 -0700 Subject: [PATCH 10/31] Fix typo --- src/losses/ctc-gpu.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index 7aa0aabacd..fe381ae4e9 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -1,4 +1,4 @@ -# GPU impelmentation +# GPU implementation # a port of the GPU kernels from Baidu's C++ warp-ctc package # GitHub: https://github.com/baidu-research/warp-ctc/ From 5145222c20d5cc9640a6f35954053e1dc29f16a0 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 13 Dec 2020 18:16:58 -0700 Subject: [PATCH 11/31] Remove camel-casing --- src/losses/ctc-gpu.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index fe381ae4e9..1975a49701 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -22,7 +22,7 @@ function log_plus_f(p1, p2) return p1 + CUDA.log(1+CUDA.exp(p2 - p1)) end -function countRepeats(A) +function count_repeats(A) repeats = 0 for (i,elem) in enumerate(A) if i > 1 && A[i] == A[i-1] @@ -32,7 +32,7 @@ function countRepeats(A) return repeats end -function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel) +function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithoutBlanks, labelsWithBlanks, alpha, blankLabel) tid = threadIdx().x L = labelSize @@ -97,7 +97,7 @@ function computeAlphaKernel(probs, labelSize, uttLength, repeats, labelsWithoutB return nothing end -function computeBetasAndGradKernel(probs, labelSize, uttLength, +function compute_betas_and_grad_kernel(probs, labelSize, uttLength, repeatsInLabel, labelsWithBlanks, alphas, beta, output, accum, grad, blankLabel) @@ -273,15 +273,15 @@ function ctc_(ŷ::CuArray, y) betas = CUDA.fill(log(zero(ŷ[1])), U′, T) output = CUDA.fill(log(zero(ŷ[1])), U′, T) - nRepeats = countRepeats(labels) + nRepeats = count_repeats(labels) nThreads = min(U′, MAX_THREADS) - @cuda blocks=1 threads=nThreads computeAlphaKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) + @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) - @cuda blocks=1 threads=nThreads computeBetasAndGradKernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) + @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) ls = collect(output) ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)]) From e002027c98bd7cdc5c3948f84b5196735b81f41b Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 13 Dec 2020 18:32:20 -0700 Subject: [PATCH 12/31] Remove some whitespace from functions --- src/losses/ctc-gpu.jl | 48 ++----------------------------------------- src/losses/ctc.jl | 11 ---------- 2 files changed, 2 insertions(+), 57 deletions(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index 1975a49701..4f33c71062 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -11,14 +11,11 @@ using CUDA const MAX_THREADS = 256 function log_plus_f(p1, p2) - isinf(p1) && return p2 isinf(p2) && return p1 - if p1 < p2 p1, p2 = p2, p1 end - return p1 + CUDA.log(1+CUDA.exp(p2 - p1)) end @@ -42,7 +39,6 @@ function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithou if L + repeats > T return nothing end - labels = labelsWithBlanks # Corner-case checking @@ -55,12 +51,10 @@ function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithou alpha[start+i, 1] = probs[labels[start+i], 1] i += blockDim().x end - sync_threads() # Fill in coefficients for each time step for t=2:T - # Corner-case checking if tid == 1 && !(1 < S - 2*(T-t) - 1) if start == 0 @@ -69,35 +63,29 @@ function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithou alpha[1, t] = alpha[1, t-1] end end - sync_threads() # Fill in coefficients for each label class in the target output sequence; # each thread will process the calculations for one class idx = tid+1 while idx <= S - prevSum = log_plus_f(alpha[idx, t-1], alpha[idx-1, t-1]) - if labels[idx] != blankLabel && idx != 2 && labels[idx] != labels[idx-2] prevSum = log_plus_f(prevSum, alpha[idx-2, t-1]) end - if idx < S - 2*(T-t) - 1 alpha[idx, t] = -Inf32 else alpha[idx, t] = prevSum + probs[labels[idx], t] end - idx += blockDim().x end - sync_threads() end return nothing end -function compute_betas_and_grad_kernel(probs, labelSize, uttLength, +function compute_beta_and_grad_kernel(probs, labelSize, uttLength, repeatsInLabel, labelsWithBlanks, alphas, beta, output, accum, grad, blankLabel) @@ -107,7 +95,6 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, T = uttLength S = 2*L + 1 repeats = repeatsInLabel - labels = labelsWithBlanks if (L+repeats) > T @@ -117,9 +104,8 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, # Corner-case checking start = S > 1 ? S-2 : 0 last = L + repeats < T ? S : S-1 - sync_threads() - + i = tid # Calculate coefficients for last column (time step) @@ -129,7 +115,6 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, output[i+start, T] = beta[i+start, T] + alphas[i+start, T] i += blockDim().x end - sync_threads() # Fill in `accum` for last column (time step) @@ -145,9 +130,7 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, # Fill in `grad` for last column (time step) idx = tid while idx <= size(grad, 1) - s = -Inf32 - for i=1:S s = log_plus_f(s, output[i, T]) end @@ -156,47 +139,36 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, grad[idx, T] = CUDA.exp(probs[idx, T]) - CUDA.exp(accum[idx, T] - s) idx += blockDim().x end - sync_threads() # Fill in the rest of the coefficients t = T-1 while t >= 1 if t < T - idx = tid # while idx <= S-1 while idx <= S - nextSum = beta[idx, t+1] + probs[labels[idx], t+1] - if idx < S - nextSum = log_plus_f(nextSum, beta[idx+1, t+1] + probs[labels[idx+1], t+1]) end - if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] nextSum = log_plus_f(nextSum, beta[idx + 2, t+1] + probs[labels[idx+2], t+1]) end - if idx > 2*t beta[idx, t] = -Inf32 else beta[idx, t] = nextSum - end - idx += blockDim().x end - sync_threads() if tid == 1 && last == S beta[S, t] = beta[S, t] + probs[blankLabel, t+1] end - sync_threads() idx = tid @@ -204,11 +176,8 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, output[idx, t] = alphas[idx, t] + beta[idx, t] idx += blockDim().x end - sync_threads() end - - sync_threads() # Calculate accumulated alpha-beta products for each label class for @@ -219,16 +188,12 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, accum[labelIdx, t] = log_plus_f(accum[labelIdx, t], output[i, t]) end end - sync_threads() - idx = tid # Calculate gradients while idx <= size(grad, 1) - s = -Inf32 - for i=1:S s = log_plus_f(s, output[i, t]) end @@ -239,11 +204,9 @@ function compute_betas_and_grad_kernel(probs, labelSize, uttLength, end sync_threads() - t -= 1 sync_threads() end - return nothing end @@ -254,9 +217,7 @@ ctc_loss(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y)) function ctc_(ŷ::CuArray, y) - ŷ = logsoftmax(ŷ) - blank = size(ŷ, 1) labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)] z = F(labels, blank) @@ -268,24 +229,19 @@ function ctc_(ŷ::CuArray, y) T = size(ŷ, 2) U′ = 2*length(z) + 1 - alphas = CUDA.fill(log(zero(ŷ[1])), U′, T) betas = CUDA.fill(log(zero(ŷ[1])), U′, T) output = CUDA.fill(log(zero(ŷ[1])), U′, T) - nRepeats = count_repeats(labels) nThreads = min(U′, MAX_THREADS) @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) - grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) - ls = collect(output) ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)]) - ŷ = alphas = betas = output = accum = nothing return ls, grads end diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 666d733546..3246f4c35f 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -3,7 +3,6 @@ using Zygote: @adjoint using Statistics # CPU implementation - """ logadd(a, b) @@ -21,7 +20,6 @@ function logadd(a, b) if a < b a, b = b, a end - return a + log(1+exp(b-a)) end @@ -74,12 +72,9 @@ function addBlanks(z, blank) end function ctc_(ŷ, y) - typedZero = zero(ŷ[1]) - ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - z = F(Base.argmax.([y[:,i] for i=1:size(y,2)]), blank) z′ = addBlanks(z, blank) T = size(ŷ, 2) @@ -102,7 +97,6 @@ function ctc_(ŷ, y) idx = u - 2 idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) idx = max(1, idx) - α[t,u] = ŷ[z′[u], t] + logsum(α[t-1, idx:u]) end end @@ -122,11 +116,9 @@ function ctc_(ŷ, y) if u > 2t || u > U′ + 1 continue end - idx = u+2 idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) idx = min(idx, U′) - v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] β[t, u] = logsum(v) end @@ -137,10 +129,8 @@ function ctc_(ŷ, y) # α and β coefficients for all the label classes at time t αβ = α + β losses = -1 .* mapslices(logsum, αβ, dims=2) - accum = fill(log(typedZero), size(ŷ)) grads = fill(log(typedZero), size(ŷ)) - for t=1:T for u=1:U′ accum[z′[u], t] = logadd(accum[z′[u], t], α[t,u] + β[t,u]) @@ -149,7 +139,6 @@ function ctc_(ŷ, y) grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -losses[t]) end end - losses = [x for x in losses] return losses, grads end From d0bd3bdbe9bb18d7ef3728b5dcd65dbab04d4343 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 13 Dec 2020 18:35:37 -0700 Subject: [PATCH 13/31] Adding info to comply with Apache license --- src/losses/ctc-gpu.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index 4f33c71062..54bf511156 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -1,6 +1,10 @@ # GPU implementation -# a port of the GPU kernels from Baidu's C++ warp-ctc package +# a port of the GPU kernels from Baidu's C++ warp-ctc package, +# which itself is Copyright 2015-2016 Baidu USA LLC +# and available under the Apache 2.0 license +# +# Apache 2.0 license: https://www.apache.org/licenses/LICENSE-2.0 # GitHub: https://github.com/baidu-research/warp-ctc/ # paper: https://arxiv.org/pdf/1512.02595.pdf From 5dafa05f53d11c7f1d0924471d3bed860fdb3655 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Thu, 17 Dec 2020 23:45:11 -0700 Subject: [PATCH 14/31] Use logsumexp in CPU CTC --- src/losses/ctc.jl | 53 +++++++++++------------------------------------ 1 file changed, 12 insertions(+), 41 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 3246f4c35f..6a2b7ea0f3 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -1,42 +1,9 @@ using Flux using Zygote: @adjoint using Statistics +using NNlib # CPU implementation -""" - logadd(a, b) - -Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))` -""" -function logadd(a, b) - isinf(a) && return b - isinf(b) && return a - - # always want the greater number on the left in the exponentiation; - # the magnitude difference may end up making the number very positive - # which will cause exp() to return Inf - # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be - # Inf for Float32 values - if a < b - a, b = b, a - end - return a + log(1+exp(b-a)) -end - -""" - logsum(a) - -Sums the elements in `a` such that the result equals `log(sum(exp.(a)))` -""" -function logsum(a::AbstractArray) - local s - s = a[1] - for item in a[2:end] - s = logadd(s, item) - end - return s -end - """ F(A, blank) @@ -57,11 +24,11 @@ function F(A, blank) end """ - addBlanks(z) + add_blanks(z) Adds blanks to the start and end of `z`, and between item in `z` """ -function addBlanks(z, blank) +function add_blanks(z, blank) z′ = [blank] for label in z @@ -76,7 +43,7 @@ function ctc_(ŷ, y) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) z = F(Base.argmax.([y[:,i] for i=1:size(y,2)]), blank) - z′ = addBlanks(z, blank) + z′ = add_blanks(z, blank) T = size(ŷ, 2) U = length(z) U′ = length(z′) @@ -97,7 +64,11 @@ function ctc_(ŷ, y) idx = u - 2 idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) idx = max(1, idx) - α[t,u] = ŷ[z′[u], t] + logsum(α[t-1, idx:u]) + α[t,u] = ŷ[z′[u], t] + logsumexp(α[t-1, idx:u]) + if isnan(α[t,u]) + println(α[t-1, idx:u]) + exit() + end end end end @@ -120,7 +91,7 @@ function ctc_(ŷ, y) idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) idx = min(idx, U′) v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] - β[t, u] = logsum(v) + β[t, u] = logsumexp(v) end end @@ -128,12 +99,12 @@ function ctc_(ŷ, y) # Loss at each time t is taken as the sum of the product (sum in log space) of the # α and β coefficients for all the label classes at time t αβ = α + β - losses = -1 .* mapslices(logsum, αβ, dims=2) + losses = -1 .* mapslices(logsumexp, αβ, dims=2) accum = fill(log(typedZero), size(ŷ)) grads = fill(log(typedZero), size(ŷ)) for t=1:T for u=1:U′ - accum[z′[u], t] = logadd(accum[z′[u], t], α[t,u] + β[t,u]) + accum[z′[u], t] = logsumexp([accum[z′[u], t], α[t,u] + β[t,u]]) end for u=1:size(grads, 1) grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -losses[t]) From 39504648a5f5b1930aeaf90487290f39cf942f0b Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Fri, 18 Dec 2020 00:06:42 -0700 Subject: [PATCH 15/31] Change logsum to logsumexp --- src/losses/ctc-gpu.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index 54bf511156..e424cc4464 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -11,6 +11,7 @@ using Flux using Statistics using CUDA +using NNlib const MAX_THREADS = 256 @@ -245,7 +246,7 @@ function ctc_(ŷ::CuArray, y) @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) ls = collect(output) - ls = vec(-1 .* [logsum(ls[:,i]) for i in 1:size(ls, 2)]) + ls = vec(-1 .* [logsumexp(ls[:,i]) for i in 1:size(ls, 2)]) ŷ = alphas = betas = output = accum = nothing return ls, grads end From 282cb2339cd2efd25fa2928c8114b214d75ee31f Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sat, 19 Dec 2020 14:59:37 -0700 Subject: [PATCH 16/31] Re-add logaddexp function to CPU ctc --- Manifest.toml | 78 +++++++++++++++++++++++------------------------ Project.toml | 2 +- src/losses/ctc.jl | 31 ++++++++++++++----- 3 files changed, 63 insertions(+), 48 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 1a6659f7b0..8751460ce6 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -20,9 +20,9 @@ version = "2.3.0" [[ArrayLayouts]] deps = ["Compat", "FillArrays", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "d4fc05297de9d8cf41a59162ebfaeee8fb6eb344" +git-tree-sha1 = "8f6af27c33b766f19fa6cfe46e629775cda81f88" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "0.4.10" +version = "0.4.11" [[Artifacts]] deps = ["Pkg"] @@ -52,21 +52,21 @@ version = "2.1.0" [[ChainRules]] deps = ["ChainRulesCore", "ChainRulesTestUtils", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "f6265a7e32d2c82e9462a06e0d5066284c29b89e" +git-tree-sha1 = "097722a98537a738e3e42bec069c63663292f991" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.32" +version = "0.7.40" [[ChainRulesCore]] deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] -git-tree-sha1 = "aebbda0a7c644bd8739b34f2a1b1e48f114aab49" +git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.17" +version = "0.9.24" [[ChainRulesTestUtils]] deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] -git-tree-sha1 = "157e00b3c05e63c3482aa2179abd3d2a54ad4a57" +git-tree-sha1 = "89cb6ebdae4010b8024b2ac22fcb9e316ac9b82c" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.5.3" +version = "0.5.9" [[CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -94,9 +94,9 @@ version = "0.3.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "a706ff10f1cd8dab94f59fd09c0e657db8e77ff0" +git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.23.0" +version = "3.25.0" [[CompilerSupportLibraries_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -131,15 +131,15 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" [[DiffResults]] deps = ["StaticArrays"] -git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc" +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.2" +version = "1.0.3" [[DiffRules]] deps = ["NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1" +git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.0.1" +version = "1.0.2" [[Distributed]] deps = ["Random", "Serialization", "Sockets"] @@ -164,9 +164,9 @@ version = "0.9.7" [[FiniteDifferences]] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson"] -git-tree-sha1 = "5ef56eea8b65c2ed2d11bd394629892201369b23" +git-tree-sha1 = "9bc8327853f21c9c53eac6935e51b65bcc28c492" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.11.2" +version = "0.11.5" [[FixedPointNumbers]] deps = ["Statistics"] @@ -176,9 +176,9 @@ version = "0.8.4" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "1d090099fb82223abc48f7ce176d3f7696ede36d" +git-tree-sha1 = "8de2519a83c6c1c2442c2f481dd9a8364855daf4" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.12" +version = "0.10.14" [[Functors]] deps = ["MacroTools"] @@ -192,9 +192,9 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "ba423e79aa337a2051fa466a7a42f43d9c81a132" +git-tree-sha1 = "2c1dd57bca7ba0b3b4bf81d9332aeb81b154ef4c" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "6.1.1" +version = "6.1.2" [[GPUCompiler]] deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] @@ -204,9 +204,9 @@ version = "0.8.3" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "a8d88c05a23b44b4da6cf4fb5659e13ff95e0f47" +git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.1" +version = "0.4.2" [[InteractiveUtils]] deps = ["Markdown"] @@ -225,9 +225,9 @@ version = "0.8.4" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "4eb5a1e702fee0d81c15ab673d7c77ef9023d509" +git-tree-sha1 = "a2101830a761d592b113129887fda626387f68d4" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.3.0" +version = "3.5.1" [[LibGit2]] deps = ["Printf"] @@ -280,21 +280,21 @@ uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" version = "0.2.2" [[NNlib]] -deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "a8180fd1445e31c0b1add98dae8da694ac2c23fd" +deps = ["Compat", "Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "bc78eb23ba38a1b7bdac20fb748b383fc73d6a8b" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.6" +version = "0.7.8" [[NaNMath]] -git-tree-sha1 = "c84c576296d0e2fbb3fc134d3e09086b3ea617cd" +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.4" +version = "0.3.5" [[OffsetArrays]] deps = ["Adapt"] -git-tree-sha1 = "9db93b990af57b3a56dca38476832f60d58f777b" +git-tree-sha1 = "45d5e495ab559357aee8cb1dfb8c12b0787d4545" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.4.0" +version = "1.4.1" [[OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -335,15 +335,15 @@ version = "0.2.0" [[Requires]] deps = ["UUIDs"] -git-tree-sha1 = "28faf1c963ca1dc3ec87f166d92982e3c4a1f66d" +git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.0" +version = "1.1.2" [[Richardson]] deps = ["LinearAlgebra"] -git-tree-sha1 = "776e0fdd3da5ad52067b60310ea8f3150d794c2f" +git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949" uuid = "708f8203-808e-40c0-ba2d-98a6953ed40d" -version = "1.2.0" +version = "1.4.0" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -394,9 +394,9 @@ version = "0.10.3" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "da4cf579416c81994afd6322365d00916c79b8ae" +git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.12.5" +version = "1.0.1" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -462,6 +462,6 @@ version = "0.5.9" [[ZygoteRules]] deps = ["MacroTools"] -git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" +git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.0" +version = "0.2.1" diff --git a/Project.toml b/Project.toml index ce57272177..072fbd4888 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Colors = "0.12" Functors = "0.1" Juno = "0.8" MacroTools = "0.5" -NNlib = "0.7" +NNlib = "0.7.8" Reexport = "0.2" StatsBase = "0.33" ZipFile = "0.9" diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 6a2b7ea0f3..c6175ba314 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -4,6 +4,25 @@ using Statistics using NNlib # CPU implementation +""" + logaddexp(a, b) +Adds log-space `a` and `b` such that the result equals `log(exp(a)+exp(b))` +""" +function logaddexp(a, b) + isinf(a) && return b + isinf(b) && return a + + # always want the greater number on the left in the exponentiation; + # the magnitude difference may end up making the number very positive + # which will cause exp() to return Inf + # E.g., a = -900, b = -800, will give exp(-800 - -900), which will be + # Inf for Float32 values + if a < b + a, b = b, a + end + return a + log(1+exp(b-a)) +end + """ F(A, blank) @@ -64,11 +83,7 @@ function ctc_(ŷ, y) idx = u - 2 idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) idx = max(1, idx) - α[t,u] = ŷ[z′[u], t] + logsumexp(α[t-1, idx:u]) - if isnan(α[t,u]) - println(α[t-1, idx:u]) - exit() - end + α[t,u] = ŷ[z′[u], t] + foldl(logaddexp, α[t-1, idx:u]) end end end @@ -91,7 +106,7 @@ function ctc_(ŷ, y) idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) idx = min(idx, U′) v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] - β[t, u] = logsumexp(v) + β[t, u] = foldl(logaddexp, v) end end @@ -99,12 +114,12 @@ function ctc_(ŷ, y) # Loss at each time t is taken as the sum of the product (sum in log space) of the # α and β coefficients for all the label classes at time t αβ = α + β - losses = -1 .* mapslices(logsumexp, αβ, dims=2) + losses = -1 .* logsumexp(αβ, dims=2) accum = fill(log(typedZero), size(ŷ)) grads = fill(log(typedZero), size(ŷ)) for t=1:T for u=1:U′ - accum[z′[u], t] = logsumexp([accum[z′[u], t], α[t,u] + β[t,u]]) + accum[z′[u], t] = logaddexp(accum[z′[u], t], α[t,u] + β[t,u]) end for u=1:size(grads, 1) grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -losses[t]) From d9caac44eb972d4d0f8d844ffe4d5ae0fee5be9f Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sat, 19 Dec 2020 16:07:14 -0700 Subject: [PATCH 17/31] Regnerate Manifest.toml --- Manifest.toml | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 2f328bd8e0..1bb2f75f27 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -20,15 +20,15 @@ version = "2.3.0" [[ArrayInterface]] deps = ["LinearAlgebra", "Requires", "SparseArrays"] -git-tree-sha1 = "de4bb46df3f67769356e737f2c7ce1d67da3ae49" +git-tree-sha1 = "b7898df8dff4098db4a9494d2451c5c2edd4cb2c" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "2.14.4" +version = "2.14.5" [[ArrayLayouts]] deps = ["Compat", "FillArrays", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "8f6af27c33b766f19fa6cfe46e629775cda81f88" +git-tree-sha1 = "a577e27915fdcb3f6b96118b56655b38e3b466f2" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "0.4.11" +version = "0.4.12" [[Artifacts]] deps = ["Pkg"] @@ -100,6 +100,9 @@ version = "0.3.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.25.0" [[CompilerSupportLibraries_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -213,6 +216,11 @@ git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.2" +[[IfElse]] +git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.0" + [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -250,9 +258,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[LoopVectorization]] deps = ["ArrayInterface", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "SLEEFPirates", "UnPack", "VectorizationBase"] -git-tree-sha1 = "787e481682c5ef24734bcc4b6390f4fa9b8d3473" +git-tree-sha1 = "9a0145feae3fc55b86e3d1d3a5b3c83c6c05e445" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.9.6" +version = "0.9.8" [[MacroTools]] deps = ["Markdown", "Random"] @@ -355,9 +363,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[SLEEFPirates]] deps = ["IfElse", "Libdl", "VectorizationBase"] -git-tree-sha1 = "6ae40418987449f040abe6b517244ff5541171ff" +git-tree-sha1 = "d82dffab8f9e50d5110c5650c25fdf9e00dec316" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.0" +version = "0.6.1" [[Scratch]] deps = ["Dates"] @@ -386,10 +394,10 @@ deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["OpenSpecFun_jll"] -git-tree-sha1 = "7286f31f27e3335cba31c618ac344a35eceac060" +deps = ["ChainRulesCore", "OpenSpecFun_jll"] +git-tree-sha1 = "75394dbe2bd346beeed750fb02baa6445487b862" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.1.0" +version = "1.2.1" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] @@ -437,9 +445,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[VectorizationBase]] deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra"] -git-tree-sha1 = "b2e9c80c584f74b547e0ebd7afb52d65aaf65362" +git-tree-sha1 = "fa6ef8980ee738089ef69298e9aa824cc0a86c25" uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -version = "0.13.10" +version = "0.14.2" [[ZipFile]] deps = ["Libdl", "Printf", "Zlib_jll"] @@ -455,9 +463,9 @@ version = "1.2.11+18" [[Zygote]] deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "0a03c99ec000a89f5d17d3477c8c7367ed4367f3" +git-tree-sha1 = "18f758f28ca2c236e449be64e366e201965129a7" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.5.16" +version = "0.5.17" [[ZygoteRules]] deps = ["MacroTools"] From c04385592c0467a2ccb43209c4567557ae5c1a7c Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 20 Dec 2020 12:19:41 -0700 Subject: [PATCH 18/31] Remove time indexing from ctc gradchecks Also remove some extra whitespace from ctc tests. --- test/ctc-gpu.jl | 24 +++--------------------- test/ctc.jl | 22 ++++------------------ 2 files changed, 7 insertions(+), 39 deletions(-) diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index 06c0f7eaf5..174328256d 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -8,23 +8,16 @@ using Statistics # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` -# -# Needs to check loss as defined at a particular time step -# related to the change in x because slight deviations in -# input propagate through further time steps, intrinsically -# causing the gradients to change and thus not be comparable -# between the numeric and analytical definitions function ctc_ngradient(xs...) - f = Flux.Losses.ctc_ + f = ctc_loss grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) - t = div(i-1, size(x, 1)) + 1 tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...)[1][t] + y1 = f(xs...) x[i] = tmp + δ/2 - y2 = f(xs...)[1][t] + y2 = f(xs...) x[i] = tmp Δ[i] = (y2-y1)/δ end @@ -32,37 +25,28 @@ function ctc_ngradient(xs...) end @testset "ctc-gpu" begin - x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) - x_cu = CuArray(x) y_cu = CuArray(y) g1 = gradient(ctc_loss, x_cu, y_cu)[1] g1 = g1 |> collect - g2 = ctc_ngradient(x, y)[1] - @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # test that GPU loss matches CPU implementation - l1 = ctc_loss(x_cu, y_cu) l2 = ctc_loss(x, y) - @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values - x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray y_cu = [1 1 0; 0 0 1; 0 0 0] |> CuArray - @test mean(ctc_loss(x_cu, y_cu)) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray @@ -70,8 +54,6 @@ end @test ctc_loss(x_cu, y_cu) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] - ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) - end diff --git a/test/ctc.jl b/test/ctc.jl index c65944ea48..8e41314f72 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -6,23 +6,16 @@ using LinearAlgebra # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` -# -# Needs to check loss as defined at a particular time step -# related to the change in x because slight deviations in -# input propagate through further time steps, intrinsically -# causing the gradients to change and thus not be comparable -# between the numeric and analytical definitions function ctc_ngradient(xs...) - f = Flux.Losses.ctc_ + f = ctc_loss grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) - t = div(i-1, size(x, 1)) + 1 tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...)[1][t] + y1 = f(xs...) x[i] = tmp + δ/2 - y2 = f(xs...)[1][t] + y2 = f(xs...) x[i] = tmp Δ[i] = (y2-y1)/δ end @@ -30,25 +23,20 @@ function ctc_ngradient(xs...) end @testset "ctc_loss" begin - x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) - g1 = gradient(ctc_loss, x, y)[1] g1 = g1 g2 = ctc_ngradient(x, y)[1] - @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values - x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] y = [1 1 0; 0 0 1; 0 0 0] - @test ctc_loss(x, y) ≈ 3.6990738275138035 + g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x, y)[1] - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] @@ -56,8 +44,6 @@ end @test ctc_loss(x, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] - ghat = gradient(ctc_loss, x, y)[1] @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) - end From 3eb9e51579a4fdd65ec8cdf2542c2406407c711f Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 20 Dec 2020 12:38:40 -0700 Subject: [PATCH 19/31] Revert "Remove time indexing from ctc gradchecks" This reverts commit c04385592c0467a2ccb43209c4567557ae5c1a7c. --- test/ctc-gpu.jl | 24 +++++++++++++++++++++--- test/ctc.jl | 22 ++++++++++++++++++---- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index 174328256d..06c0f7eaf5 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -8,16 +8,23 @@ using Statistics # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` +# +# Needs to check loss as defined at a particular time step +# related to the change in x because slight deviations in +# input propagate through further time steps, intrinsically +# causing the gradients to change and thus not be comparable +# between the numeric and analytical definitions function ctc_ngradient(xs...) - f = ctc_loss + f = Flux.Losses.ctc_ grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) + t = div(i-1, size(x, 1)) + 1 tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...) + y1 = f(xs...)[1][t] x[i] = tmp + δ/2 - y2 = f(xs...) + y2 = f(xs...)[1][t] x[i] = tmp Δ[i] = (y2-y1)/δ end @@ -25,28 +32,37 @@ function ctc_ngradient(xs...) end @testset "ctc-gpu" begin + x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) + x_cu = CuArray(x) y_cu = CuArray(y) g1 = gradient(ctc_loss, x_cu, y_cu)[1] g1 = g1 |> collect + g2 = ctc_ngradient(x, y)[1] + @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # test that GPU loss matches CPU implementation + l1 = ctc_loss(x_cu, y_cu) l2 = ctc_loss(x, y) + @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values + x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray y_cu = [1 1 0; 0 0 1; 0 0 0] |> CuArray + @test mean(ctc_loss(x_cu, y_cu)) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect + @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray @@ -54,6 +70,8 @@ end @test ctc_loss(x_cu, y_cu) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] + ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + end diff --git a/test/ctc.jl b/test/ctc.jl index 8e41314f72..c65944ea48 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -6,16 +6,23 @@ using LinearAlgebra # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` +# +# Needs to check loss as defined at a particular time step +# related to the change in x because slight deviations in +# input propagate through further time steps, intrinsically +# causing the gradients to change and thus not be comparable +# between the numeric and analytical definitions function ctc_ngradient(xs...) - f = ctc_loss + f = Flux.Losses.ctc_ grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) + t = div(i-1, size(x, 1)) + 1 tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...) + y1 = f(xs...)[1][t] x[i] = tmp + δ/2 - y2 = f(xs...) + y2 = f(xs...)[1][t] x[i] = tmp Δ[i] = (y2-y1)/δ end @@ -23,20 +30,25 @@ function ctc_ngradient(xs...) end @testset "ctc_loss" begin + x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) + g1 = gradient(ctc_loss, x, y)[1] g1 = g1 g2 = ctc_ngradient(x, y)[1] + @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values + x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] y = [1 1 0; 0 0 1; 0 0 0] + @test ctc_loss(x, y) ≈ 3.6990738275138035 - g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x, y)[1] + @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] @@ -44,6 +56,8 @@ end @test ctc_loss(x, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] + ghat = gradient(ctc_loss, x, y)[1] @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + end From bccef7a515915a03dfebb84708f6d81e50ccee46 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sun, 20 Dec 2020 13:07:30 -0700 Subject: [PATCH 20/31] Change typedZero to typed_zero --- src/losses/ctc.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index c6175ba314..13ef8f9cc7 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -58,7 +58,7 @@ function add_blanks(z, blank) end function ctc_(ŷ, y) - typedZero = zero(ŷ[1]) + typed_zero = zero(ŷ[1]) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) z = F(Base.argmax.([y[:,i] for i=1:size(y,2)]), blank) @@ -68,7 +68,7 @@ function ctc_(ŷ, y) U′ = length(z′) # Calculate α coefficients, from the upper-left, to the bottom-right - α = fill(typedZero, T, U′) + α = fill(typed_zero, T, U′) for t=1:T for u=1:U′ if t == u == 1 @@ -76,9 +76,9 @@ function ctc_(ŷ, y) elseif t == 1 && u == 2 α[t,u] = ŷ[z′[2], t] elseif t == 1 && u > 2 - α[t,u] = log(typedZero) + α[t,u] = log(typed_zero) elseif u < U′ - 2(T - t) - 1 - α[t,u] = log(typedZero) + α[t,u] = log(typed_zero) else idx = u - 2 idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) @@ -89,12 +89,12 @@ function ctc_(ŷ, y) end # Calculate beta coefficients, from the bottom-right, to the upper-left - β = fill(log(typedZero), T, U′) + β = fill(log(typed_zero), T, U′) # Fill bottom-right corner so bounding errors can be avoided # by starting `u` at `U′-1` - β[T,U′] = typedZero - β[T,U′-1] = typedZero + β[T,U′] = typed_zero + β[T,U′-1] = typed_zero # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 for t=(T-1):-1:1 @@ -115,8 +115,8 @@ function ctc_(ŷ, y) # α and β coefficients for all the label classes at time t αβ = α + β losses = -1 .* logsumexp(αβ, dims=2) - accum = fill(log(typedZero), size(ŷ)) - grads = fill(log(typedZero), size(ŷ)) + accum = fill(log(typed_zero), size(ŷ)) + grads = fill(log(typed_zero), size(ŷ)) for t=1:T for u=1:U′ accum[z′[u], t] = logaddexp(accum[z′[u], t], α[t,u] + β[t,u]) From 00d41252bf25a2efb0abcc9d3022883631fde17a Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Tue, 22 Dec 2020 16:14:13 -0700 Subject: [PATCH 21/31] Update gradcheck comment; remove some whitespace --- test/ctc-gpu.jl | 24 +++++------------------- test/ctc.jl | 20 ++++++-------------- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index 06c0f7eaf5..395fbd3000 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -4,16 +4,14 @@ using Flux.Losses: ctc_loss using Zygote: gradient using LinearAlgebra using CUDA -using Statistics # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` # -# Needs to check loss as defined at a particular time step -# related to the change in x because slight deviations in -# input propagate through further time steps, intrinsically -# causing the gradients to change and thus not be comparable -# between the numeric and analytical definitions +# Checks loss at the particular time step related to change +# in input value because the gradient for that changed +# value is calculated from the loss value associated with +# that time step in the analytical gradient calculation. function ctc_ngradient(xs...) f = Flux.Losses.ctc_ grads = zero.(xs) @@ -32,37 +30,27 @@ function ctc_ngradient(xs...) end @testset "ctc-gpu" begin - x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) - x_cu = CuArray(x) y_cu = CuArray(y) - g1 = gradient(ctc_loss, x_cu, y_cu)[1] g1 = g1 |> collect - g2 = ctc_ngradient(x, y)[1] - @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # test that GPU loss matches CPU implementation - l1 = ctc_loss(x_cu, y_cu) l2 = ctc_loss(x, y) - @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values - x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray y_cu = [1 1 0; 0 0 1; 0 0 0] |> CuArray - - @test mean(ctc_loss(x_cu, y_cu)) ≈ 3.6990738275138035 + @test ctc_loss(x_cu, y_cu) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray @@ -70,8 +58,6 @@ end @test ctc_loss(x_cu, y_cu) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] - ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) - end diff --git a/test/ctc.jl b/test/ctc.jl index c65944ea48..aebf519a5b 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -6,12 +6,11 @@ using LinearAlgebra # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` -# -# Needs to check loss as defined at a particular time step -# related to the change in x because slight deviations in -# input propagate through further time steps, intrinsically -# causing the gradients to change and thus not be comparable -# between the numeric and analytical definitions +# +# Checks loss at the particular time step related to change +# in input value because the gradient for that changed +# value is calculated from the loss value associated with +# that time step in the analytical gradient calculation. function ctc_ngradient(xs...) f = Flux.Losses.ctc_ grads = zero.(xs) @@ -30,25 +29,20 @@ function ctc_ngradient(xs...) end @testset "ctc_loss" begin - x = rand(10, 50) y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) - g1 = gradient(ctc_loss, x, y)[1] g1 = g1 g2 = ctc_ngradient(x, y)[1] - @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values - x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] y = [1 1 0; 0 0 1; 0 0 0] - @test ctc_loss(x, y) ≈ 3.6990738275138035 + g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x, y)[1] - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] @@ -56,8 +50,6 @@ end @test ctc_loss(x, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] - ghat = gradient(ctc_loss, x, y)[1] @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) - end From 9b37c8fc1da128a92ff9bc57c6d0f6f0a624e9e9 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Thu, 24 Dec 2020 12:57:08 -0700 Subject: [PATCH 22/31] Reduce allocations for better performance --- src/losses/ctc.jl | 61 +++++++++++++++++++++++------------------------ test/ctc.jl | 12 +++------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 13ef8f9cc7..dbb506cff7 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -68,23 +68,22 @@ function ctc_(ŷ, y) U′ = length(z′) # Calculate α coefficients, from the upper-left, to the bottom-right - α = fill(typed_zero, T, U′) - for t=1:T + α = fill(log(typed_zero), T, U′) + α[1,1] = ŷ[blank, 1] + α[1,2] = ŷ[z′[2], 1] + for t=2:T + bound = U′ - 2(T - t) - 1 for u=1:U′ - if t == u == 1 - α[t,u] = ŷ[blank, t] - elseif t == 1 && u == 2 - α[t,u] = ŷ[z′[2], t] - elseif t == 1 && u > 2 - α[t,u] = log(typed_zero) - elseif u < U′ - 2(T - t) - 1 - α[t,u] = log(typed_zero) - else - idx = u - 2 - idx += z′[u] == blank || (u > 2 && z′[u-2] == z′[u]) - idx = max(1, idx) - α[t,u] = ŷ[z′[u], t] + foldl(logaddexp, α[t-1, idx:u]) - end + if u < bound continue end + if u == 1 + α[t,u] = α[t-1, u] + else + α[t,u] = logaddexp(α[t-1, u], α[t-1, u-1]) + if z′[u] != blank && u != 2 && z′[u] != z′[u-2] + α[t,u] = logaddexp(α[t,u], α[t-1,u-2]) + end + end + α[t,u] += ŷ[z′[u], t] end end @@ -98,23 +97,24 @@ function ctc_(ŷ, y) # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 for t=(T-1):-1:1 + bound1 = 2t + bound2 = U′ + 1 for u=U′:-1:1 - if u > 2t || u > U′ + 1 - continue - end - idx = u+2 - idx -= z′[u] == blank || (idx < U′ && z′[u+2] == z′[u]) - idx = min(idx, U′) - v = [β[t+1,i] + ŷ[z′[i], t+1] for i=u:idx] - β[t, u] = foldl(logaddexp, v) + if u > bound1 || u > bound2 continue end + if u == U′ + β[t, u] = ŷ[z′[u], t+1] + β[t+1, u] + else + β[t, u] = logaddexp(ŷ[z′[u], t+1] + β[t+1, u], ŷ[z′[u+1], t+1] + β[t+1,u+1]) + if z′[u] != blank && u != U′-1 && z′[u] != z′[u+2] + β[t, u] = logaddexp(β[t, u], ŷ[z′[u+2], t+1] + β[t+1, u+2]) + end + end end end - # Loss at each time t is taken as the sum of the product (sum in log space) of the # α and β coefficients for all the label classes at time t - αβ = α + β - losses = -1 .* logsumexp(αβ, dims=2) + loss = -1 * logaddexp(α[T,end], α[T, end-1]) accum = fill(log(typed_zero), size(ŷ)) grads = fill(log(typed_zero), size(ŷ)) for t=1:T @@ -122,11 +122,10 @@ function ctc_(ŷ, y) accum[z′[u], t] = logaddexp(accum[z′[u], t], α[t,u] + β[t,u]) end for u=1:size(grads, 1) - grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -losses[t]) + grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -loss) end end - losses = [x for x in losses] - return losses, grads + return loss, grads end """ @@ -150,7 +149,7 @@ or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7) for mathematical details. """ function ctc_loss(ŷ::Array, y::Array) - return ctc_(ŷ, y)[1] |> mean + return ctc_(ŷ, y)[1] end @adjoint function ctc_(ŷ, y) diff --git a/test/ctc.jl b/test/ctc.jl index aebf519a5b..391077e83f 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -6,22 +6,16 @@ using LinearAlgebra # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` -# -# Checks loss at the particular time step related to change -# in input value because the gradient for that changed -# value is calculated from the loss value associated with -# that time step in the analytical gradient calculation. function ctc_ngradient(xs...) - f = Flux.Losses.ctc_ + f = Flux.Losses.ctc_loss grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) - t = div(i-1, size(x, 1)) + 1 tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...)[1][t] + y1 = f(xs...) x[i] = tmp + δ/2 - y2 = f(xs...)[1][t] + y2 = f(xs...) x[i] = tmp Δ[i] = (y2-y1)/δ end From 75bb3c1fec4d15839bc2760c36c11b40d433515b Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Thu, 24 Dec 2020 13:42:05 -0700 Subject: [PATCH 23/31] Transpose alpha and beta to match GPU kernel --- src/losses/ctc.jl | 48 +++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index dbb506cff7..07a20ffe1e 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -48,7 +48,6 @@ end Adds blanks to the start and end of `z`, and between item in `z` """ function add_blanks(z, blank) - z′ = [blank] for label in z push!(z′, label) @@ -61,65 +60,66 @@ function ctc_(ŷ, y) typed_zero = zero(ŷ[1]) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - z = F(Base.argmax.([y[:,i] for i=1:size(y,2)]), blank) + z = F(Base.argmax.(eachcol(y)), blank) z′ = add_blanks(z, blank) T = size(ŷ, 2) - U = length(z) U′ = length(z′) - # Calculate α coefficients, from the upper-left, to the bottom-right - α = fill(log(typed_zero), T, U′) + α = fill(log(typed_zero), U′, T) α[1,1] = ŷ[blank, 1] - α[1,2] = ŷ[z′[2], 1] + α[2,1] = ŷ[z′[2], 1] for t=2:T bound = U′ - 2(T - t) - 1 for u=1:U′ - if u < bound continue end + u < bound && continue if u == 1 - α[t,u] = α[t-1, u] + α[u,t] = α[u, t-1] else - α[t,u] = logaddexp(α[t-1, u], α[t-1, u-1]) + α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) + + # array bounds check and f(u) function from Eq. 7.9 if z′[u] != blank && u != 2 && z′[u] != z′[u-2] - α[t,u] = logaddexp(α[t,u], α[t-1,u-2]) + α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) end end - α[t,u] += ŷ[z′[u], t] + α[u,t] += ŷ[z′[u], t] end end # Calculate beta coefficients, from the bottom-right, to the upper-left - β = fill(log(typed_zero), T, U′) + β = fill(log(typed_zero), U′, T) # Fill bottom-right corner so bounding errors can be avoided # by starting `u` at `U′-1` - β[T,U′] = typed_zero - β[T,U′-1] = typed_zero + β[U′, T] = typed_zero + β[U′-1, T] = typed_zero # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 for t=(T-1):-1:1 - bound1 = 2t - bound2 = U′ + 1 + bound = min(U′ + 1, 2t) for u=U′:-1:1 - if u > bound1 || u > bound2 continue end + u > bound && continue if u == U′ - β[t, u] = ŷ[z′[u], t+1] + β[t+1, u] + β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] else - β[t, u] = logaddexp(ŷ[z′[u], t+1] + β[t+1, u], ŷ[z′[u+1], t+1] + β[t+1,u+1]) + β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) + + # array bounds check and g(u) function from Eq. 7.16 if z′[u] != blank && u != U′-1 && z′[u] != z′[u+2] - β[t, u] = logaddexp(β[t, u], ŷ[z′[u+2], t+1] + β[t+1, u+2]) + β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) end end end end - # Loss at each time t is taken as the sum of the product (sum in log space) of the - # α and β coefficients for all the label classes at time t - loss = -1 * logaddexp(α[T,end], α[T, end-1]) + # Loss is taken as the product (sum in log space) of the last two + # cells in the last column in α + loss = -1 * logaddexp(α[end,T], α[end-1, T]) accum = fill(log(typed_zero), size(ŷ)) grads = fill(log(typed_zero), size(ŷ)) for t=1:T for u=1:U′ - accum[z′[u], t] = logaddexp(accum[z′[u], t], α[t,u] + β[t,u]) + accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t]) end for u=1:size(grads, 1) grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -loss) From b00487aab81e7e14f12edc03b338a2140fbe134b Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Thu, 24 Dec 2020 15:29:06 -0700 Subject: [PATCH 24/31] Update add_blanks to use fill --- src/losses/ctc.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 07a20ffe1e..c92654258d 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -48,11 +48,8 @@ end Adds blanks to the start and end of `z`, and between item in `z` """ function add_blanks(z, blank) - z′ = [blank] - for label in z - push!(z′, label) - push!(z′, blank) - end + z′ = fill(blank, 2*length(z) + 1) + z′[2 .* eachindex(z)] = z return z′ end @@ -122,7 +119,7 @@ function ctc_(ŷ, y) accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t]) end for u=1:size(grads, 1) - grads[u,t] = exp(ŷ[u, t]) - exp(accum[u, t] - -loss) + grads[u,t] = exp(ŷ[u,t]) - exp(accum[u,t] + loss) end end return loss, grads From d96a53fd315a1e2431db1d8b3274b45a53c8f55a Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Thu, 31 Dec 2020 00:15:16 -0700 Subject: [PATCH 25/31] Split CPU loss and gradient calculation --- src/losses/ctc.jl | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index c92654258d..8e15753743 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -68,20 +68,27 @@ function ctc_(ŷ, y) for t=2:T bound = U′ - 2(T - t) - 1 for u=1:U′ - u < bound && continue if u == 1 α[u,t] = α[u, t-1] - else + elseif u >= bound α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) # array bounds check and f(u) function from Eq. 7.9 - if z′[u] != blank && u != 2 && z′[u] != z′[u-2] + if u > 2 && z′[u] != blank && z′[u] != z′[u-2] α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) end end α[u,t] += ŷ[z′[u], t] end end + return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ) +end + +@adjoint function ctc_(ŷ, y) + loss, α, z′, ŷ = ctc_(ŷ, y) + U′, T = size(α) + blank = U′ + typed_zero = zero(first(α)) # Calculate beta coefficients, from the bottom-right, to the upper-left β = fill(log(typed_zero), U′, T) @@ -95,34 +102,29 @@ function ctc_(ŷ, y) for t=(T-1):-1:1 bound = min(U′ + 1, 2t) for u=U′:-1:1 - u > bound && continue if u == U′ β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] - else + elseif u <= bound β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) # array bounds check and g(u) function from Eq. 7.16 - if z′[u] != blank && u != U′-1 && z′[u] != z′[u+2] + if u < U′-1 && z′[u] != blank && z′[u] != z′[u+2] β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) end end end end - # Loss is taken as the product (sum in log space) of the last two - # cells in the last column in α - loss = -1 * logaddexp(α[end,T], α[end-1, T]) + # Accumulate alpha-beta products for each category, + # then calculate gradients accum = fill(log(typed_zero), size(ŷ)) - grads = fill(log(typed_zero), size(ŷ)) for t=1:T for u=1:U′ accum[z′[u], t] = logaddexp(accum[z′[u], t], α[u,t] + β[u,t]) end - for u=1:size(grads, 1) - grads[u,t] = exp(ŷ[u,t]) - exp(accum[u,t] + loss) - end end - return loss, grads + grads = exp.(ŷ) .- exp.(accum .+ loss) + return loss, g -> (g .* grads, nothing) end """ @@ -148,8 +150,3 @@ for mathematical details. function ctc_loss(ŷ::Array, y::Array) return ctc_(ŷ, y)[1] end - -@adjoint function ctc_(ŷ, y) - ls, gs = ctc_(ŷ, y) - return mean(ls), Δ -> (Δ .* gs, Δ) -end From 6ca07e2df2de53b4f5610c4eb6d957863cbd94ea Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sat, 2 Jan 2021 17:11:20 -0700 Subject: [PATCH 26/31] Rejig CPU CTC API --- src/losses/ctc.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 8e15753743..c79cf0f809 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -34,11 +34,12 @@ function F(A, blank) prev = A[1] z = [prev] for curr in A[2:end] - if curr != prev && curr != blank + if curr != prev push!(z, curr) end prev = curr end + filter!(x -> x != blank, z) return z end @@ -53,7 +54,7 @@ function add_blanks(z, blank) return z′ end -function ctc_(ŷ, y) +function ctc_alpha(ŷ::AbstractArray, y) typed_zero = zero(ŷ[1]) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) @@ -84,8 +85,8 @@ function ctc_(ŷ, y) return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ) end -@adjoint function ctc_(ŷ, y) - loss, α, z′, ŷ = ctc_(ŷ, y) +function ∇ctc_loss(ŷ::AbstractArray, y, out) + loss, α, z′, ŷ = out U′, T = size(α) blank = U′ typed_zero = zero(first(α)) @@ -124,7 +125,7 @@ end end end grads = exp.(ŷ) .- exp.(accum .+ loss) - return loss, g -> (g .* grads, nothing) + return grads end """ @@ -147,6 +148,10 @@ solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves or [Graves (2012)](https://www.cs.toronto.edu/~graves/preprint.pdf#chapter.7) for mathematical details. """ -function ctc_loss(ŷ::Array, y::Array) - return ctc_(ŷ, y)[1] +ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss + +@adjoint function ctc_loss(ŷ, y) + out = ctc_alpha(ŷ, y) + ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) + return out.loss, ctc_loss_pullback end From bc7ab03b57c8e3f60a9f8bd1462a904819f22bbe Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sat, 2 Jan 2021 17:11:50 -0700 Subject: [PATCH 27/31] Split GPU CTC kernel and update API --- src/losses/ctc-gpu.jl | 51 +++++++++++++++++++------------------------ test/ctc-gpu.jl | 12 +++------- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index e424cc4464..600a317704 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -110,7 +110,6 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, start = S > 1 ? S-2 : 0 last = L + repeats < T ? S : S-1 sync_threads() - i = tid # Calculate coefficients for last column (time step) @@ -129,7 +128,6 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, accum[labelIdx, T] = log_plus_f(accum[labelIdx, T], output[i, T]) end end - sync_threads() # Fill in `grad` for last column (time step) @@ -175,7 +173,6 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, beta[S, t] = beta[S, t] + probs[blankLabel, t+1] end sync_threads() - idx = tid while idx <= S output[idx, t] = alphas[idx, t] + beta[idx, t] @@ -207,7 +204,6 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s) idx += blockDim().x end - sync_threads() t -= 1 sync_threads() @@ -215,38 +211,35 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, return nothing end -# methods for `ctc_` helper function -ctc_loss(ŷ::CuArray, y::Array) = ctc_(ŷ, y)[1] |> mean -ctc_loss(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y))[1] |> mean -ctc_loss(ŷ::CuArray, y::CuArray) = ctc_(ŷ, collect(y))[1] |> mean -ctc_(ŷ::Array, y::CuArray) = ctc_(CuArray(ŷ), collect(y)) - -function ctc_(ŷ::CuArray, y) +function ctc_alpha(ŷ::CuArray, y) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - labels = [Base.argmax(y[:,i]) for i in 1:size(y, 2)] + labels = Base.argmax.(eachcol(y)) z = F(labels, blank) - z′ = [blank] - for label in z - push!(z′, label) - push!(z′, blank) - end - + z′ = fill(blank, 2 * length(z) + 1) + z′[eachindex(z) .* 2] = z T = size(ŷ, 2) U′ = 2*length(z) + 1 - alphas = CUDA.fill(log(zero(ŷ[1])), U′, T) - betas = CUDA.fill(log(zero(ŷ[1])), U′, T) - output = CUDA.fill(log(zero(ŷ[1])), U′, T) + alphas = CUDA.fill(log(zero(ŷ[1])), U′,T) nRepeats = count_repeats(labels) nThreads = min(U′, MAX_THREADS) + @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(z), T, nRepeats, CuArray(z), CuArray(z′), alphas, blank) + return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z=z, z′=z′, yhat=ŷ, nRepeats=nRepeats) +end - @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z), CuArray(z′), alphas, blank) - grads = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) - accum = CUDA.fill(log(zero(ŷ[1])), size(ŷ)) - - @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(z), size(ŷ,2), nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) - ls = collect(output) - ls = vec(-1 .* [logsumexp(ls[:,i]) for i in 1:size(ls, 2)]) +ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss + +function ∇ctc_loss(ŷ::CuArray, y, out) + loss, alphas, z, z′, ŷ, nRepeats = out + U′, T = size(alphas) + blank = size(ŷ, 1) + typed_zero = zero(first(ŷ)) + betas = CUDA.fill(log(typed_zero), U′, T) + output = CUDA.fill(log(typed_zero), U′, T) + nThreads = min(U′, MAX_THREADS) + grads = CUDA.fill(log(typed_zero), size(ŷ)) + accum = CUDA.fill(log(typed_zero), size(ŷ)) + @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(z), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) ŷ = alphas = betas = output = accum = nothing - return ls, grads + return grads end diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index 395fbd3000..045854940e 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -7,22 +7,16 @@ using CUDA # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` -# -# Checks loss at the particular time step related to change -# in input value because the gradient for that changed -# value is calculated from the loss value associated with -# that time step in the analytical gradient calculation. function ctc_ngradient(xs...) - f = Flux.Losses.ctc_ + f = Flux.Losses.ctc_loss grads = zero.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) δ = sqrt(eps()) - t = div(i-1, size(x, 1)) + 1 tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...)[1][t] + y1 = f(xs...)[1] x[i] = tmp + δ/2 - y2 = f(xs...)[1][t] + y2 = f(xs...)[1] x[i] = tmp Δ[i] = (y2-y1)/δ end From 807aefa7199220895b1c705bdbcaa1d207dc69ca Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Wed, 13 Jan 2021 22:39:44 -0700 Subject: [PATCH 28/31] Move to onecold representation for ctc input --- src/losses/ctc-gpu.jl | 45 +++++++++++++---------------------- src/losses/ctc.jl | 55 ++++++++++++++----------------------------- test/ctc-gpu.jl | 46 +++++++++++++++++++++++------------- test/ctc.jl | 34 +++++++++++++++++--------- 4 files changed, 86 insertions(+), 94 deletions(-) diff --git a/src/losses/ctc-gpu.jl b/src/losses/ctc-gpu.jl index 600a317704..48934d1cae 100644 --- a/src/losses/ctc-gpu.jl +++ b/src/losses/ctc-gpu.jl @@ -63,7 +63,7 @@ function compute_alpha_kernel(probs, labelSize, uttLength, repeats, labelsWithou # Corner-case checking if tid == 1 && !(1 < S - 2*(T-t) - 1) if start == 0 - alpha[1, t] = probs[blankLabel, t] + alpha[1, t-1] + alpha[1, t] = alpha[1, t-1] + probs[blankLabel, t] elseif start == 1 alpha[1, t] = alpha[1, t-1] end @@ -93,7 +93,7 @@ end function compute_beta_and_grad_kernel(probs, labelSize, uttLength, repeatsInLabel, labelsWithBlanks, alphas, beta, output, accum, - grad, blankLabel) + grad, blankLabel, loss) tid = threadIdx().x L = labelSize @@ -114,7 +114,7 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, # Calculate coefficients for last column (time step) # then determine alpha and beta product - while i <= last - start + 1 + while i <= last - start beta[i+start, T] = 0 output[i+start, T] = beta[i+start, T] + alphas[i+start, T] i += blockDim().x @@ -149,16 +149,15 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, while t >= 1 if t < T idx = tid - # while idx <= S-1 while idx <= S - nextSum = beta[idx, t+1] + probs[labels[idx], t+1] + nextSum = probs[labels[idx], t+1] + beta[idx, t+1] if idx < S nextSum = log_plus_f(nextSum, - beta[idx+1, t+1] + probs[labels[idx+1], t+1]) + probs[labels[idx+1], t+1] + beta[idx+1, t+1]) end if labels[idx] != blankLabel && idx != S-1 && labels[idx] != labels[idx+2] nextSum = log_plus_f(nextSum, - beta[idx + 2, t+1] + probs[labels[idx+2], t+1]) + probs[labels[idx+2], t+1] + beta[idx + 2, t+1]) end if idx > 2*t beta[idx, t] = -Inf32 @@ -168,11 +167,6 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, idx += blockDim().x end sync_threads() - - if tid == 1 && last == S - beta[S, t] = beta[S, t] + probs[blankLabel, t+1] - end - sync_threads() idx = tid while idx <= S output[idx, t] = alphas[idx, t] + beta[idx, t] @@ -195,13 +189,9 @@ function compute_beta_and_grad_kernel(probs, labelSize, uttLength, # Calculate gradients while idx <= size(grad, 1) - s = -Inf32 - for i=1:S - s = log_plus_f(s, output[i, t]) - end # ∂L/∂a (where a is activation before logsoftmax) - grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] - s) + grad[idx, t] = CUDA.exp(probs[idx, t]) - CUDA.exp(accum[idx, t] + loss) idx += blockDim().x end sync_threads() @@ -214,23 +204,21 @@ end function ctc_alpha(ŷ::CuArray, y) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - labels = Base.argmax.(eachcol(y)) - z = F(labels, blank) - z′ = fill(blank, 2 * length(z) + 1) - z′[eachindex(z) .* 2] = z + z′ = fill(blank, 2 * length(y) + 1) + z′[eachindex(y) .* 2] = y T = size(ŷ, 2) - U′ = 2*length(z) + 1 + U′ = 2*length(y) + 1 alphas = CUDA.fill(log(zero(ŷ[1])), U′,T) - nRepeats = count_repeats(labels) + nRepeats = count_repeats(y) nThreads = min(U′, MAX_THREADS) - @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(z), T, nRepeats, CuArray(z), CuArray(z′), alphas, blank) - return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z=z, z′=z′, yhat=ŷ, nRepeats=nRepeats) + @cuda blocks=1 threads=nThreads compute_alpha_kernel(ŷ, length(y), T, nRepeats, CuArray(y), CuArray(z′), alphas, blank) + return (loss=-1 * logsumexp(alphas[end-1:end]), alpha=alphas, z′=z′, yhat=ŷ, nRepeats=nRepeats) end ctc_loss(ŷ::CuArray, y) = ctc_alpha(ŷ::CuArray, y).loss function ∇ctc_loss(ŷ::CuArray, y, out) - loss, alphas, z, z′, ŷ, nRepeats = out + loss, alphas, z′, ŷ, nRepeats = out U′, T = size(alphas) blank = size(ŷ, 1) typed_zero = zero(first(ŷ)) @@ -238,8 +226,7 @@ function ∇ctc_loss(ŷ::CuArray, y, out) output = CUDA.fill(log(typed_zero), U′, T) nThreads = min(U′, MAX_THREADS) grads = CUDA.fill(log(typed_zero), size(ŷ)) - accum = CUDA.fill(log(typed_zero), size(ŷ)) - @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(z), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank) - ŷ = alphas = betas = output = accum = nothing + accum = CUDA.fill(log(typed_zero), size(ŷ)) + @cuda blocks=1 threads=nThreads compute_beta_and_grad_kernel(ŷ, length(y), T, nRepeats, CuArray(z′), alphas, betas, output, accum, grads, blank, loss) return grads end diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index c79cf0f809..5d5a67597e 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -23,30 +23,10 @@ function logaddexp(a, b) return a + log(1+exp(b-a)) end -""" - F(A, blank) - -Removes blanks and repetitions in the sequence `A` - -This is the function `F` as defined in Graves (2012) -""" -function F(A, blank) - prev = A[1] - z = [prev] - for curr in A[2:end] - if curr != prev - push!(z, curr) - end - prev = curr - end - filter!(x -> x != blank, z) - return z -end - """ add_blanks(z) -Adds blanks to the start and end of `z`, and between item in `z` +Adds blanks to the start and end of `z`, and between items in `z` """ function add_blanks(z, blank) z′ = fill(blank, 2*length(z) + 1) @@ -58,8 +38,7 @@ function ctc_alpha(ŷ::AbstractArray, y) typed_zero = zero(ŷ[1]) ŷ = logsoftmax(ŷ) blank = size(ŷ, 1) - z = F(Base.argmax.(eachcol(y)), blank) - z′ = add_blanks(z, blank) + z′ = add_blanks(y, blank) T = size(ŷ, 2) U′ = length(z′) @@ -67,15 +46,15 @@ function ctc_alpha(ŷ::AbstractArray, y) α[1,1] = ŷ[blank, 1] α[2,1] = ŷ[z′[2], 1] for t=2:T - bound = U′ - 2(T - t) - 1 - for u=1:U′ + bound = max(1, U′ - 2(T - t) - 1) + for u=bound:U′ if u == 1 α[u,t] = α[u, t-1] - elseif u >= bound + else α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) # array bounds check and f(u) function from Eq. 7.9 - if u > 2 && z′[u] != blank && z′[u] != z′[u-2] + if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) end end @@ -88,7 +67,7 @@ end function ∇ctc_loss(ŷ::AbstractArray, y, out) loss, α, z′, ŷ = out U′, T = size(α) - blank = U′ + blank = size(ŷ, 1) typed_zero = zero(first(α)) # Calculate beta coefficients, from the bottom-right, to the upper-left @@ -101,15 +80,15 @@ function ∇ctc_loss(ŷ::AbstractArray, y, out) # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 for t=(T-1):-1:1 - bound = min(U′ + 1, 2t) - for u=U′:-1:1 + bound = min(U′, 2t) + for u=bound:-1:1 if u == U′ β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] - elseif u <= bound + else β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) - + # array bounds check and g(u) function from Eq. 7.16 - if u < U′-1 && z′[u] != blank && z′[u] != z′[u+2] + if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) end end @@ -134,14 +113,16 @@ end Computes the connectionist temporal classification loss between `ŷ` and `y`. -Both `ŷ` and `y` must be classes-by-time matrices, i.e., each row +`ŷ` must be a classes-by-time matrices, i.e., each row represents a class and each column represents a time step. Additionally, the `logsoftmax` function will be applied to `ŷ`, so -it must be the raw activation values from the neural network and +`ŷ` must be the raw activation values from the neural network and not, for example, the activations after being passed through a -`softmax` activation function. +`softmax` activation function. `y` must be a 1D `Array` of the labels +associated with `ŷ`. The blank label is assumed to be the last label +category in `ŷ`, so it is equivalent to `size(ŷ, 1)`. -Used for sequence to sequence classification problems such as +Used for sequence-to-sequence classification problems such as speech recognition and handwriting recognition where the exact time-alignment of the output (e.g., letters) is not needed to solve the problem. See [Graves et al. (2006)](https://www.cs.toronto.edu/~graves/icml_2006.pdf) diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index 045854940e..f6c05522e5 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -7,51 +7,63 @@ using CUDA # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` -function ctc_ngradient(xs...) +function ctc_ngradient(x, y) f = Flux.Losses.ctc_loss - grads = zero.(xs) - for (x, Δ) in zip(xs, grads), i in 1:length(x) + grads = zero(x) + for i in 1:length(x) δ = sqrt(eps()) tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...)[1] + y1 = f(x, y) x[i] = tmp + δ/2 - y2 = f(xs...)[1] + y2 = f(x, y) x[i] = tmp - Δ[i] = (y2-y1)/δ + grads[i] = (y2-y1)/δ end return grads end +function F(A, blank) + prev = A[1] + z = [prev] + for curr in A[2:end] + if curr != prev + push!(z, curr) + end + prev = curr + end + filter!(x -> x != blank, z) + return z +end + @testset "ctc-gpu" begin x = rand(10, 50) - y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) + y = F(rand(1:9, 30), 10) x_cu = CuArray(x) - y_cu = CuArray(y) - g1 = gradient(ctc_loss, x_cu, y_cu)[1] + g1 = gradient(ctc_loss, x_cu, y)[1] g1 = g1 |> collect - g2 = ctc_ngradient(x, y)[1] + g2 = ctc_ngradient(x, y) @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # test that GPU loss matches CPU implementation - l1 = ctc_loss(x_cu, y_cu) + l1 = ctc_loss(x_cu, y) l2 = ctc_loss(x, y) @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray - y_cu = [1 1 0; 0 0 1; 0 0 0] |> CuArray - @test ctc_loss(x_cu, y_cu) ≈ 3.6990738275138035 + y = [1, 2] + @test ctc_loss(x_cu, y) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] - ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect + ghat = gradient(ctc_loss, x_cu, y)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray - y_cu = [1 1 0 0; 0 0 1 1; 0 0 0 0] |> CuArray - @test ctc_loss(x_cu, y_cu) ≈ 8.02519869363453 + y = [1, 2] |> CuArray + @test ctc_loss(x_cu, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] - ghat = gradient(ctc_loss, x_cu, y_cu)[1] |> collect + ghat = gradient(ctc_loss, x_cu, y)[1] |> collect @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) end diff --git a/test/ctc.jl b/test/ctc.jl index 391077e83f..6ca63e527f 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -6,33 +6,45 @@ using LinearAlgebra # Custom function to check numerical gradient of ctc loss, # based on `ngradient` in `Tracker.jl` -function ctc_ngradient(xs...) +function ctc_ngradient(x, y) f = Flux.Losses.ctc_loss - grads = zero.(xs) - for (x, Δ) in zip(xs, grads), i in 1:length(x) + grads = zero(x) + for i in 1:length(x) δ = sqrt(eps()) tmp = x[i] x[i] = tmp - δ/2 - y1 = f(xs...) + y1 = f(x, y) x[i] = tmp + δ/2 - y2 = f(xs...) + y2 = f(x, y) x[i] = tmp - Δ[i] = (y2-y1)/δ + grads[i] = (y2-y1)/δ end return grads end +function F(A, blank) + prev = A[1] + z = [prev] + for curr in A[2:end] + if curr != prev + push!(z, curr) + end + prev = curr + end + filter!(x -> x != blank, z) + return z +end + @testset "ctc_loss" begin x = rand(10, 50) - y = reduce(hcat, repeat([Array{Float64}(I, 10, 10)[min(i, 9),:] for i in 1:10], inner=5)) + y = F(rand(1:9, 30), 10) g1 = gradient(ctc_loss, x, y)[1] - g1 = g1 - g2 = ctc_ngradient(x, y)[1] + g2 = ctc_ngradient(x, y) @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) # tests using hand-calculated values x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] - y = [1 1 0; 0 0 1; 0 0 0] + y = [1, 2] @test ctc_loss(x, y) ≈ 3.6990738275138035 g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] @@ -40,7 +52,7 @@ end @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] - y = [1 1 0 0; 0 0 1 1; 0 0 0 0] + y = [1, 2] @test ctc_loss(x, y) ≈ 8.02519869363453 g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] From e1e8cc850d47d596825ceb7f084294730a40919d Mon Sep 17 00:00:00 2001 From: "Matthew C. Kelley" Date: Sat, 16 Jan 2021 15:53:38 -0700 Subject: [PATCH 29/31] Apply suggestions from code review Co-authored-by: Carlo Lucibello --- src/losses/ctc.jl | 2 +- test/ctc.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 5d5a67597e..230d04add3 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -118,7 +118,7 @@ represents a class and each column represents a time step. Additionally, the `logsoftmax` function will be applied to `ŷ`, so `ŷ` must be the raw activation values from the neural network and not, for example, the activations after being passed through a -`softmax` activation function. `y` must be a 1D `Array` of the labels +`softmax` activation function. `y` must be a 1D array of the labels associated with `ŷ`. The blank label is assumed to be the last label category in `ŷ`, so it is equivalent to `size(ŷ, 1)`. diff --git a/test/ctc.jl b/test/ctc.jl index 6ca63e527f..6240343a6e 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -40,7 +40,7 @@ end y = F(rand(1:9, 30), 10) g1 = gradient(ctc_loss, x, y)[1] g2 = ctc_ngradient(x, y) - @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) + @test g1 ≈ g2 rtol=1e-5 atol=1e-5 # tests using hand-calculated values x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] @@ -49,7 +49,7 @@ end g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x, y)[1] - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + @test g ≈ ghat rtol=1e-5 atol=1e-5 x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] y = [1, 2] @@ -57,5 +57,5 @@ end g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] ghat = gradient(ctc_loss, x, y)[1] - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + @test g ≈ ghat rtol=1e-5 atol=1e-5 end From 6e5fb17bd5d0f50e1f8f35ca0161156065e6610c Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Sat, 16 Jan 2021 16:17:00 -0700 Subject: [PATCH 30/31] Remove F in ctc tests; update ctc-gpu test syntax --- test/ctc-gpu.jl | 23 +++++------------------ test/ctc.jl | 21 ++++----------------- 2 files changed, 9 insertions(+), 35 deletions(-) diff --git a/test/ctc-gpu.jl b/test/ctc-gpu.jl index f6c05522e5..d7ff1bdf9d 100644 --- a/test/ctc-gpu.jl +++ b/test/ctc-gpu.jl @@ -23,32 +23,19 @@ function ctc_ngradient(x, y) return grads end -function F(A, blank) - prev = A[1] - z = [prev] - for curr in A[2:end] - if curr != prev - push!(z, curr) - end - prev = curr - end - filter!(x -> x != blank, z) - return z -end - @testset "ctc-gpu" begin x = rand(10, 50) - y = F(rand(1:9, 30), 10) + y = rand(1:9, 30) x_cu = CuArray(x) g1 = gradient(ctc_loss, x_cu, y)[1] g1 = g1 |> collect g2 = ctc_ngradient(x, y) - @test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5)) + @test g1 ≈ g2 rtol=1e-5 atol=1e-5 # test that GPU loss matches CPU implementation l1 = ctc_loss(x_cu, y) l2 = ctc_loss(x, y) - @test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5)) + @test l1 ≈ l2 # tests using hand-calculated values x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray @@ -57,7 +44,7 @@ end g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x_cu, y)[1] |> collect - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + @test g ≈ ghat rtol=1e-5 atol=1e-5 x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray y = [1, 2] |> CuArray @@ -65,5 +52,5 @@ end g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] ghat = gradient(ctc_loss, x_cu, y)[1] |> collect - @test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5)) + @test g ≈ ghat rtol=1e-5 atol=1e-5 end diff --git a/test/ctc.jl b/test/ctc.jl index 6240343a6e..6fa33c4b99 100644 --- a/test/ctc.jl +++ b/test/ctc.jl @@ -22,25 +22,12 @@ function ctc_ngradient(x, y) return grads end -function F(A, blank) - prev = A[1] - z = [prev] - for curr in A[2:end] - if curr != prev - push!(z, curr) - end - prev = curr - end - filter!(x -> x != blank, z) - return z -end - @testset "ctc_loss" begin x = rand(10, 50) - y = F(rand(1:9, 30), 10) + y = rand(1:9, 30) g1 = gradient(ctc_loss, x, y)[1] g2 = ctc_ngradient(x, y) - @test g1 ≈ g2 rtol=1e-5 atol=1e-5 + @test g1 ≈ g2 rtol=1e-5 atol=1e-5 # tests using hand-calculated values x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] @@ -49,7 +36,7 @@ end g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457] ghat = gradient(ctc_loss, x, y)[1] - @test g ≈ ghat rtol=1e-5 atol=1e-5 + @test g ≈ ghat rtol=1e-5 atol=1e-5 x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] y = [1, 2] @@ -57,5 +44,5 @@ end g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07] ghat = gradient(ctc_loss, x, y)[1] - @test g ≈ ghat rtol=1e-5 atol=1e-5 + @test g ≈ ghat rtol=1e-5 atol=1e-5 end From bc94a16534047715c19ad7234c793334dc7b3455 Mon Sep 17 00:00:00 2001 From: Matt Kelley Date: Tue, 19 Jan 2021 16:05:02 -0700 Subject: [PATCH 31/31] Fix indentation in ctc.jl --- src/losses/ctc.jl | 50 +++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 230d04add3..ff1183e99f 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -46,19 +46,19 @@ function ctc_alpha(ŷ::AbstractArray, y) α[1,1] = ŷ[blank, 1] α[2,1] = ŷ[z′[2], 1] for t=2:T - bound = max(1, U′ - 2(T - t) - 1) + bound = max(1, U′ - 2(T - t) - 1) for u=bound:U′ - if u == 1 - α[u,t] = α[u, t-1] - else - α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) - - # array bounds check and f(u) function from Eq. 7.9 - if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) - α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) - end - end - α[u,t] += ŷ[z′[u], t] + if u == 1 + α[u,t] = α[u, t-1] + else + α[u,t] = logaddexp(α[u, t-1], α[u-1, t-1]) + + # array bounds check and f(u) function from Eq. 7.9 + if u > 2 && !(z′[u] == blank || z′[u-2] == z′[u]) + α[u,t] = logaddexp(α[u,t], α[u-2,t-1]) + end + end + α[u,t] += ŷ[z′[u], t] end end return (loss=-1 * logaddexp(α[end,T], α[end-1, T]), alpha=α, zprime=z′, logsoftyhat=ŷ) @@ -80,18 +80,18 @@ function ∇ctc_loss(ŷ::AbstractArray, y, out) # start at T-1 so that β(T, u) = log(0) for all u < U′ - 1 for t=(T-1):-1:1 - bound = min(U′, 2t) + bound = min(U′, 2t) for u=bound:-1:1 - if u == U′ - β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] - else - β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) + if u == U′ + β[u,t] = ŷ[z′[u], t+1] + β[u, t+1] + else + β[u,t] = logaddexp(ŷ[z′[u], t+1] + β[u, t+1], ŷ[z′[u+1], t+1] + β[u+1,t+1]) - # array bounds check and g(u) function from Eq. 7.16 - if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] - β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) - end - end + # array bounds check and g(u) function from Eq. 7.16 + if u+2 <= U′ && z′[u] != blank && z′[u] != z′[u+2] + β[u,t] = logaddexp(β[u,t], ŷ[z′[u+2], t+1] + β[u+2, t+1]) + end + end end end @@ -132,7 +132,7 @@ for mathematical details. ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss @adjoint function ctc_loss(ŷ, y) - out = ctc_alpha(ŷ, y) - ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) - return out.loss, ctc_loss_pullback + out = ctc_alpha(ŷ, y) + ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) + return out.loss, ctc_loss_pullback end