Skip to content

Commit

Permalink
api: move PunctuationSpelling to module-level (#2894)
Browse files Browse the repository at this point in the history
Turns out TypeAliases aren't allowed as ClassVars

Pylance errors 672 -> 306
  • Loading branch information
superlopuh authored Jul 17, 2024
1 parent f4c1820 commit 63ad0db
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 42 deletions.
14 changes: 7 additions & 7 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import ParseError, VerifyException
from xdsl.utils.lexer import Token
from xdsl.utils.lexer import PunctuationSpelling, Token
from xdsl.utils.str_enum import StrEnum

# pyright: reportPrivateUsage=false
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_is_punctuation_false(punctuation: Token.Kind):
"punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().values())
)
def test_is_spelling_of_punctuation_true(punctuation: Token.Kind):
value = cast(Token.PunctuationSpelling, punctuation.value)
value = cast(PunctuationSpelling, punctuation.value)
assert Token.Kind.is_spelling_of_punctuation(value)


Expand All @@ -612,14 +612,14 @@ def test_is_spelling_of_punctuation_false(punctuation: str):
"punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().values())
)
def test_get_punctuation_kind(punctuation: Token.Kind):
value = cast(Token.PunctuationSpelling, punctuation.value)
value = cast(PunctuationSpelling, punctuation.value)
assert punctuation.get_punctuation_kind_from_spelling(value) == punctuation


@pytest.mark.parametrize(
"punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_punctuation(punctuation: Token.PunctuationSpelling):
def test_parse_punctuation(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), punctuation)

res = parser.parse_punctuation(punctuation)
Expand All @@ -630,7 +630,7 @@ def test_parse_punctuation(punctuation: Token.PunctuationSpelling):
@pytest.mark.parametrize(
"punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_punctuation_fail(punctuation: Token.PunctuationSpelling):
def test_parse_punctuation_fail(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), "e +")
with pytest.raises(ParseError) as e:
parser.parse_punctuation(punctuation, " in test")
Expand All @@ -641,7 +641,7 @@ def test_parse_punctuation_fail(punctuation: Token.PunctuationSpelling):
@pytest.mark.parametrize(
"punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_optional_punctuation(punctuation: Token.PunctuationSpelling):
def test_parse_optional_punctuation(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), punctuation)
res = parser.parse_optional_punctuation(punctuation)
assert res == punctuation
Expand All @@ -651,7 +651,7 @@ def test_parse_optional_punctuation(punctuation: Token.PunctuationSpelling):
@pytest.mark.parametrize(
"punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_optional_punctuation_fail(punctuation: Token.PunctuationSpelling):
def test_parse_optional_punctuation_fail(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), "e +")
assert parser.parse_optional_punctuation(punctuation) is None

Expand Down
4 changes: 2 additions & 2 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from xdsl.printer import Printer
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.lexer import Token
from xdsl.utils.lexer import PunctuationSpelling

OperandOrResult = Literal[VarIRConstruct.OPERAND, VarIRConstruct.RESULT]

Expand Down Expand Up @@ -835,7 +835,7 @@ class PunctuationDirective(OptionallyParsableDirective):
additionally neither `<`, `(`, `}`, `]`, if the last element was not a punctuation.
"""

punctuation: Token.PunctuationSpelling
punctuation: PunctuationSpelling
"""The punctuation that should be printed/parsed."""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
Expand Down
21 changes: 14 additions & 7 deletions xdsl/parser/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from typing import NoReturn, TypeVar, overload

from xdsl.utils.exceptions import ParseError
from xdsl.utils.lexer import Lexer, Position, Span, StringLiteral, Token
from xdsl.utils.lexer import (
Lexer,
Position,
PunctuationSpelling,
Span,
StringLiteral,
Token,
)
from xdsl.utils.str_enum import StrEnum


Expand Down Expand Up @@ -505,11 +512,11 @@ def parse_keyword(self, keyword: str, context_msg: str = "") -> str:
self.raise_error(error_msg)

def parse_optional_punctuation(
self, punctuation: Token.PunctuationSpelling
) -> Token.PunctuationSpelling | None:
self, punctuation: PunctuationSpelling
) -> PunctuationSpelling | None:
"""
Parse a punctuation, if it is present. Otherwise, return None.
Punctuations are defined by `Token.PunctuationSpelling`.
Punctuations are defined by `PunctuationSpelling`.
"""
# This check is only necessary to catch errors made by users that
# are not using pyright.
Expand All @@ -522,11 +529,11 @@ def parse_optional_punctuation(
return None

def parse_punctuation(
self, punctuation: Token.PunctuationSpelling, context_msg: str = ""
) -> Token.PunctuationSpelling:
self, punctuation: PunctuationSpelling, context_msg: str = ""
) -> PunctuationSpelling:
"""
Parse a punctuation. Punctuations are defined by
`Token.PunctuationSpelling`.
`PunctuationSpelling`.
"""
# This check is only necessary to catch errors made by users that
# are not using pyright.
Expand Down
53 changes: 27 additions & 26 deletions xdsl/utils/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from io import StringIO
from string import hexdigits
from typing import ClassVar, Literal, TypeAlias, TypeGuard, cast, overload
from typing import Literal, TypeAlias, TypeGuard, cast, overload

from xdsl.utils.exceptions import ParseError

Expand Down Expand Up @@ -202,6 +202,30 @@ def bytes_contents(self) -> bytes:
return bytes(bytes_contents)


PunctuationSpelling: TypeAlias = Literal[
"->",
":",
",",
"...",
"=",
">",
"{",
"(",
"[",
"<",
"-",
"+",
"?",
"}",
")",
"]",
"*",
"|",
"{-#",
"#-}",
]


@dataclass
class Token:
class Kind(Enum):
Expand Down Expand Up @@ -282,43 +306,20 @@ def is_punctuation(self) -> bool:
@staticmethod
def is_spelling_of_punctuation(
spelling: str,
) -> TypeGuard[Token.PunctuationSpelling]:
) -> TypeGuard[PunctuationSpelling]:
punctuation_dict = Token.Kind.get_punctuation_spelling_to_kind_dict()
return spelling in punctuation_dict.keys()

@staticmethod
def get_punctuation_kind_from_spelling(
spelling: Token.PunctuationSpelling,
spelling: PunctuationSpelling,
) -> Token.Kind:
assert Token.Kind.is_spelling_of_punctuation(spelling), (
"Kind.get_punctuation_kind_from_spelling: spelling is not a "
"valid punctuation spelling!"
)
return Token.Kind.get_punctuation_spelling_to_kind_dict()[spelling]

PunctuationSpelling: ClassVar[TypeAlias] = Literal[
"->",
":",
",",
"...",
"=",
">",
"{",
"(",
"[",
"<",
"-",
"+",
"?",
"}",
")",
"]",
"*",
"|",
"{-#",
"#-}",
]

kind: Kind

span: Span
Expand Down

0 comments on commit 63ad0db

Please sign in to comment.