Skip to content

Commit

Permalink
add support for param update in gate param expr
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Sep 18, 2024
1 parent 3ae9480 commit bdc221c
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 9 deletions.
65 changes: 57 additions & 8 deletions qbraid_qir/qasm3/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import numpy as np
from openqasm3.ast import (
BinaryExpression,
BooleanLiteral,
DiscreteSet,
Expression,
FloatLiteral,
Identifier,
IndexedIdentifier,
IndexExpression,
Expand Down Expand Up @@ -142,20 +143,68 @@ def transform_gate_qubits(
gate_op.qubits[i] = qubit_map[gate_qubit_name]

@staticmethod
def transform_gate_params(gate_op: QuantumGate, param_map: dict[str, Expression]) -> None:
def transform_expression(expression, variable_map: dict[str, Union[int, float, bool]]):
"""Transform an expression by replacing variables with their values.
Args:
expression (Any): The expression to transform.
variable_map (dict): The mapping of variables to their values.
Returns:
expression (Any): The transformed expression.
"""
if expression is None:
return None

if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral)):
return expression

if isinstance(expression, BinaryExpression):
lhs = Qasm3Transformer.transform_expression(expression.lhs, variable_map)
rhs = Qasm3Transformer.transform_expression(expression.rhs, variable_map)
expression.lhs = lhs
expression.rhs = rhs

if isinstance(expression, UnaryExpression):
operand = Qasm3Transformer.transform_expression(expression.expression, variable_map)
expression.expression = operand

if isinstance(expression, Identifier):
if expression.name in variable_map:
value = variable_map[expression.name]
if isinstance(value, int):
return IntegerLiteral(value)
if isinstance(value, float):
return FloatLiteral(value)
if isinstance(value, bool):
return BooleanLiteral(value)

return expression

@staticmethod
def transform_gate_params(
gate_op: QuantumGate, param_map: dict[str, Union[int, float, bool]]
) -> None:
"""Transform the parameters of a gate operation with a parameter map.
Args:
gate_op (QuantumGate): The gate operation to transform.
param_map (dict[str, Expression]): The parameter map to use for transformation.
param_map (dict[str, Union[int, float, bool]]): The parameter map to use
for transformation.
Returns:
None
None: arguments are transformed in place
"""
for i, param in enumerate(gate_op.arguments):
if isinstance(param, Identifier):
gate_op.arguments[i] = param_map[param.name]
# TODO : update the arg value in expressions not just SINGLE identifiers
# gate_op.arguments is a list of "actual" arguments used in the gate call inside body

# param map is a "global dict for this gate" which contains the binding of the params
# to the actual values used in the call
for i, actual_arg in enumerate(gate_op.arguments):
# recursively replace ALL instances of the parameter in the expression
# with the actual value
print("Before transformation: ", actual_arg)
gate_op.arguments[i] = Qasm3Transformer.transform_expression(actual_arg, param_map)
print("After transformation: ", gate_op.arguments[i])

@staticmethod
def get_branch_params(condition: Any) -> tuple[int, str]:
Expand Down
2 changes: 1 addition & 1 deletion qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def _visit_custom_gate_operation(
for formal_arg, actual_arg in zip(gate_definition.qubits, op_qubits)
}
param_map = {
formal_arg.name: actual_arg
formal_arg.name: Qasm3ExprEvaluator.evaluate_expression(actual_arg)
for formal_arg, actual_arg in zip(gate_definition.arguments, operation.arguments)
}

Expand Down
31 changes: 31 additions & 0 deletions tests/qasm3_qir/converter/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,37 @@ def test_three_qubit_qasm3_gates(circuit_name, request):
check_three_qubit_gate_op(generated_qir, 2, qubit_list, gate_name)


def test_gate_body_param_expression():
qasm3_str = """
OPENQASM 3;
include "stdgates.inc";
gate my_gate_2(p) q {
ry(p * 2) q;
}
gate my_gate(a, b, c) q {
rx(5 * a) q;
rz(2 * b / a) q;
my_gate_2(a) q;
rx(!a) q; // not a = False
rx(c) q;
}
qubit q;
int[32] m = 3;
float[32] n = 6.0;
bool o = true;
my_gate(m, n, o) q;
"""
result = qasm3_to_qir(qasm3_str)
generated_qir = str(result).splitlines()
check_attributes(generated_qir, 1, 0)
check_single_qubit_rotation_op(generated_qir, 3, [0, 0, 0], [5 * 3, 0.0, True], "rx")
check_single_qubit_rotation_op(generated_qir, 1, [0], [2 * 6.0 / 3], "rz")
check_single_qubit_rotation_op(generated_qir, 1, [0], [3 * 2], "ry")


def test_id_gate():
qasm3_string = """
OPENQASM 3;
Expand Down

0 comments on commit bdc221c

Please sign in to comment.