From ff117d5c4e079d36794760474a528e70bdfcb89c Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Wed, 20 Jun 2018 08:49:31 +0200 Subject: [PATCH 1/2] improve performance of generic dot products --- stdlib/LinearAlgebra/src/generic.jl | 62 +++++++++-------------------- stdlib/LinearAlgebra/test/matmul.jl | 12 +++--- 2 files changed, 25 insertions(+), 49 deletions(-) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index bed2605cb75a2..847d0b902f3d7 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -621,21 +621,6 @@ opnorm(v::TransposeAbsVec) = norm(v.parent) norm(v::Union{TransposeAbsVec,AdjointAbsVec}, p::Real) = norm(v.parent, p) -function dot(x::AbstractArray, y::AbstractArray) - lx = _length(x) - if lx != _length(y) - throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(_length(y)).")) - end - if lx == 0 - return dot(zero(eltype(x)), zero(eltype(y))) - end - s = zero(dot(first(x), first(y))) - for (Ix, Iy) in zip(eachindex(x), eachindex(y)) - @inbounds s += dot(x[Ix], y[Iy]) - end - s -end - """ dot(x, y) x ⋅ y @@ -663,29 +648,23 @@ julia> dot(x, y) function dot(x, y) # arbitrary iterables ix = iterate(x) iy = iterate(y) - if ix === nothing - if iy !== nothing + if ix == nothing + if iy != nothing throw(DimensionMismatch("x and y are of different lengths!")) end return dot(zero(eltype(x)), zero(eltype(y))) end - if iy === nothing + if iy == nothing throw(DimensionMismatch("x and y are of different lengths!")) end - (vx, xs) = ix - (vy, ys) = iy - s = dot(vx, vy) - while true - ix = iterate(x, xs) - iy = iterate(y, ys) - if (ix == nothing) || (iy == nothing) - break - end - (vx, xs), (vy, ys) = ix, iy - s += dot(vx, vy) + s = dot(ix[1], iy[1]) + ix, iy = iterate(x, ix[2]), iterate(y, iy[2]) + while ix != nothing && iy != nothing + s += dot(ix[1], iy[1]) + ix, iy = iterate(x, ix[2]), iterate(y, iy[2]) end if !(iy == nothing && ix == nothing) - throw(DimensionMismatch("x and y are of different lengths!")) + throw(DimensionMismatch("x and y are of different lengths!")) end return s end @@ -709,24 +688,19 @@ julia> dot([im; im], [1; 1]) 0 - 2im ``` """ -function dot(x::AbstractVector, y::AbstractVector) - if length(LinearIndices(x)) != length(LinearIndices(y)) - throw(DimensionMismatch("dot product arguments have unequal lengths $(length(LinearIndices(x))) and $(length(LinearIndices(y)))")) +function dot(x::AbstractArray, y::AbstractArray) + lx = _length(x) + if lx != _length(y) + throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(_length(y)).")) end - ix = iterate(x) - if ix === nothing - # we only need to check the first vector, since equal lengths have been asserted + if lx == 0 return dot(zero(eltype(x)), zero(eltype(y))) end - iy = iterate(y) - s = dot(ix[1], iy[1]) - ix, iy = iterate(x, ix[2]), iterate(y, iy[2]) - while ix != nothing - s += dot(ix[1], iy[1]) - ix = iterate(x, ix[2]) - iy = iterate(y, iy[2]) + s = zero(dot(first(x), first(y))) + for (Ix, Iy) in zip(eachindex(x), eachindex(y)) + @inbounds s += dot(x[Ix], y[Iy]) end - return s + s end diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index db906303c89ac..7ac8924e63447 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -232,16 +232,18 @@ end @test dot(Z, Z) == convert(elty, 34.0) end -dot_(x,y) = invoke(dot, Tuple{Any,Any}, x,y) +dot1(x,y) = invoke(dot, Tuple{Any,Any}, x,y) +dot2(x,y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x,y) @testset "generic dot" begin AA = [1+2im 3+4im; 5+6im 7+8im] BB = [2+7im 4+1im; 3+8im 6+5im] for A in (copy(AA), view(AA, 1:2, 1:2)), B in (copy(BB), view(BB, 1:2, 1:2)) - @test dot(A,B) == dot(vec(A),vec(B)) == dot_(A,B) == dot(float.(A),float.(B)) - @test dot(Int[], Int[]) == 0 == dot_(Int[], Int[]) + @test dot(A,B) == dot(vec(A),vec(B)) == dot1(A,B) == dot2(A,B) == dot(float.(A),float.(B)) + @test dot(Int[], Int[]) == 0 == dot1(Int[], Int[]) == dot2(Int[], Int[]) @test_throws MethodError dot(Any[], Any[]) - @test_throws MethodError dot_(Any[], Any[]) - for n1 = 0:2, n2 = 0:2, d in (dot, dot_) + @test_throws MethodError dot1(Any[], Any[]) + @test_throws MethodError dot2(Any[], Any[]) + for n1 = 0:2, n2 = 0:2, d in (dot, dot1, dot2) if n1 != n2 @test_throws DimensionMismatch d(1:n1, 1:n2) else From 6c89de2d579df0568ea8ce46a48525366fe872e2 Mon Sep 17 00:00:00 2001 From: Hendrik Ranocha Date: Thu, 21 Jun 2018 09:25:17 +0200 Subject: [PATCH 2/2] update generic dot as suggested by @haampie --- stdlib/LinearAlgebra/src/generic.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 847d0b902f3d7..a66789a17b5fd 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -648,22 +648,27 @@ julia> dot(x, y) function dot(x, y) # arbitrary iterables ix = iterate(x) iy = iterate(y) - if ix == nothing - if iy != nothing + if ix === nothing + if iy !== nothing throw(DimensionMismatch("x and y are of different lengths!")) end return dot(zero(eltype(x)), zero(eltype(y))) end - if iy == nothing + if iy === nothing throw(DimensionMismatch("x and y are of different lengths!")) end - s = dot(ix[1], iy[1]) - ix, iy = iterate(x, ix[2]), iterate(y, iy[2]) - while ix != nothing && iy != nothing - s += dot(ix[1], iy[1]) - ix, iy = iterate(x, ix[2]), iterate(y, iy[2]) + (vx, xs) = ix + (vy, ys) = iy + s = dot(vx, vy) + while true + ix = iterate(x, xs) + iy = iterate(y, ys) + ix === nothing && break + iy === nothing && break + (vx, xs), (vy, ys) = ix, iy + s += dot(vx, vy) end - if !(iy == nothing && ix == nothing) + if !(iy === nothing && ix === nothing) throw(DimensionMismatch("x and y are of different lengths!")) end return s