Skip to content

Commit

Permalink
dialects: (stim) add qubit attribute and qubit coordinate attribute (#…
Browse files Browse the repository at this point in the history
…3114)

Add a base attribute to the stim dialect - StimAttr which is
StimPrintable so that all implementing attributes must implement
print_stim.

Add two initial attributes - QubitAttr to indicate qubits, and
QubitCoordAttr which provides a pair of a coordinate of a physical qubit
to a QubitAttr

---------

Co-authored-by: Emilien Bauer <papychacal@gmail.com>
Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
3 people authored Sep 30, 2024
1 parent d9da45c commit 34d3d11
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 15 deletions.
27 changes: 26 additions & 1 deletion tests/dialects/stim/test_stim_printer_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@
import pytest

from xdsl.dialects import stim
from xdsl.dialects.stim.stim_printer_parser import StimPrinter
from xdsl.dialects.stim.ops import QubitAttr, QubitMappingAttr
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.dialects.test import TestOp
from xdsl.ir import Block, Region

################################################################################
# Utils for this test file #
################################################################################


def check_stim_print(program: StimPrintable, expected_stim: str):
res_io = StringIO()
printer = StimPrinter(stream=res_io)
program.print_stim(printer)
assert expected_stim == res_io.getvalue()


def test_empty_circuit():
empty_block = Block()
Expand All @@ -27,3 +39,16 @@ def test_stim_circuit_ops_stim_printable():
printer = StimPrinter(stream=res_io)

module.print_stim(printer)


def test_print_stim_qubit_attr():
qubit = QubitAttr(0)
expected_stim = "0"
check_stim_print(qubit, expected_stim)


def test_print_stim_qubit_coord_attr():
qubit = QubitAttr(0)
qubit_coord = QubitMappingAttr([0, 0], qubit)
expected_stim = "(0, 0) 0"
check_stim_print(qubit_coord, expected_stim)
13 changes: 13 additions & 0 deletions tests/filecheck/dialects/stim/attrs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: XDSL_ROUNDTRIP

"test.op"() {
qubit = !stim.qubit<0>,
qubitcoord = #stim.qubit_coord<(0,0), !stim.qubit<0>>
} : () -> ()

%qubit0 = "test.op"() : () -> (!stim.qubit<0>)

// CHECK: builtin.module {
// CHECK-NEXT: "test.op"() {"qubit" = !stim.qubit<0>, "qubitcoord" = #stim.qubit_coord<(0, 0), !stim.qubit<0>>} : () -> ()
// CHECK-NEXT: %qubit0 = "test.op"() : () -> !stim.qubit<0>
// CHECK-NEXT: }
6 changes: 5 additions & 1 deletion xdsl/dialects/stim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from xdsl.ir import Dialect

from .ops import StimCircuitOp
from .ops import QubitAttr, QubitMappingAttr, StimCircuitOp

Stim = Dialect(
"stim",
[
StimCircuitOp,
],
[
QubitAttr,
QubitMappingAttr,
],
)
103 changes: 95 additions & 8 deletions xdsl/dialects/stim/ops.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,111 @@
from abc import ABC
from collections.abc import Sequence
from io import StringIO

from xdsl.ir import Region
from xdsl.dialects.builtin import ArrayAttr, IntAttr
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.ir import ParametrizedAttribute, Region, TypeAttribute
from xdsl.irdl import (
IRDLOperation,
PyRDLOpDefinitionError,
ParameterDef,
irdl_attr_definition,
irdl_op_definition,
region_def,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer

from .stim_printer_parser import StimPrintable, StimPrinter

@irdl_attr_definition
class QubitAttr(StimPrintable, ParametrizedAttribute, TypeAttribute):
"""
Type for a single qubit.
"""

name = "stim.qubit"

qubit: ParameterDef[IntAttr]

def __init__(self, qubit: int | IntAttr) -> None:
if not isinstance(qubit, IntAttr):
qubit = IntAttr(qubit)
super().__init__(parameters=[qubit])

@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[IntAttr]:
with parser.in_angle_brackets():
qubit = parser.parse_integer(allow_negative=False, allow_boolean=False)
return (IntAttr(qubit),)

def print_parameters(self, printer: Printer) -> None:
with printer.in_angle_brackets():
printer.print(self.qubit.data)

def print_stim(self, printer: StimPrinter):
printer.print_string(f"{self.qubit.data}")


@irdl_attr_definition
class QubitMappingAttr(StimPrintable, ParametrizedAttribute):
"""
This attribute provides a way to indicate the required connectivity or layout of `physical` qubits.
It consists of two parameters:
1. A co-ordinate array (currently it only anticipates a pair of qubits, but this is not fixed)
2. A value associated with a qubit referred to in a circuit.
The co-ordinates may be used as a physical address of a qubit, or the relative address with respect to some known physical address.
Operations that attach this as a property may represent the lattice-like structure of a physical quantum computer by having a property with an ArrayAttr[QubitCoordsAttr].
"""

name = "stim.qubit_coord"

class StimOp(IRDLOperation, ABC):
def print_stim(self, printer: StimPrinter) -> None:
raise (PyRDLOpDefinitionError("print_stim not implemented!"))
coords: ParameterDef[ArrayAttr[IntAttr]]
qubit_name: ParameterDef[QubitAttr]

def __init__(
self, coords: list[int] | ArrayAttr[IntAttr], qubit_name: int | QubitAttr
) -> None:
if not isinstance(qubit_name, QubitAttr):
qubit_name = QubitAttr(qubit_name)
if not isinstance(coords, ArrayAttr):
coords = ArrayAttr(IntAttr(c) for c in coords)
super().__init__(parameters=[coords, qubit_name])

@classmethod
def parse_parameters(
cls, parser: AttrParser
) -> tuple[ArrayAttr[IntAttr], QubitAttr]:
parser.parse_punctuation("<")
coords = parser.parse_comma_separated_list(
delimiter=parser.Delimiter.PAREN,
parse=lambda: IntAttr(parser.parse_integer(allow_boolean=False)),
)
parser.parse_punctuation(",")
qubit = parser.parse_attribute()
if not isinstance(qubit, QubitAttr):
parser.raise_error("Expected qubit attr", at_position=parser.pos)
parser.parse_punctuation(">")
return (ArrayAttr(coords), qubit)

def print_parameters(self, printer: Printer) -> None:
with printer.in_angle_brackets():
printer.print("(")
for i, elem in enumerate(self.coords):
if i:
printer.print_string(", ")
printer.print(elem.data)
printer.print("), ")
printer.print(self.qubit_name)

def print_stim(self, printer: StimPrinter):
printer.print_attribute(self.coords)
printer.print_string(" ")
self.qubit_name.print_stim(printer)


@irdl_op_definition
class StimCircuitOp(StimOp, IRDLOperation):
class StimCircuitOp(StimPrintable, IRDLOperation):
"""
Base operation containing a stim program
"""
Expand Down
39 changes: 34 additions & 5 deletions xdsl/dialects/stim/stim_printer_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import abc
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
from typing import Any, TypeVar, cast

from xdsl.dialects.builtin import ArrayAttr, IntAttr
from xdsl.ir import Attribute


@dataclass(eq=False, repr=False)
Expand All @@ -14,10 +18,35 @@ def print_string(self, text: str) -> None:
@contextmanager
def in_braces(self):
self.print_string("{")
try:
yield
finally:
self.print_string("}")
yield
self.print_string("}")

@contextmanager
def in_parens(self):
self.print_string("(")
yield
self.print_string(")")

T = TypeVar("T")

def print_list(
self, elems: Iterable[T], print_fn: Callable[[T], Any], delimiter: str = ", "
) -> None:
for i, elem in enumerate(elems):
if i:
self.print_string(delimiter)
print_fn(elem)

def print_attribute(self, attribute: Attribute) -> None:
if isinstance(attribute, ArrayAttr):
attribute = cast(ArrayAttr[Attribute], attribute)
with self.in_parens():
self.print_list(attribute, self.print_attribute)
return
if isinstance(attribute, IntAttr):
self.print_string(f"{attribute.data}")
return
raise ValueError(f"Cannot print in stim format: {attribute}")


class StimPrintable(abc.ABC):
Expand Down

0 comments on commit 34d3d11

Please sign in to comment.