Skip to content

Commit

Permalink
Improve performance of generic dot products (#27678)
Browse files Browse the repository at this point in the history
* improve performance of generic dot products

* update generic dot as suggested by @haampie
  • Loading branch information
ranocha authored and andreasnoack committed Jul 11, 2018
1 parent 3791357 commit f3ad067
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 39 deletions.
47 changes: 13 additions & 34 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,21 +621,6 @@ opnorm(v::TransposeAbsVec) = norm(v.parent)

norm(v::Union{TransposeAbsVec,AdjointAbsVec}, p::Real) = norm(v.parent, p)

function dot(x::AbstractArray, y::AbstractArray)
lx = _length(x)
if lx != _length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(_length(y))."))
end
if lx == 0
return dot(zero(eltype(x)), zero(eltype(y)))
end
s = zero(dot(first(x), first(y)))
for (Ix, Iy) in zip(eachindex(x), eachindex(y))
@inbounds s += dot(x[Ix], y[Iy])
end
s
end

"""
dot(x, y)
x ⋅ y
Expand Down Expand Up @@ -678,14 +663,13 @@ function dot(x, y) # arbitrary iterables
while true
ix = iterate(x, xs)
iy = iterate(y, ys)
if (ix == nothing) || (iy == nothing)
break
end
ix === nothing && break
iy === nothing && break
(vx, xs), (vy, ys) = ix, iy
s += dot(vx, vy)
end
if !(iy == nothing && ix == nothing)
throw(DimensionMismatch("x and y are of different lengths!"))
if !(iy === nothing && ix === nothing)
throw(DimensionMismatch("x and y are of different lengths!"))
end
return s
end
Expand All @@ -709,24 +693,19 @@ julia> dot([im; im], [1; 1])
0 - 2im
```
"""
function dot(x::AbstractVector, y::AbstractVector)
if length(LinearIndices(x)) != length(LinearIndices(y))
throw(DimensionMismatch("dot product arguments have unequal lengths $(length(LinearIndices(x))) and $(length(LinearIndices(y)))"))
function dot(x::AbstractArray, y::AbstractArray)
lx = _length(x)
if lx != _length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(_length(y))."))
end
ix = iterate(x)
if ix === nothing
# we only need to check the first vector, since equal lengths have been asserted
if lx == 0
return dot(zero(eltype(x)), zero(eltype(y)))
end
iy = iterate(y)
s = dot(ix[1], iy[1])
ix, iy = iterate(x, ix[2]), iterate(y, iy[2])
while ix != nothing
s += dot(ix[1], iy[1])
ix = iterate(x, ix[2])
iy = iterate(y, iy[2])
s = zero(dot(first(x), first(y)))
for (Ix, Iy) in zip(eachindex(x), eachindex(y))
@inbounds s += dot(x[Ix], y[Iy])
end
return s
s
end


Expand Down
12 changes: 7 additions & 5 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,18 @@ end
@test dot(Z, Z) == convert(elty, 34.0)
end

dot_(x,y) = invoke(dot, Tuple{Any,Any}, x,y)
dot1(x,y) = invoke(dot, Tuple{Any,Any}, x,y)
dot2(x,y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x,y)
@testset "generic dot" begin
AA = [1+2im 3+4im; 5+6im 7+8im]
BB = [2+7im 4+1im; 3+8im 6+5im]
for A in (copy(AA), view(AA, 1:2, 1:2)), B in (copy(BB), view(BB, 1:2, 1:2))
@test dot(A,B) == dot(vec(A),vec(B)) == dot_(A,B) == dot(float.(A),float.(B))
@test dot(Int[], Int[]) == 0 == dot_(Int[], Int[])
@test dot(A,B) == dot(vec(A),vec(B)) == dot1(A,B) == dot2(A,B) == dot(float.(A),float.(B))
@test dot(Int[], Int[]) == 0 == dot1(Int[], Int[]) == dot2(Int[], Int[])
@test_throws MethodError dot(Any[], Any[])
@test_throws MethodError dot_(Any[], Any[])
for n1 = 0:2, n2 = 0:2, d in (dot, dot_)
@test_throws MethodError dot1(Any[], Any[])
@test_throws MethodError dot2(Any[], Any[])
for n1 = 0:2, n2 = 0:2, d in (dot, dot1, dot2)
if n1 != n2
@test_throws DimensionMismatch d(1:n1, 1:n2)
else
Expand Down

0 comments on commit f3ad067

Please sign in to comment.