Skip to content

Commit

Permalink
Add autofix for flake8-type-checking
Browse files Browse the repository at this point in the history
This is actually sort of starting to work

This is actually sort of starting to work

Another hard case
  • Loading branch information
charliermarsh committed May 31, 2023
1 parent 4bd395a commit 2b9311b
Show file tree
Hide file tree
Showing 27 changed files with 1,203 additions and 42 deletions.
85 changes: 85 additions & 0 deletions crates/ruff/src/autofix/codemods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,88 @@ pub(crate) fn remove_imports<'a>(

Ok(Some(state.to_string()))
}

/// Given an import statement, remove any imports that are not specified in the `imports` slice.
///
/// Returns the modified import statement.
pub(crate) fn retain_imports(
imports: &[&str],
stmt: &Stmt,
locator: &Locator,
stylist: &Stylist,
) -> Result<String> {
let module_text = locator.slice(stmt.range());
let mut tree = match_statement(module_text)?;

let Statement::Simple(body) = &mut tree else {
bail!("Expected Statement::Simple");
};

let (aliases, import_module) = match body.body.first_mut() {
Some(SmallStatement::Import(import_body)) => (&mut import_body.names, None),
Some(SmallStatement::ImportFrom(import_body)) => {
if let ImportNames::Aliases(names) = &mut import_body.names {
(
names,
Some((&import_body.relative, import_body.module.as_ref())),
)
} else {
bail!("Expected: ImportNames::Aliases");
}
}
_ => bail!("Expected: SmallStatement::ImportFrom | SmallStatement::Import"),
};

// Preserve the trailing comma (or not) from the last entry.
let trailing_comma = aliases.last().and_then(|alias| alias.comma.clone());

aliases.retain(|alias| {
imports.iter().any(|import| {
let full_name = match import_module {
Some((relative, module)) => {
let module = module.map(compose_module_path);
let member = compose_module_path(&alias.name);
let mut full_name = String::with_capacity(
relative.len() + module.as_ref().map_or(0, String::len) + member.len() + 1,
);
for _ in 0..relative.len() {
full_name.push('.');
}
if let Some(module) = module {
full_name.push_str(&module);
full_name.push('.');
}
full_name.push_str(&member);
full_name
}
None => compose_module_path(&alias.name),
};
full_name == *import
})
});

// But avoid destroying any trailing comments.
if let Some(alias) = aliases.last_mut() {
let has_comment = if let Some(comma) = &alias.comma {
match &comma.whitespace_after {
ParenthesizableWhitespace::SimpleWhitespace(_) => false,
ParenthesizableWhitespace::ParenthesizedWhitespace(whitespace) => {
whitespace.first_line.comment.is_some()
}
}
} else {
false
};
if !has_comment {
alias.comma = trailing_comma;
}
}

let mut state = CodegenState {
default_newline: &stylist.line_ending(),
default_indent: stylist.indentation(),
..CodegenState::default()
};
tree.codegen(&mut state);
Ok(state.to_string())
}
2 changes: 1 addition & 1 deletion crates/ruff/src/autofix/edits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub(crate) fn delete_stmt(
}
}

/// Generate a `Fix` to remove any unused imports from an `import` statement.
/// Generate a `Fix` to remove the specified imports from an `import` statement.
pub(crate) fn remove_unused_imports<'a>(
unused_imports: impl Iterator<Item = &'a str>,
stmt: &Stmt,
Expand Down
2 changes: 0 additions & 2 deletions crates/ruff/src/importer/insertion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
//! Insert statements into Python code.
#![allow(dead_code)]

use ruff_text_size::TextSize;
use rustpython_parser::ast::{Ranged, Stmt};
use rustpython_parser::{lexer, Mode, Tok};
Expand Down
123 changes: 115 additions & 8 deletions crates/ruff/src/importer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute};
use ruff_text_size::TextSize;
use rustpython_parser::ast::{self, Ranged, Stmt, Suite};

use crate::autofix;
use ruff_diagnostics::Edit;
use ruff_python_ast::imports::{AnyImport, Import, ImportFrom};
use ruff_python_ast::source_code::{Locator, Stylist};
use ruff_python_semantic::model::SemanticModel;
use ruff_textwrap::indent;

use crate::cst::matchers::{match_aliases, match_import_from, match_statement};
use crate::importer::insertion::Insertion;
Expand Down Expand Up @@ -73,6 +75,55 @@ impl<'a> Importer<'a> {
}
}

/// Move an existing import into a `TYPE_CHECKING` block.
///
/// If there are no existing `TYPE_CHECKING` blocks, a new one will be added at the top
/// of the file. Otherwise, it will be added after the most recent top-level
/// `TYPE_CHECKING` block.
pub(crate) fn typing_import_edit(
&self,
import: &StmtImport,
at: TextSize,
semantic_model: &SemanticModel,
) -> Result<TypingImportEdit> {
// Generate the modified import statement.
let content = autofix::codemods::retain_imports(
&[import.full_name],
import.stmt,
self.locator,
self.stylist,
)?;

// Import the `TYPE_CHECKING` symbol from the typing module.
let (type_checking_edit, type_checking) = self.get_or_import_symbol(
&ImportRequest::import_from("typing", "TYPE_CHECKING"),
at,
semantic_model,
)?;

// Add the import to a `TYPE_CHECKING` block.
let add_import_edit = if let Some(block) = self.preceding_type_checking_block(at) {
// Add the import to the `TYPE_CHECKING` block.
self.add_to_type_checking_block(&content, block.start())
} else {
// Add the import to a new `TYPE_CHECKING` block.
self.add_type_checking_block(
&format!(
"{}if {type_checking}:{}{}",
self.stylist.line_ending().as_str(),
self.stylist.line_ending().as_str(),
indent(&content, self.stylist.indentation())
),
at,
)?
};

Ok(TypingImportEdit {
type_checking_edit,
add_import_edit,
})
}

/// Generate an [`Edit`] to reference the given symbol. Returns the [`Edit`] necessary to make
/// the symbol available in the current scope along with the bound name of the symbol.
///
Expand Down Expand Up @@ -204,14 +255,6 @@ impl<'a> Importer<'a> {
}
}

/// Return the import statement that precedes the given position, if any.
fn preceding_import(&self, at: TextSize) -> Option<&Stmt> {
self.runtime_imports
.partition_point(|stmt| stmt.start() < at)
.checked_sub(1)
.map(|idx| self.runtime_imports[idx])
}

/// Return the top-level [`Stmt`] that imports the given module using `Stmt::ImportFrom`
/// preceding the given position, if any.
fn find_import_from(&self, module: &str, at: TextSize) -> Option<&Stmt> {
Expand Down Expand Up @@ -258,6 +301,62 @@ impl<'a> Importer<'a> {
statement.codegen(&mut state);
Ok(Edit::range_replacement(state.to_string(), stmt.range()))
}

/// Add a `TYPE_CHECKING` block to the given module.
fn add_type_checking_block(&self, content: &str, at: TextSize) -> Result<Edit> {
let insertion = if let Some(stmt) = self.preceding_import(at) {
// Insert after the last top-level import.
Insertion::end_of_statement(stmt, self.locator, self.stylist)
} else {
// Insert at the start of the file.
Insertion::start_of_file(self.python_ast, self.locator, self.stylist)
};
if insertion.is_inline() {
Err(anyhow::anyhow!(
"Cannot insert `TYPE_CHECKING` block inline"
))
} else {
Ok(insertion.into_edit(content))
}
}

/// Add an import statement to an existing `TYPE_CHECKING` block.
fn add_to_type_checking_block(&self, content: &str, at: TextSize) -> Edit {
Insertion::start_of_block(at, self.locator, self.stylist).into_edit(content)
}

/// Return the import statement that precedes the given position, if any.
fn preceding_import(&self, at: TextSize) -> Option<&'a Stmt> {
self.runtime_imports
.partition_point(|stmt| stmt.start() < at)
.checked_sub(1)
.map(|idx| self.runtime_imports[idx])
}

/// Return the `TYPE_CHECKING` block that precedes the given position, if any.
fn preceding_type_checking_block(&self, at: TextSize) -> Option<&'a Stmt> {
let block = self.type_checking_blocks.first()?;
if block.start() <= at {
Some(block)
} else {
None
}
}
}

/// An edit to an import to a typing-only context.
#[derive(Debug)]
pub(crate) struct TypingImportEdit {
/// The edit to add the `TYPE_CHECKING` symbol to the module.
type_checking_edit: Edit,
/// The edit to add the import to a `TYPE_CHECKING` block.
add_import_edit: Edit,
}

impl TypingImportEdit {
pub(crate) fn into_edits(self) -> Vec<Edit> {
vec![self.type_checking_edit, self.add_import_edit]
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -301,6 +400,14 @@ impl<'a> ImportRequest<'a> {
}
}

/// An existing module or member import, located within an import statement.
pub(crate) struct StmtImport<'a> {
/// The import statement.
pub(crate) stmt: &'a Stmt,
/// The "full name" of the imported module or member.
pub(crate) full_name: &'a str,
}

/// The result of an [`Importer::get_or_import_symbol`] call.
#[derive(Debug)]
pub(crate) enum ResolutionError {
Expand Down
8 changes: 7 additions & 1 deletion crates/ruff/src/message/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::{Display, Formatter};
use std::num::NonZeroUsize;

use colored::{Color, ColoredString, Colorize, Styles};
use itertools::Itertools;
use ruff_text_size::{TextRange, TextSize};
use similar::{ChangeTag, TextDiff};

Expand Down Expand Up @@ -37,7 +38,12 @@ impl Display for Diff<'_> {
let mut output = String::with_capacity(self.source_code.source_text().len());
let mut last_end = TextSize::default();

for edit in self.fix.edits() {
for edit in self
.fix
.edits()
.iter()
.sorted_unstable_by_key(|edit| edit.start())
{
output.push_str(
self.source_code
.slice(TextRange::new(last_end, edit.start())),
Expand Down
Loading

0 comments on commit 2b9311b

Please sign in to comment.