From 1687c76c2195320b83397044fa16b00036e2b252 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Wed, 18 Dec 2013 13:53:23 -0500 Subject: [PATCH] add generic fallback for Blas.LinAlg.axpy\!, make sure it always returns y rather than a pointer --- base/linalg.jl | 1 + base/linalg/blas.jl | 7 ++++--- base/linalg/generic.jl | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/base/linalg.jl b/base/linalg.jl index 28e7bdeeb32da..7c1d6fa34a1a9 100644 --- a/base/linalg.jl +++ b/base/linalg.jl @@ -35,6 +35,7 @@ export Diagonal, # Functions + axpy!, bkfact, bkfact!, check_blas, diff --git a/base/linalg/blas.jl b/base/linalg/blas.jl index 77233c057277d..5a6d3d4a7c5a8 100644 --- a/base/linalg/blas.jl +++ b/base/linalg/blas.jl @@ -37,7 +37,7 @@ export const libblas = Base.libblas_name -import ..LinAlg: BlasFloat, BlasChar, BlasInt, blas_int, DimensionMismatch, chksquare +import ..LinAlg: BlasFloat, BlasChar, BlasInt, blas_int, DimensionMismatch, chksquare, axpy! # Level 1 ## copy @@ -163,12 +163,12 @@ for (fname, elty) in ((:daxpy_,:Float64), end end end -function axpy!{T,Ta<:Number}(alpha::Ta, x::Array{T}, y::Array{T}) +function axpy!{T<:BlasFloat,Ta<:Number}(alpha::Ta, x::Array{T}, y::Array{T}) length(x)==length(y) || throw(DimensionMismatch("")) axpy!(length(x), convert(T,alpha), x, 1, y, 1) end -function axpy!{T,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range1{Ti},Range{Ti}), +function axpy!{T<:BlasFloat,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range1{Ti},Range{Ti}), y::Array{T}, ry::Union(Range1{Ti},Range{Ti})) length(rx)==length(ry) || throw(DimensionMismatch("")) @@ -177,6 +177,7 @@ function axpy!{T,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range throw(BoundsError()) end axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry)) + y end ## iamax diff --git a/base/linalg/generic.jl b/base/linalg/generic.jl index ee20586a4516b..775489ca9e2e3 100644 --- a/base/linalg/generic.jl +++ b/base/linalg/generic.jl @@ -203,3 +203,24 @@ function peakflops(n::Integer=2000; parallel::Bool=false) parallel ? sum(pmap(peakflops, [ n for i in 1:nworkers()])) : (2*n^3/t) end +# BLAS-like in-place y=alpha*x+y function (see also the version in blas.jl +# for BlasFloat Arrays) +function axpy!(alpha, x::AbstractArray, y::AbstractArray) + n = length(x) + n==length(y) || throw(DimensionMismatch("")) + for i = 1:n + @inbounds y[i] += alpha * x[i] + end + y +end +function axpy!{Ti<:Integer,Tj<:Integer}(alpha, x::AbstractArray, rx::AbstractArray{Ti}, y::AbstractArray, ry::AbstractArray{Tj}) + length(x)==length(y) || throw(DimensionMismatch("")) + if minimum(rx) < 1 || maximum(rx) > length(x) || minimum(ry) < 1 || maximum(ry) > length(y) || length(rx) != length(ry) + throw(BoundsError()) + end + for i = 1:length(rx) + @inbounds y[ry[i]] += alpha * x[rx[i]] + end + y +end +