diff --git a/qbraid_qir/qasm3/__init__.py b/qbraid_qir/qasm3/__init__.py index 6bff53f..c76b81d 100644 --- a/qbraid_qir/qasm3/__init__.py +++ b/qbraid_qir/qasm3/__init__.py @@ -41,9 +41,16 @@ Qasm3ConversionError """ +from .checker import semantic_check from .convert import qasm3_to_qir from .elements import Qasm3Module from .exceptions import Qasm3ConversionError from .visitor import BasicQasmVisitor -__all__ = ["qasm3_to_qir", "Qasm3Module", "Qasm3ConversionError", "BasicQasmVisitor"] +__all__ = [ + "semantic_check", + "qasm3_to_qir", + "Qasm3Module", + "Qasm3ConversionError", + "BasicQasmVisitor", +] diff --git a/qbraid_qir/qasm3/checker.py b/qbraid_qir/qasm3/checker.py new file mode 100644 index 0000000..de2e921 --- /dev/null +++ b/qbraid_qir/qasm3/checker.py @@ -0,0 +1,65 @@ +# Copyright (C) 2024 qBraid +# +# This file is part of the qBraid-SDK +# +# The qBraid-SDK is free software released under the GNU General Public License v3 +# or later. You can redistribute and/or modify it under the terms of the GPL v3. +# See the LICENSE file in the project root or . +# +# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3. + +""" +Module containing OpenQASM semantic checker function + +""" +from typing import Optional, Union + +import openqasm3 +from pyqir import Context, qir_module + +from .elements import Qasm3Module, generate_module_id +from .exceptions import Qasm3ConversionError +from .visitor import BasicQasmVisitor + + +def semantic_check( + program: Union[openqasm3.ast.Program, str], + name: Optional[str] = None, + **kwargs, +) -> tuple[bool, Optional[Exception]]: + """Validates a given OpenQASM 3 program for semantic correctness. + + Args: + program (openqasm3.ast.Program or str): The OpenQASM 3 program to convert. + name (str, optional): Identifier for created QIR module. Auto-generated if not provided. + + Keyword Args: + initialize_runtime (bool): Whether to perform quantum runtime environment initialization, + default `True`. + record_output (bool): Whether to record output calls for registers, default `True` + + Returns: + bool : True if the program is semantically correct, False otherwise. + + Raises: + Exception: If the input is not a valid OpenQASM 3 program. + """ + if isinstance(program, str): + program = openqasm3.parse(program) + + elif not isinstance(program, openqasm3.ast.Program): + raise TypeError("Input quantum program must be of type openqasm3.ast.Program or str.") + + if name is None: + name = generate_module_id() + + llvm_module = qir_module(Context(), name) + module = Qasm3Module.from_program(program, llvm_module) + + try: + visitor = BasicQasmVisitor(check_only=True, **kwargs) + module.accept(visitor) + except (Qasm3ConversionError, TypeError, ValueError) as err: + return (False, err) + + return (True, None) diff --git a/qbraid_qir/qasm3/visitor.py b/qbraid_qir/qasm3/visitor.py index ee6086a..b053b71 100644 --- a/qbraid_qir/qasm3/visitor.py +++ b/qbraid_qir/qasm3/visitor.py @@ -65,7 +65,9 @@ class BasicQasmVisitor(ProgramElementVisitor): record_output (bool): If True, output of the circuit will be recorded. Defaults to True. """ - def __init__(self, initialize_runtime: bool = True, record_output: bool = True): + def __init__( + self, initialize_runtime: bool = True, record_output: bool = True, check_only: bool = False + ): self._module: pyqir.Module self._builder: pyqir.Builder self._entry_point: str = "" @@ -81,6 +83,7 @@ def __init__(self, initialize_runtime: bool = True, record_output: bool = True): self._subroutine_defns: dict[str, qasm3_ast.SubroutineDefinition] = {} self._initialize_runtime: bool = initialize_runtime self._record_output: bool = record_output + self._check_only: bool = check_only self._curr_scope: int = 0 self._label_scope_level: dict[int, set] = {self._curr_scope: set()} @@ -521,9 +524,9 @@ def _visit_measurement(self, statement: qasm3_ast.QuantumMeasurementStatement) - "for measurement operation", span=statement.span, ) - - for src_id, tgt_id in zip(source_ids, target_ids): - pyqir._native.mz(self._builder, src_id, tgt_id) # type: ignore[arg-type] + if not self._check_only: + for src_id, tgt_id in zip(source_ids, target_ids): + pyqir._native.mz(self._builder, src_id, tgt_id) # type: ignore[arg-type] def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> None: """Visit a reset statement element. @@ -546,9 +549,10 @@ def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> None: ) qubit_ids = self._get_op_bits(statement, self._global_qreg_size_map, True) - for qid in qubit_ids: - # qid is of type Constant which is inherited from Value, so we ignore the type error - pyqir._native.reset(self._builder, qid) # type: ignore[arg-type] + if not self._check_only: + for qid in qubit_ids: + # qid is of type Constant which is inherited from Value, so we ignore the type error + pyqir._native.reset(self._builder, qid) # type: ignore[arg-type] def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> None: """Visit a barrier statement element. @@ -575,7 +579,8 @@ def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> None: barrier_qubits = self._get_op_bits(barrier, self._global_qreg_size_map) total_qubit_count = sum(self._global_qreg_size_map.values()) if len(barrier_qubits) == total_qubit_count: - pyqir._native.barrier(self._builder) + if not self._check_only: + pyqir._native.barrier(self._builder) else: raise_qasm3_error( "Barrier operation on a qubit subset is not supported in pyqir", @@ -662,13 +667,14 @@ def _visit_basic_gate_operation( if inverse_action == InversionOp.INVERT_ROTATION: op_parameters = [-1 * param for param in op_parameters] - for i in range(0, len(op_qubits), op_qubit_count): - # we apply the gate on the qubit subset linearly - qubit_subset = op_qubits[i : i + op_qubit_count] - if op_parameters is not None: - qir_func(self._builder, *op_parameters, *qubit_subset) - else: - qir_func(self._builder, *qubit_subset) + if not self._check_only: + for i in range(0, len(op_qubits), op_qubit_count): + # we apply the gate on the qubit subset linearly + qubit_subset = op_qubits[i : i + op_qubit_count] + if op_parameters is not None: + qir_func(self._builder, *op_parameters, *qubit_subset) + else: + qir_func(self._builder, *qubit_subset) def _visit_custom_gate_operation( self, operation: qasm3_ast.QuantumGate, inverse: bool = False @@ -1082,13 +1088,14 @@ def _visit_statement_block(block): for stmt in block: self.visit_statement(stmt) - # if the condition is true, we visit the if block - pyqir._native.if_result( - self._builder, - pyqir.result(self._module.context, self._clbit_labels[f"{reg_name}_{reg_id}"]), - zero=lambda: _visit_statement_block(else_block), - one=lambda: _visit_statement_block(if_block), - ) + if not self._check_only: + # if the condition is true, we visit the if block + pyqir._native.if_result( + self._builder, + pyqir.result(self._module.context, self._clbit_labels[f"{reg_name}_{reg_id}"]), + zero=lambda: _visit_statement_block(else_block), + one=lambda: _visit_statement_block(if_block), + ) del self._label_scope_level[self._curr_scope] self._curr_scope -= 1 @@ -1145,6 +1152,11 @@ def _visit_forin_loop(self, statement: qasm3_ast.ForInLoop) -> None: self._pop_scope() self._restore_context() + # as we are only checking compile time errors + # not runtime errors, we can break here + if self._check_only: + break + def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition) -> None: """Visit a subroutine definition element. Reference: https://openqasm.com/language/subroutines.html#subroutines diff --git a/tests/qasm3_qir/test_checker.py b/tests/qasm3_qir/test_checker.py new file mode 100644 index 0000000..30a8f68 --- /dev/null +++ b/tests/qasm3_qir/test_checker.py @@ -0,0 +1,51 @@ +# Copyright (C) 2023 qBraid +# +# This file is part of the qBraid-SDK +# +# The qBraid-SDK is free software released under the GNU General Public License v3 +# or later. You can redistribute and/or modify it under the terms of the GPL v3. +# See the LICENSE file in the project root or . +# +# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3. + +""" +Tests the checker module of qasm3 + +""" + +import pytest + +from qbraid_qir.qasm3.checker import semantic_check +from qbraid_qir.qasm3.exceptions import Qasm3ConversionError + + +def test_correct_check(): + check, err = semantic_check("OPENQASM 3; include 'stdgates.inc'; qubit q;") + assert check + assert err is None + + +def test_incorrect_check(): + check, err = semantic_check( + """ + //semantically incorrect program + OPENQASM 3; + include 'stdgates.inc'; + qubit q; + rx(3.14) q[2]; + """, + name="test", + ) + assert not check + assert err is not None + assert isinstance(err, Qasm3ConversionError) + + +def test_incorrect_program_type(): + with pytest.raises( + TypeError, match="Input quantum program must be of type openqasm3.ast.Program or str." + ): + check, err = semantic_check(1234) + assert not check + assert err is not None + assert isinstance(err, TypeError)