Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow zero-arrays #58

Closed
wants to merge 3 commits into from
Closed

Allow zero-arrays #58

wants to merge 3 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 20, 2022

This wants to check that zero-dimensional arrays are allowed, and preserved.

But they aren't right now, because subtract! will return a number. Unless it's writing in-place. Unless we use broadcast_preserving_zero_d, which, edit, this PR now does:

julia> m = (a = fill(1.0), b = SArray{Tuple{}}(fill(1.0)), c = PermutedDimsArray(fill(1.0), ()));

# normal arrays

julia> m.a .- m.a
0.0

julia> Broadcast.broadcast_preserving_zero_d(-, m.a, m.a)
0-dimensional Array{Float64, 0}:
0.0

This doesn't work for StaticArrays, maybe that's their problem not ours:

julia> m.b .- m.b
Scalar{Float64}((0.0,))

julia> Broadcast.broadcast_preserving_zero_d(-, m.b, m.b)
ERROR: MethodError: no method matching similar(::Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{0}, Nothing, typeof(-), Tuple{Scalar{Float64}, Scalar{Float64}}}, ::Type{Scalar{Float64}}, ::Tuple{})
Closest candidates are:
  similar(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}, ::Type{T}, ::Tuple{Vararg{Int64, N}}) where {T, N} at ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:212
  similar(::Base.ReinterpretArray, ::Type, ::Tuple{Vararg{Int64, N}} where N) at ~/.julia/dev/julia/usr/share/julia/base/reinterpretarray.jl:185
  similar(::UpperHessenberg, ::Type{T}, ::Tuple{Vararg{Int64, N}}) where {T, N} at ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/hessenberg.jl:61
  ...
Stacktrace:
 [1] similar(bc::Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{0}, Nothing, typeof(-), Tuple{Scalar{Float64}, Scalar{Float64}}}, #unused#::Type{Scalar{Float64}})
   @ Base.Broadcast ./broadcast.jl:211

Alternatively, we could ban zero-arrays as being too weird.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 23, 2022

Should this allow Ref as a mutable zero-dim container in which to store scalar parameters?

One way to do so would be make it a scalar as far as the rules are concerned, by unwrapping, something like this:

isnumeric(x::Ref{<:Number}) = true  # these aren't leaf according to functors
isnumeric(x::Ref{<:Integer}) = false

iswriteable(::Ref) = true  # needed not for subtract!, but for update's fmap(copy)
init(o, x::Ref) = init(o, x.x)  # treat like a scalar for what state to store

update!(ℓ::Leaf, x::Ref, ::Zero, ::Zero...) = ℓ, x
function update!(ℓ::Leaf, x::Ref, x̄s...)
  s′, x̄′ = apply!(ℓ.rule, ℓ.state, x.x, map(x̄ -> base(x̄).x, x̄s)...)
  x[] = x[] .- x̄′
  Leaf(ℓ.rule, s′), x
end

Tests here have a few examples, but the are regarded as non-trainable -- it's just testing that they don't give errors really.

@mcabbott mcabbott marked this pull request as ready for review February 23, 2022 05:38
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see how one might accidentally get a 0-dim array, but a Ref is a conscious choice. What's the value over using a 0-dim array or single element array? I thinking ignoring them is fine.

if iswriteable(x)
x .= x .- x̄
else
broadcast_preserving_zero_d(eltype(x), broadcasted(-, x, x̄))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a public function? Meaning we can expect it to be stable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure. It has a docstring but isn't in the manual. It's used quite a bit, e.g. to implement conj at https://github.com/JuliaLang/julia/blob/4c8c5153a566b25ef8c7b7091b5126328812d287/base/abstractarraymath.jl#L145 .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this is the only use of it in the wild either: https://juliahub.com/ui/Search?q=broadcast_preserving_zero_d&type=code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good enough for me!

@mcabbott
Copy link
Member Author

For zero-arrays, someone will try, and I think it should either work or be an error. Seems bad if update! can silently replace them with numbers & then stop training, which is what falls out of broadcasting. They work in Flux now, and mutable ones will similarly work here, I think.

Maybe Ref ought to be a separate issue. It's the obvious mutable one-element container. But perhaps telling people to use a vector is fine. Currently ignored by Flux.

@darsnack
Copy link
Member

darsnack commented Feb 24, 2022

For zero-arrays, someone will try, and I think it should either work or be an error. Seems bad if update! can silently replace them with numbers & then stop training, which is what falls out of broadcasting. They work in Flux now, and mutable ones will similarly work here, I think.

👍🏾

Maybe Ref ought to be a separate issue. It's the obvious mutable one-element container. But perhaps telling people to use a vector is fine. Currently ignored by Flux.

Certainly not contentious enough on my end that it can't be resolved now, but Flux's advice for scalar params has been a 1-element vector for quite some time. I don't see the practical reason someone would want Ref, so it seems safer to just not support it (by treating it non-trainable). That said, if you think it warrants later discussion, we can just move on.

@ToucheSir
Copy link
Member

While looking for JuliaLang/julia#35778 (comment) as background reading for this discussion, I found https://juliaarrays.github.io/StaticArrays.jl/latest/pages/api/#Scalar. Perhaps it's fine to ask users to use a type like this if they want 0-d arrays?

@mcabbott
Copy link
Member Author

That plays badly with Functors right now:

julia> fmap(println, (x = Fill(1.0), y = Fill(1.0)))
Fill(1.0)
(x = nothing, y = nothing)

And if we fix it, we can fix scalars too.

Really you should use a 1-element vector, [1.0], for now. But someone will use an Array{T,0}, and this package should either work, or not work (and tell you), I think.

This PR votes for "work" now. broadcast_preserving_zero_d is a little ugly. I was concerned that more broadcasts within the rules would go wrong, by producing a scalar, but they seem to work OK.

@ToucheSir
Copy link
Member

I thought the consternation with using broadcast_preserving_zero_d was that certain array types in the wild don't work correctly with it? Or has that been resolved since the last top post update?

@mcabbott
Copy link
Member Author

Yes I don't know what to do there. These don't follow Base:

julia> using FillArrays

julia> Fill(1) .+ Fill(2)
0-dimensional Fill{Int64}, with entry equal to 3

julia> using StaticArrays

julia> s0 = SArray{Tuple{}}(fill(1.0))
Scalar{Float64}((1.0,))

julia> s0 isa AbstractArray{Float64, 0}
true

julia> s0 .+ s0
Scalar{Float64}((2.0,))

which results in weird behaviour or errors:

julia> Broadcast.broadcast_preserving_zero_d(-, Fill(1), Fill(2))
0-dimensional Array{Fill{Int64, 0, Tuple{}}, 0}:
Fill(-1)

julia> Broadcast.broadcast_preserving_zero_d(-, s0, s0)
ERROR: MethodError: no method matching similar(::Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{0}, Nothing, typeof(-), Tuple{Scalar{Float64}, Scalar{Float64}}}, ::Type{Scalar{Float64}}, ::Tuple{})

xref JuliaArrays/FillArrays.jl#145 and nothing https://github.com/JuliaArrays/StaticArrays.jl/search?q=broadcast_preserving_zero_d

They also don't really work with Functors.jl, so perhaps they aren't usable anyway.

Maybe this PR is a bit random then, since Array actually works (by mutation) and ImmutableArrays haven't landed yet.

And wrappers like this aren't preserved by broadcasting:

julia> using ReadOnlyArrays

julia> ReadOnlyArray(fill(0))
0-dimensional ReadOnlyArray{Int64, 0, Array{Int64, 0}}:
0

julia> ans .+ ans
0

julia> Broadcast.broadcast_preserving_zero_d(-, ReadOnlyArray(fill(0)), ReadOnlyArray(fill(0)))
0-dimensional Array{Int64, 0}:
0

julia> ReadOnlyArray([1,2]) .+ 3
2-element Vector{Int64}:
 4
 5

@mcabbott mcabbott closed this Feb 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants