From 491ebbaa6968e10f5a7d74ccb000b19fd9695f6c Mon Sep 17 00:00:00 2001 From: Venkateshprasad <32921645+ven-k@users.noreply.github.com> Date: Wed, 28 Jun 2023 04:46:00 +0530 Subject: [PATCH] chore: bump Symbolics to 5 (#13) * chore: bump Symbolics to 5 * fix(test): use `DifferentiableFlatten.@constructor` --- Project.toml | 5 +++-- test/forwarddiff_frule.jl | 4 ++-- test/runtests.jl | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 3269975..993092f 100644 --- a/Project.toml +++ b/Project.toml @@ -26,11 +26,12 @@ LinearMaps = "3" MacroTools = "0.5" NonconvexCore = "1.1" SparseDiffTools = "1.24" -Symbolics = "4.6" +Symbolics = "5" Zygote = "0.5, 0.6" julia = "1" [extras] +DifferentiableFlatten = "c78775a3-ee38-4681-b694-0504db4f5dc7" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" NonconvexIpopt = "bf347577-a06d-49ad-a669-8c0e005493b8" @@ -41,4 +42,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"] +test = ["DifferentiableFlatten", "NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"] diff --git a/test/forwarddiff_frule.jl b/test/forwarddiff_frule.jl index c6361e1..5b185b6 100644 --- a/test/forwarddiff_frule.jl +++ b/test/forwarddiff_frule.jl @@ -33,7 +33,7 @@ # The @constructor macro takes the type (first) and constructor function (second) # The constructor function takes input the fields generated from ntfromstruct (as multiple positional arguments) # The ntfromstruct function can be overloaded for your type - NonconvexCore.@constructor MyStruct MyStruct + DifferentiableFlatten.@constructor MyStruct MyStruct f2(x::MyStruct, y::MyStruct) = MyStruct(x.a + y.a, x.b + y.b) function ChainRulesCore.frule((_, Δx1, Δx2), ::typeof(f2), x1::MyStruct, x2::MyStruct) @@ -53,7 +53,7 @@ end # I recommend creating your own type to avoid piracy - NonconvexCore.@constructor Symmetric Symmetric + DifferentiableFlatten.@constructor Symmetric Symmetric import NamedTupleTools: ntfromstruct, structfromnt ntfromstruct(a::Symmetric) = (data = a.data,) structfromnt(::Type{Symmetric}, x::NamedTuple) = Symmetric(x.data, :U) diff --git a/test/runtests.jl b/test/runtests.jl index 660ba39..e6f9d12 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using NonconvexUtils, ForwardDiff, ReverseDiff, Tracker, Zygote using Test, LinearAlgebra, SparseArrays, NLsolve, IterativeSolvers using StableRNGs, ChainRulesCore, NonconvexCore, NonconvexIpopt +using DifferentiableFlatten include("forwarddiff_frule.jl") include("abstractdiff.jl")