Skip to content

Commit

Permalink
complete support for arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Sep 2, 2024
1 parent d34df1a commit e42befd
Show file tree
Hide file tree
Showing 11 changed files with 678 additions and 435 deletions.
75 changes: 40 additions & 35 deletions qbraid_qir/qasm3/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
Module with analysis functions for QASM3 visitor
"""
# pylint: disable=cyclic-import, import-outside-toplevel

from typing import Any, Optional, Union

import numpy as np
from openqasm3.ast import (
BinaryExpression,
DiscreteSet,
Expression,
Identifier,
IndexExpression,
IntegerLiteral,
IntType,
RangeDefinition,
UnaryExpression,
)
Expand All @@ -31,13 +36,13 @@
class Qasm3Analyzer:
"""Class with utility functions for analyzing QASM3 elements"""

@staticmethod
def analyze_classical_indices(indices: list[Any], var: Variable) -> list:
@classmethod
def analyze_classical_indices(cls, indices: list[Any], var: Variable) -> list:
"""Validate the indices for a classical variable.
Args:
indices (list[list[Any]]): The indices to validate.
var_name (Variable): The variable to verify
var (Variable): The variable to verify
Raises:
Qasm3ConversionError: If the indices are invalid.
Expand All @@ -47,12 +52,11 @@ def analyze_classical_indices(indices: list[Any], var: Variable) -> list:
a list if the variable is a multi-dimensional array.
"""
indices_list = []
var_name = var.name
var_dimensions: Optional[list[int]] = var.dims

if var_dimensions is None or len(var_dimensions) == 0:
raise_qasm3_error(
message=f"Indexing error. Variable {var_name} is not an array",
message=f"Indexing error. Variable {var.name} is not an array",
err_type=Qasm3ConversionError,
span=indices[0].span,
)
Expand All @@ -61,7 +65,7 @@ def analyze_classical_indices(indices: list[Any], var: Variable) -> list:

if len(indices) != len(var_dimensions): # type: ignore[arg-type]
raise_qasm3_error(
message=f"Invalid number of indices for variable {var_name}. "
message=f"Invalid number of indices for variable {var.name}. "
f"Expected {len(var_dimensions)} but got {len(indices)}", # type: ignore[arg-type]
err_type=Qasm3ConversionError,
span=indices[0].span,
Expand All @@ -86,40 +90,44 @@ def _validate_step(start_id, end_id, step, span):
span=span,
)

from .expressions import Qasm3ExprEvaluator

for i, index in enumerate(indices):
if not isinstance(index, (RangeDefinition, IntegerLiteral)):
if not isinstance(index, (Identifier, Expression, RangeDefinition, IntegerLiteral)):
raise_qasm3_error(
message=f"Unsupported index type {type(index)} for "
f"classical variable {var_name}",
f"classical variable {var.name}",
err_type=Qasm3ConversionError,
span=index.span,
)

if isinstance(index, RangeDefinition):
range_def = index
# TODO : add support for identifiers here
assert var_dimensions is not None

start_id = (
range_def.start.value if isinstance(range_def.start, IntegerLiteral) else 0
)
end_id = (
range_def.end.value
if isinstance(range_def.end, IntegerLiteral)
else var_dimensions[i] - 1
)
step = range_def.step.value if isinstance(range_def.step, IntegerLiteral) else 1
start_id = 0
if index.start is not None:
start_id = Qasm3ExprEvaluator.evaluate_expression(
index.start, reqd_type=IntType
)

end_id = var_dimensions[i] - 1
if index.end is not None:
end_id = Qasm3ExprEvaluator.evaluate_expression(index.end, reqd_type=IntType)

_validate_index(start_id, var_dimensions[i], var_name, range_def.span, i)
_validate_index(end_id, var_dimensions[i], var_name, range_def.span, i)
_validate_step(start_id, end_id, step, range_def.span)
step = 1
if index.step is not None:
step = Qasm3ExprEvaluator.evaluate_expression(index.step, reqd_type=IntType)

Check warning on line 119 in qbraid_qir/qasm3/analyzer.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/analyzer.py#L119

Added line #L119 was not covered by tests

_validate_index(start_id, var_dimensions[i], var.name, index.span, i)
_validate_index(end_id, var_dimensions[i], var.name, index.span, i)
_validate_step(start_id, end_id, step, index.span)

indices_list.append((start_id, end_id, step))

if isinstance(index, IntegerLiteral):
index_value = index.value
if isinstance(index, (Identifier, IntegerLiteral, Expression)):
index_value = Qasm3ExprEvaluator.evaluate_expression(index, reqd_type=IntType)
curr_dimension = var_dimensions[i] # type: ignore[index]
_validate_index(index_value, curr_dimension, var_name, index.span, i)
_validate_index(index_value, curr_dimension, var.name, index.span, i)

indices_list.append((index_value, index_value, 1))

Expand Down Expand Up @@ -161,23 +169,20 @@ def analyze_index_expression(
return var_name, indices

@staticmethod
def find_array_element(multi_dim_arr: list[Any], indices: list[tuple[int, int, int]]) -> Any:
def find_array_element(multi_dim_arr: np.ndarray, indices: list[tuple[int, int, int]]) -> Any:
"""Find the value of an array at the specified indices.
Args:
multi_dim_arr (list): The multi-dimensional list to search.
indices (list[int]): The indices to search.
multi_dim_arr (np.ndarray): The multi-dimensional list to search.
indices (list[tuple[int,int,int]]): The indices to search.
Returns:
Any: The value at the specified indices.
"""
temp = multi_dim_arr
for index_element in indices:
start, end, step = index_element
temp = temp[start : end + 1 : step]
if start == end: # if literal index, we decrease the dimension by 1
temp = temp[0]
return temp
slicing = tuple(
slice(start, end + 1, step) if start != end else start for start, end, step in indices
)
return multi_dim_arr[slicing] # type: ignore[index]

@staticmethod
def analyse_branch_condition(condition: Any) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion qbraid_qir/qasm3/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from enum import Enum
from typing import Optional, Union

import numpy as np
from openqasm3.ast import BitType, ClassicalDeclaration, Program, QubitDeclaration, Statement
from pyqir import Context as qirContext
from pyqir import Module
Expand Down Expand Up @@ -60,6 +61,7 @@ class Variable:
dims (list[int]): Dimensions of the variable.
value (Optional[Union[int, float, list]]): Value of the variable.
is_constant (bool): Flag indicating if the variable is constant.
readonly(bool): Flag indicating if the variable is readonly.
"""

Expand All @@ -70,15 +72,17 @@ def __init__(
base_type,
base_size: int,
dims: Optional[list[int]] = None,
value: Optional[Union[int, float, list]] = None,
value: Optional[Union[int, float, np.ndarray]] = None,
is_constant: bool = False,
readonly: bool = False,
):
self.name = name
self.base_type = base_type
self.base_size = base_size
self.dims = dims
self.value = value
self.is_constant = is_constant
self.readonly = readonly


class _ProgramElement(metaclass=ABCMeta):
Expand Down
31 changes: 16 additions & 15 deletions qbraid_qir/qasm3/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,24 +667,25 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value):
ComplexType: complex,
# AngleType: None, # not sure
}
VARIABLE_TYPE_STR = {
BitType: "bit",
IntType: "int",
UintType: "uint",
BoolType: "bool",
FloatType: "float",
ComplexType: "complex",
AngleType: "angle",
}

# Reference: https://openqasm.com/language/types.html#allowed-casts
VARIABLE_TYPE_CAST_MAP = {
BoolType: [int, float, bool],
IntType: [bool, int, float],
BitType: [bool, int],
UintType: [bool, int, float],
FloatType: [bool, int, float],
AngleType: [float],
BoolType: (int, float, bool, np.int64, np.float64, np.bool_),
IntType: (bool, int, float, np.int64, np.float64, np.bool_),
BitType: (bool, int, np.int64, np.bool_),
UintType: (bool, int, float, np.int64, np.uint64, np.float64, np.bool_),
FloatType: (bool, int, float, np.int64, np.float64, np.bool_),
AngleType: (float, np.float64),
}

ARRAY_TYPE_MAP = {
BitType: np.bool_,
IntType: np.int64,
UintType: np.uint64,
FloatType: np.float64,
ComplexType: np.complex128,
BoolType: np.bool_,
AngleType: np.float64,
}


Expand Down
Loading

0 comments on commit e42befd

Please sign in to comment.