Skip to content

Commit

Permalink
add expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Mar 20, 2024
1 parent fa83620 commit 95c303e
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 73 deletions.
7 changes: 0 additions & 7 deletions qbraid_qir/qasm3/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ def qasm3_to_qir(
# Supported conversions qasm3 -> qiskit :
# https://github.com/Qiskit/qiskit-qasm3-import/blob/main/src/qiskit_qasm3_import/converter.py

# PROPOSED SEMANTIC + DECOMPOSITION PASS
# qiskit_circuit = loads(program).decompose(reps=3)
# decomposed_qasm = Exporter().dumps(qiskit_circuit)
# PROPOSED SEMANTIC + DECOMPOSITION PASS

# program = openqasm3.parse(decomposed_qasm)

program = openqasm3.parse(program)

elif not isinstance(program, openqasm3.ast.Program):
Expand Down
31 changes: 0 additions & 31 deletions qbraid_qir/qasm3/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,3 @@


# can use rigetti simulator for qir programs and qiskit for qasm3 programs


# what are some problems which we face?

# 1. let us just unfold the gates once.
# assumption - we do not define any variables inside the gate
# - gate parameters are directly passed on to other gate calls
# - no if else conditions inside the gate
# - no loops inside the gate
# - no function calls inside the gate

# eg. gate x2(a,b,c) p, q {
# h p;
# h q;
# rx(a) p;
# ry(b) q;
# rz(c) p;
# cx p, q;
# }
# qubit[2] q;
# x2(1,2,3) q[0], q[1];

# will be converted to

# qubit[2] q;
# h q[0];
# h q[1];
# rx(1) q[0];
# ry(2) q[1];
# rz(3) q[0];
# cx q[0], q[1];
51 changes: 28 additions & 23 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def _visit_measurement(self, statement: QuantumMeasurementStatement) -> None:
target = statement.target
source_id, target_id = None, None

# handle measurement operation
source_name = source.name
if isinstance(source, IndexedIdentifier):
source_name = source.name.name
Expand Down Expand Up @@ -318,7 +317,6 @@ def _build_qir_measurement(
pyqir._native.mz(self._builder, source_qubit, result)

if source_id is None and target_id is None:
# sizes should match
if self._qreg_size_map[source_name] != self._creg_size_map[target_name]:
raise ValueError(
f"Register sizes of {source_name} and {target_name} do not match for measurement operation"
Expand Down Expand Up @@ -365,7 +363,6 @@ def _visit_reset(self, statement: QuantumReset) -> None:
qreg_size = self._qreg_size_map[qreg_name]
qubit_ids = [self._qubit_labels[f"{qreg_name}_{i}"] for i in range(qreg_size)]

# generate pyqir reset equivalent
for qid in qubit_ids:
pyqir._native.reset(self._builder, pyqir.qubit(self._module.context, qid))

Expand Down Expand Up @@ -402,18 +399,23 @@ def _get_op_parameters(self, operation: QuantumGate) -> List[float]:
"""
param_list = []
for param in operation.arguments:
if not isinstance(param, (FloatLiteral, IntegerLiteral)):
raise ValueError(
f"Unsupported parameter type {type(param)} for operation {operation}"
)
param_list.append(float(param.value))
param_value = self._evaluate_expression(param)
param_list.append(param_value)

if len(param_list) > 1:
raise ValueError(f"Parameterized gate {operation} with > 1 params not supported")

return param_list

def _visit_gate_definition(self, definition: QuantumGateDefinition) -> None:
"""Visit a gate definition element.
Args:
definition (QuantumGateDefinition): The gate definition to visit.
Returns:
None
"""
gate_name = definition.name.name
if gate_name in self._custom_gates:
raise ValueError(f"Duplicate gate definition for {gate_name}")
Expand All @@ -429,9 +431,8 @@ def _visit_basic_gate_operation(self, operation: QuantumGate) -> None:
None
"""

# Currently handling the gates in the stdgates.inc file
_log.debug("Visiting basic gate operation '%s'", str(operation))
op_name = operation.name.name
op_name : str = operation.name.name
op_qubits = self._get_op_qubits(operation)
qir_func, op_qubit_count = map_qasm_op_to_pyqir_callable(op_name)
op_parameters = None
Expand Down Expand Up @@ -466,7 +467,6 @@ def _transform_gate_qubits(self, gate_op, qubit_map):
for i, qubit in enumerate(gate_op.qubits):
if isinstance(qubit, IndexedIdentifier):
raise ValueError(f"Indexing {qubit} not supported in gate definition")
# now we have an Identifier
gate_op.qubits[i] = qubit_map[qubit.name]

def _transform_gate_params(self, gate_op, param_map):
Expand All @@ -480,7 +480,6 @@ def _transform_gate_params(self, gate_op, param_map):
None
"""
for i, param in enumerate(gate_op.arguments):
# replace only if we have an Identifier
if isinstance(param, Identifier):
gate_op.arguments[i] = param_map[param.name]

Expand All @@ -493,8 +492,8 @@ def _visit_custom_gate_operation(self, operation: QuantumGate) -> None:
Returns:in computation
"""
_log.debug("Visiting custom gate operation '%s'", str(operation))
gate_name = operation.name.name
gate_definition = self._custom_gates[gate_name]
gate_name : str = operation.name.name
gate_definition : QuantumGateDefinition = self._custom_gates[gate_name]

if len(operation.arguments) != len(gate_definition.arguments):
raise ValueError(
Expand Down Expand Up @@ -588,7 +587,7 @@ def _visit_classical_operation(self, statement: ClassicalDeclaration) -> None:
else:
raise ValueError(f"Unsupported classical type {decl_type} in {statement}")

def evaluate_expression(self, expression: Any) -> bool:
def _evaluate_expression(self, expression: Any) -> bool:
"""Evaluate an expression.
Args:
Expand All @@ -598,26 +597,32 @@ def evaluate_expression(self, expression: Any) -> bool:
bool: The result of the evaluation.
"""
if isinstance(expression, (ImaginaryLiteral, DurationLiteral)):
raise ValueError(f"Unsupported expression type {type(expression)} in if condition")
raise ValueError(f"Unsupported expression type {type(expression)}")
elif isinstance(expression, Identifier):
# we need to check our scope and context to get the value of the identifier
# if it is a classical register, we can directly get the value
# how to get the value of the identifier in the QIR??
# TO DO : extend this
raise ValueError(f"Unsupported expression type {type(expression)}")
elif isinstance(expression, BooleanLiteral):
return expression.value
elif isinstance(expression, (IntegerLiteral, FloatLiteral)):
return int(expression.value)
return expression.value
elif isinstance(expression, UnaryExpression):
op = expression.op.name # can be '!', '~' or '-'
if op == "!":
return not self.evaluate_expression(expression.expression)
return not self._evaluate_expression(expression.expression)
elif op == "-":
return -1 * self.evaluate_expression(expression.expression)
return -1 * self._evaluate_expression(expression.expression)
elif op == "~":
value = self.evaluate_expression(expression.expression)
value = self._evaluate_expression(expression.expression)
if not isinstance(value, int):
raise ValueError(f"Unsupported expression type {type(value)} in ~ operation")
elif isinstance(expression, BinaryExpression):
lhs = self.evaluate_expression(expression.lhs)
lhs = self._evaluate_expression(expression.lhs)
op = expression.op.name
rhs = self.evaluate_expression(expression.rhs)
return qasm3_expression_op_map[op](lhs, rhs)
rhs = self._evaluate_expression(expression.rhs)
return qasm3_expression_op_map(op, lhs, rhs)

def _visit_branching_statement(self, statement: BranchingStatement) -> None:
"""Visit a branching statement element.
Expand Down
49 changes: 49 additions & 0 deletions tests/qasm3_qir/converter/test_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (C) 2023 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module containing unit tests for QASM3 to QIR conversion functions.
"""
import pytest

from qbraid_qir.qasm3.convert import qasm3_to_qir
from tests.qir_utils import check_attributes, check_expressions


def test_correct_expressions():
qasm_str = """OPENQASM 3;
qubit q;
// supported
rx(1.57) q;
rz(3-2*3) q;
rz(3-2*3*(8/2)) q;
rx(-1.57) q;
rx(4%2) q;
"""

result = qasm3_to_qir(qasm_str)
generated_qir = str(result).splitlines()

check_attributes(generated_qir, 1, 0)
gates = ["rx", "rz", "rz", "rx", "rx"]
expression_values = [1.57, 3 - 2 * 3, 3 - 2 * 3 * (8 / 2), -1.57, 4 % 2]
qubits = [0, 0, 0, 0, 0]
check_expressions(generated_qir, 5, gates, expression_values, qubits)


def test_incorrect_expressions():
with pytest.raises(ValueError, match=r"Unsupported expression type .*"):
qasm3_to_qir("OPENQASM 3; qubit q; rz(1 - 2 + 32im) q;")
with pytest.raises(ValueError, match=r"Unsupported expression type .* in ~ operation"):
qasm3_to_qir("OPENQASM 3; qubit q; rx(~1.3) q;")
with pytest.raises(ValueError, match=r"Unsupported expression type .* in ~ operation"):
qasm3_to_qir("OPENQASM 3; qubit q; rx(~1.3+5im) q;")
2 changes: 1 addition & 1 deletion tests/qasm3_qir/converter/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_incorrect_single_qubit_gates():

# Invalid use of variables in gate application

with pytest.raises(ValueError, match=r"Unsupported parameter type .* for operation .*"):
with pytest.raises(ValueError, match=r"Unsupported expression type .*"):
_ = qasm3_to_qir(
"""
OPENQASM 3;
Expand Down
3 changes: 1 addition & 2 deletions tests/qasm3_qir/fixtures/resources/custom_gate_complex.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@ gate custom3(p, q) c, d {

qubit[2] q;

custom3(0.1, 0.2) q[0], q[1];
custom3(0.3, 0.4) q[0:];
custom3(0.1, 0.2) q[0:];
4 changes: 2 additions & 2 deletions tests/qasm3_qir/fixtures/resources/custom_gate_nested.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ gate custom(a,b,c) p, q{
h p;
cx p,q;
rx(a) q;
ry(0.5) q;
ry(0.5/0.1) q;
}

qubit[2] q;
custom(0.1, 0.2, 0.3) q[0], q[1];
custom(2 + 3 - 1/5, 0.1, 0.3) q[0], q[1];
2 changes: 1 addition & 1 deletion tests/qasm3_qir/fixtures/resources/custom_gate_simple.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ gate custom(a) p, q {
}

qubit[2] q;
custom(0.1) q[0], q[1];
custom(0.1+1) q[0], q[1];
48 changes: 42 additions & 6 deletions tests/qir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _validate_simple_custom_op(entry_body: List[str]):
initialize_call_string(),
single_op_call_string("h", 0),
single_op_call_string("z", 1),
rotation_call_string("rx", 0.1, 0),
rotation_call_string("rx", 1.1, 0),
double_op_call_string("cnot", 0, 1),
result_record_output_string(0),
result_record_output_string(1),
Expand All @@ -322,11 +322,11 @@ def _validate_nested_custom_op(entry_body: List[str]):
nested_op_lines = [
initialize_call_string(),
single_op_call_string("h", 1),
rotation_call_string("rz", 0.1, 1),
rotation_call_string("rz", 4.8, 1),
single_op_call_string("h", 0),
double_op_call_string("cnot", 0, 1),
rotation_call_string("rx", 0.1, 1),
rotation_call_string("ry", 0.5, 1),
rotation_call_string("rx", 4.8, 1),
rotation_call_string("ry", 5, 1),
result_record_output_string(0),
result_record_output_string(1),
return_string(),
Expand All @@ -338,8 +338,22 @@ def _validate_nested_custom_op(entry_body: List[str]):


def _validate_complex_custom_op(entry_body: List[str]):
pass
# todo...
complex_op_lines = [
initialize_call_string(),
single_op_call_string("h", 0),
single_op_call_string("x", 0),
rotation_call_string("rx", 0.5, 0),
rotation_call_string("ry", 0.1, 0),
rotation_call_string("rz", 0.2, 0),
double_op_call_string("cnot", 0, 1),
result_record_output_string(0),
result_record_output_string(1),
return_string(),
]

assert len(entry_body) == len(complex_op_lines), "Incorrect number of lines in complex op"
for i in range(len(entry_body)):
assert entry_body[i].strip() == complex_op_lines[i].strip(), "Incorrect complex op line"


def check_custom_qasm_gate_op(qir: List[str], test_type: str):
Expand All @@ -353,3 +367,25 @@ def check_custom_qasm_gate_op(qir: List[str], test_type: str):
_validate_complex_custom_op(entry_body)
else:
assert False, f"Unknown test type {test_type} for custom ops"


def check_expressions(
qir: List[str], expected_ops: int, gates: List[str], expression_values, qubits: List[int]
):
entry_body = get_entry_point_body(qir)
op_count = 0
q_id = 0

for line in entry_body:
if line.strip().startswith("call") and "qis__" in line:
assert line.strip() == rotation_call_string(
gates[q_id], expression_values[q_id], qubits[q_id]
), f"Incorrect rotation gate call in qir - {line}"
op_count += 1
q_id += 1

if op_count == expected_ops:
break

if op_count != expected_ops:
assert False, f"Incorrect rotation gate count: {expected_ops} expected, {op_count} actual"

0 comments on commit 95c303e

Please sign in to comment.