Skip to content

Commit

Permalink
add expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Mar 20, 2024
1 parent fa83620 commit d6ee260
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 26 deletions.
35 changes: 21 additions & 14 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,12 @@ def _get_op_parameters(self, operation: QuantumGate) -> List[float]:
"""
param_list = []
for param in operation.arguments:
if not isinstance(param, (FloatLiteral, IntegerLiteral)):
raise ValueError(
f"Unsupported parameter type {type(param)} for operation {operation}"
)
param_list.append(float(param.value))
# if not isinstance(param, (FloatLiteral, IntegerLiteral)):
# raise ValueError(
# f"Unsupported parameter type {type(param)} for operation {operation}"
# )
param_value = self._evaluate_expression(param)
param_list.append(param_value)

if len(param_list) > 1:
raise ValueError(f"Parameterized gate {operation} with > 1 params not supported")
Expand Down Expand Up @@ -588,7 +589,7 @@ def _visit_classical_operation(self, statement: ClassicalDeclaration) -> None:
else:
raise ValueError(f"Unsupported classical type {decl_type} in {statement}")

def evaluate_expression(self, expression: Any) -> bool:
def _evaluate_expression(self, expression: Any) -> bool:
"""Evaluate an expression.
Args:
Expand All @@ -598,26 +599,32 @@ def evaluate_expression(self, expression: Any) -> bool:
bool: The result of the evaluation.
"""
if isinstance(expression, (ImaginaryLiteral, DurationLiteral)):
raise ValueError(f"Unsupported expression type {type(expression)} in if condition")
raise ValueError(f"Unsupported expression type {type(expression)}")
elif isinstance(expression, Identifier):
# we need to check our scope and context to get the value of the identifier
# if it is a classical register, we can directly get the value
# how to get the value of the identifier in the QIR??
# TO DO : extend this
raise ValueError(f"Unsupported expression type {type(expression)}")
elif isinstance(expression, BooleanLiteral):
return expression.value
elif isinstance(expression, (IntegerLiteral, FloatLiteral)):
return int(expression.value)
return expression.value
elif isinstance(expression, UnaryExpression):
op = expression.op.name # can be '!', '~' or '-'
if op == "!":
return not self.evaluate_expression(expression.expression)
return not self._evaluate_expression(expression.expression)
elif op == "-":
return -1 * self.evaluate_expression(expression.expression)
return -1 * self._evaluate_expression(expression.expression)
elif op == "~":
value = self.evaluate_expression(expression.expression)
value = self._evaluate_expression(expression.expression)
if not isinstance(value, int):
raise ValueError(f"Unsupported expression type {type(value)} in ~ operation")
elif isinstance(expression, BinaryExpression):
lhs = self.evaluate_expression(expression.lhs)
lhs = self._evaluate_expression(expression.lhs)
op = expression.op.name
rhs = self.evaluate_expression(expression.rhs)
return qasm3_expression_op_map[op](lhs, rhs)
rhs = self._evaluate_expression(expression.rhs)
return qasm3_expression_op_map(op, lhs, rhs)

def _visit_branching_statement(self, statement: BranchingStatement) -> None:
"""Visit a branching statement element.
Expand Down
49 changes: 49 additions & 0 deletions tests/qasm3_qir/converter/test_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2023 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module containing unit tests for QASM3 to QIR conversion functions.
"""
import pytest

from qbraid_qir.qasm3.convert import qasm3_to_qir
from tests.qir_utils import check_attributes, check_expressions


def test_correct_expressions():
qasm_str = """OPENQASM 3;
qubit q;
// supported
rx(1.57) q;
rz(3-2*3) q;
rz(3-2*3*(8/2)) q;
rx(-1.57) q;
rx(4%2) q;
"""

result = qasm3_to_qir(qasm_str)
generated_qir = str(result).splitlines()

check_attributes(generated_qir, 1, 0)
gates = ["rx", "rz", "rz", "rx", "rx"]
expression_values = [1.57, 3 - 2 * 3, 3 - 2 * 3 * (8 / 2), -1.57, 4 % 2]
qubits = [0, 0, 0, 0, 0]
check_expressions(generated_qir, 5, gates, expression_values, qubits)


def test_incorrect_expressions():
with pytest.raises(ValueError, match=r"Unsupported expression type .*"):
qasm3_to_qir("OPENQASM 3; qubit q; rz(1 - 2 + 32im) q;")
with pytest.raises(ValueError, match=r"Unsupported expression type .* in ~ operation"):
qasm3_to_qir("OPENQASM 3; qubit q; rx(~1.3) q;")
with pytest.raises(ValueError, match=r"Unsupported expression type .* in ~ operation"):
qasm3_to_qir("OPENQASM 3; qubit q; rx(~1.3+5im) q;")
2 changes: 1 addition & 1 deletion tests/qasm3_qir/converter/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_incorrect_single_qubit_gates():

# Invalid use of variables in gate application

with pytest.raises(ValueError, match=r"Unsupported parameter type .* for operation .*"):
with pytest.raises(ValueError, match=r"Unsupported expression type .*"):
_ = qasm3_to_qir(
"""
OPENQASM 3;
Expand Down
3 changes: 1 addition & 2 deletions tests/qasm3_qir/fixtures/resources/custom_gate_complex.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@ gate custom3(p, q) c, d {

qubit[2] q;

custom3(0.1, 0.2) q[0], q[1];
custom3(0.3, 0.4) q[0:];
custom3(0.1, 0.2) q[0:];
4 changes: 2 additions & 2 deletions tests/qasm3_qir/fixtures/resources/custom_gate_nested.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ gate custom(a,b,c) p, q{
h p;
cx p,q;
rx(a) q;
ry(0.5) q;
ry(0.5/0.1) q;
}

qubit[2] q;
custom(0.1, 0.2, 0.3) q[0], q[1];
custom(2 + 3 - 1/5, 0.1, 0.3) q[0], q[1];
2 changes: 1 addition & 1 deletion tests/qasm3_qir/fixtures/resources/custom_gate_simple.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ gate custom(a) p, q {
}

qubit[2] q;
custom(0.1) q[0], q[1];
custom(0.1+1) q[0], q[1];
48 changes: 42 additions & 6 deletions tests/qir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _validate_simple_custom_op(entry_body: List[str]):
initialize_call_string(),
single_op_call_string("h", 0),
single_op_call_string("z", 1),
rotation_call_string("rx", 0.1, 0),
rotation_call_string("rx", 1.1, 0),
double_op_call_string("cnot", 0, 1),
result_record_output_string(0),
result_record_output_string(1),
Expand All @@ -322,11 +322,11 @@ def _validate_nested_custom_op(entry_body: List[str]):
nested_op_lines = [
initialize_call_string(),
single_op_call_string("h", 1),
rotation_call_string("rz", 0.1, 1),
rotation_call_string("rz", 4.8, 1),
single_op_call_string("h", 0),
double_op_call_string("cnot", 0, 1),
rotation_call_string("rx", 0.1, 1),
rotation_call_string("ry", 0.5, 1),
rotation_call_string("rx", 4.8, 1),
rotation_call_string("ry", 5, 1),
result_record_output_string(0),
result_record_output_string(1),
return_string(),
Expand All @@ -338,8 +338,22 @@ def _validate_nested_custom_op(entry_body: List[str]):


def _validate_complex_custom_op(entry_body: List[str]):
pass
# todo...
complex_op_lines = [
initialize_call_string(),
single_op_call_string("h", 0),
single_op_call_string("x", 0),
rotation_call_string("rx", 0.5, 0),
rotation_call_string("ry", 0.1, 0),
rotation_call_string("rz", 0.2, 0),
double_op_call_string("cnot", 0, 1),
result_record_output_string(0),
result_record_output_string(1),
return_string(),
]

assert len(entry_body) == len(complex_op_lines), "Incorrect number of lines in complex op"
for i in range(len(entry_body)):
assert entry_body[i].strip() == complex_op_lines[i].strip(), "Incorrect complex op line"


def check_custom_qasm_gate_op(qir: List[str], test_type: str):
Expand All @@ -353,3 +367,25 @@ def check_custom_qasm_gate_op(qir: List[str], test_type: str):
_validate_complex_custom_op(entry_body)
else:
assert False, f"Unknown test type {test_type} for custom ops"


def check_expressions(
qir: List[str], expected_ops: int, gates: List[str], expression_values, qubits: List[int]
):
entry_body = get_entry_point_body(qir)
op_count = 0
q_id = 0

for line in entry_body:
if line.strip().startswith("call") and "qis__" in line:
assert line.strip() == rotation_call_string(
gates[q_id], expression_values[q_id], qubits[q_id]
), f"Incorrect rotation gate call in qir - {line}"
op_count += 1
q_id += 1

if op_count == expected_ops:
break

if op_count != expected_ops:
assert False, f"Incorrect rotation gate count: {expected_ops} expected, {op_count} actual"

0 comments on commit d6ee260

Please sign in to comment.