Skip to content

Commit

Permalink
Add json resolvers for sympy relational operators (#4767)
Browse files Browse the repository at this point in the history
  • Loading branch information
daxfohl authored Dec 20, 2021
1 parent 1d8d2f3 commit e7892c3
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 9 deletions.
34 changes: 26 additions & 8 deletions cirq-core/cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,33 @@ def proper_repr(value: Any) -> str:
"""Overrides sympy and numpy returning repr strings that don't parse."""

if isinstance(value, sympy.Basic):
result = sympy.srepr(value)

# HACK: work around https://github.com/sympy/sympy/issues/16074
# (only handles a few cases)
fixed_tokens = ['Symbol', 'pi', 'Mul', 'Pow', 'Add', 'Mod', 'Integer', 'Float', 'Rational']
for token in fixed_tokens:
result = result.replace(token, 'sympy.' + token)

return result
fixed_tokens = [
'Symbol',
'pi',
'Mul',
'Pow',
'Add',
'Mod',
'Integer',
'Float',
'Rational',
'GreaterThan',
'StrictGreaterThan',
'LessThan',
'StrictLessThan',
'Equality',
'Unequality',
]

class Printer(sympy.printing.repr.ReprPrinter):
def _print(self, expr, **kwargs):
s = super()._print(expr, **kwargs)
if any(s.startswith(t) for t in fixed_tokens):
return 'sympy.' + s
return s

return Printer().doprint(value)

if isinstance(value, np.ndarray):
if np.issubdtype(value.dtype, np.datetime64):
Expand Down
6 changes: 6 additions & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def _parallel_gate_op(gate, qubits):
'sympy.Add': lambda args: sympy.Add(*args),
'sympy.Mul': lambda args: sympy.Mul(*args),
'sympy.Pow': lambda args: sympy.Pow(*args),
'sympy.GreaterThan': lambda args: sympy.GreaterThan(*args),
'sympy.StrictGreaterThan': lambda args: sympy.StrictGreaterThan(*args),
'sympy.LessThan': lambda args: sympy.LessThan(*args),
'sympy.StrictLessThan': lambda args: sympy.StrictLessThan(*args),
'sympy.Equality': lambda args: sympy.Equality(*args),
'sympy.Unequality': lambda args: sympy.Unequality(*args),
'sympy.Float': lambda approx: sympy.Float(approx),
'sympy.Integer': sympy.Integer,
'sympy.Rational': sympy.Rational,
Expand Down
15 changes: 14 additions & 1 deletion cirq-core/cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,20 @@ def default(self, o):
if isinstance(o, sympy.Symbol):
return {'cirq_type': 'sympy.Symbol', 'name': o.name}

if isinstance(o, (sympy.Add, sympy.Mul, sympy.Pow)):
if isinstance(
o,
(
sympy.Add,
sympy.Mul,
sympy.Pow,
sympy.GreaterThan,
sympy.StrictGreaterThan,
sympy.LessThan,
sympy.StrictLessThan,
sympy.Equality,
sympy.Unequality,
),
):
return {'cirq_type': f'sympy.{o.__class__.__name__}', 'args': o.args}

if isinstance(o, sympy.Integer):
Expand Down
13 changes: 13 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/sympy.Equality.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"cirq_type": "sympy.Equality",
"args": [
{
"cirq_type": "sympy.Symbol",
"name": "s"
},
{
"cirq_type": "sympy.Symbol",
"name": "t"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sympy.Equality(sympy.Symbol('s'), sympy.Symbol('t'))
13 changes: 13 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/sympy.GreaterThan.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"cirq_type": "sympy.GreaterThan",
"args": [
{
"cirq_type": "sympy.Symbol",
"name": "s"
},
{
"cirq_type": "sympy.Symbol",
"name": "t"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sympy.GreaterThan(sympy.Symbol('s'), sympy.Symbol('t'))
13 changes: 13 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/sympy.LessThan.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"cirq_type": "sympy.LessThan",
"args": [
{
"cirq_type": "sympy.Symbol",
"name": "s"
},
{
"cirq_type": "sympy.Symbol",
"name": "t"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sympy.LessThan(sympy.Symbol('s'), sympy.Symbol('t'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"cirq_type": "sympy.StrictGreaterThan",
"args": [
{
"cirq_type": "sympy.Symbol",
"name": "s"
},
{
"cirq_type": "sympy.Symbol",
"name": "t"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sympy.StrictGreaterThan(sympy.Symbol('s'), sympy.Symbol('t'))
13 changes: 13 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/sympy.StrictLessThan.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"cirq_type": "sympy.StrictLessThan",
"args": [
{
"cirq_type": "sympy.Symbol",
"name": "s"
},
{
"cirq_type": "sympy.Symbol",
"name": "t"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sympy.StrictLessThan(sympy.Symbol('s'), sympy.Symbol('t'))
13 changes: 13 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/sympy.Unequality.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"cirq_type": "sympy.Unequality",
"args": [
{
"cirq_type": "sympy.Symbol",
"name": "s"
},
{
"cirq_type": "sympy.Symbol",
"name": "t"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sympy.Unequality(sympy.Symbol('s'), sympy.Symbol('t'))

0 comments on commit e7892c3

Please sign in to comment.