Skip to content

Commit

Permalink
Add ==(::NullableArray, ::NullableArray) and ==(::NullableArray, ::Ab…
Browse files Browse the repository at this point in the history
…stractArray)

Fixes JuliaStats#82.
  • Loading branch information
nalimilan committed Oct 12, 2015
1 parent 6cf63d0 commit bb9a213
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,59 @@ Returns the last entry of `X`.
""" ->
Base.endof(X::NullableArray) = endof(X.values) # -> Int

@doc """
`==(A::NullableArray, B::NullableArray)`
`==(A::NullableArray, B::AbstractArray)`
`==(A::AbstractArray, B::NullableArray)`
Returns `Nullable(true)` if all array elements of the same rank
are equal and none of the arrays contain missing values, `Nullable(false)`
if two (non-missing) elements of the same rank differ, and `Nullable{Bool}()`
otherwise.
""" ->
function Base.(:(==))(A::NullableArray, B::NullableArray)
if size(A) != size(B)
return Nullable(false)
end
anynull = false
for i in eachindex(A,B)
if A.isnull[i] || B.isnull[i]
anynull = true
elseif A.values[i] != B.values[i]
return Nullable(false)
end
end
if anynull
return Nullable{Bool}()
else
return Nullable(true)
end
end

function Base.(:(==))(A::NullableArray, B::AbstractArray)
if size(A) != size(B)
return Nullable(false)
end
if isa(B, Range)
return Nullable(false)
end
anynull = false
for i in eachindex(A,B)
if A.isnull[i]
anynull = true
elseif A.values[i] != B[i]
return Nullable(false)
end
end
if anynull
return Nullable{Bool}()
else
return Nullable(true)
end
end

Base.(:(==))(A::AbstractArray, B::NullableArray) = B == A

@doc """
""" ->
Expand Down
21 changes: 21 additions & 0 deletions test/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ module TestPrimitives
@test endof(NullableArray(collect(1:10))) == 10
@test endof(NullableArray([1, 2, nothing, 4, nothing])) == 5

# ----- test Base.== ------------------------------------------------------#
x = collect(1:3)
xn = NullableArray(x)
y = [1.0, 2.0, 3.0]
yn = NullableArray(y, [false, false, true])
z = [1.1, 2.0, 3.0]
zn = NullableArray(z, [false, true, false])
@test get(x == NullableArray(x))
@test get(y == NullableArray(y))
@test get(x != NullableArray(z))
@test get(x != yn[1:2])
@test get(x != zn[1:2])

@test isnull(xn == yn)
@test get(xn == y)
@test isnull(x == yn)

@test get(xn != zn)
@test get(xn != z)
@test get(x != zn)

# ----- test Base.find -------------------------------------------------------#

z = NullableArray(rand(Bool, 10))
Expand Down

0 comments on commit bb9a213

Please sign in to comment.