From 9402b574238f6fcd249437d53e2644c39d69d394 Mon Sep 17 00:00:00 2001 From: Artem Pulkin Date: Sat, 20 Jan 2024 15:18:04 +0100 Subject: [PATCH] bulk: refactor disassembly --- CHANGELOG.md | 3 +- cython/frame.pyx | 4 +- pyteleport/bytecode/__init__.py | 1 + pyteleport/bytecode/minias.py | 787 ++++++++++++++++++++++ pyteleport/{ => bytecode}/opcodes.py | 19 + pyteleport/bytecode/primitives.py | 388 +++++++++++ pyteleport/bytecode/printing.py | 91 +++ pyteleport/bytecode/sequence_assembler.py | 113 ++++ pyteleport/bytecode/tests/__init__.py | 0 pyteleport/bytecode/tests/test_as.py | 82 +++ pyteleport/bytecode/util.py | 107 +++ pyteleport/minias.py | 483 ------------- pyteleport/morph.py | 235 ++++--- pyteleport/printtools.py | 80 --- pyteleport/snapshot.py | 18 +- pyteleport/{ => tests}/test_scripts.py | 12 +- pyteleport/util.py | 24 - 17 files changed, 1769 insertions(+), 678 deletions(-) create mode 100644 pyteleport/bytecode/__init__.py create mode 100644 pyteleport/bytecode/minias.py rename pyteleport/{ => bytecode}/opcodes.py (52%) create mode 100644 pyteleport/bytecode/primitives.py create mode 100644 pyteleport/bytecode/printing.py create mode 100644 pyteleport/bytecode/sequence_assembler.py create mode 100644 pyteleport/bytecode/tests/__init__.py create mode 100644 pyteleport/bytecode/tests/test_as.py create mode 100644 pyteleport/bytecode/util.py delete mode 100644 pyteleport/minias.py delete mode 100644 pyteleport/printtools.py rename pyteleport/{ => tests}/test_scripts.py (74%) diff --git a/CHANGELOG.md b/CHANGELOG.md index be4ca7e..3f7225f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,10 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- re-worked bytecode assembly and added tests - the use of sockets to transmit object data instead of saving it in a code object. As a result, there is no memory overhead related to serialized data while morph `pyc` files remain very small -- experimentation with AWS EC2 +- added experimentation with AWS EC2 ### Fixed diff --git a/cython/frame.pyx b/cython/frame.pyx index 1b2f0d1..d5ec8e4 100644 --- a/cython/frame.pyx +++ b/cython/frame.pyx @@ -23,7 +23,9 @@ cdef extern from "frameobject.h": cdef extern from *: # stack depth for different python versions """ #if PY_VERSION_HEX >= 0x030A0000 - static int _pyteleport_stackdepth(struct _frame* frame) {return frame->f_stackdepth;} + static int _pyteleport_stackdepth(struct _frame* frame) { + return frame->f_stackdepth; + } #elif PY_VERSION_HEX >= 0x03080000 static int _pyteleport_stackdepth(struct _frame* frame) { if (frame->f_stacktop) diff --git a/pyteleport/bytecode/__init__.py b/pyteleport/bytecode/__init__.py new file mode 100644 index 0000000..8a32f99 --- /dev/null +++ b/pyteleport/bytecode/__init__.py @@ -0,0 +1 @@ +from .minias import jump_multiplier, disassemble, ObjectBytecode as Bytecode diff --git a/pyteleport/bytecode/minias.py b/pyteleport/bytecode/minias.py new file mode 100644 index 0000000..98341b8 --- /dev/null +++ b/pyteleport/bytecode/minias.py @@ -0,0 +1,787 @@ +import sys +from collections import defaultdict +from dataclasses import dataclass +from dis import get_instructions as dis_get_instructions, _get_code_object, Instruction +from functools import partial +from io import StringIO +from opcode import EXTENDED_ARG, HAVE_ARGUMENT, opmap, hasjrel, hasjabs, hasconst, hasname, haslocal, hasfree, opname +from types import CodeType +from typing import Callable, Optional, Iterable, Iterator, Sequence + +from .primitives import AbstractBytecodePrintable, FixedCell, FloatingCell, EncodedInstruction, ReferencingInstruction, \ + NoArgInstruction, ConstInstruction, NameInstruction, jump_multiplier, no_step_opcodes +from .util import IndexStorage, NameStorage, Cell, log_iter +from .sequence_assembler import LookBackSequence, assemble as assemble_sequence +from .opcodes import guess_entering_stack_size, python_version, RETURN_VALUE + +NOP = opmap["NOP"] +if python_version > 0x030A: # 3.11 and above + from opcode import _inline_cache_entries +else: + _inline_cache_entries = {} + + +def offset_to_jump(opcode: int, offset: int, pos: Optional[int], x: int = jump_multiplier) -> int: + """ + Computes jump argument from the provided offset information. + + Parameters + ---------- + opcode + The jumping opcode. + offset + Jump destination. + pos + The jumping opcode offset. + x + The jump multiplier. + + Returns + ------- + The resulting argument. + """ + if opcode in hasjabs: + return offset // x + elif opcode in hasjrel: + return (offset - pos - 2) // x + else: + raise ValueError(f"{opcode=} {opname[opcode]} is not jumping") + + +def jump_to_offset(opcode: int, arg: int, pos: Optional[int], x: int = jump_multiplier) -> int: + """ + Computes jump argument from the provided offset information. + + Parameters + ---------- + opcode + The jumping opcode. + arg + Jump argument. + pos + The jumping opcode offset. + x + The jump multiplier. + + Returns + ------- + The resulting argument. + """ + if opcode in hasjabs: + return arg * x + elif opcode in hasjrel: + return arg * x + pos + 2 + else: + raise ValueError(f"{opcode=} {opname[opcode]} is not jumping") + + +def iter_slots(source) -> Iterator[FixedCell]: + """ + Generates slots from the raw bytecode data. + + Parameters + ---------- + source + The source of instructions. + + Yields + ------ + Bytecode slots with instructions inside. + """ + for instruction in source: + yield FixedCell( + offset=instruction.offset, + is_jump_target=instruction.is_jump_target, + instruction=EncodedInstruction( + opcode=instruction.opcode, + arg=instruction.arg or 0, + ) + ) + + +def get_instructions(code: CodeType) -> Iterator[Instruction]: + """ + A replica of `dis.get_instructions` with minor + modifications. + + Parameters + ---------- + code + The code to parse instructinos from. + + Yields + ------ + Individual instructions. + """ + for instruction in dis_get_instructions(code): + if instruction.arg is None: + arg = code.co_code[instruction.offset + 1] + instruction = instruction._replace(arg=arg) + yield instruction + + +def iter_extract(source) -> tuple[Iterable[FixedCell], CodeType]: + """ + Iterates over bytecodes from the source. + + Parameters + ---------- + source + Anything with the bytecode. + + Returns + ------ + Bytecode iterator and the corresponding code object. + """ + code_obj = _get_code_object(source) + return iter_slots(get_instructions(code_obj)), code_obj + + +def filter_nop(source: Iterable[FixedCell], keep_nop: bool = False) -> Iterator[FixedCell]: + """ + Filters out NOP and EXT_ARG. + Corrects offsets but does not apply extended args. + + Parameters + ---------- + source + The source of bytecode instructions. + keep_nop + If True, keeps NOP opcodes. + + Yields + ------ + Every instruction, except no-op and extended + args. + """ + head = None + for slot in source: + if not keep_nop and slot.instruction.opcode == NOP: + continue + if slot.instruction.opcode == EXTENDED_ARG: + if head is None: + head = slot + else: + assert not slot.is_jump_target + else: + if head is not None: + assert not slot.is_jump_target + slot = FixedCell( + offset=head.offset, + is_jump_target=head.is_jump_target, + instruction=slot.instruction, + ) + head = None + yield slot + + +def iter_dis_jumps(source: Iterable[FixedCell]) -> Iterator[FloatingCell]: + """ + Computes jumps. + + Parameters + ---------- + source + The source of bytecode slots. + + Yields + ------ + FloatingCell + The resulting cell with the referencing information. + """ + lookup: dict[int, FloatingCell] = {} + + stack_size = 0 + + for fixed_cell in source: + original = fixed_cell.instruction + + # determine the jump destination, if any + jump_destination = None + if original.opcode in hasjabs or original.opcode in hasjrel: + jump_destination = jump_to_offset( + original.opcode, + original.arg, + fixed_cell.offset, + ) + + # replace with the jump instruction + jumps_to = None + if jump_destination is not None: + try: + jumps_to = lookup[jump_destination] + except KeyError: + jumps_to = lookup[jump_destination] = FloatingCell( + instruction=None, + ) + + instruction = ReferencingInstruction( + opcode=original.opcode, + arg=jumps_to, + ) + else: + instruction = original + + floating_cell = None + if fixed_cell.is_jump_target: + try: + # check if already in lookup and replace (fw jump) + floating_cell = lookup[fixed_cell.offset] + except KeyError: + pass # add floating_cell to lookup later (bw jump) + else: + floating_cell.instruction = instruction + + if floating_cell is None: + floating_cell = FloatingCell( + instruction=instruction, + ) + stack_size += original.get_stack_effect(jump=False) + if fixed_cell.is_jump_target: + lookup[fixed_cell.offset] = floating_cell + + if jumps_to is not None: + jumps_to.referenced_by.append(floating_cell) + + yield floating_cell + + +def iter_dis_args( + source: Iterable[FloatingCell], + consts: Sequence[object], + names: Sequence[str], + varnames: Sequence[str], + cellnames: Sequence[str], +) -> Iterator[FloatingCell]: + """ + Pipes instructions from the input and computes object arguments. + + Parameters + ---------- + source + The source of bytecode instructions. + consts + A list of constants. + names + A list of names. + varnames + A list of local names. + cellnames + A llist of cells. + + Yields + ------ + Instructions with computed args. + """ + for slot in source: + instruction = slot.instruction + if isinstance(instruction, EncodedInstruction): + opcode = instruction.opcode + arg = instruction.arg + + if opcode < HAVE_ARGUMENT: + result = NoArgInstruction(opcode, arg) + else: + if opcode in hasconst: + result = ConstInstruction(opcode, consts[arg]) + elif opcode in hasname: + result = NameInstruction(opcode, names[arg]) + elif opcode in haslocal: + result = NameInstruction(opcode, varnames[arg]) + elif opcode in hasfree: + result = NameInstruction(opcode, cellnames[arg]) + else: + result = EncodedInstruction(opcode, arg) + + slot.instruction = result + yield slot + + +def iter_dis( + source: Iterable[FixedCell], + consts: Sequence[object], + names: Sequence[str], + varnames: Sequence[str], + cellnames: Sequence[str], + keep_nop: bool = False, + current_pos: int = None, +) -> Iterator[FloatingCell]: + """ + Disassembles encoded instructions. + The reverse of iter_as. + + Parameters + ---------- + source + The source of encoded instructions. + consts + names + varnames + cellnames + Constant and name collections. + keep_nop + If True, yields NOP as they are found + in the original bytecode. + current_pos + The position of currently executed + instruction. + + Yields + ------ + FloatingCell + The resulting cell with the referencing information. + """ + cell_fixed = Cell() + + for result in iter_dis_args( + iter_dis_jumps(filter_nop( + log_iter(source, cell_fixed), + keep_nop=keep_nop, + )), + consts, + names, + varnames, + cellnames, + ): + fixed: FixedCell = cell_fixed.value + result.metadata.source = fixed + if current_pos is not None: + result.metadata.mark_current = current_pos == fixed.offset + + yield result + + +def iter_as_args( + source: Iterable[FloatingCell], + consts: IndexStorage, + names: NameStorage, + varnames: NameStorage, + cellnames: NameStorage +) -> Iterator[FloatingCell]: + """ + Pipes instructions from the input and assembles their + object and name arguments. + + Parameters + ---------- + source + The source of bytecode instructions. + consts + Constant storage (modified by this iterator). + names + varnames + cellnames + Name storages (modified by this iterator). + + Yields + ------ + Instructions with assembled args. + """ + for slot in source: + instruction = slot.instruction + opcode = instruction.opcode + + if isinstance(instruction, ConstInstruction): + result = EncodedInstruction(opcode, consts.store(instruction.arg)) + elif isinstance(instruction, NameInstruction): + if opcode in hasname: + result = EncodedInstruction(opcode, names.store(instruction.arg)) + elif opcode in haslocal: + result = EncodedInstruction(opcode, varnames.store(instruction.arg)) + elif opcode in hasfree: + result = EncodedInstruction(opcode, cellnames.store(instruction.arg)) + else: + raise ValueError(f"unknown name instruction to process: {instruction}") + elif isinstance(instruction, NoArgInstruction): + result = EncodedInstruction(opcode, instruction.arg) + elif isinstance(instruction, (ReferencingInstruction, EncodedInstruction)): + result = instruction + else: + raise ValueError(f"unknown instruction to process: {instruction}") + + slot.instruction = result + + yield slot + + +def as_jumps(source: Iterable[FloatingCell]) -> list[FixedCell]: + """ + Pipes instructions from the input and assembles + jump destinations. + + Parameters + ---------- + source + The source of bytecode instructions. + + Returns + ------- + Instructions with assembled jumps. + """ + + class CellToken: + def __init__(self, cell: FloatingCell, cell_lookup: dict[FloatingCell, "CellToken"]): + instruction = cell.instruction + self.backward_reference_token = None + if isinstance(instruction, ReferencingInstruction): + try: + self.backward_reference_token = cell_lookup[instruction.arg] + except KeyError: + pass + instruction = EncodedInstruction( + opcode=instruction.opcode, + arg=0, + ) + + elif not isinstance(instruction, EncodedInstruction): + raise ValueError(f"cannot init with instruction: {instruction}") + + self.cell = FixedCell( + offset=0, + is_jump_target=bool(cell.referenced_by), + instruction=instruction, + ) + self.earlier_references_to_here = ref = [] + for i in cell.referenced_by: + try: + ref.append(cell_lookup[i]) + except KeyError: + pass + + def update_sequentially(self, prev: Optional["CellToken"]): + # update offset + if prev is None: + self.cell.offset = 0 + else: + self.cell.offset = prev.cell.offset + prev.cell.instruction.size_ext + # if jump: update arg + if self.backward_reference_token is not None: + self.update_jump(self.backward_reference_token) + + def update_jump(self, reference: "CellToken") -> bool: + if self.cell.instruction.opcode in hasjabs: + arg = reference.cell.offset // jump_multiplier + elif self.cell.instruction.opcode in hasjrel: + arg = (reference.cell.offset - self.cell.offset - self.cell.instruction.size_ext) // jump_multiplier + else: + raise ValueError(f"not a jump: {reference}") + old_size = self.cell.instruction.size_arg + self.cell.instruction = EncodedInstruction( + opcode=self.cell.instruction.opcode, + arg=arg, + ) + return self.cell.instruction.size_arg != old_size + + source = list(source) + lookup = {} + for floating in source: + lookup[floating] = CellToken(floating, lookup) + + result = LookBackSequence(lookup[i] for i in source) + assemble_sequence(result) + result.reset() + return list(i.cell for _, i in result) + + +def iter_as( + source: Iterable[FloatingCell], + consts: Optional[Sequence] = None, + names: Optional[Sequence] = None, + varnames: Optional[Sequence] = None, + cells: Optional[Sequence] = None, +) -> tuple[ + Iterable[FixedCell], + IndexStorage, + NameStorage, + NameStorage, + NameStorage, +]: + """ + Assembles decoded instructions. + The reverse of iter_dis. + + Parameters + ---------- + source + The source of decoded instructions. + consts + Initial constants. + names + varnames + cells + Initial names. + + Returns + ------- + The resulting bytecode, consts, names, varnames, and cellnames. + """ + consts = IndexStorage(consts or []) + names = NameStorage(names or []) + varnames = NameStorage(varnames or []) + cellnames = NameStorage(cells or []) + return as_jumps(iter_as_args( + source, + consts, + names, + varnames, + cellnames, + )), consts, names, varnames, cellnames + + +def assign_stack_size(source: list[FloatingCell], clean_start: bool = True) -> None: + """ + Computes and assigns stack size per instruction. + The computed values are available in `item.metadata.stack_size`. + + Parameters + ---------- + source + Bytecode instructions. + clean_start + If True, wipes previously computed stack sizes, if any. + """ + if not len(source): + return + if clean_start: + for i in source: + i.metadata.stack_size = None + starting = source[0] + starting.metadata.stack_size = guess_entering_stack_size(starting.instruction.opcode) + + # figure out starting points + chains = [] + for i, (cell, nxt) in enumerate(zip(source[:-1], source[1:])): + if cell.metadata.stack_size is not None and nxt.metadata.stack_size is None: + chains.append(i) + + while chains: + new_chains = [] + + for starting_point in chains: + for cell, nxt in zip(source[starting_point:], source[starting_point + 1:]): + + if isinstance(cell.instruction, ReferencingInstruction): + distant_stack_size = cell.metadata.stack_size + cell.instruction.get_stack_effect(jump=True) + distant = cell.instruction.arg + if distant.metadata.stack_size is None: + distant.metadata.stack_size = distant_stack_size + new_chains.append(source.index(distant)) + else: + assert distant_stack_size == distant.metadata.stack_size, \ + f"stack size computed from {cell} to {distant} (jump) mismatch: " \ + f"{distant_stack_size} vs previous {distant.metadata.stack_size}" + + if cell.instruction.opcode not in no_step_opcodes: + next_stack_size = cell.metadata.stack_size + cell.instruction.get_stack_effect(jump=False) + if nxt.metadata.stack_size is None: + if nxt.instruction.opcode == RETURN_VALUE: + assert next_stack_size == 1, f"non-zero stack at RETURN_VALUE: {next_stack_size}" + try: + nxt.metadata.stack_size = next_stack_size + except ValueError as e: + raise ValueError( + f"Failed unwinding the stack size; bytecode following (failing instruction marked)\n" + f"{ObjectBytecode(source, current=nxt).to_string()}") from e + else: + assert next_stack_size == nxt.metadata.stack_size, \ + f"stack size computed from {cell} to {nxt} (step) mismatch: " \ + f"{next_stack_size} vs previous {nxt.metadata.stack_size}" + else: + break + + chains = new_chains + + +@dataclass +class AbstractBytecode: + """An abstract bytecode""" + instructions: list[AbstractBytecodePrintable] + + def get_marks(self): + raise NotImplementedError + + def print(self, line_printer: Callable = print) -> None: + """ + Prints the bytecode. + + Parameters + ---------- + line_printer + A function printing lines. + """ + marks = self.get_marks() + for i in self.instructions: + mark = marks.get(i, '').rjust(3) + line_printer(f"{mark} {i.pprint()}") + + def to_string(self) -> str: + """Prints the bytecode and return the print""" + buffer = StringIO() + self.print(partial(print, file=buffer)) + return buffer.getvalue() + + +@dataclass +class ObjectBytecode(AbstractBytecode): + instructions: list[FloatingCell] + current: Optional[FloatingCell] = None + """ + An object bytecode. + + Parameters + ---------- + code + A list of opcode cells with object arguments. + current + Current bytecode operation. + """ + + def get_marks(self): + return {self.current: ">>>"} + + @classmethod + def from_iterable(cls, source: Iterable[FloatingCell], compute_stack_size: bool = True): + instructions = [] + current = None + for c in source: + instructions.append(c) + if c.metadata.mark_current: + current = c + + if compute_stack_size: + assign_stack_size(instructions) + + return cls( + instructions=instructions, + current=current, + ) + + def recompute_references(self): + """ + Re-computes references across the bytecode. + """ + references: dict[FloatingCell, list[FloatingCell]] = defaultdict(list) + for i in self.instructions: + if isinstance(i.instruction, ReferencingInstruction): + references[i.instruction.arg].append(i) + for i in self.instructions: + i.referenced_by = references[i] + + def assemble(self, **kwargs) -> "AssembledBytecode": + """ + Assembles the bytecode. + + Parameters + ---------- + kwargs + Arguments to `iter_as`. + + Returns + ------- + Assembled bytecode. + """ + self.recompute_references() + code_iter, consts, names, varnames, cells = iter_as(self.instructions, **kwargs) + return AssembledBytecode( + list(code_iter), + consts, + names, + varnames, + cells, + ) + + +@dataclass +class AssembledBytecode(AbstractBytecode): + instructions: list[FixedCell] + consts: IndexStorage + names: NameStorage + varnames: NameStorage + cells: NameStorage + """ + An assembled bytecode. + + Parameters + ---------- + code + A list of opcode cells. + consts + names + varnames + cells + Object and name storage. + """ + + def get_marks(self): + return {} + + @classmethod + def from_code_object(cls, source): + """ + Turns code objects into assembled bytecode. + + Parameters + ---------- + source + The source for the bytecode. + + Returns + ------- + Assembled bytecode. + """ + cells, code_obj = iter_extract(source) + return AssembledBytecode( + list(cells), + IndexStorage(code_obj.co_consts), + NameStorage(code_obj.co_names), + NameStorage(code_obj.co_varnames), + NameStorage(code_obj.co_cellvars + code_obj.co_freevars), + ) + + def disassemble(self, pos: Optional[int] = None, **kwargs) -> ObjectBytecode: + """ + Disassembles the bytecode. + + Parameters + ---------- + pos + Current bytecode pos. + kwargs + Arguments to `iter_dis`. + + Returns + ------- + The disassembled bytecode. + """ + return ObjectBytecode.from_iterable( + iter_dis( + self.instructions, + self.consts, + self.names, + self.varnames, + self.cells, + current_pos=pos, + **kwargs + ) + ) + + def __bytes__(self): + return b''.join(bytes(i.instruction) for i in self.instructions) + + +def disassemble(source, **kwargs) -> ObjectBytecode: + """ + Disassembles any bytecode source. + + Parameters + ---------- + source + The bytecode source. + kwargs + Arguments to `AssembledBytecode.disassemble`. + + Returns + ------- + The disassembled bytecode. + """ + return AssembledBytecode.from_code_object(source).disassemble(**kwargs) diff --git a/pyteleport/opcodes.py b/pyteleport/bytecode/opcodes.py similarity index 52% rename from pyteleport/opcodes.py rename to pyteleport/bytecode/opcodes.py index e1584ea..d644c89 100644 --- a/pyteleport/opcodes.py +++ b/pyteleport/bytecode/opcodes.py @@ -1,9 +1,11 @@ """ Extends opcode collections. """ +import sys from dis import opmap locals().update(opmap) # unpack opcodes here +python_version = sys.version_info.major * 0x100 + sys.version_info.minor # These unconditionally interrupt the normal bytecode flow interrupting = tuple( @@ -25,3 +27,20 @@ if i in opmap ) del opmap # cleanup + + +def guess_entering_stack_size(opcode: int) -> int: + """ + Figure out the starting stack size given the starting opcode. + This usually returns zero, except the special GEN_START case when it is one. + + Parameters + ---------- + opcode + The starting opcode. + + Returns + ------- + The stack size. + """ + return int(opcode in resuming) diff --git a/pyteleport/bytecode/primitives.py b/pyteleport/bytecode/primitives.py new file mode 100644 index 0000000..62a5430 --- /dev/null +++ b/pyteleport/bytecode/primitives.py @@ -0,0 +1,388 @@ +import opcode +from dataclasses import dataclass, field +from dis import opname as dis_opname, stack_effect +from math import ceil +from opcode import HAVE_ARGUMENT, EXTENDED_ARG +from sys import version_info +from typing import Optional + +if version_info[:2] <= (3, 10): + _inline_cache_entries = {} +else: + from opcode import _inline_cache_entries +from shutil import get_terminal_size + +from .printing import truncate, int_diff + +if version_info[:2] <= (3, 9): + jump_multiplier = 1 +else: + jump_multiplier = 2 + +max_opname_len = max(map(len, dis_opname)) +max_op_len = max_opname_len + 38 +no_step_opcodes = set() +for _name in "JUMP_ABSOLUTE", "JUMP_FORWARD", "JUMP_BACKWARD", "JUMP_BACKWARD_NO_INTERRUPT", "RETURN_VALUE", "RERAISE", "RAISE_VARARGS": + try: + no_step_opcodes.add(opcode.opname.index(_name)) + except ValueError: + pass + + +def byte_len(i: int) -> int: + return int(ceil((i or 1).bit_length() / 8)) + + +class AbstractBytecodePrintable: + """ + Anything that can be printed instruction-like. + """ + + def pprint(self): + """Subclasses ensure it is properly printable""" + raise NotImplementedError + + +@dataclass(frozen=True) +class AbstractInstruction(AbstractBytecodePrintable): + opcode: int + """ + A parent class for any instruction. + + Parameters + ---------- + opcode + Instruction opcode. + """ + + def __post_init__(self): + assert isinstance(self.opcode, int) + assert 0 <= self.opcode < 0x100 + + @property + def opname(self) -> str: + return dis_opname[self.opcode] + + @property + def size_tail(self): + return _inline_cache_entries.get(self.opcode, 0) + + @property + def size(self): + return self.size_tail + 2 + + def __str__(self): + return self.opname + + def pprint(self, width: int = max_opname_len): + return truncate(self.opname, width) + + +@dataclass(frozen=True) +class AbstractArgInstruction(AbstractInstruction): + opcode: int + """ + A base class for instructions with an argument. + Subclasses are required to have an arg field. + + Parameters + ---------- + opcode + Instruction opcode. + """ + + def get_stack_effect(self, jump: bool = False) -> int: + raise NotImplementedError + + def __str_arg__(self): + return str(self.arg) + + def __str__(self): + return f"{self.opname}({self.__str_arg__()})" + + def pprint(self, width: int = max_op_len, opname_width: int = max_opname_len): + arg_width = width - opname_width - 1 + if arg_width < 4: + return super().pprint(width) + return f"{truncate(self.opname, opname_width).ljust(opname_width)} {truncate(self.__str_arg__(), arg_width)}" + + +@dataclass(frozen=True) +class EncodedInstruction(AbstractArgInstruction): + arg: int + """ + Instruction with an encoded (positive integer) argument. + + Parameters + ---------- + opcode + Instruction opcode. + arg + Integer argument. + """ + + def __post_init__(self): + super().__post_init__() + assert isinstance(self.arg, int) + assert 0 <= self.arg + + @property + def size_arg(self): + return byte_len(self.arg) + + @property + def size_ext(self): + return self.size + 2 * (self.size_arg - 1) + + def get_stack_effect(self, jump: bool = False) -> int: + if self.opcode < HAVE_ARGUMENT: + arg = None + else: + arg = self.arg + return stack_effect(self.opcode, arg, jump=jump) + + def __bytes__(self): + arg = self.arg.to_bytes(self.size_arg, 'big') + result = [] + for a in arg[:-1]: + result.append(EXTENDED_ARG) + result.append(a) + result.append(self.opcode) + result.append(arg[-1]) + return bytes(result) + b'\x00' * self.size_tail + + +@dataclass(eq=False) +class FixedCell(AbstractBytecodePrintable): + offset: int + is_jump_target: bool + instruction: Optional[EncodedInstruction] = None + """ + An instruction cell at a specific offset, possibly + occupied by an instruction. + + Parameters + ---------- + offset + Instruction offset. + is_jump_target + If True, indicates that this slot is referenced. + instruction + An instruction occupying this slot. + """ + + def __str__(self): + result = f"Cell(@{self.offset}, {str(self.instruction)})" + if self.is_jump_target: + result += "*" + return result + + def pprint(self, width: int = 0, offset_width: int = 4): + if width == 0: + width, _ = get_terminal_size() + instr_width = width - offset_width - 1 + if self.instruction is None: + inner = truncate("None", instr_width, left="<", right=">") + else: + inner = self.instruction.pprint(width=instr_width) + offset = truncate(str(self.offset), offset_width, suffix="..") + return f"{offset.rjust(offset_width)} {inner}" + + +@dataclass(frozen=True) +class NoArgInstruction(AbstractInstruction): + arg: int = 0 + """ + An instruction that does not require any + arguments. + + Parameters + ---------- + opcode + Instruction opcode. + arg + Instruction argument. Simply keeps + whatever argument provided, without + any meaning assigned. + """ + + def __post_init__(self): + super().__post_init__() + assert isinstance(self.arg, int) + assert 0 <= self.arg + + def get_stack_effect(self, jump: bool = False) -> int: + assert not jump + return stack_effect(self.opcode, None) + + +@dataclass(frozen=True) +class NameInstruction(AbstractArgInstruction): + arg: str + """ + Instruction with a name (string) argument. + + Parameters + ---------- + opcode + Instruction opcode. + arg + Name argument. + """ + + def __post_init__(self): + super().__post_init__() + assert isinstance(self.arg, str) + + def get_stack_effect(self, jump: bool = False) -> int: + assert not jump + return stack_effect(self.opcode, 0) + + +@dataclass(frozen=True) +class ConstInstruction(AbstractArgInstruction): + arg: object + """ + Instruction with a constant (object) argument. + + Parameters + ---------- + opcode + Instruction opcode. + arg + Object argument. + """ + + def get_stack_effect(self, jump: bool = False) -> int: + assert not jump + return stack_effect(self.opcode, 0) + + def __str_arg__(self): + return repr(self.arg) + + +@dataclass(eq=False) +class FloatingMetadata: + source: Optional[FixedCell] = None + _stack_size: Optional[int] = None + mark_current: Optional[bool] = None + """ + Metadata for the floating cell. + + Parameters + ---------- + source + The corresponding fixed cell for this floating cell. + stack_size + The number of items in the value stack *before* + the corresponding instruction is executed. + mark_current + Marks this instruction as "current" (i.e. to be + executed next). + """ + + @property + def stack_size(self) -> int: + return self._stack_size + + @stack_size.setter + def stack_size(self, value: Optional[int]): + if value is not None and value < 0: + raise ValueError(f"trying to set stack_size to negative {value=}") + self._stack_size = value + + +@dataclass(eq=False) +class FloatingCell(AbstractBytecodePrintable): + instruction: Optional[AbstractInstruction] + referenced_by: list["FloatingCell"] = None + metadata: FloatingMetadata = field(default_factory=FloatingMetadata) + """ + A bytecode slot without any specific offset but + instead keeping track of references to this slot. + + Parameters + ---------- + instruction + The instruction in this cell. + referenced_by + A list of references. + metadata + Optional metadata for this instruction. + """ + + def __post_init__(self): + if self.referenced_by is None: + self.referenced_by = [] + + @property + def is_jump_target(self): + return bool(self.referenced_by) + + def swap_with(self, another: "FloatingCell") -> None: + """ + Reference another floating cell. + This makes `self` safe to delete. + + Parameters + ---------- + another + Another cell to reference. + """ + for i in set(self.referenced_by): + i.relink_to(another) + assert len(self.referenced_by) == 0 + + def __str__(self): + result = f"Cell({str(self.instruction)})" + if self.is_jump_target: + result += "*" + return result + + def pprint(self, width: int = 0, width_stack_size: int = 16): + if width == 0: + width, _ = get_terminal_size() + instr_width = width - width_stack_size - 1 + if self.instruction is None: + result = truncate("None", instr_width, left="<", right=">") + else: + result = self.instruction.pprint(width=instr_width) + if self.metadata.stack_size is None or self.instruction is None: + stack_size = " ?" + else: + try: + delta = self.instruction.get_stack_effect(jump=False) + except AssertionError: + delta = 0 + stack_size = int_diff(self.metadata.stack_size, max(0, delta), max(0, -delta), width_stack_size) + return f"{result.ljust(instr_width)} {stack_size}" + + +@dataclass(frozen=True) +class ReferencingInstruction(AbstractArgInstruction): + arg: FloatingCell + """ + Instruction referencing another one through the + argument. + + Parameters + ---------- + opcode + Instruction opcode. + arg + Reference to another instruction. + """ + + def __post_init__(self): + super().__post_init__() + assert isinstance(self.arg, FloatingCell) + + def get_stack_effect(self, jump: bool = False) -> int: + if self.opcode in no_step_opcodes: + assert jump + return stack_effect(self.opcode, 0, jump=jump) + + def __str_arg__(self, size: Optional[int] = None): + if isinstance(self.arg.instruction, ReferencingInstruction): + return truncate(f"to {self.arg.instruction.opname}", size) + return truncate(f"to {str(self.arg.instruction)}", size) diff --git a/pyteleport/bytecode/printing.py b/pyteleport/bytecode/printing.py new file mode 100644 index 0000000..a0c40ae --- /dev/null +++ b/pyteleport/bytecode/printing.py @@ -0,0 +1,91 @@ +from typing import Optional + + +def truncate( + s: str, + size: Optional[int], + suffix: str = "...", + left: str = "", + right: str = "", +) -> str: + """ + Truncates a string. + + Parameters + ---------- + s + String to truncate. + size + Maximal resulting string size. + suffix + The suffix to use when indicating truncation. + left + An optional left bracket. + right + An optional right bracket. + + Returns + ------- + The truncated string. + """ + if size is None: + size = float("+inf") + size -= len(left) + len(right) + assert size > 0 + if len(s) <= size: + return f"{left}{s}{right}" + else: + assert len(suffix) < size + return f"{left}{s[:size - len(suffix)]}{suffix}{right}" + + +def str_truncated(o, size: Optional[int]) -> str: + """ + Turn object into a string up to the specified length. + + Parameters + ---------- + o + Object to represent. + size + Maximal resulting string size. + + Returns + ------- + The string representation of an object. + """ + return truncate(str(o), size) + + +def repr_truncated(o, size: Optional[int]) -> str: + """ + Truncated representation (`repr`). + + Parameters + ---------- + o + Object to represent. + size + Maximal resulting string size. + + Returns + ------- + The string representation of an object. + """ + o_repr = repr(o) + if size is not None and len(o_repr) > size: + return truncate( + type(o).__name__, + size, + left="<", + right=" instance>", + ) + return o_repr + + +def int_diff(base: int, added: int, removed: int, limit: int) -> str: + if base + max(0, added - removed) <= limit: + changed = min(added, removed) + return "." * (base - removed) + "*" * changed + "-" * (removed - changed) + "+" * (added - changed) + else: + return f"{base}->{base - removed + added}" diff --git a/pyteleport/bytecode/sequence_assembler.py b/pyteleport/bytecode/sequence_assembler.py new file mode 100644 index 0000000..6dd8d18 --- /dev/null +++ b/pyteleport/bytecode/sequence_assembler.py @@ -0,0 +1,113 @@ +from typing import Optional, Iterable + + +class Token: + """ + A token interface represents the bytecode during assembly. + """ + def update_sequentially(self, prev: Optional["Token"]): + """ + Updates this token after the previous token update. + + Parameters + ---------- + prev + The previous token in the sequence. + """ + raise NotImplementedError + + @property + def earlier_references_to_here(self) -> list["Token"]: + """ + Earlier references to this token. + """ + raise NotImplementedError + + def update_jump(self, reference: "Token") -> bool: + """ + Updates this token after the reference token update. + + Parameters + ---------- + reference + The reference that was updated. + + Returns + ------- + True to notify the size change; False otherwise. + """ + raise NotImplementedError + + +class LookBackSequence: + """ + A sequence with a possibility to look back. + + Parameters + ---------- + source + An iterable to source tokens from. + """ + def __init__(self, source: Iterable[Token]): + self.source = source + self.next = {} + self.current = None + + def __iter__(self): + return self + + def __next__(self) -> tuple[Optional[Token], Token]: + """Outputs previous and current tokens in a sequence.""" + k = self.current + try: + self.current = self.next[k] + except KeyError: + nxt = next(self.source) + if nxt in self.next or k is nxt: + raise RuntimeError(f"non-unique token received: {nxt}") + self.current = self.next[k] = nxt + return k, self.current + + def restart_from_earlier(self, from_token: Optional[Token]): + """ + Restarts iteration from one of the tokens yielded earlier. + + Parameters + ---------- + from_token + A token to restart from. + """ + assert from_token in self.next + self.current = from_token + + def reset(self): + """Resets iteration from the very beginning.""" + self.restart_from_earlier(None) + + +def assemble(source: LookBackSequence): + """ + Assembles the sequence. + + A very simple logic here: update tokens sequentially and rewind the sequence + whenever the jump update requires to do so. + + Modifies sequence tokens in-place. + + Parameters + ---------- + source + The sequence of tokens with the possibility to look back. + """ + for previous, token in source: + token.update_sequentially(previous) + + # Update references + stale = None + for referencing_token in token.earlier_references_to_here: + if referencing_token.update_jump(token) and stale is None: + stale = referencing_token + + # If any reference indicates stale sequence state, restart from it + if stale is not None: + source.restart_from_earlier(stale) diff --git a/pyteleport/bytecode/tests/__init__.py b/pyteleport/bytecode/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyteleport/bytecode/tests/test_as.py b/pyteleport/bytecode/tests/test_as.py new file mode 100644 index 0000000..215c8c2 --- /dev/null +++ b/pyteleport/bytecode/tests/test_as.py @@ -0,0 +1,82 @@ +import dis +import random +from pathlib import Path +from random import choice +from dataclasses import dataclass + +import pytest + +from ...tests.test_scripts import test_cases +from ..minias import disassemble + + +def _test_back_forth(source_code: str): + code_obj = compile(source_code, "", "exec") + my_code = disassemble(code_obj, keep_nop=True).assemble( + consts=code_obj.co_consts, # ensures const/name order is preserved + names=code_obj.co_names, # same + varnames=code_obj.co_varnames, # same + ) + # ensure order is preserved + assert tuple(my_code.names) == code_obj.co_names + assert tuple(my_code.varnames) == code_obj.co_varnames + + ref = code_obj.co_code + test = bytes(my_code) + + def _present(_i): + return tuple(enumerate(zip(( + dis.opname[_j] + for _j in _i[::2] + ), _i[1::2]))) + + assert _present(test) == _present(ref) + + +@pytest.mark.parametrize("code", [ + "a = 'hello'", + "a = b", + "if a: pass", + "for i in something: 2 * i", + "def f(): pass", + "a, b = 3, 4", + "try: something()\nexcept Exception as e: caught(e)\nelse: otherwise()\nfinally: finalize()", + "class A(B): pass", +]) +def test_oneliners(code: str): + _test_back_forth(code) + + +@pytest.mark.parametrize("name", test_cases) +def test_script(name): + with open(Path(__file__).parent.parent.parent / "tests" / name, "r") as f: + _test_back_forth(f.read()) + + +@pytest.mark.parametrize("size", [1, 10, 100, 200]) +def test_random_branching(size: int): + @dataclass + class Tree: + children: list["Tree"] = None + + def __post_init__(self): + if self.children is None: + self.children = [] + + def lines(self): + yield "if something:" + for c in self.children: + for l in c.lines(): + yield " " + l + if len(self.children) == 0: + yield " pass" + + random.seed(0) + all_nodes = [Tree()] + for _ in range(size - 1): + new = Tree() + choice(all_nodes).children.append(new) + all_nodes.append(new) + + source_code = "\n".join(all_nodes[0].lines()) + _test_back_forth(source_code) diff --git a/pyteleport/bytecode/util.py b/pyteleport/bytecode/util.py new file mode 100644 index 0000000..d4e7470 --- /dev/null +++ b/pyteleport/bytecode/util.py @@ -0,0 +1,107 @@ +from collections.abc import Iterator, Iterable +from typing import TypeVar + +class IndexStorage(list): + """ + Collects objects and assigns indices. + """ + def store(self, x) -> int: + """ + Store an object and return its index. + + Parameters + ---------- + x + Object to store. + + Returns + ------- + The index of the object in this collection. + """ + try: + return self.index(x) + except ValueError: + self.append(x) + return len(self) - 1 + __call__ = store + + def copy(self): + return self.__class__(self) + + +class NameStorage(IndexStorage): + """ + Collects names and assigns indices. + """ + def store(self, s: str, derive_unique: bool = False) -> int: + """ + Store a name and return its index. + + Parameters + ---------- + s + Name to store. + derive_unique + If True, derives a different name + in case the provided name is + already present. + + Returns + ------- + The index of the name in this collection. + """ + if derive_unique: + s = unique_name(s, self) + return super().store(s) + + +def unique_name(prefix: str, collection) -> str: + """ + Prepares a unique name that is not + in the collection (yet). + + Parameters + ---------- + prefix + The prefix to use. + collection + Name collection. + + Returns + ------- + A unique name. + """ + if prefix not in collection: + return prefix + for i in range(len(collection) + 1): + candidate = f"{prefix}{i:d}" + if candidate not in collection: + return candidate + + +class Cell: + def __init__(self): + self.value = None + + +T = TypeVar("T") + + +def log_iter(source: Iterable[T], cell: Cell) -> Iterator[T]: + """ + Saves yielded values into the provided cell. + + Parameters + ---------- + source + The source iterator. + cell + The cell to save to. + + Yields + ------ + Values from the source iterator. + """ + for v in source: + cell.value = v + yield v diff --git a/pyteleport/minias.py b/pyteleport/minias.py deleted file mode 100644 index 4efc246..0000000 --- a/pyteleport/minias.py +++ /dev/null @@ -1,483 +0,0 @@ -""" -Bytecode assembly. - -- `Bytecode.disassemble`: disassembles the bytecode; -""" -from dataclasses import dataclass -from types import FunctionType -from itertools import count -from functools import partial - -import dis -from dis import HAVE_ARGUMENT, stack_effect -import sys - -from .opcodes import ( - LOAD_CONST, - RETURN_VALUE, - EXTENDED_ARG, - NOP, - POP_JUMP_IF_FALSE, - interrupting, - resuming, -) -from .util import unique_name -from .printtools import text_table, repr_truncated - - -def long2bytes(n): - result = tuple(map(int, n.to_bytes((n.bit_length() + 7) // 8, byteorder="big"))) - assert len(result) < 5 - if len(result) == 0: - return 0, - return result - - -def get_jump_multiplier() -> int: - """ - Computes jump multiplier. - - Returns - ------- - Jump multiplier. - """ - def _challenge(): - if something: - pass - bytecode = _challenge.__code__.co_code - assert bytecode[2] == POP_JUMP_IF_FALSE - assert bytecode[-4:] == bytes((LOAD_CONST, 0, RETURN_VALUE, 0)) - arg = bytecode[3] - target = len(bytecode) - 4 - if arg == target: - return 1 - elif arg * 2 == target: - return 2 - else: - dis.dis(_challenge, file=sys.stderr) - sys.stderr.flush() - raise RuntimeError(f"Failed to determine jump multiplier with arg={arg} and target={target}") - - -jump_multiplier = get_jump_multiplier() - - -@dataclass -class Instruction: - """Represents a single opcode""" - opcode: int - arg: int - pos: int = 0 - len: int = 2 - jump_to: "Instruction" = None - stack_size: int = None - - def __post_init__(self): - if not isinstance(self.opcode, int): - raise ValueError(f"opcode={self.opcode} not an integer") - if not isinstance(self.arg, int) and self.arg is not None: - raise ValueError(f"arg={self.arg} not an integer") - - @property - def is_jrel(self): - return self.opcode in dis.hasjrel - - @property - def is_jump(self): - return self.opcode in dis.hasjabs - - @property - def is_any_jump(self): - return self.is_jrel or self.is_jump - - @property - def pos_last(self): - return self.pos + self.len - 2 - - def get_stack_effect(self, jump=None): - result = self.opcode in resuming - if self.opcode < HAVE_ARGUMENT: - return result + stack_effect(self.opcode) - else: - if self.is_any_jump: - return result + stack_effect(self.opcode, self.arg, jump=jump) - else: - return result + stack_effect(self.opcode, self.arg) - - def get_stack_after(self, jump=None): - return self.stack_size + self.get_stack_effect(jump=jump) - - @property - def bytes(self): - arg_bytes = long2bytes(max(self.arg, 0)) - result = [] - for i in arg_bytes[:-1]: - result.extend((EXTENDED_ARG, i)) - result.extend((self.opcode, arg_bytes[-1])) - return bytes(result) - - def compute_jump(self): - if self.is_jrel: - return self.arg * jump_multiplier + self.pos_last + 2 - elif self.is_jump: - return self.arg * jump_multiplier - else: - return None - - def assert_valid(self, prev, lookup=None): - assert self.arg >= 0, f"arg is negative: {self.arg}" - assert len(self.bytes) == self.len, f"len is invalid: len({repr(self.bytes)}) != {self.len}" - if prev is not None: - assert prev.pos + prev.len == self.pos, f"pos is invalid: {prev.pos}(prev.pos) + {prev.len}(prev.len) != {self.pos}(self.pos)" - else: - assert self.pos == 0, f"pos is non-zero: {self.pos}" - if lookup is not None and self.is_any_jump: - jump_points_to = lookup.get(self.compute_jump(), None) - assert jump_points_to is self.jump_to, f"jump_to is invalid: {repr(self.jump_to)} vs {repr(jump_points_to)}" - - def __repr__(self): - return f"{dis.opname[self.opcode]}({self.arg})@{self.pos}" - - -@dataclass -class Comment: - """Represents a comment""" - text: str - - def __repr__(self): - return f"Comment({self.text})" - - -class CList(list): - def index_store(self, x, create_new=False): - if create_new: - x = unique_name(x, self) - try: - return self.index(x) - except ValueError: - self.append(x) - return len(self) - 1 - __call__ = index_store - - def copy(self): - return CList(self) - - -def assign_jumps(instructions): - for i in instructions.values(): - if i.is_any_jump and i.arg is not None: - i.jump_to = instructions[i.compute_jump()] - - -class Bytecode(list): - def __init__(self, opcodes, co_names, co_varnames, co_consts, co_cellvars, co_freevars, - pos=None): - super().__init__(opcodes) - if pos is None: - pos = len(self) - self.pos = pos - self.co_names = CList(co_names) - self.co_varnames = CList(co_varnames) - self.co_consts = CList(co_consts) - self.co_cellvars = CList(co_cellvars) - self.co_freevars = CList(co_freevars) - - @property - def co_cell_and_free_vars(self): - return tuple(self.co_cellvars + self.co_freevars) - - def copy(self, constructor=None): - if constructor is None: - constructor = type(self) - return constructor( - self, - self.co_names.copy(), - self.co_varnames.copy(), - self.co_consts.copy(), - self.co_cellvars.copy(), - self.co_freevars.copy(), - self.pos, - ) - - @classmethod - def disassemble(cls, arg, **kwargs): - if isinstance(arg, FunctionType): - arg = arg.__code__ - code = arg.co_code - - # Attempt to read source code - marks = dis.findlinestarts(arg) - try: - lines = open(arg.co_filename, 'r').readlines() # [arg.co_firstlineno - 1:] ?? - marks = list((i_opcode, i_line, lines[i_line - 1].strip()) for (i_opcode, i_line) in marks) - except (TypeError, OSError, IndexError): - marks = None - - result = cls([], arg.co_names, arg.co_varnames, arg.co_consts, arg.co_cellvars, arg.co_freevars, **kwargs) - arg = 0 - _len = 0 - pos_lookup = {} - for pos, (opcode, _arg) in enumerate(zip(code[::2], code[1::2])): - arg = arg * 0x100 + _arg - _len += 2 - if opcode != EXTENDED_ARG: - start_pos = pos * 2 - _len + 2 - - if marks is not None: - for i_mark, (mark_opcode, mark_lineno, mark_text) in enumerate(marks): - if mark_opcode <= start_pos: - result.c(f"L{mark_lineno:<3d} {mark_text}") - else: - marks = marks[i_mark:] - break - else: - marks = [] - - instruction = Instruction(opcode, arg, pos=start_pos) - pos_lookup[start_pos] = instruction - result.i(instruction) - arg = _len = 0 - assign_jumps(pos_lookup) - result.eval_stack() - return result - - def jump_to(self, label): - for i, entry in enumerate(self): - if isinstance(entry, Comment): - if entry.text == label: - self.pos = i + 1 - return i - raise ValueError(f"label '{label}' not found") - - def i(self, opcode, arg=None, *args, **kwargs): - if isinstance(opcode, Instruction): - i = opcode - else: - if arg is None and opcode < HAVE_ARGUMENT: - arg = 0 - i = Instruction(opcode, arg, *args, **kwargs) - self.insert(self.pos, i) - self.pos += 1 - return i - - def c(self, text): - i = Comment(text) - self.insert(self.pos, i) - self.pos += 1 - return i - - def I(self, opcode, arg, *args, create_new=False, **kwargs): - if opcode in dis.hasconst: - return self.i(opcode, self.co_consts(arg, create_new=create_new), *args, **kwargs) - elif opcode in dis.hasname: - return self.i(opcode, self.co_names(arg, create_new=create_new), *args, **kwargs) - elif opcode in dis.haslocal: - return self.i(opcode, self.co_varnames(arg, create_new=create_new), *args, **kwargs) - elif opcode in dis.hasjrel + dis.hasjabs: - result = self.i(opcode, None) - result.jump_to = arg - return result - else: - raise ValueError(f"Unknown opcode: {dis.opname[opcode]}") - - def nop(self, arg): - arg = bytes(arg) - for i in arg: - self.i(NOP, int(i)) - - def iter_opcodes(self, start=None): - if start is None: - iterator = iter(self) - else: - iterator = iter(self[self.index(start):]) - for i in iterator: - if isinstance(i, Instruction): - yield i - - def by_pos(self, pos): - for i in self.iter_opcodes(): - if i.pos == pos: - return i - else: - raise ValueError(f"Instruction with pos={pos} not found") - - def eval_stack(self): - for i_i, i in enumerate(self.iter_opcodes()): - i.stack_size = 0 if i_i == 0 else None - - updated = True - - def _maybe_set_stack(_op: Instruction, _stack: int): - nonlocal updated - assert _stack >= 0, f"Negative stack size {_stack} at {_op.pos} for code\n{self}" - if _op.opcode == RETURN_VALUE: - assert _stack == 1, f"Stack size {_stack} != 1 for RETURN_VALUE" - if _op.stack_size is not None: - assert _op.stack_size == _stack, f"Failed to match stack_size={_stack} against previously assigned value {_op.stack_size} at pos {_op.pos} for code\n{self}" - else: - _op.stack_size = _stack - updated = True - - while updated: - updated = False - prev_instruction = None - for i in self.iter_opcodes(): - # no-jump - if prev_instruction is not None and prev_instruction.stack_size is not None and \ - prev_instruction.opcode not in interrupting: - _maybe_set_stack(i, prev_instruction.get_stack_after(jump=False)) - - # jump - if i.stack_size is not None and i.is_any_jump: - _maybe_set_stack(i.jump_to, i.get_stack_after(jump=True)) - - prev_instruction = i - - def assign_pos(self): - pos = 0 - for i in self.iter_opcodes(): - i.pos = pos - pos += i.len - - def assign_jump_args(self): - for i in self.iter_opcodes(): - if i.is_any_jump and i.jump_to is not None: - if i.is_jump: - i.arg = i.jump_to.pos // jump_multiplier - elif i.is_jrel: - i.arg = (i.jump_to.pos - i.pos_last - 2) // jump_multiplier - - def assign_len(self): - for i in self.iter_opcodes(): - i.len = len(i.bytes) - - def assert_valid(self): - prev = None - lookup = {i.pos: i for i in self.iter_opcodes()} - for opcode in self.iter_opcodes(): - opcode.assert_valid(prev, lookup=lookup) - prev = opcode - - def get_bytecode(self): - for i in range(5): - self.assign_jump_args() - self.assign_len() - self.assign_pos() - try: - self.assert_valid() - break - except AssertionError: - pass - else: - self.assert_valid() # re-raise - return b''.join(i.bytes for i in self.iter_opcodes()) - - def __str__(self): - lookup = {id(i): i_i for i_i, i in enumerate(self)} - connections = [] - for _ in self: - connections.append([]) - for i_i, i in enumerate(self): - if isinstance(i, Instruction) and i.jump_to is not None: - i_j = lookup[id(i.jump_to)] - i_mn = min(i_i, i_j) - i_mx = max(i_i, i_j) - occupied = set(sum(connections[i_mn:i_mx+1], [])) - for slot in count(): - if slot not in occupied: - break - for _ in range(i_mn, i_mx+1): - connections[_].append(slot) - lines = [] - for i_i, (i, c, c_prev, c_next) in enumerate(zip( - self, connections, [[]] + connections[:-1], connections[1:] + [[]])): - line = [] - lines.append(line) - - if isinstance(i, Instruction): - line.append(">" if i_i == self.pos else None) # pointer - line.append(str(i.pos)) # opcode - line.append(dis.opname[i.opcode]) # opcode - line.append(str(i.arg)) # argument - - elif isinstance(i, Comment): - line.extend([None, None, (i.text, 2)]) # span 3 columns - - else: - raise ValueError(f"unknown object {i}") - - if len(c) > 0: - _str = "" - for _ in range(max(c) + 1): - if _ in c: - if _ in c_prev and _ in c_next: - _str += "┃" - elif _ in c_prev and _ not in c_next: - _str += "┛" - elif _ not in c_prev and _ in c_next: - _str += "┓" - else: - _str += "@" - else: - _str += " " - line.append(_str) # jump vis - else: - line.append(None) - - if isinstance(i, Instruction): - represented, arg = _repr_arg(i.opcode, i.arg, self, repr=partial(repr_truncated, target=24)) - line.append(arg if represented else None) # oparg value - line.append(f"{i.stack_size:d}" if i.stack_size is not None else None) # stack size - return text_table(lines, [ - (2, "right"), - (4, "right"), - (18, "left"), - (16, "left"), - (10, "left"), - (24, "left"), - (8, "left"), - ]) - - -def _repr_arg(opcode, arg, code, repr=repr): - if opcode in dis.hasconst: - return True, repr(code.co_consts[arg]) - elif opcode in dis.haslocal: - return True, code.co_varnames[arg] - elif opcode in dis.hasname: - return True, code.co_names[arg] - elif opcode in dis.hasfree: - return True, code.co_cell_and_free_vars[arg] - else: - return False, arg - - -def _repr_opcode(opcode, arg, code): - head = f"{dis.opname[opcode]:>20} {arg: 3d}" - represented, val = _repr_arg(opcode, arg, code) - if represented: - return f"{head} {'(' + repr(val) + ')':<12}" - else: - return f"{head}" + " " * 13 - - -def _dis(code_obj, alt=None): - code = code_obj.co_code - if alt is None: - alt = code - result = list(zip(code[::2], code[1::2], alt[::2], alt[1::2])) - result_repr = [] - for i, (opc_old, arg_old, opc_new, arg_new) in enumerate(result): - i *= 2 - if (opc_new, arg_new) == (opc_old, arg_old): - result_repr.append((f"{i: 3d} {_repr_opcode(opc_new, arg_new, code_obj)}",)) - else: - result_repr.append(("\033[94m", f"{i: 3d} {_repr_opcode(opc_new, arg_new, code_obj)} {_repr_opcode(opc_old, arg_old, code_obj)}", "\033[0m")) - return result_repr - - -def cdis(code_obj, alt=None): - return "\n".join(''.join(i) for i in _dis(code_obj, alt=alt)) - - - diff --git a/pyteleport/morph.py b/pyteleport/morph.py index d166758..633314d 100644 --- a/pyteleport/morph.py +++ b/pyteleport/morph.py @@ -7,25 +7,28 @@ import dis import logging from types import CodeType, FunctionType +from typing import Optional from functools import partial -import sys +from dataclasses import dataclass -from .minias import Bytecode, jump_multiplier +from .bytecode import Bytecode, disassemble, jump_multiplier +from .bytecode.primitives import AbstractInstruction, NoArgInstruction, ConstInstruction, NameInstruction, EncodedInstruction, ReferencingInstruction, FloatingCell +from .bytecode.minias import assign_stack_size from .primitives import NULL -from .opcodes import ( - POP_TOP, UNPACK_SEQUENCE, BINARY_SUBSCR, +from .bytecode.opcodes import ( + POP_TOP, UNPACK_SEQUENCE, BINARY_SUBSCR, BUILD_TUPLE, LOAD_CONST, LOAD_FAST, LOAD_ATTR, LOAD_METHOD, LOAD_GLOBAL, STORE_FAST, STORE_NAME, STORE_GLOBAL, STORE_ATTR, - JUMP_ABSOLUTE, - CALL_FUNCTION, CALL_METHOD, + JUMP_FORWARD, + CALL_FUNCTION_EX, IMPORT_NAME, IMPORT_FROM, MAKE_FUNCTION, - RAISE_VARARGS, SETUP_FINALLY, + RAISE_VARARGS, + guess_entering_stack_size, python_version, ) from .util import log_bytecode from .storage import transmission_engine EXCEPT_HANDLER = 257 -python_version = sys.version_info.major * 0x100 + sys.version_info.minor # 3.9 code_object_args = ("argcount", "posonlyargcount", "kwonlyargcount", @@ -36,8 +39,10 @@ "freevars", "cellvars", ) -if python_version > 0x0309: # 3.10 and above - from .opcodes import GEN_START +if python_version == 0x030A: # 3.10 + from .bytecode.opcodes import GEN_START +if python_version <= 0x030A: # 3.10 and below + from .bytecode.opcodes import SETUP_FINALLY def _iter_stack(value_stack, block_stack): @@ -82,7 +87,70 @@ def _iter_stack(value_stack, block_stack): yield stack_item, True +NOTSET = object() + + +@dataclass class MorphCode(Bytecode): + editing: int = 0 + + def __post_init__(self): + self.__editing_history__ = [] + + def __enter__(self): + self.__editing_history__.append(self.editing) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.editing = self.__editing_history__.pop() + + @classmethod + def from_bytecode(cls, code: Bytecode) -> "MorphCode": + return cls(code.instructions, current=code.current) + + def get_marks(self): + return {self.instructions[self.editing]: "✎✎✎", **super().get_marks()} + + def insert_cell(self, cell: FloatingCell, at: Optional[int] = None): + if at is not None: + self.instructions.insert(at, cell) + else: + self.instructions.insert(self.editing, cell) + self.editing += 1 + return cell + + def insert(self, instruction: AbstractInstruction, at: Optional[int] = None): + return self.insert_cell(FloatingCell(instruction), at=at) + + def i(self, opcode: int, arg=NOTSET) -> FloatingCell: + if opcode < dis.HAVE_ARGUMENT: + if arg is not NOTSET: + raise ValueError(f"no argument expected for {dis.opname[opcode]}; provided: {arg=}") + result = NoArgInstruction(opcode) + else: + if arg is NOTSET: + raise ValueError(f"argument expected for {dis.opname[opcode]}") + if opcode in dis.hasconst: + result = ConstInstruction(opcode, arg) + elif opcode in dis.hasname + dis.haslocal + dis.hasfree: + if not isinstance(arg, str): + raise ValueError(f"string argument expected for {dis.opname[opcode]}; provided: {arg=}") + result = NameInstruction(opcode, arg) + elif opcode in dis.hasjabs + dis.hasjrel: + if not isinstance(arg, FloatingCell): + raise ValueError(f"cell argument expected for {dis.opname[opcode]}; provided: {arg=}") + result = ReferencingInstruction(opcode, arg) + else: + if not isinstance(arg, int): + raise ValueError(f"integer argument expected for {dis.opname[opcode]}; provided: {arg=}") + result = EncodedInstruction(opcode, arg) + return self.insert(result) + + def c(self, *args): + pass + + def sign(self): + pass + def put_except_handler(self) -> None: """ Puts except handler and 3 items (NULL, NULL, None) on the stack. @@ -92,12 +160,12 @@ def put_except_handler(self) -> None: # except: # POP, POP, POP # ... - setup_finally = self.I(SETUP_FINALLY, None) + towards = FloatingCell(NoArgInstruction(POP_TOP)) + self.i(SETUP_FINALLY, towards) self.i(RAISE_VARARGS, 0) - for i in range(3): - pop_top = self.i(POP_TOP, 0) - if i == 0: - setup_finally.jump_to = pop_top + self.insert_cell(towards) + for i in range(2): + self.i(POP_TOP) def put_null(self) -> None: """ @@ -106,17 +174,9 @@ def put_null(self) -> None: # any unbound method will work here # property.fget # POP - self.I(LOAD_GLOBAL, "property") - self.I(LOAD_METHOD, "fget") - self.i(POP_TOP, 0) - - def sign(self, signature=b'mrph') -> None: - """ - Marks the code with a static signature. - """ - self.pos = len(self) - self.c("!signature") - self.nop(signature) + self.i(LOAD_GLOBAL, "property") + self.i(LOAD_METHOD, "fget") + self.i(POP_TOP) def put_unpack(self, object_storage_name: str, object_storage: dict, tos) -> None: """ @@ -137,8 +197,8 @@ def put_unpack(self, object_storage_name: str, object_storage: dict, tos) -> Non # storage_name[id(tos)] handle = id(tos) object_storage[handle] = tos - self.I(LOAD_GLOBAL, object_storage_name) - self.I(LOAD_CONST, handle) + self.i(LOAD_GLOBAL, object_storage_name) + self.i(LOAD_CONST, handle) self.i(BINARY_SUBSCR) def put_module(self, name: str, fromlist=None, level=0): @@ -154,11 +214,16 @@ def put_module(self, name: str, fromlist=None, level=0): level Import level (absolute or relative). """ - self.I(LOAD_CONST, level) - self.I(LOAD_CONST, fromlist) - self.I(IMPORT_NAME, name) - - def unpack_storage(self, object_storage_name: str, object_storage_protocol: transmission_engine) -> int: + self.i(LOAD_CONST, level) + self.i(LOAD_CONST, fromlist) + self.i(IMPORT_NAME, name) + + def unpack_storage( + self, + object_storage_name: str, + object_storage_protocol: transmission_engine, + object_data: bytes, + ) -> FloatingCell: """ Unpack the storage. @@ -166,9 +231,11 @@ def unpack_storage(self, object_storage_name: str, object_storage_protocol: tran ---------- object_storage_name The name of the storage in builtins. - object_storage_protocol : storage_protocol + object_storage_protocol A collection of functions governing initial serialization and de-serialization of the global storage dict. + object_data + The serialized data which object storage unpacks. Returns ------- @@ -177,16 +244,31 @@ def unpack_storage(self, object_storage_name: str, object_storage_protocol: tran expected. """ # storage.loads(data) (kinda) - self.I(LOAD_CONST, object_storage_protocol.load_from_code.__code__) - self.I(LOAD_CONST, "unpack") + self.i(LOAD_CONST, object_storage_protocol.load_from_code.__code__) + self.i(LOAD_CONST, "unpack") self.i(MAKE_FUNCTION, 0) - handle = self.I(LOAD_CONST, "", create_new=True).arg - self.i(CALL_FUNCTION, 1) + result = self.i(LOAD_CONST, object_data) + self.i(BUILD_TUPLE, 1) + self.i(CALL_FUNCTION_EX, 0) # import builtins self.put_module("builtins") # builtins.morph_data = ... - self.I(STORE_ATTR, object_storage_name) - return handle + self.i(STORE_ATTR, object_storage_name) + return result + + def i_print(self, what: str): + """ + Instruct to print something. + + Parameters + ---------- + what + The string to print. + """ + self.i(LOAD_GLOBAL, "print") + self.i(LOAD_CONST, what) + self.i(CALL_FUNCTION, 1) + self.i(POP_TOP) def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name="morph_data", @@ -230,12 +312,16 @@ def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name= logging.debug(f" {object_storage=}") logging.debug(f" {object_storage_name=}") logging.debug(f" {object_storage_protocol=}") - code = Bytecode.disassemble(p.code).copy(MorphCode) - if python_version >= 0x030A and next(code.iter_opcodes()).opcode == GEN_START: + code = MorphCode.from_bytecode(disassemble(p.code, pos=p.pos)) + lookup_orig = { + i.metadata.source.offset: i + for i in code.instructions + } + if python_version >= 0x030A and code.instructions[0].instruction.opcode == GEN_START: # Leave the generator header on top - code.pos = 1 + code.editing = 1 else: - code.pos = 0 + code.editing = 0 f_code = p.code code.c("--------------") code.c("Morph preamble") @@ -245,21 +331,21 @@ def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name= if object_storage_protocol is not None: logging.debug(f"Storage will be loaded here into builtins as '{object_storage_name}'") code.c("!unpack object storage") - storage_future_data_handle = code.unpack_storage(object_storage_name, object_storage_protocol) + load_storage_handle = code.unpack_storage(object_storage_name, object_storage_protocol, b"to be replaced") put = partial(code.put_unpack, object_storage_name, object_storage) else: - put = partial(code.I, LOAD_CONST) + put = partial(code.i, LOAD_CONST) # locals - for obj_collection, known_as, store_opcode, name_list in [ - (p.v_locals, "locals", STORE_FAST, code.co_varnames), + for obj_collection, known_as, store_opcode in [ + (zip(p.code.co_varnames, p.v_locals), "locals", STORE_FAST), ]: code.c(f"!unpack {known_as}") - for i_obj_in_collection, obj_in_collection in enumerate(obj_collection): + for obj_name, obj_in_collection in obj_collection: if obj_in_collection is not NULL: put(obj_in_collection) - code.i(store_opcode, i_obj_in_collection) + code.i(store_opcode, obj_name) # globals for obj_collection, known_as, store_opcode in [ @@ -273,7 +359,7 @@ def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name= code.i(UNPACK_SEQUENCE, len(vlist)) for k in klist: # k = v - code.I(store_opcode, k) + code.i(store_opcode, k) # load block and value stacks code.c("!unpack stack") @@ -286,7 +372,7 @@ def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name= put(item) else: if item.type == SETUP_FINALLY: - code.i(SETUP_FINALLY, 0, jump_to=code.by_pos(item.handler * jump_multiplier)) + code.i(SETUP_FINALLY, lookup_orig[item.handler * jump_multiplier]) elif item.type == EXCEPT_HANDLER: assert next(stack_items) == (NULL, True) # traceback assert next(stack_items) == (NULL, True) # value @@ -301,20 +387,20 @@ def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name= if call_nxt: code.c("!call TOS") if isinstance(nxt, FunctionType): - code.i(CALL_FUNCTION, 0) + code.i(BUILD_TUPLE, 0) + code.i(CALL_FUNCTION_EX, 0) elif isinstance(nxt, CodeType): put(f"morph_into:{f_code.co_name}") code.i(MAKE_FUNCTION, 0) - code.i(CALL_FUNCTION, 0) + code.i(BUILD_TUPLE, 0) + code.i(CALL_FUNCTION_EX, 0) else: raise ValueError(f"cannot call {nxt}") # now jump to the previously saved position if p.current_opcode is not None: code.c("!final jump") - last_opcode = code.i(JUMP_ABSOLUTE, 0, jump_to=code.by_pos(p.pos + 2)) - else: - last_opcode = code.nop(0) + code.i(JUMP_FORWARD, code.instructions[code.instructions.index(code.current) + 1]) code.c("-----------------") code.c("Original bytecode") @@ -324,32 +410,31 @@ def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name= code.sign() if object_storage is not None and object_storage_protocol is not None: - code.co_consts[storage_future_data_handle] = object_storage_protocol.save_to_code(object_storage) + load_storage_handle.instruction = ConstInstruction( + load_storage_handle.instruction.opcode, + object_storage_protocol.save_to_code(object_storage), + ) # finalize - bytecode_data = code.get_bytecode() - - # determine the desired stack size - s = 0 - preamble_stack_size = 0 - for i in code.iter_opcodes(): - s += i.get_stack_effect(jump=True) - preamble_stack_size = max(preamble_stack_size, s) - if i is last_opcode: - break + starting = code.instructions[0] + starting.metadata.stack_size = guess_entering_stack_size(starting.instruction.opcode) + assign_stack_size(code.instructions, clean_start=False) + code.print(log_bytecode) + assembled = code.assemble() + bytecode_data = bytes(assembled) init_args = dict( argcount=0, posonlyargcount=0, kwonlyargcount=0, - nlocals=len(code.co_varnames), - stacksize=max(f_code.co_stacksize, preamble_stack_size), + nlocals=len(assembled.varnames), + stacksize=max(i.metadata.stack_size or 0 for i in code.instructions), flags=flags, code=bytecode_data, - consts=tuple(code.co_consts), - names=tuple(code.co_names), - varnames=tuple(code.co_varnames), - freevars=tuple(code.co_cellvars + code.co_freevars), + consts=tuple(assembled.consts), + names=tuple(assembled.names), + varnames=tuple(assembled.varnames), + freevars=tuple(assembled.cells), cellvars=tuple(), filename=f_code.co_filename, # TODO: something different should be here name=f_code.co_name, @@ -359,8 +444,6 @@ def morph_into(p, nxt, call_nxt=False, object_storage=None, object_storage_name= ) init_args = tuple(init_args[f"{i}"] for i in code_object_args) result = CodeType(*init_args) - for i in str(code).split("\n"): - log_bytecode(i) return FunctionType( result, diff --git a/pyteleport/printtools.py b/pyteleport/printtools.py deleted file mode 100644 index fc27dd7..0000000 --- a/pyteleport/printtools.py +++ /dev/null @@ -1,80 +0,0 @@ -def truncate(s: str, target: int, suffix: str = "...") -> str: - """ - Truncates a string. - - Parameters - ---------- - s - String to truncate. - target - Maximal resulting string size. - suffix - The suffix to use when indicating truncation. - - Returns - ------- - The truncated string. - """ - if len(s) <= target: - return s - else: - assert len(suffix) < target - return s[:target - len(suffix)] + suffix - - -def repr_truncated(o, target: int) -> str: - """ - Truncated representation (`repr`). - - Parameters - ---------- - o - Object to represent. - target - Maximal resulting string size. - - Returns - ------- - The string representation of an object. - """ - o_repr = repr(o) - if len(o_repr) > target: - o_repr = f"<{truncate(type(o).__name__, target - 11)} instance>" - return o_repr - - -def text_table(table, column_spec, delimiter: str = " ") -> str: - """ - Renders a text table. - - Parameters - ---------- - table - Table cell data as a nested list. - column_spec - Column specification (size and alignment). - delimiter - A string used as a vertical delimiter. - - Returns - ------- - The resulting table. - """ - result = [] - for line in table: - result_line = [] - cs_iter = iter(column_spec) - for cell, (size, align) in zip(line, cs_iter): - if isinstance(cell, tuple): - cell, x = cell - for i in range(x - 1): - _size, align = next(cs_iter) - size += _size + len(delimiter) - - if cell is None: - cell = " " * size - else: - cell = truncate(cell, size) - result_line.append({"left": str.ljust, "right": str.rjust}[align](cell, size)) - result.append(delimiter.join(result_line).rstrip()) # trailing spaces - return "\n".join(result) diff --git a/pyteleport/snapshot.py b/pyteleport/snapshot.py index 12b521d..137ef86 100644 --- a/pyteleport/snapshot.py +++ b/pyteleport/snapshot.py @@ -9,9 +9,9 @@ import logging from .frame import get_value_stack, get_block_stack, snapshot_value_stack, get_value_stack_size, get_locals -from .minias import Bytecode +from .bytecode import disassemble from .util import log_bytecode -from .opcodes import CALL_METHOD, CALL_FUNCTION, CALL_FUNCTION_KW, CALL_FUNCTION_EX, LOAD_CONST, YIELD_VALUE +from .bytecode.opcodes import CALL_METHOD, CALL_FUNCTION, CALL_FUNCTION_KW, CALL_FUNCTION_EX, LOAD_CONST, YIELD_VALUE from .primitives import NULL @@ -90,15 +90,13 @@ def predict_stack_size(frame): size : int The size of the value stack """ - code = Bytecode.disassemble(frame.f_code) - opcode = code.by_pos(frame.f_lasti + 2) - code.pos = code.index(opcode) # for presentation - logging.debug(f" predicting stack size for {opcode}: {opcode.stack_size}") - for i in str(code).split("\n"): - log_bytecode(i) - if opcode.stack_size is None: + code = disassemble(frame.f_code, pos=frame.f_lasti + 2) + code.print(log_bytecode) + stack_size = code.current.metadata.stack_size + logging.debug(f" predicted stack size at {code.current}: {stack_size}") + if stack_size is None: raise ValueError("Failed to predict stack size") - return opcode.stack_size - 1 # the returned value is not there yet + return stack_size - 1 # the returned value is not there yet def normalize_frames(topmost_frame): diff --git a/pyteleport/test_scripts.py b/pyteleport/tests/test_scripts.py similarity index 74% rename from pyteleport/test_scripts.py rename to pyteleport/tests/test_scripts.py index f216c36..4e042a9 100644 --- a/pyteleport/test_scripts.py +++ b/pyteleport/tests/test_scripts.py @@ -1,3 +1,4 @@ +import subprocess from subprocess import check_output, Popen, PIPE import sys from pathlib import Path @@ -20,7 +21,7 @@ def run_test(name, interactive=False, dry_run=False, timeout=2): return check_output([sys.executable, name, f"{dry_run=}"], stderr=PIPE, text=True, env={"PYTHONPATH": "."}, timeout=timeout) -test_cases = list(map(lambda x: x.name, (Path(__file__).parent / "tests").glob("_test_*.py"))) +test_cases = list(map(lambda x: x.name, Path(__file__).parent.glob("_test_*.py"))) @pytest.mark.parametrize("test", test_cases) @@ -32,11 +33,16 @@ def test_external(test, interactive, dry_run): pytest.skip(f"_test_teleport_ec2 requires ec2 setup") if test == "_test_teleport_ssh.py": pytest.skip(f"_test_teleport_ssh.py needs an ssh setup") - test = Path(__file__).parent / "tests" / test + test = Path(__file__).parent / test with open(test, 'r') as f: module_text = f.read() module = ast.parse(module_text) docstring = ast.get_docstring(module).format(dry_run=dry_run) - assert run_test(test, interactive=interactive, dry_run=dry_run).rstrip() == eval(f'f"""{docstring}"""') + try: + assert run_test(test, interactive=interactive, dry_run=dry_run).rstrip() == eval(f'f"""{docstring}"""') + except subprocess.CalledProcessError as e: + raise RuntimeError(f"The remote python process exited with code {e.returncode}\n" + f"--- stdout ---\n{e.stdout}\n" + f"--- stderr ---\n{e.stderr}") diff --git a/pyteleport/util.py b/pyteleport/util.py index 32e8c33..77a2a65 100644 --- a/pyteleport/util.py +++ b/pyteleport/util.py @@ -23,30 +23,6 @@ def is_python_interactive() -> bool: return "ps1" in dir(sys) -def unique_name(prefix: str, collection) -> str: - """ - Prepares a unique name that is not - in the collection (yet). - - Parameters - ---------- - prefix - The prefix to use. - collection - Name collection. - - Returns - ------- - A unique name. - """ - if prefix not in collection: - return prefix - for i in range(len(collection) + 1): - candidate = f"{prefix}{i:d}" - if candidate not in collection: - return candidate - - def exit(code: int = 0, flush_stdio: bool = True): """ Exits the interpreter.