Skip to content

Commit

Permalink
add static types to qasm
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Aug 26, 2024
1 parent 1ca2d93 commit ba1cf7e
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ jobs:
run: |
python3 -m pip install --upgrade pip
python3 -m pip install tox>=4.2.0
- name: Check isort, black, headers
- name: Check isort, black, mypy, headers
run: |
tox -e format-check
24 changes: 24 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[mypy]

# Ignore missing imports
ignore_missing_imports = True

# Enable incremental mode
incremental = True

# Show error codes in output
show_error_codes = True

# Follow imports for type checking
follow_imports = normal

# Enable cache
cache_dir = .mypy_cache

# TODO: fix typing for cirq
[mypy-qbraid_qir.cirq.*]
ignore_errors = True

# TODO: fix typing for qasm3 visitor
[mypy-qbraid_qir.qasm3.visitor]
ignore_errors = True
47 changes: 30 additions & 17 deletions qbraid_qir/qasm3/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
Module with analysis functions for QASM3 visitor
"""
from typing import Any
from typing import Any, Optional, Union

from openqasm3.ast import (
BinaryExpression,
DiscreteSet,
Expression,
IndexExpression,
IntegerLiteral,
RangeDefinition,
Expand All @@ -30,7 +32,7 @@ class Qasm3Analyzer:
"""Class with utility functions for analyzing QASM3 elements"""

@staticmethod
def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> None:
def analyze_classical_indices(indices: list[Any], var: Variable) -> list:
"""Validate the indices for a classical variable.
Args:
Expand All @@ -45,18 +47,18 @@ def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> N
"""
indices_list = []
var_name = var.name
var_dimensions = var.dims
var_dimensions: Optional[list[int]] = var.dims

if not var_dimensions:
if var_dimensions is None or len(var_dimensions) == 0:
raise_qasm3_error(
message=f"Indexing error. Variable {var_name} is not an array",
err_type=Qasm3ConversionError,
span=indices[0].span,
)
if len(indices) != len(var_dimensions):
if len(indices) != len(var_dimensions): # type: ignore[arg-type]
raise_qasm3_error(
message=f"Invalid number of indices for variable {var_name}. "
f"Expected {len(var_dimensions)} but got {len(indices)}",
f"Expected {len(var_dimensions)} but got {len(indices)}", # type: ignore[arg-type]
err_type=Qasm3ConversionError,
span=indices[0].span,
)
Expand All @@ -78,7 +80,7 @@ def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> N
span=index.span,
)
index_value = index.value
curr_dimension = var_dimensions[i]
curr_dimension = var_dimensions[i] # type: ignore[index]

if index_value < 0 or index_value >= curr_dimension:
raise_qasm3_error(
Expand All @@ -92,29 +94,40 @@ def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> N
return indices_list

@staticmethod
def analyze_index_expression(index_expr: IndexExpression) -> tuple[str, list[list]]:
def analyze_index_expression(
index_expr: IndexExpression,
) -> tuple[str, list[Union[Any, Expression, RangeDefinition]]]:
"""analyze an index expression to get the variable name and indices.
Args:
index_expr (IndexExpression): The index expression to analyze.
Returns:
tuple[str, list[list]]: The variable name and indices.
tuple[str, list[Any]]: The variable name and indices.
"""
indices = []
var_name = None
indices: list[Any] = []
var_name = ""
comma_separated = False

if isinstance(index_expr.collection, IndexExpression):
while isinstance(index_expr, IndexExpression):
indices.append(index_expr.index[0])
index_expr = index_expr.collection
if isinstance(index_expr.index, list):
indices.append(index_expr.index[0])
index_expr = index_expr.collection
elif isinstance(index_expr.index, DiscreteSet):
raise_qasm3_error(

Check warning on line 119 in qbraid_qir/qasm3/analyzer.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/analyzer.py#L118-L119

Added lines #L118 - L119 were not covered by tests
message="DiscreteSet not supported in index expression",
err_type=Qasm3ConversionError,
span=index_expr.span,
)

else:
comma_separated = True
indices = index_expr.index

var_name = index_expr.collection.name if comma_separated else index_expr.name
indices = index_expr.index # type: ignore[assignment]
var_name = (
index_expr.collection.name if comma_separated else index_expr.name
) # type: ignore[attr-defined]
if not comma_separated:
indices = indices[::-1]

Expand Down Expand Up @@ -169,7 +182,7 @@ def analyse_branch_condition(condition: Any) -> bool:
err_type=Qasm3ConversionError,
span=condition.span,
)
return condition.rhs.value != 0
return condition.rhs.value != 0 # type: ignore[attr-defined]
if not isinstance(condition, IndexExpression):
raise_qasm3_error(
message=(
Expand Down
14 changes: 9 additions & 5 deletions qbraid_qir/qasm3/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
name: str,
base_type,
base_size: int,
dims: list[int] = None,
dims: Optional[list[int]] = None,
value: Optional[Union[int, float, list]] = None,
is_constant: bool = False,
):
Expand Down Expand Up @@ -166,20 +166,24 @@ def from_program(cls, program: Program, module: Optional[Module] = None):
Class method. Construct a Qasm3Module from a given openqasm3.ast.Program object
and an optional QIR Module.
"""
elements = []
elements: list[Union[_Register, _Statement]] = []

num_qubits = 0
num_clbits = 0
for statement in program.statements:
if isinstance(statement, QubitDeclaration):
size = 1 if statement.size is None else statement.size.value
size = 1
if statement.size:
size = statement.size.value # type: ignore[attr-defined]
num_qubits += size
elements.append(_Register(statement))

elif isinstance(statement, ClassicalDeclaration) and isinstance(
statement.type, BitType
):
size = 1 if statement.type.size is None else statement.type.size.value
size = 1
if statement.type.size:
size = statement.type.size.value # type: ignore[attr-defined]
num_clbits += size
elements.append(_Register(statement))
# as bit arrays are just 0 / 1 values, we can treat them as
Expand All @@ -191,7 +195,7 @@ def from_program(cls, program: Program, module: Optional[Module] = None):

if module is None:
# pylint: disable-next=too-many-function-args
module = Module(qirContext(), generate_module_id(program))
module = Module(qirContext(), generate_module_id())

Check warning on line 198 in qbraid_qir/qasm3/elements.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/elements.py#L198

Added line #L198 was not covered by tests

return cls(
name="main",
Expand Down
6 changes: 3 additions & 3 deletions qbraid_qir/qasm3/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Qasm3ExprEvaluator:
visitor_obj = None

@classmethod
def set_visitor_obj(cls, visitor_obj):
def set_visitor_obj(cls, visitor_obj) -> None:
cls.visitor_obj = visitor_obj

@classmethod
Expand Down Expand Up @@ -171,7 +171,7 @@ def evaluate_expression(cls, expression, const_expr: bool = False, reqd_type=Non
expression.span,
)

def _process_variable(var_name, indices=None):
def _process_variable(var_name: str, indices=None):
cls._check_var_in_scope(var_name, expression)
cls._check_var_constant(var_name, const_expr, expression)
cls._check_var_type(var_name, reqd_type, expression)
Expand Down Expand Up @@ -230,7 +230,7 @@ def _process_variable(var_name, indices=None):
# function will not return a reqd / const type
# Reference : https://openqasm.com/language/types.html#compile-time-constants
# para : 5
return cls.visitor_obj._visit_function_call(expression)
return cls.visitor_obj._visit_function_call(expression) # type: ignore[union-attr]

raise_qasm3_error(
f"Unsupported expression type {type(expression)}", Qasm3ConversionError, expression.span
Expand Down
15 changes: 11 additions & 4 deletions qbraid_qir/qasm3/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""


from typing import Union
from typing import Callable, Union

import numpy as np
import pyqir
Expand All @@ -36,7 +36,13 @@
from .exceptions import Qasm3ConversionError
from .linalg import kak_decomposition_angles

OPERATOR_MAP = {
# Define the type for the operator functions
OperatorFunction = Union[
Callable[[Union[int, float, bool]], Union[int, float, bool]],
Callable[[Union[int, float, bool], Union[int, float, bool]], Union[int, float, bool]],
]

OPERATOR_MAP: dict[str, OperatorFunction] = {
"+": lambda x, y: x + y,
"-": lambda x, y: x - y,
"*": lambda x, y: x * y,
Expand All @@ -61,7 +67,7 @@
}


def qasm3_expression_op_map(op_name: str, *args):
def qasm3_expression_op_map(op_name: str, *args) -> Union[float, int, bool]:
"""
Return the result of applying the given operator to the given operands.
Expand All @@ -75,7 +81,8 @@ def qasm3_expression_op_map(op_name: str, *args):
(Union[float, int, bool]): The result of applying the operator to the operands.
"""
try:
return OPERATOR_MAP[op_name](*args)
operator = OPERATOR_MAP[op_name]
return operator(*args)
except KeyError as exc:
raise Qasm3ConversionError(f"Unsupported / undeclared QASM operator: {op_name}") from exc

Expand Down
45 changes: 33 additions & 12 deletions qbraid_qir/qasm3/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Module with transformation functions for QASM3 visitor
"""
from typing import Any, Optional, Union
from typing import Any, Union

from openqasm3.ast import (
BinaryExpression,
Expand All @@ -33,14 +33,16 @@
from .expressions import Qasm3ExprEvaluator
from .validator import Qasm3Validator

# mypy: disable-error-code="attr-defined, union-attr"


class Qasm3Transformer:
"""Class with utility functions for transforming QASM3 elements"""

visitor_obj = None

@classmethod
def set_visitor_obj(cls, visitor_obj):
def set_visitor_obj(cls, visitor_obj) -> None:
cls.visitor_obj = visitor_obj

@staticmethod
Expand Down Expand Up @@ -128,8 +130,10 @@ def transform_gate_qubits(
f"Indexing '{qubit.name.name}' not supported in gate definition",
span=qubit.span,
)

gate_op.qubits[i] = qubit_map[qubit.name]
gate_qubit_name = qubit.name
if isinstance(gate_qubit_name, Identifier):
gate_qubit_name = gate_qubit_name.name

Check warning on line 135 in qbraid_qir/qasm3/transformer.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/transformer.py#L135

Added line #L135 was not covered by tests
gate_op.qubits[i] = qubit_map[gate_qubit_name]

@staticmethod
def transform_gate_params(
Expand All @@ -151,13 +155,12 @@ def transform_gate_params(
# TODO : update the arg value in expressions not just SINGLE identifiers

@staticmethod
def get_branch_params(condition: Any) -> tuple[Optional[int], Optional[str]]:
def get_branch_params(condition: Any) -> tuple[int, str]:
"""
Get the branch parameters from the branching condition
Args:
condition (Union[UnaryExpression, BinaryExpression, IndexExpression]): The condition
to analyze
condition (Any): The condition to analyze
Returns:
tuple[int, str]: The branch parameters
Expand All @@ -168,17 +171,35 @@ def get_branch_params(condition: Any) -> tuple[Optional[int], Optional[str]]:
condition.expression.collection.name,
)
if isinstance(condition, BinaryExpression):
return condition.lhs.index[0].value, condition.lhs.collection.name
return (
condition.lhs.index[0].value,
condition.lhs.collection.name,
)
if isinstance(condition, IndexExpression):
return condition.index[0].value, condition.collection.name
return None, None
if isinstance(condition.index, DiscreteSet):
raise_qasm3_error(

Check warning on line 180 in qbraid_qir/qasm3/transformer.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/transformer.py#L180

Added line #L180 was not covered by tests
message="DiscreteSet not supported in branching condition",
span=condition.span,
)
if isinstance(condition.index, list):
if isinstance(condition.index[0], RangeDefinition):
raise_qasm3_error(
message="RangeDefinition not supported in branching condition",
span=condition.span,
)
return (
condition.index[0].value,
condition.collection.name,
)
# default case
return -1, ""

Check warning on line 195 in qbraid_qir/qasm3/transformer.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/transformer.py#L195

Added line #L195 was not covered by tests

@classmethod
def transform_function_qubits(
cls,
q_op: Union[QuantumGate, QuantumBarrier, QuantumReset],
formal_qreg_sizes: dict[str:int],
qubit_map: dict[tuple:tuple],
formal_qreg_sizes: dict[str, int],
qubit_map: dict[tuple, tuple],
) -> list:
"""Transform the qubits of a function call to the actual qubits.
Expand Down
Loading

0 comments on commit ba1cf7e

Please sign in to comment.