Skip to content

Commit

Permalink
add semantic checker for qasm
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Sep 16, 2024
1 parent 3ae9480 commit c35b860
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 23 deletions.
9 changes: 8 additions & 1 deletion qbraid_qir/qasm3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@
Qasm3ConversionError
"""
from .checker import semantic_check
from .convert import qasm3_to_qir
from .elements import Qasm3Module
from .exceptions import Qasm3ConversionError
from .visitor import BasicQasmVisitor

__all__ = ["qasm3_to_qir", "Qasm3Module", "Qasm3ConversionError", "BasicQasmVisitor"]
__all__ = [
"semantic_check",
"qasm3_to_qir",
"Qasm3Module",
"Qasm3ConversionError",
"BasicQasmVisitor",
]
65 changes: 65 additions & 0 deletions qbraid_qir/qasm3/checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (C) 2024 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module containing OpenQASM semantic checker function
"""
from typing import Optional, Union

import openqasm3
from pyqir import Context, qir_module

from .elements import Qasm3Module, generate_module_id
from .exceptions import Qasm3ConversionError
from .visitor import BasicQasmVisitor


def semantic_check(
program: Union[openqasm3.ast.Program, str],
name: Optional[str] = None,
**kwargs,
) -> tuple[bool, Optional[Exception]]:
"""Validates a given OpenQASM 3 program for semantic correctness.
Args:
program (openqasm3.ast.Program or str): The OpenQASM 3 program to convert.
name (str, optional): Identifier for created QIR module. Auto-generated if not provided.
Keyword Args:
initialize_runtime (bool): Whether to perform quantum runtime environment initialization,
default `True`.
record_output (bool): Whether to record output calls for registers, default `True`
Returns:
bool : True if the program is semantically correct, False otherwise.
Raises:
Exception: If the input is not a valid OpenQASM 3 program.
"""
if isinstance(program, str):
program = openqasm3.parse(program)

elif not isinstance(program, openqasm3.ast.Program):
raise TypeError("Input quantum program must be of type openqasm3.ast.Program or str.")

if name is None:
name = generate_module_id()

llvm_module = qir_module(Context(), name)
module = Qasm3Module.from_program(program, llvm_module)

try:
visitor = BasicQasmVisitor(check_only=True, **kwargs)
module.accept(visitor)
except (Qasm3ConversionError, TypeError, ValueError) as err:
return (False, err)

return (True, None)
56 changes: 34 additions & 22 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ class BasicQasmVisitor(ProgramElementVisitor):
record_output (bool): If True, output of the circuit will be recorded. Defaults to True.
"""

def __init__(self, initialize_runtime: bool = True, record_output: bool = True):
def __init__(
self, initialize_runtime: bool = True, record_output: bool = True, check_only: bool = False
):
self._module: pyqir.Module
self._builder: pyqir.Builder
self._entry_point: str = ""
Expand All @@ -81,6 +83,7 @@ def __init__(self, initialize_runtime: bool = True, record_output: bool = True):
self._subroutine_defns: dict[str, qasm3_ast.SubroutineDefinition] = {}
self._initialize_runtime: bool = initialize_runtime
self._record_output: bool = record_output
self._check_only: bool = check_only
self._curr_scope: int = 0
self._label_scope_level: dict[int, set] = {self._curr_scope: set()}

Expand Down Expand Up @@ -521,9 +524,9 @@ def _visit_measurement(self, statement: qasm3_ast.QuantumMeasurementStatement) -
"for measurement operation",
span=statement.span,
)

for src_id, tgt_id in zip(source_ids, target_ids):
pyqir._native.mz(self._builder, src_id, tgt_id) # type: ignore[arg-type]
if not self._check_only:
for src_id, tgt_id in zip(source_ids, target_ids):
pyqir._native.mz(self._builder, src_id, tgt_id) # type: ignore[arg-type]

def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> None:
"""Visit a reset statement element.
Expand All @@ -546,9 +549,10 @@ def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> None:
)
qubit_ids = self._get_op_bits(statement, self._global_qreg_size_map, True)

for qid in qubit_ids:
# qid is of type Constant which is inherited from Value, so we ignore the type error
pyqir._native.reset(self._builder, qid) # type: ignore[arg-type]
if not self._check_only:
for qid in qubit_ids:
# qid is of type Constant which is inherited from Value, so we ignore the type error
pyqir._native.reset(self._builder, qid) # type: ignore[arg-type]

def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> None:
"""Visit a barrier statement element.
Expand All @@ -575,7 +579,8 @@ def _visit_barrier(self, barrier: qasm3_ast.QuantumBarrier) -> None:
barrier_qubits = self._get_op_bits(barrier, self._global_qreg_size_map)
total_qubit_count = sum(self._global_qreg_size_map.values())
if len(barrier_qubits) == total_qubit_count:
pyqir._native.barrier(self._builder)
if not self._check_only:
pyqir._native.barrier(self._builder)
else:
raise_qasm3_error(
"Barrier operation on a qubit subset is not supported in pyqir",
Expand Down Expand Up @@ -662,13 +667,14 @@ def _visit_basic_gate_operation(
if inverse_action == InversionOp.INVERT_ROTATION:
op_parameters = [-1 * param for param in op_parameters]

for i in range(0, len(op_qubits), op_qubit_count):
# we apply the gate on the qubit subset linearly
qubit_subset = op_qubits[i : i + op_qubit_count]
if op_parameters is not None:
qir_func(self._builder, *op_parameters, *qubit_subset)
else:
qir_func(self._builder, *qubit_subset)
if not self._check_only:
for i in range(0, len(op_qubits), op_qubit_count):
# we apply the gate on the qubit subset linearly
qubit_subset = op_qubits[i : i + op_qubit_count]
if op_parameters is not None:
qir_func(self._builder, *op_parameters, *qubit_subset)
else:
qir_func(self._builder, *qubit_subset)

def _visit_custom_gate_operation(
self, operation: qasm3_ast.QuantumGate, inverse: bool = False
Expand Down Expand Up @@ -1082,13 +1088,14 @@ def _visit_statement_block(block):
for stmt in block:
self.visit_statement(stmt)

# if the condition is true, we visit the if block
pyqir._native.if_result(
self._builder,
pyqir.result(self._module.context, self._clbit_labels[f"{reg_name}_{reg_id}"]),
zero=lambda: _visit_statement_block(else_block),
one=lambda: _visit_statement_block(if_block),
)
if not self._check_only:
# if the condition is true, we visit the if block
pyqir._native.if_result(
self._builder,
pyqir.result(self._module.context, self._clbit_labels[f"{reg_name}_{reg_id}"]),
zero=lambda: _visit_statement_block(else_block),
one=lambda: _visit_statement_block(if_block),
)

del self._label_scope_level[self._curr_scope]
self._curr_scope -= 1
Expand Down Expand Up @@ -1145,6 +1152,11 @@ def _visit_forin_loop(self, statement: qasm3_ast.ForInLoop) -> None:
self._pop_scope()
self._restore_context()

# as we are only checking compile time errors
# not runtime errors, we can break here
if self._check_only:
break

Check warning on line 1158 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L1158

Added line #L1158 was not covered by tests

def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition) -> None:
"""Visit a subroutine definition element.
Reference: https://openqasm.com/language/subroutines.html#subroutines
Expand Down
51 changes: 51 additions & 0 deletions tests/qasm3_qir/test_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2023 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Tests the checker module of qasm3
"""

import pytest

from qbraid_qir.qasm3.checker import semantic_check
from qbraid_qir.qasm3.exceptions import Qasm3ConversionError


def test_correct_check():
check, err = semantic_check("OPENQASM 3; include 'stdgates.inc'; qubit q;")
assert check
assert err is None


def test_incorrect_check():
check, err = semantic_check(
"""
//semantically incorrect program
OPENQASM 3;
include 'stdgates.inc';
qubit q;
rx(3.14) q[2];
""",
name="test",
)
assert not check
assert err is not None
assert isinstance(err, Qasm3ConversionError)


def test_incorrect_program_type():
with pytest.raises(
TypeError, match="Input quantum program must be of type openqasm3.ast.Program or str."
):
check, err = semantic_check(1234)
assert not check
assert err is not None
assert isinstance(err, TypeError)

0 comments on commit c35b860

Please sign in to comment.