diff --git a/src/primitives.jl b/src/primitives.jl index bdfb048..546cbb2 100644 --- a/src/primitives.jl +++ b/src/primitives.jl @@ -134,6 +134,59 @@ Returns the last entry of `X`. """ -> Base.endof(X::NullableArray) = endof(X.values) # -> Int +@doc """ +`==(A::NullableArray, B::NullableArray)` +`==(A::NullableArray, B::AbstractArray)` +`==(A::AbstractArray, B::NullableArray)` + +Returns `Nullable(true)` if all array elements of the same rank +are equal and none of the arrays contain missing values, `Nullable(false)` +if two (non-missing) elements of the same rank differ, and `Nullable{Bool}()` +otherwise. +""" -> +function Base.(:(==))(A::NullableArray, B::NullableArray) + if size(A) != size(B) + return Nullable(false) + end + anynull = false + for i in eachindex(A,B) + if A.isnull[i] || B.isnull[i] + anynull = true + elseif A.values[i] != B.values[i] + return Nullable(false) + end + end + if anynull + return Nullable{Bool}() + else + return Nullable(true) + end +end + +function Base.(:(==))(A::NullableArray, B::AbstractArray) + if size(A) != size(B) + return Nullable(false) + end + if isa(B, Range) + return Nullable(false) + end + anynull = false + for i in eachindex(A,B) + if A.isnull[i] + anynull = true + elseif A.values[i] != B[i] + return Nullable(false) + end + end + if anynull + return Nullable{Bool}() + else + return Nullable(true) + end +end + +Base.(:(==))(A::AbstractArray, B::NullableArray) = B == A + @doc """ """ -> diff --git a/test/primitives.jl b/test/primitives.jl index f13d340..bb8e12f 100644 --- a/test/primitives.jl +++ b/test/primitives.jl @@ -149,6 +149,27 @@ module TestPrimitives @test endof(NullableArray(collect(1:10))) == 10 @test endof(NullableArray([1, 2, nothing, 4, nothing])) == 5 +# ----- test Base.== ------------------------------------------------------# + x = collect(1:3) + xn = NullableArray(x) + y = [1.0, 2.0, 3.0] + yn = NullableArray(y, [false, false, true]) + z = [1.1, 2.0, 3.0] + zn = NullableArray(z, [false, true, false]) + @test get(x == NullableArray(x)) + @test get(y == NullableArray(y)) + @test get(x != NullableArray(z)) + @test get(x != yn[1:2]) + @test get(x != zn[1:2]) + + @test isnull(xn == yn) + @test get(xn == y) + @test isnull(x == yn) + + @test get(xn != zn) + @test get(xn != z) + @test get(x != zn) + # ----- test Base.find -------------------------------------------------------# z = NullableArray(rand(Bool, 10))