diff --git a/src/Initialization/init_SymEngine.jl b/src/Initialization/init_SymEngine.jl index 3485773471..328039d01a 100644 --- a/src/Initialization/init_SymEngine.jl +++ b/src/Initialization/init_SymEngine.jl @@ -69,6 +69,21 @@ julia> free_symbols(:(x1 + x2 <= 2*x4 + 6), HalfSpace) """ function free_symbols(::Expr, ::Type{<:LazySet}) end # COV_EXCL_LINE +# parse `a` and `b` from `a1 x1 + ... + an xn + K [cmp] 0` for [cmp] in {<, <=, =, >, >=} +function _parse_linear_expression(linexpr::Basic, vars::Vector{Basic}, N) + if isempty(vars) + vars = SymEngine.free_symbols(linexpr) + end + b = SymEngine.subs(linexpr, [vi => zero(N) for vi in vars]...) + a = convert(Basic, linexpr - b) + + # convert to correct numeric type + a = convert(Vector{N}, diff.(a, vars)) + b = convert(N, b) + + return a, b +end + # Note: this convenience function is not used anywhere function _free_symbols(expr::Expr) if _is_hyperplane(expr) diff --git a/src/Interfaces/AbstractPolyhedron.jl b/src/Interfaces/AbstractPolyhedron.jl index 1f1100f9f2..3e75a73137 100644 --- a/src/Interfaces/AbstractPolyhedron.jl +++ b/src/Interfaces/AbstractPolyhedron.jl @@ -75,11 +75,10 @@ function _linear_map_hrep_helper(M::AbstractMatrix, P::LazySet, return HPolyhedron(constraints) end -# internal function; defined here due to dependency SymEngine and submodules -function _is_halfspace() end - -# internal function; defined here due to dependency SymEngine and submodules -function _is_hyperplane() end +# internal functions; defined here due to dependency SymEngine and submodules +function _is_halfspace end +function _is_hyperplane end +function _parse_linear_expression end # To account for the compilation order, other functions are defined in the file # AbstractPolyhedron_functions.jl diff --git a/src/Sets/HalfSpace/convert.jl b/src/Sets/HalfSpace/convert.jl index 7c5ce7ad71..92d353f0bc 100644 --- a/src/Sets/HalfSpace/convert.jl +++ b/src/Sets/HalfSpace/convert.jl @@ -7,6 +7,7 @@ end function load_SymEngine_convert_HalfSpace() return quote using .SymEngine: Basic + using ..LazySets: _parse_linear_expression """ convert(::Type{HalfSpace{N}}, expr::Expr; vars=nothing) where {N} @@ -53,31 +54,18 @@ function load_SymEngine_convert_HalfSpace() ``` """ function convert(::Type{HalfSpace{N}}, expr::Expr; vars::Vector{Basic}=Basic[]) where {N} - @assert _is_halfspace(expr) "the expression :(expr) does not correspond to a half-space" - - # check sense of the inequality, assuming < or <= by default - got_geq = expr.args[1] in [:(>=), :(>)] - - # get sides of the inequality - lhs, rhs = convert(Basic, expr.args[2]), convert(Basic, expr.args[3]) - - # a1 x1 + ... + an xn + K [cmp] 0 for cmp in <, <=, >, >= - eq = lhs - rhs - if isempty(vars) - vars = SymEngine.free_symbols(eq) - end - K = SymEngine.subs(eq, [vi => zero(N) for vi in vars]...) - a = convert(Basic, eq - K) - - # convert to numeric types - K = convert(N, K) - a = convert(Vector{N}, diff.(a, vars)) - - if got_geq - return HalfSpace(-a, K) - else - return HalfSpace(a, -K) - end + @assert _is_halfspace(expr) "the expression $expr does not correspond to a half-space" + + # convert to SymEngine expressions + linexpr, cmp = _parse_halfspace(expr) + + # check sense of the inequality, assuming < or <= by default (checked before) + got_geq = cmp ∈ (:(>=), :(>)) + + # `a1 x1 + ... + an xn + b [cmp] 0` for [cmp] ∈ {<, <=, >, >=} + a, b = _parse_linear_expression(linexpr, vars, N) + + return got_geq ? HalfSpace(-a, b) : HalfSpace(a, -b) end # type-less default half-space conversion diff --git a/src/Sets/HalfSpace/init_SymEngine.jl b/src/Sets/HalfSpace/init_SymEngine.jl index 30402768f6..d4b84c73ac 100644 --- a/src/Sets/HalfSpace/init_SymEngine.jl +++ b/src/Sets/HalfSpace/init_SymEngine.jl @@ -1,9 +1,15 @@ import .SymEngine: free_symbols +function _parse_halfspace(expr::Expr) + lhs = convert(SymEngine.Basic, expr.args[2]) + rhs = convert(SymEngine.Basic, expr.args[3]) + cmp = expr.args[1] + return (lhs - rhs, cmp) +end + function free_symbols(expr::Expr, ::Type{<:HalfSpace}) - # get sides of the inequality - lhs, rhs = convert(SymEngine.Basic, expr.args[2]), convert(SymEngine.Basic, expr.args[3]) - return SymEngine.free_symbols(lhs - rhs) + linexpr, _ = _parse_halfspace(expr) + return SymEngine.free_symbols(linexpr) end eval(load_SymEngine_ishalfspace()) diff --git a/src/Sets/HalfSpace/ishalfspace.jl b/src/Sets/HalfSpace/ishalfspace.jl index 34502a0592..25ad2ddc8c 100644 --- a/src/Sets/HalfSpace/ishalfspace.jl +++ b/src/Sets/HalfSpace/ishalfspace.jl @@ -1,7 +1,5 @@ function load_SymEngine_ishalfspace() return quote - using .SymEngine: Basic - import .SymEngine: free_symbols using ..LazySets: _is_linearcombination """ @@ -45,23 +43,22 @@ function load_SymEngine_ishalfspace() ``` """ function _is_halfspace(expr::Expr)::Bool - - # check that there are three arguments - # these are the comparison symbol, the left hand side and the right hand side + # check that there are three arguments: + # the comparison symbol, the left-hand side and the right-hand side if (length(expr.args) != 3) || !(expr.head == :call) return false end + # convert to SymEngine expression + linexpr, cmp = _parse_halfspace(expr) + # check that this is an inequality - if !(expr.args[1] in [:(<=), :(<), :(>=), :(>)]) + if cmp ∉ [:(<=), :(<), :(>=), :(>)] return false end - # convert to symengine expressions - lhs, rhs = convert(Basic, expr.args[2]), convert(Basic, expr.args[3]) - # check if the expression defines a half-space - return _is_linearcombination(lhs) && _is_linearcombination(rhs) + return _is_linearcombination(linexpr) end end end # load_SymEngine_ishalfspace diff --git a/src/Sets/Hyperplane/convert.jl b/src/Sets/Hyperplane/convert.jl index b81658db09..c954a136fb 100644 --- a/src/Sets/Hyperplane/convert.jl +++ b/src/Sets/Hyperplane/convert.jl @@ -1,6 +1,7 @@ function load_SymEngine_convert_Hyperplane() return quote using .SymEngine: Basic + using ..LazySets: _parse_linear_expression """ convert(::Type{Hyperplane{N}}, expr::Expr; vars=nothing) where {N} @@ -41,28 +42,15 @@ function load_SymEngine_convert_Hyperplane() ``` """ function convert(::Type{Hyperplane{N}}, expr::Expr; vars::Vector{Basic}=Basic[]) where {N} - @assert _is_hyperplane(expr) "the expression :(expr) does not correspond to a Hyperplane" + @assert _is_hyperplane(expr) "the expression $expr does not correspond to a Hyperplane" - # get sides of the inequality - lhs = convert(Basic, expr.args[1]) + # convert to SymEngine expression + linexpr = _parse_hyperplane(expr) - # treats the 4 in :(2*x1 = 4) - rhs = :args in fieldnames(typeof(expr.args[2])) ? convert(Basic, expr.args[2].args[2]) : - convert(Basic, expr.args[2]) + # a1 x1 + ... + an xn + b = 0 + a, b = _parse_linear_expression(linexpr, vars, N) - # a1 x1 + ... + an xn + K = 0 - eq = lhs - rhs - if isempty(vars) - vars = SymEngine.free_symbols(eq) - end - K = SymEngine.subs(eq, [vi => zero(N) for vi in vars]...) - a = convert(Basic, eq - K) - - # convert to numeric types - K = convert(N, K) - a = convert(Vector{N}, diff.(a, vars)) - - return Hyperplane(a, -K) + return Hyperplane(a, -b) end # type-less default Hyperplane conversion diff --git a/src/Sets/Hyperplane/init_SymEngine.jl b/src/Sets/Hyperplane/init_SymEngine.jl index c0926b6506..61d9d0aa59 100644 --- a/src/Sets/Hyperplane/init_SymEngine.jl +++ b/src/Sets/Hyperplane/init_SymEngine.jl @@ -1,14 +1,16 @@ import .SymEngine: free_symbols -function free_symbols(expr::Expr, ::Type{<:Hyperplane}) - # get sides of the equality +function _parse_hyperplane(expr::Expr) lhs = convert(SymEngine.Basic, expr.args[1]) - - # treats the 4 in :(2*x1 = 4) rhs = :args in fieldnames(typeof(expr.args[2])) ? convert(SymEngine.Basic, expr.args[2].args[2]) : convert(SymEngine.Basic, expr.args[2]) - return SymEngine.free_symbols(lhs - rhs) + return lhs - rhs +end + +function free_symbols(expr::Expr, ::Type{<:Hyperplane}) + linexpr = _parse_hyperplane(expr) + return SymEngine.free_symbols(linexpr) end eval(load_SymEngine_ishyperplanar()) diff --git a/src/Sets/Hyperplane/is_hyperplanar.jl b/src/Sets/Hyperplane/is_hyperplanar.jl index 758474fa1b..ab8577c1b8 100644 --- a/src/Sets/Hyperplane/is_hyperplanar.jl +++ b/src/Sets/Hyperplane/is_hyperplanar.jl @@ -4,7 +4,6 @@ end function load_SymEngine_ishyperplanar() return quote - using .SymEngine: Basic using ..LazySets: _is_linearcombination """ @@ -48,25 +47,17 @@ function load_SymEngine_ishyperplanar() ``` """ function _is_hyperplane(expr::Expr)::Bool - - # check that there are three arguments - # these are the comparison symbol, the left hand side and the right hand side + # check that the head is `=` and there are two arguments: + # the left-hand side and the right-hand side if (length(expr.args) != 2) || !(expr.head == :(=)) return false end - # convert to symengine expressions - lhs = convert(Basic, expr.args[1]) - - if :args in fieldnames(typeof(expr.args[2])) - # treats the 4 in :(2*x1 = 4) - rhs = convert(Basic, expr.args[2].args[2]) - else - rhs = convert(Basic, expr.args[2]) - end + # convert to SymEngine expression + linexpr = _parse_hyperplane(expr) # check if the expression defines a hyperplane - return _is_linearcombination(lhs) && _is_linearcombination(rhs) + return _is_linearcombination(linexpr) end end end # load_SymEngine_ishyperplanar