Skip to content

Commit

Permalink
DEV: Add Notebook AST Pre-Commit Hook (#417)
Browse files Browse the repository at this point in the history
- Adds a simple pre-commit check which iterates over code cells in each notebook file and checks that any lines which don't start with a ! are valid python, i.e. that they parse to an abstract syntax tree (AST).
- Also removes a broken and unused old hook.
  • Loading branch information
John-P authored Jul 21, 2022
1 parent fa97018 commit 1f3e07d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
13 changes: 5 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ repos:
flake8-use-fstring, # Encourages use of f-strings vs old style
pep8-naming, # Check PEP8 class naming
]
- repo: local
hooks:
- id: pytest-changed-files
name: pytest-changed-files
entry: pytest
files: tests/.*\btest_\w*.py
language: system
stages: [push]
- repo: local
hooks:
- id: requirements-consistency
Expand All @@ -84,3 +76,8 @@ repos:
pass_filenames: false
language: python
additional_dependencies: [pyyaml]
- id: notebook-check-ast
name: check notebook ast
entry: python pre-commit/notebook_check_ast.py
types_or: [jupyter]
language: python
33 changes: 33 additions & 0 deletions pre-commit/notebook_check_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Simple check to ensure each code cell in a notebook is valid Python."""
import argparse
import ast
import json
import sys
from pathlib import Path
from typing import List


def main(files: List[Path]) -> bool:
"""Check each file in the list of files for valid Python."""
passed = True
for path in files:
with open(path) as fh:
notebook = json.load(fh)
for n, cell in enumerate(notebook["cells"]):
if cell["cell_type"] != "code":
continue
source = "".join([x for x in cell["source"] if not x.startswith("!")])
try:
ast.parse(source)
except SyntaxError as e:
passed = False
print(f"{path.name}: {e.msg} (cell {n}, line {e.lineno})")
break
return passed # noqa: R504


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Check notebook AST")
parser.add_argument("files", nargs="+", help="Path to notebook(s)", type=Path)
args = parser.parse_args()
sys.exit(1 - main(args.files))

0 comments on commit 1f3e07d

Please sign in to comment.