Skip to content

Commit

Permalink
Merge pull request #20147 from JuliaLang/tk/rowvector
Browse files Browse the repository at this point in the history
Make rowvector tests check for specific exception type
  • Loading branch information
tkelman authored Jan 21, 2017
2 parents a8734dc + f90f4f5 commit 45bc9e2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 56 deletions.
49 changes: 33 additions & 16 deletions base/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ end


@inline check_types{T1,T2}(::Type{T1},::AbstractVector{T2}) = check_types(T1, T2)
@pure check_types{T1,T2}(::Type{T1},::Type{T2}) = T1 === transpose_type(T2) ? nothing : error("Element type mismatch. Tried to create a `RowVector{$T1}` from an `AbstractVector{$T2}`")
@pure check_types{T1,T2}(::Type{T1},::Type{T2}) = T1 === transpose_type(T2) ? nothing :
error("Element type mismatch. Tried to create a `RowVector{$T1}` from an `AbstractVector{$T2}`")

# The element type may be transformed as transpose is recursive
@inline transpose_type{T}(::Type{T}) = promote_op(transpose, T)
Expand All @@ -34,12 +35,17 @@ end

# Constructors that take a size and default to Array
@inline (::Type{RowVector{T}}){T}(n::Int) = RowVector{T}(Vector{transpose_type(T)}(n))
@inline (::Type{RowVector{T}}){T}(n1::Int, n2::Int) = n1 == 1 ? RowVector{T}(Vector{transpose_type(T)}(n2)) : error("RowVector expects 1×N size, got ($n1,$n2)")
@inline (::Type{RowVector{T}}){T}(n1::Int, n2::Int) = n1 == 1 ?
RowVector{T}(Vector{transpose_type(T)}(n2)) :
error("RowVector expects 1×N size, got ($n1,$n2)")
@inline (::Type{RowVector{T}}){T}(n::Tuple{Int}) = RowVector{T}(Vector{transpose_type(T)}(n[1]))
@inline (::Type{RowVector{T}}){T}(n::Tuple{Int,Int}) = n[1] == 1 ? RowVector{T}(Vector{transpose_type(T)}(n[2])) : error("RowVector expects 1×N size, got $n")
@inline (::Type{RowVector{T}}){T}(n::Tuple{Int,Int}) = n[1] == 1 ?
RowVector{T}(Vector{transpose_type(T)}(n[2])) :
error("RowVector expects 1×N size, got $n")

# Conversion of underlying storage
convert{T,V<:AbstractVector}(::Type{RowVector{T,V}}, rowvec::RowVector) = RowVector{T,V}(convert(V,rowvec.vec))
convert{T,V<:AbstractVector}(::Type{RowVector{T,V}}, rowvec::RowVector) =
RowVector{T,V}(convert(V,rowvec.vec))

# similar()
@inline similar(rowvec::RowVector) = RowVector(similar(rowvec.vec))
Expand Down Expand Up @@ -115,7 +121,7 @@ end
@propagate_inbounds setindex!(rowvec::RowVector, v, i::CartesianIndex{1}) = setindex!(rowvec, v, i.I[1])

@inline check_tail_indices(i1, i2) = true
@inline check_tail_indices(i1, i2, i3, is...) = i3 == 1 ? check_tail_indices(i1, i2, is...) : false
@inline check_tail_indices(i1, i2, i3, is...) = i3 == 1 ? check_tail_indices(i1, i2, is...) : false

# helper function for below
@inline to_vec(rowvec::RowVector) = transpose(rowvec)
Expand All @@ -126,15 +132,18 @@ end
@inline map(f, rowvecs::RowVector...) = RowVector(map(f, to_vecs(rowvecs...)...))

# broacast (other combinations default to higher-dimensional array)
@inline broadcast(f, rowvecs::Union{Number,RowVector}...) = RowVector(broadcast(f, to_vecs(rowvecs...)...))
@inline broadcast(f, rowvecs::Union{Number,RowVector}...) =
RowVector(broadcast(f, to_vecs(rowvecs...)...))

# Horizontal concatenation #

@inline hcat(X::RowVector...) = transpose(vcat(map(transpose, X)...))
@inline hcat(X::Union{RowVector,Number}...) = transpose(vcat(map(transpose, X)...))

@inline typed_hcat{T}(::Type{T}, X::RowVector...) = transpose(typed_vcat(T, map(transpose, X)...))
@inline typed_hcat{T}(::Type{T}, X::Union{RowVector,Number}...) = transpose(typed_vcat(T, map(transpose, X)...))
@inline typed_hcat{T}(::Type{T}, X::RowVector...) =
transpose(typed_vcat(T, map(transpose, X)...))
@inline typed_hcat{T}(::Type{T}, X::Union{RowVector,Number}...) =
transpose(typed_vcat(T, map(transpose, X)...))

# Multiplication #

Expand All @@ -145,7 +154,8 @@ end
sum(@inbounds(return rowvec[i]*vec[i]) for i = 1:length(vec))
end
@inline *(rowvec::RowVector, mat::AbstractMatrix) = transpose(mat.' * transpose(rowvec))
*(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply a matrix by a vector")) # Should become a deprecation
*(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply a matrix by a vector")) # Should become a deprecation
*(::RowVector, ::RowVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline *(vec::AbstractVector, rowvec::RowVector) = vec .* rowvec
*(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
Expand All @@ -154,28 +164,35 @@ end
# Transposed forms
A_mul_Bt(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline A_mul_Bt(rowvec::RowVector, mat::AbstractMatrix) = transpose(mat * transpose(rowvec))
A_mul_Bt(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply a matrix by a vector"))
A_mul_Bt(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply a matrix by a vector"))
@inline A_mul_Bt(rowvec1::RowVector, rowvec2::RowVector) = rowvec1*transpose(rowvec2)
A_mul_Bt(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
@inline A_mul_Bt(vec1::AbstractVector, vec2::AbstractVector) = vec1 * transpose(vec2)
@inline A_mul_Bt(mat::AbstractMatrix, rowvec::RowVector) = mat * transpose(rowvec)

@inline At_mul_Bt(rowvec::RowVector, vec::AbstractVector) = transpose(rowvec) * transpose(vec)
At_mul_Bt(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
At_mul_Bt(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply matrix by vector"))
@inline At_mul_Bt(vec::AbstractVector, mat::AbstractMatrix) = transpose(mat * vec)
At_mul_Bt(rowvec1::RowVector, rowvec2::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
@inline At_mul_Bt(vec::AbstractVector, rowvec::RowVector) = transpose(vec)*transpose(rowvec)
At_mul_Bt(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
At_mul_Bt(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch(
"Cannot multiply two transposed vectors"))
@inline At_mul_Bt(mat::AbstractMatrix, rowvec::RowVector) = mat.' * transpose(rowvec)

At_mul_B(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
At_mul_B(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
At_mul_B(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply matrix by vector"))
@inline At_mul_B(vec::AbstractVector, mat::AbstractMatrix) = transpose(At_mul_B(mat,vec))
@inline At_mul_B(rowvec1::RowVector, rowvec2::RowVector) = transpose(rowvec1) * rowvec2
At_mul_B(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline At_mul_B{T<:Real}(vec1::AbstractVector{T}, vec2::AbstractVector{T}) = reduce(+, map(At_mul_B, vec1, vec2)) # Seems to be overloaded...
At_mul_B(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch(
"Cannot multiply two transposed vectors"))
@inline At_mul_B{T<:Real}(vec1::AbstractVector{T}, vec2::AbstractVector{T}) =
reduce(+, map(At_mul_B, vec1, vec2)) # Seems to be overloaded...
@inline At_mul_B(vec1::AbstractVector, vec2::AbstractVector) = transpose(vec1) * vec2
At_mul_B(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
At_mul_B(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch(
"Cannot right-multiply matrix by transposed vector"))

# Conjugated forms
A_mul_Bc(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
Expand Down
80 changes: 40 additions & 40 deletions test/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@test size(RowVector{Int}(1,3)) === (1,3)
@test size(RowVector{Int}((3,))) === (1,3)
@test size(RowVector{Int}((1,3))) === (1,3)
@test_throws Exception RowVector{Float64, Vector{Int}}(v)
@test_throws ErrorException RowVector{Float64, Vector{Int}}(v)

@test (v.')::RowVector == [1 2 3]
@test (v')::RowVector == [1 2 3]
Expand Down Expand Up @@ -59,19 +59,19 @@ end
rv = v.'

@test (rv*d)::RowVector == [2,6,12].'
@test_throws Exception d*rv
@test_throws DimensionMismatch d*rv

@test (d*rv.')::Vector == [2,6,12]

@test_throws Exception rv.'*d
@test_throws DimensionMismatch rv.'*d

@test (d*rv')::Vector == [2,6,12]

@test_throws Exception rv'*d
@test_throws DimensionMismatch rv'*d

@test (rv/d)::RowVector [2/1 3/2 4/3]

@test_throws Exception d \ rv
@test_throws DimensionMismatch d \ rv
end

@testset "Bidiagonal ambiguity methods" begin
Expand All @@ -81,7 +81,7 @@ end

@test (rv/bd)::RowVector [2/1 3/2 4/3]

@test_throws Exception bd \ rv
@test_throws DimensionMismatch bd \ rv
end

@testset "hcat" begin
Expand All @@ -94,7 +94,7 @@ end
v = [2,3,4]
rv = v.'

@test_throws Exception mat \ rv
@test_throws DimensionMismatch mat \ rv
end

@testset "Multiplication" begin
Expand All @@ -104,65 +104,65 @@ end

@test (rv*v) === 14
@test (rv*mat)::RowVector == [1 4 9]
@test_throws Exception [1]*reshape([1],(1,1)) # no longer permitted
@test_throws Exception rv*rv
@test_throws DimensionMismatch [1]*reshape([1],(1,1)) # no longer permitted
@test_throws DimensionMismatch rv*rv
@test (v*rv)::Matrix == [1 2 3; 2 4 6; 3 6 9]
@test_throws Exception v*v # Was previously a missing method error, now an error message
@test_throws Exception mat*rv
@test_throws DimensionMismatch v*v # Was previously a missing method error, now an error message
@test_throws DimensionMismatch mat*rv

@test_throws Exception rv*v.'
@test_throws DimensionMismatch rv*v.'
@test (rv*mat.')::RowVector == [1 4 9]
@test_throws Exception [1]*reshape([1],(1,1)).' # no longer permitted
@test_throws DimensionMismatch [1]*reshape([1],(1,1)).' # no longer permitted
@test rv*rv.' === 14
@test_throws Exception v*rv.'
@test_throws DimensionMismatch v*rv.'
@test (v*v.')::Matrix == [1 2 3; 2 4 6; 3 6 9]
@test (mat*rv.')::Vector == [1,4,9]

@test (rv.'*v.')::Matrix == [1 2 3; 2 4 6; 3 6 9]
@test_throws Exception rv.'*mat.'
@test_throws DimensionMismatch rv.'*mat.'
@test (v.'*mat.')::RowVector == [1 4 9]
@test_throws Exception rv.'*rv.'
@test_throws DimensionMismatch rv.'*rv.'
@test v.'*rv.' === 14
@test_throws Exception v.'*v.'
@test_throws DimensionMismatch v.'*v.'
@test (mat.'*rv.')::Vector == [1,4,9]

@test_throws Exception rv.'*v
@test_throws Exception rv.'*mat
@test_throws DimensionMismatch rv.'*v
@test_throws DimensionMismatch rv.'*mat
@test (v.'*mat)::RowVector == [1 4 9]
@test (rv.'*rv)::Matrix == [1 2 3; 2 4 6; 3 6 9]
@test_throws Exception v.'*rv
@test_throws DimensionMismatch v.'*rv
@test v.'*v === 14
@test_throws Exception mat.'*rv
@test_throws DimensionMismatch mat.'*rv

z = [1+im,2,3]
cz = z'
mat = diagm([1+im,2,3])

@test cz*z === 15 + 0im

@test_throws Exception cz*z'
@test_throws DimensionMismatch cz*z'
@test (cz*mat')::RowVector == [-2im 4 9]
@test_throws Exception [1]*reshape([1],(1,1))' # no longer permitted
@test_throws DimensionMismatch [1]*reshape([1],(1,1))' # no longer permitted
@test cz*cz' === 15 + 0im
@test_throws Exception z*vz'
@test_throws DimensionMismatch z*cz'
@test (z*z')::Matrix == [2 2+2im 3+3im; 2-2im 4 6; 3-3im 6 9]
@test (mat*cz')::Vector == [2im,4,9]

@test (cz'*z')::Matrix == [2 2+2im 3+3im; 2-2im 4 6; 3-3im 6 9]
@test_throws Exception cz'*mat'
@test_throws DimensionMismatch cz'*mat'
@test (z'*mat')::RowVector == [-2im 4 9]
@test_throws Exception cz'*cz'
@test_throws DimensionMismatch cz'*cz'
@test z'*cz' === 15 + 0im
@test_throws Exception z'*z'
@test_throws DimensionMismatch z'*z'
@test (mat'*cz')::Vector == [2,4,9]

@test_throws Exception cz'*z
@test_throws Exception cz'*mat
@test_throws DimensionMismatch cz'*z
@test_throws DimensionMismatch cz'*mat
@test (z'*mat)::RowVector == [2 4 9]
@test (cz'*cz)::Matrix == [2 2+2im 3+3im; 2-2im 4 6; 3-3im 6 9]
@test_throws Exception z'*cz
@test_throws DimensionMismatch z'*cz
@test z'*z === 15 + 0im
@test_throws Exception mat'*cz
@test_throws DimensionMismatch mat'*cz
end

@testset "norm" begin
Expand Down Expand Up @@ -202,7 +202,7 @@ end

@test (rv/mat)::RowVector [2/1 3/2 4/3]

@test_throws Exception mat\rv
@test_throws DimensionMismatch mat\rv
end

@testset "AbstractTriangular ambiguity methods" begin
Expand All @@ -211,31 +211,31 @@ end
rv = v.'

@test (rv*ut)::RowVector == [2 6 12]
@test_throws Exception ut*rv
@test_throws DimensionMismatch ut*rv

@test (rv*ut.')::RowVector == [2 6 12]
@test (ut*rv.')::Vector == [2,6,12]

@test (ut.'*rv.')::Vector == [2,6,12]
@test_throws Exception rv.'*ut.'
@test_throws DimensionMismatch rv.'*ut.'

@test_throws Exception ut.'*rv
@test_throws Exception rv.'*ut
@test_throws DimensionMismatch ut.'*rv
@test_throws DimensionMismatch rv.'*ut

@test (rv*ut')::RowVector == [2 6 12]
@test (ut*rv')::Vector == [2,6,12]

@test_throws Exception rv'*ut'
@test_throws DimensionMismatch rv'*ut'
@test (ut'*rv')::Vector == [2,6,12]

@test_throws Exception ut'*rv
@test_throws Exception rv'*ut
@test_throws DimensionMismatch ut'*rv
@test_throws DimensionMismatch rv'*ut

@test (rv/ut)::RowVector [2/1 3/2 4/3]
@test (rv/ut.')::RowVector [2/1 3/2 4/3]
@test (rv/ut')::RowVector [2/1 3/2 4/3]

@test_throws Exception ut\rv
@test_throws DimensionMismatch ut\rv
end


Expand Down

0 comments on commit 45bc9e2

Please sign in to comment.