diff --git a/Project.toml b/Project.toml index 3269975..993092f 100644 --- a/Project.toml +++ b/Project.toml @@ -26,11 +26,12 @@ LinearMaps = "3" MacroTools = "0.5" NonconvexCore = "1.1" SparseDiffTools = "1.24" -Symbolics = "4.6" +Symbolics = "5" Zygote = "0.5, 0.6" julia = "1" [extras] +DifferentiableFlatten = "c78775a3-ee38-4681-b694-0504db4f5dc7" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" NonconvexIpopt = "bf347577-a06d-49ad-a669-8c0e005493b8" @@ -41,4 +42,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"] +test = ["DifferentiableFlatten", "NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"] diff --git a/test/forwarddiff_frule.jl b/test/forwarddiff_frule.jl index c6361e1..5b185b6 100644 --- a/test/forwarddiff_frule.jl +++ b/test/forwarddiff_frule.jl @@ -33,7 +33,7 @@ # The @constructor macro takes the type (first) and constructor function (second) # The constructor function takes input the fields generated from ntfromstruct (as multiple positional arguments) # The ntfromstruct function can be overloaded for your type - NonconvexCore.@constructor MyStruct MyStruct + DifferentiableFlatten.@constructor MyStruct MyStruct f2(x::MyStruct, y::MyStruct) = MyStruct(x.a + y.a, x.b + y.b) function ChainRulesCore.frule((_, Δx1, Δx2), ::typeof(f2), x1::MyStruct, x2::MyStruct) @@ -53,7 +53,7 @@ end # I recommend creating your own type to avoid piracy - NonconvexCore.@constructor Symmetric Symmetric + DifferentiableFlatten.@constructor Symmetric Symmetric import NamedTupleTools: ntfromstruct, structfromnt ntfromstruct(a::Symmetric) = (data = a.data,) structfromnt(::Type{Symmetric}, x::NamedTuple) = Symmetric(x.data, :U) diff --git a/test/runtests.jl b/test/runtests.jl index 660ba39..e6f9d12 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using NonconvexUtils, ForwardDiff, ReverseDiff, Tracker, Zygote using Test, LinearAlgebra, SparseArrays, NLsolve, IterativeSolvers using StableRNGs, ChainRulesCore, NonconvexCore, NonconvexIpopt +using DifferentiableFlatten include("forwarddiff_frule.jl") include("abstractdiff.jl")