Skip to content

Commit

Permalink
Drop BroadcastStyle back down to the value domain
Browse files Browse the repository at this point in the history
This is a subtle change for broadcast implementors -- instead of defining `BroadcastStyle(::Type{<:MyType})`, they now need to define `BroadcastStyle(::MyType)` directly. This is a breaking change to a new API in 0.7; so it's not technically breaking over 0.6, but any libraries that have adapted to the new API will need to shift their definitions. The primary purpose of this change is becasue the `BroadcastStyle` needs to know about the `ndims` of its passed argument, but some arguments (like PyArrays) don't encode ndims in the type domain. Asking the value directly allows them to participate more fully in the broadcast interface.
  • Loading branch information
mbauman committed May 18, 2018
1 parent 627173b commit 7b5d9b4
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 46 deletions.
31 changes: 15 additions & 16 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dot

"""
`BroadcastStyle` is an abstract type and trait-function used to determine behavior of
objects under broadcasting. `BroadcastStyle(typeof(x))` returns the style associated
objects under broadcasting. `BroadcastStyle(x)` returns the style associated
with `x`. To customize the broadcasting behavior of a type, one can declare a style
by defining a type/method pair
struct MyContainerStyle <: BroadcastStyle end
Base.BroadcastStyle(::Type{<:MyContainer}) = MyContainerStyle()
Base.BroadcastStyle(::MyContainer) = MyContainerStyle()
One then writes method(s) (at least [`similar`](@ref)) operating on
`Broadcasted{MyContainerStyle}`. There are also several pre-defined subtypes of `BroadcastStyle`
Expand All @@ -36,14 +36,13 @@ abstract type BroadcastStyle end
parameter `C`. You can use this as an alternative to creating custom subtypes of `BroadcastStyle`,
for example
Base.BroadcastStyle(::Type{<:MyContainer}) = Broadcast.Style{MyContainer}()
Base.BroadcastStyle(::MyContainer) = Broadcast.Style{MyContainer}()
"""
struct Style{T} <: BroadcastStyle end

BroadcastStyle(::Type{<:Tuple}) = Style{Tuple}()
BroadcastStyle(::Tuple) = Style{Tuple}()

struct Unknown <: BroadcastStyle end
BroadcastStyle(::Type{Union{}}) = Unknown() # ambiguity resolution

"""
`Broadcast.AbstractArrayStyle{N} <: BroadcastStyle` is the abstract supertype for any style
Expand All @@ -52,12 +51,12 @@ The `N` parameter is the dimensionality, which can be handy for AbstractArray ty
that only support specific dimensionalities:
struct SparseMatrixStyle <: Broadcast.AbstractArrayStyle{2} end
Base.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatrixStyle()
Base.BroadcastStyle(::SparseMatrixCSC) = SparseMatrixStyle()
For AbstractArray types that support arbitrary dimensionality, `N` can be set to `Any`:
struct MyArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
Base.BroadcastStyle(::Type{<:MyArray}) = MyArrayStyle()
Base.BroadcastStyle(::MyArray) = MyArrayStyle()
In cases where you want to be able to mix multiple `AbstractArrayStyle`s and keep track
of dimensionality, your style needs to support a `Val` constructor:
Expand Down Expand Up @@ -96,9 +95,9 @@ DefaultArrayStyle(::Val{N}) where N = DefaultArrayStyle{N}()
DefaultArrayStyle{M}(::Val{N}) where {N,M} = DefaultArrayStyle{N}()
const DefaultVectorStyle = DefaultArrayStyle{1}
const DefaultMatrixStyle = DefaultArrayStyle{2}
BroadcastStyle(::Type{<:AbstractArray{T,N}}) where {T,N} = DefaultArrayStyle{N}()
BroadcastStyle(::Type{<:Ref}) = DefaultArrayStyle{0}()
BroadcastStyle(::Type{T}) where {T} = DefaultArrayStyle{ndims(T)}()
BroadcastStyle(::AbstractArray{T,N}) where {T,N} = DefaultArrayStyle{N}()
BroadcastStyle(::Ref) = DefaultArrayStyle{0}()
BroadcastStyle(v::Any) = DefaultArrayStyle{ndims(v)}()

# `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle`
# objects were supplied as arguments, and that no rule was defined for resolving the
Expand Down Expand Up @@ -221,9 +220,9 @@ _axes(::Broadcasted, axes::Tuple) = axes
_axes(bc::Broadcasted{Style{Tuple}}, ::Nothing) = (Base.OneTo(length(longest_tuple(nothing, bc.args))),)
_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = ()

BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned"))
BroadcastStyle(::Broadcasted{Style}) where {Style} = Style()
BroadcastStyle(::Broadcasted{S}) where {S<:Union{Nothing,Unknown}} =
throw(ArgumentError("Broadcasted{$S} wrappers do not have a style assigned"))

argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args
argtype(bc::Broadcasted) = argtype(typeof(bc))
Expand Down Expand Up @@ -400,7 +399,7 @@ longest(::Tuple{}, ::Tuple{}) = ()

# combine_styles operates on values (arbitrarily many)
combine_styles() = DefaultArrayStyle{0}()
combine_styles(c) = result_style(BroadcastStyle(typeof(c)))
combine_styles(c) = result_style(BroadcastStyle(c))
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))

Expand Down Expand Up @@ -589,12 +588,12 @@ Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()
"""
Broadcast.broadcastable(x)
Return either `x` or an object like `x` such that it supports `axes`, indexing, and its type supports `ndims`.
Return either `x` or an object like `x` such that it supports `axes`, indexing, and `ndims`.
If `x` supports iteration, the returned value should have the same `axes` and indexing
behaviors as [`collect(x)`](@ref).
If `x` is not an `AbstractArray` but it supports `axes`, indexing, and its type supports
If `x` is not an `AbstractArray` but it supports `axes`, indexing, and
`ndims`, then `broadcastable(::typeof(x))` may be implemented to just return itself.
Further, if `x` defines its own [`BroadcastStyle`](@ref), then it must define its
`broadcastable` method to return itself for the custom style to have any effect.
Expand Down
16 changes: 8 additions & 8 deletions doc/src/manual/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ V = view(A, [1,2,4], :) # is not strided, as the spacing between rows is not f

| Methods to implement | Brief description |
|:-------------------- |:----------------- |
| `Base.BroadcastStyle(::Type{SrcType}) = SrcStyle()` | Broadcasting behavior of `SrcType` |
| `Base.BroadcastStyle(::SrcType) = SrcStyle()` | Broadcasting behavior of `SrcType` |
| `Base.similar(bc::Broadcasted{DestStyle}, ::Type{ElType})` | Allocation of output container |
| **Optional methods** | | |
| `Base.BroadcastStyle(::Style1, ::Style2) = Style12()` | Precedence rules for mixing styles |
Expand Down Expand Up @@ -483,15 +483,15 @@ To override these defaults, you can define a custom `BroadcastStyle` for your ob

```julia
struct MyStyle <: Broadcast.BroadcastStyle end
Base.BroadcastStyle(::Type{<:MyType}) = MyStyle()
Base.BroadcastStyle(::MyType) = MyStyle()
```

In some cases it might be convenient not to have to define `MyStyle`, in which case you can
leverage one of the general broadcast wrappers:

- `Base.BroadcastStyle(::Type{<:MyType}) = Broadcast.Style{MyType}()` can be
- `Base.BroadcastStyle(::MyType) = Broadcast.Style{MyType}()` can be
used for arbitrary types.
- `Base.BroadcastStyle(::Type{<:MyType}) = Broadcast.ArrayStyle{MyType}()` is preferred
- `Base.BroadcastStyle(::MyType) = Broadcast.ArrayStyle{MyType}()` is preferred
if `MyType` is an `AbstractArray`.
- For `AbstractArrays` that only support a certain dimensionality, create a subtype of `Broadcast.AbstractArrayStyle{N}` (see below).

Expand Down Expand Up @@ -541,7 +541,7 @@ Base.showarg(io::IO, A::ArrayAndChar, toplevel) = print(io, typeof(A), " with ch
You might want broadcasting to preserve the `char` "metadata." First we define

```jldoctest ArrayAndChar
Base.BroadcastStyle(::Type{<:ArrayAndChar}) = Broadcast.ArrayStyle{ArrayAndChar}()
Base.BroadcastStyle(::ArrayAndChar) = Broadcast.ArrayStyle{ArrayAndChar}()
# output
```
Expand Down Expand Up @@ -702,13 +702,13 @@ rules unless you want to establish precedence for
two or more non-`DefaultArrayStyle` types.

If your array type does have fixed dimensionality requirements, then you should
subtype `AbstractArrayStyle`. For example, the sparse array code has the following definitions:
subtype `AbstractArrayStyle`. For example, the SparseArrays standard library has the following definitions:

```julia
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Base.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
Base.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
Base.BroadcastStyle(::SparseVector) = SparseVecStyle()
Base.BroadcastStyle(::SparseMatrixCSC) = SparseMatStyle()
```

Whenever you subtype `AbstractArrayStyle`, you also need to define rules for combining
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()

const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix} = StructuredMatrixStyle{T}()
Broadcast.BroadcastStyle(S::StructuredMatrix) = StructuredMatrixStyle{typeof(S)}()

# Promotion of broadcasts between structured matrices. This is slightly unusual
# as we define them symmetrically. This allows us to have a fallback to DefaultArrayStyle{2}().
Expand Down
5 changes: 0 additions & 5 deletions stdlib/SparseArrays/src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,6 @@ using LinearAlgebra: Adjoint, Transpose
\(::Adjoint{<:Any,<:SparseMatrixCSC}, ::RowVector) = throw(DimensionMismatch("Cannot left-divide matrix by transposed vector"))
\(::Transpose{<:Any,<:SparseMatrixCSC}, ::RowVector) = throw(DimensionMismatch("Cannot left-divide matrix by transposed vector"))

# methods involving RowVector from base/sparse/higherorderfns.jl, to deprecate
@eval SparseArrays.HigherOrderFns begin
BroadcastStyle(::Type{<:RowVector{T,<:Vector}}) where T = Broadcast.MatrixStyle()
end

import Base: asyncmap
@deprecate asyncmap(f, s::AbstractSparseArray...; kwargs...) sparse(asyncmap(f, map(Array, s)...; kwargs...))

Expand Down
8 changes: 4 additions & 4 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}
# broadcast container type promotion for combinations of sparse arrays and other types
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
Broadcast.BroadcastStyle(::SparseVector) = SparseVecStyle()
Broadcast.BroadcastStyle(::SparseMatrixCSC) = SparseMatStyle()
const SPVM = Union{SparseVecStyle,SparseMatStyle}

# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
Expand Down Expand Up @@ -66,8 +66,8 @@ PromoteToSparse(::Val{2}) = PromoteToSparse()
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
Broadcast.BroadcastStyle(::Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T) = PromoteToSparse()
Broadcast.BroadcastStyle(::Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T) = PromoteToSparse()

Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
Expand Down
24 changes: 12 additions & 12 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,27 +436,27 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{A}}, ::Type{T}) wher
struct Array19745{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:Array19745} = Broadcast.ArrayStyle{Array19745}()
Base.BroadcastStyle(::Array19745) = Broadcast.ArrayStyle{Array19745}()

# Two specialized broadcast rules with no declared precedence
struct AD1{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD1} = Broadcast.ArrayStyle{AD1}()
Base.BroadcastStyle(::AD1) = Broadcast.ArrayStyle{AD1}()
struct AD2{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD2} = Broadcast.ArrayStyle{AD2}()
Base.BroadcastStyle(::AD2) = Broadcast.ArrayStyle{AD2}()

# Two specialized broadcast rules with explicit precedence
struct AD1P{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD1P} = Broadcast.ArrayStyle{AD1P}()
Base.BroadcastStyle(::AD1P) = Broadcast.ArrayStyle{AD1P}()
struct AD2P{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD2P} = Broadcast.ArrayStyle{AD2P}()
Base.BroadcastStyle(::AD2P) = Broadcast.ArrayStyle{AD2P}()

Base.BroadcastStyle(a1::Broadcast.ArrayStyle{AD1P}, ::Broadcast.ArrayStyle{AD2P}) = a1

Expand All @@ -465,11 +465,11 @@ Base.BroadcastStyle(a1::Broadcast.ArrayStyle{AD1P}, ::Broadcast.ArrayStyle{AD2P}
struct AD1B{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD1B} = Broadcast.ArrayStyle{AD1B}()
Base.BroadcastStyle(::AD1B) = Broadcast.ArrayStyle{AD1B}()
struct AD2B{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD2B} = Broadcast.ArrayStyle{AD2B}()
Base.BroadcastStyle(::AD2B) = Broadcast.ArrayStyle{AD2B}()

Base.BroadcastStyle(a1::Broadcast.ArrayStyle{AD1B}, a2::Broadcast.ArrayStyle{AD2B}) = a1
Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2B}, a1::Broadcast.ArrayStyle{AD1B}) = a1
Expand All @@ -478,11 +478,11 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2B}, a1::Broadcast.ArrayStyle{AD1
struct AD1C{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD1C} = Broadcast.ArrayStyle{AD1C}()
Base.BroadcastStyle(::AD1C) = Broadcast.ArrayStyle{AD1C}()
struct AD2C{T,N} <: ArrayData{T,N}
data::Array{T,N}
end
Base.BroadcastStyle(::Type{T}) where {T<:AD2C} = Broadcast.ArrayStyle{AD2C}()
Base.BroadcastStyle(::AD2C) = Broadcast.ArrayStyle{AD2C}()

Base.BroadcastStyle(a1::Broadcast.ArrayStyle{AD1C}, a2::Broadcast.ArrayStyle{AD2C}) = a1
Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1C}) = a2
Expand All @@ -496,7 +496,7 @@ AD2DimStyle(::Val{2}) = AD2DimStyle()
AD2DimStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
Base.similar(bc::Broadcast.Broadcasted{AD2DimStyle}, ::Type{T}) where {T} =
AD2Dim(Array{T}(undef, length.(axes(bc))))
Base.BroadcastStyle(::Type{T}) where {T<:AD2Dim} = AD2DimStyle()
Base.BroadcastStyle(::AD2Dim) = AD2DimStyle()

@testset "broadcasting for custom AbstractArray" begin
a = randn(10)
Expand Down Expand Up @@ -622,7 +622,7 @@ struct Foo26601{T}
end
Base.axes(f::Foo26601) = axes(f.data)
Base.getindex(f::Foo26601, i...) = getindex(f.data, i...)
Base.ndims(::Type{Foo26601{T}}) where {T} = ndims(T)
Base.ndims(::Foo26601{T}) where {T} = ndims(T)
Base.Broadcast.broadcastable(f::Foo26601) = f
@testset "barebones custom object broadcasting" begin
for d in (rand(Float64, ()), rand(5), rand(5,5), rand(5,5,5))
Expand Down Expand Up @@ -695,7 +695,7 @@ end
struct T22053
t
end
Broadcast.BroadcastStyle(::Type{T22053}) = Broadcast.Style{T22053}()
Broadcast.BroadcastStyle(::T22053) = Broadcast.Style{T22053}()
Broadcast.broadcast_axes(::T22053) = ()
Broadcast.broadcastable(t::T22053) = t
function Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{T22053}})
Expand Down

0 comments on commit 7b5d9b4

Please sign in to comment.