diff --git a/src/zygote.jl b/src/zygote.jl index ac356fee..3c8d0d9e 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -83,6 +83,14 @@ function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}, y:: RectangularBoundary(x.side_lengths .+ y; check_positive=false) end +function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{3, T}}}, y::SizedVector{3, T, Vector{T}}) where T + CubicBoundary(x.side_lengths .+ y; check_positive=false) +end + +function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}, y::SizedVector{2, T, Vector{T}}) where T + RectangularBoundary(x.side_lengths .+ y; check_positive=false) +end + function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SizedVector{3, T, Vector{T}}}}, y::SVector{3, T}) where T CubicBoundary(SVector{3, T}(x.side_lengths .+ y); check_positive=false) end @@ -99,6 +107,14 @@ function Base.:+(x::NamedTuple{(:side_lengths,), Tuple{SizedVector{2, T, Vector{ RectangularBoundary(SVector{2, T}(x.side_lengths .+ y.side_lengths); check_positive=false) end +function Base.:+(x::CubicBoundary{T}, y::NamedTuple{(:side_lengths,), Tuple{SVector{3, T}}}) where T + CubicBoundary(SVector{3, T}(x.side_lengths .+ y.side_lengths); check_positive=false) +end + +function Base.:+(x::RectangularBoundary{T}, y::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}) where T + RectangularBoundary(SVector{2, T}(x.side_lengths .+ y.side_lengths); check_positive=false) +end + function Base.:+(x::SVector{3, T}, y::CubicBoundary{T}) where T CubicBoundary(x .+ y.side_lengths; check_positive=false) end