Skip to content

Commit

Permalink
Add CodeExtractor trait
Browse files Browse the repository at this point in the history
  • Loading branch information
evanrittenhouse committed Jun 15, 2023
1 parent 66089e1 commit 11f9a46
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 136 deletions.
274 changes: 140 additions & 134 deletions crates/ruff/src/jupyter/notebook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::autofix::source_map::{SourceMap, SourceMarker};
use crate::jupyter::index::JupyterIndex;
use crate::jupyter::{Cell, CellType, RawNotebook, SourceValue};
use crate::rules::pycodestyle::rules::SyntaxError;
use crate::source_kind::CodeExtractor;
use crate::IOError;

pub const JUPYTER_NOTEBOOK_EXT: &str = "ipynb";
Expand All @@ -25,7 +26,7 @@ const MAGIC_PREFIX: [&str; 3] = ["%", "!", "?"];

/// Run round-trip source code generation on a given Jupyter notebook file path.
pub fn round_trip(path: &Path) -> anyhow::Result<String> {
let mut notebook = Notebook::read(path).map_err(|err| {
let mut notebook = Notebook::extract_code(path).map_err(|err| {
anyhow::anyhow!(
"Failed to read notebook file `{}`: {:?}",
path.display(),
Expand Down Expand Up @@ -97,133 +98,6 @@ pub struct Notebook {
}

impl Notebook {
/// See also the black implementation
/// <https://github.com/psf/black/blob/69ca0a4c7a365c5f5eea519a90980bab72cab764/src/black/__init__.py#L1017-L1046>
pub fn read(path: &Path) -> Result<Self, Box<Diagnostic>> {
let reader = BufReader::new(File::open(path).map_err(|err| {
Diagnostic::new(
IOError {
message: format!("{err}"),
},
TextRange::default(),
)
})?);
let notebook: RawNotebook = match serde_json::from_reader(reader) {
Ok(notebook) => notebook,
Err(err) => {
// Translate the error into a diagnostic
return Err(Box::new({
match err.classify() {
Category::Io => Diagnostic::new(
IOError {
message: format!("{err}"),
},
TextRange::default(),
),
Category::Syntax | Category::Eof => {
// Maybe someone saved the python sources (those with the `# %%` separator)
// as jupyter notebook instead. Let's help them.
let contents = std::fs::read_to_string(path).map_err(|err| {
Diagnostic::new(
IOError {
message: format!("{err}"),
},
TextRange::default(),
)
})?;
// Check if tokenizing was successful and the file is non-empty
if (ruff_rustpython::tokenize(&contents))
.last()
.map_or(true, Result::is_err)
{
Diagnostic::new(
SyntaxError {
message: format!(
"A Jupyter Notebook (.{JUPYTER_NOTEBOOK_EXT}) must internally be JSON, \
but this file isn't valid JSON: {err}"
),
},
TextRange::default(),
)
} else {
Diagnostic::new(
SyntaxError {
message: format!(
"Expected a Jupyter Notebook (.{JUPYTER_NOTEBOOK_EXT} extension), \
which must be internally stored as JSON, \
but found a Python source file: {err}"
),
},
TextRange::default(),
)
}
}
Category::Data => {
// We could try to read the schema version here but if this fails it's
// a bug anyway
Diagnostic::new(
SyntaxError {
message: format!(
"This file does not match the schema expected of Jupyter Notebooks: {err}"
),
},
TextRange::default(),
)
}
}
}));
}
};

// v4 is what everybody uses
if notebook.nbformat != 4 {
// bail because we should have already failed at the json schema stage
return Err(Box::new(Diagnostic::new(
SyntaxError {
message: format!(
"Expected Jupyter Notebook format 4, found {}",
notebook.nbformat
),
},
TextRange::default(),
)));
}

let valid_code_cells = notebook
.cells
.iter()
.enumerate()
.filter(|(_, cell)| cell.is_valid_code_cell())
.map(|(pos, _)| u32::try_from(pos).unwrap())
.collect::<Vec<_>>();

let mut contents = Vec::with_capacity(valid_code_cells.len());
let mut current_offset = TextSize::from(0);
let mut cell_offsets = Vec::with_capacity(notebook.cells.len());
cell_offsets.push(TextSize::from(0));

for &pos in &valid_code_cells {
let cell_contents = match &notebook.cells[pos as usize].source {
SourceValue::String(string) => string.clone(),
SourceValue::StringArray(string_array) => string_array.join(""),
};
current_offset += TextSize::of(&cell_contents) + TextSize::new(1);
contents.push(cell_contents);
cell_offsets.push(current_offset);
}

Ok(Self {
raw: notebook,
index: OnceCell::new(),
// The additional newline at the end is to maintain consistency for
// all cells. These newlines will be removed before updating the
// source code with the transformed content. Refer `update_cell_content`.
content: contents.join("\n") + "\n",
cell_offsets,
valid_code_cells,
})
}

/// Update the cell offsets as per the given [`SourceMap`].
fn update_cell_offsets(&mut self, source_map: &SourceMap) {
// When there are multiple cells without any edits, the offsets of those
Expand Down Expand Up @@ -402,6 +276,135 @@ impl Notebook {
}
}

impl CodeExtractor<Notebook> for Notebook {
/// See also the black implementation
/// <https://github.com/psf/black/blob/69ca0a4c7a365c5f5eea519a90980bab72cab764/src/black/__init__.py#L1017-L1046>
fn extract_code(path: &Path) -> Result<Self, Box<Diagnostic>> {
let reader = BufReader::new(File::open(path).map_err(|err| {
Diagnostic::new(
IOError {
message: format!("{err}"),
},
TextRange::default(),
)
})?);
let notebook: RawNotebook = match serde_json::from_reader(reader) {
Ok(notebook) => notebook,
Err(err) => {
// Translate the error into a diagnostic
return Err(Box::new({
match err.classify() {
Category::Io => Diagnostic::new(
IOError {
message: format!("{err}"),
},
TextRange::default(),
),
Category::Syntax | Category::Eof => {
// Maybe someone saved the python sources (those with the `# %%` separator)
// as jupyter notebook instead. Let's help them.
let contents = std::fs::read_to_string(path).map_err(|err| {
Diagnostic::new(
IOError {
message: format!("{err}"),
},
TextRange::default(),
)
})?;
// Check if tokenizing was successful and the file is non-empty
if (ruff_rustpython::tokenize(&contents))
.last()
.map_or(true, Result::is_err)
{
Diagnostic::new(
SyntaxError {
message: format!(
"A Jupyter Notebook (.{JUPYTER_NOTEBOOK_EXT}) must internally be JSON, \
but this file isn't valid JSON: {err}"
),
},
TextRange::default(),
)
} else {
Diagnostic::new(
SyntaxError {
message: format!(
"Expected a Jupyter Notebook (.{JUPYTER_NOTEBOOK_EXT} extension), \
which must be internally stored as JSON, \
but found a Python source file: {err}"
),
},
TextRange::default(),
)
}
}
Category::Data => {
// We could try to read the schema version here but if this fails it's
// a bug anyway
Diagnostic::new(
SyntaxError {
message: format!(
"This file does not match the schema expected of Jupyter Notebooks: {err}"
),
},
TextRange::default(),
)
}
}
}));
}
};

// v4 is what everybody uses
if notebook.nbformat != 4 {
// bail because we should have already failed at the json schema stage
return Err(Box::new(Diagnostic::new(
SyntaxError {
message: format!(
"Expected Jupyter Notebook format 4, found {}",
notebook.nbformat
),
},
TextRange::default(),
)));
}

let valid_code_cells = notebook
.cells
.iter()
.enumerate()
.filter(|(_, cell)| cell.is_valid_code_cell())
.map(|(pos, _)| u32::try_from(pos).unwrap())
.collect::<Vec<_>>();

let mut contents = Vec::with_capacity(valid_code_cells.len());
let mut current_offset = TextSize::from(0);
let mut cell_offsets = Vec::with_capacity(notebook.cells.len());
cell_offsets.push(TextSize::from(0));

for &pos in &valid_code_cells {
let cell_contents = match &notebook.cells[pos as usize].source {
SourceValue::String(string) => string.clone(),
SourceValue::StringArray(string_array) => string_array.join(""),
};
current_offset += TextSize::of(&cell_contents) + TextSize::new(1);
contents.push(cell_contents);
cell_offsets.push(current_offset);
}

Ok(Self {
raw: notebook,
index: OnceCell::new(),
// The additional newline at the end is to maintain consistency for
// all cells. These newlines will be removed before updating the
// source code with the transformed content. Refer `update_cell_content`.
content: contents.join("\n") + "\n",
cell_offsets,
valid_code_cells,
})
}
}

#[cfg(test)]
mod test {
use std::path::Path;
Expand All @@ -418,6 +421,9 @@ mod test {
use crate::test::{test_notebook_path, test_resource_path};
use crate::{assert_messages, settings};

use crate::source_kind::CodeExtractor;
use crate::test::test_resource_path;

/// Read a Jupyter cell from the `resources/test/fixtures/jupyter/cell` directory.
fn read_jupyter_cell(path: impl AsRef<Path>) -> Result<Cell> {
let path = test_resource_path("fixtures/jupyter/cell").join(path);
Expand All @@ -428,36 +434,36 @@ mod test {
#[test]
fn test_valid() {
let path = Path::new("resources/test/fixtures/jupyter/valid.ipynb");
assert!(Notebook::read(path).is_ok());
assert!(Notebook::extract_code(path).is_ok());
}

#[test]
fn test_r() {
// We can load this, it will be filtered out later
let path = Path::new("resources/test/fixtures/jupyter/R.ipynb");
assert!(Notebook::read(path).is_ok());
assert!(Notebook::extract_code(path).is_ok());
}

#[test]
fn test_invalid() {
let path = Path::new("resources/test/fixtures/jupyter/invalid_extension.ipynb");
assert_eq!(
Notebook::read(path).unwrap_err().kind.body,
Notebook::extract_code(path).unwrap_err().kind.body,
"SyntaxError: Expected a Jupyter Notebook (.ipynb extension), \
which must be internally stored as JSON, \
but found a Python source file: \
expected value at line 1 column 1"
);
let path = Path::new("resources/test/fixtures/jupyter/not_json.ipynb");
assert_eq!(
Notebook::read(path).unwrap_err().kind.body,
Notebook::extract_code(path).unwrap_err().kind.body,
"SyntaxError: A Jupyter Notebook (.ipynb) must internally be JSON, \
but this file isn't valid JSON: \
expected value at line 1 column 1"
);
let path = Path::new("resources/test/fixtures/jupyter/wrong_schema.ipynb");
assert_eq!(
Notebook::read(path).unwrap_err().kind.body,
Notebook::extract_code(path).unwrap_err().kind.body,
"SyntaxError: This file does not match the schema expected of Jupyter Notebooks: \
missing field `cells` at line 1 column 2"
);
Expand Down Expand Up @@ -485,7 +491,7 @@ mod test {
#[test]
fn test_concat_notebook() {
let path = Path::new("resources/test/fixtures/jupyter/valid.ipynb");
let notebook = Notebook::read(path).unwrap();
let notebook = Notebook::extract_code(path).unwrap();
assert_eq!(
notebook.content,
r#"def unused_variable():
Expand Down
Loading

0 comments on commit 11f9a46

Please sign in to comment.