diff --git a/marimo/__init__.py b/marimo/__init__.py index d01b1e06aeb..3e873f19f80 100644 --- a/marimo/__init__.py +++ b/marimo/__init__.py @@ -48,6 +48,7 @@ "icon", "iframe", "image", + "import_guard", "latex", "lazy", "left", @@ -125,7 +126,7 @@ redirect_stderr, redirect_stdout, ) -from marimo._runtime.context.utils import running_in_notebook +from marimo._runtime.context.utils import import_guard, running_in_notebook from marimo._runtime.control_flow import MarimoStopError, stop from marimo._runtime.runtime import ( app_meta, diff --git a/marimo/_ast/app.py b/marimo/_ast/app.py index c086eed5cc5..3955a31d8ab 100644 --- a/marimo/_ast/app.py +++ b/marimo/_ast/app.py @@ -21,9 +21,12 @@ TypeVar, Union, cast, + overload, ) from uuid import uuid4 +from typing_extensions import ParamSpec, TypeAlias + from marimo import _loggers from marimo._ast.cell import Cell, CellConfig, CellId_t, CellImpl from marimo._ast.cell_manager import CellManager @@ -58,7 +61,9 @@ from marimo._runtime.context.types import ExecutionContext -Fn = TypeVar("Fn", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") +Fn: TypeAlias = Callable[P, R] LOGGER = _loggers.marimo_logger() @@ -266,13 +271,13 @@ def clone(self) -> App: def cell( self, - func: Fn | None = None, + func: Fn[P, R] | None = None, *, column: Optional[int] = None, disabled: bool = False, hide_code: bool = False, **kwargs: Any, - ) -> Cell | Callable[[Fn], Cell]: + ) -> Cell | Callable[[Fn[P, R]], Cell]: """A decorator to add a cell to the app. This decorator can be called with or without parentheses. Each of the @@ -302,21 +307,29 @@ def __(mo): del kwargs return cast( - Union[Cell, Callable[[Fn], Cell]], + Union[Cell, Callable[[Fn[P, R]], Cell]], self._cell_manager.cell_decorator( func, column, disabled, hide_code, app=InternalApp(self) ), ) + # Overloads are required to preserve the wrapped function's signature. + # mypy is not smart enough to carry transitive typing in this case. + @overload + def function(self, func: Fn[P, R]) -> Fn[P, R]: ... + + @overload + def function(self, **kwargs: Any) -> Callable[[Fn[P, R]], Fn[P, R]]: ... + def function( self, - func: Fn | None = None, + func: Fn[P, R] | None = None, *, column: Optional[int] = None, disabled: bool = False, hide_code: bool = False, **kwargs: Any, - ) -> Fn | Callable[[Fn], Fn]: + ) -> Fn[P, R] | Callable[[Fn[P, R]], Fn[P, R]]: """A decorator to wrap a callable function into a cell in the app. This decorator can be called with or without parentheses. Each of the @@ -348,7 +361,7 @@ def multiply(a: int, b: int) -> int: del kwargs return cast( - Union[Fn, Callable[[Fn], Fn]], + Union[Fn[P, R], Callable[[Fn[P, R]], Fn[P, R]]], self._cell_manager.cell_decorator( func, column, diff --git a/marimo/_ast/cell.py b/marimo/_ast/cell.py index ecced9c2430..68884a36ef9 100644 --- a/marimo/_ast/cell.py +++ b/marimo/_ast/cell.py @@ -289,8 +289,9 @@ def namespace_to_variable(self, namespace: str) -> Name | None: def is_coroutine(self) -> bool: return _is_coroutine(self.body) or _is_coroutine(self.last_expr) - @property - def is_toplevel_acceptable(self) -> bool: + def is_toplevel_acceptable( + self, allowed_refs: Optional[set[Name]] = None + ) -> bool: # Check no defs aside from the single function if len(self.defs) != 1: return False @@ -316,13 +317,16 @@ def is_toplevel_acceptable(self) -> bool: return False # No required_refs are allowed for now - # TODO: Allow imports and other toplevel functions refs = set().union( *[v.required_refs for v in self.variable_data[name]] ) - refs -= set(globals()["__builtins__"].keys()) + # NOTE: Builtins are allowed, but should be passed in under + # allowed_refs. Defers to allowed_refs because shadowed builtins + # are accounted for. + if allowed_refs is None: + allowed_refs = set(globals()["__builtins__"].keys()) # Allow recursion - refs -= {name} + refs -= {name} | allowed_refs if refs: return False diff --git a/marimo/_ast/cell_manager.py b/marimo/_ast/cell_manager.py index dbd72870191..90821ca294f 100644 --- a/marimo/_ast/cell_manager.py +++ b/marimo/_ast/cell_manager.py @@ -7,13 +7,14 @@ import string from typing import ( TYPE_CHECKING, - Any, Callable, Iterable, Optional, TypeVar, ) +from typing_extensions import ParamSpec, TypeAlias + from marimo._ast.cell import Cell, CellConfig, CellId_t from marimo._ast.compiler import cell_factory, toplevel_cell_factory from marimo._ast.models import CellData @@ -23,7 +24,9 @@ if TYPE_CHECKING: from marimo._ast.app import InternalApp - Fn = TypeVar("Fn", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") +Fn: TypeAlias = Callable[P, R] class CellManager: @@ -74,20 +77,20 @@ def create_cell_id(self) -> CellId_t: # TODO: maybe remove this, it is leaky def cell_decorator( self, - func: Fn | None, + func: Fn[P, R] | None, column: Optional[int], disabled: bool, hide_code: bool, app: InternalApp | None = None, *, top_level: bool = False, - ) -> Cell | Fn | Callable[[Fn], Cell | Fn]: + ) -> Cell | Fn[P, R] | Callable[[Fn[P, R]], Cell | Fn[P, R]]: """Create a cell decorator for marimo notebook cells.""" cell_config = CellConfig( column=column, disabled=disabled, hide_code=hide_code ) - def _register(func: Fn) -> Cell | Fn: + def _register(func: Fn[P, R]) -> Cell | Fn[P, R]: # Use PYTEST_VERSION here, opposed to PYTEST_CURRENT_TEST, in # order to allow execution during test collection. is_top_level_pytest = ( @@ -111,16 +114,15 @@ def _register(func: Fn) -> Cell | Fn: # Manually set the signature for pytest. if is_top_level_pytest: - func = wrap_fn_for_pytest(func, cell) # type: ignore - # NB. in place metadata update. - functools.wraps(func)(cell) + # NB. in place metadata update. + functools.wraps(wrap_fn_for_pytest(func, cell))(cell) return cell if func is None: # If the decorator was used with parentheses, func will be None, # and we return a decorator that takes the decorated function as an # argument - def decorator(func: Fn) -> Cell | Fn: + def decorator(func: Fn[P, R]) -> Cell | Fn[P, R]: return _register(func) return decorator diff --git a/marimo/_ast/codegen.py b/marimo/_ast/codegen.py index 28417c294a8..3744c41526c 100644 --- a/marimo/_ast/codegen.py +++ b/marimo/_ast/codegen.py @@ -6,36 +6,84 @@ import importlib.util import json import os -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, cast +import re +import textwrap +from typing import Any, List, Literal, Optional, cast from marimo import __version__ from marimo._ast.app import App, _AppConfig from marimo._ast.cell import CellConfig, CellImpl from marimo._ast.compiler import compile_cell from marimo._ast.names import DEFAULT_CELL_NAME +from marimo._ast.transformers import RemoveImportTransformer from marimo._ast.visitor import Name -if TYPE_CHECKING: - from collections.abc import Sequence - INDENT = " " MAX_LINE_LENGTH = 80 +NOTICE = "\n".join( + [ + "# These imports are auto-generated by marimo.", + "# Try modifying the source cell definitions", + "# opposed to the following block.", + ] +) + def indent_text(text: str) -> str: - return "\n".join( - [INDENT + line if line else line for line in text.split("\n")] - ) + return textwrap.indent(text, INDENT) -def _multiline_tuple(elems: Sequence[str]) -> str: - if elems: - return "(" + "\n" + indent_text(",\n".join(elems)) + ",\n)" +def _format_arg(arg: Any) -> str: + if isinstance(arg, str): + return f'"{arg}"'.replace("\\", "\\\\") + elif isinstance(arg, list): + return f"[{', '.join([_format_arg(item) for item in arg])}]" else: - return "()" + return str(arg) + + +def format_tuple_elements( + code: str, + elems: tuple[str, ...], + indent: bool = False, + allowed_naked: bool = False, +) -> str: + """ + Replaces (...) with the elements in elems, formatted as a tuple. + Adjusts for multiple lines as needed. + """ + maybe_indent = indent_text if indent else (lambda x: x) + if not elems: + if allowed_naked: + return maybe_indent(code.replace("(...)", "").rstrip()) + return maybe_indent(code.replace("(...)", "()")) + + if allowed_naked and len(elems) == 1: + allowed_naked = False + elems = (f"{elems[0]},",) + + tuple_str = ", ".join(elems) + if allowed_naked: + attempt = code.replace("(...)", tuple_str).rstrip() + else: + attempt = code.replace("(...)", f"({tuple_str})") + + attempt = maybe_indent(attempt) + if len(attempt) < MAX_LINE_LENGTH: + return attempt + + # Edgecase for very long variables + if len(elems) == 1: + elems = (elems[0].strip(","),) + + multiline_tuple = "\n".join( + ["(", indent_text(",\n".join(elems)) + ",", ")"] + ) + return maybe_indent(code.replace("(...)", multiline_tuple)) -def _to_decorator( +def to_decorator( config: Optional[CellConfig], fn: Literal["cell", "function"] = "cell" ) -> str: if config is None: @@ -53,62 +101,96 @@ def _to_decorator( if config == CellConfig(): return f"@app.{fn}" else: - return ( - f"@app.{fn}(" - + ", ".join( - f"{key}={value}" for key, value in config.__dict__.items() - ) - + ")" + return format_tuple_elements( + f"@app.{fn}(...)", + tuple(f"{key}={value}" for key, value in config.__dict__.items()), ) +def build_import_section(import_blocks: list[str]) -> str: + from marimo._utils.formatter import Formatter, ruff + + stripped_block = RemoveImportTransformer("marimo").strip_imports( + "\n".join(import_blocks) + ) + if not stripped_block: + return "" + + code = "\n".join( + [ + "with marimo.import_guard():", + indent_text(NOTICE), + indent_text(stripped_block), + ] + ) + + formatted = Formatter(MAX_LINE_LENGTH).format({"code": code}) + if not formatted: + return code + tidied = ruff(formatted, "check", "--fix-only") + if not tidied: + return formatted["code"] + return tidied["code"] + "\n\n" + + def to_functiondef( - cell: CellImpl, name: str, unshadowed_builtins: Optional[set[Name]] = None + cell: CellImpl, + name: str, + allowed_refs: Optional[set[Name]] = None, + used_refs: Optional[set[Name]] = None, + fn: Literal["cell"] = "cell", ) -> str: + # allowed refs are a combination of top level imports and unshadowed + # builtins. # unshadowed builtins is the set of builtins that haven't been # overridden (shadowed) by other cells in the app. These names # should not be taken as args by a cell's functiondef (since they are # already in globals) - if unshadowed_builtins is None: - unshadowed_builtins = set(builtins.__dict__.keys()) - refs = [ref for ref in sorted(cell.refs) if ref not in unshadowed_builtins] - args = ", ".join(refs) + if allowed_refs is None: + allowed_refs = set(builtins.__dict__.keys()) + refs = tuple(ref for ref in sorted(cell.refs) if ref not in allowed_refs) + + decorator = to_decorator(cell.config, fn=fn) - decorator = _to_decorator(cell.config) - signature = f"def {name}({args}):" prefix = "" if not cell.is_coroutine() else "async " - if len(INDENT + prefix + signature) >= MAX_LINE_LENGTH: - signature = f"def {name}{_multiline_tuple(refs)}:" - signature = prefix + signature + signature = format_tuple_elements(f"{prefix}def {name}(...):", refs) - fndef = decorator + "\n" + signature + "\n" - body = indent_text(cell.code) - if body: - fndef += body + "\n" + definition_body = [decorator, signature] + if body := indent_text(cell.code): + definition_body.append(body) + # Used refs are a collection of all the references that cells make to some + # external call. We collect them such that we can determine if a variable + # def is actually ever used. This is a nice little trick such that mypy and + # other static analysis tools can capture unused variables across cells. + defs: tuple[str, ...] = tuple() if cell.defs: - defs = tuple(name for name in sorted(cell.defs)) - returns = INDENT + "return " - if len(cell.defs) == 1: - returns += f"({defs[0]},)" + if used_refs is None: + defs = tuple(name for name in sorted(cell.defs)) else: - returns += ", ".join(defs) - fndef += ( - returns - if len(INDENT + returns) <= MAX_LINE_LENGTH - else (indent_text("return " + _multiline_tuple(defs))) - ) - else: - fndef += INDENT + "return" - return fndef + defs = tuple( + name for name in sorted(cell.defs) if name in used_refs + ) + + returns = format_tuple_elements( + "return (...)", defs, indent=True, allowed_naked=True + ) + definition_body.append(returns) + return "\n".join(definition_body) -def to_top_functiondef(cell: CellImpl) -> str: +def to_top_functiondef( + cell: CellImpl, allowed_refs: Optional[set[str]] = None +) -> str: # For the top-level function criteria to be satisfied, # the cell, it must pass basic checks in the cell impl. - assert cell.is_toplevel_acceptable, "Cell is not a top-level function" + if allowed_refs is None: + allowed_refs = set(builtins.__dict__.keys()) + assert cell.is_toplevel_acceptable(allowed_refs), ( + "Cell is not a top-level function" + ) if cell.code: - decorator = _to_decorator(cell.config, fn="function") + decorator = to_decorator(cell.config, fn="function") return "\n".join([decorator, cell.code.strip()]) return "" @@ -116,35 +198,34 @@ def to_top_functiondef(cell: CellImpl) -> str: def generate_unparsable_cell( code: str, name: Optional[str], config: CellConfig ) -> str: + text = ["app._unparsable_cell("] # escape double quotes to not interfere with string quote_escaped_code = code.replace('"', '\\"') # use r-string to handle backslashes (don't want to write # escape characters, want to actually write backslash characters) code_as_str = f'r"""\n{quote_escaped_code}\n"""' - text = "app._unparsable_cell(\n" + indent_text(code_as_str) - if name is not None: - text += ",\n" + INDENT + f'name="{name}"' + + flags = {} if config != CellConfig(): - text += ( - ",\n" - + INDENT - + ", ".join( - f"{key}={value}" for key, value in config.__dict__.items() - ) - ) - text += "\n)" - return text + flags = dict(config.__dict__) + if name is not None: + flags["name"] = name -def generate_app_constructor(config: Optional[_AppConfig]) -> str: - def _format_arg(arg: Any) -> str: - if isinstance(arg, str): - return f'"{arg}"'.replace("\\", "\\\\") - elif isinstance(arg, list): - return "[" + ", ".join([_format_arg(item) for item in arg]) + "]" - else: - return str(arg) + kwargs = ", ".join( + [f"{key}={_format_arg(value)}" for key, value in flags.items()] + ) + if kwargs: + text.extend([indent_text(f"{code_as_str},"), indent_text(kwargs)]) + else: + text.append(indent_text(code_as_str)) + + text.append(")") + + return "\n".join(text) + +def generate_app_constructor(config: Optional[_AppConfig]) -> str: default_config = _AppConfig().asdict() updates = {} # only include a config setting if it's not a default setting, to @@ -154,12 +235,37 @@ def _format_arg(arg: Any) -> str: for key in default_config: if updates[key] == default_config[key]: updates.pop(key) + if config._toplevel_fn: + updates["_toplevel_fn"] = True - kwargs = [f"{key}={_format_arg(value)}" for key, value in updates.items()] - app_constructor = "app = marimo.App(" + ", ".join(kwargs) + ")" - if len(app_constructor) >= MAX_LINE_LENGTH: - app_constructor = "app = marimo.App" + _multiline_tuple(kwargs) - return app_constructor + kwargs = tuple( + f"{key}={_format_arg(value)}" for key, value in updates.items() + ) + return format_tuple_elements("app = marimo.App(...)", kwargs) + + +def _classic_export( + fndefs: list[str], + header_comments: Optional[str], + config: Optional[_AppConfig], +) -> str: + filecontents = "".join( + "import marimo" + + "\n\n" + + f'__generated_with = "{__version__}"' + + "\n" + + generate_app_constructor(config) + + "\n\n\n" + + "\n\n\n".join(fndefs) + + "\n\n\n" + + 'if __name__ == "__main__":' + + "\n" + + indent_text("app.run()") + ) + + if header_comments: + filecontents = header_comments.rstrip() + "\n\n" + filecontents + return filecontents + "\n" def generate_filecontents( @@ -174,59 +280,98 @@ def generate_filecontents( # Let's keep it disabled by default. toplevel_fn = config is not None and config._toplevel_fn - cell_data: list[Union[CellImpl, tuple[str, CellConfig]]] = [] + # Update old internal cell names to the new ones + for idx, name in enumerate(names): + if name == "__": + names[idx] = DEFAULT_CELL_NAME + + # We require 3 sweeps. + # - One for compilation and import collection + # - One for some basic static determination of top-level functions. + # - And a final sweep for cleaner argument requirements + # (since we now know what's top-level). defs: set[Name] = set() + toplevel_imports: set[Name] = set() + used_refs: Optional[set[Name]] = set() + import_blocks: list[str] = [] + + definition_stubs: list[Optional[CellImpl]] = [None] * len(codes) + definitions: list[Optional[str]] = [None] * len(codes) - cell_id = 0 - for code, cell_config in zip(codes, cell_configs): + cell: Optional[CellImpl] + for idx, (code, cell_config) in enumerate(zip(codes, cell_configs)): try: - cell = compile_cell(code, cell_id=str(cell_id)).configure( - cell_config - ) + cell = compile_cell(code, cell_id=str(idx)).configure(cell_config) defs |= cell.defs - cell_data.append(cell) + assert isinstance(used_refs, set) + used_refs |= cell.refs + if cell.import_workspace.is_import_block: + # maybe a bug, but import_workspace.imported_defs does not + # contain the information we need. + toplevel_imports |= cell.defs + definitions[idx] = to_functiondef(cell, names[idx]) + if toplevel_fn: + import_blocks.append(code.strip()) + else: + definition_stubs[idx] = cell except SyntaxError: - cell_data.append((code, cell_config)) - cell_id += 1 + definitions[idx] = generate_unparsable_cell( + code=code, config=cell_config, name=names[idx] + ) unshadowed_builtins = set(builtins.__dict__.keys()) - defs - fndefs: list[str] = [] - - # Update old internal cell names to the new ones - for idx, name in enumerate(names): - if name == "__": - names[idx] = DEFAULT_CELL_NAME - - for data, name in zip(cell_data, names): - if isinstance(data, CellImpl): - if toplevel_fn and data.is_toplevel_acceptable: - fndefs.append(to_top_functiondef(data)) - else: - fndefs.append(to_functiondef(data, name, unshadowed_builtins)) - else: - fndefs.append( - generate_unparsable_cell( - code=data[0], config=data[1], name=name - ) + allowed_refs = unshadowed_builtins | toplevel_imports + + for idx, cell in enumerate(definition_stubs): + # We actually don't care about the graph, we just want to see if we can + # render the top-level functions without a name error. + # Let graph issues be delegated to the runtime. + if cell and toplevel_fn and cell.is_toplevel_acceptable(allowed_refs): + definitions[idx] = to_top_functiondef(cell, allowed_refs) + definition_stubs[idx] = None + # Order does matter since feasibly, an app.function could be a + # decorator for another. + allowed_refs.add(names[idx]) + + # Let's hide the new behavior for now. + # Removing the toplevel_fn check may produce a bit of churn, + # so let's release the serialization changes all together. + if not toplevel_fn: + allowed_refs = unshadowed_builtins + used_refs = None + + for idx, cell in enumerate(definition_stubs): + if cell: + definitions[idx] = to_functiondef( + cell, names[idx], allowed_refs, used_refs, fn="cell" ) - filecontents = ( - "import marimo" - + "\n\n" - + f'__generated_with = "{__version__}"' - + "\n" - + generate_app_constructor(config) - + "\n\n\n" - + "\n\n\n".join(fndefs) - + "\n\n\n" - + 'if __name__ == "__main__":' - + "\n" - + indent_text("app.run()") + assert all(isinstance(d, str) for d in definitions) + cell_blocks: List[str] = cast(List[str], definitions) + if not toplevel_fn: + return _classic_export(cell_blocks, header_comments, config) + + filecontents = [] + if header_comments is not None: + filecontents = [header_comments.rstrip(), ""] + + filecontents.extend( + [ + "import marimo", + "", + build_import_section(import_blocks), + "", + f'__generated_with = "{__version__}"', + generate_app_constructor(config), + "\n", + "\n\n\n".join(cell_blocks), + "\n", + 'if __name__ == "__main__":', + indent_text("app.run()"), + "", + ] ) - - if header_comments: - filecontents = header_comments.rstrip() + "\n\n" + filecontents - return filecontents + "\n" + return "\n".join(filecontents) class MarimoFileError(Exception): @@ -290,9 +435,6 @@ def get_app(filename: Optional[str]) -> Optional[App]: return app -RECOVERY_CELL_MARKER = "ↁ" - - def recover(filename: str) -> str: """Generate a module for code recovered from a disconnected frontend""" with open(filename, "r", encoding="utf-8") as f: @@ -317,6 +459,13 @@ def recover(filename: str) -> str: ) +def is_multiline_comment(node: ast.stmt) -> bool: + """Checks if a node is a docstring or a multiline comment.""" + if isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant): + return True + return False + + def get_header_comments(filename: str) -> Optional[str]: """Gets the header comments from a file. Returns None if the file does not exist or the header is @@ -327,12 +476,6 @@ def get_header_comments(filename: str) -> Optional[str]: statement contains any non-comment code """ - def is_multiline_comment(node: ast.stmt) -> bool: - """Checks if a node is a docstring or a multiline comment.""" - if isinstance(node, ast.Expr) and isinstance(node.value, ast.Constant): - return True - return False - if not os.path.exists(filename): return None @@ -341,8 +484,9 @@ def is_multiline_comment(node: ast.stmt) -> bool: if "import marimo" not in contents: return None - - header, _ = contents.split("import marimo", 1) + header, _ = re.split( + r"^import marimo", contents, maxsplit=1, flags=re.MULTILINE + ) # Ensure the header only contains non-executable code # ast parses out single line comments, so we only diff --git a/marimo/_ast/transformers.py b/marimo/_ast/transformers.py index c6e58e1bab2..6173c4cbabe 100644 --- a/marimo/_ast/transformers.py +++ b/marimo/_ast/transformers.py @@ -2,7 +2,7 @@ from __future__ import annotations import ast -from typing import Any +from typing import Any, Optional class NameTransformer(ast.NodeTransformer): @@ -72,3 +72,48 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign: "targets": new_targets, } ) + + +class RemoveImportTransformer(ast.NodeTransformer): + """Removes import that matches the given name. + e.g. given import_name = "bar": + ```python + from foo import bar # removed + from foo import bar as baz + import foo.bar + import foo.bar as baz + import foo.baz as bar # removed + ``` + To prevent module collisions in top level definitions. + """ + + def __init__(self, import_name: str) -> None: + super().__init__() + self.import_name = import_name + + def strip_imports(self, code: str) -> str: + tree = ast.parse(code) + tree = self.visit(tree) + return ast.unparse(tree).strip() + + def visit_Import(self, node: ast.Import) -> Optional[ast.Import]: + name = self.import_name + node.names = [ + alias + for alias in node.names + if (alias.asname and alias.asname != name) + or (not alias.asname and alias.name != name) + ] + return node if node.names else None + + def visit_ImportFrom( + self, node: ast.ImportFrom + ) -> Optional[ast.ImportFrom]: + name = self.import_name + node.names = [ + alias + for alias in node.names + if (alias.asname and alias.asname != name) + or (not alias.asname and alias.name != name) + ] + return node if node.names else None diff --git a/marimo/_runtime/context/utils.py b/marimo/_runtime/context/utils.py index 10982b145cd..518cec1b46d 100644 --- a/marimo/_runtime/context/utils.py +++ b/marimo/_runtime/context/utils.py @@ -2,13 +2,16 @@ from __future__ import annotations import os -from typing import Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional, Type from marimo._output.rich_help import mddoc from marimo._runtime.context import ContextNotInitializedError, get_context from marimo._server.model import SessionMode from marimo._utils.assert_never import assert_never +if TYPE_CHECKING: + from types import TracebackType + RunMode = Literal["run", "edit", "script", "test"] @@ -57,3 +60,23 @@ def get_mode() -> Optional[RunMode]: return "test" return None + + +class import_guard: + """ + A context manager that controls imports from being executed in top level code. + NB. May be replaced for `import_guard -> bool`, experimental stub pending + MEP-0008 (github:marimo-team/meps/pull/8) + """ + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exception: Optional[Type[BaseException]], + instance: Optional[BaseException], + _tracebacktype: Optional[TracebackType], + ) -> Literal[False]: + # Whether to suppress a given exception. + return False diff --git a/marimo/_utils/formatter.py b/marimo/_utils/formatter.py index c4cd2074843..28f46b93912 100644 --- a/marimo/_utils/formatter.py +++ b/marimo/_utils/formatter.py @@ -14,9 +14,43 @@ CellCodes = Dict[CellId_t, str] +def ruff(codes: CellCodes, *cmd: str) -> CellCodes: + ruff_cmd = [sys.executable, "-m", "ruff"] + process = subprocess.run([*ruff_cmd, "--help"], capture_output=True) + if process.returncode != 0: + LOGGER.warning( + "To enable code formatting, install ruff (pip install ruff)" + ) + return {} + + formatted_codes: CellCodes = {} + for key, code in codes.items(): + try: + process = subprocess.run( + [ + *ruff_cmd, + *cmd, + "-", + ], + input=code.encode(), + capture_output=True, + check=True, + ) + if process.returncode != 0: + raise FormatError("Failed to format code with ruff") + + formatted = process.stdout.decode() + formatted_codes[key] = formatted.strip() + except Exception as e: + LOGGER.error("Failed to format code with ruff") + LOGGER.debug(e) + continue + + return formatted_codes + + class Formatter: def __init__(self, line_length: int) -> None: - self.data = None self.line_length = line_length def format(self, codes: CellCodes) -> CellCodes: @@ -43,40 +77,7 @@ def format(self, codes: CellCodes) -> CellCodes: class RuffFormatter(Formatter): def format(self, codes: CellCodes) -> CellCodes: - ruff_cmd = [sys.executable, "-m", "ruff"] - process = subprocess.run([*ruff_cmd, "--help"], capture_output=True) - if process.returncode != 0: - LOGGER.warning( - "To enable code formatting, install ruff (pip install ruff)" - ) - return {} - - formatted_codes: CellCodes = {} - for key, code in codes.items(): - try: - process = subprocess.run( - [ - *ruff_cmd, - "format", - "--line-length", - str(self.line_length), - "-", - ], - input=code.encode(), - capture_output=True, - check=True, - ) - if process.returncode != 0: - raise FormatError("Failed to format code with ruff") - - formatted = process.stdout.decode() - formatted_codes[key] = formatted.strip() - except Exception as e: - LOGGER.error("Failed to format code with ruff") - LOGGER.debug(e) - continue - - return formatted_codes + return ruff(codes, "format", "--line-length", str(self.line_length)) class BlackFormatter(Formatter): diff --git a/tests/_ast/codegen_data/test_generate_filecontents_toplevel.py b/tests/_ast/codegen_data/test_generate_filecontents_toplevel.py new file mode 100644 index 00000000000..6e5b9f03b83 --- /dev/null +++ b/tests/_ast/codegen_data/test_generate_filecontents_toplevel.py @@ -0,0 +1,99 @@ +# This comment should be preserved. + +# The best way to regenerate this file is open it up +# in the editor. i.e: +# marimo edit tests/_ast/codegen_data/test_generate_filecontents_toplevel.py + +import marimo + +with marimo.import_guard(): + # These imports are auto-generated by marimo. + # Try modifying the source cell definitions + # opposed to the following block. + import io + import textwrap + import typing + from pathlib import Path + + import marimo as mo + + +__generated_with = "0.0.0" +app = marimo.App(_toplevel_fn=True) + + +@app.cell +def _(fun_that_uses_another_but_out_of_order): + shadow = 1 + globe = 1 + ( + fun_that_uses_mo(), + fun_that_uses_another(), + fun_that_uses_another_but_out_of_order(), + ) + return globe, shadow + + +@app.function +# Sanity check that base case works. +def addition(a, b): + return a + b + + +@app.function +def shadow_case(shadow): + shadow = 2 + return shadow + + +@app.cell +def _(shadow): + def reference_case(): + return shadow + return + + +@app.cell +def _(globe): + def global_case(): + global globe + return globe + return + + +@app.function +def fun_that_uses_mo(): + return mo.md("Hello there!") + + +@app.cell +def fun_that_uses_another_but_out_of_order(): + def fun_that_uses_another_but_out_of_order(): + return fun_that_uses_another() + return (fun_that_uses_another_but_out_of_order,) + + +@app.function +def fun_that_uses_another(): + return fun_that_uses_mo() + + +@app.cell +def _(): + import io + import textwrap + + # Comments are stripped out. + import marimo as mo + return io, mo, textwrap + + +@app.cell +def _(): + import typing + from pathlib import Path + return Path, typing + + +if __name__ == "__main__": + app.run() diff --git a/tests/_ast/codegen_data/test_generate_filecontents_toplevel_pytest.py b/tests/_ast/codegen_data/test_generate_filecontents_toplevel_pytest.py new file mode 120000 index 00000000000..05d34d6d32e --- /dev/null +++ b/tests/_ast/codegen_data/test_generate_filecontents_toplevel_pytest.py @@ -0,0 +1 @@ +../test_pytest_toplevel.py \ No newline at end of file diff --git a/tests/_ast/test_codegen.py b/tests/_ast/test_codegen.py index 96344235f1a..8ebddb69fdc 100644 --- a/tests/_ast/test_codegen.py +++ b/tests/_ast/test_codegen.py @@ -44,6 +44,10 @@ def get_filepath(name: str) -> str: return os.path.join(DIR_PATH, f"codegen_data/{name}.py") +def sanitized_version(output: str) -> str: + return output.replace(__version__, "0.0.0") + + def wrap_generate_filecontents( codes: list[str], names: list[str], @@ -62,6 +66,40 @@ def wrap_generate_filecontents( ) +def strip_blank_lines(text: str) -> str: + return "\n".join([line for line in text.splitlines() if line.strip()]) + + +def get_idempotent_marimo_source(name: str) -> str: + from marimo._utils.formatter import Formatter + + path = get_filepath(name) + app = codegen.get_app(path) + header_comments = codegen.get_header_comments(path) + generated_contents = codegen.generate_filecontents( + list(app._cell_manager.codes()), + list(app._cell_manager.names()), + list(app._cell_manager.configs()), + app._config, + header_comments, + ) + generated_contents = sanitized_version(generated_contents) + + with open(path) as f: + python_source = sanitized_version(f.read()) + + # TODO(dmadisetti): not totally idempotent for now. Revise; seems to strip + # on imports (possibly during compile?). + formatted = Formatter(codegen.MAX_LINE_LENGTH).format( + {"source": python_source, "generated": generated_contents} + ) + + assert strip_blank_lines(formatted["source"]) == strip_blank_lines( + formatted["generated"] + ) + return formatted["generated"] + + class TestGeneration: @staticmethod def test_generate_filecontents_empty() -> None: @@ -275,6 +313,23 @@ def test_generate_file_contents_overwrite_default_cell_names() -> None: assert "def _" in contents assert "def __" not in contents + @staticmethod + def test_generate_filecontents_toplevel() -> None: + source = get_idempotent_marimo_source( + "test_generate_filecontents_toplevel" + ) + assert "import marimo" in source + split = source.split("import marimo") + # The default one, the as mo in top level, in as mo in cell + assert len(split) == 4 + + @staticmethod + def test_generate_filecontents_toplevel_pytest() -> None: + source = get_idempotent_marimo_source( + "test_generate_filecontents_toplevel_pytest" + ) + assert "import marimo" in source + class TestGetCodes: @staticmethod @@ -656,3 +711,56 @@ def test_is_internal_cell_name() -> None: assert not is_internal_cell_name("___") assert not is_internal_cell_name("__1213123123") assert not is_internal_cell_name("foo") + + +def test_format_tuple_elements() -> None: + kv_case = codegen.format_tuple_elements( + "@app.fn(...)", + tuple(["a", "b", "c"]), + ) + assert kv_case == "@app.fn(a, b, c)" + + indent_case = codegen.format_tuple_elements( + "def fn(...):", tuple(["a", "b", "c"]), indent=True + ) + assert indent_case == " def fn(a, b, c):" + + multiline_case = codegen.format_tuple_elements( + "return (...)", + ( + "very", + "long", + "arglist", + "that", + "exceeds", + "maximum", + "characters", + "for", + "some", + "reason", + "or", + "the", + "other", + "wowza", + ), + allowed_naked=True, + ) + assert multiline_case == ( + "return (\n " + "very,\n long,\n arglist,\n that,\n exceeds,\n maximum,\n" + " characters,\n for,\n some,\n reason,\n" + " or,\n the,\n other,\n wowza,\n)" + ) + + long_case = codegen.format_tuple_elements( + "return (...)", + ( + "very_long_name_that_exceeds_76_characters_for_some_reason_or_the_other_woowee", + ), + allowed_naked=True, + ) + assert long_case == ( + "return (\n " + "very_long_name_that_exceeds_76_characters_for_some_reason_or_the_other_woowee," + "\n)" + ) diff --git a/tests/_ast/test_pytest_toplevel.py b/tests/_ast/test_pytest_toplevel.py new file mode 100644 index 00000000000..0372055b997 --- /dev/null +++ b/tests/_ast/test_pytest_toplevel.py @@ -0,0 +1,52 @@ +# Note that marimo is not repeated in the imports. + +import marimo + +with marimo.import_guard(): + # These imports are auto-generated by marimo. + # Try modifying the source cell definitions + # opposed to the following block. + import pytest + + +__generated_with = "0.0.0" +app = marimo.App(_toplevel_fn=True) + + +@app.function +# Sanity check that base case works. +def add(a, b): + return a + b + + +@app.function +@pytest.mark.parametrize(("a", "b", "c"), [(1, 1, 2), (1, 2, 3)]) +def test_add_good(a, b, c): + assert add(a, b) == c + + +@app.function +@pytest.mark.xfail( + reason=("Check test is actually called."), + raises=AssertionError, + strict=True, +) +@pytest.mark.parametrize(("a", "b", "c"), [(1, 1, 3), (2, 2, 5)]) +def test_add_bad(a, b, c): + assert add(a, b) == c + + +@app.cell +def _(): + # Invert the order intentionally. + import pytest + + # Comments are stripped. + import marimo # noqa: I001 + + # Final comments preserve total whitespace. + return marimo, pytest + + +if __name__ == "__main__": + app.run() diff --git a/tests/_ast/test_transformers.py b/tests/_ast/test_transformers.py index c2483754cae..81eba3d225d 100644 --- a/tests/_ast/test_transformers.py +++ b/tests/_ast/test_transformers.py @@ -5,7 +5,7 @@ import pytest -from marimo._ast.transformers import NameTransformer +from marimo._ast.transformers import NameTransformer, RemoveImportTransformer @pytest.mark.skipif( @@ -81,3 +81,16 @@ def test_name_transformer_no_changes() -> None: assert new_code.strip() == code.strip() assert not transformer.made_changes + + +def test_import_transformer_strip() -> None: + code = """ +import thing.marimo # Only line that's reasonable. +import marimo +import thing as marimo +from thing.thing import marimo +from thing import m as marimo + """ + + stripped = RemoveImportTransformer("marimo").strip_imports(code) + assert stripped == "import thing.marimo" diff --git a/tests/conftest.py b/tests/conftest.py index 91c5cdbba77..a8f9246fc79 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -619,15 +619,21 @@ def pytest_make_collect_report(collector): # Defined within the file does not seem to hook correctly, as such filter # for the test_pytest specific file here. if "test_pytest" in str(collector.path): - collected = {fn.name: fn for fn in collector.collect()} - from tests._ast.test_pytest import app + collected = {fn.originalname for fn in collector.collect()} + from tests._ast.test_pytest import app as app_pytest + from tests._ast.test_pytest_toplevel import app as app_toplevel + + app = { + "test_pytest": app_pytest, + "test_pytest_toplevel": app_toplevel, + }[collector.path.stem] invalid = [] for name in app._cell_manager.names(): if name.startswith("test_") and name not in collected: invalid.append(f"'{name}'") if invalid: - tests = ", ".join([f"'{test}'" for test in collected.keys()]) + tests = ", ".join([f"'{test}'" for test in collected]) report.outcome = "failed" report.longrepr = ( f"Cannot collect test(s) {', '.join(invalid)} from {tests}"