Skip to content

Commit

Permalink
chore: bump Symbolics to 5 (#13)
Browse files Browse the repository at this point in the history
* chore: bump Symbolics to 5

* fix(test): use `DifferentiableFlatten.@constructor`
  • Loading branch information
ven-k authored Jun 27, 2023
1 parent 7be2c34 commit 491ebba
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ LinearMaps = "3"
MacroTools = "0.5"
NonconvexCore = "1.1"
SparseDiffTools = "1.24"
Symbolics = "4.6"
Symbolics = "5"
Zygote = "0.5, 0.6"
julia = "1"

[extras]
DifferentiableFlatten = "c78775a3-ee38-4681-b694-0504db4f5dc7"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
NonconvexIpopt = "bf347577-a06d-49ad-a669-8c0e005493b8"
Expand All @@ -41,4 +42,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"]
test = ["DifferentiableFlatten", "NamedTupleTools", "NonconvexIpopt", "NLsolve", "ReverseDiff", "SparseArrays", "StableRNGs", "Test", "Tracker"]
4 changes: 2 additions & 2 deletions test/forwarddiff_frule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# The @constructor macro takes the type (first) and constructor function (second)
# The constructor function takes input the fields generated from ntfromstruct (as multiple positional arguments)
# The ntfromstruct function can be overloaded for your type
NonconvexCore.@constructor MyStruct MyStruct
DifferentiableFlatten.@constructor MyStruct MyStruct

f2(x::MyStruct, y::MyStruct) = MyStruct(x.a + y.a, x.b + y.b)
function ChainRulesCore.frule((_, Δx1, Δx2), ::typeof(f2), x1::MyStruct, x2::MyStruct)
Expand All @@ -53,7 +53,7 @@
end

# I recommend creating your own type to avoid piracy
NonconvexCore.@constructor Symmetric Symmetric
DifferentiableFlatten.@constructor Symmetric Symmetric
import NamedTupleTools: ntfromstruct, structfromnt
ntfromstruct(a::Symmetric) = (data = a.data,)
structfromnt(::Type{Symmetric}, x::NamedTuple) = Symmetric(x.data, :U)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using NonconvexUtils, ForwardDiff, ReverseDiff, Tracker, Zygote
using Test, LinearAlgebra, SparseArrays, NLsolve, IterativeSolvers
using StableRNGs, ChainRulesCore, NonconvexCore, NonconvexIpopt
using DifferentiableFlatten

include("forwarddiff_frule.jl")
include("abstractdiff.jl")
Expand Down

0 comments on commit 491ebba

Please sign in to comment.