diff --git a/qbraid_qir/qasm3/convert.py b/qbraid_qir/qasm3/convert.py index d33bf5c..c05af29 100644 --- a/qbraid_qir/qasm3/convert.py +++ b/qbraid_qir/qasm3/convert.py @@ -50,13 +50,6 @@ def qasm3_to_qir( # Supported conversions qasm3 -> qiskit : # https://github.com/Qiskit/qiskit-qasm3-import/blob/main/src/qiskit_qasm3_import/converter.py - # PROPOSED SEMANTIC + DECOMPOSITION PASS - # qiskit_circuit = loads(program).decompose(reps=3) - # decomposed_qasm = Exporter().dumps(qiskit_circuit) - # PROPOSED SEMANTIC + DECOMPOSITION PASS - - # program = openqasm3.parse(decomposed_qasm) - program = openqasm3.parse(program) elif not isinstance(program, openqasm3.ast.Program): diff --git a/qbraid_qir/qasm3/preprocess.py b/qbraid_qir/qasm3/preprocess.py index 4adf30b..9bf5fdb 100644 --- a/qbraid_qir/qasm3/preprocess.py +++ b/qbraid_qir/qasm3/preprocess.py @@ -17,34 +17,3 @@ # can use rigetti simulator for qir programs and qiskit for qasm3 programs - - -# what are some problems which we face? - -# 1. let us just unfold the gates once. -# assumption - we do not define any variables inside the gate -# - gate parameters are directly passed on to other gate calls -# - no if else conditions inside the gate -# - no loops inside the gate -# - no function calls inside the gate - -# eg. gate x2(a,b,c) p, q { -# h p; -# h q; -# rx(a) p; -# ry(b) q; -# rz(c) p; -# cx p, q; -# } -# qubit[2] q; -# x2(1,2,3) q[0], q[1]; - -# will be converted to - -# qubit[2] q; -# h q[0]; -# h q[1]; -# rx(1) q[0]; -# ry(2) q[1]; -# rz(3) q[0]; -# cx q[0], q[1]; diff --git a/qbraid_qir/qasm3/visitor.py b/qbraid_qir/qasm3/visitor.py index 4a8e3a5..2769f5d 100644 --- a/qbraid_qir/qasm3/visitor.py +++ b/qbraid_qir/qasm3/visitor.py @@ -277,7 +277,6 @@ def _visit_measurement(self, statement: QuantumMeasurementStatement) -> None: target = statement.target source_id, target_id = None, None - # handle measurement operation source_name = source.name if isinstance(source, IndexedIdentifier): source_name = source.name.name @@ -318,7 +317,6 @@ def _build_qir_measurement( pyqir._native.mz(self._builder, source_qubit, result) if source_id is None and target_id is None: - # sizes should match if self._qreg_size_map[source_name] != self._creg_size_map[target_name]: raise ValueError( f"Register sizes of {source_name} and {target_name} do not match for measurement operation" @@ -365,7 +363,6 @@ def _visit_reset(self, statement: QuantumReset) -> None: qreg_size = self._qreg_size_map[qreg_name] qubit_ids = [self._qubit_labels[f"{qreg_name}_{i}"] for i in range(qreg_size)] - # generate pyqir reset equivalent for qid in qubit_ids: pyqir._native.reset(self._builder, pyqir.qubit(self._module.context, qid)) @@ -402,11 +399,8 @@ def _get_op_parameters(self, operation: QuantumGate) -> List[float]: """ param_list = [] for param in operation.arguments: - if not isinstance(param, (FloatLiteral, IntegerLiteral)): - raise ValueError( - f"Unsupported parameter type {type(param)} for operation {operation}" - ) - param_list.append(float(param.value)) + param_value = self._evaluate_expression(param) + param_list.append(param_value) if len(param_list) > 1: raise ValueError(f"Parameterized gate {operation} with > 1 params not supported") @@ -414,6 +408,14 @@ def _get_op_parameters(self, operation: QuantumGate) -> List[float]: return param_list def _visit_gate_definition(self, definition: QuantumGateDefinition) -> None: + """Visit a gate definition element. + + Args: + definition (QuantumGateDefinition): The gate definition to visit. + + Returns: + None + """ gate_name = definition.name.name if gate_name in self._custom_gates: raise ValueError(f"Duplicate gate definition for {gate_name}") @@ -429,9 +431,8 @@ def _visit_basic_gate_operation(self, operation: QuantumGate) -> None: None """ - # Currently handling the gates in the stdgates.inc file _log.debug("Visiting basic gate operation '%s'", str(operation)) - op_name = operation.name.name + op_name : str = operation.name.name op_qubits = self._get_op_qubits(operation) qir_func, op_qubit_count = map_qasm_op_to_pyqir_callable(op_name) op_parameters = None @@ -466,7 +467,6 @@ def _transform_gate_qubits(self, gate_op, qubit_map): for i, qubit in enumerate(gate_op.qubits): if isinstance(qubit, IndexedIdentifier): raise ValueError(f"Indexing {qubit} not supported in gate definition") - # now we have an Identifier gate_op.qubits[i] = qubit_map[qubit.name] def _transform_gate_params(self, gate_op, param_map): @@ -480,7 +480,6 @@ def _transform_gate_params(self, gate_op, param_map): None """ for i, param in enumerate(gate_op.arguments): - # replace only if we have an Identifier if isinstance(param, Identifier): gate_op.arguments[i] = param_map[param.name] @@ -493,8 +492,8 @@ def _visit_custom_gate_operation(self, operation: QuantumGate) -> None: Returns:in computation """ _log.debug("Visiting custom gate operation '%s'", str(operation)) - gate_name = operation.name.name - gate_definition = self._custom_gates[gate_name] + gate_name : str = operation.name.name + gate_definition : QuantumGateDefinition = self._custom_gates[gate_name] if len(operation.arguments) != len(gate_definition.arguments): raise ValueError( @@ -588,7 +587,7 @@ def _visit_classical_operation(self, statement: ClassicalDeclaration) -> None: else: raise ValueError(f"Unsupported classical type {decl_type} in {statement}") - def evaluate_expression(self, expression: Any) -> bool: + def _evaluate_expression(self, expression: Any) -> bool: """Evaluate an expression. Args: @@ -598,26 +597,32 @@ def evaluate_expression(self, expression: Any) -> bool: bool: The result of the evaluation. """ if isinstance(expression, (ImaginaryLiteral, DurationLiteral)): - raise ValueError(f"Unsupported expression type {type(expression)} in if condition") + raise ValueError(f"Unsupported expression type {type(expression)}") + elif isinstance(expression, Identifier): + # we need to check our scope and context to get the value of the identifier + # if it is a classical register, we can directly get the value + # how to get the value of the identifier in the QIR?? + # TO DO : extend this + raise ValueError(f"Unsupported expression type {type(expression)}") elif isinstance(expression, BooleanLiteral): return expression.value elif isinstance(expression, (IntegerLiteral, FloatLiteral)): - return int(expression.value) + return expression.value elif isinstance(expression, UnaryExpression): op = expression.op.name # can be '!', '~' or '-' if op == "!": - return not self.evaluate_expression(expression.expression) + return not self._evaluate_expression(expression.expression) elif op == "-": - return -1 * self.evaluate_expression(expression.expression) + return -1 * self._evaluate_expression(expression.expression) elif op == "~": - value = self.evaluate_expression(expression.expression) + value = self._evaluate_expression(expression.expression) if not isinstance(value, int): raise ValueError(f"Unsupported expression type {type(value)} in ~ operation") elif isinstance(expression, BinaryExpression): - lhs = self.evaluate_expression(expression.lhs) + lhs = self._evaluate_expression(expression.lhs) op = expression.op.name - rhs = self.evaluate_expression(expression.rhs) - return qasm3_expression_op_map[op](lhs, rhs) + rhs = self._evaluate_expression(expression.rhs) + return qasm3_expression_op_map(op, lhs, rhs) def _visit_branching_statement(self, statement: BranchingStatement) -> None: """Visit a branching statement element. diff --git a/tests/qasm3_qir/converter/test_expressions.py b/tests/qasm3_qir/converter/test_expressions.py new file mode 100644 index 0000000..20b1bb3 --- /dev/null +++ b/tests/qasm3_qir/converter/test_expressions.py @@ -0,0 +1,49 @@ +# Copyright (C) 2023 qBraid +# +# This file is part of the qBraid-SDK +# +# The qBraid-SDK is free software released under the GNU General Public License v3 +# or later. You can redistribute and/or modify it under the terms of the GPL v3. +# See the LICENSE file in the project root or . +# +# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3. + +""" +Module containing unit tests for QASM3 to QIR conversion functions. + +""" +import pytest + +from qbraid_qir.qasm3.convert import qasm3_to_qir +from tests.qir_utils import check_attributes, check_expressions + + +def test_correct_expressions(): + qasm_str = """OPENQASM 3; + qubit q; + + // supported + rx(1.57) q; + rz(3-2*3) q; + rz(3-2*3*(8/2)) q; + rx(-1.57) q; + rx(4%2) q; + """ + + result = qasm3_to_qir(qasm_str) + generated_qir = str(result).splitlines() + + check_attributes(generated_qir, 1, 0) + gates = ["rx", "rz", "rz", "rx", "rx"] + expression_values = [1.57, 3 - 2 * 3, 3 - 2 * 3 * (8 / 2), -1.57, 4 % 2] + qubits = [0, 0, 0, 0, 0] + check_expressions(generated_qir, 5, gates, expression_values, qubits) + + +def test_incorrect_expressions(): + with pytest.raises(ValueError, match=r"Unsupported expression type .*"): + qasm3_to_qir("OPENQASM 3; qubit q; rz(1 - 2 + 32im) q;") + with pytest.raises(ValueError, match=r"Unsupported expression type .* in ~ operation"): + qasm3_to_qir("OPENQASM 3; qubit q; rx(~1.3) q;") + with pytest.raises(ValueError, match=r"Unsupported expression type .* in ~ operation"): + qasm3_to_qir("OPENQASM 3; qubit q; rx(~1.3+5im) q;") diff --git a/tests/qasm3_qir/converter/test_gates.py b/tests/qasm3_qir/converter/test_gates.py index afc710b..d34253f 100644 --- a/tests/qasm3_qir/converter/test_gates.py +++ b/tests/qasm3_qir/converter/test_gates.py @@ -149,7 +149,7 @@ def test_incorrect_single_qubit_gates(): # Invalid use of variables in gate application - with pytest.raises(ValueError, match=r"Unsupported parameter type .* for operation .*"): + with pytest.raises(ValueError, match=r"Unsupported expression type .*"): _ = qasm3_to_qir( """ OPENQASM 3; diff --git a/tests/qasm3_qir/fixtures/resources/custom_gate_complex.qasm b/tests/qasm3_qir/fixtures/resources/custom_gate_complex.qasm index dd74d70..5652971 100644 --- a/tests/qasm3_qir/fixtures/resources/custom_gate_complex.qasm +++ b/tests/qasm3_qir/fixtures/resources/custom_gate_complex.qasm @@ -20,5 +20,4 @@ gate custom3(p, q) c, d { qubit[2] q; -custom3(0.1, 0.2) q[0], q[1]; -custom3(0.3, 0.4) q[0:]; \ No newline at end of file +custom3(0.1, 0.2) q[0:]; \ No newline at end of file diff --git a/tests/qasm3_qir/fixtures/resources/custom_gate_nested.qasm b/tests/qasm3_qir/fixtures/resources/custom_gate_nested.qasm index eda0dab..918c5d9 100644 --- a/tests/qasm3_qir/fixtures/resources/custom_gate_nested.qasm +++ b/tests/qasm3_qir/fixtures/resources/custom_gate_nested.qasm @@ -10,8 +10,8 @@ gate custom(a,b,c) p, q{ h p; cx p,q; rx(a) q; - ry(0.5) q; + ry(0.5/0.1) q; } qubit[2] q; -custom(0.1, 0.2, 0.3) q[0], q[1]; +custom(2 + 3 - 1/5, 0.1, 0.3) q[0], q[1]; diff --git a/tests/qasm3_qir/fixtures/resources/custom_gate_simple.qasm b/tests/qasm3_qir/fixtures/resources/custom_gate_simple.qasm index e84e3da..acd4a0c 100644 --- a/tests/qasm3_qir/fixtures/resources/custom_gate_simple.qasm +++ b/tests/qasm3_qir/fixtures/resources/custom_gate_simple.qasm @@ -9,4 +9,4 @@ gate custom(a) p, q { } qubit[2] q; -custom(0.1) q[0], q[1]; \ No newline at end of file +custom(0.1+1) q[0], q[1]; \ No newline at end of file diff --git a/tests/qir_utils.py b/tests/qir_utils.py index 050f1ae..d4d84a1 100644 --- a/tests/qir_utils.py +++ b/tests/qir_utils.py @@ -306,7 +306,7 @@ def _validate_simple_custom_op(entry_body: List[str]): initialize_call_string(), single_op_call_string("h", 0), single_op_call_string("z", 1), - rotation_call_string("rx", 0.1, 0), + rotation_call_string("rx", 1.1, 0), double_op_call_string("cnot", 0, 1), result_record_output_string(0), result_record_output_string(1), @@ -322,11 +322,11 @@ def _validate_nested_custom_op(entry_body: List[str]): nested_op_lines = [ initialize_call_string(), single_op_call_string("h", 1), - rotation_call_string("rz", 0.1, 1), + rotation_call_string("rz", 4.8, 1), single_op_call_string("h", 0), double_op_call_string("cnot", 0, 1), - rotation_call_string("rx", 0.1, 1), - rotation_call_string("ry", 0.5, 1), + rotation_call_string("rx", 4.8, 1), + rotation_call_string("ry", 5, 1), result_record_output_string(0), result_record_output_string(1), return_string(), @@ -338,8 +338,22 @@ def _validate_nested_custom_op(entry_body: List[str]): def _validate_complex_custom_op(entry_body: List[str]): - pass - # todo... + complex_op_lines = [ + initialize_call_string(), + single_op_call_string("h", 0), + single_op_call_string("x", 0), + rotation_call_string("rx", 0.5, 0), + rotation_call_string("ry", 0.1, 0), + rotation_call_string("rz", 0.2, 0), + double_op_call_string("cnot", 0, 1), + result_record_output_string(0), + result_record_output_string(1), + return_string(), + ] + + assert len(entry_body) == len(complex_op_lines), "Incorrect number of lines in complex op" + for i in range(len(entry_body)): + assert entry_body[i].strip() == complex_op_lines[i].strip(), "Incorrect complex op line" def check_custom_qasm_gate_op(qir: List[str], test_type: str): @@ -353,3 +367,25 @@ def check_custom_qasm_gate_op(qir: List[str], test_type: str): _validate_complex_custom_op(entry_body) else: assert False, f"Unknown test type {test_type} for custom ops" + + +def check_expressions( + qir: List[str], expected_ops: int, gates: List[str], expression_values, qubits: List[int] +): + entry_body = get_entry_point_body(qir) + op_count = 0 + q_id = 0 + + for line in entry_body: + if line.strip().startswith("call") and "qis__" in line: + assert line.strip() == rotation_call_string( + gates[q_id], expression_values[q_id], qubits[q_id] + ), f"Incorrect rotation gate call in qir - {line}" + op_count += 1 + q_id += 1 + + if op_count == expected_ops: + break + + if op_count != expected_ops: + assert False, f"Incorrect rotation gate count: {expected_ops} expected, {op_count} actual"