Skip to content

Commit

Permalink
refactor!: Stop inlining array.__getitem__ and arrary.__setitem__ (#799)
Browse files Browse the repository at this point in the history
Closes #786

BREAKING CHANGE: `CompiledGlobals` renamed to `CompilerContext`
  • Loading branch information
tatiana-s authored Feb 13, 2025
1 parent 08dcae8 commit bb199a0
Show file tree
Hide file tree
Showing 20 changed files with 364 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,4 @@ devenv.local.nix
.pre-commit-config.yaml
/.envrc

**/.benchmarks/
**/.benchmarks/
12 changes: 6 additions & 6 deletions guppylang/compiler/cfg_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG, Row, Signature
from guppylang.checker.core import Place, Variable
from guppylang.compiler.core import (
CompiledGlobals,
CompilerContext,
DFContainer,
is_return_var,
return_var,
Expand All @@ -24,7 +24,7 @@ def compile_cfg(
cfg: CheckedCFG[Place],
container: DfBase[DP],
inputs: Sequence[Wire],
globals: CompiledGlobals,
ctx: CompilerContext,
) -> hc.Cfg:
"""Compiles a CFG to Hugr."""
# Patch the CFG with dummy return variables
Expand Down Expand Up @@ -54,7 +54,7 @@ def compile_cfg(

blocks: dict[CheckedBB[Place], ToNode] = {}
for bb in cfg.bbs:
blocks[bb] = compile_bb(bb, builder, bb == cfg.entry_bb, globals)
blocks[bb] = compile_bb(bb, builder, bb == cfg.entry_bb, ctx)
for bb in cfg.bbs:
for i, succ in enumerate(bb.successors):
builder.branch(blocks[bb][i], blocks[succ])
Expand All @@ -66,7 +66,7 @@ def compile_bb(
bb: CheckedBB[Place],
builder: hc.Cfg,
is_entry: bool,
globals: CompiledGlobals,
ctx: CompilerContext,
) -> ToNode:
"""Compiles a single basic block to Hugr.
Expand Down Expand Up @@ -94,12 +94,12 @@ def compile_bb(
dfg = DFContainer(block)
for v, wire in zip(inputs, block.input_node, strict=True):
dfg[v] = wire
dfg = StmtCompiler(globals).compile_stmts(bb.statements, dfg)
dfg = StmtCompiler(ctx).compile_stmts(bb.statements, dfg)

# If we branch, we also have to compile the branch predicate
if len(bb.successors) > 1:
assert bb.branch_pred is not None
branch_port = ExprCompiler(globals).compile(bb.branch_pred, dfg)
branch_port = ExprCompiler(ctx).compile(bb.branch_pred, dfg)
else:
# Even if we don't branch, we still have to add a `Sum(())` predicates
branch_port = dfg.builder.add_op(ops.Tag(0, ht.UnitSum(1)))
Expand Down
50 changes: 46 additions & 4 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import itertools
from abc import ABC
from dataclasses import dataclass, field
from typing import cast

from hugr import Wire, ops
from hugr import tys as ht
from hugr.build import function as hf
from hugr.build.dfg import DP, DefinitionBuilder, DfBase

from guppylang.checker.core import FieldAccess, Globals, Place, PlaceId, Variable
Expand All @@ -15,7 +18,23 @@
CompiledLocals = dict[PlaceId, Wire]


class CompiledGlobals:
@dataclass(frozen=True)
class GlobalConstId:
id: int
base_name: str

_fresh_ids = itertools.count()

@staticmethod
def fresh(base_name: str) -> "GlobalConstId":
return GlobalConstId(next(GlobalConstId._fresh_ids), base_name)

@property
def name(self) -> str:
return f"{self.base_name}.{self.id}"


class CompilerContext:
"""Compilation context containing all available definitions.
Maintains a `worklist` of definitions which have been used by other compiled code
Expand All @@ -28,6 +47,8 @@ class CompiledGlobals:
compiled: dict[DefId, CompiledDef]
worklist: set[DefId]

global_funcs: dict[GlobalConstId, hf.Function]

checked_globals: Globals

def __init__(
Expand All @@ -40,6 +61,7 @@ def __init__(
self.checked = checked
self.worklist = set()
self.compiled = {}
self.global_funcs = {}
self.checked_globals = checked_globals

def build_compiled_def(self, def_id: DefId) -> CompiledDef:
Expand Down Expand Up @@ -81,6 +103,26 @@ def get_instance_func(
assert isinstance(compiled_func, CompiledCallableDef)
return compiled_func

def declare_global_func(
self,
const_id: GlobalConstId,
func_ty: ht.PolyFuncType,
) -> tuple[hf.Function, bool]:
"""
Creates a function builder for a global function if it doesn't already exist,
else returns the existing one.
"""
if const_id in self.global_funcs:
return self.global_funcs[const_id], True
func = self.module.define_function(
name=const_id.name,
input_types=func_ty.body.input,
output_types=func_ty.body.output,
type_params=func_ty.params,
)
self.global_funcs[const_id] = func
return func, False


@dataclass
class DFContainer:
Expand Down Expand Up @@ -155,10 +197,10 @@ def __copy__(self) -> "DFContainer":
class CompilerBase(ABC):
"""Base class for the Guppy compiler."""

globals: CompiledGlobals
ctx: CompilerContext

def __init__(self, globals: CompiledGlobals) -> None:
self.globals = globals
def __init__(self, ctx: CompilerContext) -> None:
self.ctx = ctx


def return_var(n: int) -> str:
Expand Down
20 changes: 9 additions & 11 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,15 @@ def visit_PlaceNode(self, node: PlaceNode) -> Wire:
return self.dfg[node.place]

def visit_GlobalName(self, node: GlobalName) -> Wire:
defn = self.globals.build_compiled_def(node.def_id)
defn = self.ctx.build_compiled_def(node.def_id)
assert isinstance(defn, CompiledValueDef)
if isinstance(defn, CompiledCallableDef) and defn.ty.parametrized:
# TODO: This should be caught during checking
err = UnsupportedError(
node, "Polymorphic functions as dynamic higher-order values"
)
raise GuppyError(err)
return defn.load(self.dfg, self.globals, node)
return defn.load(self.dfg, self.ctx, node)

def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
match node.param.ty:
Expand Down Expand Up @@ -363,13 +363,11 @@ def _compile_tensor_with_leftovers(
raise InternalGuppyError("Tensor element wasn't function or tuple")

def visit_GlobalCall(self, node: GlobalCall) -> Wire:
func = self.globals.build_compiled_def(node.def_id)
func = self.ctx.build_compiled_def(node.def_id)
assert isinstance(func, CompiledCallableDef)

args = [self.visit(arg) for arg in node.args]
rets = func.compile_call(
args, list(node.type_args), self.dfg, self.globals, node
)
rets = func.compile_call(args, list(node.type_args), self.dfg, self.ctx, node)
if isinstance(func, CustomFunctionDef) and not func.has_signature:
func_ty = FunctionType(
[FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
Expand Down Expand Up @@ -397,7 +395,7 @@ def visit_TypeApply(self, node: TypeApply) -> Wire:
# For now, we can only TypeApply global FunctionDefs/Decls.
if not isinstance(node.value, GlobalName):
raise InternalGuppyError("Dynamic TypeApply not supported yet!")
defn = self.globals.build_compiled_def(node.value.def_id)
defn = self.ctx.build_compiled_def(node.value.def_id)
assert isinstance(defn, CompiledCallableDef)

# We have to be very careful here: If we instantiate `foo: forall T. T -> T`
Expand All @@ -413,7 +411,7 @@ def visit_TypeApply(self, node: TypeApply) -> Wire:
)
raise GuppyError(err)

return defn.load_with_args(node.inst, self.dfg, self.globals, node)
return defn.load_with_args(node.inst, self.dfg, self.ctx, node)

def visit_UnaryOp(self, node: ast.UnaryOp) -> Wire:
# The only case that is not desugared by the type checker is the `not` operation
Expand Down Expand Up @@ -547,9 +545,9 @@ def _build_method_call(
args: list[Wire],
type_args: Sequence[Argument] | None = None,
) -> CallReturnWires:
func = self.globals.get_instance_func(ty, method)
func = self.ctx.get_instance_func(ty, method)
assert func is not None
return func.compile_call(args, type_args or [], self.dfg, self.globals, node)
return func.compile_call(args, type_args or [], self.dfg, self.ctx, node)

@contextmanager
def _build_generators(
Expand All @@ -562,7 +560,7 @@ def _build_generators(
"""
from guppylang.compiler.stmt_compiler import StmtCompiler

compiler = StmtCompiler(self.globals)
compiler = StmtCompiler(self.ctx)
with ExitStack() as stack:
for gen in gens:
# Build the generator
Expand Down
14 changes: 7 additions & 7 deletions guppylang/compiler/func_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hugr.build.function import Function

from guppylang.compiler.cfg_compiler import compile_cfg
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.compiler.core import CompilerContext, DFContainer
from guppylang.compiler.hugr_extension import PartialOp
from guppylang.nodes import CheckedNestedFunctionDef

Expand All @@ -16,17 +16,17 @@
def compile_global_func_def(
func: "CheckedFunctionDef",
builder: Function,
globals: CompiledGlobals,
ctx: CompilerContext,
) -> None:
"""Compiles a top-level function definition to Hugr."""
cfg = compile_cfg(func.cfg, builder, builder.inputs(), globals)
cfg = compile_cfg(func.cfg, builder, builder.inputs(), ctx)
builder.set_outputs(*cfg)


def compile_local_func_def(
func: CheckedNestedFunctionDef,
dfg: DFContainer,
globals: CompiledGlobals,
ctx: CompilerContext,
) -> Wire:
"""Compiles a local (nested) function definition to Hugr and loads it into a value.
Expand Down Expand Up @@ -64,13 +64,13 @@ def compile_local_func_def(
func.cfg.input_tys.append(func.ty)

# Compile the CFG
cfg = compile_cfg(func.cfg, func_builder, call_args, globals)
cfg = compile_cfg(func.cfg, func_builder, call_args, ctx)
func_builder.set_outputs(*cfg)
else:
# Otherwise, we treat the function like a normal global variable
from guppylang.definition.function import CompiledFunctionDef

globals.compiled[func.def_id] = CompiledFunctionDef(
ctx.compiled[func.def_id] = CompiledFunctionDef(
func.def_id,
func.name,
func,
Expand All @@ -80,7 +80,7 @@ def compile_local_func_def(
func.cfg,
func_builder,
)
globals.worklist.add(func.def_id) # will compile the CFG later
ctx.worklist.add(func.def_id) # will compile the CFG later

# Finally, load the function into the local data-flow graph
loaded = dfg.builder.load_function(func_builder, closure_ty)
Expand Down
10 changes: 5 additions & 5 deletions guppylang/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from guppylang.checker.core import SubscriptAccess, Variable
from guppylang.checker.linearity_checker import contains_subscript
from guppylang.compiler.core import (
CompiledGlobals,
CompilerBase,
CompilerContext,
DFContainer,
return_var,
)
Expand Down Expand Up @@ -41,9 +41,9 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):

dfg: DFContainer

def __init__(self, globals: CompiledGlobals):
super().__init__(globals)
self.expr_compiler = ExprCompiler(globals)
def __init__(self, ctx: CompilerContext):
super().__init__(ctx)
self.expr_compiler = ExprCompiler(ctx)

def compile_stmts(
self,
Expand Down Expand Up @@ -197,5 +197,5 @@ def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None
from guppylang.compiler.func_compiler import compile_local_func_def

var = Variable(node.name, node.ty, node)
loaded_func = compile_local_func_def(node, self.dfg, self.globals)
loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
self.dfg[var] = loaded_func
2 changes: 1 addition & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def compile_function(self, f_def: RawFunctionDef) -> FuncDefnPointer:
raise ValueError("Function definition must belong to a module")
compiled_module = module.compile()
assert module._compiled is not None, "Module should be compiled"
globs = module._compiled.globs
globs = module._compiled.context
assert globs is not None
compiled_def = globs.build_compiled_def(f_def.id)
assert isinstance(compiled_def, CompiledFunctionDef)
Expand Down
4 changes: 2 additions & 2 deletions guppylang/definition/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompiledGlobals
from guppylang.compiler.core import CompilerContext
from guppylang.module import GuppyModule


Expand Down Expand Up @@ -152,7 +152,7 @@ class CompiledDef(Definition):
defined_at: The AST node where the definition was defined.
"""

def compile_inner(self, globals: "CompiledGlobals") -> None:
def compile_inner(self, ctx: "CompilerContext") -> None:
"""Optional hook that is called to fill in the content of the Hugr node.
Opposed to `CompilableDef.compile()`, we have access to all other compiled
Expand Down
4 changes: 2 additions & 2 deletions guppylang/definition/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from guppylang.ast_util import AstNode
from guppylang.checker.core import Globals
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.compiler.core import CompilerContext, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.value import CompiledValueDef, ValueDef
from guppylang.span import SourceMap
Expand Down Expand Up @@ -58,6 +58,6 @@ class CompiledConstDef(ConstDef, CompiledValueDef):

const_node: Node

def load(self, dfg: DFContainer, globals: CompiledGlobals, node: AstNode) -> Wire:
def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:
"""Loads the extern value into a local Hugr dataflow graph."""
return dfg.builder.load(self.const_node)
Loading

0 comments on commit bb199a0

Please sign in to comment.