Skip to content

Commit

Permalink
boundary gradient functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Jun 27, 2024
1 parent c3a74cd commit 49d92d9
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}, y::
RectangularBoundary(x.side_lengths .+ y; check_positive=false)
end

function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{3, T}}}, y::SizedVector{3, T, Vector{T}}) where T
CubicBoundary(x.side_lengths .+ y; check_positive=false)
end

function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}, y::SizedVector{2, T, Vector{T}}) where T
RectangularBoundary(x.side_lengths .+ y; check_positive=false)
end

function Zygote.accum(x::NamedTuple{(:side_lengths,), Tuple{SizedVector{3, T, Vector{T}}}}, y::SVector{3, T}) where T
CubicBoundary(SVector{3, T}(x.side_lengths .+ y); check_positive=false)
end
Expand All @@ -99,6 +107,14 @@ function Base.:+(x::NamedTuple{(:side_lengths,), Tuple{SizedVector{2, T, Vector{
RectangularBoundary(SVector{2, T}(x.side_lengths .+ y.side_lengths); check_positive=false)
end

function Base.:+(x::CubicBoundary{T}, y::NamedTuple{(:side_lengths,), Tuple{SVector{3, T}}}) where T
CubicBoundary(SVector{3, T}(x.side_lengths .+ y.side_lengths); check_positive=false)
end

function Base.:+(x::RectangularBoundary{T}, y::NamedTuple{(:side_lengths,), Tuple{SVector{2, T}}}) where T
RectangularBoundary(SVector{2, T}(x.side_lengths .+ y.side_lengths); check_positive=false)
end

function Base.:+(x::SVector{3, T}, y::CubicBoundary{T}) where T
CubicBoundary(x .+ y.side_lengths; check_positive=false)
end
Expand Down

0 comments on commit 49d92d9

Please sign in to comment.