Skip to content

Commit

Permalink
base: make diff() use views and broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
stev47 committed Oct 29, 2018
1 parent 1717adb commit 2f9fd06
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
29 changes: 14 additions & 15 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,17 +659,15 @@ end
end
end

function diff(a::AbstractVector)
@assert !has_offset_axes(a)
[ a[i+1] - a[i] for i=1:length(a)-1 ]
end
diff(a::AbstractVector) = diff(a, dims=1)

"""
diff(A::AbstractVector)
diff(A::AbstractMatrix; dims::Integer)
diff(A::AbstractArray; dims::Integer)
Finite difference operator of matrix or vector `A`. If `A` is a matrix,
specify the dimension over which to operate with the `dims` keyword argument.
Finite difference operator on a vector or a multidimensional array `A`. In the
latter case the dimension to operate on needs to be specified with the `dims`
keyword argument.
# Examples
```jldoctest
Expand All @@ -690,14 +688,15 @@ julia> diff(vec(a))
12
```
"""
function diff(A::AbstractMatrix; dims::Integer)
if dims == 1
[A[i+1,j] - A[i,j] for i=1:size(A,1)-1, j=1:size(A,2)]
elseif dims == 2
[A[i,j+1] - A[i,j] for i=1:size(A,1), j=1:size(A,2)-1]
else
throw(ArgumentError("dimension must be 1 or 2, got $dims"))
end
function diff(a::AbstractArray{T,N}; dims::Integer) where {T,N}
has_offset_axes(a) && throw(ArgumentError("offset axes unsupported"))
1 <= dims <= N || throw(ArgumentError("dimension $dims out of range (1:$N)"))

r = axes(a)
r0 = ntuple(i -> i == dims ? UnitRange(1, last(r[i]) - 1) : UnitRange(r[i]), N)
r1 = ntuple(i -> i == dims ? UnitRange(2, last(r[i])) : UnitRange(r[i]), N)

return view(a, r1...) .- view(a, r0...)
end

### from abstractarray.jl
Expand Down
6 changes: 6 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,9 @@ end

@testset "diff" begin
# test diff, throw ArgumentError for invalid dimension argument
v = [7, 3, 5, 1, 9]
@test diff(v) == [-4, 2, -4, 8]
@test diff(v,dims=1) == [-4, 2, -4, 8]
X = [3 9 5;
7 4 2;
2 1 10]
Expand All @@ -2292,6 +2295,9 @@ end
@test diff(view(X, 1:2, 1:2),dims=2) == reshape([6; -3], (2,1))
@test diff(view(X, 2:3, 2:3),dims=1) == [-3 8]
@test diff(view(X, 2:3, 2:3),dims=2) == reshape([-2; 9], (2,1))
Y = cat([1 3; 4 3], [6 5; 1 4], dims=3)
@test diff(Y, dims=3) == reshape([5 2; -3 1], (2, 2, 1))
@test_throws UndefKeywordError diff(X)
@test_throws ArgumentError diff(X,dims=3)
@test_throws ArgumentError diff(X,dims=-1)
end
Expand Down

0 comments on commit 2f9fd06

Please sign in to comment.