Skip to content

Commit

Permalink
fix(test): use DifferentiableFlatten.@constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ven-k committed Jun 27, 2023
1 parent c20bdff commit 1c54d40
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Zygote = "0.5, 0.6"
julia = "1"

[extras]
DifferentiableFlatten = "c78775a3-ee38-4681-b694-0504db4f5dc7"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
NonconvexIpopt = "bf347577-a06d-49ad-a669-8c0e005493b8"
Expand All @@ -41,4 +42,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"]
test = ["DifferentiableFlatten", "NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"]
4 changes: 2 additions & 2 deletions test/forwarddiff_frule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# The @constructor macro takes the type (first) and constructor function (second)
# The constructor function takes input the fields generated from ntfromstruct (as multiple positional arguments)
# The ntfromstruct function can be overloaded for your type
NonconvexCore.@constructor MyStruct MyStruct
DifferentiableFlatten.@constructor MyStruct MyStruct

f2(x::MyStruct, y::MyStruct) = MyStruct(x.a + y.a, x.b + y.b)
function ChainRulesCore.frule((_, Δx1, Δx2), ::typeof(f2), x1::MyStruct, x2::MyStruct)
Expand All @@ -53,7 +53,7 @@
end

# I recommend creating your own type to avoid piracy
NonconvexCore.@constructor Symmetric Symmetric
DifferentiableFlatten.@constructor Symmetric Symmetric
import NamedTupleTools: ntfromstruct, structfromnt
ntfromstruct(a::Symmetric) = (data = a.data,)
structfromnt(::Type{Symmetric}, x::NamedTuple) = Symmetric(x.data, :U)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using NonconvexUtils, ForwardDiff, ReverseDiff, Tracker, Zygote
using Test, LinearAlgebra, SparseArrays, NLsolve, IterativeSolvers
using StableRNGs, ChainRulesCore, NonconvexCore, NonconvexIpopt
using DifferentiableFlatten

include("forwarddiff_frule.jl")
include("abstractdiff.jl")
Expand Down

0 comments on commit 1c54d40

Please sign in to comment.