Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update: update code block to ```python {.marimo} #3387

Merged
merged 9 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions marimo/_ast/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class CellImpl:
mod: ast.Module
defs: set[Name]
refs: set[Name]
# Variables that should only live for the duration of the cell
temporaries: set[Name]

# metadata about definitions
Expand Down Expand Up @@ -178,6 +179,9 @@ class CellImpl:
_sqls: ParsedSQLStatements = dataclasses.field(
default_factory=ParsedSQLStatements
)
_raw_sqls: ParsedSQLStatements = dataclasses.field(
default_factory=ParsedSQLStatements
)

def configure(self, update: dict[str, Any] | CellConfig) -> CellImpl:
"""Update the cell config.
Expand Down Expand Up @@ -205,6 +209,14 @@ def runtime_state(self) -> Optional[RuntimeStateType]:
def run_result_status(self) -> Optional[RunResultStatusType]:
return self._run_result_status.state

def _get_sqls(self, raw: bool = False) -> list[str]:
try:
visitor = SQLVisitor(raw=raw)
visitor.visit(ast.parse(self.code))
return visitor.get_sqls()
except Exception:
return []

@property
def sqls(self) -> list[str]:
"""Returns parsed SQL statements from this cell.
Expand All @@ -215,16 +227,22 @@ def sqls(self) -> list[str]:
if self._sqls.parsed is not None:
return self._sqls.parsed

try:
visitor = SQLVisitor()
visitor.visit(ast.parse(self.code))
sqls = visitor.get_sqls()
self._sqls.parsed = sqls
except Exception:
self._sqls.parsed = []

self._sqls.parsed = self._get_sqls()
return self._sqls.parsed

@property
def raw_sqls(self) -> list[str]:
"""Returns unparsed SQL statements from this cell.
Returns:
list[str]: List of SQL statements verbatim from the cell code.
"""
if self._raw_sqls.parsed is not None:
return self._raw_sqls.parsed

self._raw_sqls.parsed = self._get_sqls(raw=True)
return self._raw_sqls.parsed

@property
def stale(self) -> bool:
return self._stale.state
Expand Down
19 changes: 15 additions & 4 deletions marimo/_ast/sql_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import re
from dataclasses import dataclass, field
from textwrap import dedent
from typing import Any, List, Optional

from marimo import _loggers
Expand All @@ -18,9 +19,10 @@ class SQLVisitor(ast.NodeVisitor):
This should be inside a function called `.execute` or `.sql`.
"""

def __init__(self) -> None:
def __init__(self, raw: bool = False) -> None:
super().__init__()
self._sqls: list[str] = []
self._raw = raw

def visit_Call(self, node: ast.Call) -> None:
# Check if the call is a method call and the method is named
Expand All @@ -35,9 +37,18 @@ def visit_Call(self, node: ast.Call) -> None:
first_arg = node.args[0]
sql: Optional[str] = None
if isinstance(first_arg, ast.Constant):
sql = first_arg.s
sql = first_arg.value
elif isinstance(first_arg, ast.JoinedStr):
sql = normalize_sql_f_string(first_arg)
if self._raw:
f_sql = ast.unparse(first_arg)
sql = dedent(
f_sql[1:]
.strip(f_sql[1])
.encode()
.decode("unicode_escape")
)
else:
sql = normalize_sql_f_string(first_arg)

if sql is not None:
# Append the SQL query to the list
Expand All @@ -64,7 +75,7 @@ def print_part(part: ast.expr) -> str:
elif isinstance(part, ast.JoinedStr):
return normalize_sql_f_string(part)
elif isinstance(part, ast.Constant):
return str(part.s)
return str(part.value)
else:
# Just add null as a placeholder for {...} expressions
return "null"
Expand Down
113 changes: 85 additions & 28 deletions marimo/_cli/convert/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union, cast

# Native to python
from xml.etree.ElementTree import Element, SubElement
Expand All @@ -22,22 +22,71 @@
SuperFencesCodeExtension,
)

from marimo import _loggers
from marimo._ast import codegen
from marimo._ast.app import App, InternalApp, _AppConfig
from marimo._ast.cell import Cell, CellConfig
from marimo._ast.compiler import compile_cell
from marimo._ast.names import DEFAULT_CELL_NAME
from marimo._convert.utils import markdown_to_marimo
from marimo._convert.utils import markdown_to_marimo, sql_to_marimo
from marimo._dependencies.dependencies import DependencyManager

LOGGER = _loggers.marimo_logger()

MARIMO_MD = "marimo-md"
MARIMO_CODE = "marimo-code"

ConvertKeys = Union[Literal["marimo"], Literal["marimo-app"]]

# Regex captures loose yaml for frontmatter
# Should match the following:
# ---
# title: "Title"
# whatever
# ---
YAML_FRONT_MATTER_REGEX = re.compile(
r"^---\s*\n(.*?\n?)(?:---)\s*\n", re.UNICODE | re.DOTALL
)


def backwards_compatible_sanitization(line: str) -> str:
return line


def extract_attribs(
line: str, fence_start: Optional[re.Match[str]] = None
) -> dict[str, str]:
# Extract attributes from the code block.
# Blocks are expected to be like this:
# python {.marimo disabled="true"}
if fence_start is None:
fence_start = RE_NESTED_FENCE_START.match(line)

if fence_start:
# attrs is a bit of a misnomer, matches
# .python.marimo disabled="true"
inner = fence_start.group("attrs")
if inner:
return dict(re.findall(r'(\w+)="([^"]*)"', inner))
return {}


def _is_code_tag(text: str) -> bool:
head = text.split("\n")[0].strip()
return bool(re.search(r"\{.*python.*\}", head))
legacy_format = bool(re.search(r"\{.*python.*\}", head))
legacy_format |= bool(re.search(r"\{.*sql.*\}", head))
if DependencyManager.new_superfences.has_required_version(quiet=True):
supported_format = bool(re.search(r".*\{.*marimo.*\}", head))
return legacy_format or supported_format
return legacy_format


def _get_language(text: str) -> str:
header = text.split("\n").pop(0)
match = RE_NESTED_FENCE_START.match(header)
if match and match.group("lang"):
return str(match.group("lang"))
return "python"


def formatted_code_block(
Expand All @@ -46,14 +95,24 @@ def formatted_code_block(
"""Wraps code in a fenced code block with marimo attributes."""
if attributes is None:
attributes = {}
language = attributes.pop("language", "python")
attribute_str = " ".join(
[""] + [f'{key}="{value}"' for key, value in attributes.items()]
)
guard = "```"
while guard in code:
guard += "`"
if DependencyManager.new_superfences.has_required_version(quiet=True):
return "\n".join(
[
f"""{guard}{language} {{.marimo{attribute_str}}}""",
code,
guard,
"",
]
)
return "\n".join(
[f"""{guard}{{.python.marimo{attribute_str}}}""", code, guard, ""]
[f"""{guard}{{.{language}.marimo{attribute_str}}}""", code, guard, ""]
)


Expand Down Expand Up @@ -82,6 +141,8 @@ def get_source_from_tag(tag: Element) -> str:
if not (source and source.strip()):
return ""
source = markdown_to_marimo(source)
elif tag.attrib.get("language") == "sql":
dmadisetti marked this conversation as resolved.
Show resolved Hide resolved
source = sql_to_marimo(source, tag.attrib.get("query", "_df"))
else:
assert tag.tag == MARIMO_CODE, f"Unknown tag: {tag.tag}"
return source
Expand Down Expand Up @@ -150,7 +211,10 @@ def _tree_to_app(root: Element) -> str:
cell_config: list[CellConfig] = []
for child in root:
names.append(child.get("name", DEFAULT_CELL_NAME))
cell_config.append(get_cell_config_from_tag(child))
# Default to hiding markdown cells.
cell_config.append(
get_cell_config_from_tag(child, hide_code=child.tag == MARIMO_MD)
)
sources.append(get_source_from_tag(child))

return codegen.generate_filecontents(
Expand Down Expand Up @@ -272,15 +336,6 @@ def __init__(self, md: MarimoParser):
super().__init__(md)
self.md = md
self.md.meta = {}
# Regex captures loose yaml for frontmatter
# Should match the following:
# ---
# title: "Title"
# whatever
# ---
self.yaml_front_matter_regex = re.compile(
r"^---\s*\n(.*?\n?)(?:---)\s*\n", re.UNICODE | re.DOTALL
)

def run(self, lines: list[str]) -> list[str]:
import yaml
Expand All @@ -295,7 +350,7 @@ def run(self, lines: list[str]) -> list[str]:
return lines

doc = "\n".join(lines)
result = self.yaml_front_matter_regex.match(doc)
result = YAML_FRONT_MATTER_REGEX.match(doc)

if result:
yaml_content = result.group(1)
Expand Down Expand Up @@ -414,25 +469,27 @@ def add_paragraph() -> None:
code_block = SubElement(parent, MARIMO_CODE)
block_lines = code.split("\n")
code_block.text = "\n".join(block_lines[1:-1])
# Extract attributes from the code block.
# Blocks are expected to be like this:
# {.python.marimo disabled="true"}
fence_start = RE_NESTED_FENCE_START.match(block_lines[0])
if fence_start:
# attrs is a bit of a misnomer, matches
# .python.marimo disabled="true"
inner = fence_start.group("attrs")
if inner:
code_block.attrib = dict(
re.findall(r'(\w+)="([^"]*)"', inner)
)

attribs = extract_attribs(block_lines[0])
if attribs:
code_block.attrib = attribs

# Set after to prevent lang being flushed.
code_block.set("language", _get_language(code))

add_paragraph()
# Flush to indicate all blocks have been processed.
blocks.clear()


def convert_from_md_to_app(text: str) -> App:
return MarimoParser(output_format="marimo-app").convert(text) # type: ignore[arg-type, return-value]
if not text.strip():
app = App()
else:
app = cast(App, MarimoParser(output_format="marimo-app").convert(text))

app._cell_manager.ensure_one_cell()
return app


def convert_from_md(text: str) -> str:
Expand Down
13 changes: 13 additions & 0 deletions marimo/_convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def markdown_to_marimo(source: str) -> str:
)


def sql_to_marimo(source: str, table: str) -> str:
return "\n".join(
[
f"{table} = mo.sql(",
# f-string: expected for sql
codegen.indent_text('f"""'),
codegen.indent_text(source),
codegen.indent_text('"""'),
")",
]
)


def generate_from_sources(
*,
sources: list[str],
Expand Down
Loading
Loading