Skip to content

Commit

Permalink
Remove F in ctc tests; update ctc-gpu test syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
maetshju committed Jan 16, 2021
1 parent e1e8cc8 commit 6e5fb17
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 35 deletions.
23 changes: 5 additions & 18 deletions test/ctc-gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,19 @@ function ctc_ngradient(x, y)
return grads
end

function F(A, blank)
prev = A[1]
z = [prev]
for curr in A[2:end]
if curr != prev
push!(z, curr)
end
prev = curr
end
filter!(x -> x != blank, z)
return z
end

@testset "ctc-gpu" begin
x = rand(10, 50)
y = F(rand(1:9, 30), 10)
y = rand(1:9, 30)
x_cu = CuArray(x)
g1 = gradient(ctc_loss, x_cu, y)[1]
g1 = g1 |> collect
g2 = ctc_ngradient(x, y)
@test all(isapprox.(g1, g2, rtol=1e-5, atol=1e-5))
@test g1 g2 rtol=1e-5 atol=1e-5

# test that GPU loss matches CPU implementation
l1 = ctc_loss(x_cu, y)
l2 = ctc_loss(x, y)
@test all(isapprox.(l1, l2, rtol=1e-5, atol=1e-5))
@test l1 l2

# tests using hand-calculated values
x_cu = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.] |> CuArray
Expand All @@ -57,13 +44,13 @@ end

g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]
ghat = gradient(ctc_loss, x_cu, y)[1] |> collect
@test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5))
@test g ghat rtol=1e-5 atol=1e-5

x_cu = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.] |> CuArray
y = [1, 2] |> CuArray
@test ctc_loss(x_cu, y) 8.02519869363453

g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]
ghat = gradient(ctc_loss, x_cu, y)[1] |> collect
@test all(isapprox.(g, ghat, rtol=1e-5, atol=1e-5))
@test g ghat rtol=1e-5 atol=1e-5
end
21 changes: 4 additions & 17 deletions test/ctc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,12 @@ function ctc_ngradient(x, y)
return grads
end

function F(A, blank)
prev = A[1]
z = [prev]
for curr in A[2:end]
if curr != prev
push!(z, curr)
end
prev = curr
end
filter!(x -> x != blank, z)
return z
end

@testset "ctc_loss" begin
x = rand(10, 50)
y = F(rand(1:9, 30), 10)
y = rand(1:9, 30)
g1 = gradient(ctc_loss, x, y)[1]
g2 = ctc_ngradient(x, y)
@test g1 g2 rtol=1e-5 atol=1e-5
@test g1 g2 rtol=1e-5 atol=1e-5

# tests using hand-calculated values
x = [1. 2. 3.; 2. 1. 1.; 3. 3. 2.]
Expand All @@ -49,13 +36,13 @@ end

g = [-0.317671 -0.427729 0.665241; 0.244728 -0.0196172 -0.829811; 0.0729422 0.447346 0.16457]
ghat = gradient(ctc_loss, x, y)[1]
@test g ghat rtol=1e-5 atol=1e-5
@test g ghat rtol=1e-5 atol=1e-5

x = [-3. 12. 8. 15.; 4. 20. -2. 20.; 8. -33. 6. 5.]
y = [1, 2]
@test ctc_loss(x, y) 8.02519869363453

g = [-2.29294774655333e-06 -0.999662657278862 1.75500863563993e-06 0.00669284889063; 0.017985914969696 0.999662657278861 -1.9907078755387e-06 -0.006693150917307; -0.01798362202195 -2.52019580677916e-20 2.35699239251042e-07 3.02026677058789e-07]
ghat = gradient(ctc_loss, x, y)[1]
@test g ghat rtol=1e-5 atol=1e-5
@test g ghat rtol=1e-5 atol=1e-5
end

0 comments on commit 6e5fb17

Please sign in to comment.