Skip to content

Commit

Permalink
start with warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Sep 24, 2024
1 parent 30399ed commit 70717b1
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 10 deletions.
2 changes: 2 additions & 0 deletions qbraid_qir/qasm3/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,7 @@ def semantic_check(
try:
visitor = BasicQasmVisitor(check_only=True, **kwargs)
module.accept(visitor)
visitor.finalize_check()

except (Qasm3ConversionError, TypeError, ValueError) as err:
raise err
1 change: 1 addition & 0 deletions qbraid_qir/qasm3/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def qasm3_to_qir(

visitor = BasicQasmVisitor(**kwargs)
module.accept(visitor)
visitor.finalize_check()

err = llvm_module.verify()
if err is not None:
Expand Down
1 change: 1 addition & 0 deletions qbraid_qir/qasm3/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
self.value = value
self.is_constant = is_constant
self.readonly = readonly
self.referenced = False


class _ProgramElement(metaclass=ABCMeta):
Expand Down
39 changes: 39 additions & 0 deletions qbraid_qir/qasm3/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,24 @@
"""
import logging
from enum import Enum
from typing import Optional, Type

from openqasm3.ast import Span

from qbraid_qir.exceptions import QirConversionError

logging.basicConfig(level=logging.WARNING, format="%(levelname)s - %(message)s")


class WarnType(Enum):
"""Enum for different qasm3 semantic warnings."""

UNUSED_VAR = "unused"
IMPLICIT_CAST = "implicit_cast"
UNUSED_FUNCTION = "unused_function"
UNUSED_GATE = "unused_gate"


class Qasm3ConversionError(QirConversionError):
"""Class for errors raised when converting an OpenQASM 3 program to QIR."""
Expand Down Expand Up @@ -49,3 +61,30 @@ def raise_qasm3_error(
if raised_from:
raise err_type(message) from raised_from
raise err_type(message)


def emit_qasm3_warning(
warning_type: WarnType,
message: Optional[str] = None,
span: Optional[Span] = None,
):
"""Emits a QASM3 conversion warning.
Args:
warning_type: The type of warning.
message: The warning message.
span: The span (location) in the QASM file where the warning occurred.
Returns:
None
"""
err_message = "No message provided."
if span:
err_message = (

Check warning on line 83 in qbraid_qir/qasm3/exceptions.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/exceptions.py#L83

Added line #L83 was not covered by tests
f"Warning at line {span.start_line}, column {span.start_column} in QASM file. "
)

if message:
err_message = message if not span else f"{err_message} {message}"

logging.warning(f"{warning_type} warning emitted: {err_message}")
12 changes: 4 additions & 8 deletions qbraid_qir/qasm3/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,13 @@ def _get_var_value(cls, var_name, indices, expression):
Returns:
var_value: The value of the variable.
"""

var_value = None
variable_obj = cls.visitor_obj._get_from_visible_scope(var_name)
if isinstance(expression, Identifier):
var_value = cls.visitor_obj._get_from_visible_scope(var_name).value
var_value = variable_obj.value
else:
validated_indices = Qasm3Analyzer.analyze_classical_indices(
indices, cls.visitor_obj._get_from_visible_scope(var_name)
)
var_value = Qasm3Analyzer.find_array_element(
cls.visitor_obj._get_from_visible_scope(var_name).value, validated_indices
)
validated_indices = Qasm3Analyzer.analyze_classical_indices(indices, variable_obj)
var_value = Qasm3Analyzer.find_array_element(variable_obj.value, validated_indices)
return var_value

@classmethod
Expand Down
74 changes: 72 additions & 2 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from .analyzer import Qasm3Analyzer
from .elements import Context, InversionOp, Qasm3Module, Variable
from .exceptions import Qasm3ConversionError, raise_qasm3_error
from .exceptions import Qasm3ConversionError, WarnType, emit_qasm3_warning, raise_qasm3_error
from .expressions import Qasm3ExprEvaluator
from .maps import (
ARRAY_TYPE_MAP,
Expand Down Expand Up @@ -66,11 +66,16 @@ class BasicQasmVisitor(ProgramElementVisitor):
"""

def __init__(
self, initialize_runtime: bool = True, record_output: bool = True, check_only: bool = False
self,
initialize_runtime: bool = True,
record_output: bool = True,
check_only: bool = False,
suppress_warnings: bool = False,
):
self._module: pyqir.Module
self._builder: pyqir.Builder
self._entry_point: str = ""
self._suppress_warnings: bool = suppress_warnings
self._scope: deque = deque([{}])
self._context: deque = deque([Context.GLOBAL])
self._qubit_labels: dict[str, int] = {}
Expand All @@ -80,7 +85,9 @@ def __init__(
self._function_qreg_transform_map: deque = deque([]) # for nested functions
self._global_creg_size_map: dict[str, int] = {}
self._custom_gates: dict[str, qasm3_ast.QuantumGateDefinition] = {}
self._custom_gates_usage: dict[str, bool] = {}
self._subroutine_defns: dict[str, qasm3_ast.SubroutineDefinition] = {}
self._subroutine_usage: dict[str, bool] = {}
self._initialize_runtime: bool = initialize_runtime
self._record_output: bool = record_output
self._check_only: bool = check_only
Expand Down Expand Up @@ -135,9 +142,44 @@ def _push_context(self, context: Context) -> None:
raise TypeError("Context must be an instance of Context")
self._context.append(context)

def _warn_unused_vars(self, scope: dict) -> None:
for var_name, var in scope.items():
if not var.referenced:
emit_qasm3_warning(
WarnType.UNUSED_VAR,
f"Variable '{var_name}' is declared but not used.",
span=None,
)

def _emit_unresolved_warnings(self) -> None:
if self._suppress_warnings:
return

Check warning on line 156 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L156

Added line #L156 was not covered by tests
# emit unused gates
for gate, usage in self._custom_gates_usage.items():
if not usage:
emit_qasm3_warning(

Check warning on line 160 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L160

Added line #L160 was not covered by tests
WarnType.UNUSED_GATE,
f"Gate '{gate}' is declared but not used.",
self._custom_gates[gate].span,
)

# emit unused functions
for func, usage in self._subroutine_usage.items():
if not usage:
emit_qasm3_warning(

Check warning on line 169 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L169

Added line #L169 was not covered by tests
WarnType.UNUSED_FUNCTION,
f"Function '{func}' is declared but not used.",
self._subroutine_defns[func].span,
)

# emit unused global variables
self._warn_unused_vars(self._scope[0])

def _pop_scope(self) -> None:
if len(self._scope) == 0:
raise IndexError("Scope list is empty, can not pop")
if not self._suppress_warnings:
self._warn_unused_vars(self._scope[-1])
self._scope.pop()

def _restore_context(self) -> None:
Expand Down Expand Up @@ -193,17 +235,24 @@ def _check_in_scope(self, var_name: str) -> bool:
global_scope = self._get_global_scope()
curr_scope = self._get_curr_scope()
if self._in_global_scope():
if var_name in global_scope:
global_scope[var_name].referenced = True
return var_name in global_scope
if self._in_function_scope() or self._in_gate_scope():
if var_name in curr_scope:
curr_scope[var_name].referenced = True
return True
if var_name in global_scope:
global_scope[var_name].referenced = True
return global_scope[var_name].is_constant
if self._in_block_scope():
for scope, context in zip(reversed(self._scope), reversed(self._context)):
if context != Context.BLOCK:
if var_name in scope:
scope[var_name].referenced = True
return var_name in scope
if var_name in scope:
scope[var_name].referenced = True
return True
return False

Expand All @@ -221,17 +270,24 @@ def _get_from_visible_scope(self, var_name: str) -> Union[Variable, None]:
curr_scope = self._get_curr_scope()

if self._in_global_scope():
if var_name in global_scope:
global_scope[var_name].referenced = True
return global_scope.get(var_name, None)
if self._in_function_scope() or self._in_gate_scope():
if var_name in curr_scope:
curr_scope[var_name].referenced = True
return curr_scope[var_name]
if var_name in global_scope and global_scope[var_name].is_constant:
global_scope[var_name].referenced = True
return global_scope[var_name]
if self._in_block_scope():
for scope, context in zip(reversed(self._scope), reversed(self._context)):
if context != Context.BLOCK:
if var_name in scope:
scope[var_name].referenced = True
return scope.get(var_name, None)
if var_name in scope:
scope[var_name].referenced = True
return scope[var_name]
# keep on checking
return None
Expand Down Expand Up @@ -417,6 +473,9 @@ def _get_op_bits(
span=operation.span,
)
self._check_if_name_in_scope(reg_name, operation)
# mark the variable as referenced
if self._get_from_visible_scope(reg_name):
self._get_from_visible_scope(reg_name).referenced = True # type: ignore[union-attr]

if isinstance(bit, qasm3_ast.IndexedIdentifier):
if isinstance(bit.indices[0], qasm3_ast.DiscreteSet):
Expand Down Expand Up @@ -617,6 +676,7 @@ def _visit_gate_definition(self, definition: qasm3_ast.QuantumGateDefinition) ->
if gate_name in self._custom_gates:
raise_qasm3_error(f"Duplicate gate definition for {gate_name}", span=definition.span)
self._custom_gates[gate_name] = definition
self._custom_gates_usage[gate_name] = False

def _visit_basic_gate_operation(
self, operation: qasm3_ast.QuantumGate, inverse: bool = False
Expand Down Expand Up @@ -808,6 +868,9 @@ def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> Non
)
# Applying the inverse first and then the power is same as
# apply the power first and then inverting the result
if operation.name.name in self._custom_gates:
self._custom_gates_usage[operation.name.name] = True

for _ in range(power_value):
if operation.name.name in self._custom_gates:
self._visit_custom_gate_operation(operation, inverse_value)
Expand Down Expand Up @@ -1185,6 +1248,7 @@ def _visit_subroutine_definition(self, statement: qasm3_ast.SubroutineDefinition
)

self._subroutine_defns[fn_name] = statement
self._subroutine_usage[fn_name] = False

# pylint: disable=too-many-locals, too-many-statements
def _visit_function_call(self, statement: qasm3_ast.FunctionCall) -> None:
Expand All @@ -1201,6 +1265,7 @@ def _visit_function_call(self, statement: qasm3_ast.FunctionCall) -> None:
raise_qasm3_error(f"Undefined subroutine '{fn_name}' was called", span=statement.span)

subroutine_def = self._subroutine_defns[fn_name]
self._subroutine_usage[fn_name] = True

if len(statement.arguments) != len(subroutine_def.arguments):
raise_qasm3_error(
Expand Down Expand Up @@ -1482,3 +1547,8 @@ def ir(self) -> str:

def bitcode(self) -> bytes:
return self._module.bitcode

def finalize_check(self):
if self._suppress_warnings:
return

Check warning on line 1553 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L1553

Added line #L1553 was not covered by tests
self._emit_unresolved_warnings()

0 comments on commit 70717b1

Please sign in to comment.