diff --git a/base/blas.jl b/base/blas.jl index 9cf881e4bd05d..b02178c17e129 100644 --- a/base/blas.jl +++ b/base/blas.jl @@ -166,18 +166,16 @@ function axpy!{T,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range return axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry)) end - -# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC) -# * .. Scalar Arguments .. -# REAL ALPHA,BETA -# INTEGER K,LDA,LDC,N -# CHARACTER TRANS,UPLO -# * .. -# * .. Array Arguments .. -# REAL A(LDA,*),C(LDC,*) for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32), (:zsyrk_,:Complex128), (:csyrk_,:Complex64)) @eval begin + # SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC) + # * .. Scalar Arguments .. + # REAL ALPHA,BETA + # INTEGER K,LDA,LDC,N + # CHARACTER TRANS,UPLO + # * .. Array Arguments .. + # REAL A(LDA,*),C(LDC,*) function syrk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}, beta::($elty), C::StridedMatrix{$elty}) m, n = size(C) @@ -193,14 +191,9 @@ for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32), end function syrk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}) n = size(A, trans == 'N' ? 1 : 2) - k = size(A, trans == 'N' ? 2 : 1) - C = Array($elty, (n, n)) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), - &uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2)) - C + syrk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n, n))) end + syrk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = syrk(uplo, trans, one($elty), A) end end @@ -214,7 +207,7 @@ end # COMPLEX A(LDA,*),C(LDC,*) for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64)) @eval begin - function herk!(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty}, + function herk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}, beta::($elty), C::StridedMatrix{$elty}) m, n = size(C) if m != n error("syrk!: matrix C must be square") end @@ -227,16 +220,11 @@ for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64)) &uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &beta, C, &stride(C,2)) C end - function herk(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty}) + function herk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty}) n = size(A, trans == 'N' ? 1 : 2) - k = size(A, trans == 'N' ? 2 : 1) - C = Array($elty, (n, n)) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), - &uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2)) - C + herk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n,n))) end + herk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = herk(uplo, trans, one($elty), A) end end @@ -266,16 +254,12 @@ for (fname, elty) in ((:dgbmv_,:Float64), (:sgbmv_,:Float32), function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty}) n = stride(A,2) - y = Array($elty, n) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), - &trans, &m, &n, &kl, &ku, &alpha, A, &stride(A,2), - x, &stride(x,1), &0., y, &1) - y + gbmv!(trans, m, kl, ku, alpha, A, x, zero($elty), Array($elty, n)) + end + function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer, + A::StridedMatrix{$elty}, x::StridedVector{$elty}) + gbmv(trans, m, kl, ku, one($elty), A, x) end - end end @@ -303,34 +287,42 @@ for (fname, elty) in ((:dsbmv_,:Float64), (:ssbmv_,:Float32), function sbmv(uplo::BlasChar, k::Integer, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty}) n = size(A,2) - y = Array($elty, n) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), - &uplo, &size(A,2), &k, &alpha, A, &stride(A,2), x, &stride(x,1), &0., y, &1) - y + sbmv!(uplo, k, alpha, A, x, zero($elty), Array($elty, n)) + end + function sbmv(uplo::BlasChar, k::Integer, A::StridedMatrix{$elty}, x::StridedVector{$elty}) + sbmv(uplo, k, one($elty), A, x) end end end -# (GE) general matrix-matrix multiplication -# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC) -# * .. Scalar Arguments .. -# DOUBLE PRECISION ALPHA,BETA -# INTEGER K,LDA,LDB,LDC,M,N -# CHARACTER TRANSA,TRANSB -# * .. Array Arguments .. -# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) -for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32), - (:zgemm_,:Complex128), (:cgemm_,:Complex64)) +# (GE) general matrix-matrix and matrix-vector multiplication +for (gemm, gemv, elty) in + ((:dgemm_,:dgemv_,:Float64), + (:sgemm_,:sgemv_,:Float32), + (:zgemm_,:zgemv_,:Complex128), + (:cgemm_,:cgemv_,:Complex64)) @eval begin - function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, - B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty}) + # SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC) + # * .. Scalar Arguments .. + # DOUBLE PRECISION ALPHA,BETA + # INTEGER K,LDA,LDB,LDC,M,N + # CHARACTER TRANSA,TRANSB + # * .. Array Arguments .. + # DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) + function gemm!(transA::BlasChar, transB::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + B::StridedMatrix{$elty}, + beta::($elty), C::StridedMatrix{$elty}) +# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1) +# error("gemm!: BLAS module requires contiguous matrix columns") +# end # should this be checked on every call? m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) - if m != size(C,1) || n != size(C,2) error("gemm!: mismatched dimensions") end - ccall(($(string(fname)),libblas), Void, + if m != size(C,1) || n != size(C,2) + error("gemm!: mismatched dimensions") + end + ccall(($(string(gemm)),libblas), Void, (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), @@ -338,92 +330,66 @@ for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32), B, &stride(B,2), &beta, C, &stride(C,2)) C end - function gemm(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty}) - m = size(A, transA == 'N' ? 1 : 2) - k = size(A, transA == 'N' ? 2 : 1) - if k != size(B, transB == 'N' ? 1 : 2) error("gemm!: mismatched dimensions") end - n = size(B, transB == 'N' ? 2 : 1) - C = Array($elty, (m, n)) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), - &transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2), - B, &stride(B,2), &0., C, &stride(C,2)) - C + function gemm(transA::BlasChar, transB::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + B::StridedMatrix{$elty}) + gemm!(transA, transB, alpha, A, B, zero($elty), + Array($elty, (size(A, transA == 'N' ? 1 : 2), + size(B, transB == 'N' ? 2 : 1)))) end - end -end - -#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) -#* .. Scalar Arguments .. -# DOUBLE PRECISION ALPHA,BETA -# INTEGER INCX,INCY,LDA,M,N -# CHARACTER TRANS -#* .. Array Arguments .. -# DOUBLE PRECISION A(LDA,*),X(*),Y(*) - -for (fname, elty) in ((:dgemv_,:Float64), (:sgemv_,:Float32), - (:zgemv_,:Complex128), (:cgemv_,:Complex64)) - @eval begin - function gemv!(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, - X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty}) - ccall(($(string(fname)),libblas), Void, - (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), + function gemm(transA::BlasChar, transB::BlasChar, + A::StridedMatrix{$elty}, B::StridedMatrix{$elty}) + gemm(transA, transB, one($elty), A, B) + end + #SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) + #* .. Scalar Arguments .. + # DOUBLE PRECISION ALPHA,BETA + # INTEGER INCX,INCY,LDA,M,N + # CHARACTER TRANS + #* .. Array Arguments .. + # DOUBLE PRECISION A(LDA,*),X(*),Y(*) + function gemv!(trans::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + X::StridedVector{$elty}, + beta::($elty), Y::StridedVector{$elty}) + ccall(($(string(gemv)),libblas), Void, + (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), &trans, &size(A,1), &size(A,2), &alpha, A, &stride(A,2), X, &stride(X,1), &beta, Y, &stride(Y,1)) Y end - function gemv(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty}) - Y = Array($elty, size(A,1)) - gemv!(trans, alpha, A, X, zero($elty), Y) - Y + function gemv(trans::BlasChar, + alpha::($elty), A::StridedMatrix{$elty}, + X::StridedVector{$elty}) + gemv!(trans, alpha, A, X, zero($elty), + Array($elty, size(A, (trans == 'N' ? 1 : 2)))) + end + function gemv(trans::BlasChar, A::StridedMatrix{$elty}, X::StridedVector{$elty}) + gemv!(trans, one($elty), A, X, zero($elty), + Array($elty, size(A, (trans == 'N' ? 1 : 2)))) end end end # (SY) symmetric matrix-matrix and matrix-vector multiplication - -# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC) -# .. Scalar Arguments .. -# DOUBLE PRECISION ALPHA,BETA -# INTEGER LDA,LDB,LDC,M,N -# CHARACTER SIDE,UPLO -# .. Array Arguments .. -# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) - -# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) -# .. Scalar Arguments .. -# DOUBLE PRECISION ALPHA,BETA -# INTEGER INCX,INCY,LDA,N -# CHARACTER UPLO -# .. Array Arguments .. -# DOUBLE PRECISION A(LDA,*),X(*),Y(*) - -for (vfname, mfname, elty) in - ((:dsymv_,:dsymm_,:Float64), - (:ssymv_,:ssymm_,:Float32), - (:zsymv_,:zsymm_,:Complex128), - (:csymv_,:csymm_,:Complex64)) +for (mfname, vfname, elty) in + ((:dsymm_,:dsymv_,:Float64), + (:ssymm_,:ssymv_,:Float32), + (:zsymm_,:zsymv_,:Complex128), + (:csymm_,:csymv_,:Complex64)) @eval begin - function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty}, - beta::($elty), Y::StridedVector{$elty}) - m, n = size(A) - if m != n error("symm!: matrix A is $m by $n but must be square") end - if m != length(X) || m != length(Y) error("symm!: dimension mismatch") end - ccall(($(string(vfname)),libblas), Void, - (Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, - Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), - &uplo, &n, &alpha, A, &stride(A,2), X, &stride(X,1), &beta, Y, &stride(Y,1)) - Y - end - function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty}) - symv!(uplo, alpha, A, X, zero($elty), similar(X)) - end - function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty}, - beta::($elty), C::StridedMatrix{$elty}) - side = uppercase(convert(Char, side)) + # SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC) + # .. Scalar Arguments .. + # DOUBLE PRECISION ALPHA,BETA + # INTEGER LDA,LDB,LDC,M,N + # CHARACTER SIDE,UPLO + # .. Array Arguments .. + # DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*) + function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, + B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty}) m, n = size(C) k, j = size(A) if k != j error("symm!: matrix A is $k by $j but must be square") end @@ -435,9 +401,37 @@ for (vfname, mfname, elty) in &beta, C, &stride(C,2)) C end - function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty}) + function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, + B::StridedMatrix{$elty}) symm!(side, uplo, alpha, A, B, zero($elty), similar(B)) end + function symm(side::BlasChar, uplo::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty}) + symm(side, uplo, one($elty), A, B) + end + # SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY) + # .. Scalar Arguments .. + # DOUBLE PRECISION ALPHA,BETA + # INTEGER INCX,INCY,LDA,N + # CHARACTER UPLO + # .. Array Arguments .. + # DOUBLE PRECISION A(LDA,*),X(*),Y(*) + function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty}, + beta::($elty), y::StridedVector{$elty}) + m, n = size(A) + if m != n error("symm!: matrix A is $m by $n but must be square") end + if m != length(x) || m != length(y) error("symm!: dimension mismatch") end + ccall(($(string(vfname)),libblas), Void, + (Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, + Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}), + &uplo, &n, &alpha, A, &stride(A,2), x, &stride(x,1), &beta, y, &stride(y,1)) + Y + end + function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty}) + symv!(uplo, alpha, A, x, zero($elty), similar(x)) + end + function symv(uplo::BlasChar, A::StridedMatrix{$elty}, x::StridedVector{$elty}) + symv(uplo, one($elty), A, x) + end end end diff --git a/test/Makefile b/test/Makefile index 97a562b30bbbc..073590a380af8 100644 --- a/test/Makefile +++ b/test/Makefile @@ -5,7 +5,7 @@ default all extra unicode gzip:: TESTS = default all extra \ core numbers strings unicode corelib hashing remote \ -arrayops linalg fft dct sparse bitarray suitesparse arpack \ +arrayops blas linalg fft dct sparse bitarray suitesparse arpack \ random math functional bigint bigfloat sorting \ statistics poly file Rmath remote zlib image \ iostring gzip integers spawn ccall parallel diff --git a/test/blas.jl b/test/blas.jl new file mode 100644 index 0000000000000..7c58a753ca312 --- /dev/null +++ b/test/blas.jl @@ -0,0 +1,78 @@ +## BLAS tests - testing the interface code to BLAS routines +for elty in (Float32, Float64, Complex64, Complex128) + + o4 = ones(elty, 4) + z4 = zeros(elty, 4) + + I4 = eye(elty, 4) + L4 = tril(ones(elty, (4,4))) + U4 = triu(ones(elty, (4,4))) + Z4 = zeros(elty, (4,4)) + + elm1 = convert(elty, -1) + el2 = convert(elty, 2) + v14 = convert(Vector{elty}, [1:4]) + v41 = convert(Vector{elty}, [4:-1:1]) + # gemv + @assert all(BLAS.gemv('N', I4, o4) .== o4) + @assert all(BLAS.gemv('T', I4, o4) .== o4) + @assert all(BLAS.gemv('N', el2, I4, o4) .== el2 * o4) + @assert all(BLAS.gemv('T', el2, I4, o4) .== el2 * o4) + o4cp = copy(o4) + @assert all(BLAS.gemv!('N', one(elty), I4, o4, elm1, o4cp) .== z4) + @assert all(o4cp .== z4) + o4cp[:] = o4 + @assert all(BLAS.gemv!('T', one(elty), I4, o4, elm1, o4cp) .== z4) + @assert all(o4cp .== z4) + @assert all(BLAS.gemv('N', U4, o4) .== v41) + @assert all(BLAS.gemv('N', U4, o4) .== v41) + # gemm + @assert all(BLAS.gemm('N', 'N', I4, I4) .== I4) + @assert all(BLAS.gemm('N', 'T', I4, I4) .== I4) + @assert all(BLAS.gemm('T', 'N', I4, I4) .== I4) + @assert all(BLAS.gemm('T', 'T', I4, I4) .== I4) + @assert all(BLAS.gemm('N', 'N', el2, I4, I4) .== el2 * I4) + @assert all(BLAS.gemm('N', 'T', el2, I4, I4) .== el2 * I4) + @assert all(BLAS.gemm('T', 'N', el2, I4, I4) .== el2 * I4) + @assert all(BLAS.gemm('T', 'T', el2, I4, I4) .== el2 * I4) + I4cp = copy(I4) + @assert all(BLAS.gemm!('N', 'N', one(elty), I4, I4, elm1, I4cp) .== Z4) + @assert all(I4cp .== Z4) + I4cp[:] = I4 + @assert all(BLAS.gemm!('N', 'T', one(elty), I4, I4, elm1, I4cp) .== Z4) + @assert all(I4cp .== Z4) + I4cp[:] = I4 + @assert all(BLAS.gemm!('T', 'N', one(elty), I4, I4, elm1, I4cp) .== Z4) + @assert all(I4cp .== Z4) + I4cp[:] = I4 + @assert all(BLAS.gemm!('T', 'T', one(elty), I4, I4, elm1, I4cp) .== Z4) + @assert all(I4cp .== Z4) + @assert all(BLAS.gemm('N', 'N', I4, U4) .== U4) + @assert all(BLAS.gemm('N', 'T', I4, U4) .== L4) + # gemm compared to (sy)(he)rk + if iscomplex(elm1) + @assert all(triu(BLAS.herk('U', 'N', U4)) .== triu(BLAS.gemm('N', 'T', U4, U4))) + @assert all(tril(BLAS.herk('L', 'N', U4)) .== tril(BLAS.gemm('N', 'T', U4, U4))) + @assert all(triu(BLAS.herk('U', 'N', L4)) .== triu(BLAS.gemm('N', 'T', L4, L4))) + @assert all(tril(BLAS.herk('L', 'N', L4)) .== tril(BLAS.gemm('N', 'T', L4, L4))) + @assert all(triu(BLAS.herk('U', 'T', U4)) .== triu(BLAS.gemm('T', 'N', U4, U4))) + @assert all(tril(BLAS.herk('L', 'T', U4)) .== tril(BLAS.gemm('T', 'N', U4, U4))) + @assert all(triu(BLAS.herk('U', 'T', L4)) .== triu(BLAS.gemm('T', 'N', L4, L4))) + @assert all(tril(BLAS.herk('L', 'T', L4)) .== tril(BLAS.gemm('T', 'N', L4, L4))) + ans = similar(L4) + @assert all(tril(BLAS.herk('L','T', L4)) .== tril(BLAS.herk!('L', 'T', one(elty), L4, zero(elty), ans))) + @assert all(symmetrize!(ans, 'L') .== BLAS.gemm('T', 'N', L4, L4)) + else + @assert all(triu(BLAS.syrk('U', 'N', U4)) .== triu(BLAS.gemm('N', 'T', U4, U4))) + @assert all(tril(BLAS.syrk('L', 'N', U4)) .== tril(BLAS.gemm('N', 'T', U4, U4))) + @assert all(triu(BLAS.syrk('U', 'N', L4)) .== triu(BLAS.gemm('N', 'T', L4, L4))) + @assert all(tril(BLAS.syrk('L', 'N', L4)) .== tril(BLAS.gemm('N', 'T', L4, L4))) + @assert all(triu(BLAS.syrk('U', 'T', U4)) .== triu(BLAS.gemm('T', 'N', U4, U4))) + @assert all(tril(BLAS.syrk('L', 'T', U4)) .== tril(BLAS.gemm('T', 'N', U4, U4))) + @assert all(triu(BLAS.syrk('U', 'T', L4)) .== triu(BLAS.gemm('T', 'N', L4, L4))) + @assert all(tril(BLAS.syrk('L', 'T', L4)) .== tril(BLAS.gemm('T', 'N', L4, L4))) + ans = similar(L4) + @assert all(tril(BLAS.syrk('L','T', L4)) .== tril(BLAS.syrk!('L', 'T', one(elty), L4, zero(elty), ans))) + @assert all(symmetrize!(ans, 'L') .== BLAS.gemm('T', 'N', L4, L4)) + end +end diff --git a/test/default.jl b/test/default.jl index 06eb626c75ace..3843bffbc8061 100644 --- a/test/default.jl +++ b/test/default.jl @@ -11,6 +11,7 @@ runtests("iostring") # array/matrix tests runtests("arrayops") runtests("linalg") +runtests("blas") runtests("fft") runtests("dct") runtests("sparse") diff --git a/test/linalg.jl b/test/linalg.jl index 07f108126e8ae..8c468ecb53170 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -4,35 +4,35 @@ a = rand(n,n) b = rand(n) for elty in (Float32, Float64, Complex64, Complex128) a = convert(Matrix{elty}, a) - asym = a' + a # symmetric indefinite - apd = a'*a # symmetric positive-definite + asym = a' + a # symmetric indefinite + apd = a'*a # symmetric positive-definite b = convert(Vector{elty}, b) - capd = chold(apd) # upper Cholesky factor + capd = chold(apd) # upper Cholesky factor r = factors(capd) @assert_approx_eq r'*r apd @assert_approx_eq b apd * (capd\b) @assert_approx_eq apd * inv(capd) eye(elty, n) - @assert_approx_eq a*(capd\(a'*b)) b # least squares soln for square a + @assert_approx_eq a*(capd\(a'*b)) b # least squares soln for square a @assert_approx_eq det(capd) det(apd) - l = factors(chold(apd, 'L')) # lower Cholesky factor + l = factors(chold(apd, 'L')) # lower Cholesky factor @assert_approx_eq l*l' apd - cpapd = cholpd(apd) # pivoted Choleksy decomposition + cpapd = cholpd(apd) # pivoted Choleksy decomposition @test rank(cpapd) == n - @test all(diff(diag(real(cpapd.LR))).<=0.) # diagonal show be non-increasing + @test all(diff(diag(real(cpapd.LR))).<=0.) # diagonal should be non-increasing @assert_approx_eq b apd * (cpapd\b) @assert_approx_eq apd * inv(cpapd) eye(elty, n) - bc1 = BunchKaufman(asym) # Bunch-Kaufman factor of indefinite matrix + bc1 = BunchKaufman(asym) # Bunch-Kaufman factor of indefinite matrix @assert_approx_eq inv(bc1) * asym eye(elty, n) @assert_approx_eq asym * (bc1\b) b - bc2 = BunchKaufman(apd) # Bunch-Kaufman factors of a pos-def matrix + bc2 = BunchKaufman(apd) # Bunch-Kaufman factors of a pos-def matrix @assert_approx_eq inv(bc2) * apd eye(elty, n) @assert_approx_eq apd * (bc2\b) b - lua = lud(a) # LU decomposition + lua = lud(a) # LU decomposition l,u,p = lu(a) L,U,P = factors(lua) @test l == L && u == U && p == P @@ -41,7 +41,7 @@ for elty in (Float32, Float64, Complex64, Complex128) @assert_approx_eq a * inv(lua) eye(elty, n) @assert_approx_eq a*(lua\b) b - qra = qrd(a) # QR decomposition + qra = qrd(a) # QR decomposition q,r = factors(qra) @assert_approx_eq q'*q eye(elty, n) @assert_approx_eq q*q' eye(elty, n) @@ -52,7 +52,7 @@ for elty in (Float32, Float64, Complex64, Complex128) @assert_approx_eq qra'*b Q'*b @assert_approx_eq a*(qra\b) b - qrpa = qrpd(a) # pivoted QR decomposition + qrpa = qrpd(a) # pivoted QR decomposition q,r,p = factors(qrpa) @assert_approx_eq q'*q eye(elty, n) @assert_approx_eq q*q' eye(elty, n) @@ -62,20 +62,20 @@ for elty in (Float32, Float64, Complex64, Complex128) @assert_approx_eq q*r[:,invperm(p)] a @assert_approx_eq a*(qrpa\b) b - d,v = eig(asym) # symmetric eigen-decomposition + d,v = eig(asym) # symmetric eigen-decomposition @assert_approx_eq asym*v[:,1] d[1]*v[:,1] @assert_approx_eq v*diagmm(d,v') asym - d,v = eig(a) # non-symmetric eigen decomposition + d,v = eig(a) # non-symmetric eigen decomposition for i in 1:size(a,2) @assert_approx_eq a*v[:,i] d[i]*v[:,i] end - u, q, v = schur(a) # Schur + u, q, v = schur(a) # Schur @assert_approx_eq q*u*q' a @assert_approx_eq sort(real(v)) sort(real(d)) @assert_approx_eq sort(imag(v)) sort(imag(d)) @test istriu(u) || isreal(a) - u,s,vt = svdt(a) # singular value decomposition + u,s,vt = svdt(a) # singular value decomposition @assert_approx_eq u*diagmm(s,vt) a gsvd = factors(svd(a,a[1:5,:])) # Generalized svd @@ -91,63 +91,60 @@ for elty in (Float32, Float64, Complex64, Complex128) x = tril(a)\b @assert_approx_eq tril(a)*x b - # Test null + # Test null a15null = null(a[:,1:5]') @assert_approx_eq_eps norm(a[:,1:5]'a15null) zero(elty) n*eps(one(elty)) @assert_approx_eq_eps norm(a15null'a[:,1:5]) zero(elty) n*eps(one(elty)) @test size(null(b), 2) == 0 - # Test pinv + # Test pinv pinva15 = pinv(a[:,1:5]) @assert_approx_eq a[:,1:5]*pinva15*a[:,1:5] a[:,1:5] @assert_approx_eq pinva15*a[:,1:5]*pinva15 pinva15 - # Complex complex rhs real lhs + # Complex vector rhs x = a\complex(b) @assert_approx_eq a*x complex(b) - # Test cond + # Test cond @assert_approx_eq_eps cond(a, 1) 4.837320054554436e+02 0.01 @assert_approx_eq_eps cond(a, 2) 1.960057871514615e+02 0.01 @assert_approx_eq_eps cond(a, Inf) 3.757017682707787e+02 0.01 @assert_approx_eq_eps cond(a[:,1:5]) 10.233059337453463 0.01 - # Matrix square root + # Matrix square root asq = sqrtm(a) @assert_approx_eq asq*asq a end + +## Least squares solutions a = [ones(20) 1:20 1:20] b = reshape(eye(8, 5), 20, 2) for elty in (Float32, Float64, Complex64, Complex128) a = convert(Matrix{elty}, a) b = convert(Matrix{elty}, b) - # Matrix and vector multiplication - @assert_approx_eq b'b convert(Matrix{elty}, [3 0; 0 2]) - @assert_approx_eq b'b[:,1] convert(Vector{elty}, [3, 0]) - @assert_approx_eq dot(b[:,1], b[:,1]) convert(elty, 3.0) - - # Least squares - x = a[:,1:2]\b[:,1] # Vector rhs + x = a[:,1:2]\b[:,1] # Vector rhs @assert_approx_eq ((a[:,1:2]*x-b[:,1])'*(a[:,1:2]*x-b[:,1]))[1] convert(elty, 2.546616541353384) - x = a[:,1:2]\b # Matrix rhs + x = a[:,1:2]\b # Matrix rhs @assert_approx_eq det((a[:,1:2]*x-b)'*(a[:,1:2]*x-b)) convert(elty, 4.437969924812031) - x = a\b # Rank deficient + x = a\b # Rank deficient @assert_approx_eq det((a*x-b)'*(a*x-b)) convert(elty, 4.437969924812031) - x = convert(Matrix{elty}, [1 0 0; 0 1 -1]) \ convert(Vector{elty}, [1,1]) # Underdetermined minimum norm + # Underdetermined minimum norm + x = convert(Matrix{elty}, [1 0 0; 0 1 -1]) \ convert(Vector{elty}, [1,1]) @assert_approx_eq x convert(Vector{elty}, [1, 0.5, -0.5]) - # symmetric, positive definite + # symmetric, positive definite @assert_approx_eq inv(convert(Matrix{elty}, [6. 2; 2 1])) convert(Matrix{elty}, [0.5 -1; -1 3]) - # symmetric, negative definite + # symmetric, negative definite @assert_approx_eq inv(convert(Matrix{elty}, [1. 2; 2 1])) convert(Matrix{elty}, [-1. 2; 2 -1]/3) end ## Test Julia fallbacks to BLAS routines -# matrices with zero dimensions + # matrices with zero dimensions @test ones(0,5)*ones(5,3) == zeros(0,3) @test ones(3,5)*ones(5,0) == zeros(3,0) @test ones(3,0)*ones(0,4) == zeros(3,4) @@ -155,7 +152,7 @@ end @test ones(0,0)*ones(0,4) == zeros(0,4) @test ones(3,0)*ones(0,0) == zeros(3,0) @test ones(0,0)*ones(0,0) == zeros(0,0) -# 2x2 + # 2x2 A = [1 2; 3 4] B = [5 6; 7 8] @test A*B == [19 22; 43 50] @@ -168,7 +165,7 @@ Bi = B+(2.5*im).*A[[2,1],[2,1]] @test Ac_mul_B(Ai, Bi) == [68.5-12im 57.5-28im; 88-3im 76.5-25im] @test A_mul_Bc(Ai, Bi) == [64.5+5.5im 43+31.5im; 104-18.5im 80.5+31.5im] @test Ac_mul_Bc(Ai, Bi) == [-28.25-66im 9.75-58im; -26-89im 21-73im] -# 3x3 + # 3x3 A = [1 2 3; 4 5 6; 7 8 9]-5 B = [1 0 5; 6 -10 3; 2 -4 -1] @test A*B == [-26 38 -27; 1 -4 -6; 28 -46 15] @@ -181,7 +178,7 @@ Bi = B+(2.5*im).*A[[2,1,3],[2,3,1]] @test Ac_mul_B(Ai, Bi) == [-21+2im -1.75+49im -51.25+19.5im; 25.5+56.5im -7-35.5im 22+35.5im; -3+12im -32.25+43im -34.75-2.5im] @test A_mul_Bc(Ai, Bi) == [-20.25+15.5im -28.75-54.5im 22.25+68.5im; -12.25+13im -15.5+75im -23+27im; 18.25+im 1.5+94.5im -27-54.5im] @test Ac_mul_Bc(Ai, Bi) == [1+2im 20.75+9im -44.75+42im; 19.5+17.5im -54-36.5im 51-14.5im; 13+7.5im 11.25+31.5im -43.25-14.5im] -# Generic integer matrix multiplication + # Generic integer matrix multiplication A = [1 2 3; 4 5 6] - 3 B = [2 -2; 3 -5; -4 7] @test A*B == [-7 9; -4 9] @@ -193,13 +190,13 @@ A = rand(1:20, 5, 5) - 10 B = rand(1:20, 5, 5) - 10 @test At_mul_B(A, B) == A'*B @test A_mul_Bt(A, B) == A*B' -# Preallocated + # Preallocated C = Array(Int, size(A, 1), size(B, 2)) @test A_mul_B(C, A, B) == A*B @test At_mul_B(C, A, B) == A'*B @test A_mul_Bt(C, A, B) == A*B' @test At_mul_Bt(C, A, B) == A'*B' -# matrix algebra with subarrays of floats (stride != 1) + # matrix algebra with subarrays of floats (stride != 1) A = reshape(float64(1:20),5,4) Aref = A[1:2:end,1:2:end] Asub = sub(A, 1:2:5, 1:2:4) @@ -212,7 +209,7 @@ Aref = Ai[1:2:end,1:2:end] Asub = sub(Ai, 1:2:5, 1:2:4) @test Ac_mul_B(Asub, Asub) == Ac_mul_B(Aref, Aref) @test A_mul_Bc(Asub, Asub) == A_mul_Bc(Aref, Aref) -# syrk & herk + # syrk & herk A = reshape(1:1503, 501, 3)-750.0 res = float64([135228751 9979252 -115270247; 9979252 10481254 10983256; -115270247 10983256 137236759]) @test At_mul_B(A, A) == res @@ -227,7 +224,7 @@ Asub = sub(Ai, 1:2:2*cutoff, 1:3) Aref = Ai[1:2:2*cutoff, 1:3] @test Ac_mul_B(Asub, Asub) == Ac_mul_B(Aref, Aref) -# Matrix exponential and Hessenberg + # Matrix exponential for elty in (Float32, Float64, Complex64, Complex128) A1 = convert(Matrix{elty}, [4 2 0; 1 4 1; 1 1 4]) eA1 = convert(Matrix{elty}, [147.866622446369 127.781085523181 127.781085523182; @@ -251,27 +248,27 @@ for elty in (Float32, Float64, Complex64, Complex128) 0.135335281175235 0.406005843524598 0.541341126763207]') @assert_approx_eq expm(A3) eA3 - # Hessenberg + # Hessenberg @assert_approx_eq hess(A1) convert(Matrix{elty}, [4.000000000000000 -1.414213562373094 -1.414213562373095 -1.414213562373095 4.999999999999996 -0.000000000000000 0 -0.000000000000002 3.000000000000000]) end -# matmul for types w/o sizeof (issue #1282) + # matmul for types w/o sizeof (issue #1282) A = Array(ComplexPair{Int},10,10) A[:] = complex(1,1) A2 = A^2 @test A2[1,1] == 20im -# basic tridiagonal operations + # basic tridiagonal operations n = 5 d = 1 + rand(n) dl = -rand(n-1) du = -rand(n-1) v = randn(n) B = randn(n,2) -# Woodbury + # Woodbury U = randn(n,2) V = randn(2,n) C = randn(2,2) @@ -289,8 +286,7 @@ for elty in (Float32, Float64, Complex64, Complex128) F[i+1,i] = dl[i] end @test full(T) == F - - # tridiagonal linear algebra + # tridiagonal linear algebra v = convert(Vector{elty}, v) @assert_approx_eq T*v F*v invFv = F\v @@ -302,24 +298,21 @@ for elty in (Float32, Float64, Complex64, Complex128) x = Tlu\v @assert_approx_eq x invFv @assert_approx_eq det(T) det(F) - - # symmetric tridiagonal + # symmetric tridiagonal Ts = SymTridiagonal(d, dl) Fs = full(Ts) invFsv = Fs\v Tldlt = ldltd(Ts) x = Tldlt\v @assert_approx_eq x invFsv - - # eigenvalues/eigenvectors of symmetric tridiagonal + # eigenvalues/eigenvectors of symmetric tridiagonal if elty === Float32 || elty === Float64 DT, VT = eig(Ts) D, Vecs = eig(Fs) @assert_approx_eq DT D @assert_approx_eq abs(VT'Vecs) eye(elty, n) end - - # Woodbury + # Woodbury U = convert(Matrix{elty}, U) V = convert(Matrix{elty}, V) C = convert(Matrix{elty}, C) @@ -359,10 +352,10 @@ for elty in (Float32, Float64, Complex64, Complex128) @assert_approx_eq_eps det(ones(elty, 3,3)) zero(elty) 3*eps(one(elty)) end -# LAPACK tests + # LAPACK tests Ainit = randn(5,5) for elty in (Float32, Float64, Complex64, Complex128) - # syevr! + # syevr! A = convert(Array{elty, 2}, Ainit) Asym = A'A Z = Array(elty, 5, 5)