Skip to content

Commit

Permalink
update: update code block to ```python {.marimo} (#3387)
Browse files Browse the repository at this point in the history
fixes #1451

I just pushed this branch up from tinkering around back in September.

Outstanding:
  - [x] tests
    - [x] Basic backwards compat
  - [x] confirm `sql` blocks
  - [x] code highlight for "`python {.marimo}` on frontend

---------

Co-authored-by: Myles Scolnick <myles@marimo.io>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 31, 2025
1 parent 07ac53d commit 7a61a27
Show file tree
Hide file tree
Showing 29 changed files with 1,230 additions and 195 deletions.
34 changes: 26 additions & 8 deletions marimo/_ast/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class CellImpl:
mod: ast.Module
defs: set[Name]
refs: set[Name]
# Variables that should only live for the duration of the cell
temporaries: set[Name]

# metadata about definitions
Expand Down Expand Up @@ -178,6 +179,9 @@ class CellImpl:
_sqls: ParsedSQLStatements = dataclasses.field(
default_factory=ParsedSQLStatements
)
_raw_sqls: ParsedSQLStatements = dataclasses.field(
default_factory=ParsedSQLStatements
)

def configure(self, update: dict[str, Any] | CellConfig) -> CellImpl:
"""Update the cell config.
Expand Down Expand Up @@ -205,6 +209,14 @@ def runtime_state(self) -> Optional[RuntimeStateType]:
def run_result_status(self) -> Optional[RunResultStatusType]:
return self._run_result_status.state

def _get_sqls(self, raw: bool = False) -> list[str]:
try:
visitor = SQLVisitor(raw=raw)
visitor.visit(ast.parse(self.code))
return visitor.get_sqls()
except Exception:
return []

@property
def sqls(self) -> list[str]:
"""Returns parsed SQL statements from this cell.
Expand All @@ -215,16 +227,22 @@ def sqls(self) -> list[str]:
if self._sqls.parsed is not None:
return self._sqls.parsed

try:
visitor = SQLVisitor()
visitor.visit(ast.parse(self.code))
sqls = visitor.get_sqls()
self._sqls.parsed = sqls
except Exception:
self._sqls.parsed = []

self._sqls.parsed = self._get_sqls()
return self._sqls.parsed

@property
def raw_sqls(self) -> list[str]:
"""Returns unparsed SQL statements from this cell.
Returns:
list[str]: List of SQL statements verbatim from the cell code.
"""
if self._raw_sqls.parsed is not None:
return self._raw_sqls.parsed

self._raw_sqls.parsed = self._get_sqls(raw=True)
return self._raw_sqls.parsed

@property
def stale(self) -> bool:
return self._stale.state
Expand Down
19 changes: 15 additions & 4 deletions marimo/_ast/sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import re
from dataclasses import dataclass, field
from textwrap import dedent
from typing import Any, List, Optional

from marimo import _loggers
Expand All @@ -18,9 +19,10 @@ class SQLVisitor(ast.NodeVisitor):
This should be inside a function called `.execute` or `.sql`.
"""

def __init__(self) -> None:
def __init__(self, raw: bool = False) -> None:
super().__init__()
self._sqls: list[str] = []
self._raw = raw

def visit_Call(self, node: ast.Call) -> None:
# Check if the call is a method call and the method is named
Expand All @@ -35,9 +37,18 @@ def visit_Call(self, node: ast.Call) -> None:
first_arg = node.args[0]
sql: Optional[str] = None
if isinstance(first_arg, ast.Constant):
sql = first_arg.s
sql = first_arg.value
elif isinstance(first_arg, ast.JoinedStr):
sql = normalize_sql_f_string(first_arg)
if self._raw:
f_sql = ast.unparse(first_arg)
sql = dedent(
f_sql[1:]
.strip(f_sql[1])
.encode()
.decode("unicode_escape")
)
else:
sql = normalize_sql_f_string(first_arg)

if sql is not None:
# Append the SQL query to the list
Expand All @@ -64,7 +75,7 @@ def print_part(part: ast.expr) -> str:
elif isinstance(part, ast.JoinedStr):
return normalize_sql_f_string(part)
elif isinstance(part, ast.Constant):
return str(part.s)
return str(part.value)
else:
# Just add null as a placeholder for {...} expressions
return "null"
Expand Down
113 changes: 85 additions & 28 deletions marimo/_cli/convert/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union, cast

# Native to python
from xml.etree.ElementTree import Element, SubElement
Expand All @@ -22,22 +22,71 @@
SuperFencesCodeExtension,
)

from marimo import _loggers
from marimo._ast import codegen
from marimo._ast.app import App, InternalApp, _AppConfig
from marimo._ast.cell import Cell, CellConfig
from marimo._ast.compiler import compile_cell
from marimo._ast.names import DEFAULT_CELL_NAME
from marimo._convert.utils import markdown_to_marimo
from marimo._convert.utils import markdown_to_marimo, sql_to_marimo
from marimo._dependencies.dependencies import DependencyManager

LOGGER = _loggers.marimo_logger()

MARIMO_MD = "marimo-md"
MARIMO_CODE = "marimo-code"

ConvertKeys = Union[Literal["marimo"], Literal["marimo-app"]]

# Regex captures loose yaml for frontmatter
# Should match the following:
# ---
# title: "Title"
# whatever
# ---
YAML_FRONT_MATTER_REGEX = re.compile(
r"^---\s*\n(.*?\n?)(?:---)\s*\n", re.UNICODE | re.DOTALL
)


def backwards_compatible_sanitization(line: str) -> str:
return line


def extract_attribs(
line: str, fence_start: Optional[re.Match[str]] = None
) -> dict[str, str]:
# Extract attributes from the code block.
# Blocks are expected to be like this:
# python {.marimo disabled="true"}
if fence_start is None:
fence_start = RE_NESTED_FENCE_START.match(line)

if fence_start:
# attrs is a bit of a misnomer, matches
# .python.marimo disabled="true"
inner = fence_start.group("attrs")
if inner:
return dict(re.findall(r'(\w+)="([^"]*)"', inner))
return {}


def _is_code_tag(text: str) -> bool:
head = text.split("\n")[0].strip()
return bool(re.search(r"\{.*python.*\}", head))
legacy_format = bool(re.search(r"\{.*python.*\}", head))
legacy_format |= bool(re.search(r"\{.*sql.*\}", head))
if DependencyManager.new_superfences.has_required_version(quiet=True):
supported_format = bool(re.search(r".*\{.*marimo.*\}", head))
return legacy_format or supported_format
return legacy_format


def _get_language(text: str) -> str:
header = text.split("\n").pop(0)
match = RE_NESTED_FENCE_START.match(header)
if match and match.group("lang"):
return str(match.group("lang"))
return "python"


def formatted_code_block(
Expand All @@ -46,14 +95,24 @@ def formatted_code_block(
"""Wraps code in a fenced code block with marimo attributes."""
if attributes is None:
attributes = {}
language = attributes.pop("language", "python")
attribute_str = " ".join(
[""] + [f'{key}="{value}"' for key, value in attributes.items()]
)
guard = "```"
while guard in code:
guard += "`"
if DependencyManager.new_superfences.has_required_version(quiet=True):
return "\n".join(
[
f"""{guard}{language} {{.marimo{attribute_str}}}""",
code,
guard,
"",
]
)
return "\n".join(
[f"""{guard}{{.python.marimo{attribute_str}}}""", code, guard, ""]
[f"""{guard}{{.{language}.marimo{attribute_str}}}""", code, guard, ""]
)


Expand Down Expand Up @@ -82,6 +141,8 @@ def get_source_from_tag(tag: Element) -> str:
if not (source and source.strip()):
return ""
source = markdown_to_marimo(source)
elif tag.attrib.get("language") == "sql":
source = sql_to_marimo(source, tag.attrib.get("query", "_df"))
else:
assert tag.tag == MARIMO_CODE, f"Unknown tag: {tag.tag}"
return source
Expand Down Expand Up @@ -150,7 +211,10 @@ def _tree_to_app(root: Element) -> str:
cell_config: list[CellConfig] = []
for child in root:
names.append(child.get("name", DEFAULT_CELL_NAME))
cell_config.append(get_cell_config_from_tag(child))
# Default to hiding markdown cells.
cell_config.append(
get_cell_config_from_tag(child, hide_code=child.tag == MARIMO_MD)
)
sources.append(get_source_from_tag(child))

return codegen.generate_filecontents(
Expand Down Expand Up @@ -272,15 +336,6 @@ def __init__(self, md: MarimoParser):
super().__init__(md)
self.md = md
self.md.meta = {}
# Regex captures loose yaml for frontmatter
# Should match the following:
# ---
# title: "Title"
# whatever
# ---
self.yaml_front_matter_regex = re.compile(
r"^---\s*\n(.*?\n?)(?:---)\s*\n", re.UNICODE | re.DOTALL
)

def run(self, lines: list[str]) -> list[str]:
import yaml
Expand All @@ -295,7 +350,7 @@ def run(self, lines: list[str]) -> list[str]:
return lines

doc = "\n".join(lines)
result = self.yaml_front_matter_regex.match(doc)
result = YAML_FRONT_MATTER_REGEX.match(doc)

if result:
yaml_content = result.group(1)
Expand Down Expand Up @@ -414,25 +469,27 @@ def add_paragraph() -> None:
code_block = SubElement(parent, MARIMO_CODE)
block_lines = code.split("\n")
code_block.text = "\n".join(block_lines[1:-1])
# Extract attributes from the code block.
# Blocks are expected to be like this:
# {.python.marimo disabled="true"}
fence_start = RE_NESTED_FENCE_START.match(block_lines[0])
if fence_start:
# attrs is a bit of a misnomer, matches
# .python.marimo disabled="true"
inner = fence_start.group("attrs")
if inner:
code_block.attrib = dict(
re.findall(r'(\w+)="([^"]*)"', inner)
)

attribs = extract_attribs(block_lines[0])
if attribs:
code_block.attrib = attribs

# Set after to prevent lang being flushed.
code_block.set("language", _get_language(code))

add_paragraph()
# Flush to indicate all blocks have been processed.
blocks.clear()


def convert_from_md_to_app(text: str) -> App:
return MarimoParser(output_format="marimo-app").convert(text) # type: ignore[arg-type, return-value]
if not text.strip():
app = App()
else:
app = cast(App, MarimoParser(output_format="marimo-app").convert(text))

app._cell_manager.ensure_one_cell()
return app


def convert_from_md(text: str) -> str:
Expand Down
13 changes: 13 additions & 0 deletions marimo/_convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def markdown_to_marimo(source: str) -> str:
)


def sql_to_marimo(source: str, table: str) -> str:
return "\n".join(
[
f"{table} = mo.sql(",
# f-string: expected for sql
codegen.indent_text('f"""'),
codegen.indent_text(source),
codegen.indent_text('"""'),
")",
]
)


def generate_from_sources(
*,
sources: list[str],
Expand Down
Loading

0 comments on commit 7a61a27

Please sign in to comment.