Skip to content

Commit

Permalink
fix: refactor codegen to use a statement tree
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Dec 7, 2023
1 parent ed97089 commit 53a4ba9
Showing 1 changed file with 47 additions and 32 deletions.
79 changes: 47 additions & 32 deletions bolt/codegen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"Codegen",
"Accumulator",
"CodegenStatement",
"ChildrenCollector",
"CommandCollector",
"RootCommandCollector",
Expand Down Expand Up @@ -89,18 +90,34 @@
from .module import CodegenResult, MacroLibrary


@dataclass(slots=True)
class CodegenStatement:
"""Python statement emitted by the codegen, which can recursively contain other statements."""

code: str
lineno: Optional[int] = None
children: List["CodegenStatement"] = field(default_factory=list)

def flatten(self, indent: str = "") -> Iterable[Tuple[str, Optional[int]]]:
"""Yield the indented statements with their associated line number."""
yield f"{indent}{self.code}", self.lineno
if self.children:
indent += 4 * " "
for child_statement in self.children:
yield from child_statement.flatten(indent)


@dataclass
class Accumulator:
"""Utility for generating python code."""

indentation: str = ""
refs: List[Any] = field(default_factory=list)
dependencies: Set[str] = field(default_factory=set)
prelude_imports: List[AstPrelude] = field(default_factory=list)
macros: MacroLibrary = field(default_factory=dict)
macro_ids: Dict[str, int] = field(default_factory=dict)
memo_index: Dict[AstMemo, int] = field(default_factory=dict)
lines: List[str] = field(default_factory=list)
statements: List["CodegenStatement"] = field(default_factory=list)
counter: int = 0
header: Dict[str, str] = field(default_factory=dict)
root_scope: bool = True
Expand All @@ -111,23 +128,21 @@ class Accumulator:

def get_source(self) -> str:
"""Return the source code."""
header = "".join(
f"{variable} = {expression}\n"
header = [
CodegenStatement(f"{variable} = {expression}")
for variable, expression in self.header.items()
)
]

lines: List[str] = ["_bolt_lineno = "]
numbers1: List[int] = [1]
numbers2: List[int] = [1]

for line in (header + "".join(self.lines)).splitlines():
if line.startswith("!lineno "):
current_line = int(line[8:])
if numbers2[-1] != current_line:
for statement in header + self.statements:
for code, lineno in statement.flatten():
if lineno and numbers2[-1] != lineno:
numbers1.append(len(lines) + 1)
numbers2.append(current_line)
else:
lines.append(line)
numbers2.append(lineno)
lines.append(code)

lines[0] += f"{numbers1}, {numbers2}"

Expand Down Expand Up @@ -208,27 +223,27 @@ def get_macro(self, name: str) -> str:
self.macro_ids[name] = len(self.macro_ids)
return f"_bolt_macro{self.macro_ids[name]}"

def lineno(self, lineno: Any):
"""Emit line number."""
def extract_lineno(self, lineno: Any):
"""Utility to extract the line number."""
if isinstance(lineno, AstNode) and not lineno.location.unknown:
lineno = lineno.location.lineno
if isinstance(lineno, int):
self.lines.append(f"!lineno {lineno}\n")
return lineno
return None

@contextmanager
def block(self):
"""Wrap statements in an indented block."""
previous_indentation = self.indentation
self.indentation += " "
previous_statements = self.statements
self.statements = self.statements[-1].children
try:
yield
finally:
self.indentation = previous_indentation
self.statements = previous_statements

def statement(self, code: str, *, lineno: Any = None):
"""Emit statement."""
self.lineno(lineno)
self.lines.append(f"{self.indentation}{code}\n")
self.statements.append(CodegenStatement(code, self.extract_lineno(lineno)))

@contextmanager
def function(self, name: str, *args: str, return_type: str = ""):
Expand Down Expand Up @@ -262,13 +277,13 @@ def else_statement(self):
with self.if_statement(self.condition_inverse):
yield

def enclose(self, code: str, from_index: int):
"""Enclose lines starting from the given index."""
self.lines[from_index:] = [
line if line.startswith("!") else f" {line}"
for line in self.lines[from_index:]
def enclose(self, code: str, from_index: int, *, lineno: Any = None):
"""Enclose statements starting from the given index."""
self.statements[from_index:] = [
CodegenStatement(
code, self.extract_lineno(lineno), self.statements[from_index:]
)
]
self.lines.insert(from_index, f"{self.indentation}{code}\n")


@dataclass
Expand Down Expand Up @@ -364,7 +379,7 @@ def visit_multiple(
"""Yield all the nodes and return a single result pointing to the new children."""
current_count = 0
collector: Optional[ChildrenCollector] = None
index = len(acc.lines)
index = len(acc.statements)

previous_siblings = acc.current_siblings
previous_sibling_index = acc.current_sibling_index
Expand All @@ -379,14 +394,14 @@ def visit_multiple(
if not collector:
collector = children_collector(acc, index)

lines = acc.lines[index:]
del acc.lines[index:]
statements = acc.statements[index:]
del acc.statements[index:]
collector.add_static(*children[current_count:i])
acc.lines.extend(lines)
acc.statements.extend(statements)
collector.add_dynamic(*result)

current_count = i + 1
index = len(acc.lines)
index = len(acc.statements)

acc.current_siblings = previous_siblings
acc.current_sibling_index = previous_sibling_index
Expand Down Expand Up @@ -536,7 +551,7 @@ def command(
return [acc.replace(acc.make_ref(node), arguments=arguments)]

arguments = yield from visit_multiple(node.arguments[:-1], acc)
nesting_index = len(acc.lines)
nesting_index = len(acc.statements)
nesting = yield from visit_single(node.arguments[-1])

if nesting is None:
Expand Down

0 comments on commit 53a4ba9

Please sign in to comment.