Skip to content

Commit

Permalink
Attempt to add static typing to API (#34)
Browse files Browse the repository at this point in the history
* Attempt to add static typing to API

* Include py.typed

* Add typed classifier

* All methods with static typing

* Less hacky handling of exit with mypy

* Refactor `type: ignore` out
  • Loading branch information
facelessuser authored Oct 17, 2021
1 parent 0a0dbfc commit 9b0f00e
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 75 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
2 changes: 1 addition & 1 deletion .pyspelling.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ matrix:
context_visible_first: true
delimiters:
# Ignore lint (noqa) and coverage (pragma) as well as shebang (#!)
- open: '^(?: *(?:noqa\b|pragma: no cover)|!)'
- open: '^(?: *(?:noqa\b|pragma: no cover)|!|type:)'
close: '$'
# Ignore Python encoding string -*- encoding stuff -*-
- open: '^ *-\*-'
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
recursive-include bracex *.py
recursive-include bracex *.py py.typed
recursive-include tests *.txt *.py
recursive-include docs/src/markdown *.md *.png *.gif *.html
recursive-include docs/src/dictionary *.txt
Expand Down
129 changes: 67 additions & 62 deletions bracex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
import itertools
import math
import re
from .__meta__ import __version_info__, __version__ # noqa: F401
from typing import TypeVar, List, Iterator, Pattern, Optional, Sequence, Union, Match
from . import __meta__

String = TypeVar('String', str, bytes)

__all__ = ('expand', 'iexpand')

__version__ = __meta__.__version__
__version_info__ = __meta__.__version_info__

_alpha = [chr(x) if x != 0x5c else '' for x in range(ord('A'), ord('z') + 1)]
_nalpha = list(reversed(_alpha))

Expand All @@ -38,46 +44,43 @@ class ExpansionLimitException(Exception):
"""Brace expansion limit exception."""


def expand(string, keep_escapes=False, limit=DEFAULT_LIMIT):
def expand(string: String, keep_escapes: bool = False, limit: int = DEFAULT_LIMIT) -> List[String]:
"""Expand braces."""

return list(iexpand(string, keep_escapes, limit))


def iexpand(string, keep_escapes=False, limit=DEFAULT_LIMIT):
def iexpand(string: String, keep_escapes: bool = False, limit: int = DEFAULT_LIMIT) -> Iterator[String]:
"""Expand braces and return an iterator."""

if isinstance(string, bytes):
is_bytes = True
string = string.decode('latin-1')

for entry in ExpandBrace(keep_escapes, limit).expand(string.decode('latin-1')):
yield entry.encode('latin-1')
else:
is_bytes = False

for entry in ExpandBrace(keep_escapes, limit).expand(string):
yield entry.encode('latin-1') if is_bytes else entry
for entry in ExpandBrace(keep_escapes, limit).expand(string):
yield entry


class StringIter(object):
class StringIter:
"""Preprocess replace tokens."""

def __init__(self, string):
def __init__(self, string: str) -> None:
"""Initialize."""

self._string = string
self._index = 0

def __iter__(self):
def __iter__(self) -> "StringIter": # pragma: no cover
"""Iterate."""

return self

def __next__(self):
def __next__(self) -> str:
"""Python 3 iterator compatible next."""

return self.iternext()

def match(self, pattern):
def match(self, pattern: Pattern[str]) -> Optional[Match[str]]:
"""Perform regex match at index."""

m = pattern.match(self._string, self._index)
Expand All @@ -86,30 +89,30 @@ def match(self, pattern):
return m

@property
def index(self):
def index(self) -> int:
"""Get current index."""

return self._index

def previous(self): # pragma: no cover
def previous(self) -> str: # pragma: no cover
"""Get previous char."""

return self._string[self._index - 1]

def advance(self, count):
def advance(self, count: int) -> None:
"""Advanced the index."""

self._index += count

def rewind(self, count):
def rewind(self, count: int) -> None:
"""Rewind index."""

if count > self._index: # pragma: no cover
raise ValueError("Can't rewind past beginning!")

self._index -= count

def iternext(self):
def iternext(self) -> str:
"""Iterate through characters of the string."""

try:
Expand All @@ -121,18 +124,18 @@ def iternext(self):
return char


class ExpandBrace(object):
class ExpandBrace:
"""Expand braces like in Bash."""

def __init__(self, keep_escapes=False, limit=DEFAULT_LIMIT):
def __init__(self, keep_escapes: bool = False, limit: int = DEFAULT_LIMIT) -> None:
"""Initialize."""

self.max_limit = limit
self.count = 0
self.expanding = False
self.keep_escapes = keep_escapes

def update_count(self, count):
def update_count(self, count: Union[int, List[int]]) -> None:
"""Update the count and assert if count exceeds the max limit."""

if isinstance(count, int):
Expand All @@ -148,26 +151,26 @@ def update_count(self, count):
'Brace expansion has exceeded the limit of {:d}'.format(self.max_limit)
)

def set_expanding(self):
def set_expanding(self) -> bool:
"""Set that we are expanding a sequence, and return whether a release is required by the caller."""

status = not self.expanding
if status:
self.expanding = True
return status

def is_expanding(self):
def is_expanding(self) -> bool:
"""Get status of whether we are expanding."""

return self.expanding

def release_expanding(self, release):
def release_expanding(self, release: bool) -> None:
"""Release the expand status."""

if release:
self.expanding = False

def get_escape(self, c, i):
def get_escape(self, c: str, i: StringIter) -> str:
"""Get an escape."""

try:
Expand All @@ -176,7 +179,7 @@ def get_escape(self, c, i):
escaped = ''
return c + escaped if self.keep_escapes else escaped

def squash(self, a, b):
def squash(self, a: Union[Iterator[str], Sequence[str]], b: Union[Iterator[str], Sequence[str]]) -> Iterator[str]:
"""
Returns a generator that squashes two iterables into one.
Expand All @@ -187,30 +190,31 @@ def squash(self, a, b):

return ((''.join(x) if isinstance(x, tuple) else x) for x in itertools.product(a, b))

def get_literals(self, c, i, depth):
def get_literals(self, c: str, i: StringIter, depth: int) -> Optional[Iterator[str]]:
"""
Get a string literal.
Gather all the literal chars up to opening curly or closing brace.
Also gather chars between braces and commas within a group (is_expanding).
"""

result = ['']
result = iter([''])
is_dollar = False

count = True
seq_count = []

try:
while c:
value = [c] # type: Union[Iterator[str], List[str]]
ignore_brace = is_dollar
is_dollar = False

if c == '$':
is_dollar = True

elif c == '\\':
c = [self.get_escape(c, i)]
value = [self.get_escape(c, i)]

elif not ignore_brace and c == '{':
# Try and get the group
Expand All @@ -223,7 +227,7 @@ def get_literals(self, c, i, depth):
diff = self.count - current_count
seq_count.append(diff)
count = False
c = seq
value = seq
except StopIteration:
# Searched to end of string
# and still didn't find it.
Expand All @@ -237,7 +241,7 @@ def get_literals(self, c, i, depth):
return (x for x in result)

# Squash the current set of literals.
result = self.squash(result, [c] if isinstance(c, str) else c)
result = self.squash(result, value)

c = next(i)
except StopIteration:
Expand All @@ -247,22 +251,22 @@ def get_literals(self, c, i, depth):
self.update_count(1 if count else seq_count)
return (x for x in result)

def combine(self, a, b):
def combine(self, a: Union[Iterator[str], Sequence[str]], b: Union[Iterator[str], Sequence[str]]) -> Iterator[str]:
"""A generator that combines two iterables."""

for l in (a, b):
for x in l:
yield x

def get_sequence(self, c, i, depth):
def get_sequence(self, c: str, i: StringIter, depth: int) -> Optional[Iterator[str]]:
"""
Get the sequence.
Get sequence between `{}`, such as: `{a,b}`, `{1..2[..inc]}`, etc.
It will basically crawl to the end or find a valid series.
"""

result = []
result = [] # type: Union[List[str], Iterator[str]]
release = self.set_expanding()
has_comma = False # Used to indicate validity of group (`{1..2}` are an exception).
is_empty = True # Tracks whether the current slot is empty `{slot,slot,slot}`.
Expand All @@ -276,7 +280,7 @@ def get_sequence(self, c, i, depth):
return (x for x in item)

try:
while c:
while True:
# Bash has some special top level logic. if `}` follows `{` but hasn't matched
# a group yet, keep going except when the first 2 bytes are `{}` which gets
# completely ignored.
Expand Down Expand Up @@ -319,11 +323,12 @@ def get_sequence(self, c, i, depth):
is_empty = False

c = next(i)

except StopIteration:
self.release_expanding(release)
raise

def get_range(self, i):
def get_range(self, i: StringIter) -> Optional[Iterator[str]]:
"""
Check and retrieve range if value is a valid range.
Expand All @@ -350,7 +355,7 @@ def get_range(self, i):

return None

def format_value(self, value, padding):
def format_value(self, value: int, padding: int) -> str:
"""Get padding adjusting for negative values."""

if padding:
Expand All @@ -359,17 +364,17 @@ def format_value(self, value, padding):
else:
return str(value)

def get_int_range(self, start, end, increment=None):
def get_int_range(self, start: str, end: str, increment: Optional[str] = None) -> Iterator[str]:
"""Get an integer range between start and end and increments of increment."""

first, last = int(start), int(end)
increment = int(increment) if increment is not None else 1
inc = int(increment) if increment is not None else 1
max_length = max(len(start), len(end))

# Zero doesn't make sense as an incrementer
# but like bash, just assume one
if increment == 0:
increment = 1
if inc == 0:
inc = 1

if start[0] == '-':
start = start[1:]
Expand All @@ -384,48 +389,48 @@ def get_int_range(self, start, end, increment=None):
padding = 0

if first < last:
self.update_count(math.ceil(abs(((last + 1) - first) / increment)))
r = range(first, last + 1, -increment if increment < 0 else increment)
self.update_count(math.ceil(abs(((last + 1) - first) / inc)))
r = range(first, last + 1, -inc if inc < 0 else inc)
else:
self.update_count(math.ceil(abs(((first + 1) - last) / increment)))
r = range(first, last - 1, increment if increment < 0 else -increment)
self.update_count(math.ceil(abs(((first + 1) - last) / inc)))
r = range(first, last - 1, inc if inc < 0 else -inc)

return (self.format_value(value, padding) for value in r)

def get_char_range(self, start, end, increment=None):
def get_char_range(self, start: str, end: str, increment: Optional[str] = None) -> Iterator[str]:
"""Get a range of alphabetic characters."""

increment = int(increment) if increment else 1
if increment < 0:
increment = -increment
inc = int(increment) if increment else 1
if inc < 0:
inc = -inc

# Zero doesn't make sense as an incrementer
# but like bash, just assume one
if increment == 0:
increment = 1
if inc == 0:
inc = 1

inverse = start > end
alpha = _nalpha if inverse else _alpha

start = alpha.index(start)
end = alpha.index(end)
first = alpha.index(start)
last = alpha.index(end)

if start < end:
self.update_count(math.ceil(((end + 1) - start) / increment))
return (c for c in alpha[start:end + 1:increment])
if first < last:
self.update_count(math.ceil(((last + 1) - first) / inc))
return (c for c in alpha[first:last + 1:inc])

else:
self.update_count(math.ceil(((start + 1) - end) / increment))
return (c for c in alpha[end:start + 1:increment])
self.update_count(math.ceil(((first + 1) - last) / inc))
return (c for c in alpha[last:first + 1:inc])

def expand(self, string):
def expand(self, string: str) -> Iterator[str]:
"""Expand."""

self.expanding = False
empties = []
found_literal = False
if string:
i = iter(StringIter(string))
i = StringIter(string)
value = self.get_literals(next(i), i, 0)
if value is not None:
for x in value:
Expand Down
Loading

0 comments on commit 9b0f00e

Please sign in to comment.