Skip to content

Commit

Permalink
Merge the commits from #2069 into master
Browse files Browse the repository at this point in the history
  commit 5c1e646
    Chain methods for gemm and gemv to gemm! and gemv!
    Add gemm and gemv methods that have an implicit 1.0 multiplier
  commit 11c8f36
    Chain more foo methods to foo!, add blas tests.
    Reformat the linalg tests.

Relevant comments are in #2062.

Conflicts:
	test/Makefile
  • Loading branch information
ViralBShah committed Feb 9, 2013
1 parent 95f88b7 commit fead965
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 186 deletions.
254 changes: 124 additions & 130 deletions base/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,16 @@ function axpy!{T,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range
return axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
end


# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
# REAL ALPHA,BETA
# INTEGER K,LDA,LDC,N
# CHARACTER TRANS,UPLO
# * ..
# * .. Array Arguments ..
# REAL A(LDA,*),C(LDC,*)
for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32),
(:zsyrk_,:Complex128), (:csyrk_,:Complex64))
@eval begin
# SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# * .. Scalar Arguments ..
# REAL ALPHA,BETA
# INTEGER K,LDA,LDC,N
# CHARACTER TRANS,UPLO
# * .. Array Arguments ..
# REAL A(LDA,*),C(LDC,*)
function syrk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
Expand All @@ -193,14 +191,9 @@ for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32),
end
function syrk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
n = size(A, trans == 'N' ? 1 : 2)
k = size(A, trans == 'N' ? 2 : 1)
C = Array($elty, (n, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2))
C
syrk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n, n)))
end
syrk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = syrk(uplo, trans, one($elty), A)
end
end

Expand All @@ -214,7 +207,7 @@ end
# COMPLEX A(LDA,*),C(LDC,*)
for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
@eval begin
function herk!(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty},
function herk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
if m != n error("syrk!: matrix C must be square") end
Expand All @@ -227,16 +220,11 @@ for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &beta, C, &stride(C,2))
C
end
function herk(uplo::BlasChar, trans, alpha::($elty), A::StridedVecOrMat{$elty})
function herk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
n = size(A, trans == 'N' ? 1 : 2)
k = size(A, trans == 'N' ? 2 : 1)
C = Array($elty, (n, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &0., C, &stride(C,2))
C
herk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n,n)))
end
herk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = herk(uplo, trans, one($elty), A)
end
end

Expand Down Expand Up @@ -266,16 +254,12 @@ for (fname, elty) in ((:dgbmv_,:Float64), (:sgbmv_,:Float32),
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
n = stride(A,2)
y = Array($elty, n)
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &m, &n, &kl, &ku, &alpha, A, &stride(A,2),
x, &stride(x,1), &0., y, &1)
y
gbmv!(trans, m, kl, ku, alpha, A, x, zero($elty), Array($elty, n))
end
function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
A::StridedMatrix{$elty}, x::StridedVector{$elty})
gbmv(trans, m, kl, ku, one($elty), A, x)
end

end
end

Expand Down Expand Up @@ -303,127 +287,109 @@ for (fname, elty) in ((:dsbmv_,:Float64), (:ssbmv_,:Float32),
function sbmv(uplo::BlasChar, k::Integer, alpha::($elty), A::StridedMatrix{$elty},
x::StridedVector{$elty})
n = size(A,2)
y = Array($elty, n)
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &size(A,2), &k, &alpha, A, &stride(A,2), x, &stride(x,1), &0., y, &1)
y
sbmv!(uplo, k, alpha, A, x, zero($elty), Array($elty, n))
end
function sbmv(uplo::BlasChar, k::Integer, A::StridedMatrix{$elty}, x::StridedVector{$elty})
sbmv(uplo, k, one($elty), A, x)
end
end
end

# (GE) general matrix-matrix multiplication
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
for (fname, elty) in ((:dgemm_,:Float64), (:sgemm_,:Float32),
(:zgemm_,:Complex128), (:cgemm_,:Complex64))
# (GE) general matrix-matrix and matrix-vector multiplication
for (gemm, gemv, elty) in
((:dgemm_,:dgemv_,:Float64),
(:sgemm_,:sgemv_,:Float32),
(:zgemm_,:zgemv_,:Complex128),
(:cgemm_,:cgemv_,:Complex64))
@eval begin
function gemm!(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
# SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# * .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER K,LDA,LDB,LDC,M,N
# CHARACTER TRANSA,TRANSB
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function gemm!(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty},
beta::($elty), C::StridedMatrix{$elty})
# if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
# error("gemm!: BLAS module requires contiguous matrix columns")
# end # should this be checked on every call?
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
n = size(B, transB == 'N' ? 2 : 1)
if m != size(C,1) || n != size(C,2) error("gemm!: mismatched dimensions") end
ccall(($(string(fname)),libblas), Void,
if m != size(C,1) || n != size(C,2)
error("gemm!: mismatched dimensions")
end
ccall(($(string(gemm)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &beta, C, &stride(C,2))
C
end
function gemm(transA::BlasChar, transB::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
if k != size(B, transB == 'N' ? 1 : 2) error("gemm!: mismatched dimensions") end
n = size(B, transB == 'N' ? 2 : 1)
C = Array($elty, (m, n))
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
B, &stride(B,2), &0., C, &stride(C,2))
C
function gemm(transA::BlasChar, transB::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty})
gemm!(transA, transB, alpha, A, B, zero($elty),
Array($elty, (size(A, transA == 'N' ? 1 : 2),
size(B, transB == 'N' ? 2 : 1))))
end
end
end

#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)

for (fname, elty) in ((:dgemv_,:Float64), (:sgemv_,:Float32),
(:zgemv_,:Complex128), (:cgemv_,:Complex64))
@eval begin
function gemv!(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
ccall(($(string(fname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
function gemm(transA::BlasChar, transB::BlasChar,
A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
gemm(transA, transB, one($elty), A, B)
end
#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gemv!(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty},
beta::($elty), Y::StridedVector{$elty})
ccall(($(string(gemv)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &size(A,1), &size(A,2), &alpha, A, &stride(A,2),
X, &stride(X,1), &beta, Y, &stride(Y,1))
Y
end
function gemv(trans::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty})
Y = Array($elty, size(A,1))
gemv!(trans, alpha, A, X, zero($elty), Y)
Y
function gemv(trans::BlasChar,
alpha::($elty), A::StridedMatrix{$elty},
X::StridedVector{$elty})
gemv!(trans, alpha, A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
function gemv(trans::BlasChar, A::StridedMatrix{$elty}, X::StridedVector{$elty})
gemv!(trans, one($elty), A, X, zero($elty),
Array($elty, size(A, (trans == 'N' ? 1 : 2))))
end
end
end

# (SY) symmetric matrix-matrix and matrix-vector multiplication

# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER LDA,LDB,LDC,M,N
# CHARACTER SIDE,UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)

# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,N
# CHARACTER UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)

for (vfname, mfname, elty) in
((:dsymv_,:dsymm_,:Float64),
(:ssymv_,:ssymm_,:Float32),
(:zsymv_,:zsymm_,:Complex128),
(:csymv_,:csymm_,:Complex64))
for (mfname, vfname, elty) in
((:dsymm_,:dsymv_,:Float64),
(:ssymm_,:ssymv_,:Float32),
(:zsymm_,:zsymv_,:Complex128),
(:csymm_,:csymv_,:Complex64))
@eval begin
function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty},
beta::($elty), Y::StridedVector{$elty})
m, n = size(A)
if m != n error("symm!: matrix A is $m by $n but must be square") end
if m != length(X) || m != length(Y) error("symm!: dimension mismatch") end
ccall(($(string(vfname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &alpha, A, &stride(A,2), X, &stride(X,1), &beta, Y, &stride(Y,1))
Y
end
function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, X::StridedVector{$elty})
symv!(uplo, alpha, A, X, zero($elty), similar(X))
end
function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty},
beta::($elty), C::StridedMatrix{$elty})
side = uppercase(convert(Char, side))
# SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER LDA,LDB,LDC,M,N
# CHARACTER SIDE,UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
k, j = size(A)
if k != j error("symm!: matrix A is $k by $j but must be square") end
Expand All @@ -435,9 +401,37 @@ for (vfname, mfname, elty) in
&beta, C, &stride(C,2))
C
end
function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
B::StridedMatrix{$elty})
symm!(side, uplo, alpha, A, B, zero($elty), similar(B))
end
function symm(side::BlasChar, uplo::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
symm(side, uplo, one($elty), A, B)
end
# SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,N
# CHARACTER UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty},
beta::($elty), y::StridedVector{$elty})
m, n = size(A)
if m != n error("symm!: matrix A is $m by $n but must be square") end
if m != length(x) || m != length(y) error("symm!: dimension mismatch") end
ccall(($(string(vfname)),libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &alpha, A, &stride(A,2), x, &stride(x,1), &beta, y, &stride(y,1))
Y
end
function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
symv!(uplo, alpha, A, x, zero($elty), similar(x))
end
function symv(uplo::BlasChar, A::StridedMatrix{$elty}, x::StridedVector{$elty})
symv(uplo, one($elty), A, x)
end
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default all extra unicode gzip::

TESTS = default all extra \
core numbers strings unicode corelib hashing remote \
arrayops linalg fft dct sparse bitarray suitesparse arpack \
arrayops blas linalg fft dct sparse bitarray suitesparse arpack \
random math functional bigint bigfloat sorting \
statistics poly file Rmath remote zlib image \
iostring gzip integers spawn ccall parallel
Expand Down
Loading

0 comments on commit fead965

Please sign in to comment.